1use serde::{Deserialize, Serialize};
8use std::collections::{HashMap, HashSet, VecDeque};
9
10pub fn tokenize(text: &str) -> Vec<String> {
15 let mut tokens = Vec::new();
16
17 for word in text.split_whitespace() {
19 let segments = split_on_punctuation(word);
21 for segment in segments {
22 let sub_tokens = split_camel_case(&segment);
24 for token in sub_tokens {
25 let lower = token.to_lowercase();
26 if lower.len() >= 2 {
27 tokens.push(lower);
28 }
29 }
30 }
31 }
32
33 tokens
34}
35
36fn split_on_punctuation(s: &str) -> Vec<String> {
39 let mut segments = Vec::new();
40 let mut current = String::new();
41
42 for ch in s.chars() {
43 if ch.is_alphanumeric() {
44 current.push(ch);
45 } else {
46 if !current.is_empty() {
48 segments.push(std::mem::take(&mut current));
49 }
50 }
51 }
52 if !current.is_empty() {
53 segments.push(current);
54 }
55
56 segments
57}
58
59fn split_camel_case(s: &str) -> Vec<String> {
64 if s.is_empty() {
65 return vec![];
66 }
67
68 let chars: Vec<char> = s.chars().collect();
69 let mut parts = Vec::new();
70 let mut start = 0;
71
72 for i in 1..chars.len() {
73 let prev = chars[i - 1];
74 let curr = chars[i];
75
76 let lower_to_upper = prev.is_lowercase() && curr.is_uppercase();
78
79 let upper_run_end =
81 i >= 2 && chars[i - 2].is_uppercase() && prev.is_uppercase() && curr.is_lowercase();
82
83 let digit_boundary = (prev.is_alphabetic() && curr.is_ascii_digit())
85 || (prev.is_ascii_digit() && curr.is_alphabetic());
86
87 if lower_to_upper || upper_run_end || digit_boundary {
88 let split_at = if upper_run_end { i - 1 } else { i };
89 if split_at > start {
90 let part: String = chars[start..split_at].iter().collect();
91 parts.push(part);
92 start = split_at;
93 }
94 }
95 }
96
97 if start < chars.len() {
99 let part: String = chars[start..].iter().collect();
100 parts.push(part);
101 }
102
103 parts
104}
105
106#[derive(Debug, Serialize, Deserialize)]
116pub struct Bm25Index {
117 doc_freq: HashMap<String, usize>,
119 doc_lengths: HashMap<String, usize>,
121 doc_terms: HashMap<String, HashMap<String, usize>>,
123 pub doc_count: usize,
125 avg_doc_len: f64,
127 #[serde(default)]
129 total_doc_len: usize,
130 k1: f64,
132 b: f64,
134 #[serde(default = "default_max_documents")]
136 max_documents: usize,
137 #[serde(default)]
139 insertion_order: VecDeque<String>,
140}
141
142fn default_max_documents() -> usize {
143 100_000
144}
145
146impl Bm25Index {
147 pub fn new() -> Self {
149 Self {
150 doc_freq: HashMap::new(),
151 doc_lengths: HashMap::new(),
152 doc_terms: HashMap::new(),
153 doc_count: 0,
154 avg_doc_len: 0.0,
155 total_doc_len: 0,
156 k1: 1.2,
157 b: 0.75,
158 max_documents: default_max_documents(),
159 insertion_order: VecDeque::new(),
160 }
161 }
162
163 pub fn add_document(&mut self, id: &str, content: &str) {
167 if self.doc_terms.contains_key(id) {
169 self.remove_document(id);
170 }
171
172 if self.doc_count >= self.max_documents {
174 if let Some(oldest_id) = self.insertion_order.pop_front() {
175 self.remove_document(&oldest_id);
176 }
177 }
178
179 let tokens = tokenize(content);
180 let doc_len = tokens.len();
181
182 let mut term_freqs: HashMap<String, usize> = HashMap::new();
184 for token in &tokens {
185 *term_freqs.entry(token.clone()).or_insert(0) += 1;
186 }
187
188 for term in term_freqs.keys() {
190 *self.doc_freq.entry(term.clone()).or_insert(0) += 1;
191 }
192
193 self.doc_lengths.insert(id.to_string(), doc_len);
195 self.doc_terms.insert(id.to_string(), term_freqs);
196 self.doc_count += 1;
197
198 self.insertion_order.push_back(id.to_string());
200
201 self.total_doc_len += doc_len;
203 self.avg_doc_len = self.total_doc_len as f64 / self.doc_count as f64;
204 }
205
206 pub fn remove_document(&mut self, id: &str) {
208 if let Some(term_freqs) = self.doc_terms.remove(id) {
209 for term in term_freqs.keys() {
211 if let Some(df) = self.doc_freq.get_mut(term) {
212 *df = df.saturating_sub(1);
213 if *df == 0 {
214 self.doc_freq.remove(term);
215 }
216 }
217 }
218
219 let removed_len = self.doc_lengths.remove(id).unwrap_or(0);
220 self.doc_count = self.doc_count.saturating_sub(1);
221
222 self.total_doc_len = self.total_doc_len.saturating_sub(removed_len);
224 if self.doc_count == 0 {
225 self.avg_doc_len = 0.0;
226 } else {
227 self.avg_doc_len = self.total_doc_len as f64 / self.doc_count as f64;
228 }
229
230 self.insertion_order.retain(|x| x != id);
232 }
233 }
234
235 pub fn score(&self, query: &str, doc_id: &str) -> f64 {
246 if self.doc_count == 0 {
247 return 0.0;
248 }
249
250 let query_tokens = tokenize(query);
251 if query_tokens.is_empty() {
252 return 0.0;
253 }
254
255 let doc_len = match self.doc_lengths.get(doc_id) {
256 Some(&len) => len,
257 None => return 0.0,
258 };
259
260 let doc_term_freqs = match self.doc_terms.get(doc_id) {
261 Some(tf) => tf,
262 None => return 0.0,
263 };
264
265 let raw = self.raw_bm25_score(&query_tokens, doc_term_freqs, doc_len);
266
267 let max_score = self.max_possible_score(&query_tokens);
270 if max_score <= 0.0 {
271 return 0.0;
272 }
273
274 (raw / max_score).min(1.0)
275 }
276
277 pub fn score_with_tokens_str(&self, query_tokens: &[&str], doc_id: &str) -> f64 {
282 if self.doc_count == 0 || query_tokens.is_empty() {
283 return 0.0;
284 }
285
286 let doc_len = match self.doc_lengths.get(doc_id) {
287 Some(&len) => len,
288 None => return 0.0,
289 };
290
291 let doc_term_freqs = match self.doc_terms.get(doc_id) {
292 Some(tf) => tf,
293 None => return 0.0,
294 };
295
296 let raw = self.raw_bm25_score(query_tokens, doc_term_freqs, doc_len);
297
298 let max_score = self.max_possible_score(query_tokens);
299 if max_score <= 0.0 {
300 return 0.0;
301 }
302
303 (raw / max_score).min(1.0)
304 }
305
306 pub fn score_text_with_tokens(&self, query_tokens: &[String], document: &str) -> f64 {
310 if self.doc_count == 0 || query_tokens.is_empty() {
311 return 0.0;
312 }
313
314 let doc_tokens = tokenize(document);
315 if doc_tokens.is_empty() {
316 return 0.0;
317 }
318
319 let doc_len = doc_tokens.len();
320
321 let mut doc_term_freqs: HashMap<String, usize> = HashMap::new();
323 for token in &doc_tokens {
324 *doc_term_freqs.entry(token.clone()).or_insert(0) += 1;
325 }
326
327 let raw = self.raw_bm25_score(query_tokens, &doc_term_freqs, doc_len);
328
329 let max_score = self.max_possible_score(query_tokens);
330 if max_score <= 0.0 {
331 return 0.0;
332 }
333
334 (raw / max_score).min(1.0)
335 }
336
337 pub fn score_text_with_tokens_str(&self, query_tokens: &[&str], document: &str) -> f64 {
341 if self.doc_count == 0 || query_tokens.is_empty() {
342 return 0.0;
343 }
344
345 let doc_tokens = tokenize(document);
346 if doc_tokens.is_empty() {
347 return 0.0;
348 }
349
350 let doc_len = doc_tokens.len();
351 let mut doc_term_freqs: HashMap<String, usize> = HashMap::new();
352 for token in &doc_tokens {
353 *doc_term_freqs.entry(token.clone()).or_insert(0) += 1;
354 }
355
356 let raw = self.raw_bm25_score(query_tokens, &doc_term_freqs, doc_len);
357 let max_score = self.max_possible_score(query_tokens);
358 if max_score <= 0.0 {
359 return 0.0;
360 }
361
362 (raw / max_score).min(1.0)
363 }
364
365 pub fn score_text(&self, query: &str, document: &str) -> f64 {
368 let query_tokens = tokenize(query);
369 self.score_text_with_tokens(&query_tokens, document)
370 }
371
372 pub fn build(documents: &[(String, String)]) -> Self {
374 let mut index = Self::new();
375 for (id, content) in documents {
376 index.add_document(id, content);
377 }
378 index
379 }
380
381 fn raw_bm25_score<S: AsRef<str>>(
387 &self,
388 query_tokens: &[S],
389 doc_term_freqs: &HashMap<String, usize>,
390 doc_len: usize,
391 ) -> f64 {
392 let n = self.doc_count as f64;
393 let avgdl = if self.avg_doc_len > 0.0 {
394 self.avg_doc_len
395 } else {
396 1.0
397 };
398 let dl = doc_len as f64;
399
400 let mut score = 0.0;
401
402 let mut seen_query_terms: HashSet<&str> = HashSet::new();
404
405 for qt in query_tokens {
406 let qt_str = qt.as_ref();
407 if !seen_query_terms.insert(qt_str) {
408 continue;
409 }
410
411 let tf = *doc_term_freqs.get(qt_str).unwrap_or(&0) as f64;
413 if tf == 0.0 {
414 continue;
415 }
416
417 let df = *self.doc_freq.get(qt_str).unwrap_or(&0) as f64;
419
420 let idf = ((n - df + 0.5) / (df + 0.5) + 1.0).ln();
422
423 let numerator = tf * (self.k1 + 1.0);
425 let denominator = tf + self.k1 * (1.0 - self.b + self.b * dl / avgdl);
426
427 score += idf * numerator / denominator;
428 }
429
430 score
431 }
432
433 fn max_possible_score<S: AsRef<str>>(&self, query_tokens: &[S]) -> f64 {
437 let n = self.doc_count as f64;
438
439 let mut max_score = 0.0;
440 let mut seen: HashSet<&str> = HashSet::new();
441
442 for qt in query_tokens {
443 let qt_str = qt.as_ref();
444 if !seen.insert(qt_str) {
445 continue;
446 }
447
448 let df = *self.doc_freq.get(qt_str).unwrap_or(&0) as f64;
449 let idf = ((n - df + 0.5) / (df + 0.5) + 1.0).ln();
450
451 let tf = 10.0_f64;
455 let numerator = tf * (self.k1 + 1.0);
456 let denominator = tf + self.k1; max_score += idf * numerator / denominator;
458 }
459
460 max_score
461 }
462
463 pub(crate) fn recompute_avg_doc_len(&mut self) {
466 let total: usize = self.doc_lengths.values().sum();
467 self.total_doc_len = total;
468 if self.doc_count == 0 {
469 self.avg_doc_len = 0.0;
470 } else {
471 self.avg_doc_len = total as f64 / self.doc_count as f64;
472 }
473 }
474}
475
476impl Bm25Index {
477 pub fn serialize(&self) -> Vec<u8> {
483 serde_json::to_vec(self).unwrap_or_default()
485 }
486
487 pub fn deserialize(data: &[u8]) -> Result<Self, String> {
493 let mut index: Self = serde_json::from_slice(data)
494 .map_err(|e| format!("BM25 deserialization failed: {e}"))?;
495 index.recompute_avg_doc_len();
497 if index.insertion_order.is_empty() && index.doc_count > 0 {
499 index.insertion_order = index.doc_lengths.keys().cloned().collect();
500 }
501 Ok(index)
502 }
503
504 pub fn needs_save(&self) -> bool {
510 self.doc_count > 0
511 }
512}
513
514impl Default for Bm25Index {
515 fn default() -> Self {
516 Self::new()
517 }
518}
519
520#[cfg(test)]
521#[path = "tests/bm25_tests.rs"]
522mod tests;