mod neural;
use std::collections::HashMap;
use crate::fts::FtsResult;
use crate::vectordb::SearchResult;
pub use neural::NeuralReranker;
pub const DEFAULT_RRF_K: f32 = 20.0;
pub const EXACT_MATCH_RRF_K: f32 = 5.0;
#[derive(Debug, Clone)]
#[allow(dead_code)] pub struct FusedResult {
pub chunk_id: u32,
pub rrf_score: f32,
pub vector_score: Option<f32>,
pub fts_score: Option<f32>,
pub vector_rank: Option<usize>,
pub fts_rank: Option<usize>,
}
type ScoreEntry = (f32, Option<f32>, Option<f32>, Option<usize>, Option<usize>);
pub fn rrf_fusion(
vector_results: &[SearchResult],
fts_results: &[FtsResult],
k: f32,
) -> Vec<FusedResult> {
let mut scores: HashMap<u32, ScoreEntry> = HashMap::new();
for (rank, result) in vector_results.iter().enumerate() {
let chunk_id = result.id;
let rrf_score = 1.0 / (k + rank as f32 + 1.0);
let entry = scores
.entry(chunk_id)
.or_insert((0.0, None, None, None, None));
entry.0 += rrf_score;
entry.1 = Some(result.score);
entry.3 = Some(rank + 1);
}
for (rank, result) in fts_results.iter().enumerate() {
let chunk_id = result.chunk_id;
let rrf_score = 1.0 / (k + rank as f32 + 1.0);
let entry = scores
.entry(chunk_id)
.or_insert((0.0, None, None, None, None));
entry.0 += rrf_score;
entry.2 = Some(result.score);
entry.4 = Some(rank + 1);
}
let mut results: Vec<FusedResult> = scores
.into_iter()
.map(
|(chunk_id, (rrf_score, vector_score, fts_score, vector_rank, fts_rank))| FusedResult {
chunk_id,
rrf_score,
vector_score,
fts_score,
vector_rank,
fts_rank,
},
)
.collect();
results.sort_by(|a, b| {
b.rrf_score
.partial_cmp(&a.rrf_score)
.unwrap_or(std::cmp::Ordering::Equal)
});
results
}
pub fn vector_only(vector_results: &[SearchResult]) -> Vec<FusedResult> {
vector_results
.iter()
.enumerate()
.map(|(rank, result)| FusedResult {
chunk_id: result.id,
rrf_score: result.score,
vector_score: Some(result.score),
fts_score: None,
vector_rank: Some(rank + 1),
fts_rank: None,
})
.collect()
}
pub fn rrf_fusion_with_exact(
vector_results: &[SearchResult],
fts_results: &[FtsResult],
exact_results: &[FtsResult],
vector_k: f32,
fts_k: f32,
exact_k: f32,
) -> Vec<FusedResult> {
let mut scores: HashMap<
u32,
(
f32,
Option<f32>,
Option<f32>,
Option<f32>,
Option<usize>,
Option<usize>,
Option<usize>,
),
> = HashMap::new();
for (rank, result) in vector_results.iter().enumerate() {
let chunk_id = result.id;
let rrf_score = 1.0 / (vector_k + rank as f32 + 1.0);
let entry = scores
.entry(chunk_id)
.or_insert((0.0, None, None, None, None, None, None));
entry.0 += rrf_score;
entry.1 = Some(result.score);
entry.4 = Some(rank + 1);
}
for (rank, result) in fts_results.iter().enumerate() {
let chunk_id = result.chunk_id;
let rrf_score = 1.0 / (fts_k + rank as f32 + 1.0);
let entry = scores
.entry(chunk_id)
.or_insert((0.0, None, None, None, None, None, None));
entry.0 += rrf_score;
entry.2 = Some(result.score);
entry.5 = Some(rank + 1);
}
for (rank, result) in exact_results.iter().enumerate() {
let chunk_id = result.chunk_id;
let rrf_score = 1.0 / (exact_k + rank as f32 + 1.0);
let entry = scores
.entry(chunk_id)
.or_insert((0.0, None, None, None, None, None, None));
entry.0 += rrf_score;
entry.3 = Some(result.score);
entry.6 = Some(rank + 1);
}
let mut results: Vec<FusedResult> = scores
.into_iter()
.map(
|(
chunk_id,
(
rrf_score,
vector_score,
fts_score,
exact_score,
vector_rank,
fts_rank,
exact_rank,
),
)| {
let combined_fts_score = match (fts_score, exact_score) {
(Some(f), Some(e)) => Some((f + e) / 2.0),
(Some(f), None) => Some(f),
(None, Some(e)) => Some(e),
(None, None) => None,
};
FusedResult {
chunk_id,
rrf_score,
vector_score,
fts_score: combined_fts_score,
vector_rank,
fts_rank: fts_rank.or(exact_rank),
}
},
)
.collect();
results.sort_by(|a, b| {
b.rrf_score
.partial_cmp(&a.rrf_score)
.unwrap_or(std::cmp::Ordering::Equal)
});
results
}
#[cfg(test)]
mod tests {
use super::*;
fn make_vector_result(id: u32, score: f32) -> SearchResult {
SearchResult {
id,
score,
path: format!("file_{}.rs", id),
content: format!("content {}", id),
start_line: 1,
end_line: 10,
kind: "function".to_string(),
signature: None,
context_prev: None,
context_next: None,
distance: 0.0,
context: None,
docstring: None,
hash: String::new(),
}
}
fn make_fts_result(id: u32, score: f32) -> FtsResult {
FtsResult {
chunk_id: id,
score,
}
}
#[test]
fn test_rrf_fusion_basic() {
let vector_results = vec![
make_vector_result(1, 0.9),
make_vector_result(2, 0.8),
make_vector_result(3, 0.7),
];
let fts_results = vec![
make_fts_result(2, 10.0), make_fts_result(1, 8.0),
make_fts_result(4, 6.0), ];
let fused = rrf_fusion(&vector_results, &fts_results, 20.0);
assert!(!fused.is_empty());
let id1 = fused.iter().find(|r| r.chunk_id == 1).unwrap();
let id2 = fused.iter().find(|r| r.chunk_id == 2).unwrap();
assert!(id1.vector_rank.is_some());
assert!(id1.fts_rank.is_some());
assert!(id2.vector_rank.is_some());
assert!(id2.fts_rank.is_some());
let id4 = fused.iter().find(|r| r.chunk_id == 4).unwrap();
assert!(id4.vector_rank.is_none());
assert!(id4.fts_rank.is_some());
}
#[test]
fn test_rrf_score_calculation() {
let vector_results = vec![make_vector_result(1, 0.9)];
let fts_results = vec![make_fts_result(1, 10.0)];
let fused = rrf_fusion(&vector_results, &fts_results, 20.0);
assert_eq!(fused.len(), 1);
let result = &fused[0];
let expected = 1.0 / 21.0 + 1.0 / 21.0;
assert!((result.rrf_score - expected).abs() < 0.0001);
}
#[test]
fn test_vector_only() {
let vector_results = vec![make_vector_result(1, 0.9), make_vector_result(2, 0.8)];
let results = vector_only(&vector_results);
assert_eq!(results.len(), 2);
assert_eq!(results[0].chunk_id, 1);
assert_eq!(results[0].rrf_score, 0.9);
assert!(results[0].fts_score.is_none());
}
}