1use serde::{Deserialize, Serialize};
8use std::collections::{HashMap, HashSet, VecDeque};
9
10pub fn tokenize(text: &str) -> Vec<String> {
15 let mut tokens = Vec::new();
16
17 for word in text.split_whitespace() {
19 let segments = split_on_punctuation(word);
21 for segment in segments {
22 let sub_tokens = split_camel_case(&segment);
24 for token in sub_tokens {
25 let lower = token.to_lowercase();
26 if lower.len() >= 2 {
27 tokens.push(lower);
28 }
29 }
30 }
31 }
32
33 tokens
34}
35
36fn split_on_punctuation(s: &str) -> Vec<String> {
39 let mut segments = Vec::new();
40 let mut current = String::new();
41
42 for ch in s.chars() {
43 if ch.is_alphanumeric() {
44 current.push(ch);
45 } else {
46 if !current.is_empty() {
48 segments.push(std::mem::take(&mut current));
49 }
50 }
51 }
52 if !current.is_empty() {
53 segments.push(current);
54 }
55
56 segments
57}
58
59fn split_camel_case(s: &str) -> Vec<String> {
64 if s.is_empty() {
65 return vec![];
66 }
67
68 let chars: Vec<char> = s.chars().collect();
69 let mut parts = Vec::new();
70 let mut start = 0;
71
72 for i in 1..chars.len() {
73 let prev = chars[i - 1];
74 let curr = chars[i];
75
76 let lower_to_upper = prev.is_lowercase() && curr.is_uppercase();
78
79 let upper_run_end =
81 i >= 2 && chars[i - 2].is_uppercase() && prev.is_uppercase() && curr.is_lowercase();
82
83 let digit_boundary = (prev.is_alphabetic() && curr.is_ascii_digit())
85 || (prev.is_ascii_digit() && curr.is_alphabetic());
86
87 if lower_to_upper || upper_run_end || digit_boundary {
88 let split_at = if upper_run_end { i - 1 } else { i };
89 if split_at > start {
90 let part: String = chars[start..split_at].iter().collect();
91 parts.push(part);
92 start = split_at;
93 }
94 }
95 }
96
97 if start < chars.len() {
99 let part: String = chars[start..].iter().collect();
100 parts.push(part);
101 }
102
103 parts
104}
105
106#[derive(Debug, Serialize, Deserialize)]
116pub struct Bm25Index {
117 doc_freq: HashMap<String, usize>,
119 doc_lengths: HashMap<String, usize>,
121 doc_terms: HashMap<String, HashMap<String, usize>>,
123 pub doc_count: usize,
125 avg_doc_len: f64,
127 #[serde(default)]
129 total_doc_len: usize,
130 k1: f64,
132 b: f64,
134 #[serde(default = "default_max_documents")]
136 max_documents: usize,
137 #[serde(default)]
139 insertion_order: VecDeque<String>,
140}
141
142fn default_max_documents() -> usize {
143 100_000
144}
145
146impl Bm25Index {
147 pub fn new() -> Self {
149 Self {
150 doc_freq: HashMap::new(),
151 doc_lengths: HashMap::new(),
152 doc_terms: HashMap::new(),
153 doc_count: 0,
154 avg_doc_len: 0.0,
155 total_doc_len: 0,
156 k1: 1.2,
157 b: 0.75,
158 max_documents: default_max_documents(),
159 insertion_order: VecDeque::new(),
160 }
161 }
162
163 pub fn add_document(&mut self, id: &str, content: &str) {
167 if self.doc_terms.contains_key(id) {
169 self.remove_document(id);
170 }
171
172 if self.doc_count >= self.max_documents {
174 if let Some(oldest_id) = self.insertion_order.pop_front() {
175 self.remove_document(&oldest_id);
176 }
177 }
178
179 let tokens = tokenize(content);
180 let doc_len = tokens.len();
181
182 let mut term_freqs: HashMap<String, usize> = HashMap::new();
184 for token in &tokens {
185 *term_freqs.entry(token.clone()).or_insert(0) += 1;
186 }
187
188 for term in term_freqs.keys() {
190 *self.doc_freq.entry(term.clone()).or_insert(0) += 1;
191 }
192
193 self.doc_lengths.insert(id.to_string(), doc_len);
195 self.doc_terms.insert(id.to_string(), term_freqs);
196 self.doc_count += 1;
197
198 self.insertion_order.push_back(id.to_string());
200
201 self.total_doc_len += doc_len;
203 self.avg_doc_len = self.total_doc_len as f64 / self.doc_count as f64;
204 }
205
206 pub fn remove_document(&mut self, id: &str) {
208 if let Some(term_freqs) = self.doc_terms.remove(id) {
209 for term in term_freqs.keys() {
211 if let Some(df) = self.doc_freq.get_mut(term) {
212 *df = df.saturating_sub(1);
213 if *df == 0 {
214 self.doc_freq.remove(term);
215 }
216 }
217 }
218
219 let removed_len = self.doc_lengths.remove(id).unwrap_or(0);
220 self.doc_count = self.doc_count.saturating_sub(1);
221
222 self.total_doc_len = self.total_doc_len.saturating_sub(removed_len);
224 if self.doc_count == 0 {
225 self.avg_doc_len = 0.0;
226 } else {
227 self.avg_doc_len = self.total_doc_len as f64 / self.doc_count as f64;
228 }
229
230 self.insertion_order.retain(|x| x != id);
232 }
233 }
234
235 pub fn score(&self, query: &str, doc_id: &str) -> f64 {
246 if self.doc_count == 0 {
247 return 0.0;
248 }
249
250 let query_tokens = tokenize(query);
251 if query_tokens.is_empty() {
252 return 0.0;
253 }
254
255 let doc_len = match self.doc_lengths.get(doc_id) {
256 Some(&len) => len,
257 None => return 0.0,
258 };
259
260 let doc_term_freqs = match self.doc_terms.get(doc_id) {
261 Some(tf) => tf,
262 None => return 0.0,
263 };
264
265 let raw = self.raw_bm25_score(&query_tokens, doc_term_freqs, doc_len);
266
267 let max_score = self.max_possible_score(&query_tokens);
270 if max_score <= 0.0 {
271 return 0.0;
272 }
273
274 (raw / max_score).min(1.0)
275 }
276
277 pub fn score_with_tokens_str(&self, query_tokens: &[&str], doc_id: &str) -> f64 {
282 if self.doc_count == 0 || query_tokens.is_empty() {
283 return 0.0;
284 }
285
286 let doc_len = match self.doc_lengths.get(doc_id) {
287 Some(&len) => len,
288 None => return 0.0,
289 };
290
291 let doc_term_freqs = match self.doc_terms.get(doc_id) {
292 Some(tf) => tf,
293 None => return 0.0,
294 };
295
296 let raw = self.raw_bm25_score(query_tokens, doc_term_freqs, doc_len);
297
298 let max_score = self.max_possible_score(query_tokens);
299 if max_score <= 0.0 {
300 return 0.0;
301 }
302
303 (raw / max_score).min(1.0)
304 }
305
306 pub fn score_text_with_tokens(&self, query_tokens: &[String], document: &str) -> f64 {
310 if self.doc_count == 0 || query_tokens.is_empty() {
311 return 0.0;
312 }
313
314 let doc_tokens = tokenize(document);
315 if doc_tokens.is_empty() {
316 return 0.0;
317 }
318
319 let doc_len = doc_tokens.len();
320
321 let mut doc_term_freqs: HashMap<String, usize> = HashMap::new();
323 for token in &doc_tokens {
324 *doc_term_freqs.entry(token.clone()).or_insert(0) += 1;
325 }
326
327 let raw = self.raw_bm25_score(query_tokens, &doc_term_freqs, doc_len);
328
329 let max_score = self.max_possible_score(query_tokens);
330 if max_score <= 0.0 {
331 return 0.0;
332 }
333
334 (raw / max_score).min(1.0)
335 }
336
337 pub fn score_text_with_tokens_str(&self, query_tokens: &[&str], document: &str) -> f64 {
341 if self.doc_count == 0 || query_tokens.is_empty() {
342 return 0.0;
343 }
344
345 let doc_tokens = tokenize(document);
346 if doc_tokens.is_empty() {
347 return 0.0;
348 }
349
350 let doc_len = doc_tokens.len();
351 let mut doc_term_freqs: HashMap<String, usize> = HashMap::new();
352 for token in &doc_tokens {
353 *doc_term_freqs.entry(token.clone()).or_insert(0) += 1;
354 }
355
356 let raw = self.raw_bm25_score(query_tokens, &doc_term_freqs, doc_len);
357 let max_score = self.max_possible_score(query_tokens);
358 if max_score <= 0.0 {
359 return 0.0;
360 }
361
362 (raw / max_score).min(1.0)
363 }
364
365 pub fn score_text(&self, query: &str, document: &str) -> f64 {
368 let query_tokens = tokenize(query);
369 self.score_text_with_tokens(&query_tokens, document)
370 }
371
372 #[cfg(test)]
374 pub fn build(documents: &[(String, String)]) -> Self {
375 let mut index = Self::new();
376 for (id, content) in documents {
377 index.add_document(id, content);
378 }
379 index
380 }
381
382 fn raw_bm25_score<S: AsRef<str>>(
388 &self,
389 query_tokens: &[S],
390 doc_term_freqs: &HashMap<String, usize>,
391 doc_len: usize,
392 ) -> f64 {
393 let n = self.doc_count as f64;
394 let avgdl = if self.avg_doc_len > 0.0 {
395 self.avg_doc_len
396 } else {
397 1.0
398 };
399 let dl = doc_len as f64;
400
401 let mut score = 0.0;
402
403 let mut seen_query_terms: HashSet<&str> = HashSet::new();
405
406 for qt in query_tokens {
407 let qt_str = qt.as_ref();
408 if !seen_query_terms.insert(qt_str) {
409 continue;
410 }
411
412 let tf = *doc_term_freqs.get(qt_str).unwrap_or(&0) as f64;
414 if tf == 0.0 {
415 continue;
416 }
417
418 let df = *self.doc_freq.get(qt_str).unwrap_or(&0) as f64;
420
421 let idf = ((n - df + 0.5) / (df + 0.5) + 1.0).ln();
423
424 let numerator = tf * (self.k1 + 1.0);
426 let denominator = tf + self.k1 * (1.0 - self.b + self.b * dl / avgdl);
427
428 score += idf * numerator / denominator;
429 }
430
431 score
432 }
433
434 fn max_possible_score<S: AsRef<str>>(&self, query_tokens: &[S]) -> f64 {
438 let n = self.doc_count as f64;
439
440 let mut max_score = 0.0;
441 let mut seen: HashSet<&str> = HashSet::new();
442
443 for qt in query_tokens {
444 let qt_str = qt.as_ref();
445 if !seen.insert(qt_str) {
446 continue;
447 }
448
449 let df = *self.doc_freq.get(qt_str).unwrap_or(&0) as f64;
450 let idf = ((n - df + 0.5) / (df + 0.5) + 1.0).ln();
451
452 let tf = 10.0_f64;
456 let numerator = tf * (self.k1 + 1.0);
457 let denominator = tf + self.k1; max_score += idf * numerator / denominator;
459 }
460
461 max_score
462 }
463
464 pub(crate) fn recompute_avg_doc_len(&mut self) {
467 let total: usize = self.doc_lengths.values().sum();
468 self.total_doc_len = total;
469 if self.doc_count == 0 {
470 self.avg_doc_len = 0.0;
471 } else {
472 self.avg_doc_len = total as f64 / self.doc_count as f64;
473 }
474 }
475}
476
477impl Bm25Index {
478 pub fn serialize(&self) -> Vec<u8> {
484 serde_json::to_vec(self).unwrap_or_default()
486 }
487
488 pub fn deserialize(data: &[u8]) -> Result<Self, String> {
494 let mut index: Self = serde_json::from_slice(data)
495 .map_err(|e| format!("BM25 deserialization failed: {e}"))?;
496 index.recompute_avg_doc_len();
498 if index.insertion_order.is_empty() && index.doc_count > 0 {
500 index.insertion_order = index.doc_lengths.keys().cloned().collect();
501 }
502 Ok(index)
503 }
504
505 pub fn needs_save(&self) -> bool {
511 self.doc_count > 0
512 }
513}
514
515impl Default for Bm25Index {
516 fn default() -> Self {
517 Self::new()
518 }
519}
520
521#[cfg(test)]
522#[path = "tests/bm25_tests.rs"]
523mod tests;