use std::collections::{BTreeMap, HashMap};
use super::bm25::{Bm25Params, score as bm25_score};
use super::tokenizer::tokenize;
#[derive(Debug, Default, Clone)]
pub struct PostingList {
postings: BTreeMap<String, BTreeMap<i64, u32>>,
doc_lengths: BTreeMap<i64, u32>,
total_tokens: u64,
}
impl PostingList {
pub fn new() -> Self {
Self::default()
}
pub fn len(&self) -> usize {
self.doc_lengths.len()
}
pub fn is_empty(&self) -> bool {
self.doc_lengths.is_empty()
}
pub fn avg_doc_len(&self) -> f64 {
if self.doc_lengths.is_empty() {
0.0
} else {
self.total_tokens as f64 / self.doc_lengths.len() as f64
}
}
pub fn serialize_doc_lengths(&self) -> Vec<(i64, u32)> {
self.doc_lengths
.iter()
.map(|(id, len)| (*id, *len))
.collect()
}
pub fn serialize_postings(&self) -> Vec<(String, Vec<(i64, u32)>)> {
self.postings
.iter()
.map(|(term, postings)| {
let entries = postings.iter().map(|(id, freq)| (*id, *freq)).collect();
(term.clone(), entries)
})
.collect()
}
pub fn from_persisted_postings<I, J>(doc_lengths: I, postings: J) -> Self
where
I: IntoIterator<Item = (i64, u32)>,
J: IntoIterator<Item = (String, Vec<(i64, u32)>)>,
{
let mut doc_lengths_map: BTreeMap<i64, u32> = BTreeMap::new();
let mut total_tokens: u64 = 0;
for (rowid, len) in doc_lengths {
doc_lengths_map.insert(rowid, len);
total_tokens += len as u64;
}
let mut postings_map: BTreeMap<String, BTreeMap<i64, u32>> = BTreeMap::new();
for (term, entries) in postings {
let inner: BTreeMap<i64, u32> = entries.into_iter().collect();
if !inner.is_empty() {
postings_map.insert(term, inner);
}
}
Self {
postings: postings_map,
doc_lengths: doc_lengths_map,
total_tokens,
}
}
pub fn insert(&mut self, rowid: i64, text: &str) {
if self.doc_lengths.contains_key(&rowid) {
self.remove(rowid);
}
let tokens = tokenize(text);
let doc_len = tokens.len() as u32;
self.total_tokens += doc_len as u64;
self.doc_lengths.insert(rowid, doc_len);
let mut tf: HashMap<&str, u32> = HashMap::new();
for tok in &tokens {
*tf.entry(tok.as_str()).or_insert(0) += 1;
}
for (term, freq) in tf {
self.postings
.entry(term.to_string())
.or_default()
.insert(rowid, freq);
}
}
pub fn remove(&mut self, rowid: i64) {
let Some(doc_len) = self.doc_lengths.remove(&rowid) else {
return;
};
self.total_tokens -= doc_len as u64;
let mut empty_terms = Vec::new();
for (term, postings) in self.postings.iter_mut() {
if postings.remove(&rowid).is_some() && postings.is_empty() {
empty_terms.push(term.clone());
}
}
for term in empty_terms {
self.postings.remove(&term);
}
}
pub fn matches(&self, rowid: i64, query: &str) -> bool {
if !self.doc_lengths.contains_key(&rowid) {
return false;
}
for term in tokenize(query) {
if let Some(postings) = self.postings.get(&term) {
if postings.contains_key(&rowid) {
return true;
}
}
}
false
}
pub fn score(&self, rowid: i64, query: &str, params: &Bm25Params) -> f64 {
let Some(&doc_len) = self.doc_lengths.get(&rowid) else {
return 0.0;
};
let query_terms = tokenize(query);
if query_terms.is_empty() {
return 0.0;
}
let term_freq = self.term_freq_for_doc(rowid, &query_terms);
let n_docs_with = self.n_docs_with_for_terms(&query_terms);
bm25_score(
&query_terms,
&term_freq,
doc_len,
self.avg_doc_len(),
&n_docs_with,
self.doc_lengths.len() as u32,
params,
)
}
pub fn query(&self, query: &str, params: &Bm25Params) -> Vec<(i64, f64)> {
let query_terms = tokenize(query);
if query_terms.is_empty() || self.doc_lengths.is_empty() {
return Vec::new();
}
let mut candidates: BTreeMap<i64, u32> = BTreeMap::new();
for term in &query_terms {
if let Some(postings) = self.postings.get(term) {
for &rowid in postings.keys() {
candidates.entry(rowid).or_insert(0);
}
}
}
if candidates.is_empty() {
return Vec::new();
}
let n_docs_with = self.n_docs_with_for_terms(&query_terms);
let avg = self.avg_doc_len();
let total_docs = self.doc_lengths.len() as u32;
let mut scored: Vec<(i64, f64)> = candidates
.into_keys()
.map(|rowid| {
let doc_len = self.doc_lengths[&rowid];
let tf = self.term_freq_for_doc(rowid, &query_terms);
let s = bm25_score(
&query_terms,
&tf,
doc_len,
avg,
&n_docs_with,
total_docs,
params,
);
(rowid, s)
})
.collect();
scored.sort_by(|a, b| {
b.1.partial_cmp(&a.1)
.unwrap_or(std::cmp::Ordering::Equal)
.then_with(|| a.0.cmp(&b.0))
});
scored
}
fn term_freq_for_doc(&self, rowid: i64, query_terms: &[String]) -> HashMap<String, u32> {
let mut tf = HashMap::with_capacity(query_terms.len());
for term in query_terms {
if tf.contains_key(term) {
continue;
}
let freq = self
.postings
.get(term)
.and_then(|p| p.get(&rowid).copied())
.unwrap_or(0);
tf.insert(term.clone(), freq);
}
tf
}
fn n_docs_with_for_terms(&self, query_terms: &[String]) -> HashMap<String, u32> {
let mut n = HashMap::with_capacity(query_terms.len());
for term in query_terms {
if n.contains_key(term) {
continue;
}
let count = self.postings.get(term).map(|p| p.len() as u32).unwrap_or(0);
n.insert(term.clone(), count);
}
n
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn empty_list_is_empty() {
let pl = PostingList::new();
assert!(pl.is_empty());
assert_eq!(pl.len(), 0);
assert_eq!(pl.avg_doc_len(), 0.0);
assert!(pl.query("anything", &Bm25Params::default()).is_empty());
assert_eq!(pl.score(1, "anything", &Bm25Params::default()), 0.0);
assert!(!pl.matches(1, "anything"));
}
#[test]
fn empty_query_returns_empty_results() {
let mut pl = PostingList::new();
pl.insert(1, "rust embedded database");
assert!(pl.query("", &Bm25Params::default()).is_empty());
assert!(pl.query("!!!", &Bm25Params::default()).is_empty());
assert_eq!(pl.score(1, "", &Bm25Params::default()), 0.0);
}
#[test]
fn insert_and_query_two_docs_ranks_correctly() {
let mut pl = PostingList::new();
pl.insert(1, "rust rust embedded database");
pl.insert(2, "rust language");
let res = pl.query("rust", &Bm25Params::default());
assert_eq!(res.len(), 2);
let (id_a, s_a) = res[0];
let (id_b, s_b) = res[1];
assert!(s_a > 0.0 && s_b > 0.0);
assert!(s_a >= s_b);
assert!(
(id_a == 1 || id_a == 2) && (id_b == 1 || id_b == 2) && id_a != id_b,
"result rowids should be {{1,2}}, got ({}, {})",
id_a,
id_b
);
assert!(pl.matches(1, "rust"));
assert!(pl.matches(2, "rust"));
assert!(!pl.matches(1, "python"));
}
#[test]
fn score_method_matches_bulk_query() {
let mut pl = PostingList::new();
pl.insert(10, "rust embedded database");
pl.insert(20, "go embedded database");
pl.insert(30, "python web framework");
let params = Bm25Params::default();
let bulk = pl.query("embedded", ¶ms);
for (rowid, score) in &bulk {
let direct = pl.score(*rowid, "embedded", ¶ms);
assert!(
(direct - score).abs() < f64::EPSILON * 16.0,
"score({}, ...) = {} vs query() reported {}",
rowid,
direct,
score
);
}
assert_eq!(pl.score(30, "embedded", ¶ms), 0.0);
}
#[test]
fn remove_clears_doc_and_prunes_empty_terms() {
let mut pl = PostingList::new();
pl.insert(1, "rust");
pl.insert(2, "rust embedded");
assert_eq!(pl.len(), 2);
assert_eq!(pl.total_tokens, 3);
assert!(pl.postings.contains_key("rust"));
assert!(pl.postings.contains_key("embedded"));
pl.remove(2);
assert_eq!(pl.len(), 1);
assert_eq!(pl.total_tokens, 1);
assert!(!pl.postings.contains_key("embedded"));
assert!(pl.postings.contains_key("rust"));
pl.remove(1);
assert!(pl.is_empty());
assert!(pl.postings.is_empty());
assert_eq!(pl.total_tokens, 0);
pl.remove(1);
pl.remove(99);
assert!(pl.is_empty());
}
#[test]
fn reinsert_replaces_prior_postings() {
let mut pl = PostingList::new();
pl.insert(1, "rust rust rust");
assert_eq!(pl.postings["rust"][&1], 3);
assert_eq!(pl.total_tokens, 3);
pl.insert(1, "go");
assert_eq!(pl.len(), 1);
assert_eq!(pl.total_tokens, 1);
assert!(!pl.postings.contains_key("rust"));
assert_eq!(pl.postings["go"][&1], 1);
}
#[test]
fn tie_break_orders_by_rowid_ascending() {
let mut pl = PostingList::new();
pl.insert(7, "alpha beta");
pl.insert(3, "alpha beta");
pl.insert(5, "alpha beta");
let res = pl.query("alpha", &Bm25Params::default());
let ids: Vec<i64> = res.iter().map(|(id, _)| *id).collect();
assert_eq!(ids, vec![3, 5, 7]);
let s = res[0].1;
for (_, score) in &res {
assert_eq!(*score, s);
}
}
#[test]
fn multi_term_query_unions_candidates_any_term() {
let mut pl = PostingList::new();
pl.insert(1, "rust embedded");
pl.insert(2, "rust web");
pl.insert(3, "go embedded");
pl.insert(4, "python web");
let res = pl.query("rust embedded", &Bm25Params::default());
let ids: std::collections::BTreeSet<i64> = res.iter().map(|(id, _)| *id).collect();
assert_eq!(ids, [1, 2, 3].iter().copied().collect());
assert_eq!(res[0].0, 1);
}
#[test]
fn serialize_round_trips_through_from_persisted() {
let mut pl = PostingList::new();
pl.insert(1, "rust embedded database");
pl.insert(2, "rust web framework");
pl.insert(3, ""); pl.insert(4, "rust rust rust embedded power");
let docs = pl.serialize_doc_lengths();
let postings = pl.serialize_postings();
let roundtripped = PostingList::from_persisted_postings(docs, postings);
assert_eq!(roundtripped.len(), pl.len(), "doc count");
assert_eq!(roundtripped.avg_doc_len(), pl.avg_doc_len(), "avg_doc_len");
let q = pl.query("rust", &Bm25Params::default());
let q2 = roundtripped.query("rust", &Bm25Params::default());
assert_eq!(q, q2, "query results must match after round-trip");
assert!(roundtripped.matches(1, "rust"));
assert!(!roundtripped.matches(3, "rust"));
}
#[test]
fn synthetic_thousand_doc_corpus_top_ten_is_stable() {
let mut pl = PostingList::new();
let rare_rows: [i64; 5] = [137, 248, 391, 642, 873];
for i in 0..1000_i64 {
let words = ["alpha", "beta", "gamma", "delta", "epsilon", "zeta"];
let pick_a = words[((i as usize) * 7) % words.len()];
let pick_b = words[((i as usize) * 13 + 1) % words.len()];
let body = if rare_rows.contains(&i) {
format!("quasar {} {}", pick_a, pick_b)
} else {
format!("{} {}", pick_a, pick_b)
};
pl.insert(i, &body);
}
assert_eq!(pl.len(), 1000);
let res = pl.query("quasar", &Bm25Params::default());
assert_eq!(res.len(), 5, "exactly five docs should contain 'quasar'");
let returned: std::collections::BTreeSet<i64> = res.iter().map(|(id, _)| *id).collect();
let expected: std::collections::BTreeSet<i64> = rare_rows.iter().copied().collect();
assert_eq!(returned, expected);
let res2 = pl.query("quasar", &Bm25Params::default());
assert_eq!(res, res2);
}
}