1use serde::{Deserialize, Serialize};
26use std::collections::HashMap;
27
28#[derive(Debug, Clone, Serialize, Deserialize)]
30pub struct Bm25Index {
31 #[serde(rename = "_type")]
33 pub type_marker: String,
34
35 #[serde(rename = "_version")]
37 pub version: String,
38
39 pub options: IndexOptions,
41
42 pub doc_count: usize,
44
45 pub avg_doc_length: f64,
47
48 pub docs: HashMap<String, DocInfo>,
50
51 pub terms: HashMap<String, TermInfo>,
53}
54
55#[derive(Debug, Clone, Serialize, Deserialize)]
57pub struct IndexOptions {
58 #[serde(default)]
60 pub fields: Vec<String>,
61
62 #[serde(default)]
64 pub id_field: Option<String>,
65
66 #[serde(default = "default_true")]
68 pub lowercase: bool,
69
70 #[serde(default)]
72 pub stopwords: Vec<String>,
73
74 #[serde(default = "default_k1")]
76 pub k1: f64,
77
78 #[serde(default = "default_b")]
80 pub b: f64,
81}
82
83fn default_true() -> bool {
84 true
85}
86
87fn default_k1() -> f64 {
88 1.2
89}
90
91fn default_b() -> f64 {
92 0.75
93}
94
95impl Default for IndexOptions {
96 fn default() -> Self {
97 Self {
98 fields: Vec::new(),
99 id_field: None,
100 lowercase: true,
101 stopwords: Vec::new(),
102 k1: 1.2,
103 b: 0.75,
104 }
105 }
106}
107
108#[derive(Debug, Clone, Serialize, Deserialize)]
110pub struct DocInfo {
111 pub length: usize,
113
114 #[serde(default, skip_serializing_if = "HashMap::is_empty")]
116 pub field_lengths: HashMap<String, usize>,
117
118 #[serde(default, skip_serializing_if = "Option::is_none")]
120 pub source: Option<serde_json::Value>,
121}
122
123#[derive(Debug, Clone, Serialize, Deserialize)]
125pub struct TermInfo {
126 pub df: usize,
128
129 pub postings: HashMap<String, usize>,
131}
132
133#[derive(Debug, Clone, Serialize, Deserialize)]
135pub struct SearchResult {
136 pub id: String,
138
139 pub score: f64,
141
142 pub matches: HashMap<String, Vec<String>>,
144
145 #[serde(default, skip_serializing_if = "Option::is_none")]
147 pub doc: Option<serde_json::Value>,
148}
149
150#[derive(Debug, Clone, Serialize, Deserialize)]
152pub struct ScoreExplanation {
153 pub id: String,
155
156 pub total_score: f64,
158
159 pub term_scores: Vec<TermScoreDetail>,
161}
162
163#[derive(Debug, Clone, Serialize, Deserialize)]
165pub struct TermScoreDetail {
166 pub term: String,
168
169 pub tf: usize,
171
172 pub df: usize,
174
175 pub idf: f64,
177
178 pub tf_component: f64,
180
181 pub score: f64,
183}
184
185impl Bm25Index {
186 pub fn new(options: IndexOptions) -> Self {
188 Self {
189 type_marker: "jpx:bm25_index".to_string(),
190 version: "1.0".to_string(),
191 options,
192 doc_count: 0,
193 avg_doc_length: 0.0,
194 docs: HashMap::new(),
195 terms: HashMap::new(),
196 }
197 }
198
199 pub fn build(docs: &[serde_json::Value], options: IndexOptions) -> Self {
201 let mut index = Self::new(options);
202 let mut total_length = 0usize;
203
204 for (i, doc) in docs.iter().enumerate() {
205 let doc_id = index.get_doc_id(doc, i);
206 let (tokens, field_lengths) = index.tokenize_doc(doc);
207 let doc_length = tokens.len();
208 total_length += doc_length;
209
210 index.docs.insert(
212 doc_id.clone(),
213 DocInfo {
214 length: doc_length,
215 field_lengths,
216 source: Some(doc.clone()),
217 },
218 );
219
220 let mut term_freqs: HashMap<String, usize> = HashMap::new();
222 for token in tokens {
223 *term_freqs.entry(token).or_insert(0) += 1;
224 }
225
226 for (term, freq) in term_freqs {
227 let term_info = index.terms.entry(term).or_insert(TermInfo {
228 df: 0,
229 postings: HashMap::new(),
230 });
231 term_info.df += 1;
232 term_info.postings.insert(doc_id.clone(), freq);
233 }
234
235 index.doc_count += 1;
236 }
237
238 if index.doc_count > 0 {
240 index.avg_doc_length = total_length as f64 / index.doc_count as f64;
241 }
242
243 index
244 }
245
246 fn get_doc_id(&self, doc: &serde_json::Value, index: usize) -> String {
248 if let Some(id) = self
249 .options
250 .id_field
251 .as_ref()
252 .and_then(|id_field| doc.get(id_field))
253 {
254 return match id {
255 serde_json::Value::String(s) => s.clone(),
256 serde_json::Value::Number(n) => n.to_string(),
257 _ => format!("{}", index),
258 };
259 }
260 format!("{}", index)
261 }
262
263 fn tokenize_doc(&self, doc: &serde_json::Value) -> (Vec<String>, HashMap<String, usize>) {
265 let mut tokens = Vec::new();
266 let mut field_lengths = HashMap::new();
267
268 if self.options.fields.is_empty() {
269 let text = self.extract_text(doc);
271 tokens = self.tokenize_text(&text);
272 } else {
273 for field in &self.options.fields {
275 if let Some(value) = doc.get(field) {
276 let text = self.extract_text(value);
277 let field_tokens = self.tokenize_text(&text);
278 field_lengths.insert(field.clone(), field_tokens.len());
279 tokens.extend(field_tokens);
280 }
281 }
282 }
283
284 (tokens, field_lengths)
285 }
286
287 fn extract_text(&self, value: &serde_json::Value) -> String {
289 match value {
290 serde_json::Value::String(s) => s.clone(),
291 serde_json::Value::Array(arr) => arr
292 .iter()
293 .filter_map(|v| {
294 if let serde_json::Value::String(s) = v {
295 Some(s.as_str())
296 } else {
297 None
298 }
299 })
300 .collect::<Vec<_>>()
301 .join(" "),
302 serde_json::Value::Object(obj) => obj
303 .values()
304 .map(|v| self.extract_text(v))
305 .collect::<Vec<_>>()
306 .join(" "),
307 _ => String::new(),
308 }
309 }
310
311 fn tokenize_text(&self, text: &str) -> Vec<String> {
313 let text = if self.options.lowercase {
314 text.to_lowercase()
315 } else {
316 text.to_string()
317 };
318
319 text.split(|c: char| !c.is_alphanumeric() && c != '_')
320 .filter(|s| !s.is_empty())
321 .filter(|s| !self.options.stopwords.contains(&s.to_string()))
322 .map(stem_simple)
323 .collect()
324 }
325}
326
327fn stem_simple(term: &str) -> String {
338 let t = term.to_string();
339 let len = t.len();
340
341 if len < 3 {
343 return t;
344 }
345
346 if len > 3 && t.ends_with("ies") {
348 return format!("{}y", &t[..len - 3]);
349 }
350
351 if len > 3 && (t.ends_with("xes") || t.ends_with("zes")) {
353 return t[..len - 2].to_string();
354 }
355
356 if len > 4 && t.ends_with("sses") {
358 return t[..len - 2].to_string();
359 }
360
361 if len > 4 && t.ends_with("shes") {
363 return t[..len - 2].to_string();
364 }
365
366 if t.ends_with('s') && !t.ends_with("ss") {
369 return t[..len - 1].to_string();
370 }
371
372 t
373}
374
375impl Bm25Index {
376 fn idf(&self, term: &str) -> f64 {
378 let df = self.terms.get(term).map(|t| t.df as f64).unwrap_or(0.0);
379
380 if df == 0.0 {
381 return 0.0;
382 }
383
384 let n = self.doc_count as f64;
385 ((n - df + 0.5) / (df + 0.5) + 1.0).ln()
387 }
388
389 fn score_doc(&self, doc_id: &str, query_terms: &[String]) -> f64 {
391 let doc_info = match self.docs.get(doc_id) {
392 Some(info) => info,
393 None => return 0.0,
394 };
395
396 let doc_length = doc_info.length as f64;
397 let k1 = self.options.k1;
398 let b = self.options.b;
399 let avgdl = self.avg_doc_length;
400
401 let mut score = 0.0;
402
403 for term in query_terms {
404 let idf = self.idf(term);
405 let tf = self
406 .terms
407 .get(term)
408 .and_then(|t| t.postings.get(doc_id))
409 .copied()
410 .unwrap_or(0) as f64;
411
412 if tf > 0.0 {
413 let numerator = tf * (k1 + 1.0);
415 let denominator = tf + k1 * (1.0 - b + b * doc_length / avgdl);
416 score += idf * numerator / denominator;
417 }
418 }
419
420 score
421 }
422
423 pub fn search(&self, query: &str, top_k: usize) -> Vec<SearchResult> {
425 let query_terms = self.tokenize_text(query);
426
427 if query_terms.is_empty() {
428 return Vec::new();
429 }
430
431 let mut candidates: HashMap<String, f64> = HashMap::new();
433
434 for term in &query_terms {
435 if let Some(term_info) = self.terms.get(term) {
436 for doc_id in term_info.postings.keys() {
437 candidates.entry(doc_id.clone()).or_insert(0.0);
438 }
439 }
440 }
441
442 let mut results: Vec<SearchResult> = candidates
444 .keys()
445 .map(|doc_id| {
446 let score = self.score_doc(doc_id, &query_terms);
447 let matches = self.get_matches(doc_id, &query_terms);
448 let doc = self.docs.get(doc_id).and_then(|d| d.source.clone());
449
450 SearchResult {
451 id: doc_id.clone(),
452 score,
453 matches,
454 doc,
455 }
456 })
457 .filter(|r| r.score > 0.0)
458 .collect();
459
460 results.sort_by(|a, b| {
462 b.score
463 .partial_cmp(&a.score)
464 .unwrap_or(std::cmp::Ordering::Equal)
465 });
466
467 results.truncate(top_k);
469 results
470 }
471
472 fn get_matches(&self, doc_id: &str, query_terms: &[String]) -> HashMap<String, Vec<String>> {
474 let mut matches: HashMap<String, Vec<String>> = HashMap::new();
475
476 for term in query_terms {
477 if self
478 .terms
479 .get(term)
480 .is_some_and(|term_info| term_info.postings.contains_key(doc_id))
481 {
482 matches
484 .entry("_matched".to_string())
485 .or_default()
486 .push(term.clone());
487 }
488 }
489
490 matches
491 }
492
493 pub fn explain(&self, query: &str, doc_id: &str) -> Option<ScoreExplanation> {
495 let doc_info = self.docs.get(doc_id)?;
496 let query_terms = self.tokenize_text(query);
497
498 let doc_length = doc_info.length as f64;
499 let k1 = self.options.k1;
500 let b = self.options.b;
501 let avgdl = self.avg_doc_length;
502
503 let mut total_score = 0.0;
504 let mut term_scores = Vec::new();
505
506 for term in &query_terms {
507 let idf = self.idf(term);
508 let df = self.terms.get(term).map(|t| t.df).unwrap_or(0);
509 let tf = self
510 .terms
511 .get(term)
512 .and_then(|t| t.postings.get(doc_id))
513 .copied()
514 .unwrap_or(0);
515
516 let tf_f64 = tf as f64;
517 let tf_component = if tf > 0 {
518 let numerator = tf_f64 * (k1 + 1.0);
519 let denominator = tf_f64 + k1 * (1.0 - b + b * doc_length / avgdl);
520 numerator / denominator
521 } else {
522 0.0
523 };
524
525 let score = idf * tf_component;
526 total_score += score;
527
528 term_scores.push(TermScoreDetail {
529 term: term.clone(),
530 tf,
531 df,
532 idf,
533 tf_component,
534 score,
535 });
536 }
537
538 Some(ScoreExplanation {
539 id: doc_id.to_string(),
540 total_score,
541 term_scores,
542 })
543 }
544
545 pub fn terms(&self) -> Vec<(String, usize)> {
547 let mut terms: Vec<_> = self
548 .terms
549 .iter()
550 .map(|(t, info)| (t.clone(), info.df))
551 .collect();
552 terms.sort_by(|a, b| b.1.cmp(&a.1)); terms
554 }
555
556 pub fn similar(&self, doc_id: &str, top_k: usize) -> Vec<SearchResult> {
558 let doc_terms: Vec<String> = self
559 .terms
560 .iter()
561 .filter(|(_, info)| info.postings.contains_key(doc_id))
562 .map(|(term, _)| term.clone())
563 .collect();
564
565 if doc_terms.is_empty() {
566 return Vec::new();
567 }
568
569 let mut results: Vec<SearchResult> = self
571 .docs
572 .keys()
573 .filter(|id| *id != doc_id)
574 .map(|id| {
575 let score = self.score_doc(id, &doc_terms);
576 let matches = self.get_matches(id, &doc_terms);
577 let doc = self.docs.get(id).and_then(|d| d.source.clone());
578
579 SearchResult {
580 id: id.clone(),
581 score,
582 matches,
583 doc,
584 }
585 })
586 .filter(|r| r.score > 0.0)
587 .collect();
588
589 results.sort_by(|a, b| {
590 b.score
591 .partial_cmp(&a.score)
592 .unwrap_or(std::cmp::Ordering::Equal)
593 });
594 results.truncate(top_k);
595 results
596 }
597}
598
599#[cfg(test)]
600mod tests {
601 use super::*;
602 use serde_json::json;
603
604 #[test]
605 fn test_build_index_simple() {
606 let docs = vec![
607 json!("hello world"),
608 json!("hello there"),
609 json!("goodbye world"),
610 ];
611
612 let index = Bm25Index::build(&docs, IndexOptions::default());
613
614 assert_eq!(index.doc_count, 3);
615 assert!(index.terms.contains_key("hello"));
616 assert!(index.terms.contains_key("world"));
617 assert_eq!(index.terms.get("hello").unwrap().df, 2);
618 assert_eq!(index.terms.get("world").unwrap().df, 2);
619 }
620
621 #[test]
622 fn test_build_index_with_fields() {
623 let docs = vec![
624 json!({"name": "create_cluster", "description": "Create a new cluster"}),
625 json!({"name": "delete_cluster", "description": "Delete an existing cluster"}),
626 json!({"name": "list_backups", "description": "List all backups"}),
627 ];
628
629 let options = IndexOptions {
630 fields: vec!["name".to_string(), "description".to_string()],
631 id_field: Some("name".to_string()),
632 ..Default::default()
633 };
634
635 let index = Bm25Index::build(&docs, options);
636
637 assert_eq!(index.doc_count, 3);
638 assert!(index.docs.contains_key("create_cluster"));
639 assert!(index.docs.contains_key("delete_cluster"));
640 assert!(index.terms.contains_key("cluster"));
641 assert_eq!(index.terms.get("cluster").unwrap().df, 2);
642 }
643
644 #[test]
645 fn test_search_basic() {
646 let docs = vec![
647 json!({"name": "create_cluster", "description": "Create a new Redis cluster"}),
648 json!({"name": "delete_cluster", "description": "Delete an existing cluster"}),
649 json!({"name": "create_backup", "description": "Create a backup of data"}),
650 ];
651
652 let options = IndexOptions {
653 fields: vec!["name".to_string(), "description".to_string()],
654 id_field: Some("name".to_string()),
655 ..Default::default()
656 };
657
658 let index = Bm25Index::build(&docs, options);
659 let results = index.search("cluster", 10);
660
661 assert_eq!(results.len(), 2);
662 let ids: Vec<_> = results.iter().map(|r| r.id.as_str()).collect();
664 assert!(ids.contains(&"create_cluster"));
665 assert!(ids.contains(&"delete_cluster"));
666 }
667
668 #[test]
669 fn test_search_ranking() {
670 let docs = vec![
671 json!({"name": "cluster_manager", "description": "Manage cluster operations"}),
672 json!({"name": "backup_tool", "description": "Backup tool for cluster data"}),
673 json!({"name": "monitor", "description": "Monitor system health"}),
674 ];
675
676 let options = IndexOptions {
677 fields: vec!["name".to_string(), "description".to_string()],
678 id_field: Some("name".to_string()),
679 ..Default::default()
680 };
681
682 let index = Bm25Index::build(&docs, options);
683 let results = index.search("cluster", 10);
684
685 assert!(!results.is_empty());
687 assert_eq!(results[0].id, "cluster_manager");
688 }
689
690 #[test]
691 fn test_search_multi_term() {
692 let docs = vec![
693 json!({"name": "create_backup", "description": "Create a backup in a region"}),
694 json!({"name": "restore_backup", "description": "Restore from backup"}),
695 json!({"name": "list_regions", "description": "List available regions"}),
696 ];
697
698 let options = IndexOptions {
699 fields: vec!["name".to_string(), "description".to_string()],
700 id_field: Some("name".to_string()),
701 ..Default::default()
702 };
703
704 let index = Bm25Index::build(&docs, options);
705 let results = index.search("backup region", 10);
706
707 assert!(!results.is_empty());
709 assert_eq!(results[0].id, "create_backup");
710 }
711
712 #[test]
713 fn test_explain() {
714 let docs = vec![json!({"name": "test", "description": "test document with terms"})];
715
716 let options = IndexOptions {
717 fields: vec!["name".to_string(), "description".to_string()],
718 id_field: Some("name".to_string()),
719 ..Default::default()
720 };
721
722 let index = Bm25Index::build(&docs, options);
723 let explanation = index.explain("test", "test").unwrap();
724
725 assert_eq!(explanation.id, "test");
726 assert!(explanation.total_score > 0.0);
727 assert!(!explanation.term_scores.is_empty());
728 }
729
730 #[test]
731 fn test_similar() {
732 let docs = vec![
733 json!({"name": "create_cluster", "description": "Create a new kubernetes cluster"}),
734 json!({"name": "delete_cluster", "description": "Delete an existing kubernetes cluster"}),
735 json!({"name": "upload_file", "description": "Upload a file to storage"}),
736 ];
737
738 let options = IndexOptions {
739 fields: vec!["name".to_string(), "description".to_string()],
740 id_field: Some("name".to_string()),
741 ..Default::default()
742 };
743
744 let index = Bm25Index::build(&docs, options);
745 let similar = index.similar("create_cluster", 10);
746
747 assert!(!similar.is_empty());
749 assert_eq!(similar[0].id, "delete_cluster");
750 }
751
752 #[test]
753 fn test_stopwords() {
754 let docs = vec![json!("the quick brown fox"), json!("the lazy dog")];
755
756 let options = IndexOptions {
757 stopwords: vec!["the".to_string()],
758 ..Default::default()
759 };
760
761 let index = Bm25Index::build(&docs, options);
762
763 assert!(!index.terms.contains_key("the"));
764 assert!(index.terms.contains_key("quick"));
765 }
766
767 #[test]
768 fn test_case_insensitive() {
769 let docs = vec![json!("Hello World"), json!("HELLO THERE")];
770
771 let index = Bm25Index::build(&docs, IndexOptions::default());
772 let results = index.search("hello", 10);
773
774 assert_eq!(results.len(), 2);
775 }
776
777 #[test]
778 fn test_json_serialization() {
779 let docs = vec![json!({"name": "test", "description": "test doc"})];
780
781 let options = IndexOptions {
782 fields: vec!["name".to_string()],
783 id_field: Some("name".to_string()),
784 ..Default::default()
785 };
786
787 let index = Bm25Index::build(&docs, options);
788
789 let json = serde_json::to_string(&index).unwrap();
791 assert!(json.contains("jpx:bm25_index"));
792
793 let restored: Bm25Index = serde_json::from_str(&json).unwrap();
795 assert_eq!(restored.doc_count, 1);
796 }
797
798 #[test]
799 fn test_terms_list() {
800 let docs = vec![
801 json!("hello hello world"),
802 json!("hello there"),
803 json!("goodbye world"),
804 ];
805
806 let index = Bm25Index::build(&docs, IndexOptions::default());
807 let terms = index.terms();
808
809 assert!(!terms.is_empty());
811 assert!(terms[0].1 >= terms.last().unwrap().1);
813 }
814}