Skip to main content

codemem_engine/
bm25.rs

1//! BM25 scoring module for code-aware text ranking.
2//!
3//! Implements Okapi BM25 with a code-aware tokenizer that handles camelCase,
4//! snake_case, and other programming conventions. Replaces the naive
5//! split+intersect token overlap in hybrid scoring.
6
7use serde::{Deserialize, Serialize};
8use std::collections::{HashMap, HashSet, VecDeque};
9
10// ── Code-Aware Tokenizer ────────────────────────────────────────────────────
11
12/// Tokenize text with code-awareness: splits camelCase, snake_case,
13/// punctuation boundaries, lowercases, and filters short tokens.
14pub fn tokenize(text: &str) -> Vec<String> {
15    let mut tokens = Vec::new();
16
17    // First split on whitespace, then process each word
18    for word in text.split_whitespace() {
19        // Split on punctuation and non-alphanumeric chars (but keep segments)
20        let segments = split_on_punctuation(word);
21        for segment in segments {
22            // Split camelCase and PascalCase
23            let sub_tokens = split_camel_case(&segment);
24            for token in sub_tokens {
25                let lower = token.to_lowercase();
26                if lower.len() >= 2 {
27                    tokens.push(lower);
28                }
29            }
30        }
31    }
32
33    tokens
34}
35
36/// Split a string on punctuation and non-alphanumeric characters.
37/// Keeps alphanumeric segments, discards separators.
38fn split_on_punctuation(s: &str) -> Vec<String> {
39    let mut segments = Vec::new();
40    let mut current = String::new();
41
42    for ch in s.chars() {
43        if ch.is_alphanumeric() {
44            current.push(ch);
45        } else {
46            // Separator character (underscore, dot, dash, etc.)
47            if !current.is_empty() {
48                segments.push(std::mem::take(&mut current));
49            }
50        }
51    }
52    if !current.is_empty() {
53        segments.push(current);
54    }
55
56    segments
57}
58
59/// Split a camelCase or PascalCase string into its components.
60/// "processRequest" -> ["process", "Request"]
61/// "HTMLParser" -> ["HTML", "Parser"]
62/// "getHTTPResponse" -> ["get", "HTTP", "Response"]
63fn split_camel_case(s: &str) -> Vec<String> {
64    if s.is_empty() {
65        return vec![];
66    }
67
68    let chars: Vec<char> = s.chars().collect();
69    let mut parts = Vec::new();
70    let mut start = 0;
71
72    for i in 1..chars.len() {
73        let prev = chars[i - 1];
74        let curr = chars[i];
75
76        // Split at lowercase -> uppercase boundary: "processRequest"
77        let lower_to_upper = prev.is_lowercase() && curr.is_uppercase();
78
79        // Split at uppercase run -> uppercase+lowercase boundary: "HTMLParser" -> "HTML" | "Parser"
80        let upper_run_end =
81            i >= 2 && chars[i - 2].is_uppercase() && prev.is_uppercase() && curr.is_lowercase();
82
83        // Split at digit boundaries: "item2count" -> "item", "2", "count"
84        let digit_boundary = (prev.is_alphabetic() && curr.is_ascii_digit())
85            || (prev.is_ascii_digit() && curr.is_alphabetic());
86
87        if lower_to_upper || upper_run_end || digit_boundary {
88            let split_at = if upper_run_end { i - 1 } else { i };
89            if split_at > start {
90                let part: String = chars[start..split_at].iter().collect();
91                parts.push(part);
92                start = split_at;
93            }
94        }
95    }
96
97    // Push remaining
98    if start < chars.len() {
99        let part: String = chars[start..].iter().collect();
100        parts.push(part);
101    }
102
103    parts
104}
105
106// ── BM25 Index ──────────────────────────────────────────────────────────────
107
108/// BM25 index for scoring query-document relevance.
109///
110/// Maintains document frequency statistics incrementally. Documents are
111/// identified by string IDs and can be added/removed dynamically.
112///
113/// Supports serialization via `serialize()`/`deserialize()` for persistence
114/// across restarts, avoiding the need to rebuild from all stored memories.
115#[derive(Serialize, Deserialize)]
116pub struct Bm25Index {
117    /// Number of documents containing each term: term -> count
118    doc_freq: HashMap<String, usize>,
119    /// Document lengths (in tokens): doc_id -> length
120    doc_lengths: HashMap<String, usize>,
121    /// Per-document term frequencies for removal support: doc_id -> (term -> count)
122    doc_terms: HashMap<String, HashMap<String, usize>>,
123    /// Total number of documents
124    pub doc_count: usize,
125    /// Average document length
126    avg_doc_len: f64,
127    /// Running total of all document lengths for incremental avg_doc_len tracking
128    #[serde(default)]
129    total_doc_len: usize,
130    /// BM25 k1 parameter (term frequency saturation)
131    k1: f64,
132    /// BM25 b parameter (length normalization)
133    b: f64,
134    /// Maximum number of documents before eviction kicks in
135    #[serde(default = "default_max_documents")]
136    max_documents: usize,
137    /// Insertion order for FIFO eviction when max_documents is exceeded
138    #[serde(default)]
139    insertion_order: VecDeque<String>,
140}
141
142fn default_max_documents() -> usize {
143    100_000
144}
145
146impl Bm25Index {
147    /// Create a new empty BM25 index with default parameters (k1=1.2, b=0.75).
148    pub fn new() -> Self {
149        Self {
150            doc_freq: HashMap::new(),
151            doc_lengths: HashMap::new(),
152            doc_terms: HashMap::new(),
153            doc_count: 0,
154            avg_doc_len: 0.0,
155            total_doc_len: 0,
156            k1: 1.2,
157            b: 0.75,
158            max_documents: default_max_documents(),
159            insertion_order: VecDeque::new(),
160        }
161    }
162
163    /// Add a document to the index, updating term frequencies and statistics.
164    /// If a document with the same ID already exists, it is replaced.
165    /// When the index exceeds `max_documents`, the oldest document is evicted.
166    pub fn add_document(&mut self, id: &str, content: &str) {
167        // Remove old version if exists (to avoid double-counting)
168        if self.doc_terms.contains_key(id) {
169            self.remove_document(id);
170        }
171
172        // Evict oldest document if at capacity
173        if self.doc_count >= self.max_documents {
174            if let Some(oldest_id) = self.insertion_order.pop_front() {
175                self.remove_document(&oldest_id);
176            }
177        }
178
179        let tokens = tokenize(content);
180        let doc_len = tokens.len();
181
182        // Count term frequencies for this document
183        let mut term_freqs: HashMap<String, usize> = HashMap::new();
184        for token in &tokens {
185            *term_freqs.entry(token.clone()).or_insert(0) += 1;
186        }
187
188        // Update global document frequency (each unique term in this doc)
189        for term in term_freqs.keys() {
190            *self.doc_freq.entry(term.clone()).or_insert(0) += 1;
191        }
192
193        // Store document info
194        self.doc_lengths.insert(id.to_string(), doc_len);
195        self.doc_terms.insert(id.to_string(), term_freqs);
196        self.doc_count += 1;
197
198        // Track insertion order for FIFO eviction
199        self.insertion_order.push_back(id.to_string());
200
201        // Incrementally update average document length
202        self.total_doc_len += doc_len;
203        self.avg_doc_len = self.total_doc_len as f64 / self.doc_count as f64;
204    }
205
206    /// Remove a document from the index, updating all statistics.
207    pub fn remove_document(&mut self, id: &str) {
208        if let Some(term_freqs) = self.doc_terms.remove(id) {
209            // Decrement document frequency for each term
210            for term in term_freqs.keys() {
211                if let Some(df) = self.doc_freq.get_mut(term) {
212                    *df = df.saturating_sub(1);
213                    if *df == 0 {
214                        self.doc_freq.remove(term);
215                    }
216                }
217            }
218
219            let removed_len = self.doc_lengths.remove(id).unwrap_or(0);
220            self.doc_count = self.doc_count.saturating_sub(1);
221
222            // Incrementally update average document length
223            self.total_doc_len = self.total_doc_len.saturating_sub(removed_len);
224            if self.doc_count == 0 {
225                self.avg_doc_len = 0.0;
226            } else {
227                self.avg_doc_len = self.total_doc_len as f64 / self.doc_count as f64;
228            }
229
230            // Remove from insertion order
231            self.insertion_order.retain(|x| x != id);
232        }
233    }
234
235    /// Score a query against a specific document using BM25.
236    ///
237    /// The score is computed as:
238    /// ```text
239    /// score(q, d) = Σ IDF(qi) * (f(qi,d) * (k1+1)) / (f(qi,d) + k1*(1 - b + b*|d|/avgdl))
240    /// IDF(qi) = ln((N - n(qi) + 0.5) / (n(qi) + 0.5) + 1)
241    /// ```
242    ///
243    /// The returned score is normalized to [0, 1] by dividing by the maximum
244    /// possible score (perfect self-match with all query terms).
245    pub fn score(&self, query: &str, doc_id: &str) -> f64 {
246        if self.doc_count == 0 {
247            return 0.0;
248        }
249
250        let query_tokens = tokenize(query);
251        if query_tokens.is_empty() {
252            return 0.0;
253        }
254
255        let doc_len = match self.doc_lengths.get(doc_id) {
256            Some(&len) => len,
257            None => return 0.0,
258        };
259
260        let doc_term_freqs = match self.doc_terms.get(doc_id) {
261            Some(tf) => tf,
262            None => return 0.0,
263        };
264
265        let raw = self.raw_bm25_score(&query_tokens, doc_term_freqs, doc_len);
266
267        // Normalize: compute max possible score (if every query term appeared
268        // with high frequency in a short document).
269        let max_score = self.max_possible_score(&query_tokens);
270        if max_score <= 0.0 {
271            return 0.0;
272        }
273
274        (raw / max_score).min(1.0)
275    }
276
277    /// Score pre-tokenized query tokens (as `&str` slices) against a specific indexed document.
278    /// Use this when scoring multiple documents against the same query to avoid
279    /// re-tokenizing the query each time. Zero-allocation: passes slices directly
280    /// to the generic internal helpers.
281    pub fn score_with_tokens_str(&self, query_tokens: &[&str], doc_id: &str) -> f64 {
282        if self.doc_count == 0 || query_tokens.is_empty() {
283            return 0.0;
284        }
285
286        let doc_len = match self.doc_lengths.get(doc_id) {
287            Some(&len) => len,
288            None => return 0.0,
289        };
290
291        let doc_term_freqs = match self.doc_terms.get(doc_id) {
292            Some(tf) => tf,
293            None => return 0.0,
294        };
295
296        let raw = self.raw_bm25_score(query_tokens, doc_term_freqs, doc_len);
297
298        let max_score = self.max_possible_score(query_tokens);
299        if max_score <= 0.0 {
300            return 0.0;
301        }
302
303        (raw / max_score).min(1.0)
304    }
305
306    /// Score pre-tokenized query tokens against arbitrary text (not necessarily in the index).
307    /// Use this when scoring multiple documents against the same query to avoid
308    /// re-tokenizing the query each time.
309    pub fn score_text_with_tokens(&self, query_tokens: &[String], document: &str) -> f64 {
310        if self.doc_count == 0 || query_tokens.is_empty() {
311            return 0.0;
312        }
313
314        let doc_tokens = tokenize(document);
315        if doc_tokens.is_empty() {
316            return 0.0;
317        }
318
319        let doc_len = doc_tokens.len();
320
321        // Build term frequencies for this document
322        let mut doc_term_freqs: HashMap<String, usize> = HashMap::new();
323        for token in &doc_tokens {
324            *doc_term_freqs.entry(token.clone()).or_insert(0) += 1;
325        }
326
327        let raw = self.raw_bm25_score(query_tokens, &doc_term_freqs, doc_len);
328
329        let max_score = self.max_possible_score(query_tokens);
330        if max_score <= 0.0 {
331            return 0.0;
332        }
333
334        (raw / max_score).min(1.0)
335    }
336
337    /// Score pre-tokenized query tokens (as `&str` slices) against arbitrary text.
338    /// Use this when scoring multiple documents against the same query to avoid
339    /// re-tokenizing the query each time. Zero-allocation: generic helpers accept `&[&str]`.
340    pub fn score_text_with_tokens_str(&self, query_tokens: &[&str], document: &str) -> f64 {
341        if self.doc_count == 0 || query_tokens.is_empty() {
342            return 0.0;
343        }
344
345        let doc_tokens = tokenize(document);
346        if doc_tokens.is_empty() {
347            return 0.0;
348        }
349
350        let doc_len = doc_tokens.len();
351        let mut doc_term_freqs: HashMap<String, usize> = HashMap::new();
352        for token in &doc_tokens {
353            *doc_term_freqs.entry(token.clone()).or_insert(0) += 1;
354        }
355
356        let raw = self.raw_bm25_score(query_tokens, &doc_term_freqs, doc_len);
357        let max_score = self.max_possible_score(query_tokens);
358        if max_score <= 0.0 {
359            return 0.0;
360        }
361
362        (raw / max_score).min(1.0)
363    }
364
365    /// Score a query against arbitrary text (not necessarily in the index).
366    /// Useful for scoring documents that haven't been indexed yet.
367    pub fn score_text(&self, query: &str, document: &str) -> f64 {
368        let query_tokens = tokenize(query);
369        self.score_text_with_tokens(&query_tokens, document)
370    }
371
372    /// Build a BM25 index from a slice of (id, content) pairs.
373    pub fn build(documents: &[(String, String)]) -> Self {
374        let mut index = Self::new();
375        for (id, content) in documents {
376            index.add_document(id, content);
377        }
378        index
379    }
380
381    // ── Internal helpers ────────────────────────────────────────────────
382
383    /// Compute the raw (unnormalized) BM25 score.
384    /// Generic over `AsRef<str>` so both `&[String]` and `&[&str]` work
385    /// without per-call allocation.
386    fn raw_bm25_score<S: AsRef<str>>(
387        &self,
388        query_tokens: &[S],
389        doc_term_freqs: &HashMap<String, usize>,
390        doc_len: usize,
391    ) -> f64 {
392        let n = self.doc_count as f64;
393        let avgdl = if self.avg_doc_len > 0.0 {
394            self.avg_doc_len
395        } else {
396            1.0
397        };
398        let dl = doc_len as f64;
399
400        let mut score = 0.0;
401
402        // De-duplicate query tokens for scoring (each unique term scored once)
403        let mut seen_query_terms: HashSet<&str> = HashSet::new();
404
405        for qt in query_tokens {
406            let qt_str = qt.as_ref();
407            if !seen_query_terms.insert(qt_str) {
408                continue;
409            }
410
411            // Term frequency in document
412            let tf = *doc_term_freqs.get(qt_str).unwrap_or(&0) as f64;
413            if tf == 0.0 {
414                continue;
415            }
416
417            // Document frequency (number of docs containing this term)
418            let df = *self.doc_freq.get(qt_str).unwrap_or(&0) as f64;
419
420            // IDF: ln((N - n(qi) + 0.5) / (n(qi) + 0.5) + 1)
421            let idf = ((n - df + 0.5) / (df + 0.5) + 1.0).ln();
422
423            // BM25 term score
424            let numerator = tf * (self.k1 + 1.0);
425            let denominator = tf + self.k1 * (1.0 - self.b + self.b * dl / avgdl);
426
427            score += idf * numerator / denominator;
428        }
429
430        score
431    }
432
433    /// Compute the maximum possible BM25 score for a query (for normalization).
434    /// Assumes a perfect document that contains every query term with high tf
435    /// and has average length.
436    fn max_possible_score<S: AsRef<str>>(&self, query_tokens: &[S]) -> f64 {
437        let n = self.doc_count as f64;
438
439        let mut max_score = 0.0;
440        let mut seen: HashSet<&str> = HashSet::new();
441
442        for qt in query_tokens {
443            let qt_str = qt.as_ref();
444            if !seen.insert(qt_str) {
445                continue;
446            }
447
448            let df = *self.doc_freq.get(qt_str).unwrap_or(&0) as f64;
449            let idf = ((n - df + 0.5) / (df + 0.5) + 1.0).ln();
450
451            // Best case: high tf in average-length document
452            // With tf large, the term score approaches idf * (k1 + 1) / 1 = idf * (k1 + 1)
453            // But for normalization we use tf=10 in avg-length doc for a realistic ceiling.
454            let tf = 10.0_f64;
455            let numerator = tf * (self.k1 + 1.0);
456            let denominator = tf + self.k1; // b * (dl/avgdl) = b * 1.0 when dl == avgdl
457            max_score += idf * numerator / denominator;
458        }
459
460        max_score
461    }
462
463    /// Recompute average document length and total_doc_len from stored lengths.
464    /// Public for use after deserialization to reconstruct the running total.
465    pub fn recompute_avg_doc_len(&mut self) {
466        let total: usize = self.doc_lengths.values().sum();
467        self.total_doc_len = total;
468        if self.doc_count == 0 {
469            self.avg_doc_len = 0.0;
470        } else {
471            self.avg_doc_len = total as f64 / self.doc_count as f64;
472        }
473    }
474}
475
476impl Bm25Index {
477    /// Serialize the BM25 index to a byte vector for persistence.
478    ///
479    /// Uses bincode for compact binary representation. The serialized format
480    /// includes all document frequency statistics, term frequencies, and
481    /// parameters, enabling fast startup without re-indexing all memories.
482    pub fn serialize(&self) -> Vec<u8> {
483        // Use JSON for reliable serialization (bincode not in deps)
484        serde_json::to_vec(self).unwrap_or_default()
485    }
486
487    /// Deserialize a BM25 index from bytes previously produced by `serialize()`.
488    ///
489    /// Returns `Err` if the data is corrupt or from an incompatible version.
490    /// Reconstructs `total_doc_len` and `insertion_order` if they were missing
491    /// from older serialized data (via `#[serde(default)]`).
492    pub fn deserialize(data: &[u8]) -> Result<Self, String> {
493        let mut index: Self = serde_json::from_slice(data)
494            .map_err(|e| format!("BM25 deserialization failed: {e}"))?;
495        // Reconstruct total_doc_len from stored doc_lengths (handles old data without the field)
496        index.recompute_avg_doc_len();
497        // Rebuild insertion_order if it was empty (old data without the field)
498        if index.insertion_order.is_empty() && index.doc_count > 0 {
499            index.insertion_order = index.doc_lengths.keys().cloned().collect();
500        }
501        Ok(index)
502    }
503
504    /// Whether the index contains any documents and may need saving.
505    ///
506    /// Useful for batch operations: call `persist_memory_no_save()` in a loop,
507    /// then check `needs_save()` before writing the vector index to disk once
508    /// at the end. This avoids O(N) disk writes during bulk inserts.
509    pub fn needs_save(&self) -> bool {
510        self.doc_count > 0
511    }
512}
513
514impl Default for Bm25Index {
515    fn default() -> Self {
516        Self::new()
517    }
518}
519
520#[cfg(test)]
521#[path = "tests/bm25_tests.rs"]
522mod tests;