use crate::confidence::Stage;
use crate::config::Config;
use crate::index::Index;
use crate::lexical::{self, Lex};
use crate::rank::Hit;
use crate::rerank;
#[derive(Debug)]
pub struct Plan {
pub stage: Stage,
pub rows: Vec<Hit>,
pub lexical: Option<Lex>,
pub passed: Vec<Hit>,
pub threshold: f32,
}
pub fn cosine_passed(hits: &[Hit], cfg: &Config) -> Vec<Hit> {
let top = hits.first().map(|h| h.score).unwrap_or(0.0);
hits.iter()
.filter(|h| {
let forced = cfg.force.contains(&h.id) && h.keyword > 0.0;
forced || (h.score >= cfg.min_similarity && h.score >= top - cfg.score_margin)
})
.cloned()
.collect()
}
pub fn decide(hits: &[Hit], idx: &Index, prompt: &str, rerank_query: &str, cfg: &Config) -> Plan {
if !rerank::confident_winner(hits, cfg) {
if let Some(win) = lexical::dominant(prompt, idx, cfg) {
let passed = hits.iter().filter(|h| h.id == win.id).cloned().collect();
return Plan {
stage: Stage::Lexical,
rows: hits.to_vec(),
lexical: Some(win),
passed,
threshold: cfg.lexical_min,
};
}
}
match rerank::is_ambiguous(hits, cfg)
.then(|| rerank::rerank(hits, idx, rerank_query, cfg))
.flatten()
{
Some(reranked) => {
let passed = rerank::passes(&reranked, cfg);
Plan {
stage: Stage::Rerank,
rows: reranked,
lexical: None,
passed,
threshold: cfg.rerank_min,
}
}
None => Plan {
stage: Stage::Cosine,
passed: cosine_passed(hits, cfg),
rows: hits.to_vec(),
lexical: None,
threshold: cfg.min_similarity,
},
}
}
pub fn stage_label(stage: Stage, model: &str) -> String {
match stage {
Stage::Cosine => format!("stage1:{model}"),
Stage::Rerank => "rerank:turbo".to_string(),
Stage::Lexical => "lexical(BM25)".to_string(),
}
}
#[cfg(test)]
mod tests {
use super::*;
fn hit(id: &str, score: f32, keyword: f32) -> Hit {
Hit {
id: id.to_string(),
name: id.to_string(),
cosine: score - keyword,
context: 0.0,
file: 0.0,
project: 0.0,
keyword,
phrase: 0.0,
score,
}
}
#[test]
fn cosine_passed_applies_floor_and_margin() {
let cfg = Config::default(); let hits = vec![
hit("a", 0.90, 0.0),
hit("b", 0.80, 0.0), hit("c", 0.50, 0.0), hit("d", 0.10, 0.0), ];
let got: Vec<String> = cosine_passed(&hits, &cfg)
.into_iter()
.map(|h| h.id)
.collect();
assert_eq!(got, ["a", "b"]);
}
#[test]
fn cosine_passed_force_bypasses_floor_on_keyword() {
let cfg = Config {
force: vec!["x".to_string()],
..Default::default()
};
let hits = vec![hit("x", 0.10, 0.15), hit("y", 0.20, 0.0)];
let got: Vec<String> = cosine_passed(&hits, &cfg)
.into_iter()
.map(|h| h.id)
.collect();
assert_eq!(got, ["x"]);
}
#[test]
fn stage_label_renders_each_stage() {
assert_eq!(stage_label(Stage::Cosine, "bge"), "stage1:bge");
assert_eq!(stage_label(Stage::Rerank, "bge"), "rerank:turbo");
assert_eq!(stage_label(Stage::Lexical, "bge"), "lexical(BM25)");
}
}