#![allow(clippy::cast_possible_truncation, clippy::cast_sign_loss, clippy::too_many_lines)]
use std::path::PathBuf;
use std::sync::Arc;
use std::time::Instant;
use ripvec_core::chunk::CodeChunk;
use ripvec_core::embed::{Scope, SearchConfig};
use ripvec_core::encoder::ripvec::dense::{DEFAULT_MODEL_REPO, StaticEncoder};
use ripvec_core::encoder::ripvec::index::RipvecIndex;
use ripvec_core::encoder::ripvec::ranking::is_symbol_query;
use ripvec_core::hybrid::{SearchMode, pagerank_lookup};
use ripvec_core::profile::Profiler;
use ripvec_core::ranking::{CrossEncoderRerank, RankingLayer, apply_chain};
use ripvec_core::repo_map::build_graph;
use ripvec_core::rerank::{DEFAULT_RERANK_CANDIDATES, DEFAULT_RERANK_MODEL, Reranker};
use serde::Deserialize;
const TOP_K: usize = 10;
const PAGERANK_ALPHA: f32 = 0.5;
#[derive(Debug, Deserialize)]
#[serde(untagged)]
enum RawTarget {
Path(String),
Span {
path: String,
#[serde(default)]
start_line: Option<usize>,
#[serde(default)]
end_line: Option<usize>,
},
}
#[derive(Debug, Clone)]
struct Target {
path: String,
start_line: Option<usize>,
end_line: Option<usize>,
}
impl From<RawTarget> for Target {
fn from(raw: RawTarget) -> Self {
match raw {
RawTarget::Path(path) => Self { path, start_line: None, end_line: None },
RawTarget::Span { path, start_line, end_line } => {
Self { path, start_line, end_line }
}
}
}
}
#[derive(Debug, Deserialize)]
struct RawTask {
query: String,
#[serde(default)]
relevant: Vec<RawTarget>,
#[serde(default)]
secondary: Vec<RawTarget>,
#[serde(default)]
category: Option<String>,
}
struct Task {
query: String,
targets: Vec<Target>,
category: String,
}
fn infer_category(query: &str) -> &'static str {
if !query.trim().contains(' ') { return "symbol"; }
let q = query.to_lowercase();
if q.starts_with("how ") { "architecture" } else { "semantic" }
}
fn path_matches(file_path: &str, target_path: &str) -> bool {
let f = file_path.replace('\\', "/");
let t = target_path.replace('\\', "/");
f == t || f.ends_with(&format!("/{t}")) || t.ends_with(&format!("/{f}"))
}
fn target_matches_chunk(chunk: &CodeChunk, target: &Target) -> bool {
if !path_matches(&chunk.file_path, &target.path) {
return false;
}
match (target.start_line, target.end_line) {
(Some(ts), Some(te)) => !(chunk.end_line < ts || chunk.start_line > te),
_ => true,
}
}
fn dcg(rels: &[u8]) -> f64 {
rels.iter().enumerate()
.map(|(i, &r)| f64::from(r) / ((i + 2) as f64).log2())
.sum()
}
fn ndcg_at_k(ranks: &[usize], n_relevant: usize, k: usize) -> f64 {
if n_relevant == 0 { return 0.0; }
let mut rels = vec![0u8; k];
for &r in ranks {
if (1..=k).contains(&r) { rels[r - 1] = 1; }
}
let ideal = dcg(&vec![1u8; k.min(n_relevant)]);
if ideal > 0.0 { dcg(&rels) / ideal } else { 0.0 }
}
fn percentile(sorted: &[f64], p: f64) -> f64 {
if sorted.is_empty() { return 0.0; }
let n = sorted.len();
let pos = (p / 100.0) * ((n - 1) as f64);
let lo = pos.floor() as usize;
let hi = pos.ceil() as usize;
if lo == hi { return sorted[lo]; }
let frac = pos - lo as f64;
sorted[lo] * (1.0 - frac) + sorted[hi] * frac
}
fn main() -> anyhow::Result<()> {
let mut args = std::env::args().skip(1);
let corpus: PathBuf = args.next()
.ok_or_else(|| anyhow::anyhow!("usage: corpus_bench <corpus_root> <annotations.json> [opts]"))?
.into();
let ann_path: PathBuf = args.next()
.ok_or_else(|| anyhow::anyhow!("missing annotations.json"))?
.into();
let mut model_repo = DEFAULT_MODEL_REPO.to_string();
let mut rerank_model = DEFAULT_RERANK_MODEL.to_string();
let mut scope = Scope::All;
let mut rerank_override: Option<bool> = None;
let mut repeats: usize = 5;
let mut candidate_k: usize = DEFAULT_RERANK_CANDIDATES;
let mut args: Vec<String> = args.collect();
let mut i = 0;
while i < args.len() {
match args[i].as_str() {
"--model" => { args.remove(i); if i < args.len() { model_repo = args.remove(i); } }
"--rerank-model" => {
args.remove(i);
if i < args.len() { rerank_model = args.remove(i); }
}
"--scope" => {
args.remove(i);
if i < args.len() {
scope = match args.remove(i).as_str() {
"code" => Scope::Code, "docs" => Scope::Docs, "all" => Scope::All,
o => anyhow::bail!("bad scope: {o}"),
};
}
}
"--no-rerank" => { rerank_override = Some(false); args.remove(i); }
"--rerank" => { rerank_override = Some(true); args.remove(i); }
"--repeats" => {
args.remove(i);
if i < args.len() { repeats = args.remove(i).parse()?; }
}
"--candidates" => {
args.remove(i);
if i < args.len() { candidate_k = args.remove(i).parse()?; }
}
_ => i += 1,
}
}
let raw: Vec<RawTask> = serde_json::from_slice(&std::fs::read(&ann_path)?)?;
let tasks: Vec<Task> = raw.into_iter().map(|t| Task {
category: t.category.unwrap_or_else(|| infer_category(&t.query).to_string()),
query: t.query,
targets: t.relevant.into_iter().chain(t.secondary).map(Into::into).collect(),
}).collect();
eprintln!("loading encoder ({model_repo}) + reranker ({rerank_model})...");
let t0 = Instant::now();
let encoder = StaticEncoder::from_pretrained(&model_repo)?;
let reranker = Arc::new(Reranker::from_pretrained(&rerank_model)?);
eprintln!(" loaded in {} ms", t0.elapsed().as_millis());
eprintln!("building repo graph for PageRank...");
let t0 = Instant::now();
let pr_result = build_graph(&corpus);
let (pr_lookup, pr_enabled) = match pr_result {
Ok(graph) => {
let lookup = pagerank_lookup(&graph);
eprintln!(" graph: {} files, pagerank lookup: {} entries",
graph.files.len(), lookup.len());
let enabled = !lookup.is_empty();
(lookup, enabled)
}
Err(e) => {
eprintln!(" repo graph failed ({e}); continuing without PageRank");
(std::collections::HashMap::new(), false)
}
};
eprintln!(" pagerank build: {} ms", t0.elapsed().as_millis());
eprintln!("building RipvecIndex for {}...", corpus.display());
let t0 = Instant::now();
let cfg = SearchConfig { scope, ..SearchConfig::default() };
let profiler = Profiler::noop();
let index = RipvecIndex::from_root(
&corpus, encoder, &cfg, &profiler,
if pr_enabled { Some(pr_lookup) } else { None },
if pr_enabled { PAGERANK_ALPHA } else { 0.0 },
)?;
let index_ms = t0.elapsed().as_secs_f64() * 1000.0;
let corpus_class = index.corpus_class();
eprintln!(" built in {index_ms:.0} ms ({} chunks, corpus={:?}, pr={})",
index.chunks().len(), corpus_class, pr_enabled);
let mut median_latencies: Vec<f64> = Vec::with_capacity(tasks.len());
let mut q_ndcg10s: Vec<f64> = Vec::with_capacity(tasks.len());
let mut q_recall10s: Vec<f64> = Vec::with_capacity(tasks.len());
let mut q_precision10s: Vec<f64> = Vec::with_capacity(tasks.len());
let mut by_cat: std::collections::BTreeMap<String, Vec<f64>> = Default::default();
let mut rerank_fired = 0usize;
for task in &tasks {
let auto_rerank = !is_symbol_query(&task.query) && match scope {
Scope::Code => false,
Scope::Docs => true,
Scope::All => corpus_class.rerank_eligible(),
};
let do_rerank = rerank_override.map_or(auto_rerank, |f| f && !is_symbol_query(&task.query));
if do_rerank { rerank_fired += 1; }
let mut latencies = Vec::with_capacity(repeats);
let mut ranked: Vec<(usize, f32)> = Vec::new();
for _ in 0..repeats {
let t = Instant::now();
ranked = index.search(&task.query, candidate_k, SearchMode::Hybrid, None, None, None);
if do_rerank {
let layer = CrossEncoderRerank::new(reranker.clone(), task.query.clone(), candidate_k);
let layers: Vec<Box<dyn RankingLayer>> = vec![Box::new(layer)];
apply_chain(&mut ranked, index.chunks(), &layers);
}
latencies.push(t.elapsed().as_secs_f64() * 1000.0);
}
latencies.sort_by(|a, b| a.partial_cmp(b).unwrap());
median_latencies.push(latencies[latencies.len() / 2]);
let top_k: Vec<&CodeChunk> = ranked.iter().take(TOP_K)
.filter_map(|(idx, _)| index.chunks().get(*idx)).collect();
let n_rel = task.targets.len();
let mut ranks: Vec<usize> = Vec::new();
for target in &task.targets {
for (i, c) in top_k.iter().enumerate() {
if target_matches_chunk(c, target) {
ranks.push(i + 1);
break;
}
}
}
let hits = ranks.len();
let q_ndcg = ndcg_at_k(&ranks, n_rel, TOP_K);
let q_recall = if n_rel == 0 { 0.0 } else { hits as f64 / n_rel as f64 };
let q_precision = hits as f64 / TOP_K as f64;
q_ndcg10s.push(q_ndcg);
q_recall10s.push(q_recall);
q_precision10s.push(q_precision);
by_cat.entry(task.category.clone()).or_default().push(q_ndcg);
}
median_latencies.sort_by(|a, b| a.partial_cmp(b).unwrap());
let p50 = percentile(&median_latencies, 50.0);
let p90 = percentile(&median_latencies, 90.0);
let p95 = percentile(&median_latencies, 95.0);
let p99 = percentile(&median_latencies, 99.0);
let n = tasks.len() as f64;
let mean = |xs: &[f64]| xs.iter().sum::<f64>() / n;
let by_cat_summary: serde_json::Value = by_cat.iter().map(|(k, v)| {
(k.clone(), serde_json::json!({"n": v.len(), "ndcg10": v.iter().sum::<f64>() / v.len() as f64}))
}).collect::<serde_json::Map<_, _>>().into();
let out = serde_json::json!({
"corpus": corpus.display().to_string(),
"annotations": ann_path.display().to_string(),
"model": model_repo,
"scope": format!("{scope:?}").to_lowercase(),
"corpus_class": format!("{corpus_class:?}").to_lowercase(),
"tasks": tasks.len(),
"rerank_fired": rerank_fired,
"chunks": index.chunks().len(),
"pagerank_enabled": pr_enabled,
"index_ms": index_ms,
"ndcg10": mean(&q_ndcg10s),
"recall10": mean(&q_recall10s),
"precision10": mean(&q_precision10s),
"p50_ms": p50, "p90_ms": p90, "p95_ms": p95, "p99_ms": p99,
"by_category": by_cat_summary,
});
println!("{}", serde_json::to_string_pretty(&out)?);
Ok(())
}