from __future__ import annotations
import argparse
import json
import math
import shutil
import socket
import statistics
import subprocess
import sys
import tempfile
import time
import urllib.error
import urllib.parse
import urllib.request
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, Iterable, List, Protocol, Sequence
ROOT = Path(__file__).resolve().parents[1]
sys.path.insert(0, str(ROOT / "sdk" / "python"))
from iridium_driver import IridiumClient
@dataclass(frozen=True)
class NodeFixture:
node_id: int
adjacency: tuple[int, ...]
vector: tuple[float, ...]
@dataclass(frozen=True)
class QueryFixture:
name: str
query_vector: tuple[float, ...]
threshold: float
class QueryClient(Protocol):
def query(self, query: str) -> Dict[str, Any]:
pass
NODES: Sequence[NodeFixture] = (
NodeFixture(1, (2, 3), (1.0, 0.0)),
NodeFixture(2, (1, 3), (0.98, 0.02)),
NodeFixture(3, (1, 2), (0.96, 0.10)),
NodeFixture(4, (5, 6), (0.0, 1.0)),
NodeFixture(5, (4, 6), (0.02, 0.99)),
NodeFixture(6, (4, 5), (-0.10, 0.97)),
NodeFixture(7, (8, 9), (-1.0, 0.0)),
NodeFixture(8, (7, 9), (-0.98, 0.02)),
NodeFixture(9, (7, 8), (-0.96, -0.10)),
NodeFixture(10, (11, 12), (0.70, 0.70)),
NodeFixture(11, (10, 12), (0.0, -1.0)),
NodeFixture(12, (10, 11), (0.60, -0.80)),
)
QUERIES: Sequence[QueryFixture] = (
QueryFixture("risk_east", (1.0, 0.0), 0.95),
QueryFixture("risk_north", (0.0, 1.0), 0.95),
QueryFixture("risk_west", (-1.0, 0.0), 0.95),
)
DATASET_ID = "canonical-fixture-v1"
DATASET_LABEL = "deterministic embedded retrieval fixture"
DATASET_DESCRIPTION = (
"Synthetic three-cluster vector fixture used to validate retrieval quality deterministically "
"for the Iridium release-candidate path. This report is quality evidence only and is not a "
"performance benchmark."
)
class ServiceQueryClient:
def __init__(self, listen: str) -> None:
self.base_url = f"http://{listen}"
def query(self, query: str) -> Dict[str, Any]:
encoded = urllib.parse.quote_plus(query)
request = urllib.request.Request(
f"{self.base_url}/v1/query?q={encoded}",
method="GET",
)
with urllib.request.urlopen(request, timeout=2) as response:
return json.loads(response.read().decode("utf-8"))
def cosine_similarity(lhs: Sequence[float], rhs: Sequence[float]) -> float:
dot = sum(a * b for a, b in zip(lhs, rhs))
lhs_norm = math.sqrt(sum(a * a for a in lhs))
rhs_norm = math.sqrt(sum(b * b for b in rhs))
if lhs_norm == 0.0 or rhs_norm == 0.0:
return 0.0
return dot / (lhs_norm * rhs_norm)
def inline_vector(values: Iterable[float]) -> str:
return "$q:" + ":".join(f"{value:g}" for value in values)
def precision_at_k(returned_ids: Sequence[int], relevant_ids: set[int], k: int) -> float:
prefix = returned_ids[:k]
if not prefix:
return 0.0
hits = sum(1 for node_id in prefix if node_id in relevant_ids)
return hits / float(k)
def recall_at_k(returned_ids: Sequence[int], relevant_ids: set[int], k: int) -> float:
if not relevant_ids:
return 1.0
prefix = returned_ids[:k]
hits = sum(1 for node_id in prefix if node_id in relevant_ids)
return hits / float(len(relevant_ids))
def reciprocal_rank(returned_ids: Sequence[int], relevant_ids: set[int]) -> float:
for index, node_id in enumerate(returned_ids, start=1):
if node_id in relevant_ids:
return 1.0 / float(index)
return 0.0
def run_command(command: list[str]) -> None:
subprocess.run(
command,
cwd=ROOT,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
check=True,
)
def seed_fixture_via_cli(binary: str, data_dir: Path) -> None:
for node in NODES:
run_command(
[
binary,
"--data",
str(data_dir),
"ingest-node",
str(node.node_id),
"1",
",".join(str(entry) for entry in node.adjacency),
]
)
run_command(
[
binary,
"--data",
str(data_dir),
"ingest-vector",
str(node.node_id),
"2",
",".join(f"{value:g}" for value in node.vector),
]
)
def seed_fixture_embedded(data_dir: Path, binary: str) -> QueryClient:
client = IridiumClient(data_dir=str(data_dir), binary=binary)
for node in NODES:
client.ingest_node(node.node_id, 1, node.adjacency)
client.ingest_vector(node.node_id, 2, node.vector)
return client
def evaluate_query(
client: QueryClient,
fixture: QueryFixture,
eval_k: int,
) -> Dict[str, Any]:
query = (
"MATCH (n) "
f"WHERE vector.cosine(n.embedding, {inline_vector(fixture.query_vector)}) > {fixture.threshold} "
f"RETURN n LIMIT {len(NODES)}"
)
result = client.query(query)
rows = result.get("rows", [])
returned_ids = [int(row["node_id"]) for row in rows]
exact_scores = []
for node in NODES:
score = cosine_similarity(fixture.query_vector, node.vector)
exact_scores.append((node.node_id, score))
exact_scores.sort(key=lambda item: (-item[1], item[0]))
relevant_ids = {
node_id for node_id, score in exact_scores if score > fixture.threshold
}
expected_ranked_ids = [node_id for node_id, _score in exact_scores if node_id in relevant_ids]
return {
"query": query,
"threshold": fixture.threshold,
"returned_ids": returned_ids,
"expected_relevant_ids": expected_ranked_ids,
"precision_at_k": precision_at_k(returned_ids, relevant_ids, eval_k),
"recall_at_k": recall_at_k(returned_ids, relevant_ids, eval_k),
"mrr": reciprocal_rank(returned_ids, relevant_ids),
"avg_returned_score": statistics.fmean(
row.get("score", 0.0) or 0.0 for row in rows[:eval_k]
)
if rows
else 0.0,
"scanned_nodes": int(result.get("scanned_nodes", 0)),
"rerank_batches": int(result.get("rerank_batches", 0)),
"returned_count": len(returned_ids),
}
def write_reports(report_dir: Path, prefix: str, report: Dict[str, Any]) -> None:
report_dir.mkdir(parents=True, exist_ok=True)
json_path = report_dir / f"{prefix}_report.json"
md_path = report_dir / f"{prefix}_report.md"
json_path.write_text(json.dumps(report, indent=2), encoding="utf-8")
md_lines = [
"# Retrieval Quality Report",
"",
f"- schema: {report['schema']}",
f"- release_candidate: {report['release_candidate']}",
f"- runtime_mode: {report['runtime_mode']}",
f"- dataset_id: {report['dataset']['id']}",
f"- dataset_label: {report['dataset']['label']}",
f"- eval_k: {report['config']['eval_k']}",
f"- binary: {report['config']['binary']}",
f"- work_dir: {report['config']['work_dir']}",
f"- evidence_class: {report['evidence']['class']}",
f"- quality_claim: {report['evidence']['quality_claim']}",
f"- performance_note: {report['evidence']['performance_note']}",
"",
"## Macro Metrics",
f"- precision_at_k: {report['macro']['precision_at_k']:.3f}",
f"- recall_at_k: {report['macro']['recall_at_k']:.3f}",
f"- mrr: {report['macro']['mrr']:.3f}",
f"- overall_pass: {str(report['quality_gate']['overall_pass']).lower()}",
f"- alloy_contract_path: {report['alloy_contract_path']}",
]
if report["service_report_path"]:
md_lines.append(f"- service_report_path: {report['service_report_path']}")
md_lines.extend(["", "## Per Query"])
for name, details in report["queries"].items():
md_lines.extend(
[
f"### {name}",
f"- expected_relevant_ids: {details['expected_relevant_ids']}",
f"- returned_ids: {details['returned_ids']}",
f"- precision_at_k: {details['precision_at_k']:.3f}",
f"- recall_at_k: {details['recall_at_k']:.3f}",
f"- mrr: {details['mrr']:.3f}",
f"- scanned_nodes: {details['scanned_nodes']}",
f"- rerank_batches: {details['rerank_batches']}",
"",
]
)
md_path.write_text("\n".join(md_lines), encoding="utf-8")
print(f"wrote: {json_path}")
print(f"wrote: {md_path}")
def parse_args(argv: Sequence[str]) -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Run deterministic retrieval-quality evaluation for Iridium"
)
parser.add_argument(
"--work-dir",
default="/tmp/iridium-retrieval-quality",
help="Scratch directory used for the data root",
)
parser.add_argument(
"--binary",
default=str(ROOT / "target" / "debug" / "ir"),
help="Path to the Iridium CLI binary used by local and service paths",
)
parser.add_argument("--report-dir", default="artifacts", help="Directory for JSON/Markdown reports")
parser.add_argument("--report-prefix", default="retrieval_quality_gate")
parser.add_argument("--eval-k", type=int, default=3)
parser.add_argument("--min-precision-at-k", type=float, default=1.0)
parser.add_argument("--min-recall-at-k", type=float, default=1.0)
parser.add_argument("--min-mrr", type=float, default=1.0)
parser.add_argument(
"--runtime-mode",
choices=("embedded", "service"),
default="embedded",
)
parser.add_argument("--listen", default="127.0.0.1:0")
parser.add_argument("--telemetry-endpoint", default="stdout")
parser.add_argument("--tls", default="operator-optional")
parser.add_argument("--admin-token", default="local-dev")
return parser.parse_args(argv)
def load_alloy_contract(binary: str, report_dir: Path, prefix: str) -> str:
contract_path = report_dir / f"{prefix}_alloy_contract.json"
output = subprocess.run(
[binary, "contract-report", "retrieval-quality"],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
check=True,
)
contract_path.write_text(output.stdout, encoding="utf-8")
return contract_path.name
def load_service_report(binary: str, listen: str, report_dir: Path, prefix: str) -> str:
service_path = report_dir / f"{prefix}_service_report.json"
output = subprocess.run(
[binary, "service-report", "--listen", listen],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
check=True,
)
service_path.write_text(output.stdout, encoding="utf-8")
return service_path.name
def wait_for_service(listen: str, attempts: int = 20) -> None:
livez = f"http://{listen}/livez"
for _ in range(attempts):
try:
with urllib.request.urlopen(livez, timeout=2) as response:
if response.status == 200:
return
except (urllib.error.URLError, TimeoutError):
pass
time.sleep(0.1)
raise RuntimeError("service failed to become ready for retrieval-quality gate")
def resolve_listen(listen: str) -> str:
if not listen.endswith(":0"):
return listen
host = listen.rsplit(":", 1)[0]
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
sock.bind((host, 0))
return f"{host}:{sock.getsockname()[1]}"
def main(argv: Sequence[str]) -> int:
args = parse_args(argv)
if args.eval_k <= 0:
raise ValueError("--eval-k must be > 0")
if not 0.0 <= args.min_precision_at_k <= 1.0:
raise ValueError("--min-precision-at-k must be between 0 and 1")
if not 0.0 <= args.min_recall_at_k <= 1.0:
raise ValueError("--min-recall-at-k must be between 0 and 1")
if not 0.0 <= args.min_mrr <= 1.0:
raise ValueError("--min-mrr must be between 0 and 1")
work_dir = Path(args.work_dir).resolve()
data_dir = work_dir / "data"
if data_dir.exists():
shutil.rmtree(data_dir)
data_dir.mkdir(parents=True, exist_ok=True)
report_dir = Path(args.report_dir)
service_process: subprocess.Popen[str] | None = None
service_report_path: str | None = None
client: QueryClient
listen = resolve_listen(args.listen)
try:
if args.runtime_mode == "embedded":
client = seed_fixture_embedded(data_dir, args.binary)
else:
seed_fixture_via_cli(args.binary, data_dir)
service_process = subprocess.Popen(
[
args.binary,
"--data",
str(data_dir),
"service-serve",
"--listen",
listen,
"--telemetry-endpoint",
args.telemetry_endpoint,
"--tls",
args.tls,
"--admin-token",
args.admin_token,
"--max-requests",
str(len(QUERIES) + 1),
],
cwd=ROOT,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
)
wait_for_service(listen)
if service_process.poll() is not None:
stderr = service_process.stderr.read() if service_process.stderr else ""
raise RuntimeError(f"service failed to stay up: {stderr}")
client = ServiceQueryClient(listen)
service_report_path = load_service_report(
args.binary, listen, report_dir, args.report_prefix
)
query_results = {
fixture.name: evaluate_query(client, fixture, args.eval_k) for fixture in QUERIES
}
macro = {
"precision_at_k": statistics.fmean(
item["precision_at_k"] for item in query_results.values()
),
"recall_at_k": statistics.fmean(
item["recall_at_k"] for item in query_results.values()
),
"mrr": statistics.fmean(item["mrr"] for item in query_results.values()),
}
quality_gate = {
"minimums": {
"precision_at_k": args.min_precision_at_k,
"recall_at_k": args.min_recall_at_k,
"mrr": args.min_mrr,
},
"checks": {
"precision_at_k": macro["precision_at_k"] >= args.min_precision_at_k,
"recall_at_k": macro["recall_at_k"] >= args.min_recall_at_k,
"mrr": macro["mrr"] >= args.min_mrr,
},
}
quality_gate["overall_pass"] = all(quality_gate["checks"].values())
alloy_contract_path = load_alloy_contract(args.binary, report_dir, args.report_prefix)
report = {
"schema": "iridium.retrieval-quality.v1",
"release_candidate": (
"service-phase-2" if args.runtime_mode == "service" else "embedded-phase-1"
),
"runtime_mode": args.runtime_mode,
"config": {
"work_dir": str(work_dir),
"data_dir": str(data_dir),
"binary": str(Path(args.binary).resolve()),
"eval_k": args.eval_k,
"node_count": len(NODES),
"query_count": len(QUERIES),
"listen": listen if args.runtime_mode == "service" else None,
},
"dataset": {
"id": DATASET_ID,
"label": DATASET_LABEL,
"description": DATASET_DESCRIPTION,
"canonical_for_release_candidate": True,
"node_count": len(NODES),
"query_count": len(QUERIES),
},
"evidence": {
"class": "quality-report",
"quality_claim": (
f"Deterministic {args.runtime_mode} retrieval quality on the canonical "
"fixture meets the release-candidate minimums."
),
"performance_note": (
"Latency and throughput are intentionally out of scope for this report and "
"remain covered by separate benchmark gates."
),
},
"alloy_contract_path": alloy_contract_path,
"service_report_path": service_report_path,
"macro": macro,
"quality_gate": quality_gate,
"queries": query_results,
}
write_reports(report_dir, args.report_prefix, report)
return 0 if quality_gate["overall_pass"] else 1
finally:
if service_process and service_process.poll() is None:
service_process.terminate()
try:
service_process.wait(timeout=2)
except subprocess.TimeoutExpired:
service_process.kill()
service_process.wait(timeout=2)
if __name__ == "__main__":
raise SystemExit(main(sys.argv[1:]))