import argparse
import json
from pathlib import Path
import numpy as np
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--meta", required=True)
ap.add_argument("--python", required=True)
ap.add_argument("--rust", required=True)
args = ap.parse_args()
meta = json.loads(Path(args.meta).read_text())
b = meta["batch"]
t = meta["out_samples"]
y_py = np.fromfile(args.python, dtype=np.float32).reshape(b, t)
y_rs = np.fromfile(args.rust, dtype=np.float32).reshape(b, t)
diff = y_rs - y_py
mae = np.mean(np.abs(diff))
rmse = np.sqrt(np.mean(diff**2))
max_abs = np.max(np.abs(diff))
corr = np.corrcoef(y_py.reshape(-1), y_rs.reshape(-1))[0, 1]
print(f"MAE : {mae:.10f}")
print(f"RMSE : {rmse:.10f}")
print(f"MAX ABS : {max_abs:.10f}")
print(f"Pearson : {corr:.10f}")
if max_abs > 1e-5 or rmse > 1e-6:
raise SystemExit("PARITY CHECK FAILED")
print("PARITY CHECK PASSED")
if __name__ == "__main__":
main()