use std::collections::HashMap;
use crate::index::SearchHit;
use crate::search::ast_pattern::AstMatch;
use crate::search::text::TextMatch;
#[derive(Debug, Clone)]
pub struct FusedResult {
pub chunk_id: u64,
pub fused_score: f32,
pub text_score: Option<f32>,
pub semantic_score: Option<f32>,
pub ast_score: Option<f32>,
pub matched_lines: Vec<usize>,
}
pub fn rrf_fuse(
text_results: &[TextMatch],
semantic_results: &[SearchHit],
k: u32,
top_k: usize,
) -> Vec<FusedResult> {
let mut scores: HashMap<u64, FusedResult> = HashMap::new();
for (rank, result) in text_results.iter().enumerate() {
let rrf_score = 1.0 / (k as f32 + rank as f32 + 1.0);
scores
.entry(result.chunk_id)
.and_modify(|e| {
e.fused_score += rrf_score;
e.text_score = Some(result.score);
e.matched_lines = result.matched_lines.clone();
})
.or_insert(FusedResult {
chunk_id: result.chunk_id,
fused_score: rrf_score,
text_score: Some(result.score),
semantic_score: None,
ast_score: None,
matched_lines: result.matched_lines.clone(),
});
}
for (rank, result) in semantic_results.iter().enumerate() {
let rrf_score = 1.0 / (k as f32 + rank as f32 + 1.0);
scores
.entry(result.chunk_id)
.and_modify(|e| {
e.fused_score += rrf_score;
e.semantic_score = Some(result.score);
})
.or_insert(FusedResult {
chunk_id: result.chunk_id,
fused_score: rrf_score,
text_score: None,
semantic_score: Some(result.score),
ast_score: None,
matched_lines: Vec::new(),
});
}
let mut fused: Vec<FusedResult> = scores.into_values().collect();
fused.sort_by(|a, b| {
b.fused_score
.partial_cmp(&a.fused_score)
.unwrap_or(std::cmp::Ordering::Equal)
});
fused.truncate(top_k);
fused
}
pub fn rrf_fuse_three(
text_results: &[TextMatch],
semantic_results: &[SearchHit],
ast_results: &[AstMatch],
k: u32,
top_k: usize,
) -> Vec<FusedResult> {
let mut scores: HashMap<u64, FusedResult> = HashMap::new();
for (rank, result) in text_results.iter().enumerate() {
let rrf_score = 1.0 / (k as f32 + rank as f32 + 1.0);
scores
.entry(result.chunk_id)
.and_modify(|e| {
e.fused_score += rrf_score;
e.text_score = Some(result.score);
e.matched_lines = result.matched_lines.clone();
})
.or_insert(FusedResult {
chunk_id: result.chunk_id,
fused_score: rrf_score,
text_score: Some(result.score),
semantic_score: None,
ast_score: None,
matched_lines: result.matched_lines.clone(),
});
}
for (rank, result) in semantic_results.iter().enumerate() {
let rrf_score = 1.0 / (k as f32 + rank as f32 + 1.0);
scores
.entry(result.chunk_id)
.and_modify(|e| {
e.fused_score += rrf_score;
e.semantic_score = Some(result.score);
})
.or_insert(FusedResult {
chunk_id: result.chunk_id,
fused_score: rrf_score,
text_score: None,
semantic_score: Some(result.score),
ast_score: None,
matched_lines: Vec::new(),
});
}
for (rank, result) in ast_results.iter().enumerate() {
let rrf_score = 1.0 / (k as f32 + rank as f32 + 1.0);
scores
.entry(result.chunk_id)
.and_modify(|e| {
e.fused_score += rrf_score;
e.ast_score = Some(result.score);
})
.or_insert(FusedResult {
chunk_id: result.chunk_id,
fused_score: rrf_score,
text_score: None,
semantic_score: None,
ast_score: Some(result.score),
matched_lines: Vec::new(),
});
}
let mut fused: Vec<FusedResult> = scores.into_values().collect();
fused.sort_by(|a, b| {
b.fused_score
.partial_cmp(&a.fused_score)
.unwrap_or(std::cmp::Ordering::Equal)
});
fused.truncate(top_k);
fused
}
pub fn fuse_semantic_only(semantic_results: &[SearchHit], top_k: usize) -> Vec<FusedResult> {
semantic_results
.iter()
.take(top_k)
.map(|r| FusedResult {
chunk_id: r.chunk_id,
fused_score: r.score,
text_score: None,
semantic_score: Some(r.score),
ast_score: None,
matched_lines: Vec::new(),
})
.collect()
}
pub fn fuse_text_only(text_results: &[TextMatch], top_k: usize) -> Vec<FusedResult> {
text_results
.iter()
.take(top_k)
.map(|r| FusedResult {
chunk_id: r.chunk_id,
fused_score: r.score,
text_score: Some(r.score),
semantic_score: None,
ast_score: None,
matched_lines: r.matched_lines.clone(),
})
.collect()
}
pub fn fuse_ast_only(ast_results: &[AstMatch], top_k: usize) -> Vec<FusedResult> {
ast_results
.iter()
.take(top_k)
.map(|r| FusedResult {
chunk_id: r.chunk_id,
fused_score: r.score,
text_score: None,
semantic_score: None,
ast_score: Some(r.score),
matched_lines: Vec::new(),
})
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
fn make_text_matches(chunk_ids: &[u64]) -> Vec<TextMatch> {
chunk_ids
.iter()
.enumerate()
.map(|(i, &id)| TextMatch {
chunk_id: id,
matched_lines: vec![0],
score: (chunk_ids.len() - i) as f32,
})
.collect()
}
fn make_semantic_hits(chunk_ids: &[u64]) -> Vec<SearchHit> {
chunk_ids
.iter()
.enumerate()
.map(|(i, &id)| SearchHit {
chunk_id: id,
score: 1.0 - (i as f32 * 0.1),
})
.collect()
}
#[test]
fn test_rrf_basic_fusion() {
let text = make_text_matches(&[1, 2, 3]);
let semantic = make_semantic_hits(&[2, 3, 4]);
let fused = rrf_fuse(&text, &semantic, 60, 10);
assert!(!fused.is_empty());
let chunk_2 = fused.iter().find(|r| r.chunk_id == 2).unwrap();
let chunk_1 = fused.iter().find(|r| r.chunk_id == 1).unwrap();
assert!(
chunk_2.fused_score > chunk_1.fused_score,
"Chunk appearing in both lists should rank higher"
);
}
#[test]
fn test_rrf_preserves_all_unique_results() {
let text = make_text_matches(&[1, 2]);
let semantic = make_semantic_hits(&[3, 4]);
let fused = rrf_fuse(&text, &semantic, 60, 10);
assert_eq!(fused.len(), 4, "All unique chunks should be in results");
}
#[test]
fn test_rrf_top_k_truncation() {
let text = make_text_matches(&[1, 2, 3, 4, 5]);
let semantic = make_semantic_hits(&[6, 7, 8, 9, 10]);
let fused = rrf_fuse(&text, &semantic, 60, 3);
assert_eq!(fused.len(), 3, "Should respect top-k");
}
#[test]
fn test_rrf_empty_inputs() {
let fused = rrf_fuse(&[], &[], 60, 10);
assert!(fused.is_empty());
}
#[test]
fn test_fuse_semantic_only() {
let semantic = make_semantic_hits(&[1, 2, 3]);
let fused = fuse_semantic_only(&semantic, 2);
assert_eq!(fused.len(), 2);
assert!(fused[0].text_score.is_none());
assert!(fused[0].semantic_score.is_some());
}
#[test]
fn test_fuse_text_only() {
let text = make_text_matches(&[1, 2, 3]);
let fused = fuse_text_only(&text, 2);
assert_eq!(fused.len(), 2);
assert!(fused[0].text_score.is_some());
assert!(fused[0].semantic_score.is_none());
assert!(fused[0].ast_score.is_none());
}
fn make_ast_matches(chunk_ids: &[u64]) -> Vec<AstMatch> {
chunk_ids
.iter()
.enumerate()
.map(|(i, &id)| AstMatch {
chunk_id: id,
score: 1.0 - (i as f32 * 0.1),
})
.collect()
}
#[test]
fn test_fuse_ast_only() {
let ast = make_ast_matches(&[1, 2, 3]);
let fused = fuse_ast_only(&ast, 2);
assert_eq!(fused.len(), 2);
assert!(fused[0].text_score.is_none());
assert!(fused[0].semantic_score.is_none());
assert!(fused[0].ast_score.is_some());
}
#[test]
fn test_rrf_three_way_fusion() {
let text = make_text_matches(&[1, 2]);
let semantic = make_semantic_hits(&[2, 3]);
let ast = make_ast_matches(&[3, 4]);
let fused = rrf_fuse_three(&text, &semantic, &ast, 60, 10);
assert_eq!(fused.len(), 4);
let chunk_2 = fused.iter().find(|r| r.chunk_id == 2).unwrap();
let chunk_3 = fused.iter().find(|r| r.chunk_id == 3).unwrap();
let chunk_1 = fused.iter().find(|r| r.chunk_id == 1).unwrap();
let chunk_4 = fused.iter().find(|r| r.chunk_id == 4).unwrap();
assert!(chunk_2.fused_score > chunk_1.fused_score);
assert!(chunk_3.fused_score > chunk_4.fused_score);
}
}