#![allow(clippy::doc_markdown)]
#![allow(clippy::unwrap_or_default)]
use super::posting_list::PostingList;
use parking_lot::RwLock;
use rustc_hash::FxHashMap;
#[derive(Debug, Clone, Copy)]
pub struct Bm25Params {
pub k1: f32,
pub b: f32,
}
impl Default for Bm25Params {
fn default() -> Self {
Self { k1: 1.2, b: 0.75 }
}
}
#[derive(Debug, Clone)]
struct Document {
term_freqs: FxHashMap<String, u32>,
length: u32,
}
#[allow(clippy::cast_precision_loss)] pub struct Bm25Index {
params: Bm25Params,
inverted_index: RwLock<FxHashMap<String, PostingList>>,
documents: RwLock<FxHashMap<u64, Document>>,
point_to_doc: RwLock<FxHashMap<u64, u32>>,
doc_to_point: RwLock<FxHashMap<u32, u64>>,
free_doc_ids: RwLock<Vec<u32>>,
next_doc_id: RwLock<u32>,
doc_count: RwLock<usize>,
total_doc_length: RwLock<u64>,
}
impl Bm25Index {
#[must_use]
pub fn new() -> Self {
Self::with_params(Bm25Params::default())
}
#[must_use]
pub fn with_params(params: Bm25Params) -> Self {
Self {
params,
inverted_index: RwLock::new(FxHashMap::default()),
documents: RwLock::new(FxHashMap::default()),
point_to_doc: RwLock::new(FxHashMap::default()),
doc_to_point: RwLock::new(FxHashMap::default()),
free_doc_ids: RwLock::new(Vec::new()),
next_doc_id: RwLock::new(0),
doc_count: RwLock::new(0),
total_doc_length: RwLock::new(0),
}
}
pub(crate) fn tokenize(text: &str) -> Vec<String> {
text.to_lowercase()
.split(|c: char| !c.is_alphanumeric())
.filter(|s| !s.is_empty() && s.len() > 1) .map(String::from)
.collect()
}
pub fn add_document(&self, id: u64, text: &str) {
let tokens = Self::tokenize(text);
if tokens.is_empty() {
return;
}
let mut term_freqs: FxHashMap<String, u32> = FxHashMap::default();
for token in &tokens {
*term_freqs.entry(token.clone()).or_insert(0) += 1;
}
#[allow(clippy::cast_possible_truncation)]
let doc_length = tokens.len() as u32;
let doc = Document {
term_freqs,
length: doc_length,
};
self.remove_document_internal(id, false);
let Some(id_u32) = self.get_or_allocate_doc_id(id) else {
return;
};
{
let mut inv_idx = self.inverted_index.write();
for term in doc.term_freqs.keys() {
inv_idx
.entry(term.clone())
.or_insert_with(PostingList::new)
.insert(id_u32);
}
}
{
let mut docs = self.documents.write();
if let Some(old_doc) = docs.get(&id) {
let mut total = self.total_doc_length.write();
*total = total.saturating_sub(u64::from(old_doc.length));
} else {
let mut count = self.doc_count.write();
*count += 1;
}
docs.insert(id, doc);
}
{
let mut total = self.total_doc_length.write();
*total += u64::from(doc_length);
}
}
pub fn remove_document(&self, id: u64) -> bool {
self.remove_document_internal(id, true)
}
#[allow(clippy::cast_precision_loss)]
pub fn search(&self, query: &str, k: usize) -> Vec<(u64, f32)> {
let query_terms = Self::tokenize(query);
if query_terms.is_empty() {
return Vec::new();
}
let doc_count = *self.doc_count.read();
if doc_count == 0 {
return Vec::new();
}
let total_length = *self.total_doc_length.read();
let avgdl = total_length as f32 / doc_count as f32;
let mut scores = self.score_candidates(&query_terms, doc_count, avgdl);
Self::top_k_sort(&mut scores, k);
scores
}
#[allow(clippy::cast_precision_loss)]
fn score_candidates(
&self,
query_terms: &[String],
doc_count: usize,
avgdl: f32,
) -> Vec<(u64, f32)> {
let k1 = self.params.k1;
let b = self.params.b;
let inv_idx = self.inverted_index.read();
let docs = self.documents.read();
let doc_to_point = self.doc_to_point.read();
let n = doc_count as f32;
let idf_cache = Self::build_idf_cache(query_terms, &inv_idx, n);
let candidate_union = Self::build_candidate_union(query_terms, &inv_idx);
candidate_union
.iter()
.filter_map(|doc_id_u32| {
let doc_id = *doc_to_point.get(&doc_id_u32)?;
let doc = docs.get(&doc_id)?;
let score = Self::score_document_fast(doc, query_terms, &idf_cache, k1, b, avgdl);
(score > 0.0).then_some((doc_id, score))
})
.collect()
}
#[allow(clippy::cast_precision_loss)]
fn build_idf_cache<'a>(
query_terms: &'a [String],
inv_idx: &FxHashMap<String, PostingList>,
n: f32,
) -> FxHashMap<&'a str, f32> {
query_terms
.iter()
.map(|term| {
let df = inv_idx.get(term).map_or(0, PostingList::len);
let idf_val = if df == 0 {
0.0
} else {
let df_f = df as f32;
((n - df_f + 0.5) / (df_f + 0.5) + 1.0).ln()
};
(term.as_str(), idf_val)
})
.collect()
}
fn build_candidate_union(
query_terms: &[String],
inv_idx: &FxHashMap<String, PostingList>,
) -> PostingList {
let mut candidate_union = PostingList::new();
for term in query_terms {
if let Some(posting_list) = inv_idx.get(term) {
candidate_union = candidate_union.union(posting_list);
}
}
candidate_union
}
fn top_k_sort(scores: &mut Vec<(u64, f32)>, k: usize) {
super::top_k_partial_sort(scores, k, |a, b| b.1.total_cmp(&a.1));
}
#[allow(clippy::cast_precision_loss)]
fn score_document_fast(
doc: &Document,
query_terms: &[String],
idf_cache: &FxHashMap<&str, f32>,
k1: f32,
b: f32,
avgdl: f32,
) -> f32 {
let doc_len = doc.length as f32;
let len_norm = 1.0 - b + b * doc_len / avgdl;
query_terms
.iter()
.map(|term| {
let tf = doc.term_freqs.get(term).copied().unwrap_or(0) as f32;
if tf == 0.0 {
return 0.0;
}
let idf = idf_cache.get(term.as_str()).copied().unwrap_or(0.0);
let numerator = tf * (k1 + 1.0);
let denominator = tf + k1 * len_norm;
idf * numerator / denominator
})
.sum()
}
#[must_use]
pub fn len(&self) -> usize {
*self.doc_count.read()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.len() == 0
}
#[must_use]
pub fn term_count(&self) -> usize {
self.inverted_index.read().len()
}
fn get_or_allocate_doc_id(&self, point_id: u64) -> Option<u32> {
let mut map = self.point_to_doc.write();
if let Some(existing) = map.get(&point_id).copied() {
return Some(existing);
}
let allocated = if let Some(recycled) = self.free_doc_ids.write().pop() {
recycled
} else {
let mut next = self.next_doc_id.write();
let current = *next;
*next = next.checked_add(1)?;
current
};
map.insert(point_id, allocated);
self.doc_to_point.write().insert(allocated, point_id);
Some(allocated)
}
fn remove_document_internal(&self, point_id: u64, release_mapping: bool) -> bool {
let Some(doc_id_u32) = self.point_to_doc.read().get(&point_id).copied() else {
return false;
};
let doc = {
let mut docs = self.documents.write();
docs.remove(&point_id)
};
let mut removed = false;
if let Some(doc) = doc {
{
let mut inv_idx = self.inverted_index.write();
for term in doc.term_freqs.keys() {
if let Some(posting_list) = inv_idx.get_mut(term) {
posting_list.remove(doc_id_u32);
if posting_list.is_empty() {
inv_idx.remove(term);
}
}
}
}
{
let mut count = self.doc_count.write();
*count = count.saturating_sub(1);
}
{
let mut total = self.total_doc_length.write();
*total = total.saturating_sub(u64::from(doc.length));
}
removed = true;
}
if release_mapping {
self.point_to_doc.write().remove(&point_id);
self.doc_to_point.write().remove(&doc_id_u32);
self.free_doc_ids.write().push(doc_id_u32);
}
removed
}
}
impl Default for Bm25Index {
fn default() -> Self {
Self::new()
}
}