use std::collections::HashMap;
#[derive(Debug, Clone)]
pub struct SearchResult {
pub id: String,
pub score: f64,
}
pub trait SearchEngine {
fn add_document(&mut self, id: &str, fields: &[(&str, &str, f64)]);
fn search(&self, query: &str, limit: usize) -> Vec<SearchResult>;
}
#[derive(Debug)]
struct FieldData {
tokens: Vec<String>,
weight: f64,
}
#[derive(Debug)]
struct Document {
fields: HashMap<String, FieldData>,
}
#[derive(Debug)]
pub struct BM25Index {
documents: HashMap<String, Document>,
field_total_tokens: HashMap<String, usize>,
field_doc_count: HashMap<String, usize>,
doc_freq: HashMap<String, Vec<String>>,
}
impl Default for BM25Index {
fn default() -> Self {
Self::new()
}
}
impl BM25Index {
const K1: f64 = 1.2;
const B: f64 = 0.75;
pub fn new() -> Self {
Self {
documents: HashMap::new(),
field_total_tokens: HashMap::new(),
field_doc_count: HashMap::new(),
doc_freq: HashMap::new(),
}
}
fn tokenize(text: &str) -> Vec<String> {
text.split(|c: char| !c.is_alphanumeric() && c != '_')
.map(|s| s.to_lowercase())
.filter(|s| s.len() > 1)
.collect()
}
fn idf(&self, term: &str) -> f64 {
let n = self.documents.len() as f64;
let df = self.doc_freq.get(term).map_or(0, |docs| docs.len()) as f64;
f64::ln((n - df + 0.5) / (df + 0.5) + 1.0)
}
fn avg_field_len(&self, field_name: &str) -> f64 {
let total = *self.field_total_tokens.get(field_name).unwrap_or(&0) as f64;
let count = *self.field_doc_count.get(field_name).unwrap_or(&0) as f64;
if count == 0.0 {
return 0.0;
}
total / count
}
}
impl SearchEngine for BM25Index {
fn add_document(&mut self, id: &str, fields: &[(&str, &str, f64)]) {
let mut doc_fields = HashMap::new();
let mut seen_terms: HashMap<String, bool> = HashMap::new();
for &(field_name, field_value, weight) in fields {
let tokens = Self::tokenize(field_value);
*self
.field_total_tokens
.entry(field_name.to_string())
.or_insert(0) += tokens.len();
*self
.field_doc_count
.entry(field_name.to_string())
.or_insert(0) += 1;
for token in &tokens {
seen_terms.entry(token.clone()).or_insert(true);
}
doc_fields.insert(field_name.to_string(), FieldData { tokens, weight });
}
for term in seen_terms.keys() {
self.doc_freq
.entry(term.clone())
.or_default()
.push(id.to_string());
}
self.documents
.insert(id.to_string(), Document { fields: doc_fields });
}
fn search(&self, query: &str, limit: usize) -> Vec<SearchResult> {
let query_tokens = Self::tokenize(query);
if query_tokens.is_empty() || self.documents.is_empty() {
return Vec::new();
}
let mut scores: HashMap<&str, f64> = HashMap::new();
for term in &query_tokens {
let idf = self.idf(term);
for (doc_id, doc) in &self.documents {
let mut doc_term_score = 0.0;
for (field_name, field_data) in &doc.fields {
let tf = field_data.tokens.iter().filter(|t| *t == term).count() as f64;
if tf == 0.0 {
continue;
}
let field_len = field_data.tokens.len() as f64;
let avg_fl = self.avg_field_len(field_name);
let tf_norm = if avg_fl == 0.0 {
0.0
} else {
(tf * (Self::K1 + 1.0))
/ (tf + Self::K1 * (1.0 - Self::B + Self::B * field_len / avg_fl))
};
doc_term_score += idf * tf_norm * field_data.weight;
}
if doc_term_score > 0.0 {
*scores.entry(doc_id.as_str()).or_insert(0.0) += doc_term_score;
}
}
}
let mut results: Vec<SearchResult> = scores
.into_iter()
.map(|(id, score)| SearchResult {
id: id.to_string(),
score,
})
.collect();
results.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
results.truncate(limit);
results
}
}