from __future__ import annotations
import os
import sys
import time
import datetime as dt
from contextlib import redirect_stdout
from dataclasses import dataclass
from enum import Enum
from io import StringIO
from typing import Any, Callable, Dict, List, Optional
def _ensure_java_home() -> None:
if "JAVA_HOME" in os.environ and os.environ["JAVA_HOME"].strip():
return
candidates = [
"/opt/homebrew/opt/openjdk@11",
"/opt/homebrew/opt/openjdk@17",
"/opt/homebrew/opt/openjdk",
]
for candidate in candidates:
java_bin = os.path.join(candidate, "bin", "java")
if not os.path.exists(java_bin):
continue
try:
actual_java_path = os.path.realpath(java_bin)
actual_java_home = os.path.dirname(os.path.dirname(actual_java_path))
if os.path.exists(os.path.join(actual_java_home, "bin", "java")):
os.environ["JAVA_HOME"] = actual_java_home
if actual_java_home:
bin_dir = os.path.join(actual_java_home, "bin")
if bin_dir not in os.environ.get("PATH", ""):
os.environ["PATH"] = f"{bin_dir}:{os.environ.get('PATH', '')}"
return
except Exception:
os.environ["JAVA_HOME"] = candidate
bin_dir = os.path.join(candidate, "bin")
if bin_dir not in os.environ.get("PATH", ""):
os.environ["PATH"] = f"{bin_dir}:{os.environ.get('PATH', '')}"
return
class Backend(Enum):
PYSPARK = "pyspark"
SPARKLESS = "sparkless"
@dataclass(frozen=True)
class RunMeta:
started_at_epoch_s: float
python: str
platform: str
@dataclass
class Scenario:
id: str
title: str
fn: Callable[[Any], Any]
tags: List[str]
@dataclass
class CapturedUI:
show: str = ""
print_schema: str = ""
explain: str = ""
repr_df: str = ""
repr_col: str = ""
@dataclass
class CapturedError:
exc_type: str
message: str
@dataclass
class ScenarioResult:
backend: str
scenario_id: str
ok: bool
data: Optional[List[Dict[str, Any]]] = None
schema: Optional[str] = None
ui: Optional[CapturedUI] = None
error: Optional[CapturedError] = None
@dataclass
class RunResults:
meta: RunMeta
by_scenario: Dict[str, Dict[str, ScenarioResult]]
def to_jsonable(self) -> dict:
out: dict = {"by_scenario": {}}
for sid, per_backend in self.by_scenario.items():
out["by_scenario"][sid] = {
b: _sr_to_jsonable(r) for b, r in per_backend.items()
}
return out
def _row_to_dict(r: Any) -> Dict[str, Any]:
d = r.asDict() if hasattr(r, "asDict") else dict(r)
out: Dict[str, Any] = {}
for k, v in d.items():
if isinstance(v, (dt.date, dt.datetime)):
out[k] = v.isoformat()
continue
if (
v is not None
and hasattr(v, "__iter__")
and not isinstance(v, (str, bytes, dict))
):
try:
out[k] = list(v)
except Exception:
out[k] = v
else:
out[k] = v
return out
def _capture_show(df: Any) -> str:
buf = StringIO()
with redirect_stdout(buf):
df.show()
return buf.getvalue()
def _capture_print_schema(df: Any) -> str:
buf = StringIO()
with redirect_stdout(buf):
df.printSchema()
return buf.getvalue()
def _capture_explain(df: Any) -> str:
buf = StringIO()
with redirect_stdout(buf):
try:
df.explain(True)
except Exception:
df.explain()
return buf.getvalue()
def _best_effort_schema(df: Any) -> str:
s = getattr(df, "schema", None)
if s is None:
return ""
if hasattr(s, "simpleString") and callable(s.simpleString):
try:
return s.simpleString()
except Exception:
pass
if hasattr(s, "json") and callable(s.json):
try:
return s.json()
except Exception:
pass
if hasattr(s, "simpleString") and callable(s.simpleString):
try:
return s.simpleString()
except Exception:
pass
return str(s)
def _sr_to_jsonable(r: ScenarioResult) -> dict:
return {
"backend": r.backend,
"scenario_id": r.scenario_id,
"ok": r.ok,
"data": r.data,
"schema": r.schema,
"ui": None if r.ui is None else r.ui.__dict__,
"error": None if r.error is None else r.error.__dict__,
}
def _make_session(backend: Backend):
if backend == Backend.PYSPARK:
_ensure_java_home()
os.environ["PYSPARK_DRIVER_PYTHON"] = sys.executable
os.environ["PYSPARK_PYTHON"] = sys.executable
from pyspark.sql import SparkSession
try:
active = SparkSession.getActiveSession()
if active is not None:
active.stop()
except Exception:
pass
return (
SparkSession.builder.master("local[1]")
.appName("parity_hunt")
.config("spark.pyspark.driver.python", os.environ["PYSPARK_DRIVER_PYTHON"])
.config("spark.pyspark.python", os.environ["PYSPARK_PYTHON"])
.getOrCreate()
)
else:
from sparkless.sql import SparkSession
return SparkSession.builder.appName("parity_hunt").getOrCreate()
def run_one_scenario(s: Scenario, backend: Backend) -> ScenarioResult:
sess = _make_session(backend)
try:
ui = CapturedUI()
result_obj = s.fn(sess)
df = None
col = None
if isinstance(result_obj, tuple) and len(result_obj) == 2:
df, col = result_obj
else:
df = result_obj
if df is not None:
ui.repr_df = repr(df)
try:
ui.show = _capture_show(df)
except Exception as e:
ui.show = f"<show failed: {type(e).__name__}: {e}>"
try:
ui.print_schema = _capture_print_schema(df)
except Exception as e:
ui.print_schema = f"<printSchema failed: {type(e).__name__}: {e}>"
try:
ui.explain = _capture_explain(df)
except Exception as e:
ui.explain = f"<explain failed: {type(e).__name__}: {e}>"
if col is not None:
ui.repr_col = repr(col)
data = None
schema = None
if df is not None and hasattr(df, "collect"):
rows = df.collect()
data = [_row_to_dict(r) for r in rows]
schema = _best_effort_schema(df)
return ScenarioResult(
backend=backend.value,
scenario_id=s.id,
ok=True,
data=data,
schema=schema,
ui=ui,
)
except Exception as e:
return ScenarioResult(
backend=backend.value,
scenario_id=s.id,
ok=False,
error=CapturedError(exc_type=type(e).__name__, message=str(e)),
)
finally:
try:
sess.stop()
except Exception:
pass
def run_scenarios(scenarios: List[Scenario], backends: List[Backend]) -> RunResults:
meta = RunMeta(
started_at_epoch_s=time.time(),
python=sys.version.split()[0],
platform=sys.platform,
)
by_scenario: Dict[str, Dict[str, ScenarioResult]] = {}
for s in scenarios:
by_scenario[s.id] = {}
for b in backends:
by_scenario[s.id][b.value] = run_one_scenario(s, b)
return RunResults(meta=meta, by_scenario=by_scenario)