use super::tokenizer::{SimpleTokenizer, Tokenizer};
use grafeo_common::types::NodeId;
use std::collections::HashMap;
#[derive(Debug, Clone)]
pub struct BM25Config {
pub k1: f64,
pub b: f64,
}
impl Default for BM25Config {
fn default() -> Self {
Self { k1: 1.2, b: 0.75 }
}
}
#[derive(Debug, Clone)]
struct Posting {
node_id: NodeId,
term_freq: u32,
}
#[derive(Debug, Clone, Default)]
struct PostingList {
postings: Vec<Posting>,
}
pub struct InvertedIndex {
postings: HashMap<String, PostingList>,
doc_lengths: HashMap<NodeId, u32>,
total_length: u64,
tokenizer: Box<dyn Tokenizer>,
config: BM25Config,
}
impl InvertedIndex {
#[must_use]
pub fn new(config: BM25Config) -> Self {
Self {
postings: HashMap::new(),
doc_lengths: HashMap::new(),
total_length: 0,
tokenizer: Box::new(SimpleTokenizer::new()),
config,
}
}
pub fn with_tokenizer(config: BM25Config, tokenizer: Box<dyn Tokenizer>) -> Self {
Self {
postings: HashMap::new(),
doc_lengths: HashMap::new(),
total_length: 0,
tokenizer,
config,
}
}
pub fn insert(&mut self, id: NodeId, text: &str) {
if self.doc_lengths.contains_key(&id) {
self.remove(id);
}
let tokens = self.tokenizer.tokenize(text);
let doc_len = tokens.len() as u32;
if doc_len == 0 {
return;
}
let mut term_freqs: HashMap<&str, u32> = HashMap::new();
for token in &tokens {
*term_freqs.entry(token.as_str()).or_insert(0) += 1;
}
for (term, freq) in term_freqs {
self.postings
.entry(term.to_string())
.or_default()
.postings
.push(Posting {
node_id: id,
term_freq: freq,
});
}
self.doc_lengths.insert(id, doc_len);
self.total_length += u64::from(doc_len);
}
pub fn remove(&mut self, id: NodeId) -> bool {
let Some(doc_len) = self.doc_lengths.remove(&id) else {
return false;
};
self.total_length -= u64::from(doc_len);
self.postings.retain(|_, list| {
list.postings.retain(|p| p.node_id != id);
!list.postings.is_empty()
});
true
}
pub fn search(&self, query: &str, k: usize) -> Vec<(NodeId, f64)> {
let query_tokens = self.tokenizer.tokenize(query);
if query_tokens.is_empty() || self.doc_lengths.is_empty() {
return Vec::new();
}
let n = self.doc_lengths.len() as f64;
let avg_dl = self.total_length as f64 / n;
let mut scores: HashMap<NodeId, f64> = HashMap::new();
for token in &query_tokens {
let Some(posting_list) = self.postings.get(token.as_str()) else {
continue;
};
let df = posting_list.postings.len() as f64;
let idf = ((n - df + 0.5) / (df + 0.5) + 1.0).ln();
for posting in &posting_list.postings {
let tf = f64::from(posting.term_freq);
let dl = f64::from(self.doc_lengths.get(&posting.node_id).copied().unwrap_or(0));
let tf_component = (tf * (self.config.k1 + 1.0))
/ (tf + self.config.k1 * (1.0 - self.config.b + self.config.b * dl / avg_dl));
*scores.entry(posting.node_id).or_insert(0.0) += idf * tf_component;
}
}
let mut results: Vec<(NodeId, f64)> = scores.into_iter().collect();
results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
results.truncate(k);
results
}
#[must_use]
pub fn contains(&self, id: NodeId) -> bool {
self.doc_lengths.contains_key(&id)
}
#[must_use]
pub fn len(&self) -> usize {
self.doc_lengths.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.doc_lengths.is_empty()
}
#[must_use]
pub fn term_count(&self) -> usize {
self.postings.len()
}
#[must_use]
pub fn config(&self) -> &BM25Config {
&self.config
}
#[must_use]
pub fn snapshot(&self) -> (Vec<(String, Vec<(NodeId, u32)>)>, Vec<(NodeId, u32)>, u64) {
let mut postings: Vec<(String, Vec<(NodeId, u32)>)> = self
.postings
.iter()
.map(|(term, pl)| {
let entries: Vec<(NodeId, u32)> = pl
.postings
.iter()
.map(|p| (p.node_id, p.term_freq))
.collect();
(term.clone(), entries)
})
.collect();
postings.sort_by(|(a, _), (b, _)| a.cmp(b));
let mut doc_lengths: Vec<(NodeId, u32)> = self
.doc_lengths
.iter()
.map(|(id, len)| (*id, *len))
.collect();
doc_lengths.sort_by_key(|(id, _)| *id);
(postings, doc_lengths, self.total_length)
}
pub fn set_config(&mut self, config: BM25Config) {
self.config = config;
}
pub fn restore(
&mut self,
postings: Vec<(String, Vec<(NodeId, u32)>)>,
doc_lengths: Vec<(NodeId, u32)>,
total_length: u64,
) {
self.postings.clear();
for (term, entries) in postings {
let posting_list = PostingList {
postings: entries
.into_iter()
.map(|(node_id, term_freq)| Posting { node_id, term_freq })
.collect(),
};
self.postings.insert(term, posting_list);
}
self.doc_lengths = doc_lengths.into_iter().collect();
self.total_length = total_length;
}
#[must_use]
pub fn heap_memory_bytes(&self) -> usize {
let postings_overhead = self.postings.capacity()
* (std::mem::size_of::<String>() + std::mem::size_of::<PostingList>() + 1);
let postings_data: usize = self
.postings
.iter()
.map(|(term, pl)| term.len() + pl.postings.capacity() * std::mem::size_of::<Posting>())
.sum();
let doc_lengths_bytes = self.doc_lengths.capacity()
* (std::mem::size_of::<NodeId>() + std::mem::size_of::<u32>() + 1);
postings_overhead + postings_data + doc_lengths_bytes
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_insert_and_search() {
let mut index = InvertedIndex::new(BM25Config::default());
index.insert(
NodeId::new(1),
"the quick brown fox jumps over the lazy dog",
);
index.insert(NodeId::new(2), "a fast red car drives on the highway");
index.insert(NodeId::new(3), "the brown dog sleeps all day");
let results = index.search("brown dog", 10);
assert!(!results.is_empty());
assert_eq!(results[0].0, NodeId::new(3));
}
#[test]
fn test_empty_index_search() {
let index = InvertedIndex::new(BM25Config::default());
let results = index.search("anything", 10);
assert!(results.is_empty());
}
#[test]
fn test_empty_query() {
let mut index = InvertedIndex::new(BM25Config::default());
index.insert(NodeId::new(1), "hello world");
let results = index.search("", 10);
assert!(results.is_empty());
}
#[test]
fn test_stop_word_only_query() {
let mut index = InvertedIndex::new(BM25Config::default());
index.insert(NodeId::new(1), "hello world");
let results = index.search("the a an", 10);
assert!(results.is_empty());
}
#[test]
fn test_remove() {
let mut index = InvertedIndex::new(BM25Config::default());
index.insert(NodeId::new(1), "hello world");
index.insert(NodeId::new(2), "hello rust");
assert_eq!(index.len(), 2);
assert!(index.remove(NodeId::new(1)));
assert_eq!(index.len(), 1);
let results = index.search("hello", 10);
assert_eq!(results.len(), 1);
assert_eq!(results[0].0, NodeId::new(2));
}
#[test]
fn test_remove_nonexistent() {
let mut index = InvertedIndex::new(BM25Config::default());
assert!(!index.remove(NodeId::new(999)));
}
#[test]
fn test_reinsert() {
let mut index = InvertedIndex::new(BM25Config::default());
index.insert(NodeId::new(1), "old text");
index.insert(NodeId::new(1), "new text completely different");
assert_eq!(index.len(), 1);
let results = index.search("old", 10);
assert!(results.is_empty());
let results = index.search("completely different", 10);
assert_eq!(results.len(), 1);
assert_eq!(results[0].0, NodeId::new(1));
}
#[test]
fn test_contains() {
let mut index = InvertedIndex::new(BM25Config::default());
index.insert(NodeId::new(1), "hello world");
assert!(index.contains(NodeId::new(1)));
assert!(!index.contains(NodeId::new(2)));
}
#[test]
fn test_term_count() {
let mut index = InvertedIndex::new(BM25Config::default());
index.insert(NodeId::new(1), "hello world");
index.insert(NodeId::new(2), "hello rust");
assert_eq!(index.term_count(), 3);
}
#[test]
fn test_k_limit() {
let mut index = InvertedIndex::new(BM25Config::default());
for i in 1..=10 {
index.insert(NodeId::new(i), &format!("document number {}", i));
}
let results = index.search("document", 3);
assert_eq!(results.len(), 3);
}
#[test]
fn test_bm25_scoring_prefers_shorter_docs() {
let mut index = InvertedIndex::new(BM25Config::default());
index.insert(NodeId::new(1), "rust database");
index.insert(
NodeId::new(2),
"rust programming language systems web server framework database engine query optimizer",
);
let results = index.search("rust database", 10);
assert_eq!(results.len(), 2);
assert_eq!(results[0].0, NodeId::new(1));
assert!(results[0].1 > results[1].1);
}
#[test]
fn test_no_match() {
let mut index = InvertedIndex::new(BM25Config::default());
index.insert(NodeId::new(1), "hello world");
let results = index.search("nonexistent term", 10);
assert!(results.is_empty());
}
#[test]
fn test_idf_weighting() {
let mut index = InvertedIndex::new(BM25Config::default());
index.insert(NodeId::new(1), "common rare word");
index.insert(NodeId::new(2), "common another word");
index.insert(NodeId::new(3), "common third word");
let results = index.search("rare", 10);
assert_eq!(results.len(), 1);
assert_eq!(results[0].0, NodeId::new(1));
let results = index.search("common", 10);
assert_eq!(results.len(), 3);
}
}