use chroma_types::SparseVector;
use thiserror::Error;
use crate::embed::bm25_tokenizer::Bm25Tokenizer;
use crate::embed::murmur3_abs_hasher::Murmur3AbsHasher;
use crate::embed::{EmbeddingFunction, TokenHasher, Tokenizer};
#[derive(Debug, Error)]
pub enum BM25SparseEmbeddingError {}
pub struct BM25SparseEmbeddingFunction<T, H>
where
T: Tokenizer,
H: TokenHasher,
{
pub include_tokens: bool,
pub tokenizer: T,
pub hasher: H,
pub k: f32,
pub b: f32,
pub avg_len: f32,
}
impl BM25SparseEmbeddingFunction<Bm25Tokenizer, Murmur3AbsHasher> {
pub fn default_murmur3_abs() -> Self {
Self {
include_tokens: true,
tokenizer: Bm25Tokenizer::default(),
hasher: Murmur3AbsHasher::default(),
k: 1.2,
b: 0.75,
avg_len: 256.0,
}
}
}
impl<T, H> BM25SparseEmbeddingFunction<T, H>
where
T: Tokenizer,
H: TokenHasher,
{
pub fn encode(&self, text: &str) -> Result<SparseVector, BM25SparseEmbeddingError> {
let tokens = self.tokenizer.tokenize(text);
let doc_len = tokens.len() as f32;
if self.include_tokens {
let mut token_ids = Vec::with_capacity(tokens.len());
for token in tokens {
let id = self.hasher.hash(&token);
token_ids.push((id, token));
}
token_ids.sort_unstable();
let sparse_triples = token_ids.chunk_by(|a, b| a.0 == b.0).map(|chunk| {
let id = chunk[0].0;
let tk = chunk[0].1.clone();
let tf = chunk.len() as f32;
let score = tf * (self.k + 1.0)
/ (tf + self.k * (1.0 - self.b + self.b * doc_len / self.avg_len));
(tk, id, score)
});
Ok(SparseVector::from_triples(sparse_triples))
} else {
let mut token_ids = Vec::with_capacity(tokens.len());
for token in tokens {
let id = self.hasher.hash(&token);
token_ids.push(id);
}
token_ids.sort_unstable();
let sparse_pairs = token_ids.chunk_by(|a, b| *a == *b).map(|chunk| {
let id = chunk[0];
let tf = chunk.len() as f32;
let score = tf * (self.k + 1.0)
/ (tf + self.k * (1.0 - self.b + self.b * doc_len / self.avg_len));
(id, score)
});
Ok(SparseVector::from_pairs(sparse_pairs))
}
}
}
#[async_trait::async_trait]
impl<T, H> EmbeddingFunction for BM25SparseEmbeddingFunction<T, H>
where
T: Tokenizer + Send + Sync + 'static,
H: TokenHasher + Send + Sync + 'static,
{
type Embedding = SparseVector;
type Error = BM25SparseEmbeddingError;
async fn embed_strs(&self, batches: &[&str]) -> Result<Vec<Self::Embedding>, Self::Error> {
batches.iter().map(|text| self.encode(text)).collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_bm25_comprehensive_tokenization() {
let bm25 = BM25SparseEmbeddingFunction::default_murmur3_abs();
let text = "Usain Bolt's top speed reached ~27.8 mph (44.72 km/h)";
let result = bm25.encode(text).unwrap();
let expected_indices = vec![
230246813, 395514983, 458027949, 488165615, 729632045, 734978415, 997512866,
1114505193, 1381820790, 1501587190, 1649421877, 1837285388,
];
let expected_value = 1.6391153;
assert_eq!(result.indices.len(), 12);
assert_eq!(result.indices, expected_indices);
for &value in &result.values {
assert!((value - expected_value).abs() < 1e-5);
}
}
#[test]
fn test_bm25_stopwords_and_punctuation() {
let bm25 = BM25SparseEmbeddingFunction::default_murmur3_abs();
let text = "The space-time continuum WARPS near massive objects...";
let result = bm25.encode(text).unwrap();
let expected_indices = vec![
90097469, 519064992, 737893654, 1110755108, 1950894484, 2031641008, 2058513491,
];
let expected_value = 1.660867;
assert_eq!(result.indices.len(), 7);
assert_eq!(result.indices, expected_indices);
for &value in &result.values {
assert!((value - expected_value).abs() < 1e-5);
}
}
}