import argparse
import csv
import json
import math
import subprocess
import sys
import time
from pathlib import Path
SCRIPT_DIR = Path(__file__).parent
GROUND_TRUTH_CSV = SCRIPT_DIR / "code.csv"
def parse_ground_truth(row):
query = row["query"]
results = []
for key in ("result1", "result2", "result3"):
val = row.get(key, "").strip()
if not val:
continue
parts = val.rsplit(":", 2)
if len(parts) != 3:
print(f" WARN: bad format '{val}', skipping", file=sys.stderr)
continue
path_with_lines, line_range, relevance = parts[0], parts[1], parts[2]
pass
results = []
for key in ("result1", "result2", "result3"):
val = row.get(key, "").strip()
if not val:
continue
last_colon = val.rfind(":")
if last_colon == -1:
continue
relevance = int(val[last_colon + 1:])
rest = val[:last_colon]
second_colon = rest.rfind(":")
if second_colon == -1:
continue
line_range = rest[second_colon + 1:]
path = rest[:second_colon]
start, end = line_range.split("-")
results.append({
"path": path,
"start_line": int(start),
"end_line": int(end),
"relevance": relevance,
})
return query, results
def run_search(query, mode="code", threshold=None, max_results=20):
cmd = ["octocode", "search", query, "--format", "json", "--mode", mode]
if threshold is not None:
cmd.extend(["--threshold", str(threshold)])
try:
result = subprocess.run(
cmd,
capture_output=True,
text=True,
timeout=60,
)
if result.returncode != 0:
return None, result.stderr.strip()
output = result.stdout.strip()
if not output:
return [], None
data = json.loads(output)
if isinstance(data, list):
return data[:max_results], None
if isinstance(data, dict) and "code_blocks" in data:
return data["code_blocks"][:max_results], None
return data[:max_results] if isinstance(data, list) else [], None
except subprocess.TimeoutExpired:
return None, "timeout"
except json.JSONDecodeError as e:
return None, f"json parse error: {e}"
def ranges_overlap(gt_start, gt_end, res_start, res_end):
return gt_start <= res_end and res_start <= gt_end
def match_result(result, ground_truths):
res_path = result.get("path", "")
res_start = result.get("start_line", 0)
res_end = result.get("end_line", 0)
best_relevance = 0
for gt in ground_truths:
if gt["path"] == res_path and ranges_overlap(
gt["start_line"], gt["end_line"], res_start, res_end
):
best_relevance = max(best_relevance, gt["relevance"])
return best_relevance
def compute_hit_at_k(results, ground_truths, k):
for result in results[:k]:
if match_result(result, ground_truths) > 0:
return 1
return 0
def compute_mrr(results, ground_truths):
for i, result in enumerate(results):
if match_result(result, ground_truths) > 0:
return 1.0 / (i + 1)
return 0.0
def compute_ndcg_at_k(results, ground_truths, k):
dcg = 0.0
for i, result in enumerate(results[:k]):
rel = match_result(result, ground_truths)
if rel > 0:
dcg += rel / math.log2(i + 2)
ideal_rels = sorted([gt["relevance"] for gt in ground_truths], reverse=True)
idcg = 0.0
for i, rel in enumerate(ideal_rels[:k]):
idcg += rel / math.log2(i + 2)
if idcg == 0:
return 0.0
return dcg / idcg
def compute_recall_at_k(results, ground_truths, k):
if not ground_truths:
return 0.0
found = set()
for result in results[:k]:
res_path = result.get("path", "")
res_start = result.get("start_line", 0)
res_end = result.get("end_line", 0)
for j, gt in enumerate(ground_truths):
if gt["path"] == res_path and ranges_overlap(
gt["start_line"], gt["end_line"], res_start, res_end
):
found.add(j)
return len(found) / len(ground_truths)
def main():
parser = argparse.ArgumentParser(description="Benchmark octocode retrieval quality")
parser.add_argument("--threshold", type=float, default=None, help="Similarity threshold")
parser.add_argument("--max-results", type=int, default=20, help="Max results per query")
parser.add_argument("--mode", default="code", help="Search mode (default: code)")
parser.add_argument("--verbose", action="store_true", help="Show per-query details")
parser.add_argument("--csv", default=str(GROUND_TRUTH_CSV), help="Ground truth CSV path")
args = parser.parse_args()
csv_path = Path(args.csv)
if not csv_path.exists():
print(f"ERROR: Ground truth CSV not found: {csv_path}", file=sys.stderr)
sys.exit(1)
queries = []
with open(csv_path) as f:
reader = csv.DictReader(f)
for row in reader:
query, ground_truths = parse_ground_truth(row)
if query and ground_truths:
queries.append((query, ground_truths))
print(f"Loaded {len(queries)} queries from {csv_path}")
print(f"Settings: mode={args.mode}, threshold={args.threshold}, max_results={args.max_results}")
print(f"{'=' * 70}")
metrics = {
"hit_at_5": [],
"hit_at_10": [],
"mrr": [],
"ndcg_at_10": [],
"recall_at_5": [],
"recall_at_10": [],
}
failures = []
errors = []
for i, (query, ground_truths) in enumerate(queries):
results, err = run_search(
query, mode=args.mode, threshold=args.threshold, max_results=args.max_results
)
if results is None:
errors.append((i + 1, query, err))
if args.verbose:
print(f" [{i+1:2d}] ERROR: {query[:60]}... => {err}")
continue
h5 = compute_hit_at_k(results, ground_truths, 5)
h10 = compute_hit_at_k(results, ground_truths, 10)
mrr = compute_mrr(results, ground_truths)
ndcg = compute_ndcg_at_k(results, ground_truths, 10)
r5 = compute_recall_at_k(results, ground_truths, 5)
r10 = compute_recall_at_k(results, ground_truths, 10)
metrics["hit_at_5"].append(h5)
metrics["hit_at_10"].append(h10)
metrics["mrr"].append(mrr)
metrics["ndcg_at_10"].append(ndcg)
metrics["recall_at_5"].append(r5)
metrics["recall_at_10"].append(r10)
if h10 == 0:
failures.append((i + 1, query, ground_truths, results[:3]))
if args.verbose:
status = "HIT" if h5 else ("hit@10" if h10 else "MISS")
print(f" [{i+1:2d}] {status:6s} MRR={mrr:.2f} NDCG={ndcg:.2f} R@10={r10:.2f} {query[:55]}")
time.sleep(0.05)
print(f"\n{'=' * 70}")
print("RESULTS SUMMARY")
print(f"{'=' * 70}")
n = len(metrics["hit_at_5"])
if n == 0:
print("No successful queries. Check octocode installation.")
sys.exit(1)
def avg(lst):
return sum(lst) / len(lst) if lst else 0.0
print(f" Queries evaluated : {n} / {len(queries)}")
print(f" Errors : {len(errors)}")
print()
print(f" Hit@5 : {avg(metrics['hit_at_5']):.3f} ({sum(metrics['hit_at_5'])}/{n})")
print(f" Hit@10 : {avg(metrics['hit_at_10']):.3f} ({sum(metrics['hit_at_10'])}/{n})")
print(f" MRR : {avg(metrics['mrr']):.3f}")
print(f" NDCG@10 : {avg(metrics['ndcg_at_10']):.3f}")
print(f" Recall@5 : {avg(metrics['recall_at_5']):.3f}")
print(f" Recall@10 : {avg(metrics['recall_at_10']):.3f}")
if failures:
print(f"\n{'=' * 70}")
print(f"MISSED QUERIES ({len(failures)} queries with no hit in top-10)")
print(f"{'=' * 70}")
for idx, query, gts, top3 in failures:
print(f"\n [{idx}] {query}")
print(f" Expected:")
for gt in gts:
print(f" {gt['path']}:{gt['start_line']}-{gt['end_line']} (rel={gt['relevance']})")
if top3:
print(f" Got (top 3):")
for r in top3:
print(f" {r.get('path', '?')}:{r.get('start_line', '?')}-{r.get('end_line', '?')}")
else:
print(f" Got: (no results)")
if errors:
print(f"\n{'=' * 70}")
print(f"ERRORS ({len(errors)})")
print(f"{'=' * 70}")
for idx, query, err in errors:
print(f" [{idx}] {query[:60]}... => {err}")
hit5 = avg(metrics["hit_at_5"])
if hit5 < 0.7:
print(f"\nWARN: Hit@5 ({hit5:.3f}) is below 0.70 threshold")
sys.exit(1)
if __name__ == "__main__":
main()