use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Bm25Index {
#[serde(rename = "_type")]
pub type_marker: String,
#[serde(rename = "_version")]
pub version: String,
pub options: IndexOptions,
pub doc_count: usize,
pub avg_doc_length: f64,
pub docs: HashMap<String, DocInfo>,
pub terms: HashMap<String, TermInfo>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct IndexOptions {
#[serde(default)]
pub fields: Vec<String>,
#[serde(default)]
pub id_field: Option<String>,
#[serde(default = "default_true")]
pub lowercase: bool,
#[serde(default)]
pub stopwords: Vec<String>,
#[serde(default = "default_k1")]
pub k1: f64,
#[serde(default = "default_b")]
pub b: f64,
}
fn default_true() -> bool {
true
}
fn default_k1() -> f64 {
1.2
}
fn default_b() -> f64 {
0.75
}
impl Default for IndexOptions {
fn default() -> Self {
Self {
fields: Vec::new(),
id_field: None,
lowercase: true,
stopwords: Vec::new(),
k1: 1.2,
b: 0.75,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DocInfo {
pub length: usize,
#[serde(default, skip_serializing_if = "HashMap::is_empty")]
pub field_lengths: HashMap<String, usize>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub source: Option<serde_json::Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TermInfo {
pub df: usize,
pub postings: HashMap<String, usize>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SearchResult {
pub id: String,
pub score: f64,
pub matches: HashMap<String, Vec<String>>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub doc: Option<serde_json::Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ScoreExplanation {
pub id: String,
pub total_score: f64,
pub term_scores: Vec<TermScoreDetail>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TermScoreDetail {
pub term: String,
pub tf: usize,
pub df: usize,
pub idf: f64,
pub tf_component: f64,
pub score: f64,
}
impl Bm25Index {
pub fn new(options: IndexOptions) -> Self {
Self {
type_marker: "jpx:bm25_index".to_string(),
version: "1.0".to_string(),
options,
doc_count: 0,
avg_doc_length: 0.0,
docs: HashMap::new(),
terms: HashMap::new(),
}
}
pub fn build(docs: &[serde_json::Value], options: IndexOptions) -> Self {
let mut index = Self::new(options);
let mut total_length = 0usize;
for (i, doc) in docs.iter().enumerate() {
let doc_id = index.get_doc_id(doc, i);
let (tokens, field_lengths) = index.tokenize_doc(doc);
let doc_length = tokens.len();
total_length += doc_length;
index.docs.insert(
doc_id.clone(),
DocInfo {
length: doc_length,
field_lengths,
source: Some(doc.clone()),
},
);
let mut term_freqs: HashMap<String, usize> = HashMap::new();
for token in tokens {
*term_freqs.entry(token).or_insert(0) += 1;
}
for (term, freq) in term_freqs {
let term_info = index.terms.entry(term).or_insert(TermInfo {
df: 0,
postings: HashMap::new(),
});
term_info.df += 1;
term_info.postings.insert(doc_id.clone(), freq);
}
index.doc_count += 1;
}
if index.doc_count > 0 {
index.avg_doc_length = total_length as f64 / index.doc_count as f64;
}
index
}
fn get_doc_id(&self, doc: &serde_json::Value, index: usize) -> String {
if let Some(id) = self
.options
.id_field
.as_ref()
.and_then(|id_field| doc.get(id_field))
{
return match id {
serde_json::Value::String(s) => s.clone(),
serde_json::Value::Number(n) => n.to_string(),
_ => format!("{}", index),
};
}
format!("{}", index)
}
fn tokenize_doc(&self, doc: &serde_json::Value) -> (Vec<String>, HashMap<String, usize>) {
let mut tokens = Vec::new();
let mut field_lengths = HashMap::new();
if self.options.fields.is_empty() {
let text = self.extract_text(doc);
tokens = self.tokenize_text(&text);
} else {
for field in &self.options.fields {
if let Some(value) = doc.get(field) {
let text = self.extract_text(value);
let field_tokens = self.tokenize_text(&text);
field_lengths.insert(field.clone(), field_tokens.len());
tokens.extend(field_tokens);
}
}
}
(tokens, field_lengths)
}
fn extract_text(&self, value: &serde_json::Value) -> String {
match value {
serde_json::Value::String(s) => s.clone(),
serde_json::Value::Array(arr) => arr
.iter()
.filter_map(|v| {
if let serde_json::Value::String(s) = v {
Some(s.as_str())
} else {
None
}
})
.collect::<Vec<_>>()
.join(" "),
serde_json::Value::Object(obj) => obj
.values()
.map(|v| self.extract_text(v))
.collect::<Vec<_>>()
.join(" "),
_ => String::new(),
}
}
fn tokenize_text(&self, text: &str) -> Vec<String> {
let text = if self.options.lowercase {
text.to_lowercase()
} else {
text.to_string()
};
text.split(|c: char| !c.is_alphanumeric() && c != '_')
.filter(|s| !s.is_empty())
.filter(|s| !self.options.stopwords.contains(&s.to_string()))
.map(stem_simple)
.collect()
}
}
fn stem_simple(term: &str) -> String {
let t = term.to_string();
let len = t.len();
if len < 3 {
return t;
}
if len > 3 && t.ends_with("ies") {
return format!("{}y", &t[..len - 3]);
}
if len > 3 && (t.ends_with("xes") || t.ends_with("zes")) {
return t[..len - 2].to_string();
}
if len > 4 && t.ends_with("sses") {
return t[..len - 2].to_string();
}
if len > 4 && t.ends_with("shes") {
return t[..len - 2].to_string();
}
if t.ends_with('s') && !t.ends_with("ss") {
return t[..len - 1].to_string();
}
t
}
impl Bm25Index {
fn idf(&self, term: &str) -> f64 {
let df = self.terms.get(term).map(|t| t.df as f64).unwrap_or(0.0);
if df == 0.0 {
return 0.0;
}
let n = self.doc_count as f64;
((n - df + 0.5) / (df + 0.5) + 1.0).ln()
}
fn score_doc(&self, doc_id: &str, query_terms: &[String]) -> f64 {
let doc_info = match self.docs.get(doc_id) {
Some(info) => info,
None => return 0.0,
};
let doc_length = doc_info.length as f64;
let k1 = self.options.k1;
let b = self.options.b;
let avgdl = self.avg_doc_length;
let mut score = 0.0;
for term in query_terms {
let idf = self.idf(term);
let tf = self
.terms
.get(term)
.and_then(|t| t.postings.get(doc_id))
.copied()
.unwrap_or(0) as f64;
if tf > 0.0 {
let numerator = tf * (k1 + 1.0);
let denominator = tf + k1 * (1.0 - b + b * doc_length / avgdl);
score += idf * numerator / denominator;
}
}
score
}
pub fn search(&self, query: &str, top_k: usize) -> Vec<SearchResult> {
let query_terms = self.tokenize_text(query);
if query_terms.is_empty() {
return Vec::new();
}
let mut candidates: HashMap<String, f64> = HashMap::new();
for term in &query_terms {
if let Some(term_info) = self.terms.get(term) {
for doc_id in term_info.postings.keys() {
candidates.entry(doc_id.clone()).or_insert(0.0);
}
}
}
let mut results: Vec<SearchResult> = candidates
.keys()
.map(|doc_id| {
let score = self.score_doc(doc_id, &query_terms);
let matches = self.get_matches(doc_id, &query_terms);
let doc = self.docs.get(doc_id).and_then(|d| d.source.clone());
SearchResult {
id: doc_id.clone(),
score,
matches,
doc,
}
})
.filter(|r| r.score > 0.0)
.collect();
results.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
results.truncate(top_k);
results
}
fn get_matches(&self, doc_id: &str, query_terms: &[String]) -> HashMap<String, Vec<String>> {
let mut matches: HashMap<String, Vec<String>> = HashMap::new();
for term in query_terms {
if self
.terms
.get(term)
.is_some_and(|term_info| term_info.postings.contains_key(doc_id))
{
matches
.entry("_matched".to_string())
.or_default()
.push(term.clone());
}
}
matches
}
pub fn explain(&self, query: &str, doc_id: &str) -> Option<ScoreExplanation> {
let doc_info = self.docs.get(doc_id)?;
let query_terms = self.tokenize_text(query);
let doc_length = doc_info.length as f64;
let k1 = self.options.k1;
let b = self.options.b;
let avgdl = self.avg_doc_length;
let mut total_score = 0.0;
let mut term_scores = Vec::new();
for term in &query_terms {
let idf = self.idf(term);
let df = self.terms.get(term).map(|t| t.df).unwrap_or(0);
let tf = self
.terms
.get(term)
.and_then(|t| t.postings.get(doc_id))
.copied()
.unwrap_or(0);
let tf_f64 = tf as f64;
let tf_component = if tf > 0 {
let numerator = tf_f64 * (k1 + 1.0);
let denominator = tf_f64 + k1 * (1.0 - b + b * doc_length / avgdl);
numerator / denominator
} else {
0.0
};
let score = idf * tf_component;
total_score += score;
term_scores.push(TermScoreDetail {
term: term.clone(),
tf,
df,
idf,
tf_component,
score,
});
}
Some(ScoreExplanation {
id: doc_id.to_string(),
total_score,
term_scores,
})
}
pub fn terms(&self) -> Vec<(String, usize)> {
let mut terms: Vec<_> = self
.terms
.iter()
.map(|(t, info)| (t.clone(), info.df))
.collect();
terms.sort_by(|a, b| b.1.cmp(&a.1)); terms
}
pub fn similar(&self, doc_id: &str, top_k: usize) -> Vec<SearchResult> {
let doc_terms: Vec<String> = self
.terms
.iter()
.filter(|(_, info)| info.postings.contains_key(doc_id))
.map(|(term, _)| term.clone())
.collect();
if doc_terms.is_empty() {
return Vec::new();
}
let mut results: Vec<SearchResult> = self
.docs
.keys()
.filter(|id| *id != doc_id)
.map(|id| {
let score = self.score_doc(id, &doc_terms);
let matches = self.get_matches(id, &doc_terms);
let doc = self.docs.get(id).and_then(|d| d.source.clone());
SearchResult {
id: id.clone(),
score,
matches,
doc,
}
})
.filter(|r| r.score > 0.0)
.collect();
results.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
results.truncate(top_k);
results
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn test_build_index_simple() {
let docs = vec![
json!("hello world"),
json!("hello there"),
json!("goodbye world"),
];
let index = Bm25Index::build(&docs, IndexOptions::default());
assert_eq!(index.doc_count, 3);
assert!(index.terms.contains_key("hello"));
assert!(index.terms.contains_key("world"));
assert_eq!(index.terms.get("hello").unwrap().df, 2);
assert_eq!(index.terms.get("world").unwrap().df, 2);
}
#[test]
fn test_build_index_with_fields() {
let docs = vec![
json!({"name": "create_cluster", "description": "Create a new cluster"}),
json!({"name": "delete_cluster", "description": "Delete an existing cluster"}),
json!({"name": "list_backups", "description": "List all backups"}),
];
let options = IndexOptions {
fields: vec!["name".to_string(), "description".to_string()],
id_field: Some("name".to_string()),
..Default::default()
};
let index = Bm25Index::build(&docs, options);
assert_eq!(index.doc_count, 3);
assert!(index.docs.contains_key("create_cluster"));
assert!(index.docs.contains_key("delete_cluster"));
assert!(index.terms.contains_key("cluster"));
assert_eq!(index.terms.get("cluster").unwrap().df, 2);
}
#[test]
fn test_search_basic() {
let docs = vec![
json!({"name": "create_cluster", "description": "Create a new Redis cluster"}),
json!({"name": "delete_cluster", "description": "Delete an existing cluster"}),
json!({"name": "create_backup", "description": "Create a backup of data"}),
];
let options = IndexOptions {
fields: vec!["name".to_string(), "description".to_string()],
id_field: Some("name".to_string()),
..Default::default()
};
let index = Bm25Index::build(&docs, options);
let results = index.search("cluster", 10);
assert_eq!(results.len(), 2);
let ids: Vec<_> = results.iter().map(|r| r.id.as_str()).collect();
assert!(ids.contains(&"create_cluster"));
assert!(ids.contains(&"delete_cluster"));
}
#[test]
fn test_search_ranking() {
let docs = vec![
json!({"name": "cluster_manager", "description": "Manage cluster operations"}),
json!({"name": "backup_tool", "description": "Backup tool for cluster data"}),
json!({"name": "monitor", "description": "Monitor system health"}),
];
let options = IndexOptions {
fields: vec!["name".to_string(), "description".to_string()],
id_field: Some("name".to_string()),
..Default::default()
};
let index = Bm25Index::build(&docs, options);
let results = index.search("cluster", 10);
assert!(!results.is_empty());
assert_eq!(results[0].id, "cluster_manager");
}
#[test]
fn test_search_multi_term() {
let docs = vec![
json!({"name": "create_backup", "description": "Create a backup in a region"}),
json!({"name": "restore_backup", "description": "Restore from backup"}),
json!({"name": "list_regions", "description": "List available regions"}),
];
let options = IndexOptions {
fields: vec!["name".to_string(), "description".to_string()],
id_field: Some("name".to_string()),
..Default::default()
};
let index = Bm25Index::build(&docs, options);
let results = index.search("backup region", 10);
assert!(!results.is_empty());
assert_eq!(results[0].id, "create_backup");
}
#[test]
fn test_explain() {
let docs = vec![json!({"name": "test", "description": "test document with terms"})];
let options = IndexOptions {
fields: vec!["name".to_string(), "description".to_string()],
id_field: Some("name".to_string()),
..Default::default()
};
let index = Bm25Index::build(&docs, options);
let explanation = index.explain("test", "test").unwrap();
assert_eq!(explanation.id, "test");
assert!(explanation.total_score > 0.0);
assert!(!explanation.term_scores.is_empty());
}
#[test]
fn test_similar() {
let docs = vec![
json!({"name": "create_cluster", "description": "Create a new kubernetes cluster"}),
json!({"name": "delete_cluster", "description": "Delete an existing kubernetes cluster"}),
json!({"name": "upload_file", "description": "Upload a file to storage"}),
];
let options = IndexOptions {
fields: vec!["name".to_string(), "description".to_string()],
id_field: Some("name".to_string()),
..Default::default()
};
let index = Bm25Index::build(&docs, options);
let similar = index.similar("create_cluster", 10);
assert!(!similar.is_empty());
assert_eq!(similar[0].id, "delete_cluster");
}
#[test]
fn test_stopwords() {
let docs = vec![json!("the quick brown fox"), json!("the lazy dog")];
let options = IndexOptions {
stopwords: vec!["the".to_string()],
..Default::default()
};
let index = Bm25Index::build(&docs, options);
assert!(!index.terms.contains_key("the"));
assert!(index.terms.contains_key("quick"));
}
#[test]
fn test_case_insensitive() {
let docs = vec![json!("Hello World"), json!("HELLO THERE")];
let index = Bm25Index::build(&docs, IndexOptions::default());
let results = index.search("hello", 10);
assert_eq!(results.len(), 2);
}
#[test]
fn test_json_serialization() {
let docs = vec![json!({"name": "test", "description": "test doc"})];
let options = IndexOptions {
fields: vec!["name".to_string()],
id_field: Some("name".to_string()),
..Default::default()
};
let index = Bm25Index::build(&docs, options);
let json = serde_json::to_string(&index).unwrap();
assert!(json.contains("jpx:bm25_index"));
let restored: Bm25Index = serde_json::from_str(&json).unwrap();
assert_eq!(restored.doc_count, 1);
}
#[test]
fn test_terms_list() {
let docs = vec![
json!("hello hello world"),
json!("hello there"),
json!("goodbye world"),
];
let index = Bm25Index::build(&docs, IndexOptions::default());
let terms = index.terms();
assert!(!terms.is_empty());
assert!(terms[0].1 >= terms.last().unwrap().1);
}
#[test]
fn test_empty_index_search() {
let index = Bm25Index::new(IndexOptions::default());
let results = index.search("anything", 10);
assert!(results.is_empty());
}
#[test]
fn test_empty_query_search() {
let docs = vec![json!("hello world"), json!("goodbye world")];
let index = Bm25Index::build(&docs, IndexOptions::default());
let results = index.search("", 10);
assert!(results.is_empty());
}
#[test]
fn test_single_document_index() {
let docs = vec![json!("the rust programming language")];
let index = Bm25Index::build(&docs, IndexOptions::default());
assert_eq!(index.doc_count, 1);
let results = index.search("rust", 10);
assert_eq!(results.len(), 1);
assert_eq!(results[0].id, "0");
assert!(results[0].score > 0.0);
}
#[test]
fn test_stem_simple_plural_s() {
assert_eq!(stem_simple("databases"), "database");
}
#[test]
fn test_stem_simple_plural_ies() {
assert_eq!(stem_simple("queries"), "query");
}
#[test]
fn test_stem_simple_plural_xes() {
assert_eq!(stem_simple("boxes"), "box");
}
#[test]
fn test_stem_simple_short_word() {
assert_eq!(stem_simple("is"), "is");
}
#[test]
fn test_stem_simple_no_change() {
assert_eq!(stem_simple("data"), "data");
}
#[test]
fn test_idf_zero_for_unknown_term() {
let docs = vec![json!("hello world"), json!("goodbye world")];
let index = Bm25Index::build(&docs, IndexOptions::default());
let idf = index.idf("nonexistent_term");
assert_eq!(idf, 0.0);
}
#[test]
fn test_similar_nonexistent_doc() {
let docs = vec![
json!({"name": "alpha", "description": "first document about rust"}),
json!({"name": "beta", "description": "second document about python"}),
];
let options = IndexOptions {
fields: vec!["name".to_string(), "description".to_string()],
id_field: Some("name".to_string()),
..Default::default()
};
let index = Bm25Index::build(&docs, options);
let results = index.similar("nonexistent", 5);
assert!(results.is_empty());
}
}