use std::collections::HashMap;
use std::marker::PhantomData;
use triblespace_core::inline::encodings::genid::GenId;
use triblespace_core::inline::{RawInline, IntoInline, Inline, InlineEncoding};
const DEFAULT_K1: f32 = 1.5;
const DEFAULT_B: f32 = 0.75;
pub struct BM25Builder<D: InlineEncoding = GenId, T: InlineEncoding = crate::tokens::WordHash> {
pub(crate) docs: Vec<(RawInline, Vec<RawInline>)>,
pub(crate) k1: f32,
pub(crate) b: f32,
pub(crate) _phantom: PhantomData<(D, T)>,
}
impl<D: InlineEncoding, T: InlineEncoding> Clone for BM25Builder<D, T> {
fn clone(&self) -> Self {
Self {
docs: self.docs.clone(),
k1: self.k1,
b: self.b,
_phantom: PhantomData,
}
}
}
impl<D: InlineEncoding, T: InlineEncoding> Default for BM25Builder<D, T> {
fn default() -> Self {
Self::new()
}
}
impl<D: InlineEncoding, T: InlineEncoding> BM25Builder<D, T> {
pub fn new() -> Self {
Self {
docs: Vec::new(),
k1: DEFAULT_K1,
b: DEFAULT_B,
_phantom: PhantomData,
}
}
pub fn k1(mut self, k1: f32) -> Self {
self.k1 = k1;
self
}
pub fn b(mut self, b: f32) -> Self {
self.b = b;
self
}
pub fn insert<K, I>(&mut self, key: K, terms: I)
where
K: IntoInline<D>,
I: IntoIterator<Item = Inline<T>>,
{
let key_val: Inline<D> = key.to_inline();
let term_rows: Vec<RawInline> = terms.into_iter().map(|v| v.raw).collect();
self.docs.push((key_val.raw, term_rows));
}
pub fn build(self) -> crate::succinct::SuccinctBM25Index<D, T> {
crate::succinct::SuccinctBM25Index::from_builder(self)
}
pub fn build_naive(self) -> BM25Index<D, T> {
self.build_naive_with_threads(1)
}
pub fn build_naive_with_threads(self, threads: usize) -> BM25Index<D, T> {
let Self { docs, k1, b, _phantom } = self;
let n_docs = docs.len();
let doc_lens: Vec<u32> = docs.iter().map(|(_, t)| t.len() as u32).collect();
let avg_doc_len = if n_docs == 0 {
0.0
} else {
doc_lens.iter().map(|&n| n as f64).sum::<f64>() as f32 / n_docs as f32
};
let keys: Vec<RawInline> = docs.iter().map(|(key, _)| *key).collect();
let term_to_tfs = if threads <= 1 || n_docs < 2 {
let mut m: HashMap<RawInline, HashMap<u32, u32>> = HashMap::new();
for (doc_idx, (_, terms)) in docs.into_iter().enumerate() {
accumulate_tfs(&mut m, doc_idx as u32, terms);
}
m
} else {
let threads = threads.min(n_docs);
let base_chunk = n_docs / threads;
let extra = n_docs % threads;
let mut starts = Vec::with_capacity(threads);
let mut chunks: Vec<Vec<(RawInline, Vec<RawInline>)>> = Vec::with_capacity(threads);
let mut docs_iter = docs.into_iter();
let mut idx = 0usize;
for t in 0..threads {
let size = base_chunk + if t < extra { 1 } else { 0 };
let chunk: Vec<_> = (&mut docs_iter).take(size).collect();
starts.push(idx);
idx += size;
chunks.push(chunk);
}
let locals: Vec<HashMap<RawInline, HashMap<u32, u32>>> = std::thread::scope(|s| {
let mut handles = Vec::with_capacity(threads);
for (shard_start, chunk) in starts.iter().zip(chunks) {
let start = *shard_start as u32;
handles.push(s.spawn(move || {
let mut m: HashMap<RawInline, HashMap<u32, u32>> = HashMap::new();
for (i, (_, terms)) in chunk.into_iter().enumerate() {
accumulate_tfs(&mut m, start + i as u32, terms);
}
m
}));
}
handles.into_iter().map(|h| h.join().unwrap()).collect()
});
let mut merged: HashMap<RawInline, HashMap<u32, u32>> = HashMap::new();
for local in locals {
for (term, tfs) in local {
merged.entry(term).or_default().extend(tfs);
}
}
merged
};
let mut terms: Vec<RawInline> = term_to_tfs.keys().copied().collect();
terms.sort_unstable();
let mut offsets: Vec<u32> = Vec::with_capacity(terms.len() + 1);
offsets.push(0);
let mut postings: Vec<(u32, f32)> = Vec::new();
let n = n_docs as f32;
for term in &terms {
let tfs = &term_to_tfs[term];
let df = tfs.len() as f32;
let idf = ((n - df + 0.5) / (df + 0.5) + 1.0).ln();
let mut entries: Vec<(u32, f32)> = tfs
.iter()
.map(|(&doc_idx, &tf)| {
let tf = tf as f32;
let dl = doc_lens[doc_idx as usize] as f32;
let norm = if avg_doc_len > 0.0 {
1.0 - b + b * (dl / avg_doc_len)
} else {
1.0
};
let score = idf * (tf * (k1 + 1.0)) / (tf + k1 * norm);
(doc_idx, score)
})
.collect();
entries.sort_unstable_by_key(|&(idx, _)| idx);
postings.extend(entries);
offsets.push(postings.len() as u32);
}
BM25Index {
keys,
doc_lens,
avg_doc_len,
terms,
offsets,
postings,
k1,
b,
_phantom,
}
}
}
fn accumulate_tfs(
m: &mut HashMap<RawInline, HashMap<u32, u32>>,
doc_idx: u32,
terms: Vec<RawInline>,
) {
for term in terms {
let entry = m.entry(term).or_default().entry(doc_idx).or_insert(0);
*entry += 1;
}
}
#[doc(hidden)]
pub struct BM25Index<D: InlineEncoding = GenId, T: InlineEncoding = crate::tokens::WordHash> {
keys: Vec<RawInline>,
doc_lens: Vec<u32>,
avg_doc_len: f32,
terms: Vec<RawInline>,
offsets: Vec<u32>,
postings: Vec<(u32, f32)>,
k1: f32,
b: f32,
_phantom: PhantomData<(D, T)>,
}
impl<D: InlineEncoding, T: InlineEncoding> std::fmt::Debug for BM25Index<D, T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("BM25Index")
.field("n_docs", &self.keys.len())
.field("n_terms", &self.terms.len())
.field("avg_doc_len", &self.avg_doc_len)
.field("k1", &self.k1)
.field("b", &self.b)
.finish()
}
}
impl<D: InlineEncoding, T: InlineEncoding> Clone for BM25Index<D, T> {
fn clone(&self) -> Self {
Self {
keys: self.keys.clone(),
doc_lens: self.doc_lens.clone(),
avg_doc_len: self.avg_doc_len,
terms: self.terms.clone(),
offsets: self.offsets.clone(),
postings: self.postings.clone(),
k1: self.k1,
b: self.b,
_phantom: PhantomData,
}
}
}
impl<D: InlineEncoding, T: InlineEncoding> BM25Index<D, T> {
pub fn doc_count(&self) -> usize {
self.keys.len()
}
pub fn term_count(&self) -> usize {
self.terms.len()
}
pub fn avg_doc_len(&self) -> f32 {
self.avg_doc_len
}
pub fn query_term<'a>(
&'a self,
term: &Inline<T>,
) -> impl Iterator<Item = (Inline<D>, f32)> + 'a {
let lo = self.terms.binary_search(&term.raw).ok();
let range = match lo {
Some(i) => self.offsets[i] as usize..self.offsets[i + 1] as usize,
None => 0..0,
};
self.postings[range]
.iter()
.map(|&(doc_idx, score)| (Inline::<D>::new(self.keys[doc_idx as usize]), score))
}
pub fn query_multi(&self, terms: &[Inline<T>]) -> Vec<(Inline<D>, f32)> {
let mut acc: HashMap<RawInline, f32> = HashMap::new();
for term in terms {
for (doc, score) in self.query_term(term) {
*acc.entry(doc.raw).or_insert(0.0) += score;
}
}
let mut out: Vec<(Inline<D>, f32)> =
acc.into_iter().map(|(raw, s)| (Inline::<D>::new(raw), s)).collect();
out.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
out
}
pub fn doc_frequency(&self, term: &Inline<T>) -> usize {
self.terms
.binary_search(&term.raw)
.ok()
.map(|i| (self.offsets[i + 1] - self.offsets[i]) as usize)
.unwrap_or(0)
}
pub fn k1(&self) -> f32 {
self.k1
}
pub fn b(&self) -> f32 {
self.b
}
pub fn doc_lens(&self) -> &[u32] {
&self.doc_lens
}
pub fn keys(&self) -> &[RawInline] {
&self.keys
}
pub fn terms_slice(&self) -> &[RawInline] {
&self.terms
}
pub fn postings_for(&self, t: usize) -> &[(u32, f32)] {
if t >= self.terms.len() {
return &[];
}
let lo = self.offsets[t] as usize;
let hi = self.offsets[t + 1] as usize;
&self.postings[lo..hi]
}
pub fn byte_size(&self) -> usize {
20 + self.keys.len() * 32
+ self.doc_lens.len() * 4
+ self.terms.len() * 32
+ self.offsets.len() * 4
+ self.postings.len() * 8
}
}
impl<D: InlineEncoding, T: InlineEncoding> PartialEq for BM25Index<D, T> {
fn eq(&self, other: &Self) -> bool {
fn f32_bit_eq(a: f32, b: f32) -> bool {
a.to_bits() == b.to_bits()
}
self.keys == other.keys
&& self.doc_lens == other.doc_lens
&& f32_bit_eq(self.avg_doc_len, other.avg_doc_len)
&& self.terms == other.terms
&& self.offsets == other.offsets
&& self.postings.len() == other.postings.len()
&& self
.postings
.iter()
.zip(other.postings.iter())
.all(|(a, b)| a.0 == b.0 && f32_bit_eq(a.1, b.1))
&& f32_bit_eq(self.k1, other.k1)
&& f32_bit_eq(self.b, other.b)
}
}
impl<D: InlineEncoding, T: InlineEncoding> Eq for BM25Index<D, T> {}
#[cfg(test)]
mod tests {
use super::*;
use crate::tokens::hash_tokens;
use triblespace_core::id::Id;
fn id(byte: u8) -> Id {
assert!(byte != 0, "test fixture: 0 is the nil Id");
Id::new([byte; 16]).unwrap()
}
fn id_key(byte: u8) -> RawInline {
let v: Inline<GenId> = id(byte).to_inline();
v.raw
}
#[test]
fn empty_index_is_queryable() {
let idx = BM25Builder::<GenId, crate::tokens::WordHash>::new().build();
assert_eq!(idx.doc_count(), 0);
assert_eq!(idx.term_count(), 0);
let term: Inline<crate::tokens::WordHash> = Inline::new([0u8; 32]);
assert!(idx.query_term(&term).next().is_none());
}
#[test]
fn insert_indexes_by_string_key() {
use triblespace_core::inline::encodings::shortstring::ShortString;
use triblespace_core::inline::{IntoInline, Inline};
let mut b: BM25Builder<ShortString> = BM25Builder::new();
let red: Inline<ShortString> = "red".to_inline();
let blue: Inline<ShortString> = "blue".to_inline();
b.insert(red, hash_tokens("a tomato is red"));
b.insert(blue, hash_tokens("the ocean is blue"));
let idx = b.build();
assert_eq!(idx.doc_count(), 2);
let red_hits: Vec<_> = idx.query_term(&hash_tokens("red")[0]).collect();
let blue_hits: Vec<_> = idx.query_term(&hash_tokens("blue")[0]).collect();
assert_eq!(red_hits.len(), 1);
assert_eq!(blue_hits.len(), 1);
let red_raw: Inline<ShortString> = "red".to_inline();
let blue_raw: Inline<ShortString> = "blue".to_inline();
assert_eq!(red_hits[0].0, red_raw);
assert_eq!(blue_hits[0].0, blue_raw);
}
#[test]
fn three_docs_basic() {
let mut b: BM25Builder = BM25Builder::new();
b.insert(id(1), hash_tokens("the quick brown fox"));
b.insert(id(2), hash_tokens("the lazy brown dog"));
b.insert(id(3), hash_tokens("quick silver fox"));
let idx = b.build();
assert_eq!(idx.doc_count(), 3);
let fox = hash_tokens("fox");
let hits: Vec<_> = idx.query_term(&fox[0]).collect();
assert_eq!(hits.len(), 2);
let doc_ids: Vec<_> = hits.iter().map(|(d, _)| *d).collect();
assert!(doc_ids.iter().any(|v| v.raw == id_key(1)));
assert!(doc_ids.iter().any(|v| v.raw == id_key(3)));
let the = hash_tokens("the");
assert_eq!(idx.doc_frequency(&the[0]), 2);
let missing = hash_tokens("banana");
assert!(idx.query_term(&missing[0]).next().is_none());
}
#[test]
fn idf_prefers_rare_terms() {
let mut b: BM25Builder = BM25Builder::new();
for i in 1..=10 {
b.insert(id(i), hash_tokens("common common"));
}
b.insert(id(100), hash_tokens("common rare"));
let idx = b.build();
let rare = hash_tokens("rare");
let common = hash_tokens("common");
let rare_score = idx.query_term(&rare[0]).next().unwrap().1;
let common_score = idx
.query_term(&common[0])
.max_by(|a, b| a.1.partial_cmp(&b.1).unwrap())
.unwrap()
.1;
assert!(
rare_score > common_score,
"rare_score={rare_score}, common_score={common_score}"
);
}
#[test]
fn term_frequency_saturates() {
let mut b: BM25Builder = BM25Builder::new();
b.insert(id(1), hash_tokens("foo bar baz"));
let many: String = "foo ".repeat(100);
b.insert(id(2), hash_tokens(&many));
let idx = b.build();
let foo = hash_tokens("foo");
let scores: HashMap<RawInline, f32> = idx
.query_term(&foo[0])
.map(|(v, s)| (v.raw, s))
.collect();
let s1 = scores[&id_key(1)];
let s2 = scores[&id_key(2)];
assert!(s2 > s1);
assert!(
s2 < s1 * 20.0,
"tf saturation should keep ratio moderate: {s1} -> {s2}"
);
}
#[test]
fn multi_term_query_sums() {
let mut b: BM25Builder = BM25Builder::new();
b.insert(id(1), hash_tokens("quick brown fox"));
b.insert(id(2), hash_tokens("quick red fox"));
b.insert(id(3), hash_tokens("slow brown dog"));
let idx = b.build();
let q = hash_tokens("quick fox");
let ranked = idx.query_multi(&q);
assert_eq!(ranked.len(), 2);
let top_ids: Vec<[u8; 32]> = ranked.iter().map(|(d, _)| d.raw).collect();
assert!(top_ids.contains(&id_key(1)));
assert!(top_ids.contains(&id_key(2)));
assert!(ranked[0].1 >= ranked[1].1);
}
#[test]
fn tuning_params_round_trip() {
let b = BM25Builder::<GenId, crate::tokens::WordHash>::new()
.k1(1.2)
.b(0.5);
let idx = b.build();
assert!((idx.k1() - 1.2).abs() < 1e-6);
assert!((idx.b() - 0.5).abs() < 1e-6);
}
fn build_sample_index() -> BM25Index {
let mut b: BM25Builder = BM25Builder::new().k1(1.4).b(0.72);
b.insert(id(1), hash_tokens("the quick brown fox"));
b.insert(id(2), hash_tokens("the lazy brown dog"));
b.insert(id(3), hash_tokens("quick silver fox jumps"));
b.build_naive()
}
#[test]
fn byte_size_matches_naive_layout() {
let idx = build_sample_index();
let expected = 20
+ idx.keys().len() * 32
+ idx.doc_lens().len() * 4
+ idx.terms_slice().len() * 32
+ (idx.terms_slice().len() + 1) * 4
+ (0..idx.term_count())
.map(|t| idx.postings_for(t).len())
.sum::<usize>()
* 8;
assert_eq!(idx.byte_size(), expected);
}
#[test]
fn build_is_deterministic() {
let a = build_sample_index();
let b = build_sample_index();
assert_eq!(a, b);
}
#[test]
fn parallel_build_matches_single_thread() {
fn build(threads: usize) -> BM25Index {
let mut b: BM25Builder = BM25Builder::new();
for i in 1..=50u32 {
let text = format!(
"doc {i} text about {} {}",
(i % 5) + 1,
(i.wrapping_mul(7)) % 13
);
let byte = (i as u8).max(1);
b.insert(id(byte), hash_tokens(&text));
}
b.build_naive_with_threads(threads)
}
let serial = build(1);
for t in [2usize, 3, 4, 8] {
assert_eq!(
build(t),
serial,
"threads={t} produced a different index than serial"
);
}
}
#[test]
fn parallel_build_on_empty_corpus() {
let idx = BM25Builder::<GenId, crate::tokens::WordHash>::new()
.build_naive_with_threads(4);
assert_eq!(idx.doc_count(), 0);
assert_eq!(idx.term_count(), 0);
}
#[test]
fn parallel_build_threads_cap_at_doc_count() {
let mut b: BM25Builder = BM25Builder::new();
b.insert(id(1), hash_tokens("one two three"));
b.insert(id(2), hash_tokens("two three four"));
b.insert(id(3), hash_tokens("three four five"));
let idx = b.build_naive_with_threads(16);
assert_eq!(idx.doc_count(), 3);
let three = hash_tokens("three")[0];
assert_eq!(idx.doc_frequency(&three), 3);
}
#[test]
fn ngrams_enable_prefix_queries() {
use crate::tokens::{ngram_tokens, NgramHash};
let mut b: BM25Builder<GenId, NgramHash> = BM25Builder::new();
b.insert(id(1), ngram_tokens("foxes are cunning", 3));
b.insert(id(2), ngram_tokens("the dog barks", 3));
b.insert(id(3), ngram_tokens("silver fox at night", 3));
let idx = b.build();
let q = ngram_tokens("fox", 3);
let hits: Vec<_> = idx.query_multi(&q);
let doc_ids: Vec<_> = hits.iter().map(|(d, _)| *d).collect();
assert!(doc_ids.iter().any(|v| v.raw == id_key(1)), "prefix should match 'foxes'");
assert!(doc_ids.iter().any(|v| v.raw == id_key(3)), "prefix should match 'fox'");
assert!(!doc_ids.iter().any(|v| v.raw == id_key(2)), "must not match 'dog'");
}
}