from __future__ import annotations
from datetime import date, datetime
def _is_date_like(v: object) -> bool:
if isinstance(v, (date, datetime)):
return True
if isinstance(v, str) and len(v) >= 10 and v[4:5] == "-" and v[7:8] == "-":
return v[:4].isdigit() and v[5:7].isdigit() and v[8:10].isdigit()
return False
def _normalize_date_like(v: object) -> tuple[str | None, str | None]:
if isinstance(v, datetime):
return (v.date().isoformat(), v.time().isoformat())
if isinstance(v, date):
return (v.isoformat(), None)
if isinstance(v, str) and _is_date_like(v):
if "T" in v or " " in v:
parts = v.replace("T", " ").split(" ", 1)
return (parts[0], parts[1] if len(parts) > 1 else None)
return (v[:10], None)
return (None, None)
def assert_rows_equal(
actual: list[dict],
expected: list[dict],
order_matters: bool = True,
) -> None:
actual_dicts = [_row_to_dict(r) if not isinstance(r, dict) else r for r in actual]
expected_dicts = [
_row_to_dict(r) if not isinstance(r, dict) else r for r in expected
]
if len(actual_dicts) != len(expected_dicts):
raise AssertionError(
f"Row count mismatch: got {len(actual_dicts)}, expected {len(expected_dicts)}"
)
if order_matters:
for i, (a, e) in enumerate(zip(actual_dicts, expected_dicts)):
_assert_row_equal(a, e, index=i)
else:
def key_fn(r: dict) -> str:
return str(sorted((k, _sort_key_val(v)) for k, v in r.items()))
actual_sorted = sorted(actual_dicts, key=key_fn)
expected_sorted = sorted(expected_dicts, key=key_fn)
for i, (a, e) in enumerate(zip(actual_sorted, expected_sorted)):
_assert_row_equal(a, e, index=i)
def _norm_val(v: object) -> object:
if isinstance(v, float) and not (v != v): return round(v, 10)
return v
def _sort_key_val(v: object) -> tuple:
if v is None:
return (0, "")
return (1, (type(v).__name__, repr(v)))
def _assert_row_equal(actual: dict, expected: dict, index: int = 0) -> None:
keys = set(actual.keys()) | set(expected.keys())
for k in sorted(keys):
if k not in actual:
raise AssertionError(f"Row {index}: missing key '{k}' in actual")
if k not in expected:
raise AssertionError(f"Row {index}: extra key '{k}' in actual")
a, e = actual[k], expected[k]
if isinstance(a, float) and isinstance(e, float):
if a != a and e != e:
continue if abs(a - e) > 1e-9:
raise AssertionError(f"Row {index} key '{k}': {a!r} != {e!r}")
elif _is_date_like(a) or _is_date_like(e):
na, ta = _normalize_date_like(a)
ne, te = _normalize_date_like(e)
if (na, ta) != (ne, te):
raise AssertionError(f"Row {index} key '{k}': {a!r} != {e!r}")
elif a != e:
raise AssertionError(f"Row {index} key '{k}': {a!r} != {e!r}")
def _row_to_dict(r) -> dict:
d = r.asDict() if hasattr(r, "asDict") else dict(r)
out = {}
for k, v in d.items():
if (
v is not None
and hasattr(v, "__iter__")
and not isinstance(v, (str, bytes, dict))
):
try:
out[k] = list(v)
except (TypeError, ValueError):
out[k] = v
else:
out[k] = v
return out
def _try_pyspark():
try:
from pyspark.sql import SparkSession as PySparkSession
from pyspark.sql import functions as F
spark = PySparkSession.builder.master("local[1]").appName("test").getOrCreate()
return spark, F
except Exception:
return None, None
def run_with_pyspark_expected(
pyspark_fn,
fallback_expected: list[dict],
) -> list[dict]:
pyspark_spark, F = _try_pyspark()
if pyspark_spark is not None and F is not None:
try:
rows = pyspark_fn(pyspark_spark, F)
return [_row_to_dict(r) for r in rows]
except Exception:
pass
return fallback_expected