1use std::collections::HashMap;
2
3#[derive(Debug, Clone)]
5pub struct SearchResult {
6 pub id: String,
7 pub score: f64,
8}
9
10pub trait SearchEngine {
13 fn add_document(&mut self, id: &str, fields: &[(&str, &str, f64)]);
16
17 fn search(&self, query: &str, limit: usize) -> Vec<SearchResult>;
20}
21
22#[derive(Debug)]
24struct FieldData {
25 tokens: Vec<String>,
26 weight: f64,
27}
28
29#[derive(Debug)]
31struct Document {
32 fields: HashMap<String, FieldData>,
33}
34
35#[derive(Debug)]
40pub struct BM25Index {
41 documents: HashMap<String, Document>,
42 field_total_tokens: HashMap<String, usize>,
45 field_doc_count: HashMap<String, usize>,
47 doc_freq: HashMap<String, Vec<String>>,
49}
50
51impl Default for BM25Index {
52 fn default() -> Self {
53 Self::new()
54 }
55}
56
57impl BM25Index {
58 const K1: f64 = 1.2;
60 const B: f64 = 0.75;
62
63 pub fn new() -> Self {
64 Self {
65 documents: HashMap::new(),
66 field_total_tokens: HashMap::new(),
67 field_doc_count: HashMap::new(),
68 doc_freq: HashMap::new(),
69 }
70 }
71
72 fn tokenize(text: &str) -> Vec<String> {
74 text.split(|c: char| !c.is_alphanumeric() && c != '_')
75 .map(|s| s.to_lowercase())
76 .filter(|s| s.len() > 1)
77 .collect()
78 }
79
80 fn idf(&self, term: &str) -> f64 {
83 let n = self.documents.len() as f64;
84 let df = self.doc_freq.get(term).map_or(0, |docs| docs.len()) as f64;
85 f64::ln((n - df + 0.5) / (df + 0.5) + 1.0)
86 }
87
88 fn avg_field_len(&self, field_name: &str) -> f64 {
90 let total = *self.field_total_tokens.get(field_name).unwrap_or(&0) as f64;
91 let count = *self.field_doc_count.get(field_name).unwrap_or(&0) as f64;
92 if count == 0.0 {
93 return 0.0;
94 }
95 total / count
96 }
97}
98
99impl SearchEngine for BM25Index {
100 fn add_document(&mut self, id: &str, fields: &[(&str, &str, f64)]) {
101 let mut doc_fields = HashMap::new();
102 let mut seen_terms: HashMap<String, bool> = HashMap::new();
103
104 for &(field_name, field_value, weight) in fields {
105 let tokens = Self::tokenize(field_value);
106
107 *self
109 .field_total_tokens
110 .entry(field_name.to_string())
111 .or_insert(0) += tokens.len();
112 *self
113 .field_doc_count
114 .entry(field_name.to_string())
115 .or_insert(0) += 1;
116
117 for token in &tokens {
119 seen_terms.entry(token.clone()).or_insert(true);
120 }
121
122 doc_fields.insert(field_name.to_string(), FieldData { tokens, weight });
123 }
124
125 for term in seen_terms.keys() {
127 self.doc_freq
128 .entry(term.clone())
129 .or_default()
130 .push(id.to_string());
131 }
132
133 self.documents
134 .insert(id.to_string(), Document { fields: doc_fields });
135 }
136
137 fn search(&self, query: &str, limit: usize) -> Vec<SearchResult> {
138 let query_tokens = Self::tokenize(query);
139 if query_tokens.is_empty() || self.documents.is_empty() {
140 return Vec::new();
141 }
142
143 let mut scores: HashMap<&str, f64> = HashMap::new();
144
145 for term in &query_tokens {
146 let idf = self.idf(term);
147
148 for (doc_id, doc) in &self.documents {
149 let mut doc_term_score = 0.0;
150
151 for (field_name, field_data) in &doc.fields {
152 let tf = field_data.tokens.iter().filter(|t| *t == term).count() as f64;
153
154 if tf == 0.0 {
155 continue;
156 }
157
158 let field_len = field_data.tokens.len() as f64;
159 let avg_fl = self.avg_field_len(field_name);
160
161 let tf_norm = if avg_fl == 0.0 {
162 0.0
163 } else {
164 (tf * (Self::K1 + 1.0))
165 / (tf + Self::K1 * (1.0 - Self::B + Self::B * field_len / avg_fl))
166 };
167
168 doc_term_score += idf * tf_norm * field_data.weight;
169 }
170
171 if doc_term_score > 0.0 {
172 *scores.entry(doc_id.as_str()).or_insert(0.0) += doc_term_score;
173 }
174 }
175 }
176
177 let mut results: Vec<SearchResult> = scores
178 .into_iter()
179 .map(|(id, score)| SearchResult {
180 id: id.to_string(),
181 score,
182 })
183 .collect();
184
185 results.sort_by(|a, b| {
186 b.score
187 .partial_cmp(&a.score)
188 .unwrap_or(std::cmp::Ordering::Equal)
189 });
190 results.truncate(limit);
191 results
192 }
193}