use std::collections::HashMap;
use uuid::Uuid;
#[derive(Debug, Clone)]
pub struct SearchConfig {
pub limit: usize,
pub rrf_k: u32,
pub use_fts: bool,
pub use_vector: bool,
pub min_score: f32,
pub pre_fusion_limit: usize,
}
impl Default for SearchConfig {
fn default() -> Self {
Self {
limit: 10,
rrf_k: 60,
use_fts: true,
use_vector: true,
min_score: 0.0,
pre_fusion_limit: 50,
}
}
}
impl SearchConfig {
pub fn with_limit(mut self, limit: usize) -> Self {
self.limit = limit;
self
}
pub fn with_rrf_k(mut self, k: u32) -> Self {
self.rrf_k = k;
self
}
pub fn vector_only(mut self) -> Self {
self.use_fts = false;
self.use_vector = true;
self
}
pub fn fts_only(mut self) -> Self {
self.use_fts = true;
self.use_vector = false;
self
}
pub fn with_min_score(mut self, score: f32) -> Self {
self.min_score = score.clamp(0.0, 1.0);
self
}
}
#[derive(Debug, Clone)]
pub struct SearchResult {
pub document_id: Uuid,
pub chunk_id: Uuid,
pub content: String,
pub score: f32,
pub fts_rank: Option<u32>,
pub vector_rank: Option<u32>,
}
impl SearchResult {
pub fn from_fts(&self) -> bool {
self.fts_rank.is_some()
}
pub fn from_vector(&self) -> bool {
self.vector_rank.is_some()
}
pub fn is_hybrid(&self) -> bool {
self.fts_rank.is_some() && self.vector_rank.is_some()
}
}
#[derive(Debug, Clone)]
pub struct RankedResult {
pub chunk_id: Uuid,
pub document_id: Uuid,
pub content: String,
pub rank: u32, }
pub fn reciprocal_rank_fusion(
fts_results: Vec<RankedResult>,
vector_results: Vec<RankedResult>,
config: &SearchConfig,
) -> Vec<SearchResult> {
let k = config.rrf_k as f32;
struct ChunkInfo {
document_id: Uuid,
content: String,
score: f32,
fts_rank: Option<u32>,
vector_rank: Option<u32>,
}
let mut chunk_scores: HashMap<Uuid, ChunkInfo> = HashMap::new();
for result in fts_results {
let rrf_score = 1.0 / (k + result.rank as f32);
chunk_scores
.entry(result.chunk_id)
.and_modify(|info| {
info.score += rrf_score;
info.fts_rank = Some(result.rank);
})
.or_insert(ChunkInfo {
document_id: result.document_id,
content: result.content,
score: rrf_score,
fts_rank: Some(result.rank),
vector_rank: None,
});
}
for result in vector_results {
let rrf_score = 1.0 / (k + result.rank as f32);
chunk_scores
.entry(result.chunk_id)
.and_modify(|info| {
info.score += rrf_score;
info.vector_rank = Some(result.rank);
})
.or_insert(ChunkInfo {
document_id: result.document_id,
content: result.content,
score: rrf_score,
fts_rank: None,
vector_rank: Some(result.rank),
});
}
let mut results: Vec<SearchResult> = chunk_scores
.into_iter()
.map(|(chunk_id, info)| SearchResult {
document_id: info.document_id,
chunk_id,
content: info.content,
score: info.score,
fts_rank: info.fts_rank,
vector_rank: info.vector_rank,
})
.collect();
if let Some(max_score) = results.iter().map(|r| r.score).reduce(f32::max)
&& max_score > 0.0
{
for result in &mut results {
result.score /= max_score;
}
}
if config.min_score > 0.0 {
results.retain(|r| r.score >= config.min_score);
}
results.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
results.truncate(config.limit);
results
}
#[cfg(test)]
mod tests {
use super::*;
fn make_result(chunk_id: Uuid, doc_id: Uuid, rank: u32) -> RankedResult {
RankedResult {
chunk_id,
document_id: doc_id,
content: format!("content for chunk {}", chunk_id),
rank,
}
}
#[test]
fn test_rrf_single_method() {
let config = SearchConfig::default().with_limit(10);
let chunk1 = Uuid::new_v4();
let chunk2 = Uuid::new_v4();
let doc = Uuid::new_v4();
let fts_results = vec![make_result(chunk1, doc, 1), make_result(chunk2, doc, 2)];
let results = reciprocal_rank_fusion(fts_results, Vec::new(), &config);
assert_eq!(results.len(), 2);
assert!(results[0].score > results[1].score);
assert!(results.iter().all(|r| r.fts_rank.is_some()));
assert!(results.iter().all(|r| r.vector_rank.is_none()));
}
#[test]
fn test_rrf_hybrid_match_boosted() {
let config = SearchConfig::default().with_limit(10);
let chunk1 = Uuid::new_v4(); let chunk2 = Uuid::new_v4(); let chunk3 = Uuid::new_v4(); let doc = Uuid::new_v4();
let fts_results = vec![make_result(chunk1, doc, 1), make_result(chunk2, doc, 2)];
let vector_results = vec![make_result(chunk1, doc, 1), make_result(chunk3, doc, 2)];
let results = reciprocal_rank_fusion(fts_results, vector_results, &config);
assert_eq!(results.len(), 3);
assert_eq!(results[0].chunk_id, chunk1);
assert!(results[0].is_hybrid());
assert!(results[0].score > results[1].score);
assert!(!results[1].is_hybrid());
assert!(!results[2].is_hybrid());
}
#[test]
fn test_rrf_score_normalization() {
let config = SearchConfig::default();
let chunk1 = Uuid::new_v4();
let doc = Uuid::new_v4();
let fts_results = vec![make_result(chunk1, doc, 1)];
let results = reciprocal_rank_fusion(fts_results, Vec::new(), &config);
assert_eq!(results.len(), 1);
assert!((results[0].score - 1.0).abs() < 0.001);
}
#[test]
fn test_rrf_min_score_filter() {
let config = SearchConfig::default().with_limit(10).with_min_score(0.5);
let chunk1 = Uuid::new_v4();
let chunk2 = Uuid::new_v4();
let chunk3 = Uuid::new_v4();
let doc = Uuid::new_v4();
let fts_results = vec![
make_result(chunk1, doc, 1),
make_result(chunk2, doc, 50),
make_result(chunk3, doc, 100),
];
let results = reciprocal_rank_fusion(fts_results, Vec::new(), &config);
for result in &results {
assert!(result.score >= 0.5);
}
}
#[test]
fn test_rrf_limit() {
let config = SearchConfig::default().with_limit(2);
let doc = Uuid::new_v4();
let fts_results: Vec<_> = (1..=5)
.map(|i| make_result(Uuid::new_v4(), doc, i))
.collect();
let results = reciprocal_rank_fusion(fts_results, Vec::new(), &config);
assert_eq!(results.len(), 2);
}
#[test]
fn test_rrf_k_parameter() {
let chunk1 = Uuid::new_v4();
let chunk2 = Uuid::new_v4();
let doc = Uuid::new_v4();
let fts_results = vec![make_result(chunk1, doc, 1), make_result(chunk2, doc, 2)];
let config_low_k = SearchConfig::default().with_rrf_k(10);
let results_low = reciprocal_rank_fusion(fts_results.clone(), Vec::new(), &config_low_k);
let config_high_k = SearchConfig::default().with_rrf_k(100);
let results_high = reciprocal_rank_fusion(fts_results, Vec::new(), &config_high_k);
let diff_low = results_low[0].score - results_low[1].score;
let diff_high = results_high[0].score - results_high[1].score;
assert!(diff_low > diff_high);
}
#[test]
fn test_search_config_builders() {
let config = SearchConfig::default()
.with_limit(20)
.with_rrf_k(30)
.with_min_score(0.1);
assert_eq!(config.limit, 20);
assert_eq!(config.rrf_k, 30);
assert!((config.min_score - 0.1).abs() < 0.001);
assert!(config.use_fts);
assert!(config.use_vector);
let fts_only = SearchConfig::default().fts_only();
assert!(fts_only.use_fts);
assert!(!fts_only.use_vector);
let vector_only = SearchConfig::default().vector_only();
assert!(!vector_only.use_fts);
assert!(vector_only.use_vector);
}
}