use std::collections::{HashMap, HashSet};
use crate::chunk::CodeChunk;
use crate::encoder::ripvec::bm25::{Bm25Index, search_bm25};
use crate::encoder::ripvec::penalties::rerank_topk;
use crate::encoder::ripvec::ranking::{apply_query_boost, boost_multi_chunk_files, resolve_alpha};
pub const RRF_K: f32 = 60.0;
const CANDIDATE_MULTIPLIER: usize = 5;
fn dot(a: &[f32], b: &[f32]) -> f32 {
debug_assert_eq!(a.len(), b.len(), "embedding length mismatch");
a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
}
#[must_use]
pub fn search_semantic(
query_embedding: &[f32],
chunk_embeddings: &[Vec<f32>],
top_k: usize,
selector: Option<&[usize]>,
) -> Vec<(usize, f32)> {
if top_k == 0 || chunk_embeddings.is_empty() {
return Vec::new();
}
let selector_set: Option<HashSet<usize>> = selector.map(|s| s.iter().copied().collect());
let mut scored: Vec<(usize, f32)> = chunk_embeddings
.iter()
.enumerate()
.filter(|(i, _)| selector_set.as_ref().is_none_or(|s| s.contains(i)))
.map(|(i, emb)| (i, dot(query_embedding, emb)))
.collect();
scored.sort_unstable_by(|a, b| b.1.total_cmp(&a.1).then_with(|| a.0.cmp(&b.0)));
scored.truncate(top_k);
scored
}
fn rrf_scores(ranked: &[(usize, f32)]) -> HashMap<usize, f32> {
ranked
.iter()
.enumerate()
.map(|(rank0, (idx, _))| {
let rank = rank0 as f32 + 1.0;
(*idx, 1.0 / (RRF_K + rank))
})
.collect()
}
#[must_use]
pub fn search_hybrid(
query: &str,
query_embedding: &[f32],
chunk_embeddings: &[Vec<f32>],
chunks: &[CodeChunk],
bm25: &Bm25Index,
top_k: usize,
alpha: Option<f32>,
selector: Option<&[usize]>,
) -> Vec<(usize, f32)> {
if top_k == 0 || chunks.is_empty() {
return Vec::new();
}
let alpha_weight = resolve_alpha(query, alpha);
let candidate_count = top_k.saturating_mul(CANDIDATE_MULTIPLIER);
let semantic = search_semantic(query_embedding, chunk_embeddings, candidate_count, selector);
let bm25_hits = search_bm25(query, bm25, candidate_count, selector);
let normalized_semantic = rrf_scores(&semantic);
let normalized_bm25 = rrf_scores(&bm25_hits);
let mut combined: HashMap<usize, f32> = HashMap::new();
let union: HashSet<usize> = normalized_semantic
.keys()
.chain(normalized_bm25.keys())
.copied()
.collect();
for idx in union {
let s = normalized_semantic.get(&idx).copied().unwrap_or(0.0);
let b = normalized_bm25.get(&idx).copied().unwrap_or(0.0);
combined.insert(idx, alpha_weight * s + (1.0 - alpha_weight) * b);
}
boost_multi_chunk_files(&mut combined, chunks);
let boosted = apply_query_boost(&combined, query, chunks);
let penalise_paths = alpha_weight < 1.0;
let scores_vec: Vec<(usize, f32)> = boosted.into_iter().collect();
rerank_topk(&scores_vec, chunks, top_k, penalise_paths)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::encoder::ripvec::bm25::Bm25Index;
fn chunk(path: &str, content: &str) -> CodeChunk {
CodeChunk {
file_path: path.to_string(),
name: String::new(),
kind: String::new(),
start_line: 1,
end_line: 1,
content: content.to_string(),
enriched_content: content.to_string(),
}
}
fn unit_vec(values: &[f32]) -> Vec<f32> {
let norm: f32 = values.iter().map(|x| x * x).sum::<f32>().sqrt().max(1e-12);
values.iter().map(|x| x / norm).collect()
}
#[test]
fn rrf_k_60() {
let ranked = vec![(7, 0.9), (3, 0.8), (5, 0.5)];
let rrf = rrf_scores(&ranked);
assert!((rrf[&7] - 1.0 / 61.0).abs() < 1e-7);
assert!((rrf[&3] - 1.0 / 62.0).abs() < 1e-7);
assert!((rrf[&5] - 1.0 / 63.0).abs() < 1e-7);
}
#[test]
fn hybrid_candidate_count_5x_top_k() {
let chunks: Vec<CodeChunk> = (0..10)
.map(|i| chunk(&format!("src/f{i}.rs"), &format!("content {i}")))
.collect();
let embeddings: Vec<Vec<f32>> = (0..10)
.map(|i| {
let mut v = vec![0.0_f32; 10];
v[i] = 1.0;
v
})
.collect();
let query_emb = unit_vec(&{
let mut q = vec![0.0_f32; 10];
q[0] = 1.0;
q
});
let bm25 = Bm25Index::build(&chunks);
let results = search_hybrid(
"content",
&query_emb,
&embeddings,
&chunks,
&bm25,
2,
Some(0.5),
None,
);
assert!(!results.is_empty());
assert!(results.iter().any(|(i, _)| *i == 0));
assert!(results.len() <= 2);
}
#[test]
fn hybrid_zero_bm25_excluded_from_fusion() {
let chunks = vec![chunk("src/a.rs", "alpha"), chunk("src/b.rs", "bravo")];
let bm25 = Bm25Index::build(&chunks);
let bm = search_bm25("alpha", &bm25, 10, None);
assert_eq!(bm.len(), 1);
let rrf = rrf_scores(&bm);
assert!(
!rrf.contains_key(&1),
"BM25 zero-score doc should be excluded"
);
}
#[test]
fn hybrid_applies_rerank_topk() {
let chunks = vec![
chunk("src/a.rs", "alpha bravo"),
chunk("src/a.rs", "alpha bravo"),
];
let embeddings = vec![vec![1.0_f32, 0.0], vec![1.0_f32, 0.0]];
let bm25 = Bm25Index::build(&chunks);
let query_emb = vec![1.0_f32, 0.0];
let results = search_hybrid(
"alpha",
&query_emb,
&embeddings,
&chunks,
&bm25,
2,
Some(0.5),
None,
);
assert_eq!(results.len(), 2);
assert!(
results[0].1 > results[1].1,
"expected saturation decay; got scores={results:?}"
);
}
#[test]
fn hybrid_pipeline_wires_through_boosts_and_rerank() {
let chunks = vec![
chunk("src/auth.rs", "fn login() {}"),
chunk("src/utils.rs", "fn unrelated() {}"),
];
let embeddings = vec![vec![1.0_f32, 0.0], vec![0.0, 1.0]];
let bm25 = Bm25Index::build(&chunks);
let query_emb = vec![0.0_f32, 0.0]; let results = search_hybrid(
"auth",
&query_emb,
&embeddings,
&chunks,
&bm25,
2,
Some(0.5),
None,
);
assert!(!results.is_empty());
let top = results[0].0;
assert_eq!(top, 0, "expected auth.rs first; got {results:?}");
}
}