use std::collections::HashMap;
pub use nodedb_document::text_analyzer::analyze;
#[derive(Debug, Default)]
pub struct InvertedIndex {
postings: HashMap<String, HashMap<String, u32>>,
doc_lengths: HashMap<String, u32>,
doc_count: u32,
total_length: u64,
}
#[derive(Debug, Clone)]
pub struct TextSearchResult {
pub doc_id: String,
pub score: f64,
}
#[derive(Debug, Clone, Copy)]
pub struct Bm25Params {
pub k1: f64,
pub b: f64,
}
impl Default for Bm25Params {
fn default() -> Self {
Self { k1: 1.2, b: 0.75 }
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum QueryMode {
And,
Or,
}
impl InvertedIndex {
pub fn new() -> Self {
Self::default()
}
pub fn index_document(&mut self, doc_id: &str, text: &str) {
self.remove_document(doc_id);
let tokens = analyze(text);
if tokens.is_empty() {
return;
}
let doc_len = tokens.len() as u32;
self.doc_lengths.insert(doc_id.to_string(), doc_len);
self.doc_count += 1;
self.total_length += doc_len as u64;
let mut tf: HashMap<String, u32> = HashMap::new();
for token in &tokens {
*tf.entry(token.clone()).or_insert(0) += 1;
}
for (token, freq) in tf {
self.postings
.entry(token)
.or_default()
.insert(doc_id.to_string(), freq);
}
}
pub fn remove_document(&mut self, doc_id: &str) {
if let Some(old_len) = self.doc_lengths.remove(doc_id) {
self.doc_count = self.doc_count.saturating_sub(1);
self.total_length = self.total_length.saturating_sub(old_len as u64);
self.postings.retain(|_, docs| {
docs.remove(doc_id);
!docs.is_empty()
});
}
}
pub fn search(
&self,
query: &str,
top_k: usize,
mode: QueryMode,
params: Bm25Params,
) -> Vec<TextSearchResult> {
let tokens = analyze(query);
if tokens.is_empty() {
return Vec::new();
}
let avg_dl = if self.doc_count > 0 {
self.total_length as f64 / self.doc_count as f64
} else {
1.0
};
let mut scores: HashMap<String, f64> = HashMap::new();
for token in &tokens {
let Some(posting) = self.postings.get(token) else {
continue;
};
let df = posting.len() as f64;
let idf = ((self.doc_count as f64 - df + 0.5) / (df + 0.5) + 1.0).ln();
for (doc_id, &tf) in posting {
let dl = *self.doc_lengths.get(doc_id).unwrap_or(&1) as f64;
let tf_f = tf as f64;
let numerator = tf_f * (params.k1 + 1.0);
let denominator = tf_f + params.k1 * (1.0 - params.b + params.b * dl / avg_dl);
let bm25 = idf * numerator / denominator;
*scores.entry(doc_id.clone()).or_insert(0.0) += bm25;
}
}
if mode == QueryMode::And {
let query_token_count = tokens.len();
scores.retain(|doc_id, _| {
let matched_tokens = tokens
.iter()
.filter(|t| {
self.postings
.get(*t)
.is_some_and(|p| p.contains_key(doc_id))
})
.count();
matched_tokens == query_token_count
});
}
let mut results: Vec<TextSearchResult> = scores
.into_iter()
.map(|(doc_id, score)| TextSearchResult { doc_id, score })
.collect();
results.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
results.truncate(top_k);
results
}
pub fn search_fuzzy(
&self,
query: &str,
max_distance: usize,
top_k: usize,
params: Bm25Params,
) -> Vec<TextSearchResult> {
let tokens = analyze(query);
if tokens.is_empty() {
return Vec::new();
}
let mut expanded_query = String::new();
for token in &tokens {
let matching: Vec<&str> = self
.postings
.keys()
.filter(|idx_token| levenshtein(token, idx_token) <= max_distance)
.map(|s| s.as_str())
.collect();
if !matching.is_empty() {
if !expanded_query.is_empty() {
expanded_query.push(' ');
}
expanded_query.push_str(&matching.join(" "));
}
}
if expanded_query.is_empty() {
return Vec::new();
}
self.search(&expanded_query, top_k, QueryMode::Or, params)
}
pub fn doc_count(&self) -> u32 {
self.doc_count
}
pub fn token_count(&self) -> usize {
self.postings.len()
}
}
fn levenshtein(a: &str, b: &str) -> usize {
nodedb_document::levenshtein(a, b)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn analyze_basic() {
let tokens = analyze("The quick brown fox jumps over the lazy dog");
assert!(!tokens.is_empty());
assert!(tokens.iter().all(|t| t != "the"));
}
#[test]
fn analyze_stemming() {
let tokens = analyze("running jumps quickly");
assert!(tokens.contains(&"run".to_string()));
assert!(tokens.contains(&"jump".to_string()));
assert!(tokens.contains(&"quick".to_string()));
}
#[test]
fn index_and_search() {
let mut idx = InvertedIndex::new();
idx.index_document("d1", "Rust is a systems programming language");
idx.index_document("d2", "Python is great for machine learning");
idx.index_document("d3", "Rust and Python are both great languages");
let results = idx.search("rust programming", 10, QueryMode::Or, Bm25Params::default());
assert!(!results.is_empty());
assert_eq!(results[0].doc_id, "d1");
}
#[test]
fn and_mode() {
let mut idx = InvertedIndex::new();
idx.index_document("d1", "Rust programming language");
idx.index_document("d2", "Python programming language");
let results = idx.search(
"rust programming",
10,
QueryMode::And,
Bm25Params::default(),
);
assert_eq!(results.len(), 1);
assert_eq!(results[0].doc_id, "d1");
}
#[test]
fn fuzzy_search() {
let mut idx = InvertedIndex::new();
idx.index_document("d1", "programming language design");
idx.index_document("d2", "progrmmng language review");
let results = idx.search_fuzzy("programming", 3, 10, Bm25Params::default());
assert!(!results.is_empty(), "fuzzy search should find matches");
let doc_ids: Vec<&str> = results.iter().map(|r| r.doc_id.as_str()).collect();
assert!(doc_ids.contains(&"d1"), "should find d1 (exact match)");
}
#[test]
fn remove_document() {
let mut idx = InvertedIndex::new();
idx.index_document("d1", "hello world");
assert_eq!(idx.doc_count(), 1);
idx.remove_document("d1");
assert_eq!(idx.doc_count(), 0);
let results = idx.search("hello", 10, QueryMode::Or, Bm25Params::default());
assert!(results.is_empty());
}
#[test]
fn levenshtein_basic() {
assert_eq!(levenshtein("kitten", "sitting"), 3);
assert_eq!(levenshtein("", "abc"), 3);
assert_eq!(levenshtein("abc", "abc"), 0);
assert_eq!(levenshtein("abc", "ab"), 1);
}
#[test]
fn reindex_replaces() {
let mut idx = InvertedIndex::new();
idx.index_document("d1", "old content");
idx.index_document("d1", "new content");
assert_eq!(idx.doc_count(), 1);
let results = idx.search("old", 10, QueryMode::Or, Bm25Params::default());
assert!(results.is_empty()); }
}