use std::collections::HashMap;
pub(super) fn escape_fts5_query(query: &str) -> Option<String> {
let tokens: Vec<String> = query
.split_whitespace()
.filter(|token| token.chars().any(|c| c.is_alphanumeric()))
.map(|token| {
let escaped = token.replace('"', "\"\"");
format!("\"{escaped}\"")
})
.collect();
if tokens.is_empty() {
None
} else {
Some(tokens.join(" "))
}
}
#[derive(Debug)]
pub(super) struct RankedItem {
pub id: String,
pub rank: u32,
}
pub(super) fn rrf_merge(lists: &[&[RankedItem]], k: u32) -> Vec<(String, f64)> {
let mut scores: HashMap<&str, f64> = HashMap::new();
for list in lists {
for item in *list {
*scores.entry(&item.id).or_default() += 1.0 / (k as f64 + item.rank as f64);
}
}
let mut results: Vec<(String, f64)> = scores
.into_iter()
.map(|(id, score)| (id.to_string(), score))
.collect();
results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
results
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn escape_simple_query() {
assert_eq!(
escape_fts5_query("hello world"),
Some(r#""hello" "world""#.to_string())
);
}
#[test]
fn escape_operators() {
assert_eq!(
escape_fts5_query("foo AND bar"),
Some(r#""foo" "AND" "bar""#.to_string())
);
}
#[test]
fn escape_quotes_in_tokens() {
assert_eq!(
escape_fts5_query(r#"say "hello""#),
Some(r#""say" """hello""""#.to_string())
);
}
#[test]
fn escape_special_chars() {
assert_eq!(
escape_fts5_query("foo* NEAR bar"),
Some(r#""foo*" "NEAR" "bar""#.to_string())
);
}
#[test]
fn escape_empty_query() {
assert_eq!(escape_fts5_query(""), None);
assert_eq!(escape_fts5_query(" "), None);
}
#[test]
fn escape_drops_punctuation_only_tokens() {
assert_eq!(
escape_fts5_query("error -> fix"),
Some(r#""error" "fix""#.to_string())
);
assert_eq!(
escape_fts5_query("|| operator"),
Some(r#""operator""#.to_string())
);
assert_eq!(escape_fts5_query("-> => ||"), None);
}
#[test]
fn escape_adversarial_inputs() {
assert_eq!(escape_fts5_query(r#""""#), None);
assert_eq!(
escape_fts5_query("rust 错误处理"),
Some(r#""rust" "错误处理""#.to_string())
);
assert_eq!(escape_fts5_query("🦀 rust"), Some(r#""rust""#.to_string()));
assert_eq!(
escape_fts5_query("hello\t\nworld"),
Some(r#""hello" "world""#.to_string())
);
}
#[test]
fn escape_single_token() {
assert_eq!(escape_fts5_query("rust"), Some(r#""rust""#.to_string()));
}
#[test]
fn rrf_single_list() {
let list = vec![
RankedItem {
id: "a".into(),
rank: 1,
},
RankedItem {
id: "b".into(),
rank: 2,
},
RankedItem {
id: "c".into(),
rank: 3,
},
];
let results = rrf_merge(&[&list], 60);
assert_eq!(results.len(), 3);
assert_eq!(results[0].0, "a");
assert_eq!(results[1].0, "b");
assert_eq!(results[2].0, "c");
assert!((results[0].1 - 1.0 / 61.0).abs() < 1e-10);
}
#[test]
fn rrf_two_lists_boosts_overlap() {
let vec_list = vec![
RankedItem {
id: "a".into(),
rank: 1,
},
RankedItem {
id: "b".into(),
rank: 2,
},
];
let fts_list = vec![
RankedItem {
id: "b".into(),
rank: 1,
},
RankedItem {
id: "c".into(),
rank: 2,
},
];
let results = rrf_merge(&[&vec_list, &fts_list], 60);
assert_eq!(results[0].0, "b");
let expected_b = 1.0 / 62.0 + 1.0 / 61.0;
assert!((results[0].1 - expected_b).abs() < 1e-10);
}
#[test]
fn rrf_empty_lists() {
let results = rrf_merge(&[], 60);
assert!(results.is_empty());
let empty: Vec<RankedItem> = vec![];
let results = rrf_merge(&[&empty], 60);
assert!(results.is_empty());
}
#[test]
fn rrf_equal_ranks_equal_scores() {
let list = vec![
RankedItem {
id: "a".into(),
rank: 1,
},
RankedItem {
id: "b".into(),
rank: 1,
},
];
let results = rrf_merge(&[&list], 60);
assert_eq!(results.len(), 2);
assert!((results[0].1 - results[1].1).abs() < 1e-10);
}
#[test]
fn rrf_k_affects_score_spread() {
let list = vec![
RankedItem {
id: "a".into(),
rank: 1,
},
RankedItem {
id: "b".into(),
rank: 2,
},
];
let k1 = rrf_merge(&[&list], 1);
let k60 = rrf_merge(&[&list], 60);
let gap_k1 = k1[0].1 - k1[1].1;
let gap_k60 = k60[0].1 - k60[1].1;
assert!(gap_k1 > gap_k60);
}
#[test]
fn rrf_three_lists() {
let l1 = vec![RankedItem {
id: "a".into(),
rank: 1,
}];
let l2 = vec![RankedItem {
id: "a".into(),
rank: 1,
}];
let l3 = vec![RankedItem {
id: "a".into(),
rank: 1,
}];
let results = rrf_merge(&[&l1, &l2, &l3], 60);
assert_eq!(results.len(), 1);
assert!((results[0].1 - 3.0 / 61.0).abs() < 1e-10);
}
}