use std::collections::HashMap;
use uuid::Uuid;
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
pub enum FusionStrategy {
#[default]
Rrf,
WeightedScore,
}
#[derive(Debug, Clone)]
pub struct SearchConfig {
pub limit: usize,
pub rrf_k: u32,
pub use_fts: bool,
pub use_vector: bool,
pub min_score: f32,
pub pre_fusion_limit: usize,
pub fusion_strategy: FusionStrategy,
pub fts_weight: f32,
pub vector_weight: f32,
}
impl Default for SearchConfig {
fn default() -> Self {
Self {
limit: 10,
rrf_k: 60,
use_fts: true,
use_vector: true,
min_score: 0.0,
pre_fusion_limit: 50,
fusion_strategy: FusionStrategy::default(),
fts_weight: 0.5,
vector_weight: 0.5,
}
}
}
impl SearchConfig {
pub fn with_limit(mut self, limit: usize) -> Self {
self.limit = limit;
self
}
pub fn with_rrf_k(mut self, k: u32) -> Self {
self.rrf_k = k;
self
}
pub fn vector_only(mut self) -> Self {
self.use_fts = false;
self.use_vector = true;
self
}
pub fn fts_only(mut self) -> Self {
self.use_fts = true;
self.use_vector = false;
self
}
pub fn with_min_score(mut self, score: f32) -> Self {
self.min_score = score.clamp(0.0, 1.0);
self
}
pub fn with_fusion_strategy(mut self, strategy: FusionStrategy) -> Self {
self.fusion_strategy = strategy;
self
}
pub fn with_fts_weight(mut self, weight: f32) -> Self {
if weight.is_finite() && weight >= 0.0 {
self.fts_weight = weight;
}
self
}
pub fn with_vector_weight(mut self, weight: f32) -> Self {
if weight.is_finite() && weight >= 0.0 {
self.vector_weight = weight;
}
self
}
}
#[derive(Debug, Clone)]
pub struct SearchResult {
pub document_id: Uuid,
pub document_path: String,
pub chunk_id: Uuid,
pub content: String,
pub score: f32,
pub fts_rank: Option<u32>,
pub vector_rank: Option<u32>,
}
impl SearchResult {
pub fn from_fts(&self) -> bool {
self.fts_rank.is_some()
}
pub fn from_vector(&self) -> bool {
self.vector_rank.is_some()
}
pub fn is_hybrid(&self) -> bool {
self.fts_rank.is_some() && self.vector_rank.is_some()
}
}
#[derive(Debug, Clone)]
pub struct RankedResult {
pub chunk_id: Uuid,
pub document_id: Uuid,
pub document_path: String,
pub content: String,
pub rank: u32, }
pub fn fuse_results(
fts_results: Vec<RankedResult>,
vector_results: Vec<RankedResult>,
config: &SearchConfig,
) -> Vec<SearchResult> {
match config.fusion_strategy {
FusionStrategy::Rrf => reciprocal_rank_fusion(fts_results, vector_results, config),
FusionStrategy::WeightedScore => weighted_score_fusion(fts_results, vector_results, config),
}
}
pub fn reciprocal_rank_fusion(
fts_results: Vec<RankedResult>,
vector_results: Vec<RankedResult>,
config: &SearchConfig,
) -> Vec<SearchResult> {
let k = config.rrf_k as f32;
struct ChunkInfo {
document_id: Uuid,
document_path: String,
content: String,
score: f32,
fts_rank: Option<u32>,
vector_rank: Option<u32>,
}
let mut chunk_scores: HashMap<Uuid, ChunkInfo> = HashMap::new();
for result in fts_results {
let rrf_score = 1.0 / (k + result.rank as f32);
chunk_scores
.entry(result.chunk_id)
.and_modify(|info| {
info.score += rrf_score;
info.fts_rank = Some(result.rank);
})
.or_insert(ChunkInfo {
document_id: result.document_id,
document_path: result.document_path,
content: result.content,
score: rrf_score,
fts_rank: Some(result.rank),
vector_rank: None,
});
}
for result in vector_results {
let rrf_score = 1.0 / (k + result.rank as f32);
chunk_scores
.entry(result.chunk_id)
.and_modify(|info| {
info.score += rrf_score;
info.vector_rank = Some(result.rank);
})
.or_insert(ChunkInfo {
document_id: result.document_id,
document_path: result.document_path,
content: result.content,
score: rrf_score,
fts_rank: None,
vector_rank: Some(result.rank),
});
}
let mut results: Vec<SearchResult> = chunk_scores
.into_iter()
.map(|(chunk_id, info)| SearchResult {
document_id: info.document_id,
document_path: info.document_path,
chunk_id,
content: info.content,
score: info.score,
fts_rank: info.fts_rank,
vector_rank: info.vector_rank,
})
.collect();
if let Some(max_score) = results.iter().map(|r| r.score).reduce(f32::max)
&& max_score > 0.0
{
for result in &mut results {
result.score /= max_score;
}
}
if config.min_score > 0.0 {
results.retain(|r| r.score >= config.min_score);
}
results.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
results.truncate(config.limit);
results
}
pub fn weighted_score_fusion(
fts_results: Vec<RankedResult>,
vector_results: Vec<RankedResult>,
config: &SearchConfig,
) -> Vec<SearchResult> {
struct ChunkInfo {
document_id: Uuid,
document_path: String,
content: String,
score: f32,
fts_rank: Option<u32>,
vector_rank: Option<u32>,
}
let mut chunk_scores: HashMap<Uuid, ChunkInfo> = HashMap::new();
for result in fts_results {
let score = config.fts_weight * (1.0 / result.rank as f32);
chunk_scores
.entry(result.chunk_id)
.and_modify(|info| {
info.score += score;
info.fts_rank = Some(result.rank);
})
.or_insert(ChunkInfo {
document_id: result.document_id,
document_path: result.document_path,
content: result.content,
score,
fts_rank: Some(result.rank),
vector_rank: None,
});
}
for result in vector_results {
let score = config.vector_weight * (1.0 / result.rank as f32);
chunk_scores
.entry(result.chunk_id)
.and_modify(|info| {
info.score += score;
info.vector_rank = Some(result.rank);
})
.or_insert(ChunkInfo {
document_id: result.document_id,
document_path: result.document_path,
content: result.content,
score,
fts_rank: None,
vector_rank: Some(result.rank),
});
}
let mut results: Vec<SearchResult> = chunk_scores
.into_iter()
.map(|(chunk_id, info)| SearchResult {
document_id: info.document_id,
document_path: info.document_path,
chunk_id,
content: info.content,
score: info.score,
fts_rank: info.fts_rank,
vector_rank: info.vector_rank,
})
.collect();
if let Some(max_score) = results.iter().map(|r| r.score).reduce(f32::max)
&& max_score > 0.0
{
for result in &mut results {
result.score /= max_score;
}
}
if config.min_score > 0.0 {
results.retain(|r| r.score >= config.min_score);
}
results.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
results.truncate(config.limit);
results
}
#[cfg(test)]
mod tests {
use super::*;
fn make_result(chunk_id: Uuid, doc_id: Uuid, rank: u32) -> RankedResult {
RankedResult {
chunk_id,
document_id: doc_id,
document_path: format!("docs/{}.md", doc_id),
content: format!("content for chunk {}", chunk_id),
rank,
}
}
fn make_result_with_path(chunk_id: Uuid, doc_id: Uuid, path: &str, rank: u32) -> RankedResult {
RankedResult {
chunk_id,
document_id: doc_id,
document_path: path.to_string(),
content: format!("content for chunk {}", chunk_id),
rank,
}
}
#[test]
fn test_rrf_propagates_document_path() {
let config = SearchConfig::default().with_limit(10);
let doc_a = Uuid::new_v4();
let doc_b = Uuid::new_v4();
let chunk1 = Uuid::new_v4();
let chunk2 = Uuid::new_v4();
let chunk3 = Uuid::new_v4();
let fts_results = vec![
make_result_with_path(chunk1, doc_a, "notes/todo.md", 1),
make_result_with_path(chunk2, doc_b, "journal/2024-01-15.md", 2),
];
let vector_results = vec![
make_result_with_path(chunk1, doc_a, "notes/todo.md", 1),
make_result_with_path(chunk3, doc_b, "journal/2024-01-15.md", 2),
];
let results = reciprocal_rank_fusion(fts_results, vector_results, &config);
for result in &results {
assert!(
Uuid::parse_str(&result.document_path).is_err(),
"document_path looks like a UUID ('{}'), expected a file path",
result.document_path
);
}
let paths: Vec<&str> = results.iter().map(|r| r.document_path.as_str()).collect();
assert!(
paths.contains(&"notes/todo.md"),
"missing notes/todo.md in {:?}",
paths
);
assert!(
paths.contains(&"journal/2024-01-15.md"),
"missing journal/2024-01-15.md in {:?}",
paths
);
let hybrid = results.iter().find(|r| r.chunk_id == chunk1).unwrap();
assert_eq!(hybrid.document_path, "notes/todo.md");
assert!(hybrid.is_hybrid());
}
#[test]
fn test_rrf_single_method() {
let config = SearchConfig::default().with_limit(10);
let chunk1 = Uuid::new_v4();
let chunk2 = Uuid::new_v4();
let doc = Uuid::new_v4();
let fts_results = vec![make_result(chunk1, doc, 1), make_result(chunk2, doc, 2)];
let results = reciprocal_rank_fusion(fts_results, Vec::new(), &config);
assert_eq!(results.len(), 2);
assert!(results[0].score > results[1].score);
assert!(results.iter().all(|r| r.fts_rank.is_some()));
assert!(results.iter().all(|r| r.vector_rank.is_none()));
}
#[test]
fn test_rrf_hybrid_match_boosted() {
let config = SearchConfig::default().with_limit(10);
let chunk1 = Uuid::new_v4(); let chunk2 = Uuid::new_v4(); let chunk3 = Uuid::new_v4(); let doc = Uuid::new_v4();
let fts_results = vec![make_result(chunk1, doc, 1), make_result(chunk2, doc, 2)];
let vector_results = vec![make_result(chunk1, doc, 1), make_result(chunk3, doc, 2)];
let results = reciprocal_rank_fusion(fts_results, vector_results, &config);
assert_eq!(results.len(), 3);
assert_eq!(results[0].chunk_id, chunk1);
assert!(results[0].is_hybrid());
assert!(results[0].score > results[1].score);
assert!(!results[1].is_hybrid());
assert!(!results[2].is_hybrid());
}
#[test]
fn test_rrf_score_normalization() {
let config = SearchConfig::default();
let chunk1 = Uuid::new_v4();
let doc = Uuid::new_v4();
let fts_results = vec![make_result(chunk1, doc, 1)];
let results = reciprocal_rank_fusion(fts_results, Vec::new(), &config);
assert_eq!(results.len(), 1);
assert!((results[0].score - 1.0).abs() < 0.001);
}
#[test]
fn test_rrf_min_score_filter() {
let config = SearchConfig::default().with_limit(10).with_min_score(0.5);
let chunk1 = Uuid::new_v4();
let chunk2 = Uuid::new_v4();
let chunk3 = Uuid::new_v4();
let doc = Uuid::new_v4();
let fts_results = vec![
make_result(chunk1, doc, 1),
make_result(chunk2, doc, 50),
make_result(chunk3, doc, 100),
];
let results = reciprocal_rank_fusion(fts_results, Vec::new(), &config);
for result in &results {
assert!(result.score >= 0.5);
}
}
#[test]
fn test_rrf_limit() {
let config = SearchConfig::default().with_limit(2);
let doc = Uuid::new_v4();
let fts_results: Vec<_> = (1..=5)
.map(|i| make_result(Uuid::new_v4(), doc, i))
.collect();
let results = reciprocal_rank_fusion(fts_results, Vec::new(), &config);
assert_eq!(results.len(), 2);
}
#[test]
fn test_rrf_k_parameter() {
let chunk1 = Uuid::new_v4();
let chunk2 = Uuid::new_v4();
let doc = Uuid::new_v4();
let fts_results = vec![make_result(chunk1, doc, 1), make_result(chunk2, doc, 2)];
let config_low_k = SearchConfig::default().with_rrf_k(10);
let results_low = reciprocal_rank_fusion(fts_results.clone(), Vec::new(), &config_low_k);
let config_high_k = SearchConfig::default().with_rrf_k(100);
let results_high = reciprocal_rank_fusion(fts_results, Vec::new(), &config_high_k);
let diff_low = results_low[0].score - results_low[1].score;
let diff_high = results_high[0].score - results_high[1].score;
assert!(diff_low > diff_high);
}
#[test]
fn test_search_config_builders() {
let config = SearchConfig::default()
.with_limit(20)
.with_rrf_k(30)
.with_min_score(0.1);
assert_eq!(config.limit, 20);
assert_eq!(config.rrf_k, 30);
assert!((config.min_score - 0.1).abs() < 0.001);
assert!(config.use_fts);
assert!(config.use_vector);
let fts_only = SearchConfig::default().fts_only();
assert!(fts_only.use_fts);
assert!(!fts_only.use_vector);
let vector_only = SearchConfig::default().vector_only();
assert!(!vector_only.use_fts);
assert!(vector_only.use_vector);
let weighted = SearchConfig::default()
.with_fusion_strategy(FusionStrategy::WeightedScore)
.with_fts_weight(0.8)
.with_vector_weight(0.2);
assert_eq!(weighted.fusion_strategy, FusionStrategy::WeightedScore);
assert!((weighted.fts_weight - 0.8).abs() < 0.001);
assert!((weighted.vector_weight - 0.2).abs() < 0.001);
}
#[test]
fn test_weighted_fusion_basic() {
let config = SearchConfig::default()
.with_fusion_strategy(FusionStrategy::WeightedScore)
.with_fts_weight(1.0)
.with_vector_weight(1.0)
.with_limit(10);
let chunk1 = Uuid::new_v4(); let chunk2 = Uuid::new_v4(); let chunk3 = Uuid::new_v4(); let doc = Uuid::new_v4();
let fts = vec![make_result(chunk1, doc, 1), make_result(chunk2, doc, 2)];
let vec_results = vec![make_result(chunk1, doc, 1), make_result(chunk3, doc, 2)];
let results = weighted_score_fusion(fts, vec_results, &config);
assert_eq!(results.len(), 3);
assert_eq!(results[0].chunk_id, chunk1);
assert!(results[0].is_hybrid());
assert!(results[0].score > results[1].score);
}
#[test]
fn test_weighted_fusion_fts_boost() {
let config = SearchConfig::default()
.with_fusion_strategy(FusionStrategy::WeightedScore)
.with_fts_weight(2.0)
.with_vector_weight(0.5)
.with_limit(10);
let chunk_fts = Uuid::new_v4(); let chunk_vec = Uuid::new_v4(); let doc = Uuid::new_v4();
let fts = vec![make_result(chunk_fts, doc, 2)];
let vec_results = vec![make_result(chunk_vec, doc, 2)];
let results = weighted_score_fusion(fts, vec_results, &config);
assert_eq!(results.len(), 2);
assert_eq!(results[0].chunk_id, chunk_fts);
assert!(results[0].from_fts());
assert!(!results[0].from_vector());
}
#[test]
fn test_weighted_fusion_single_source() {
let config = SearchConfig::default()
.with_fusion_strategy(FusionStrategy::WeightedScore)
.with_limit(10);
let chunk1 = Uuid::new_v4();
let chunk2 = Uuid::new_v4();
let doc = Uuid::new_v4();
let fts = vec![make_result(chunk1, doc, 1), make_result(chunk2, doc, 3)];
let results = weighted_score_fusion(fts, Vec::new(), &config);
assert_eq!(results.len(), 2);
assert_eq!(results[0].chunk_id, chunk1);
assert!(results[0].score > results[1].score);
assert!((results[0].score - 1.0).abs() < 0.001);
}
#[test]
fn test_weight_setters_reject_invalid() {
let config = SearchConfig::default();
let original_fts = config.fts_weight;
let original_vec = config.vector_weight;
let c = config.clone().with_fts_weight(f32::NAN);
assert!((c.fts_weight - original_fts).abs() < 0.001);
let c = config.clone().with_vector_weight(f32::INFINITY);
assert!((c.vector_weight - original_vec).abs() < 0.001);
let c = config.clone().with_fts_weight(-1.0);
assert!((c.fts_weight - original_fts).abs() < 0.001);
let c = config.clone().with_vector_weight(f32::NEG_INFINITY);
assert!((c.vector_weight - original_vec).abs() < 0.001);
let c = config.clone().with_fts_weight(2.0);
assert!((c.fts_weight - 2.0).abs() < 0.001);
let c = config.clone().with_vector_weight(0.0);
assert!(c.vector_weight.abs() < 0.001);
}
#[test]
fn test_fuse_results_dispatches_correctly() {
let chunk1 = Uuid::new_v4();
let doc = Uuid::new_v4();
let fts = vec![make_result(chunk1, doc, 1)];
let rrf_config = SearchConfig::default().with_limit(10);
let rrf_results = fuse_results(fts.clone(), Vec::new(), &rrf_config);
assert_eq!(rrf_results.len(), 1);
let weighted_config = SearchConfig::default()
.with_fusion_strategy(FusionStrategy::WeightedScore)
.with_limit(10);
let weighted_results = fuse_results(fts, Vec::new(), &weighted_config);
assert_eq!(weighted_results.len(), 1);
assert!((rrf_results[0].score - 1.0).abs() < 0.001);
assert!((weighted_results[0].score - 1.0).abs() < 0.001);
}
#[test]
fn test_rrf_both_empty() {
let config = SearchConfig::default();
let results = reciprocal_rank_fusion(Vec::new(), Vec::new(), &config);
assert!(results.is_empty());
}
#[test]
fn test_rrf_fts_only_no_vector() {
let config = SearchConfig::default().with_limit(10);
let chunk1 = Uuid::new_v4();
let chunk2 = Uuid::new_v4();
let chunk3 = Uuid::new_v4();
let doc = Uuid::new_v4();
let fts_results = vec![
make_result(chunk1, doc, 1),
make_result(chunk2, doc, 2),
make_result(chunk3, doc, 3),
];
let results = reciprocal_rank_fusion(fts_results, Vec::new(), &config);
assert_eq!(results.len(), 3);
assert!(results.iter().all(|r| r.from_fts()));
assert!(results.iter().all(|r| !r.from_vector()));
assert!(results.iter().all(|r| !r.is_hybrid()));
for w in results.windows(2) {
assert!(w[0].score >= w[1].score);
}
}
#[test]
fn test_rrf_vector_only_no_fts() {
let config = SearchConfig::default().with_limit(10);
let chunk1 = Uuid::new_v4();
let chunk2 = Uuid::new_v4();
let chunk3 = Uuid::new_v4();
let doc = Uuid::new_v4();
let vector_results = vec![
make_result(chunk1, doc, 1),
make_result(chunk2, doc, 2),
make_result(chunk3, doc, 3),
];
let results = reciprocal_rank_fusion(Vec::new(), vector_results, &config);
assert_eq!(results.len(), 3);
assert!(results.iter().all(|r| r.from_vector()));
assert!(results.iter().all(|r| !r.from_fts()));
assert!(results.iter().all(|r| !r.is_hybrid()));
for w in results.windows(2) {
assert!(w[0].score >= w[1].score);
}
}
#[test]
fn test_rrf_duplicate_chunks_merged() {
let config = SearchConfig::default().with_limit(10);
let shared_chunk = Uuid::new_v4();
let fts_only_chunk = Uuid::new_v4();
let vector_only_chunk = Uuid::new_v4();
let doc = Uuid::new_v4();
let fts_results = vec![
make_result(fts_only_chunk, doc, 1),
make_result(shared_chunk, doc, 2),
];
let vector_results = vec![
make_result(vector_only_chunk, doc, 1),
make_result(shared_chunk, doc, 3),
];
let results = reciprocal_rank_fusion(fts_results, vector_results, &config);
assert_eq!(results.len(), 3);
let shared = results.iter().find(|r| r.chunk_id == shared_chunk).unwrap();
assert!(shared.is_hybrid());
assert_eq!(shared.fts_rank, Some(2));
assert_eq!(shared.vector_rank, Some(3));
assert_eq!(results[0].chunk_id, shared_chunk);
}
#[test]
fn test_rrf_limit_zero_returns_empty() {
let config = SearchConfig::default().with_limit(0);
let doc = Uuid::new_v4();
let fts_results = vec![
make_result(Uuid::new_v4(), doc, 1),
make_result(Uuid::new_v4(), doc, 2),
];
let results = reciprocal_rank_fusion(fts_results, Vec::new(), &config);
assert!(results.is_empty());
}
#[test]
fn test_rrf_min_score_one_filters_all() {
let config = SearchConfig::default().with_limit(10).with_min_score(1.0);
let doc = Uuid::new_v4();
let fts_results = vec![
make_result(Uuid::new_v4(), doc, 1),
make_result(Uuid::new_v4(), doc, 2),
make_result(Uuid::new_v4(), doc, 3),
];
let results = reciprocal_rank_fusion(fts_results, Vec::new(), &config);
assert_eq!(results.len(), 1);
assert!((results[0].score - 1.0).abs() < 0.001);
}
#[test]
fn test_search_config_fts_only() {
let config = SearchConfig::default().fts_only();
assert!(config.use_fts);
assert!(!config.use_vector);
assert_eq!(config.limit, 10);
assert_eq!(config.rrf_k, 60);
assert!((config.min_score - 0.0).abs() < f32::EPSILON);
}
#[test]
fn test_search_config_vector_only() {
let config = SearchConfig::default().vector_only();
assert!(!config.use_fts);
assert!(config.use_vector);
assert_eq!(config.limit, 10);
assert_eq!(config.rrf_k, 60);
assert!((config.min_score - 0.0).abs() < f32::EPSILON);
}
}