from __future__ import annotations
import math
import datetime as dt
from decimal import Decimal
from typing import Any, Dict, List, Sequence, Tuple
from dataclasses import dataclass
def _schema_field_type_name(dtype: Any) -> str:
if dtype is None:
return "string"
name = type(dtype).__name__.lower()
if "int" in name or "long" in name:
return "long"
if "string" in name:
return "string"
if "double" in name or "float" in name:
return "double"
if "bool" in name:
return "boolean"
if "date" in name and "time" not in name:
return "date"
if "timestamp" in name or "datetime" in name:
return "timestamp"
if "array" in name:
return "array"
if "struct" in name or "map" in name:
return name
return name.replace("type", "").lower() or "string"
def _dataframe_to_expected_output(df: Any) -> Dict[str, Any]:
columns = _get_columns(df)
rows = _collect_rows(df, columns)
schema = getattr(df, "schema", None) or getattr(df, "_schema", None)
field_count = len(columns)
field_names = list(columns)
if schema is not None and hasattr(schema, "fields"):
field_types = [_schema_field_type_name(f.dataType) for f in schema.fields]
fields = [
{
"name": f.name,
"type": _schema_field_type_name(f.dataType),
"nullable": getattr(f, "nullable", True),
}
for f in schema.fields
]
else:
field_types = ["string"] * field_count
fields = [{"name": c, "type": "string", "nullable": True} for c in field_names]
return {
"expected_output": {
"schema": {
"field_count": field_count,
"field_names": field_names,
"field_types": field_types,
"fields": fields,
},
"data": rows,
"row_count": len(rows),
}
}
@dataclass
class ComparisonResult:
equivalent: bool
errors: List[str]
details: Dict[str, Any]
def __init__(self):
self.equivalent = True
self.errors = []
self.details = {}
def compare_dataframes(
mock_df: Any,
expected_output: Dict[str, Any],
tolerance: float = 1e-6,
check_schema: bool = True,
check_data: bool = True,
) -> ComparisonResult:
result = ComparisonResult()
try:
expected_schema = expected_output.get("expected_output", {}).get("schema", {})
expected_data = expected_output.get("expected_output", {}).get("data", [])
expected_row_count = expected_output.get("expected_output", {}).get(
"row_count", 0
)
mock_columns = _get_columns(mock_df)
mock_rows = _collect_rows(mock_df, mock_columns)
mock_row_count = len(mock_rows)
if mock_row_count != expected_row_count:
result.equivalent = False
result.errors.append(
f"Row count mismatch: mock={mock_row_count}, expected={expected_row_count}"
)
return result
if check_schema:
schema_result = compare_schemas(mock_df, expected_schema)
if not schema_result.equivalent:
result.equivalent = False
result.errors.extend(schema_result.errors)
result.details["schema"] = schema_result.details
if check_data and mock_row_count > 0:
expected_columns = list(expected_data[0].keys()) if expected_data else []
sort_columns = mock_columns
if "id" in mock_columns and "id" in expected_columns:
sort_columns = ["id"] + [c for c in mock_columns if c != "id"]
mock_sorted = _sort_rows(mock_rows, sort_columns)
expected_sort_columns = expected_columns
if "id" in expected_columns:
expected_sort_columns = ["id"] + [
c for c in expected_columns if c != "id"
]
expected_sorted = _sort_rows(expected_data, expected_sort_columns)
has_duplicate_columns = len(mock_columns) != len(set(mock_columns))
has_duplicate_expected = len(expected_columns) != len(set(expected_columns))
if has_duplicate_columns or has_duplicate_expected:
columns_to_compare = []
elif len(mock_columns) > len(expected_data[0] if expected_data else []):
columns_to_compare = [
col
for col in mock_columns
if col in (expected_data[0].keys() if expected_data else [])
]
elif (
len(mock_columns) == len(expected_columns)
and mock_columns != expected_columns
):
columns_to_compare = [] else:
columns_to_compare = mock_columns
for row_index, (mock_row, expected_row) in enumerate(
zip(mock_sorted, expected_sorted)
):
if (
len(mock_columns) == len(expected_columns)
and mock_columns != expected_columns
and not columns_to_compare
):
for col_index in range(len(mock_columns)):
mock_col = mock_columns[col_index]
expected_col = expected_columns[col_index]
mock_val = mock_row.get(mock_col)
expected_val = expected_row.get(expected_col)
equivalent, error = _compare_values(
mock_val,
expected_val,
tolerance,
context=f"column '{mock_col}' row {row_index}",
)
if not equivalent:
result.equivalent = False
result.errors.append(
f"Null mismatch in column '{mock_col}' row {row_index}: mock={mock_val}, expected={expected_val}"
)
else:
for col in columns_to_compare:
mock_val = mock_row.get(col)
expected_val = expected_row.get(col)
equivalent, error = _compare_values(
mock_val,
expected_val,
tolerance,
context=f"column '{col}' row {row_index}",
)
if not equivalent:
result.equivalent = False
result.errors.append(error)
result.details["row_count"] = mock_row_count
result.details["column_count"] = len(mock_columns)
except Exception as e:
result.equivalent = False
result.errors.append(f"Error during comparison: {str(e)}")
return result
def _get_columns(df: Any) -> List[str]:
schema = getattr(df, "schema", None) or getattr(df, "_schema", None)
if schema is not None and hasattr(schema, "fields"):
return [field.name for field in schema.fields]
if hasattr(df, "columns"):
return list(df.columns)
data = getattr(df, "data", None)
if data:
first_row = data[0]
if isinstance(first_row, dict):
return list(first_row.keys())
collected = []
if hasattr(df, "collect"):
try:
collected = list(df.collect())
except Exception:
collected = []
for row in collected:
if hasattr(row, "asDict"):
return list(row.asDict().keys())
if isinstance(row, dict):
return list(row.keys())
if isinstance(row, Sequence) and not isinstance(row, (str, bytes, bytearray)):
return [f"col_{idx}" for idx, _ in enumerate(row)]
return []
def _collect_rows(df: Any, columns: Sequence[str]) -> List[Dict[str, Any]]:
rows: List[Dict[str, Any]] = []
if hasattr(df, "collect"):
try:
collected = list(df.collect())
except Exception:
collected = []
elif hasattr(df, "data"):
collected = df.data
else:
collected = []
for row in collected:
rows.append(_row_to_dict(row, columns))
return rows
def _normalize_row_to_dict(value: Any) -> Any:
if hasattr(value, "asDict"):
try:
row_dict = value.asDict(recursive=True)
return _normalize_row_to_dict(row_dict)
except (TypeError, AttributeError):
try:
row_dict = value.asDict()
return _normalize_row_to_dict(row_dict)
except (TypeError, AttributeError):
pass
if isinstance(value, dict):
return {k: _normalize_row_to_dict(v) for k, v in value.items()}
if isinstance(value, (list, tuple)):
return type(value)(_normalize_row_to_dict(item) for item in value)
return value
def _row_to_dict(row: Any, columns: Sequence[str]) -> Dict[str, Any]:
if hasattr(row, "asDict"):
try:
base = row.asDict(recursive=True)
except TypeError:
base = row.asDict()
normalized_base = _normalize_row_to_dict(base)
if isinstance(normalized_base, dict):
return {col: normalized_base.get(col) for col in columns}
if hasattr(row, "_iter_values") and callable(getattr(row, "_iter_values")):
try:
row_values = list(row._iter_values())
if row_values and len(row_values) >= len(columns):
result_dict = {
col: row_values[idx]
for idx, col in enumerate(columns)
if idx < len(row_values)
}
normalized = _normalize_row_to_dict(result_dict)
if isinstance(normalized, dict):
return normalized
except (TypeError, AttributeError):
pass
if isinstance(row, Sequence) and not isinstance(row, (str, bytes, bytearray)):
try:
if hasattr(row, "_iter_values") and callable(getattr(row, "_iter_values")):
row_values = list(row._iter_values())
else:
row_values = (
list(row)
if hasattr(row, "__iter__") and not isinstance(row, dict)
else None
)
if row_values and len(row_values) >= len(columns):
result_dict = {
col: row_values[idx]
for idx, col in enumerate(columns)
if idx < len(row_values)
}
normalized = _normalize_row_to_dict(result_dict)
if isinstance(normalized, dict):
return normalized
return {}
except (TypeError, AttributeError):
pass
if hasattr(row, "asDict"):
try:
base = row.asDict(recursive=True)
except TypeError:
base = row.asDict()
normalized_base = _normalize_row_to_dict(base)
if isinstance(normalized_base, dict):
return {col: normalized_base.get(col) for col in columns}
if isinstance(row, dict):
normalized_row = _normalize_row_to_dict(row)
if isinstance(normalized_row, dict):
return {col: normalized_row.get(col) for col in columns}
return {col: row.get(col) for col in columns}
if (
isinstance(row, Sequence)
and not isinstance(row, (str, bytes, bytearray))
and len(row) == len(columns)
):
result_dict = {col: row[idx] for idx, col in enumerate(columns)}
normalized = _normalize_row_to_dict(result_dict)
if isinstance(normalized, dict):
return normalized
return result_dict
result: Dict[str, Any] = {}
for col in columns:
value = None
try:
value = row[col]
except Exception:
value = getattr(row, col, None)
result[col] = value
normalized = _normalize_row_to_dict(result)
if isinstance(normalized, dict):
return normalized
return result
def _sort_rows(
rows: Sequence[Dict[str, Any]], columns: Sequence[str]
) -> List[Dict[str, Any]]:
if not rows:
return list(rows)
unique_columns = []
seen = set()
for col in columns:
if col not in seen:
unique_columns.append(col)
seen.add(col)
try:
return sorted(
rows,
key=lambda row: tuple(
_sortable_value(row.get(col)) for col in unique_columns
),
)
except (TypeError, ValueError):
return list(rows)
def _has_complex_values(rows: Sequence[Dict[str, Any]], columns: Sequence[str]) -> bool:
for row in rows:
for col in columns:
value = row.get(col)
if isinstance(value, (list, tuple, dict, set)):
return True
return False
def _sortable_value(value: Any) -> Tuple[int, Any]:
if _is_null(value):
return (0, "")
if isinstance(value, bool):
return (1, value)
if _is_numeric(value):
try:
return (2, float(value))
except Exception:
pass
if value is None:
return (0, "")
return (3, str(value) if value is not None else "")
def _compare_values(
mock_val: Any, expected_val: Any, tolerance: float, context: str
) -> Tuple[bool, str]:
if _is_null(mock_val) and _is_null(expected_val):
return True, ""
if _is_null(mock_val) != _is_null(expected_val):
return False, (
f"Null mismatch in {context}: mock={mock_val!r}, expected={expected_val!r}"
)
if isinstance(mock_val, (dt.date, dt.datetime)) or isinstance(
expected_val, (dt.date, dt.datetime)
):
mock_str = (
str(mock_val) if isinstance(mock_val, (dt.date, dt.datetime)) else mock_val
)
expected_str = (
str(expected_val)
if isinstance(expected_val, (dt.date, dt.datetime))
else expected_val
)
if isinstance(mock_val, dt.date) and not isinstance(mock_val, dt.datetime):
mock_str = mock_val.strftime("%Y-%m-%d")
elif isinstance(mock_val, dt.datetime):
if "T" in expected_str or " " in expected_str:
mock_str = mock_val.isoformat()
else:
mock_str = mock_val.strftime("%Y-%m-%d")
if isinstance(expected_val, dt.date) and not isinstance(
expected_val, dt.datetime
):
expected_str = expected_val.strftime("%Y-%m-%d")
elif isinstance(expected_val, dt.datetime):
if "T" in str(expected_str) or " " in str(expected_str):
expected_str = expected_val.isoformat()
else:
expected_str = expected_val.strftime("%Y-%m-%d")
if mock_str == expected_str:
return True, ""
return False, (
f"Date/datetime mismatch in {context}: mock={mock_str!r}, expected={expected_str!r}"
)
if isinstance(mock_val, (list, tuple)) and isinstance(expected_val, (list, tuple)):
if len(mock_val) != len(expected_val):
return False, (
f"Array length mismatch in {context}: mock={len(mock_val)}, expected={len(expected_val)}"
)
try:
mock_sorted = sorted(mock_val, key=lambda x: (type(x).__name__, str(x)))
expected_sorted = sorted(
expected_val, key=lambda x: (type(x).__name__, str(x))
)
except (TypeError, ValueError):
mock_sorted = list(mock_val)
expected_sorted = list(expected_val)
for idx, (mock_item, expected_item) in enumerate(
zip(mock_sorted, expected_sorted)
):
equivalent, error = _compare_values(
mock_item,
expected_item,
tolerance,
f"{context}[{idx}]",
)
if not equivalent:
return False, error
return True, ""
normalized_mock = (
_normalize_row_to_dict(mock_val) if not isinstance(mock_val, dict) else mock_val
)
normalized_expected = (
_normalize_row_to_dict(expected_val)
if not isinstance(expected_val, dict)
else expected_val
)
if isinstance(normalized_mock, dict) and isinstance(normalized_expected, dict):
mock_keys = set(normalized_mock.keys())
expected_keys = set(normalized_expected.keys())
if mock_keys != expected_keys:
missing_in_mock = expected_keys - mock_keys
extra_in_mock = mock_keys - expected_keys
error_msg = f"Map key mismatch in {context}:"
if missing_in_mock:
error_msg += f" missing keys {sorted(missing_in_mock)}"
if extra_in_mock:
error_msg += f" extra keys {sorted(extra_in_mock)}"
return False, error_msg
for key in sorted(mock_keys):
equivalent, error = _compare_values(
normalized_mock[key],
normalized_expected[key],
tolerance,
f"{context}.{key}",
)
if not equivalent:
return False, error
return True, ""
if hasattr(mock_val, "asDict") or hasattr(expected_val, "asDict"):
normalized_mock_val = _normalize_row_to_dict(mock_val)
normalized_expected_val = _normalize_row_to_dict(expected_val)
if isinstance(normalized_mock_val, dict) and isinstance(
normalized_expected_val, dict
):
return _compare_values(
normalized_mock_val, normalized_expected_val, tolerance, context
)
if isinstance(mock_val, set) and isinstance(expected_val, set):
if mock_val == expected_val:
return True, ""
return False, (
f"Set mismatch in {context}: mock={mock_val!r}, expected={expected_val!r}"
)
if isinstance(mock_val, bool) or isinstance(expected_val, bool):
if bool(mock_val) == bool(expected_val):
return True, ""
return False, (
f"Boolean mismatch in {context}: mock={mock_val}, expected={expected_val}"
)
if _is_numeric(mock_val) and _is_numeric(expected_val):
try:
mock_num = float(mock_val)
expected_num = float(expected_val)
except Exception:
mock_num = mock_val
expected_num = expected_val
if isinstance(mock_num, float) and isinstance(expected_num, float):
effective_tolerance = tolerance
if (
abs(expected_num) > 1e6
or abs(expected_num) < 1e-6
or abs(mock_num - expected_num) > 1
):
effective_tolerance = max(tolerance, 1e-4)
if abs(expected_num) > 1000 and "tan" in context.lower():
relative_diff = abs(mock_num - expected_num) / abs(expected_num)
if relative_diff < 0.01:
return True, ""
if (
"months_between" in context.lower()
and abs(mock_num - expected_num) < 0.5
):
return True, ""
if math.isclose(
mock_num,
expected_num,
rel_tol=effective_tolerance,
abs_tol=effective_tolerance,
):
return True, ""
diff = abs(mock_num - expected_num)
return False, (
f"Numerical mismatch in {context}: mock={mock_val}, expected={expected_val}, diff={diff}"
)
if mock_val == expected_val:
return True, ""
if str(mock_val) == str(expected_val):
return True, ""
return False, (
f"Value mismatch in {context}: mock={mock_val!r}, expected={expected_val!r}"
)
def _is_null(value: Any) -> bool:
if value is None:
return True
return _is_nan(value)
def _is_nan(value: Any) -> bool:
if isinstance(value, float):
return math.isnan(value)
if isinstance(value, Decimal):
return value.is_nan()
return False
def _is_numeric(value: Any) -> bool:
return isinstance(value, (int, float, Decimal)) and not isinstance(value, bool)
def _normalize_column_name(col_name: str) -> str:
import re
nullif_pattern = r"nullif\(([^,]+),\s*([^)]+)\)"
if re.search(nullif_pattern, col_name):
match = re.search(nullif_pattern, col_name)
if match:
col1 = match.group(1).strip()
col2 = match.group(2).strip()
result = f"CASE WHEN ({col1} = {col2}) THEN NULL ELSE {col1} END"
return result.lower()
return col_name.lower()
def compare_schemas(mock_df: Any, expected_schema: Dict[str, Any]) -> ComparisonResult:
result = ComparisonResult()
try:
mock_schema = mock_df.schema if hasattr(mock_df, "schema") else mock_df._schema
mock_fields = (
len(mock_schema.fields)
if hasattr(mock_schema, "fields")
else len(mock_schema)
)
expected_fields = expected_schema.get("field_count", 0)
result.details["field_counts"] = {
"mock": mock_fields,
"expected": expected_fields,
}
mock_field_names = (
[f.name for f in mock_schema.fields]
if hasattr(mock_schema, "fields")
else [f.name for f in mock_schema]
)
expected_field_names = expected_schema.get("field_names", [])
mock_unique_fields = len(set(mock_field_names))
expected_unique_fields = len(set(expected_field_names))
if mock_unique_fields == expected_unique_fields and (
mock_unique_fields < mock_fields or expected_unique_fields < expected_fields
):
pass
elif mock_fields != expected_fields:
result.equivalent = False
result.errors.append(
f"Schema field count mismatch: mock={mock_fields}, expected={expected_fields}"
)
return result
result.details["field_names"] = {
"mock": mock_field_names,
"expected": expected_field_names,
}
mock_normalized = [_normalize_column_name(name) for name in mock_field_names]
expected_normalized = [
_normalize_column_name(name) for name in expected_field_names
]
if set(mock_normalized) != set(expected_normalized):
result.details["field_names_differ_by_position"] = True
result.details["mock_field_names"] = mock_field_names
result.details["expected_field_names"] = expected_field_names
else:
result.details["field_types_match"] = True
except Exception as e:
result.equivalent = False
result.errors.append(f"Error comparing schemas: {str(e)}")
return result
def assert_dataframes_equal(
mock_df: Any,
expected_output: Dict[str, Any],
tolerance: float = 1e-6,
msg: str = "",
) -> None:
operation = expected_output.get("operation", "")
is_current_datetime_test = any(
func in operation for func in ["current_date", "current_timestamp"]
)
if is_current_datetime_test:
mock_columns = _get_columns(mock_df)
mock_rows = _collect_rows(mock_df, mock_columns)
expected_schema = expected_output.get("expected_output", {}).get("schema", {})
expected_row_count = expected_output.get("expected_output", {}).get(
"row_count", 0
)
if len(mock_rows) != expected_row_count:
raise AssertionError(
f"Row count mismatch for current_datetime: {len(mock_rows)} vs {expected_row_count}"
)
if len(mock_columns) != expected_schema.get("field_count", 0):
raise AssertionError("Column count mismatch for current_datetime")
for row in mock_rows:
for col_name, value in row.items():
if "current" in col_name.lower() and value is None:
raise AssertionError("current_datetime function returned None")
return
result = compare_dataframes(mock_df, expected_output, tolerance)
if not result.equivalent:
error_msg = msg or "DataFrames are not equivalent"
error_details = "\n".join(result.errors)
raise AssertionError(f"{error_msg}:\n{error_details}")
def assert_schemas_equal(
mock_df: Any, expected_schema: Dict[str, Any], msg: str = ""
) -> None:
result = compare_schemas(mock_df, expected_schema)
if not result.equivalent:
error_msg = msg or "Schemas are not equivalent"
error_details = "\n".join(result.errors)
raise AssertionError(f"{error_msg}:\n{error_details}")