use std::path::PathBuf;
use std::sync::Arc;
use ripvec_core::embed::{Scope, SearchConfig};
use ripvec_core::encoder::ripvec::dense::{DEFAULT_MODEL_REPO, StaticEncoder};
use ripvec_core::encoder::ripvec::index::RipvecIndex;
use ripvec_core::encoder::ripvec::ranking::is_symbol_query;
use ripvec_core::hybrid::SearchMode;
use ripvec_core::profile::Profiler;
use ripvec_core::ranking::{CrossEncoderRerank, RankingLayer, apply_chain};
use ripvec_core::rerank::{DEFAULT_RERANK_CANDIDATES, DEFAULT_RERANK_MODEL, Reranker};
use serde::{Deserialize, Serialize};
#[derive(Deserialize)]
struct RawTask {
query: String,
}
#[derive(Serialize)]
struct ResultRecord {
rank: usize,
file_path: String,
start_line: usize,
end_line: usize,
raw_cosine: f32,
normalized_cosine: f32,
content_preview: String,
}
#[derive(Serialize)]
struct QueryRecord {
query: String,
scope: String,
rerank_fired: bool,
top_raw: f32,
results: Vec<ResultRecord>,
}
const TOP_K: usize = 100;
const PREVIEW_CHARS: usize = 600;
fn truncate(s: &str, max: usize) -> String {
if s.len() <= max {
s.to_string()
} else {
let mut end = max;
while !s.is_char_boundary(end) && end > 0 {
end -= 1;
}
format!("{}…", &s[..end])
}
}
#[allow(clippy::too_many_lines, reason = "single-file calibration probe")]
fn main() -> anyhow::Result<()> {
let mut args = std::env::args().skip(1).collect::<Vec<_>>();
if args.len() < 2 {
anyhow::bail!(
"usage: threshold_calibration <corpus_root> <annotations.json> --scope code|docs|all --out PATH"
);
}
let corpus: PathBuf = args.remove(0).into();
let ann_path: PathBuf = args.remove(0).into();
let mut scope = Scope::All;
let mut out_path: Option<PathBuf> = None;
let mut i = 0;
while i < args.len() {
match args[i].as_str() {
"--scope" if i + 1 < args.len() => {
scope = match args[i + 1].as_str() {
"code" => Scope::Code,
"docs" => Scope::Docs,
"all" => Scope::All,
o => anyhow::bail!("bad scope: {o}"),
};
args.drain(i..=i + 1);
}
"--out" if i + 1 < args.len() => {
out_path = Some(args[i + 1].clone().into());
args.drain(i..=i + 1);
}
_ => i += 1,
}
}
let out_path = out_path.ok_or_else(|| anyhow::anyhow!("--out PATH required"))?;
let scope_str = match scope {
Scope::Code => "code",
Scope::Docs => "docs",
Scope::All => "all",
};
eprintln!("==> calibration dump");
eprintln!(" corpus: {}", corpus.display());
eprintln!(" annotations: {}", ann_path.display());
eprintln!(" scope: {scope:?}");
eprintln!(" out: {}", out_path.display());
let raw: Vec<RawTask> = serde_json::from_slice(&std::fs::read(&ann_path)?)?;
eprintln!("loading encoder + reranker...");
let encoder = StaticEncoder::from_pretrained(DEFAULT_MODEL_REPO)?;
let reranker = Arc::new(Reranker::from_pretrained(DEFAULT_RERANK_MODEL)?);
eprintln!("building RipvecIndex...");
let cfg = SearchConfig {
scope,
..SearchConfig::default()
};
let index = RipvecIndex::from_root(&corpus, encoder, &cfg, &Profiler::noop(), None, 0.0)?;
let corpus_class = index.corpus_class();
eprintln!(
" {} chunks, corpus_class = {:?}",
index.chunks().len(),
corpus_class
);
let mut output: Vec<QueryRecord> = Vec::new();
for task in &raw {
let rerank_fired = !is_symbol_query(&task.query)
&& match scope {
Scope::Code => false,
Scope::Docs => true,
Scope::All => corpus_class.rerank_eligible(),
};
let mut ranked = index.search(
&task.query,
DEFAULT_RERANK_CANDIDATES,
SearchMode::Hybrid,
None,
None,
None,
);
if rerank_fired {
let layer = CrossEncoderRerank::new(
reranker.clone(),
task.query.clone(),
DEFAULT_RERANK_CANDIDATES,
);
let layers: Vec<Box<dyn RankingLayer>> = vec![Box::new(layer)];
apply_chain(&mut ranked, index.chunks(), &layers);
}
let top_raw = ranked.first().map_or(0.0, |&(_, s)| s);
let mut results = Vec::new();
for (rank, (idx, score)) in ranked.iter().take(TOP_K).enumerate() {
let Some(chunk) = index.chunks().get(*idx) else {
continue;
};
let normalized = if top_raw > 0.0 { score / top_raw } else { 0.0 };
results.push(ResultRecord {
rank: rank + 1,
file_path: chunk.file_path.clone(),
start_line: chunk.start_line,
end_line: chunk.end_line,
raw_cosine: *score,
normalized_cosine: normalized,
content_preview: truncate(&chunk.content, PREVIEW_CHARS),
});
}
output.push(QueryRecord {
query: task.query.clone(),
scope: scope_str.to_string(),
rerank_fired,
top_raw,
results,
});
}
let json = serde_json::to_string_pretty(&output)?;
std::fs::write(&out_path, json)?;
eprintln!("wrote {} queries to {}", output.len(), out_path.display());
Ok(())
}