1#![deny(missing_docs)]
26
27#[derive(Debug, Clone, Copy)]
29pub struct Bm25Opts {
30 pub k1: f32,
32 pub b: f32,
34}
35
36impl Default for Bm25Opts {
37 fn default() -> Self {
38 Self { k1: 1.2, b: 0.75 }
39 }
40}
41
42pub fn rerank<S: AsRef<str>>(query: &str, docs: &[S], opts: Bm25Opts) -> Vec<usize> {
46 let scores = score(query, docs, opts);
47 let mut indices: Vec<usize> = (0..docs.len()).collect();
48 indices.sort_by(|&a, &b| {
49 scores[b]
50 .partial_cmp(&scores[a])
51 .unwrap_or(std::cmp::Ordering::Equal)
52 });
53 indices
54}
55
56pub fn score<S: AsRef<str>>(query: &str, docs: &[S], opts: Bm25Opts) -> Vec<f32> {
58 let q_terms: Vec<String> = tokenize(query);
59 if q_terms.is_empty() || docs.is_empty() {
60 return vec![0.0; docs.len()];
61 }
62
63 let doc_tokens: Vec<Vec<String>> = docs.iter().map(|d| tokenize(d.as_ref())).collect();
64 let lens: Vec<f32> = doc_tokens.iter().map(|t| t.len() as f32).collect();
65 let avgdl: f32 = if lens.is_empty() {
66 0.0
67 } else {
68 lens.iter().sum::<f32>() / lens.len() as f32
69 };
70 let n = doc_tokens.len() as f32;
71
72 let mut scores = vec![0.0_f32; doc_tokens.len()];
73 for term in &q_terms {
74 let df = doc_tokens
76 .iter()
77 .filter(|t| t.iter().any(|x| x == term))
78 .count() as f32;
79 if df == 0.0 {
80 continue;
81 }
82 let idf = ((n - df + 0.5) / (df + 0.5) + 1.0).ln();
84
85 for (i, tokens) in doc_tokens.iter().enumerate() {
86 let tf = tokens.iter().filter(|x| *x == term).count() as f32;
87 if tf == 0.0 {
88 continue;
89 }
90 let dl = lens[i];
91 let denom = tf + opts.k1 * (1.0 - opts.b + opts.b * (dl / avgdl.max(1.0)));
92 scores[i] += idf * (tf * (opts.k1 + 1.0)) / denom;
93 }
94 }
95 scores
96}
97
98fn tokenize(s: &str) -> Vec<String> {
99 s.split(|c: char| !c.is_alphanumeric())
100 .filter(|t| !t.is_empty())
101 .map(|t| t.to_ascii_lowercase())
102 .collect()
103}