#!/usr/bin/env bash
set -euo pipefail

ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
ARTIFACTS_DIR="${ARTIFACTS_DIR:-$ROOT_DIR/artifacts}"
REPORT_PREFIX="${REPORT_PREFIX:-core_recovery_gate}"
OUT_JSON="${OUT_JSON:-$ARTIFACTS_DIR/${REPORT_PREFIX}_report.json}"
OUT_MD="${OUT_MD:-$ARTIFACTS_DIR/${REPORT_PREFIX}_report.md}"

PYTHON_BIN="${PYTHON_BIN:-python3}"
DATA_DIR="${DATA_DIR:-/tmp/iridium-core-recovery-gate}"
NODES="${NODES:-6000}"
BATCH_SIZE="${BATCH_SIZE:-256}"
QUERY_LIMIT="${QUERY_LIMIT:-200}"
MAX_RECOVERY_MS="${MAX_RECOVERY_MS:-250}"
MAX_QUERY_MS="${MAX_QUERY_MS:-25}"
MAX_SCANNED_MULTIPLIER="${MAX_SCANNED_MULTIPLIER:-1}"
MAX_SCAN_SLACK="${MAX_SCAN_SLACK:-1}"

mkdir -p "$ARTIFACTS_DIR"

"$PYTHON_BIN" - \
  "$OUT_JSON" "$OUT_MD" "$DATA_DIR" "$NODES" "$BATCH_SIZE" "$QUERY_LIMIT" \
  "$MAX_RECOVERY_MS" "$MAX_QUERY_MS" "$MAX_SCANNED_MULTIPLIER" "$MAX_SCAN_SLACK" <<'PY'
import gc
import json
import math
import shutil
import sys
import time
from pathlib import Path

try:
    import iridium
except Exception as exc:
    print(f"core_recovery_gate: failed to import iridium ({exc})", file=sys.stderr)
    raise SystemExit(2)


def chunked(values, size):
    for i in range(0, len(values), size):
        yield values[i : i + size]


out_json = Path(sys.argv[1])
out_md = Path(sys.argv[2])
data_dir = Path(sys.argv[3])
nodes = int(sys.argv[4])
batch_size = int(sys.argv[5])
query_limit = int(sys.argv[6])
max_recovery_ms = float(sys.argv[7])
max_query_ms = float(sys.argv[8])
max_scanned_multiplier = float(sys.argv[9])
max_scan_slack = int(sys.argv[10])

if data_dir.exists():
    shutil.rmtree(data_dir)
data_dir.mkdir(parents=True, exist_ok=True)

# Prepare WAL-heavy state: begin ingest and intentionally skip finish to force replay path.
writer = iridium.Client(
    data_dir=str(data_dir),
    scan_start=0,
    scan_end_exclusive=nodes + 1,
    morsel_size=256,
    parallel_workers=0,
)
writer.begin_ingest()
try:
    node_rows = [(node_id, 1, [node_id + 1]) for node_id in range(1, nodes + 1)]
    for batch in chunked(node_rows, batch_size):
        writer.ingest_nodes_batch(batch)
    edge_rows = [(node_id, 2, f"edge-{node_id}") for node_id in range(1, nodes + 1)]
    for batch in chunked(edge_rows, batch_size):
        writer.ingest_edges_batch(batch)
finally:
    # No finish_ingest on purpose; this validates restart recovery from WAL.
    writer = None
    gc.collect()

reader = iridium.Client(
    data_dir=str(data_dir),
    scan_start=0,
    scan_end_exclusive=nodes + 1,
    morsel_size=256,
    parallel_workers=0,
)
t0 = time.perf_counter()
stream = json.loads(reader.query_json(f"MATCH (n) RETURN n LIMIT {query_limit}"))
recovery_plus_query_ms = (time.perf_counter() - t0) * 1000.0

# Second query approximates steady-state query latency after recovery/open path.
t1 = time.perf_counter()
stream_steady = json.loads(reader.query_json(f"MATCH (n) RETURN n LIMIT {query_limit}"))
steady_query_ms = (time.perf_counter() - t1) * 1000.0

rows = reader.query_rows(f"MATCH (n) RETURN n LIMIT {query_limit}")
node1 = next((row for row in rows if row[0] == 1), None)
node1_has_delta = bool(node1 and node1[2] >= 1)

scanned_nodes = int(stream.get("scanned_nodes", 0))
scan_bound = int(math.ceil(query_limit * max_scanned_multiplier)) + max_scan_slack

checks = {
    "recovery_latency_pass": recovery_plus_query_ms <= max_recovery_ms,
    "steady_query_latency_pass": steady_query_ms <= max_query_ms,
    "scan_bound_pass": scanned_nodes <= scan_bound,
    "recovered_delta_pass": node1_has_delta,
}
overall_pass = all(checks.values())

report = {
    "gate": "core_recovery_gate",
    "passed": overall_pass,
    "config": {
        "data_dir": str(data_dir),
        "nodes": nodes,
        "batch_size": batch_size,
        "query_limit": query_limit,
    },
    "thresholds": {
        "max_recovery_ms": max_recovery_ms,
        "max_query_ms": max_query_ms,
        "max_scanned_multiplier": max_scanned_multiplier,
        "max_scan_slack": max_scan_slack,
    },
    "metrics": {
        "recovery_plus_query_ms": recovery_plus_query_ms,
        "steady_query_ms": steady_query_ms,
        "scanned_nodes": scanned_nodes,
        "scan_bound": scan_bound,
        "steady_scanned_nodes": int(stream_steady.get("scanned_nodes", 0)),
    },
    "checks": checks,
}

out_json.write_text(json.dumps(report, indent=2), encoding="utf-8")
out_md.write_text(
    "\n".join(
        [
            "# Core Recovery Gate",
            "",
            f"- passed: {str(overall_pass).lower()}",
            f"- nodes: {nodes}",
            f"- query_limit: {query_limit}",
            "",
            "## Metrics",
            f"- recovery_plus_query_ms: {recovery_plus_query_ms:.3f}",
            f"- steady_query_ms: {steady_query_ms:.3f}",
            f"- scanned_nodes: {scanned_nodes}",
            f"- scan_bound: {scan_bound}",
            "",
            "## Checks",
            f"- recovery_latency_pass: {str(checks['recovery_latency_pass']).lower()}",
            f"- steady_query_latency_pass: {str(checks['steady_query_latency_pass']).lower()}",
            f"- scan_bound_pass: {str(checks['scan_bound_pass']).lower()}",
            f"- recovered_delta_pass: {str(checks['recovered_delta_pass']).lower()}",
        ]
    ),
    encoding="utf-8",
)
print(f"core_recovery_gate passed={str(overall_pass).lower()}")
if not overall_pass:
    raise SystemExit(1)
PY

echo "json_report: $OUT_JSON"
echo "md_report: $OUT_MD"
