#!/usr/bin/env bash
set -euo pipefail

ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
ARTIFACTS_DIR="${ARTIFACTS_DIR:-$ROOT_DIR/artifacts}"
REPORT_PREFIX="${REPORT_PREFIX:-bitmap_gate}"
OUT_JSON="${OUT_JSON:-$ARTIFACTS_DIR/${REPORT_PREFIX}_report.json}"
OUT_MD="${OUT_MD:-$ARTIFACTS_DIR/${REPORT_PREFIX}_report.md}"
DATA_DIR="${DATA_DIR:-/tmp/iridium-bitmap-gate-data}"
NODES="${NODES:-8000}"
SELECTIVE_STRIDE="${SELECTIVE_STRIDE:-80}"
ITERATIONS="${ITERATIONS:-20}"
MORSEL_SIZE="${MORSEL_SIZE:-256}"
PARALLEL_WORKERS="${PARALLEL_WORKERS:-0}"
SCAN_START="${SCAN_START:-0}"

mkdir -p "$ARTIFACTS_DIR"

python3 - "$OUT_JSON" "$OUT_MD" "$DATA_DIR" "$NODES" "$SELECTIVE_STRIDE" "$ITERATIONS" "$MORSEL_SIZE" "$PARALLEL_WORKERS" "$SCAN_START" <<'PY'
import json
import math
import os
import shutil
import statistics
import sys
import time
from pathlib import Path

try:
    import iridium
except Exception as exc:
    print(f"bitmap_gate: failed to import iridium ({exc})", file=sys.stderr)
    raise SystemExit(2)


def percentile(sorted_values, p):
    if not sorted_values:
        return 0.0
    idx = int(math.ceil((p / 100.0) * len(sorted_values))) - 1
    idx = max(0, min(idx, len(sorted_values) - 1))
    return sorted_values[idx]


def summarize(values):
    if not values:
        return {"avg": 0.0, "p50": 0.0, "p95": 0.0, "max": 0.0}
    vals = sorted(values)
    return {
        "avg": statistics.fmean(vals),
        "p50": percentile(vals, 50),
        "p95": percentile(vals, 95),
        "max": vals[-1],
    }


out_json = Path(sys.argv[1])
out_md = Path(sys.argv[2])
data_dir = Path(sys.argv[3])
nodes = int(sys.argv[4])
stride = int(sys.argv[5])
iterations = int(sys.argv[6])
morsel_size = int(sys.argv[7])
parallel_workers = int(sys.argv[8])
scan_start = int(sys.argv[9])

if nodes <= 0:
    raise ValueError("NODES must be > 0")
if stride <= 0:
    raise ValueError("SELECTIVE_STRIDE must be > 0")
if iterations <= 0:
    raise ValueError("ITERATIONS must be > 0")

if data_dir.exists():
    shutil.rmtree(data_dir)
data_dir.mkdir(parents=True, exist_ok=True)

scan_end_exclusive = nodes + 8
client = iridium.Client(
    data_dir=str(data_dir),
    scan_start=scan_start,
    scan_end_exclusive=scan_end_exclusive,
    morsel_size=morsel_size,
    parallel_workers=parallel_workers,
)

client.create_bitmap_index("idx_country", "n.country")
client.begin_ingest()
try:
    for node_id in range(1, nodes + 1):
        client.ingest_node(node_id, 1, [node_id + 1])
        if node_id % stride == 0:
            client.bitmap_add_posting("idx_country", "US", node_id)
finally:
    client.finish_ingest()

expected_ids = [node_id for node_id in range(1, nodes + 1) if node_id % stride == 0]

baseline_query = f"MATCH (n) RETURN n LIMIT {nodes + 8}"
bitmap_query = "MATCH (n) WHERE bitmap.contains(idx_country, US) = 1 RETURN n LIMIT 100000"

baseline_lat_us = []
bitmap_lat_us = []
baseline_scanned = []
bitmap_scanned = []
bitmap_rows = []

for _ in range(iterations):
    baseline = json.loads(client.query_json(baseline_query))
    bitmap = json.loads(client.query_json(bitmap_query))
    baseline_lat_us.append(int(baseline.get("latency_micros", 0)))
    bitmap_lat_us.append(int(bitmap.get("latency_micros", 0)))
    baseline_scanned.append(int(baseline.get("scanned_nodes", 0)))
    bitmap_scanned.append(int(bitmap.get("scanned_nodes", 0)))
    bitmap_rows.append([int(row["node_id"]) for row in bitmap.get("rows", [])])

actual_ids = sorted(bitmap_rows[-1]) if bitmap_rows else []
completeness_pass = actual_ids == expected_ids
scan_avg_baseline = statistics.fmean(baseline_scanned) if baseline_scanned else 0.0
scan_avg_bitmap = statistics.fmean(bitmap_scanned) if bitmap_scanned else 0.0
scan_reduction = (scan_avg_baseline / scan_avg_bitmap) if scan_avg_bitmap > 0 else 0.0
scan_gate_pass = scan_avg_bitmap < scan_avg_baseline

lat_base = summarize(baseline_lat_us)
lat_bitmap = summarize(bitmap_lat_us)
latency_gate_pass = lat_bitmap["p95"] <= lat_base["p95"]

passed = completeness_pass and scan_gate_pass and latency_gate_pass

report = {
    "gate": "bitmap_gate",
    "passed": passed,
    "config": {
        "data_dir": str(data_dir),
        "nodes": nodes,
        "selective_stride": stride,
        "iterations": iterations,
        "scan_start": scan_start,
        "scan_end_exclusive": scan_end_exclusive,
        "morsel_size": morsel_size,
        "parallel_workers": parallel_workers,
    },
    "queries": {
        "baseline": baseline_query,
        "bitmap": bitmap_query,
    },
    "expected_rows": len(expected_ids),
    "actual_rows": len(actual_ids),
    "completeness_pass": completeness_pass,
    "scan_reduction": {
        "baseline_avg_scanned_nodes": scan_avg_baseline,
        "bitmap_avg_scanned_nodes": scan_avg_bitmap,
        "reduction_factor": scan_reduction,
        "pass": scan_gate_pass,
    },
    "latency_micros": {
        "baseline": lat_base,
        "bitmap": lat_bitmap,
        "pass": latency_gate_pass,
    },
}

out_json.write_text(json.dumps(report, indent=2), encoding="utf-8")

lines = [
    "# Bitmap Gate Report",
    "",
    f"- passed: {str(passed).lower()}",
    f"- data_dir: {data_dir}",
    f"- nodes: {nodes}",
    f"- selective_stride: {stride}",
    f"- iterations: {iterations}",
    "",
    "## Completeness",
    f"- expected_rows: {len(expected_ids)}",
    f"- actual_rows: {len(actual_ids)}",
    f"- pass: {str(completeness_pass).lower()}",
    "",
    "## Scan Reduction",
    f"- baseline_avg_scanned_nodes: {scan_avg_baseline:.2f}",
    f"- bitmap_avg_scanned_nodes: {scan_avg_bitmap:.2f}",
    f"- reduction_factor: {scan_reduction:.2f}",
    f"- pass: {str(scan_gate_pass).lower()}",
    "",
    "## Latency Micros",
    f"- baseline_p50: {lat_base['p50']:.2f}",
    f"- baseline_p95: {lat_base['p95']:.2f}",
    f"- bitmap_p50: {lat_bitmap['p50']:.2f}",
    f"- bitmap_p95: {lat_bitmap['p95']:.2f}",
    f"- pass: {str(latency_gate_pass).lower()}",
    "",
]
out_md.write_text("\n".join(lines), encoding="utf-8")

print(f"bitmap_gate passed={str(passed).lower()} reduction_factor={scan_reduction:.2f}")
if not passed:
    raise SystemExit(1)
PY

echo "json_report: $OUT_JSON"
echo "md_report: $OUT_MD"
