use parking_lot::RwLock;
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet};
use crate::error::Result;
use crate::types::{Metadata, MetadataValue, VectorId};
const K1: f32 = 1.2; const B: f32 = 0.75;
#[derive(Clone, Serialize, Deserialize)]
struct TokenizedDoc {
id: VectorId,
term_frequencies: HashMap<String, u32>,
length: u32,
}
#[derive(Default, Serialize, Deserialize)]
struct BM25Inner {
documents: HashMap<VectorId, TokenizedDoc>,
inverted_index: HashMap<String, HashSet<VectorId>>,
doc_frequencies: HashMap<String, u32>,
total_doc_length: u64,
}
#[derive(Debug, Clone)]
pub struct BM25SearchResult {
pub id: VectorId,
pub score: f32,
}
pub struct BM25Index {
inner: RwLock<BM25Inner>,
indexed_fields: Vec<String>,
}
impl BM25Index {
pub fn new(indexed_fields: Vec<String>) -> Self {
Self {
inner: RwLock::new(BM25Inner::default()),
indexed_fields,
}
}
pub fn indexed_fields(&self) -> &[String] {
&self.indexed_fields
}
pub fn add(&self, id: &str, metadata: Option<&Metadata>) -> Result<()> {
let text = self.extract_text(metadata);
let tokens = self.tokenize(&text);
if tokens.is_empty() {
let mut inner = self.inner.write();
inner.documents.insert(
id.to_string(),
TokenizedDoc {
id: id.to_string(),
term_frequencies: HashMap::new(),
length: 0,
},
);
return Ok(());
}
let mut term_frequencies: HashMap<String, u32> = HashMap::new();
for token in &tokens {
*term_frequencies.entry(token.clone()).or_insert(0) += 1;
}
let doc = TokenizedDoc {
id: id.to_string(),
term_frequencies: term_frequencies.clone(),
length: tokens.len() as u32,
};
let mut inner = self.inner.write();
if let Some(old_doc) = inner.documents.remove(id) {
inner.total_doc_length = inner.total_doc_length.saturating_sub(old_doc.length as u64);
for term in old_doc.term_frequencies.keys() {
if let Some(docs) = inner.inverted_index.get_mut(term) {
docs.remove(id);
if docs.is_empty() {
inner.inverted_index.remove(term);
}
}
if let Some(count) = inner.doc_frequencies.get_mut(term) {
*count = count.saturating_sub(1);
if *count == 0 {
inner.doc_frequencies.remove(term);
}
}
}
}
for term in term_frequencies.keys() {
inner
.inverted_index
.entry(term.clone())
.or_default()
.insert(id.to_string());
*inner.doc_frequencies.entry(term.clone()).or_insert(0) += 1;
}
inner.total_doc_length += doc.length as u64;
inner.documents.insert(id.to_string(), doc);
Ok(())
}
pub fn remove(&self, id: &str) -> Result<bool> {
let mut inner = self.inner.write();
if let Some(doc) = inner.documents.remove(id) {
inner.total_doc_length = inner.total_doc_length.saturating_sub(doc.length as u64);
for term in doc.term_frequencies.keys() {
if let Some(docs) = inner.inverted_index.get_mut(term) {
docs.remove(id);
if docs.is_empty() {
inner.inverted_index.remove(term);
}
}
if let Some(count) = inner.doc_frequencies.get_mut(term) {
*count = count.saturating_sub(1);
if *count == 0 {
inner.doc_frequencies.remove(term);
}
}
}
Ok(true)
} else {
Ok(false)
}
}
pub fn search(&self, query: &str, k: usize) -> Vec<BM25SearchResult> {
let query_tokens = self.tokenize(query);
let inner = self.inner.read();
let n = inner.documents.len() as f32;
if n == 0.0 || query_tokens.is_empty() {
return vec![];
}
let avgdl = inner.total_doc_length as f32 / n;
let mut scores: HashMap<VectorId, f32> = HashMap::new();
for token in &query_tokens {
let df = *inner.doc_frequencies.get(token).unwrap_or(&0) as f32;
if df == 0.0 {
continue;
}
let idf = ((n - df + 0.5) / (df + 0.5) + 1.0).ln();
if let Some(doc_ids) = inner.inverted_index.get(token) {
for doc_id in doc_ids {
if let Some(doc) = inner.documents.get(doc_id) {
let tf = *doc.term_frequencies.get(token).unwrap_or(&0) as f32;
let dl = doc.length as f32;
let score =
idf * (tf * (K1 + 1.0)) / (tf + K1 * (1.0 - B + B * dl / avgdl));
*scores.entry(doc_id.clone()).or_insert(0.0) += score;
}
}
}
}
let mut results: Vec<_> = scores
.into_iter()
.map(|(id, score)| BM25SearchResult { id, score })
.collect();
results.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
results.truncate(k);
results
}
fn extract_text(&self, metadata: Option<&Metadata>) -> String {
let Some(meta) = metadata else {
return String::new();
};
self.indexed_fields
.iter()
.filter_map(|field| match meta.get(field) {
Some(MetadataValue::String(s)) => Some(s.clone()),
_ => None,
})
.collect::<Vec<_>>()
.join(" ")
}
fn tokenize(&self, text: &str) -> Vec<String> {
text.to_lowercase()
.split(|c: char| !c.is_alphanumeric())
.filter(|s| !s.is_empty() && s.len() > 1)
.map(|s| s.to_string())
.collect()
}
pub fn len(&self) -> usize {
self.inner.read().documents.len()
}
pub fn is_empty(&self) -> bool {
self.inner.read().documents.is_empty()
}
pub fn clear(&self) {
let mut inner = self.inner.write();
inner.documents.clear();
inner.inverted_index.clear();
inner.doc_frequencies.clear();
inner.total_doc_length = 0;
}
pub fn serialize(&self) -> Result<Vec<u8>> {
let inner = self.inner.read();
let data = bincode::serialize(&*inner)?;
Ok(data)
}
pub fn deserialize(indexed_fields: Vec<String>, data: &[u8]) -> Result<Self> {
let inner: BM25Inner = bincode::deserialize(data)?;
Ok(Self {
inner: RwLock::new(inner),
indexed_fields,
})
}
pub fn stats(&self) -> BM25Stats {
let inner = self.inner.read();
BM25Stats {
document_count: inner.documents.len(),
unique_terms: inner.doc_frequencies.len(),
total_tokens: inner.total_doc_length as usize,
avg_doc_length: if inner.documents.is_empty() {
0.0
} else {
inner.total_doc_length as f32 / inner.documents.len() as f32
},
}
}
}
#[derive(Debug, Clone)]
pub struct BM25Stats {
pub document_count: usize,
pub unique_terms: usize,
pub total_tokens: usize,
pub avg_doc_length: f32,
}
#[cfg(test)]
mod tests {
use super::*;
fn create_index() -> BM25Index {
BM25Index::new(vec!["title".into(), "content".into()])
}
fn create_doc(title: &str, content: &str) -> Metadata {
let mut meta = Metadata::new();
meta.insert("title", title);
meta.insert("content", content);
meta
}
#[test]
fn test_add_and_search() {
let index = create_index();
let meta1 = create_doc("Rust Programming", "Learn Rust programming language");
let meta2 = create_doc("Python Guide", "Python is great for beginners");
let meta3 = create_doc("JavaScript Tutorial", "JavaScript for web development");
index.add("doc-1", Some(&meta1)).unwrap();
index.add("doc-2", Some(&meta2)).unwrap();
index.add("doc-3", Some(&meta3)).unwrap();
assert_eq!(index.len(), 3);
let results = index.search("rust", 10);
assert_eq!(results.len(), 1);
assert_eq!(results[0].id, "doc-1");
let results = index.search("programming", 10);
assert_eq!(results.len(), 1); assert_eq!(results[0].id, "doc-1");
}
#[test]
fn test_empty_query() {
let index = create_index();
let meta = create_doc("Test", "Content");
index.add("doc-1", Some(&meta)).unwrap();
let results = index.search("", 10);
assert!(results.is_empty());
}
#[test]
fn test_no_match() {
let index = create_index();
let meta = create_doc("Rust", "Programming");
index.add("doc-1", Some(&meta)).unwrap();
let results = index.search("python", 10);
assert!(results.is_empty());
}
#[test]
fn test_remove() {
let index = create_index();
let meta = create_doc("Rust", "Programming");
index.add("doc-1", Some(&meta)).unwrap();
assert_eq!(index.len(), 1);
index.remove("doc-1").unwrap();
assert_eq!(index.len(), 0);
let results = index.search("rust", 10);
assert!(results.is_empty());
}
#[test]
fn test_bm25_ranking() {
let index = create_index();
let meta1 = create_doc(
"Rust",
"Rust is a systems programming language. Rust is fast.",
);
let meta2 = create_doc("Python", "Learn Rust programming");
index.add("doc-1", Some(&meta1)).unwrap();
index.add("doc-2", Some(&meta2)).unwrap();
let results = index.search("rust", 10);
assert_eq!(results.len(), 2);
assert_eq!(results[0].id, "doc-1");
assert!(results[0].score > results[1].score);
}
#[test]
fn test_clear() {
let index = create_index();
let meta = create_doc("Test", "Content");
index.add("doc-1", Some(&meta)).unwrap();
index.clear();
assert!(index.is_empty());
}
#[test]
fn test_stats() {
let index = create_index();
let meta = create_doc("Rust Programming", "Learn Rust language");
index.add("doc-1", Some(&meta)).unwrap();
let stats = index.stats();
assert_eq!(stats.document_count, 1);
assert!(stats.unique_terms > 0);
assert!(stats.total_tokens > 0);
}
#[test]
fn test_serialization() {
let index = create_index();
let meta = create_doc("Test", "Content for serialization");
index.add("doc-1", Some(&meta)).unwrap();
let serialized = index.serialize().unwrap();
let restored =
BM25Index::deserialize(vec!["title".into(), "content".into()], &serialized).unwrap();
assert_eq!(restored.len(), 1);
let results = restored.search("serialization", 10);
assert_eq!(results.len(), 1);
}
}