use crate::errors::Error;
use crate::sqlite::Memory;
use std::collections::HashMap;
#[derive(Debug, Clone, Copy)]
pub struct RrfConfig {
pub k: f64,
}
impl Default for RrfConfig {
fn default() -> Self {
Self { k: 25.0 }
}
}
pub fn rrf_fusion(
result_lists: Vec<Vec<Memory>>,
config: Option<RrfConfig>,
) -> Result<Vec<Memory>, Error> {
let config = config.unwrap_or_default();
if result_lists.is_empty() {
return Ok(vec![]);
}
let mut fused_results: HashMap<String, (Memory, f64)> = HashMap::new();
for result_list in result_lists {
for (rank, mut result) in result_list.into_iter().enumerate() {
let rank = rank + 1; let rrf_score = 1.0f64 / (config.k + rank as f64);
let id = result.id.clone();
match fused_results.get_mut(&id) {
Some((_, accumulated_score)) => {
*accumulated_score += rrf_score;
}
None => {
result.similarity = Some(rrf_score);
fused_results.insert(id, (result, rrf_score));
}
}
}
}
let mut fused_vec: Vec<(Memory, f64)> = fused_results.into_values().collect();
fused_vec.sort_by(|a, b| {
b.1.partial_cmp(&a.1)
.unwrap_or(std::cmp::Ordering::Equal)
.then_with(|| a.0.id.cmp(&b.0.id))
});
let final_results = fused_vec
.into_iter()
.map(|(mut result, score)| {
result.similarity = Some(score);
result
})
.collect();
Ok(final_results)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::embedding::EMBEDDING_DIMS;
fn create_test_memory(
id: &str,
content: &str,
project_id: &str,
similarity: Option<f64>,
) -> Memory {
Memory {
id: id.to_string(),
project_id: project_id.to_string(),
content: content.to_string(),
metadata: None,
embedding: vec![0.1f32; EMBEDDING_DIMS],
similarity,
created_at: "2024-01-01T00:00:00Z".to_string(),
updated_at: "2024-01-01T00:00:00Z".to_string(),
memory_type: "fact".to_string(),
status: "active".to_string(),
superseded_by: None,
}
}
#[test]
fn test_rrf_fusion_basic() {
let semantic_results = vec![
create_test_memory("mem-1", "rust programming", "proj-a", Some(0.9)),
create_test_memory("mem-2", "python code", "proj-a", Some(0.7)),
];
let bm25_results = vec![
create_test_memory("mem-2", "python code", "proj-a", Some(0.5)),
create_test_memory("mem-3", "database query", "proj-a", Some(0.3)),
];
let fused = rrf_fusion(vec![semantic_results, bm25_results], None).unwrap();
assert_eq!(fused.len(), 3);
assert_eq!(fused[0].id, "mem-2");
assert!(fused[0].similarity.unwrap() > fused[1].similarity.unwrap());
}
#[test]
fn test_rrf_fusion_empty_lists() {
let fused = rrf_fusion(vec![], None).unwrap();
assert!(fused.is_empty());
}
#[test]
fn test_rrf_fusion_single_list() {
let results = vec![
create_test_memory("mem-1", "content 1", "proj-a", Some(0.9)),
create_test_memory("mem-2", "content 2", "proj-a", Some(0.8)),
];
let fused = rrf_fusion(vec![results], None).unwrap();
assert_eq!(fused.len(), 2);
assert_eq!(fused[0].id, "mem-1");
}
#[test]
fn test_rrf_fusion_duplicate_documents() {
let memory = create_test_memory("mem-1", "shared content", "proj-a", Some(0.8));
let list1 = vec![memory.clone()];
let list2 = vec![memory.clone()];
let list3 = vec![memory];
let default_config = RrfConfig::default();
let fused = rrf_fusion(vec![list1, list2, list3], None).unwrap();
assert_eq!(fused.len(), 1);
let expected_score = 3.0 * (1.0 / (default_config.k + 1.0));
assert!((fused[0].similarity.unwrap() - expected_score).abs() < 0.001);
}
#[test]
fn test_rrf_fusion_different_k_values() {
let configs = [RrfConfig { k: 10.0 }, RrfConfig { k: 100.0 }];
let semantic_results = vec![create_test_memory("mem-1", "doc1", "proj-a", Some(0.9))];
let bm25_results = vec![create_test_memory("mem-1", "doc1", "proj-a", Some(0.9))];
for &config in &[configs[0], configs[1]] {
let fused = rrf_fusion(
vec![semantic_results.clone(), bm25_results.clone()],
Some(config),
)
.unwrap();
let expected_score = 2.0 * (1.0 / (config.k + 1.0));
assert!((fused[0].similarity.unwrap() - expected_score).abs() < 0.001);
}
}
#[test]
fn test_rrf_fusion_ranking_priority() {
let semantic_results = vec![
create_test_memory("mem-1", "doc_a", "proj-a", Some(0.9)), create_test_memory("mem-3", "doc_c", "proj-a", Some(0.8)), create_test_memory("mem-2", "doc_b", "proj-a", Some(0.7)), ];
let bm25_results = vec![
create_test_memory("mem-2", "doc_b", "proj-a", Some(0.9)), create_test_memory("mem-3", "doc_c", "proj-a", Some(0.8)), create_test_memory("mem-1", "doc_a", "proj-a", Some(0.7)), ];
let fused = rrf_fusion(vec![semantic_results, bm25_results], None).unwrap();
assert_eq!(fused.len(), 3);
assert_eq!(fused[0].similarity.unwrap(), fused[1].similarity.unwrap());
assert!(fused[2].similarity.unwrap() < fused[0].similarity.unwrap());
}
#[test]
fn test_rrf_fusion_preserves_metadata() {
let memory = Memory {
id: "mem-1".to_string(),
project_id: "proj-a".to_string(),
content: "test content".to_string(),
metadata: Some("metadata".to_string()),
embedding: vec![0.1f32; EMBEDDING_DIMS],
similarity: Some(0.9),
created_at: "2024-01-01T00:00:00Z".to_string(),
updated_at: "2024-01-01T00:00:00Z".to_string(),
memory_type: "fact".to_string(),
status: "active".to_string(),
superseded_by: None,
};
let fused = rrf_fusion(vec![vec![memory]], None).unwrap();
assert_eq!(fused.len(), 1);
assert_eq!(fused[0].metadata, Some("metadata".to_string()));
assert_eq!(fused[0].project_id, "proj-a");
}
#[test]
fn test_rrf_fusion_empty_lists_in_vector() {
let results = vec![create_test_memory("mem-1", "content", "proj-a", Some(0.9))];
let fused = rrf_fusion(vec![vec![], results.clone()], None).unwrap();
assert_eq!(fused.len(), 1);
assert_eq!(fused[0].id, "mem-1");
}
#[test]
fn test_rrf_fusion_many_results() {
let list1: Vec<Memory> = (1..=10)
.map(|i| {
create_test_memory(
&format!("mem-{}", i),
&format!("content {}", i),
"proj-a",
None,
)
})
.collect();
let list2: Vec<Memory> = (5..=15)
.map(|i| {
create_test_memory(
&format!("mem-{}", i),
&format!("content {}", i),
"proj-a",
None,
)
})
.collect();
let fused = rrf_fusion(vec![list1, list2], None).unwrap();
assert_eq!(fused.len(), 15);
verify_top_results_are_overlap(&fused, 5, 10);
}
fn verify_top_results_are_overlap(results: &[Memory], start: u32, end: u32) {
let mut overlap_count = 0;
for result in results.iter().take((end - start + 1) as usize) {
let id_num = result
.id
.strip_prefix("mem-")
.unwrap()
.parse::<u32>()
.unwrap();
if id_num >= start && id_num <= end {
overlap_count += 1;
}
}
assert!(
overlap_count >= 3,
"Expected at least 3 overlapping documents in top results, got {}",
overlap_count
);
}
#[test]
fn test_rrf_fusion_order_consistency() {
let list1 = vec![
create_test_memory("mem-1", "a", "proj-a", None),
create_test_memory("mem-2", "b", "proj-a", None),
];
let list2 = vec![
create_test_memory("mem-2", "b", "proj-a", None),
create_test_memory("mem-1", "a", "proj-a", None),
];
let fused1 = rrf_fusion(vec![list1.clone(), list2.clone()], None).unwrap();
let fused2 = rrf_fusion(vec![list1, list2], None).unwrap();
assert_eq!(fused1.len(), fused2.len());
for (r1, r2) in fused1.iter().zip(fused2.iter()) {
assert_eq!(r1.id, r2.id);
}
}
}