#![allow(
clippy::cast_possible_truncation,
clippy::cast_sign_loss,
clippy::too_many_lines
)]
use std::path::PathBuf;
use std::sync::Arc;
use std::time::Instant;
use ripvec_core::chunk::CodeChunk;
use ripvec_core::embed::{Scope, SearchConfig};
use ripvec_core::encoder::ripvec::dense::{DEFAULT_MODEL_REPO, StaticEncoder};
use ripvec_core::encoder::ripvec::index::RipvecIndex;
use ripvec_core::encoder::ripvec::ranking::is_symbol_query;
use ripvec_core::hybrid::SearchMode;
use ripvec_core::profile::Profiler;
use ripvec_core::ranking::{CrossEncoderRerank, RankingLayer, apply_chain};
use ripvec_core::rerank::{DEFAULT_RERANK_MODEL, Reranker};
use serde::Deserialize;
const TOP_K: usize = 10;
const CANDIDATE_K: usize = 100;
const LATENCY_RUNS: usize = 5;
#[derive(Debug, Deserialize)]
#[serde(untagged)]
enum RawTarget {
Path(String),
Span {
path: String,
#[serde(default)]
start_line: Option<usize>,
#[serde(default)]
end_line: Option<usize>,
},
}
#[derive(Debug, Clone)]
struct Target {
path: String,
start_line: Option<usize>,
end_line: Option<usize>,
}
impl From<RawTarget> for Target {
fn from(raw: RawTarget) -> Self {
match raw {
RawTarget::Path(path) => Self {
path,
start_line: None,
end_line: None,
},
RawTarget::Span {
path,
start_line,
end_line,
} => Self {
path,
start_line,
end_line,
},
}
}
}
#[derive(Debug, Deserialize)]
struct RawTask {
query: String,
#[serde(default)]
relevant: Vec<RawTarget>,
#[serde(default)]
secondary: Vec<RawTarget>,
#[serde(default)]
category: Option<String>,
}
struct Task {
query: String,
targets: Vec<Target>,
category: String,
}
fn infer_category(query: &str) -> &'static str {
if !query.trim().contains(' ') {
return "symbol";
}
let q = query.to_lowercase();
if q.starts_with("how ") || q.starts_with("how does") || q.starts_with("how are") {
"architecture"
} else {
"semantic"
}
}
fn path_matches(file_path: &str, target_path: &str) -> bool {
let f = file_path.replace('\\', "/");
let t = target_path.replace('\\', "/");
f == t || f.ends_with(&format!("/{t}")) || t.ends_with(&format!("/{f}"))
}
fn target_matches_chunk(chunk: &CodeChunk, target: &Target) -> bool {
if !path_matches(&chunk.file_path, &target.path) {
return false;
}
match (target.start_line, target.end_line) {
(Some(ts), Some(te)) => !(chunk.end_line < ts || chunk.start_line > te),
_ => true,
}
}
fn dcg(relevances: &[u8]) -> f64 {
relevances
.iter()
.enumerate()
.map(|(i, &r)| f64::from(r) / ((i + 2) as f64).log2())
.sum()
}
fn ndcg_at_k(ranks: &[usize], n_relevant: usize, k: usize) -> f64 {
if n_relevant == 0 {
return 0.0;
}
let mut rels = vec![0u8; k];
for &r in ranks {
if (1..=k).contains(&r) {
rels[r - 1] = 1;
}
}
let ideal = dcg(&vec![1u8; k.min(n_relevant)]);
if ideal > 0.0 { dcg(&rels) / ideal } else { 0.0 }
}
fn percentile(sorted_ms: &[f64], p: f64) -> f64 {
if sorted_ms.is_empty() {
return 0.0;
}
let n = sorted_ms.len();
let pos = (p / 100.0) * ((n - 1) as f64);
let lo = pos.floor() as usize;
let hi = pos.ceil() as usize;
if lo == hi {
return sorted_ms[lo];
}
let frac = pos - lo as f64;
sorted_ms[lo] * (1.0 - frac) + sorted_ms[hi] * frac
}
fn main() -> anyhow::Result<()> {
let mut args = std::env::args().skip(1);
let repo_root: PathBuf = args
.next()
.ok_or_else(|| {
anyhow::anyhow!(
"usage: semble_bench <repo_root> <annotations.json> [--no-rerank] [--model REPO]"
)
})?
.into();
let ann_path: PathBuf = args
.next()
.ok_or_else(|| anyhow::anyhow!("missing annotations.json"))?
.into();
let mut rerank_override: Option<bool> = None;
let mut scope = Scope::All;
let mut include_ext: Vec<String> = Vec::new();
let mut exclude_ext: Vec<String> = Vec::new();
let mut model_repo = DEFAULT_MODEL_REPO.to_string();
let mut args: Vec<String> = args.collect();
let mut i = 0;
while i < args.len() {
match args[i].as_str() {
"--no-rerank" => {
rerank_override = Some(false);
args.remove(i);
}
"--rerank" => {
rerank_override = Some(true);
args.remove(i);
}
"--model" => {
args.remove(i);
if i < args.len() {
model_repo = args.remove(i);
}
}
"--scope" => {
args.remove(i);
if i < args.len() {
scope = match args.remove(i).as_str() {
"code" => Scope::Code,
"docs" => Scope::Docs,
"all" => Scope::All,
other => anyhow::bail!("unknown --scope: {other}"),
};
}
}
"--include-ext" => {
args.remove(i);
if i < args.len() {
include_ext = args
.remove(i)
.split(',')
.map(|s| s.trim_start_matches('.').to_ascii_lowercase())
.filter(|s| !s.is_empty())
.collect();
}
}
"--exclude-ext" => {
args.remove(i);
if i < args.len() {
exclude_ext = args
.remove(i)
.split(',')
.map(|s| s.trim_start_matches('.').to_ascii_lowercase())
.filter(|s| !s.is_empty())
.collect();
}
}
_ => i += 1,
}
}
let raw: Vec<RawTask> = serde_json::from_slice(&std::fs::read(&ann_path)?)?;
let tasks: Vec<Task> = raw
.into_iter()
.map(|t| Task {
category: t
.category
.unwrap_or_else(|| infer_category(&t.query).to_string()),
query: t.query,
targets: t
.relevant
.into_iter()
.chain(t.secondary)
.map(Into::into)
.collect(),
})
.collect();
let may_rerank = rerank_override.unwrap_or(true);
eprintln!(
"Loading encoder ({model_repo}){}...",
if may_rerank { " + L-12 reranker" } else { "" }
);
let t0 = Instant::now();
let encoder = StaticEncoder::from_pretrained(&model_repo)?;
let reranker = if may_rerank {
Some(Arc::new(Reranker::from_pretrained(DEFAULT_RERANK_MODEL)?))
} else {
None
};
eprintln!(" loaded in {} ms", t0.elapsed().as_millis());
eprintln!("Building RipvecIndex for {}...", repo_root.display());
let t0 = Instant::now();
let cfg = SearchConfig {
scope,
include_extensions: include_ext.clone(),
exclude_extensions: exclude_ext.clone(),
..SearchConfig::default()
};
let profiler = Profiler::noop();
let index = RipvecIndex::from_root(&repo_root, encoder, &cfg, &profiler, None, 0.0)?;
let index_ms = t0.elapsed().as_secs_f64() * 1000.0;
eprintln!(
" built in {index_ms:.0} ms ({} chunks, corpus={:?})",
index.chunks().len(),
index.corpus_class(),
);
let mut ndcg10_sum = 0.0;
let mut median_latencies_ms: Vec<f64> = Vec::with_capacity(tasks.len());
let mut by_category: std::collections::BTreeMap<String, Vec<f64>> =
std::collections::BTreeMap::new();
for task in &tasks {
let mut latencies = Vec::with_capacity(LATENCY_RUNS);
let mut ranked: Vec<(usize, f32)> = Vec::new();
let do_rerank = match rerank_override {
Some(force) => force && !is_symbol_query(&task.query),
None => {
!is_symbol_query(&task.query)
&& match scope {
Scope::Code => false,
Scope::Docs => true,
Scope::All => index.corpus_class().rerank_eligible(),
}
}
};
for _ in 0..LATENCY_RUNS {
let t = Instant::now();
ranked = index.search(
&task.query,
CANDIDATE_K,
SearchMode::Hybrid,
None,
None,
None,
);
if do_rerank {
let r = reranker.as_ref().unwrap().clone();
let layer = CrossEncoderRerank::new(r, task.query.clone(), CANDIDATE_K);
let layers: Vec<Box<dyn RankingLayer>> = vec![Box::new(layer)];
apply_chain(&mut ranked, index.chunks(), &layers);
}
latencies.push(t.elapsed().as_secs_f64() * 1000.0);
}
latencies.sort_by(|a, b| a.partial_cmp(b).unwrap());
median_latencies_ms.push(latencies[latencies.len() / 2]);
let top: Vec<&CodeChunk> = ranked
.iter()
.take(TOP_K)
.filter_map(|(idx, _)| index.chunks().get(*idx))
.collect();
let n_rel = task.targets.len();
let ranks: Vec<usize> = task
.targets
.iter()
.filter_map(|target| {
top.iter()
.position(|c| target_matches_chunk(c, target))
.map(|i| i + 1)
})
.collect();
let q_ndcg10 = ndcg_at_k(&ranks, n_rel, TOP_K);
ndcg10_sum += q_ndcg10;
by_category
.entry(task.category.clone())
.or_default()
.push(q_ndcg10);
}
let n = tasks.len();
let avg_ndcg10 = ndcg10_sum / n as f64;
median_latencies_ms.sort_by(|a, b| a.partial_cmp(b).unwrap());
let p50 = percentile(&median_latencies_ms, 50.0);
let p90 = percentile(&median_latencies_ms, 90.0);
let p95 = percentile(&median_latencies_ms, 95.0);
let p99 = percentile(&median_latencies_ms, 99.0);
let category_summary: serde_json::Value = by_category
.iter()
.map(|(k, v)| {
(
k.clone(),
serde_json::json!({
"n": v.len(),
"ndcg10": v.iter().sum::<f64>() / v.len() as f64,
}),
)
})
.collect::<serde_json::Map<_, _>>()
.into();
let out = serde_json::json!({
"repo_root": repo_root.display().to_string(),
"annotations": ann_path.display().to_string(),
"tasks": n,
"chunks": index.chunks().len(),
"index_ms": index_ms,
"rerank_override": rerank_override,
"scope": format!("{scope:?}").to_lowercase(),
"corpus_class": format!("{:?}", index.corpus_class()).to_lowercase(),
"ndcg10": avg_ndcg10,
"p50_ms": p50,
"p90_ms": p90,
"p95_ms": p95,
"p99_ms": p99,
"by_category": category_summary,
});
println!("{}", serde_json::to_string_pretty(&out)?);
Ok(())
}