Skip to main content

jpx_engine/
bm25.rs

1//! BM25 search indexing for MCP tool discovery.
2//!
3//! This module provides JSON-in/JSON-out search indexing primitives using
4//! the BM25 ranking algorithm. Designed for tool discovery across MCP servers.
5//!
6//! # Design
7//!
8//! - Pure JSON serialization for index portability
9//! - BM25 chosen over TF-IDF for better term saturation and length normalization
10//! - Session-scoped indices by default, but can be saved/restored
11//!
12//! # BM25 Formula
13//!
14//! ```text
15//! score(D,Q) = Σ IDF(qi) * (f(qi,D) * (k1 + 1)) / (f(qi,D) + k1 * (1 - b + b * |D|/avgdl))
16//! ```
17//!
18//! Where:
19//! - f(qi,D) = term frequency of qi in document D
20//! - |D| = document length
21//! - avgdl = average document length
22//! - k1 = term frequency saturation parameter (default 1.2)
23//! - b = length normalization parameter (default 0.75)
24
25use serde::{Deserialize, Serialize};
26use std::collections::HashMap;
27
28/// BM25 index structure - fully JSON serializable
29#[derive(Debug, Clone, Serialize, Deserialize)]
30pub struct Bm25Index {
31    /// Type marker for JSON identification
32    #[serde(rename = "_type")]
33    pub type_marker: String,
34
35    /// Version for future compatibility
36    #[serde(rename = "_version")]
37    pub version: String,
38
39    /// Index configuration
40    pub options: IndexOptions,
41
42    /// Total number of documents
43    pub doc_count: usize,
44
45    /// Average document length (in tokens)
46    pub avg_doc_length: f64,
47
48    /// Document metadata: id -> DocInfo
49    pub docs: HashMap<String, DocInfo>,
50
51    /// Inverted index: term -> TermInfo
52    pub terms: HashMap<String, TermInfo>,
53}
54
55/// Index configuration options
56#[derive(Debug, Clone, Serialize, Deserialize)]
57pub struct IndexOptions {
58    /// Fields to index (empty = treat input as text)
59    #[serde(default)]
60    pub fields: Vec<String>,
61
62    /// Field to use as document ID (default: array index)
63    #[serde(default)]
64    pub id_field: Option<String>,
65
66    /// Normalize case (default: true)
67    #[serde(default = "default_true")]
68    pub lowercase: bool,
69
70    /// Terms to exclude from indexing
71    #[serde(default)]
72    pub stopwords: Vec<String>,
73
74    /// BM25 k1 parameter (term frequency saturation)
75    #[serde(default = "default_k1")]
76    pub k1: f64,
77
78    /// BM25 b parameter (length normalization)
79    #[serde(default = "default_b")]
80    pub b: f64,
81}
82
83fn default_true() -> bool {
84    true
85}
86
87fn default_k1() -> f64 {
88    1.2
89}
90
91fn default_b() -> f64 {
92    0.75
93}
94
95impl Default for IndexOptions {
96    fn default() -> Self {
97        Self {
98            fields: Vec::new(),
99            id_field: None,
100            lowercase: true,
101            stopwords: Vec::new(),
102            k1: 1.2,
103            b: 0.75,
104        }
105    }
106}
107
108/// Document metadata
109#[derive(Debug, Clone, Serialize, Deserialize)]
110pub struct DocInfo {
111    /// Document length in tokens
112    pub length: usize,
113
114    /// Per-field token counts (for multi-field indices)
115    #[serde(default, skip_serializing_if = "HashMap::is_empty")]
116    pub field_lengths: HashMap<String, usize>,
117
118    /// Original document (optional, for returning with results)
119    #[serde(default, skip_serializing_if = "Option::is_none")]
120    pub source: Option<serde_json::Value>,
121}
122
123/// Term information in the inverted index
124#[derive(Debug, Clone, Serialize, Deserialize)]
125pub struct TermInfo {
126    /// Document frequency (number of documents containing this term)
127    pub df: usize,
128
129    /// Postings: doc_id -> term frequency in that document
130    pub postings: HashMap<String, usize>,
131}
132
133/// Search result
134#[derive(Debug, Clone, Serialize, Deserialize)]
135pub struct SearchResult {
136    /// Document ID
137    pub id: String,
138
139    /// BM25 score
140    pub score: f64,
141
142    /// Matched terms and their locations
143    pub matches: HashMap<String, Vec<String>>,
144
145    /// Original document (if stored in index)
146    #[serde(default, skip_serializing_if = "Option::is_none")]
147    pub doc: Option<serde_json::Value>,
148}
149
150/// Score explanation for debugging
151#[derive(Debug, Clone, Serialize, Deserialize)]
152pub struct ScoreExplanation {
153    /// Document ID
154    pub id: String,
155
156    /// Total score
157    pub total_score: f64,
158
159    /// Per-term breakdown
160    pub term_scores: Vec<TermScoreDetail>,
161}
162
163/// Per-term score breakdown
164#[derive(Debug, Clone, Serialize, Deserialize)]
165pub struct TermScoreDetail {
166    /// The term
167    pub term: String,
168
169    /// Term frequency in document
170    pub tf: usize,
171
172    /// Document frequency (corpus-wide)
173    pub df: usize,
174
175    /// IDF component
176    pub idf: f64,
177
178    /// TF saturation component
179    pub tf_component: f64,
180
181    /// Final score contribution
182    pub score: f64,
183}
184
185impl Bm25Index {
186    /// Create a new empty index with the given options
187    pub fn new(options: IndexOptions) -> Self {
188        Self {
189            type_marker: "jpx:bm25_index".to_string(),
190            version: "1.0".to_string(),
191            options,
192            doc_count: 0,
193            avg_doc_length: 0.0,
194            docs: HashMap::new(),
195            terms: HashMap::new(),
196        }
197    }
198
199    /// Build an index from an array of documents
200    pub fn build(docs: &[serde_json::Value], options: IndexOptions) -> Self {
201        let mut index = Self::new(options);
202        let mut total_length = 0usize;
203
204        for (i, doc) in docs.iter().enumerate() {
205            let doc_id = index.get_doc_id(doc, i);
206            let (tokens, field_lengths) = index.tokenize_doc(doc);
207            let doc_length = tokens.len();
208            total_length += doc_length;
209
210            // Store document info
211            index.docs.insert(
212                doc_id.clone(),
213                DocInfo {
214                    length: doc_length,
215                    field_lengths,
216                    source: Some(doc.clone()),
217                },
218            );
219
220            // Update inverted index
221            let mut term_freqs: HashMap<String, usize> = HashMap::new();
222            for token in tokens {
223                *term_freqs.entry(token).or_insert(0) += 1;
224            }
225
226            for (term, freq) in term_freqs {
227                let term_info = index.terms.entry(term).or_insert(TermInfo {
228                    df: 0,
229                    postings: HashMap::new(),
230                });
231                term_info.df += 1;
232                term_info.postings.insert(doc_id.clone(), freq);
233            }
234
235            index.doc_count += 1;
236        }
237
238        // Calculate average document length
239        if index.doc_count > 0 {
240            index.avg_doc_length = total_length as f64 / index.doc_count as f64;
241        }
242
243        index
244    }
245
246    /// Get document ID from a document
247    fn get_doc_id(&self, doc: &serde_json::Value, index: usize) -> String {
248        if let Some(id) = self
249            .options
250            .id_field
251            .as_ref()
252            .and_then(|id_field| doc.get(id_field))
253        {
254            return match id {
255                serde_json::Value::String(s) => s.clone(),
256                serde_json::Value::Number(n) => n.to_string(),
257                _ => format!("{}", index),
258            };
259        }
260        format!("{}", index)
261    }
262
263    /// Tokenize a document into terms
264    fn tokenize_doc(&self, doc: &serde_json::Value) -> (Vec<String>, HashMap<String, usize>) {
265        let mut tokens = Vec::new();
266        let mut field_lengths = HashMap::new();
267
268        if self.options.fields.is_empty() {
269            // Treat entire doc as text
270            let text = self.extract_text(doc);
271            tokens = self.tokenize_text(&text);
272        } else {
273            // Index specific fields
274            for field in &self.options.fields {
275                if let Some(value) = doc.get(field) {
276                    let text = self.extract_text(value);
277                    let field_tokens = self.tokenize_text(&text);
278                    field_lengths.insert(field.clone(), field_tokens.len());
279                    tokens.extend(field_tokens);
280                }
281            }
282        }
283
284        (tokens, field_lengths)
285    }
286
287    /// Extract text from a JSON value
288    fn extract_text(&self, value: &serde_json::Value) -> String {
289        match value {
290            serde_json::Value::String(s) => s.clone(),
291            serde_json::Value::Array(arr) => arr
292                .iter()
293                .filter_map(|v| {
294                    if let serde_json::Value::String(s) = v {
295                        Some(s.as_str())
296                    } else {
297                        None
298                    }
299                })
300                .collect::<Vec<_>>()
301                .join(" "),
302            serde_json::Value::Object(obj) => obj
303                .values()
304                .map(|v| self.extract_text(v))
305                .collect::<Vec<_>>()
306                .join(" "),
307            _ => String::new(),
308        }
309    }
310
311    /// Tokenize text into terms
312    fn tokenize_text(&self, text: &str) -> Vec<String> {
313        let text = if self.options.lowercase {
314            text.to_lowercase()
315        } else {
316            text.to_string()
317        };
318
319        text.split(|c: char| !c.is_alphanumeric() && c != '_')
320            .filter(|s| !s.is_empty())
321            .filter(|s| !self.options.stopwords.contains(&s.to_string()))
322            .map(stem_simple)
323            .collect()
324    }
325}
326
327/// Simple plural stemmer for search indexing.
328///
329/// Handles common English plural forms:
330/// - "databases" -> "database" (strip -s after vowel+consonant+e pattern)
331/// - "ACLs" -> "ACL" (strip -s)
332/// - "queries" -> "query" (ies -> y)
333/// - "boxes" -> "box" (strip -es after x/z)
334///
335/// This is intentionally simple - it improves recall for plural/singular
336/// matching without the complexity of a full Porter stemmer.
337fn stem_simple(term: &str) -> String {
338    let t = term.to_string();
339    let len = t.len();
340
341    // Skip very short terms
342    if len < 3 {
343        return t;
344    }
345
346    // Handle -ies -> -y (queries -> query, entries -> entry)
347    if len > 3 && t.ends_with("ies") {
348        return format!("{}y", &t[..len - 3]);
349    }
350
351    // Handle -xes -> -x and -zes -> -z (boxes -> box, buzzes handled by -ss check)
352    if len > 3 && (t.ends_with("xes") || t.ends_with("zes")) {
353        return t[..len - 2].to_string();
354    }
355
356    // Handle -sses -> -ss (classes -> class, but keep the ss)
357    if len > 4 && t.ends_with("sses") {
358        return t[..len - 2].to_string();
359    }
360
361    // Handle -shes -> -sh (dishes -> dish)
362    if len > 4 && t.ends_with("shes") {
363        return t[..len - 2].to_string();
364    }
365
366    // Handle simple -s (but not -ss like "lass", "class", "boss")
367    // This covers: databases -> database, caches -> cache, shards -> shard
368    if t.ends_with('s') && !t.ends_with("ss") {
369        return t[..len - 1].to_string();
370    }
371
372    t
373}
374
375impl Bm25Index {
376    /// Calculate IDF for a term
377    fn idf(&self, term: &str) -> f64 {
378        let df = self.terms.get(term).map(|t| t.df as f64).unwrap_or(0.0);
379
380        if df == 0.0 {
381            return 0.0;
382        }
383
384        let n = self.doc_count as f64;
385        // IDF formula: ln((N - df + 0.5) / (df + 0.5) + 1)
386        ((n - df + 0.5) / (df + 0.5) + 1.0).ln()
387    }
388
389    /// Calculate BM25 score for a document given query terms
390    fn score_doc(&self, doc_id: &str, query_terms: &[String]) -> f64 {
391        let doc_info = match self.docs.get(doc_id) {
392            Some(info) => info,
393            None => return 0.0,
394        };
395
396        let doc_length = doc_info.length as f64;
397        let k1 = self.options.k1;
398        let b = self.options.b;
399        let avgdl = self.avg_doc_length;
400
401        let mut score = 0.0;
402
403        for term in query_terms {
404            let idf = self.idf(term);
405            let tf = self
406                .terms
407                .get(term)
408                .and_then(|t| t.postings.get(doc_id))
409                .copied()
410                .unwrap_or(0) as f64;
411
412            if tf > 0.0 {
413                // BM25 formula
414                let numerator = tf * (k1 + 1.0);
415                let denominator = tf + k1 * (1.0 - b + b * doc_length / avgdl);
416                score += idf * numerator / denominator;
417            }
418        }
419
420        score
421    }
422
423    /// Search the index
424    pub fn search(&self, query: &str, top_k: usize) -> Vec<SearchResult> {
425        let query_terms = self.tokenize_text(query);
426
427        if query_terms.is_empty() {
428            return Vec::new();
429        }
430
431        // Find candidate documents (those containing at least one query term)
432        let mut candidates: HashMap<String, f64> = HashMap::new();
433
434        for term in &query_terms {
435            if let Some(term_info) = self.terms.get(term) {
436                for doc_id in term_info.postings.keys() {
437                    candidates.entry(doc_id.clone()).or_insert(0.0);
438                }
439            }
440        }
441
442        // Score all candidates
443        let mut results: Vec<SearchResult> = candidates
444            .keys()
445            .map(|doc_id| {
446                let score = self.score_doc(doc_id, &query_terms);
447                let matches = self.get_matches(doc_id, &query_terms);
448                let doc = self.docs.get(doc_id).and_then(|d| d.source.clone());
449
450                SearchResult {
451                    id: doc_id.clone(),
452                    score,
453                    matches,
454                    doc,
455                }
456            })
457            .filter(|r| r.score > 0.0)
458            .collect();
459
460        // Sort by score descending
461        results.sort_by(|a, b| {
462            b.score
463                .partial_cmp(&a.score)
464                .unwrap_or(std::cmp::Ordering::Equal)
465        });
466
467        // Return top_k results
468        results.truncate(top_k);
469        results
470    }
471
472    /// Get matched terms for a document
473    fn get_matches(&self, doc_id: &str, query_terms: &[String]) -> HashMap<String, Vec<String>> {
474        let mut matches: HashMap<String, Vec<String>> = HashMap::new();
475
476        for term in query_terms {
477            if self
478                .terms
479                .get(term)
480                .is_some_and(|term_info| term_info.postings.contains_key(doc_id))
481            {
482                // For now, just note which field matched (if we have field info)
483                matches
484                    .entry("_matched".to_string())
485                    .or_default()
486                    .push(term.clone());
487            }
488        }
489
490        matches
491    }
492
493    /// Explain scoring for a specific document
494    pub fn explain(&self, query: &str, doc_id: &str) -> Option<ScoreExplanation> {
495        let doc_info = self.docs.get(doc_id)?;
496        let query_terms = self.tokenize_text(query);
497
498        let doc_length = doc_info.length as f64;
499        let k1 = self.options.k1;
500        let b = self.options.b;
501        let avgdl = self.avg_doc_length;
502
503        let mut total_score = 0.0;
504        let mut term_scores = Vec::new();
505
506        for term in &query_terms {
507            let idf = self.idf(term);
508            let df = self.terms.get(term).map(|t| t.df).unwrap_or(0);
509            let tf = self
510                .terms
511                .get(term)
512                .and_then(|t| t.postings.get(doc_id))
513                .copied()
514                .unwrap_or(0);
515
516            let tf_f64 = tf as f64;
517            let tf_component = if tf > 0 {
518                let numerator = tf_f64 * (k1 + 1.0);
519                let denominator = tf_f64 + k1 * (1.0 - b + b * doc_length / avgdl);
520                numerator / denominator
521            } else {
522                0.0
523            };
524
525            let score = idf * tf_component;
526            total_score += score;
527
528            term_scores.push(TermScoreDetail {
529                term: term.clone(),
530                tf,
531                df,
532                idf,
533                tf_component,
534                score,
535            });
536        }
537
538        Some(ScoreExplanation {
539            id: doc_id.to_string(),
540            total_score,
541            term_scores,
542        })
543    }
544
545    /// Get all indexed terms with their document frequencies
546    pub fn terms(&self) -> Vec<(String, usize)> {
547        let mut terms: Vec<_> = self
548            .terms
549            .iter()
550            .map(|(t, info)| (t.clone(), info.df))
551            .collect();
552        terms.sort_by(|a, b| b.1.cmp(&a.1)); // Sort by df descending
553        terms
554    }
555
556    /// Find similar documents using term overlap
557    pub fn similar(&self, doc_id: &str, top_k: usize) -> Vec<SearchResult> {
558        let doc_terms: Vec<String> = self
559            .terms
560            .iter()
561            .filter(|(_, info)| info.postings.contains_key(doc_id))
562            .map(|(term, _)| term.clone())
563            .collect();
564
565        if doc_terms.is_empty() {
566            return Vec::new();
567        }
568
569        // Score all other documents using the source doc's terms as query
570        let mut results: Vec<SearchResult> = self
571            .docs
572            .keys()
573            .filter(|id| *id != doc_id)
574            .map(|id| {
575                let score = self.score_doc(id, &doc_terms);
576                let matches = self.get_matches(id, &doc_terms);
577                let doc = self.docs.get(id).and_then(|d| d.source.clone());
578
579                SearchResult {
580                    id: id.clone(),
581                    score,
582                    matches,
583                    doc,
584                }
585            })
586            .filter(|r| r.score > 0.0)
587            .collect();
588
589        results.sort_by(|a, b| {
590            b.score
591                .partial_cmp(&a.score)
592                .unwrap_or(std::cmp::Ordering::Equal)
593        });
594        results.truncate(top_k);
595        results
596    }
597}
598
599#[cfg(test)]
600mod tests {
601    use super::*;
602    use serde_json::json;
603
604    #[test]
605    fn test_build_index_simple() {
606        let docs = vec![
607            json!("hello world"),
608            json!("hello there"),
609            json!("goodbye world"),
610        ];
611
612        let index = Bm25Index::build(&docs, IndexOptions::default());
613
614        assert_eq!(index.doc_count, 3);
615        assert!(index.terms.contains_key("hello"));
616        assert!(index.terms.contains_key("world"));
617        assert_eq!(index.terms.get("hello").unwrap().df, 2);
618        assert_eq!(index.terms.get("world").unwrap().df, 2);
619    }
620
621    #[test]
622    fn test_build_index_with_fields() {
623        let docs = vec![
624            json!({"name": "create_cluster", "description": "Create a new cluster"}),
625            json!({"name": "delete_cluster", "description": "Delete an existing cluster"}),
626            json!({"name": "list_backups", "description": "List all backups"}),
627        ];
628
629        let options = IndexOptions {
630            fields: vec!["name".to_string(), "description".to_string()],
631            id_field: Some("name".to_string()),
632            ..Default::default()
633        };
634
635        let index = Bm25Index::build(&docs, options);
636
637        assert_eq!(index.doc_count, 3);
638        assert!(index.docs.contains_key("create_cluster"));
639        assert!(index.docs.contains_key("delete_cluster"));
640        assert!(index.terms.contains_key("cluster"));
641        assert_eq!(index.terms.get("cluster").unwrap().df, 2);
642    }
643
644    #[test]
645    fn test_search_basic() {
646        let docs = vec![
647            json!({"name": "create_cluster", "description": "Create a new Redis cluster"}),
648            json!({"name": "delete_cluster", "description": "Delete an existing cluster"}),
649            json!({"name": "create_backup", "description": "Create a backup of data"}),
650        ];
651
652        let options = IndexOptions {
653            fields: vec!["name".to_string(), "description".to_string()],
654            id_field: Some("name".to_string()),
655            ..Default::default()
656        };
657
658        let index = Bm25Index::build(&docs, options);
659        let results = index.search("cluster", 10);
660
661        assert_eq!(results.len(), 2);
662        // Both cluster docs should be returned
663        let ids: Vec<_> = results.iter().map(|r| r.id.as_str()).collect();
664        assert!(ids.contains(&"create_cluster"));
665        assert!(ids.contains(&"delete_cluster"));
666    }
667
668    #[test]
669    fn test_search_ranking() {
670        let docs = vec![
671            json!({"name": "cluster_manager", "description": "Manage cluster operations"}),
672            json!({"name": "backup_tool", "description": "Backup tool for cluster data"}),
673            json!({"name": "monitor", "description": "Monitor system health"}),
674        ];
675
676        let options = IndexOptions {
677            fields: vec!["name".to_string(), "description".to_string()],
678            id_field: Some("name".to_string()),
679            ..Default::default()
680        };
681
682        let index = Bm25Index::build(&docs, options);
683        let results = index.search("cluster", 10);
684
685        // cluster_manager should rank higher (has "cluster" in both name and description)
686        assert!(!results.is_empty());
687        assert_eq!(results[0].id, "cluster_manager");
688    }
689
690    #[test]
691    fn test_search_multi_term() {
692        let docs = vec![
693            json!({"name": "create_backup", "description": "Create a backup in a region"}),
694            json!({"name": "restore_backup", "description": "Restore from backup"}),
695            json!({"name": "list_regions", "description": "List available regions"}),
696        ];
697
698        let options = IndexOptions {
699            fields: vec!["name".to_string(), "description".to_string()],
700            id_field: Some("name".to_string()),
701            ..Default::default()
702        };
703
704        let index = Bm25Index::build(&docs, options);
705        let results = index.search("backup region", 10);
706
707        // create_backup should rank highest (has both terms)
708        assert!(!results.is_empty());
709        assert_eq!(results[0].id, "create_backup");
710    }
711
712    #[test]
713    fn test_explain() {
714        let docs = vec![json!({"name": "test", "description": "test document with terms"})];
715
716        let options = IndexOptions {
717            fields: vec!["name".to_string(), "description".to_string()],
718            id_field: Some("name".to_string()),
719            ..Default::default()
720        };
721
722        let index = Bm25Index::build(&docs, options);
723        let explanation = index.explain("test", "test").unwrap();
724
725        assert_eq!(explanation.id, "test");
726        assert!(explanation.total_score > 0.0);
727        assert!(!explanation.term_scores.is_empty());
728    }
729
730    #[test]
731    fn test_similar() {
732        let docs = vec![
733            json!({"name": "create_cluster", "description": "Create a new kubernetes cluster"}),
734            json!({"name": "delete_cluster", "description": "Delete an existing kubernetes cluster"}),
735            json!({"name": "upload_file", "description": "Upload a file to storage"}),
736        ];
737
738        let options = IndexOptions {
739            fields: vec!["name".to_string(), "description".to_string()],
740            id_field: Some("name".to_string()),
741            ..Default::default()
742        };
743
744        let index = Bm25Index::build(&docs, options);
745        let similar = index.similar("create_cluster", 10);
746
747        // delete_cluster should be most similar (shares "cluster" and "kubernetes")
748        assert!(!similar.is_empty());
749        assert_eq!(similar[0].id, "delete_cluster");
750    }
751
752    #[test]
753    fn test_stopwords() {
754        let docs = vec![json!("the quick brown fox"), json!("the lazy dog")];
755
756        let options = IndexOptions {
757            stopwords: vec!["the".to_string()],
758            ..Default::default()
759        };
760
761        let index = Bm25Index::build(&docs, options);
762
763        assert!(!index.terms.contains_key("the"));
764        assert!(index.terms.contains_key("quick"));
765    }
766
767    #[test]
768    fn test_case_insensitive() {
769        let docs = vec![json!("Hello World"), json!("HELLO THERE")];
770
771        let index = Bm25Index::build(&docs, IndexOptions::default());
772        let results = index.search("hello", 10);
773
774        assert_eq!(results.len(), 2);
775    }
776
777    #[test]
778    fn test_json_serialization() {
779        let docs = vec![json!({"name": "test", "description": "test doc"})];
780
781        let options = IndexOptions {
782            fields: vec!["name".to_string()],
783            id_field: Some("name".to_string()),
784            ..Default::default()
785        };
786
787        let index = Bm25Index::build(&docs, options);
788
789        // Should serialize to JSON without error
790        let json = serde_json::to_string(&index).unwrap();
791        assert!(json.contains("jpx:bm25_index"));
792
793        // Should deserialize back
794        let restored: Bm25Index = serde_json::from_str(&json).unwrap();
795        assert_eq!(restored.doc_count, 1);
796    }
797
798    #[test]
799    fn test_terms_list() {
800        let docs = vec![
801            json!("hello hello world"),
802            json!("hello there"),
803            json!("goodbye world"),
804        ];
805
806        let index = Bm25Index::build(&docs, IndexOptions::default());
807        let terms = index.terms();
808
809        // Should be sorted by df descending
810        assert!(!terms.is_empty());
811        // "hello" appears in 2 docs, "world" in 2 docs
812        assert!(terms[0].1 >= terms.last().unwrap().1);
813    }
814}