#!/usr/bin/env bash
set -euo pipefail

ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
ARTIFACTS_DIR="${ARTIFACTS_DIR:-$ROOT_DIR/artifacts}"
REPORT_PREFIX="${REPORT_PREFIX:-bitmap_perf_matrix}"
OUT_JSON="${OUT_JSON:-$ARTIFACTS_DIR/${REPORT_PREFIX}_report.json}"
OUT_MD="${OUT_MD:-$ARTIFACTS_DIR/${REPORT_PREFIX}_report.md}"

PYTHON_BIN="${PYTHON_BIN:-python3}"
DATA_ROOT="${DATA_ROOT:-/tmp/iridium-bitmap-perf}"
NODES="${NODES:-10000}"
ITERATIONS="${ITERATIONS:-20}"
SELECTIVITY_PCTS="${SELECTIVITY_PCTS:-0.1,1,5,20}"
INGEST_SELECTIVITY_PCT="${INGEST_SELECTIVITY_PCT:-5}"
MORSEL_SIZE="${MORSEL_SIZE:-256}"
PARALLEL_WORKERS="${PARALLEL_WORKERS:-0}"

MAX_INGEST_REGRESSION_PCT="${MAX_INGEST_REGRESSION_PCT:-20}"
MIN_SCAN_REDUCTION_FACTOR="${MIN_SCAN_REDUCTION_FACTOR:-2.0}"
SELECTIVE_LATENCY_MAX_PCT="${SELECTIVE_LATENCY_MAX_PCT:-5}"
MAX_SELECTIVE_P95_LATENCY_RATIO="${MAX_SELECTIVE_P95_LATENCY_RATIO:-1.0}"

mkdir -p "$ARTIFACTS_DIR"

"$PYTHON_BIN" - \
  "$OUT_JSON" "$OUT_MD" "$DATA_ROOT" "$NODES" "$ITERATIONS" "$SELECTIVITY_PCTS" \
  "$INGEST_SELECTIVITY_PCT" "$MORSEL_SIZE" "$PARALLEL_WORKERS" \
  "$MAX_INGEST_REGRESSION_PCT" "$MIN_SCAN_REDUCTION_FACTOR" "$SELECTIVE_LATENCY_MAX_PCT" \
  "$MAX_SELECTIVE_P95_LATENCY_RATIO" <<'PY'
import json
import math
import shutil
import statistics
import sys
import time
from pathlib import Path

try:
    import iridium
except Exception as exc:
    print(f"bitmap_perf_matrix: failed to import iridium ({exc})", file=sys.stderr)
    raise SystemExit(2)


def percentile(values, p):
    if not values:
        return 0.0
    vals = sorted(values)
    idx = int(math.ceil((p / 100.0) * len(vals))) - 1
    idx = max(0, min(idx, len(vals) - 1))
    return vals[idx]


def summarize(values):
    if not values:
        return {"avg": 0.0, "p50": 0.0, "p95": 0.0, "p99": 0.0, "max": 0.0}
    vals = sorted(values)
    return {
        "avg": statistics.fmean(vals),
        "p50": percentile(vals, 50),
        "p95": percentile(vals, 95),
        "p99": percentile(vals, 99),
        "max": vals[-1],
    }


def select_ids(nodes, pct):
    pct = max(0.0001, float(pct))
    stride = max(1, int(round(100.0 / pct)))
    ids = list(range(stride, nodes + 1, stride))
    if not ids:
        ids = [nodes]
    return ids


def make_client(path, nodes, morsel_size, parallel_workers):
    return iridium.Client(
        data_dir=str(path),
        scan_start=0,
        scan_end_exclusive=nodes + 8,
        morsel_size=morsel_size,
        parallel_workers=parallel_workers,
    )


def ingest_dataset(path, nodes, morsel_size, parallel_workers, bitmap_enabled, ingest_pct):
    if path.exists():
        shutil.rmtree(path)
    path.mkdir(parents=True, exist_ok=True)

    indexed_ids = select_ids(nodes, ingest_pct)
    indexed_set = set(indexed_ids)
    client = make_client(path, nodes, morsel_size, parallel_workers)
    if bitmap_enabled:
        client.create_bitmap_index("idx_country", "n.country")

    start = time.perf_counter()
    client.begin_ingest()
    try:
        for node_id in range(1, nodes + 1):
            client.ingest_node(node_id, 1, [node_id + 1])
            if bitmap_enabled and node_id in indexed_set:
                client.bitmap_add_posting("idx_country", "ingest_sample", node_id)
    finally:
        client.finish_ingest()
    elapsed = max(1e-9, time.perf_counter() - start)
    return {
        "elapsed_seconds": elapsed,
        "nodes": nodes,
        "indexed_nodes": len(indexed_ids) if bitmap_enabled else 0,
        "nodes_per_second": nodes / elapsed,
    }


out_json = Path(sys.argv[1])
out_md = Path(sys.argv[2])
data_root = Path(sys.argv[3])
nodes = int(sys.argv[4])
iterations = int(sys.argv[5])
selectivity_pcts = [float(v.strip()) for v in sys.argv[6].split(",") if v.strip()]
ingest_selectivity_pct = float(sys.argv[7])
morsel_size = int(sys.argv[8])
parallel_workers = int(sys.argv[9])
max_ingest_regression_pct = float(sys.argv[10])
min_scan_reduction_factor = float(sys.argv[11])
selective_latency_max_pct = float(sys.argv[12])
max_selective_p95_latency_ratio = float(sys.argv[13])

if nodes <= 0:
    raise ValueError("NODES must be > 0")
if iterations <= 0:
    raise ValueError("ITERATIONS must be > 0")
if morsel_size <= 0:
    raise ValueError("MORSEL_SIZE must be > 0")
if not selectivity_pcts:
    raise ValueError("SELECTIVITY_PCTS must not be empty")

plain_data = data_root / "plain"
bitmap_data = data_root / "bitmap"

ingest_plain = ingest_dataset(
    plain_data, nodes, morsel_size, parallel_workers, bitmap_enabled=False, ingest_pct=ingest_selectivity_pct
)
ingest_bitmap = ingest_dataset(
    bitmap_data, nodes, morsel_size, parallel_workers, bitmap_enabled=True, ingest_pct=ingest_selectivity_pct
)

ingest_regression_pct = (
    ((ingest_plain["nodes_per_second"] - ingest_bitmap["nodes_per_second"]) / ingest_plain["nodes_per_second"]) * 100.0
    if ingest_plain["nodes_per_second"] > 0
    else 0.0
)
ingest_gate_pass = ingest_regression_pct <= max_ingest_regression_pct

client = make_client(bitmap_data, nodes, morsel_size, parallel_workers)
client.begin_ingest()
try:
    for pct in selectivity_pcts:
        key = f"sel_{str(pct).replace('.', '_')}"
        for node_id in select_ids(nodes, pct):
            client.bitmap_add_posting("idx_country", key, int(node_id))
finally:
    client.finish_ingest()

baseline_query = f"MATCH (n) RETURN n LIMIT {nodes + 8}"
cases = []
for pct in selectivity_pcts:
    value_key = f"sel_{str(pct).replace('.', '_')}"
    bitmap_query = f"MATCH (n) WHERE bitmap.contains(idx_country, {value_key}) = 1 RETURN n LIMIT {nodes + 8}"
    expected_ids = select_ids(nodes, pct)

    baseline_lat_us = []
    bitmap_lat_us = []
    baseline_scanned = []
    bitmap_scanned = []
    bitmap_ids_last = []

    for _ in range(iterations):
        baseline = json.loads(client.query_json(baseline_query))
        bitmap = json.loads(client.query_json(bitmap_query))
        baseline_lat_us.append(int(baseline.get("latency_micros", 0)))
        bitmap_lat_us.append(int(bitmap.get("latency_micros", 0)))
        baseline_scanned.append(int(baseline.get("scanned_nodes", 0)))
        bitmap_scanned.append(int(bitmap.get("scanned_nodes", 0)))
        bitmap_ids_last = sorted(int(row["node_id"]) for row in bitmap.get("rows", []))

    completeness_pass = bitmap_ids_last == expected_ids
    baseline_scan_avg = statistics.fmean(baseline_scanned) if baseline_scanned else 0.0
    bitmap_scan_avg = statistics.fmean(bitmap_scanned) if bitmap_scanned else 0.0
    scan_reduction = (baseline_scan_avg / bitmap_scan_avg) if bitmap_scan_avg > 0 else 0.0
    scan_gate_pass = scan_reduction >= min_scan_reduction_factor

    baseline_lat = summarize(baseline_lat_us)
    bitmap_lat = summarize(bitmap_lat_us)
    p95_ratio = (bitmap_lat["p95"] / baseline_lat["p95"]) if baseline_lat["p95"] > 0 else 0.0
    selective_case = pct <= selective_latency_max_pct
    latency_gate_pass = (not selective_case) or (p95_ratio <= max_selective_p95_latency_ratio)

    case_pass = completeness_pass and scan_gate_pass and latency_gate_pass
    cases.append(
        {
            "selectivity_pct": pct,
            "expected_rows": len(expected_ids),
            "actual_rows": len(bitmap_ids_last),
            "completeness_pass": completeness_pass,
            "scan": {
                "baseline_avg_scanned_nodes": baseline_scan_avg,
                "bitmap_avg_scanned_nodes": bitmap_scan_avg,
                "reduction_factor": scan_reduction,
                "pass": scan_gate_pass,
            },
            "latency_micros": {
                "baseline": baseline_lat,
                "bitmap": bitmap_lat,
                "bitmap_to_baseline_p95_ratio": p95_ratio,
                "selective_case": selective_case,
                "pass": latency_gate_pass,
            },
            "pass": case_pass,
        }
    )

query_gate_pass = all(case["pass"] for case in cases)
overall_pass = ingest_gate_pass and query_gate_pass

report = {
    "gate": "bitmap_perf_matrix",
    "passed": overall_pass,
    "thresholds": {
        "max_ingest_regression_pct": max_ingest_regression_pct,
        "min_scan_reduction_factor": min_scan_reduction_factor,
        "selective_latency_max_pct": selective_latency_max_pct,
        "max_selective_p95_latency_ratio": max_selective_p95_latency_ratio,
    },
    "config": {
        "data_root": str(data_root),
        "nodes": nodes,
        "iterations": iterations,
        "selectivity_pcts": selectivity_pcts,
        "ingest_selectivity_pct": ingest_selectivity_pct,
        "morsel_size": morsel_size,
        "parallel_workers": parallel_workers,
    },
    "ingest": {
        "baseline_no_bitmap": ingest_plain,
        "bitmap_enabled": ingest_bitmap,
        "regression_pct": ingest_regression_pct,
        "pass": ingest_gate_pass,
    },
    "query_cases": cases,
}
out_json.write_text(json.dumps(report, indent=2), encoding="utf-8")

lines = [
    "# Bitmap Performance Matrix Report",
    "",
    f"- passed: {str(overall_pass).lower()}",
    f"- nodes: {nodes}",
    f"- iterations: {iterations}",
    f"- selectivity_pcts: {','.join(str(v) for v in selectivity_pcts)}",
    "",
    "## Ingest Overhead",
    f"- baseline_nodes_per_second: {ingest_plain['nodes_per_second']:.2f}",
    f"- bitmap_nodes_per_second: {ingest_bitmap['nodes_per_second']:.2f}",
    f"- regression_pct: {ingest_regression_pct:.2f}",
    f"- threshold_max_regression_pct: {max_ingest_regression_pct:.2f}",
    f"- pass: {str(ingest_gate_pass).lower()}",
    "",
    "## Query Selectivity Sweep",
]

for case in cases:
    lines.extend(
        [
            f"### selectivity_{case['selectivity_pct']}pct",
            f"- expected_rows: {case['expected_rows']}",
            f"- actual_rows: {case['actual_rows']}",
            f"- completeness_pass: {str(case['completeness_pass']).lower()}",
            f"- scan_reduction_factor: {case['scan']['reduction_factor']:.2f}",
            f"- scan_pass: {str(case['scan']['pass']).lower()}",
            f"- p95_latency_ratio_bitmap_over_baseline: {case['latency_micros']['bitmap_to_baseline_p95_ratio']:.4f}",
            f"- latency_pass: {str(case['latency_micros']['pass']).lower()}",
            f"- pass: {str(case['pass']).lower()}",
            "",
        ]
    )

out_md.write_text("\n".join(lines), encoding="utf-8")
print(f"bitmap_perf_matrix passed={str(overall_pass).lower()} ingest_regression_pct={ingest_regression_pct:.2f}")
if not overall_pass:
    raise SystemExit(1)
PY

echo "json_report: $OUT_JSON"
echo "md_report: $OUT_MD"
