#!/usr/bin/env bash
set -euo pipefail

ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
ARTIFACTS_DIR="${ARTIFACTS_DIR:-$ROOT_DIR/artifacts}"
REPORT_PREFIX="${REPORT_PREFIX:-core_perf_gate}"
OUT_JSON="${OUT_JSON:-$ARTIFACTS_DIR/${REPORT_PREFIX}_report.json}"
OUT_MD="${OUT_MD:-$ARTIFACTS_DIR/${REPORT_PREFIX}_report.md}"

PYTHON_BIN="${PYTHON_BIN:-python3}"
DATA_DIR="${DATA_DIR:-/tmp/iridium-core-perf-gate}"
NODES="${NODES:-12000}"
ITERATIONS="${ITERATIONS:-10}"
LIMIT="${LIMIT:-200}"
MORSEL_SIZE="${MORSEL_SIZE:-256}"
PARALLEL_WORKERS="${PARALLEL_WORKERS:-0}"

MAX_VECTOR_SCAN_MULTIPLIER="${MAX_VECTOR_SCAN_MULTIPLIER:-1000}"
MAX_BITMAP_SCAN_MULTIPLIER="${MAX_BITMAP_SCAN_MULTIPLIER:-1}"
MAX_GRAPH_AVG_MS="${MAX_GRAPH_AVG_MS:-10}"
MAX_BITMAP_AVG_MS="${MAX_BITMAP_AVG_MS:-10}"
MAX_VECTOR_AVG_MS="${MAX_VECTOR_AVG_MS:-200}"

mkdir -p "$ARTIFACTS_DIR"

"$PYTHON_BIN" - \
  "$OUT_JSON" "$OUT_MD" "$DATA_DIR" "$NODES" "$ITERATIONS" "$LIMIT" "$MORSEL_SIZE" "$PARALLEL_WORKERS" \
  "$MAX_VECTOR_SCAN_MULTIPLIER" "$MAX_BITMAP_SCAN_MULTIPLIER" "$MAX_GRAPH_AVG_MS" "$MAX_BITMAP_AVG_MS" "$MAX_VECTOR_AVG_MS" <<'PY'
import json
import math
import shutil
import statistics
import sys
import time
from pathlib import Path

try:
    import iridium
except Exception as exc:
    print(f"core_perf_gate: failed to import iridium ({exc})", file=sys.stderr)
    raise SystemExit(2)


def percentile(values, p):
    if not values:
        return 0.0
    vals = sorted(values)
    idx = int(math.ceil((p / 100.0) * len(vals))) - 1
    idx = max(0, min(idx, len(vals) - 1))
    return vals[idx]


def summarize(values):
    if not values:
        return {"count": 0.0, "avg_ms": 0.0, "p50_ms": 0.0, "p95_ms": 0.0, "max_ms": 0.0}
    vals = sorted(values)
    return {
        "count": float(len(vals)),
        "avg_ms": statistics.fmean(vals),
        "p50_ms": percentile(vals, 50),
        "p95_ms": percentile(vals, 95),
        "max_ms": vals[-1],
    }


out_json = Path(sys.argv[1])
out_md = Path(sys.argv[2])
data_dir = Path(sys.argv[3])
nodes = int(sys.argv[4])
iterations = int(sys.argv[5])
limit = int(sys.argv[6])
morsel_size = int(sys.argv[7])
parallel_workers = int(sys.argv[8])
max_vector_scan_multiplier = float(sys.argv[9])
max_bitmap_scan_multiplier = float(sys.argv[10])
max_graph_avg_ms = float(sys.argv[11])
max_bitmap_avg_ms = float(sys.argv[12])
max_vector_avg_ms = float(sys.argv[13])

if data_dir.exists():
    shutil.rmtree(data_dir)
data_dir.mkdir(parents=True, exist_ok=True)

client = iridium.Client(
    data_dir=str(data_dir),
    scan_start=0,
    scan_end_exclusive=nodes + 1,
    morsel_size=morsel_size,
    parallel_workers=parallel_workers,
)

client.create_bitmap_index("idx_country", "n.country")
client.begin_ingest()
try:
    for node_id in range(1, nodes + 1):
        client.ingest_node(node_id, 1, [node_id + 1])
        if node_id % 3 == 0:
            client.bitmap_add_posting("idx_country", "US", node_id)
finally:
    client.finish_ingest()

queries = {
    "graph_scan": f"MATCH (n) RETURN n LIMIT {limit}",
    "vector_filter": f"MATCH (n) WHERE vector.cosine(n.embedding, $risk) > 0.65 RETURN n LIMIT {limit}",
    "bitmap_filter": f"MATCH (n) WHERE bitmap.contains(idx_country, US) = 1 RETURN n LIMIT {limit}",
}

latencies = {name: [] for name in queries}
scanned = {name: [] for name in queries}
for _ in range(iterations):
    for name, query in queries.items():
        t0 = time.perf_counter()
        out = json.loads(client.query_json(query))
        latencies[name].append((time.perf_counter() - t0) * 1000.0)
        scanned[name].append(int(out.get("scanned_nodes", 0)))

warmups = []
for _ in range(4):
    t0 = time.perf_counter()
    _ = json.loads(client.query_json(f"MATCH (n) WHERE bitmap.contains(idx_country, US) = 1 RETURN n LIMIT {limit}"))
    warmups.append((time.perf_counter() - t0) * 1000.0)
cold_ms = warmups[0]
warm_avg_ms = statistics.fmean(warmups[1:]) if len(warmups) > 1 else cold_ms
warmup_ratio = warm_avg_ms / cold_ms if cold_ms > 0 else 1.0

stats = {}
for name in queries:
    stats[name] = {
        **summarize(latencies[name]),
        "avg_scanned_nodes": statistics.fmean(scanned[name]) if scanned[name] else 0.0,
    }

vector_scan_pass = stats["vector_filter"]["avg_scanned_nodes"] <= (limit * max_vector_scan_multiplier)
bitmap_scan_pass = stats["bitmap_filter"]["avg_scanned_nodes"] <= (limit * max_bitmap_scan_multiplier)
graph_latency_pass = stats["graph_scan"]["avg_ms"] <= max_graph_avg_ms
bitmap_latency_pass = stats["bitmap_filter"]["avg_ms"] <= max_bitmap_avg_ms
vector_latency_pass = stats["vector_filter"]["avg_ms"] <= max_vector_avg_ms
overall_pass = (
    vector_scan_pass
    and bitmap_scan_pass
    and graph_latency_pass
    and bitmap_latency_pass
    and vector_latency_pass
)

report = {
    "gate": "core_perf_gate",
    "passed": overall_pass,
    "config": {
        "data_dir": str(data_dir),
        "nodes": nodes,
        "iterations": iterations,
        "limit": limit,
        "morsel_size": morsel_size,
        "parallel_workers": parallel_workers,
    },
    "thresholds": {
        "max_vector_scan_multiplier": max_vector_scan_multiplier,
        "max_bitmap_scan_multiplier": max_bitmap_scan_multiplier,
        "max_graph_avg_ms": max_graph_avg_ms,
        "max_bitmap_avg_ms": max_bitmap_avg_ms,
        "max_vector_avg_ms": max_vector_avg_ms,
    },
    "queries": stats,
    "cache_probe": {
        "cold_ms": cold_ms,
        "warm_avg_ms": warm_avg_ms,
        "warmup_ratio": warmup_ratio,
    },
    "checks": {
        "vector_scan_pass": vector_scan_pass,
        "bitmap_scan_pass": bitmap_scan_pass,
        "graph_latency_pass": graph_latency_pass,
        "bitmap_latency_pass": bitmap_latency_pass,
        "vector_latency_pass": vector_latency_pass,
    },
}

out_json.write_text(json.dumps(report, indent=2), encoding="utf-8")
lines = [
    "# Core Performance Gate",
    "",
    f"- passed: {str(overall_pass).lower()}",
    f"- nodes: {nodes}",
    f"- iterations: {iterations}",
    f"- limit: {limit}",
    "",
    "## Query Stats",
]
for name, q in stats.items():
    lines.extend(
        [
            f"### {name}",
            f"- avg_ms: {q['avg_ms']:.3f}",
            f"- p95_ms: {q['p95_ms']:.3f}",
            f"- avg_scanned_nodes: {q['avg_scanned_nodes']:.2f}",
            "",
        ]
    )
lines.extend(
    [
        "## Cache Probe",
        f"- cold_ms: {cold_ms:.3f}",
        f"- warm_avg_ms: {warm_avg_ms:.3f}",
        f"- warmup_ratio: {warmup_ratio:.3f}",
        "",
        "## Checks",
        f"- vector_scan_pass: {str(vector_scan_pass).lower()}",
        f"- bitmap_scan_pass: {str(bitmap_scan_pass).lower()}",
        f"- graph_latency_pass: {str(graph_latency_pass).lower()}",
        f"- bitmap_latency_pass: {str(bitmap_latency_pass).lower()}",
        f"- vector_latency_pass: {str(vector_latency_pass).lower()}",
    ]
)
out_md.write_text("\n".join(lines), encoding="utf-8")
print(f"core_perf_gate passed={str(overall_pass).lower()}")
if not overall_pass:
    raise SystemExit(1)
PY

echo "json_report: $OUT_JSON"
echo "md_report: $OUT_MD"
