use std::collections::{HashMap, HashSet};
use ndarray::{Array1, Array2, ArrayView1, s};
use rayon::prelude::*;
use crate::chunk::CodeChunk;
use crate::encoder::ripvec::bm25::{Bm25Index, search_bm25};
use crate::encoder::ripvec::penalties::rerank_topk;
use crate::encoder::ripvec::ranking::{apply_query_boost, boost_multi_chunk_files, resolve_alpha};
pub const RRF_K: f32 = 60.0;
const CANDIDATE_MULTIPLIER: usize = 5;
const SGEMV_SERIAL_THRESHOLD: usize = 4096;
#[must_use]
pub fn parallel_sgemv(matrix: &Array2<f32>, vector: &ArrayView1<f32>) -> Array1<f32> {
let n = matrix.nrows();
if n == 0 {
return Array1::zeros(0);
}
let n_threads = rayon::current_num_threads().max(1);
if n <= SGEMV_SERIAL_THRESHOLD || n_threads == 1 {
return matrix.dot(vector);
}
let chunk_size = n.div_ceil(n_threads);
let mut scores = vec![0.0_f32; n];
scores
.par_chunks_mut(chunk_size)
.enumerate()
.for_each(|(thread_idx, out)| {
let start = thread_idx * chunk_size;
let end = (start + out.len()).min(n);
let slice = matrix.slice(s![start..end, ..]);
let local: Array1<f32> = slice.dot(vector);
out.copy_from_slice(local.as_slice().expect("sgemv output contiguous"));
});
Array1::from_vec(scores)
}
#[must_use]
pub fn search_semantic(
query_embedding: &[f32],
chunk_embeddings: &Array2<f32>,
top_k: usize,
selector: Option<&[usize]>,
) -> Vec<(usize, f32)> {
let n_chunks = chunk_embeddings.nrows();
if top_k == 0 || n_chunks == 0 {
return Vec::new();
}
debug_assert_eq!(
query_embedding.len(),
chunk_embeddings.ncols(),
"query embedding dim ({}) != chunk embedding dim ({})",
query_embedding.len(),
chunk_embeddings.ncols(),
);
let query: ArrayView1<f32> = ArrayView1::from(query_embedding);
let scores: Array1<f32> = parallel_sgemv(chunk_embeddings, &query);
let selector_set: Option<HashSet<usize>> = selector.map(|s| s.iter().copied().collect());
let mut scored: Vec<(usize, f32)> = if let Some(set) = selector_set {
scores
.iter()
.enumerate()
.filter(|(i, _)| set.contains(i))
.map(|(i, &s)| (i, s))
.collect()
} else {
scores.iter().enumerate().map(|(i, &s)| (i, s)).collect()
};
if scored.len() > top_k {
scored.select_nth_unstable_by(top_k - 1, |a, b| {
b.1.total_cmp(&a.1).then_with(|| a.0.cmp(&b.0))
});
scored.truncate(top_k);
}
scored.sort_unstable_by(|a, b| b.1.total_cmp(&a.1).then_with(|| a.0.cmp(&b.0)));
scored
}
fn rrf_scores(ranked: &[(usize, f32)]) -> HashMap<usize, f32> {
ranked
.iter()
.enumerate()
.map(|(rank0, (idx, _))| {
let rank = rank0 as f32 + 1.0;
(*idx, 1.0 / (RRF_K + rank))
})
.collect()
}
#[must_use]
pub fn search_hybrid(
query: &str,
query_embedding: &[f32],
chunk_embeddings: &Array2<f32>,
chunks: &[CodeChunk],
bm25: &Bm25Index,
top_k: usize,
alpha: Option<f32>,
selector: Option<&[usize]>,
) -> Vec<(usize, f32)> {
if top_k == 0 || chunks.is_empty() {
return Vec::new();
}
let alpha_weight = resolve_alpha(query, alpha);
let candidate_count = top_k.saturating_mul(CANDIDATE_MULTIPLIER);
let semantic = search_semantic(query_embedding, chunk_embeddings, candidate_count, selector);
let bm25_hits = search_bm25(query, bm25, candidate_count, selector);
let normalized_semantic = rrf_scores(&semantic);
let normalized_bm25 = rrf_scores(&bm25_hits);
let mut combined: HashMap<usize, f32> = HashMap::new();
let union: HashSet<usize> = normalized_semantic
.keys()
.chain(normalized_bm25.keys())
.copied()
.collect();
for idx in union {
let s = normalized_semantic.get(&idx).copied().unwrap_or(0.0);
let b = normalized_bm25.get(&idx).copied().unwrap_or(0.0);
combined.insert(idx, alpha_weight * s + (1.0 - alpha_weight) * b);
}
boost_multi_chunk_files(&mut combined, chunks);
let boosted = apply_query_boost(&combined, query, chunks);
let penalise_paths = alpha_weight < 1.0;
let scores_vec: Vec<(usize, f32)> = boosted.into_iter().collect();
rerank_topk(&scores_vec, chunks, top_k, penalise_paths)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::encoder::ripvec::bm25::Bm25Index;
fn chunk(path: &str, content: &str) -> CodeChunk {
CodeChunk {
file_path: path.to_string(),
name: String::new(),
kind: String::new(),
start_line: 1,
end_line: 1,
content: content.to_string(),
enriched_content: content.to_string(),
}
}
fn unit_vec(values: &[f32]) -> Vec<f32> {
let norm: f32 = values.iter().map(|x| x * x).sum::<f32>().sqrt().max(1e-12);
values.iter().map(|x| x / norm).collect()
}
#[test]
fn rrf_k_60() {
let ranked = vec![(7, 0.9), (3, 0.8), (5, 0.5)];
let rrf = rrf_scores(&ranked);
assert!((rrf[&7] - 1.0 / 61.0).abs() < 1e-7);
assert!((rrf[&3] - 1.0 / 62.0).abs() < 1e-7);
assert!((rrf[&5] - 1.0 / 63.0).abs() < 1e-7);
}
#[test]
fn hybrid_candidate_count_5x_top_k() {
let chunks: Vec<CodeChunk> = (0..10)
.map(|i| chunk(&format!("src/f{i}.rs"), &format!("content {i}")))
.collect();
let flat: Vec<f32> = (0..10)
.flat_map(|i| {
let mut v = vec![0.0_f32; 10];
v[i] = 1.0;
v
})
.collect();
let embeddings = Array2::from_shape_vec((10, 10), flat).unwrap();
let query_emb = unit_vec(&{
let mut q = vec![0.0_f32; 10];
q[0] = 1.0;
q
});
let bm25 = Bm25Index::build(&chunks);
let results = search_hybrid(
"content",
&query_emb,
&embeddings,
&chunks,
&bm25,
2,
Some(0.5),
None,
);
assert!(!results.is_empty());
assert!(results.iter().any(|(i, _)| *i == 0));
assert!(results.len() <= 2);
}
#[test]
fn hybrid_zero_bm25_excluded_from_fusion() {
let chunks = vec![chunk("src/a.rs", "alpha"), chunk("src/b.rs", "bravo")];
let bm25 = Bm25Index::build(&chunks);
let bm = search_bm25("alpha", &bm25, 10, None);
assert_eq!(bm.len(), 1);
let rrf = rrf_scores(&bm);
assert!(
!rrf.contains_key(&1),
"BM25 zero-score doc should be excluded"
);
}
#[test]
fn hybrid_applies_rerank_topk() {
let chunks = vec![
chunk("src/a.rs", "alpha bravo"),
chunk("src/a.rs", "alpha bravo"),
];
let embeddings = Array2::from_shape_vec((2, 2), vec![1.0_f32, 0.0, 1.0, 0.0]).unwrap();
let bm25 = Bm25Index::build(&chunks);
let query_emb = vec![1.0_f32, 0.0];
let results = search_hybrid(
"alpha",
&query_emb,
&embeddings,
&chunks,
&bm25,
2,
Some(0.5),
None,
);
assert_eq!(results.len(), 2);
assert!(
results[0].1 > results[1].1,
"expected saturation decay; got scores={results:?}"
);
}
#[test]
fn hybrid_pipeline_wires_through_boosts_and_rerank() {
let chunks = vec![
chunk("src/auth.rs", "fn login() {}"),
chunk("src/utils.rs", "fn unrelated() {}"),
];
let embeddings = Array2::from_shape_vec((2, 2), vec![1.0_f32, 0.0, 0.0, 1.0]).unwrap();
let bm25 = Bm25Index::build(&chunks);
let query_emb = vec![0.0_f32, 0.0]; let results = search_hybrid(
"auth",
&query_emb,
&embeddings,
&chunks,
&bm25,
2,
Some(0.5),
None,
);
assert!(!results.is_empty());
let top = results[0].0;
assert_eq!(top, 0, "expected auth.rs first; got {results:?}");
}
}