use std::cmp::{Ordering, Reverse};
use std::collections::{BinaryHeap, HashMap, HashSet};
use parking_lot::RwLock;
use crate::bm25::{BM25Config, BM25Scorer, tokenize_minimal};
pub type DocId = u64;
pub type Position = u32;
pub type TermFreq = u32;
struct ScoredDoc {
score: f32,
doc_id: DocId,
}
impl PartialEq for ScoredDoc {
fn eq(&self, other: &Self) -> bool {
self.cmp(other) == Ordering::Equal
}
}
impl Eq for ScoredDoc {}
impl Ord for ScoredDoc {
fn cmp(&self, other: &Self) -> Ordering {
self.score
.total_cmp(&other.score)
.then_with(|| self.doc_id.cmp(&other.doc_id))
}
}
impl PartialOrd for ScoredDoc {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
#[derive(Debug, Clone)]
pub struct Posting {
pub doc_id: DocId,
pub term_freq: TermFreq,
pub positions: Option<Vec<Position>>,
}
impl Posting {
pub fn new(doc_id: DocId, term_freq: TermFreq) -> Self {
Self {
doc_id,
term_freq,
positions: None,
}
}
pub fn with_positions(doc_id: DocId, positions: Vec<Position>) -> Self {
Self {
doc_id,
term_freq: positions.len() as TermFreq,
positions: Some(positions),
}
}
}
#[derive(Debug, Clone, Default)]
pub struct PostingList {
postings: Vec<Posting>,
}
impl PostingList {
pub fn new() -> Self {
Self {
postings: Vec::new(),
}
}
pub fn add(&mut self, posting: Posting) {
match self
.postings
.binary_search_by_key(&posting.doc_id, |p| p.doc_id)
{
Ok(idx) => {
self.postings[idx] = posting;
}
Err(idx) => {
self.postings.insert(idx, posting);
}
}
}
pub fn get(&self, doc_id: DocId) -> Option<&Posting> {
self.postings
.binary_search_by_key(&doc_id, |p| p.doc_id)
.ok()
.map(|idx| &self.postings[idx])
}
pub fn doc_freq(&self) -> usize {
self.postings.len()
}
pub fn iter(&self) -> impl Iterator<Item = &Posting> {
self.postings.iter()
}
pub fn doc_ids(&self) -> Vec<DocId> {
self.postings.iter().map(|p| p.doc_id).collect()
}
}
#[derive(Debug, Clone)]
pub struct DocumentInfo {
pub length: u32,
pub term_freqs: HashMap<String, TermFreq>,
}
pub struct InvertedIndex {
index: RwLock<HashMap<String, PostingList>>,
docs: RwLock<HashMap<DocId, DocumentInfo>>,
scorer: RwLock<BM25Scorer>,
next_doc_id: RwLock<DocId>,
store_positions: bool,
}
impl InvertedIndex {
pub fn new(config: BM25Config) -> Self {
Self {
index: RwLock::new(HashMap::new()),
docs: RwLock::new(HashMap::new()),
scorer: RwLock::new(BM25Scorer::new(config)),
next_doc_id: RwLock::new(0),
store_positions: false,
}
}
pub fn with_positions(mut self) -> Self {
self.store_positions = true;
self
}
pub fn add_document(&self, text: &str) -> DocId {
let tokens = tokenize_minimal(text);
self.add_document_tokens(&tokens)
}
pub fn add_document_with_id(&self, doc_id: DocId, text: &str) {
let tokens = tokenize_minimal(text);
self.add_document_tokens_with_id(doc_id, &tokens);
}
pub fn clear(&self) {
let config = self.scorer.read().config();
self.index.write().clear();
self.docs.write().clear();
*self.scorer.write() = BM25Scorer::new(config);
*self.next_doc_id.write() = 0;
}
pub fn rebuild_from_documents<'a, I>(&self, documents: I)
where
I: IntoIterator<Item = (DocId, &'a str)>,
{
self.clear();
let mut max_id: Option<DocId> = None;
for (doc_id, text) in documents {
self.add_document_with_id(doc_id, text);
max_id = Some(max_id.map_or(doc_id, |m| m.max(doc_id)));
}
if let Some(m) = max_id {
*self.next_doc_id.write() = m + 1;
}
}
pub fn add_document_tokens(&self, tokens: &[String]) -> DocId {
let doc_id = {
let mut next = self.next_doc_id.write();
let id = *next;
*next += 1;
id
};
self.add_document_tokens_with_id(doc_id, tokens);
doc_id
}
pub fn add_document_tokens_with_id(&self, doc_id: DocId, tokens: &[String]) {
let mut term_freqs: HashMap<String, TermFreq> = HashMap::new();
let mut term_positions: HashMap<String, Vec<Position>> = HashMap::new();
for (pos, token) in tokens.iter().enumerate() {
*term_freqs.entry(token.clone()).or_insert(0) += 1;
if self.store_positions {
term_positions
.entry(token.clone())
.or_default()
.push(pos as Position);
}
}
{
let mut index = self.index.write();
for (term, tf) in &term_freqs {
let posting = if self.store_positions {
Posting::with_positions(
doc_id,
term_positions.get(term).cloned().unwrap_or_default(),
)
} else {
Posting::new(doc_id, *tf)
};
index.entry(term.clone()).or_default().add(posting);
}
}
{
let mut docs = self.docs.write();
docs.insert(
doc_id,
DocumentInfo {
length: tokens.len() as u32,
term_freqs,
},
);
}
{
let mut scorer = self.scorer.write();
scorer.add_document(tokens.iter().map(|s| s.as_str()));
}
}
pub fn remove_document(&self, doc_id: DocId) -> bool {
let doc_info = {
let mut docs = self.docs.write();
docs.remove(&doc_id)
};
if let Some(info) = doc_info {
{
let mut index = self.index.write();
for term in info.term_freqs.keys() {
let now_empty = if let Some(posting_list) = index.get_mut(term) {
posting_list.postings.retain(|p| p.doc_id != doc_id);
posting_list.postings.is_empty()
} else {
false
};
if now_empty {
index.remove(term);
}
}
}
{
let mut scorer = self.scorer.write();
scorer.remove_document(
info.term_freqs.keys().map(|s| s.as_str()),
info.length as usize,
);
}
true
} else {
false
}
}
pub fn search(&self, query: &str, limit: usize) -> Vec<(DocId, f32)> {
let query_tokens = tokenize_minimal(query);
if query_tokens.is_empty() {
return Vec::new();
}
self.search_tokens(&query_tokens, limit)
}
pub fn search_tokens(&self, query_tokens: &[String], limit: usize) -> Vec<(DocId, f32)> {
if query_tokens.is_empty() {
return Vec::new();
}
let index = self.index.read();
let docs = self.docs.read();
let scorer = self.scorer.read();
let mut candidates: HashSet<DocId> = HashSet::new();
for token in query_tokens {
if let Some(posting_list) = index.get(token) {
for posting in posting_list.iter() {
candidates.insert(posting.doc_id);
}
}
}
let mut heap: BinaryHeap<Reverse<ScoredDoc>> = BinaryHeap::with_capacity(limit + 1);
for doc_id in candidates {
let Some(doc_info) = docs.get(&doc_id) else {
continue;
};
let score = scorer.score_with_tf_u32(
query_tokens,
&doc_info.term_freqs,
doc_info.length as usize,
);
if score <= 0.0 {
continue;
}
if limit == 0 {
continue;
}
heap.push(Reverse(ScoredDoc { score, doc_id }));
if heap.len() > limit {
heap.pop();
}
}
let mut results: Vec<(DocId, f32)> = heap
.into_iter()
.map(|Reverse(sd)| (sd.doc_id, sd.score))
.collect();
results.sort_by(|a, b| b.1.total_cmp(&a.1).then_with(|| a.0.cmp(&b.0)));
results
}
pub fn get_posting_list(&self, term: &str) -> Option<PostingList> {
self.index.read().get(&term.to_lowercase()).cloned()
}
pub fn num_documents(&self) -> usize {
self.docs.read().len()
}
pub fn vocab_size(&self) -> usize {
self.index.read().len()
}
pub fn get_document_info(&self, doc_id: DocId) -> Option<DocumentInfo> {
self.docs.read().get(&doc_id).cloned()
}
pub fn has_document(&self, doc_id: DocId) -> bool {
self.docs.read().contains_key(&doc_id)
}
}
pub struct InvertedIndexBuilder {
config: BM25Config,
store_positions: bool,
}
impl InvertedIndexBuilder {
pub fn new() -> Self {
Self {
config: BM25Config::default(),
store_positions: false,
}
}
pub fn with_config(mut self, config: BM25Config) -> Self {
self.config = config;
self
}
pub fn with_positions(mut self) -> Self {
self.store_positions = true;
self
}
pub fn build<I>(self, documents: I) -> InvertedIndex
where
I: IntoIterator<Item = (DocId, String)>,
{
let index = if self.store_positions {
InvertedIndex::new(self.config).with_positions()
} else {
InvertedIndex::new(self.config)
};
for (doc_id, text) in documents {
index.add_document_with_id(doc_id, &text);
}
index
}
}
impl Default for InvertedIndexBuilder {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_posting_list() {
let mut list = PostingList::new();
list.add(Posting::new(1, 2));
list.add(Posting::new(3, 1));
list.add(Posting::new(2, 3));
assert_eq!(list.doc_freq(), 3);
let ids = list.doc_ids();
assert_eq!(ids, vec![1, 2, 3]);
let p = list.get(2).unwrap();
assert_eq!(p.term_freq, 3);
}
#[test]
fn test_add_document() {
let index = InvertedIndex::new(BM25Config::default());
let doc1 = index.add_document("hello world");
let doc2 = index.add_document("hello there");
assert_eq!(doc1, 0);
assert_eq!(doc2, 1);
assert_eq!(index.num_documents(), 2);
let hello_list = index.get_posting_list("hello").unwrap();
assert_eq!(hello_list.doc_freq(), 2);
}
#[test]
fn test_search() {
let index = InvertedIndex::new(BM25Config::default());
index.add_document("the quick brown fox jumps over the lazy dog");
index.add_document("quick quick quick fox"); index.add_document("lazy lazy lazy dog");
let results = index.search("quick", 10);
assert!(!results.is_empty());
assert_eq!(results[0].0, 1); }
#[test]
fn test_search_multi_term() {
let index = InvertedIndex::new(BM25Config::default());
index.add_document("apple banana cherry");
index.add_document("apple banana");
index.add_document("apple");
let results = index.search("apple banana cherry", 10);
assert_eq!(results[0].0, 0);
}
#[test]
fn test_search_topk_bound_matches_full_sort() {
let index = InvertedIndex::new(BM25Config::default());
for i in 0..20 {
let body = std::iter::repeat("alpha")
.take(i + 1)
.collect::<Vec<_>>()
.join(" ");
index.add_document(&format!("{body} doc{i}"));
}
let limit = 5;
let topk = index.search("alpha", limit);
assert_eq!(topk.len(), limit, "must return exactly `limit` results");
for w in topk.windows(2) {
assert!(
w[0].1 >= w[1].1,
"results must be sorted by score descending"
);
}
let full = index.search("alpha", 1000);
let full_prefix: Vec<u64> = full.iter().take(limit).map(|(id, _)| *id).collect();
let topk_ids: Vec<u64> = topk.iter().map(|(id, _)| *id).collect();
assert_eq!(
topk_ids, full_prefix,
"bounded top-k must equal full-sort prefix"
);
}
#[test]
fn test_rebuild_reproduces_index() {
let corpus: Vec<(u64, &str)> = vec![
(10, "the quick brown fox"),
(11, "the lazy dog sleeps"),
(12, "quick foxes jump high"),
(13, "lazy dogs and quick cats"),
];
let reference = InvertedIndex::new(BM25Config::default());
for (id, text) in &corpus {
reference.add_document_with_id(*id, text);
}
let rebuilt = InvertedIndex::new(BM25Config::default());
rebuilt.add_document_with_id(99, "noise document that should vanish");
rebuilt.add_document_with_id(98, "more transient noise quick fox");
rebuilt.remove_document(99);
rebuilt.rebuild_from_documents(corpus.iter().map(|(id, t)| (*id, *t)));
assert!(!rebuilt.has_document(99));
assert!(!rebuilt.has_document(98));
for (id, _) in &corpus {
assert!(rebuilt.has_document(*id));
}
for q in ["quick", "lazy dog", "fox", "the quick brown"] {
assert_eq!(
rebuilt.search(q, 10),
reference.search(q, 10),
"rebuilt ranking diverges for query {q:?}",
);
}
let next = rebuilt.add_document("brand new quick doc");
assert_eq!(next, 14, "auto-id must resume one past max restored id");
}
#[test]
fn test_remove_document() {
let index = InvertedIndex::new(BM25Config::default());
let doc1 = index.add_document("hello world");
let doc2 = index.add_document("hello there");
assert!(index.has_document(doc1));
assert!(index.remove_document(doc1));
assert!(!index.has_document(doc1));
let hello_list = index.get_posting_list("hello").unwrap();
assert_eq!(hello_list.doc_freq(), 1);
assert!(index.get_posting_list("world").is_none());
}
#[test]
fn test_add_remove_equals_never_added() {
let with_removed = InvertedIndex::new(BM25Config::default());
with_removed.add_document_with_id(1, "the quick brown fox");
with_removed.add_document_with_id(2, "lazy dog sleeps all day");
let transient = with_removed.add_document("ephemeral zebra quagga");
assert!(with_removed.remove_document(transient));
let never_added = InvertedIndex::new(BM25Config::default());
never_added.add_document_with_id(1, "the quick brown fox");
never_added.add_document_with_id(2, "lazy dog sleeps all day");
assert_eq!(with_removed.num_documents(), never_added.num_documents());
assert_eq!(with_removed.vocab_size(), never_added.vocab_size());
assert!(with_removed.get_posting_list("zebra").is_none());
assert!(with_removed.get_posting_list("quagga").is_none());
for q in ["quick", "dog", "fox sleeps"] {
let a = with_removed.search(q, 10);
let b = never_added.search(q, 10);
assert_eq!(a.len(), b.len(), "result-count mismatch for {q:?}");
for (x, y) in a.iter().zip(b.iter()) {
assert_eq!(x.0, y.0, "doc_id mismatch for {q:?}");
assert_eq!(x.1.to_bits(), y.1.to_bits(), "score mismatch for {q:?}");
}
}
}
#[test]
fn test_builder() {
let documents = vec![
(0, "hello world".to_string()),
(1, "hello there".to_string()),
(2, "goodbye world".to_string()),
];
let index = InvertedIndexBuilder::new()
.with_config(BM25Config::lucene())
.build(documents);
assert_eq!(index.num_documents(), 3);
assert!(index.vocab_size() > 0);
}
#[test]
fn test_positions() {
let index = InvertedIndex::new(BM25Config::default()).with_positions();
let doc_id = index.add_document("hello world hello");
let hello_list = index.get_posting_list("hello").unwrap();
let posting = hello_list.get(doc_id).unwrap();
assert_eq!(posting.positions, Some(vec![0, 2]));
}
}