use crate::bm25_tokenizer::{ Bm25Tokenizer};
use crate::bm25_vectorizer::Bm25VectorizerError::{
InvalidAverageDocumentLength, InvalidTermFrequencyLowerBound, InvalidTermRelevanceSaturation,
MissingAverageDocumentLength, MissingTokenIndexer, MissingTokenizer,
};
use std::collections::BTreeMap;
use std::fmt::Debug;
use std::hash::Hash;
#[cfg(feature = "parallelism")]
use rayon::prelude::*;
use crate::Bm25TokenIndexer;
#[derive(PartialEq, Debug, Clone, PartialOrd)]
pub struct TokenIndexValue<T> {
pub index: T,
pub value: f32,
}
#[derive(PartialEq, Debug, Clone, PartialOrd)]
pub struct SparseRepresentation<T>(pub Vec<TokenIndexValue<T>>);
#[derive(Debug)]
pub struct TermRelevanceSaturation {
k1: f32,
}
#[derive(Debug)]
pub struct TermFrequencyLowerBound {
delta: f32,
}
#[derive(Debug)]
pub struct LengthNormalisation {
b: f32,
}
#[derive(Debug)]
pub struct AverageDocumentLength {
avgdl: f32,
}
#[derive(Debug)]
pub struct Bm25Vectorizer<TokenIndexer, Tokenizer> {
tokenizer: Tokenizer,
k1: TermRelevanceSaturation,
b: LengthNormalisation,
avgdl: AverageDocumentLength,
delta: TermFrequencyLowerBound,
token_indexer: TokenIndexer,
}
impl<TokenIndexer, Tokenizer> Bm25Vectorizer<TokenIndexer, Tokenizer> {
pub fn avgdl(&self) -> f32 {
self.avgdl.avgdl
}
pub fn k1(&self) -> f32 {
self.k1.k1
}
pub fn b(&self) -> f32 {
self.b.b
}
pub fn delta(&self) -> f32 {
self.delta.delta
}
pub fn vectorize(&self, text: &str) -> SparseRepresentation<TokenIndexer::Bm25TokenIndex>
where
TokenIndexer: Bm25TokenIndexer,
TokenIndexer::Bm25TokenIndex: Eq + Hash + Clone + Debug + Ord,
Tokenizer: Bm25Tokenizer,
{
let tokens = self.tokenizer.tokenize(text);
let doc_length = tokens.len() as f32;
let mut index_counts: BTreeMap<TokenIndexer::Bm25TokenIndex, usize> = BTreeMap::new();
for token in tokens.iter() {
let index = self.token_indexer.index(token);
*index_counts.entry(index).or_insert(0) += 1;
}
let embeddings: Vec<TokenIndexValue<TokenIndexer::Bm25TokenIndex>> = index_counts
.into_iter()
.map(|(index, count)| {
let token_frequency = count as f32;
let numerator = token_frequency * (self.k1() + 1.0);
let denominator = token_frequency
+ self.k1() * (1.0 - self.b() + self.b() * (doc_length / self.avgdl()));
let value = (numerator / denominator) + self.delta();
TokenIndexValue { index, value }
})
.collect();
SparseRepresentation(embeddings)
}
}
pub struct Bm25VectorizerBuilder<TokenIndexer, Tokenizer> {
tokenizer: Option<Tokenizer>,
k1: TermRelevanceSaturation,
b: LengthNormalisation,
avgdl: Option<AverageDocumentLength>,
delta: TermFrequencyLowerBound,
token_indexer: Option<TokenIndexer>,
}
impl<TokenIndexer, Tokenizer> Bm25VectorizerBuilder<TokenIndexer, Tokenizer> {
pub fn new() -> Self {
Self {
tokenizer: None,
k1: TermRelevanceSaturation { k1: 1.2 },
b: LengthNormalisation { b: 0.75 },
avgdl: None,
delta: TermFrequencyLowerBound { delta: 0.0 },
token_indexer: None,
}
}
pub fn k1(mut self, k1: f32) -> Self {
self.k1 = TermRelevanceSaturation { k1 };
self
}
pub fn b(mut self, b: f32) -> Self {
self.b = LengthNormalisation { b };
self
}
pub fn delta(mut self, delta: f32) -> Self {
self.delta = TermFrequencyLowerBound { delta };
self
}
pub fn avgdl(mut self, avgdl: f32) -> Self {
self.avgdl = Some(AverageDocumentLength { avgdl });
self
}
pub fn tokenizer(mut self, tokenizer: Tokenizer) -> Self {
self.tokenizer = Some(tokenizer);
self
}
pub fn token_indexer(mut self, token_indexer: TokenIndexer) -> Self {
self.token_indexer = Some(token_indexer);
self
}
pub fn fit(mut self, corpus: &[&str]) -> Result<Self, Bm25VectorizerError>
where
Tokenizer: Bm25Tokenizer + Sync,
{
if let Some(ref tokenizer) = self.tokenizer {
let doc_count = corpus.len();
if doc_count == 0 {
return Err(Bm25VectorizerError::EmptyCorpus);
}
#[cfg(not(feature = "parallelism"))]
let corpus_iter = corpus.iter();
#[cfg(feature = "parallelism")]
let corpus_iter = corpus.par_iter();
let total_length: usize = corpus_iter.map(|doc| tokenizer.tokenize(doc).len()).sum();
self.avgdl = Some(AverageDocumentLength {
avgdl: total_length as f32 / doc_count as f32,
});
}
Ok(self)
}
pub fn fit_iter<I, S>(mut self, corpus: I) -> Result<Self, Bm25VectorizerError>
where
I: IntoIterator<Item = S>,
S: AsRef<str>,
Tokenizer: Bm25Tokenizer + Sync,
{
if let Some(ref tokenizer) = self.tokenizer {
let (doc_count, total_length) = corpus
.into_iter()
.map(|doc| tokenizer.tokenize(doc.as_ref()).len())
.fold((0usize, 0usize), |(count, sum), len| (count + 1, sum + len));
self.avgdl = Some(AverageDocumentLength {
avgdl: total_length as f32 / doc_count as f32,
});
}
Ok(self)
}
#[cfg(feature = "parallelism")]
pub fn fit_par_iter<I, S>(mut self, corpus: I) -> Result<Self, Bm25VectorizerError>
where
I: IntoIterator<Item = S>,
I::IntoIter: Send,
S: AsRef<str> + Send,
Tokenizer: Bm25Tokenizer + Sync,
{
if let Some(ref tokenizer) = self.tokenizer {
let (doc_count, total_length) = {
use rayon::iter::ParallelBridge;
corpus
.into_iter()
.par_bridge()
.map(|doc| tokenizer.tokenize(doc.as_ref()).len())
.fold(
|| (0usize, 0usize),
|(count, sum), len| (count + 1, sum + len),
)
.reduce(|| (0, 0), |(c1, s1), (c2, s2)| (c1 + c2, s1 + s2))
};
if doc_count == 0 {
return Err(Bm25VectorizerError::EmptyCorpus);
}
self.avgdl = Some(AverageDocumentLength {
avgdl: total_length as f32 / doc_count as f32,
});
}
Ok(self)
}
pub fn build(self) -> Result<Bm25Vectorizer<TokenIndexer, Tokenizer>, Bm25VectorizerError> {
let tokenizer = self.tokenizer.ok_or(MissingTokenizer)?;
let token_indexer = self.token_indexer.ok_or(MissingTokenIndexer)?;
let avgdl = self.avgdl.ok_or(MissingAverageDocumentLength)?;
if &self.k1.k1 < &0.0 {
return Err(InvalidTermRelevanceSaturation);
}
if &self.b.b < &0.0 || &self.b.b > &1.0 {
return Err(InvalidTermRelevanceSaturation);
}
if &avgdl.avgdl <= &0.0 {
return Err(InvalidAverageDocumentLength);
}
if &self.delta.delta < &0.0 {
return Err(InvalidTermFrequencyLowerBound);
}
Ok(Bm25Vectorizer {
tokenizer,
k1: self.k1,
b: self.b,
avgdl,
delta: self.delta,
token_indexer,
})
}
}
#[derive(Debug, thiserror::Error)]
pub enum Bm25VectorizerError {
#[error("Cannot fit on empty corpus.")]
EmptyCorpus,
#[error("Average document length must be provided or computed via fit().")]
MissingAverageDocumentLength,
#[error("Tokenizer must be provided.")]
MissingTokenizer,
#[error("Token indexer must be provided.")]
MissingTokenIndexer,
#[error("Invalid b value: must be between 0 and 1.")]
InvalidLengthNormalisation,
#[error(
"Invalid k1 value: should normally fall within the 0 to 3 range. However, there is no strict enforcement preventing values higher than 3."
)]
InvalidTermRelevanceSaturation,
#[error("Invalid average document length: value must be greater than 0.")]
InvalidAverageDocumentLength,
#[error("Invalid delta (δ) value: must be 0 or greater.")]
InvalidTermFrequencyLowerBound,
}
#[cfg(test)]
mod tests {
use super::*;
use crate::mocking::{
MockDictionaryTokenIndexer, MockHashTokenIndexer, MockWhitespaceTokenizer,
};
#[test]
fn test_builder_new_defaults() {
let builder = Bm25VectorizerBuilder::<MockHashTokenIndexer, MockWhitespaceTokenizer>::new();
assert_eq!(builder.k1.k1, 1.2);
assert_eq!(builder.b.b, 0.75);
assert_eq!(builder.delta.delta, 0.0);
assert!(builder.tokenizer.is_none());
assert!(builder.token_indexer.is_none());
assert!(builder.avgdl.is_none());
}
#[test]
fn test_builder_parameter_setting() {
let builder = Bm25VectorizerBuilder::<MockHashTokenIndexer, MockWhitespaceTokenizer>::new()
.k1(2.0)
.b(0.5)
.delta(0.25)
.avgdl(15.0);
assert_eq!(builder.k1.k1, 2.0);
assert_eq!(builder.b.b, 0.5);
assert_eq!(builder.delta.delta, 0.25);
assert_eq!(builder.avgdl.unwrap().avgdl, 15.0);
}
#[test]
fn test_builder_missing_components() {
let result = Bm25VectorizerBuilder::<MockHashTokenIndexer, MockWhitespaceTokenizer>::new()
.avgdl(10.0)
.build();
assert!(matches!(result, Err(MissingTokenizer)));
let result = Bm25VectorizerBuilder::<MockHashTokenIndexer, MockWhitespaceTokenizer>::new()
.tokenizer(MockWhitespaceTokenizer)
.avgdl(10.0)
.build();
assert!(matches!(result, Err(MissingTokenIndexer)));
let result = Bm25VectorizerBuilder::<MockHashTokenIndexer, MockWhitespaceTokenizer>::new()
.tokenizer(MockWhitespaceTokenizer)
.token_indexer(MockHashTokenIndexer)
.build();
assert!(matches!(result, Err(MissingAverageDocumentLength)));
}
#[test]
fn test_builder_invalid_parameters() {
let result = Bm25VectorizerBuilder::new()
.tokenizer(MockWhitespaceTokenizer)
.token_indexer(MockHashTokenIndexer)
.k1(-1.0)
.avgdl(10.0)
.build();
assert!(matches!(result, Err(InvalidTermRelevanceSaturation)));
let result = Bm25VectorizerBuilder::new()
.tokenizer(MockWhitespaceTokenizer)
.token_indexer(MockHashTokenIndexer)
.b(-0.1)
.avgdl(10.0)
.build();
assert!(matches!(result, Err(InvalidTermRelevanceSaturation)));
let result = Bm25VectorizerBuilder::new()
.tokenizer(MockWhitespaceTokenizer)
.token_indexer(MockHashTokenIndexer)
.b(1.1)
.avgdl(10.0)
.build();
assert!(matches!(result, Err(InvalidTermRelevanceSaturation)));
let result = Bm25VectorizerBuilder::new()
.tokenizer(MockWhitespaceTokenizer)
.token_indexer(MockHashTokenIndexer)
.avgdl(0.0)
.build();
assert!(matches!(result, Err(InvalidAverageDocumentLength)));
let result = Bm25VectorizerBuilder::new()
.tokenizer(MockWhitespaceTokenizer)
.token_indexer(MockHashTokenIndexer)
.delta(-0.1)
.avgdl(10.0)
.build();
assert!(matches!(result, Err(InvalidTermFrequencyLowerBound)));
}
#[test]
fn test_successful_build() {
let vectorizer = Bm25VectorizerBuilder::new()
.tokenizer(MockWhitespaceTokenizer)
.token_indexer(MockHashTokenIndexer)
.k1(1.5)
.b(0.8)
.delta(0.25)
.avgdl(12.0)
.build()
.unwrap();
assert_eq!(vectorizer.k1(), 1.5);
assert_eq!(vectorizer.b(), 0.8);
assert_eq!(vectorizer.delta(), 0.25);
assert_eq!(vectorizer.avgdl(), 12.0);
}
#[test]
fn test_fit_corpus() {
let corpus = vec!["hello world", "world of rust", "hello rust programming"];
let builder = Bm25VectorizerBuilder::new()
.tokenizer(MockWhitespaceTokenizer)
.token_indexer(MockHashTokenIndexer)
.fit(&corpus)
.unwrap();
let expected_avgdl = (2.0 + 3.0 + 3.0) / 3.0;
assert_eq!(builder.avgdl.unwrap().avgdl, expected_avgdl);
}
#[test]
fn test_fit_empty_corpus() {
let corpus: Vec<&str> = vec![];
let result = Bm25VectorizerBuilder::<MockHashTokenIndexer, MockWhitespaceTokenizer>::new()
.tokenizer(MockWhitespaceTokenizer)
.fit(&corpus);
assert!(matches!(result, Err(Bm25VectorizerError::EmptyCorpus)));
}
#[test]
fn test_vectorize_basic() {
let vectorizer = Bm25VectorizerBuilder::new()
.tokenizer(MockWhitespaceTokenizer)
.token_indexer(MockDictionaryTokenIndexer::new())
.avgdl(2.0)
.build()
.unwrap();
let result = vectorizer.vectorize("hello world");
assert_eq!(result.0.len(), 2);
for token in &result.0 {
assert!(token.value > 0.0);
}
}
#[test]
fn test_vectorize_repeated_tokens() {
let vectorizer = Bm25VectorizerBuilder::new()
.tokenizer(MockWhitespaceTokenizer)
.token_indexer(MockDictionaryTokenIndexer::new())
.avgdl(3.0)
.build()
.unwrap();
let result = vectorizer.vectorize("hello hello world");
assert_eq!(result.0.len(), 2);
let hello_value = result.0.iter().find(|t| t.index == 0).unwrap().value; let world_value = result.0.iter().find(|t| t.index == 1).unwrap().value;
assert!(hello_value > world_value);
}
#[test]
fn test_vectorize_empty_text() {
let vectorizer = Bm25VectorizerBuilder::new()
.tokenizer(MockWhitespaceTokenizer)
.token_indexer(MockHashTokenIndexer)
.avgdl(2.0)
.build()
.unwrap();
let result = vectorizer.vectorize("");
assert_eq!(result.0.len(), 0);
}
#[test]
fn test_bm25_parameters_effect() {
let vectorizer_low_k1 = Bm25VectorizerBuilder::new()
.tokenizer(MockWhitespaceTokenizer)
.token_indexer(MockDictionaryTokenIndexer::new())
.k1(0.5)
.avgdl(2.0)
.build()
.unwrap();
let vectorizer_high_k1 = Bm25VectorizerBuilder::new()
.tokenizer(MockWhitespaceTokenizer)
.token_indexer(MockDictionaryTokenIndexer::new())
.k1(3.0)
.avgdl(2.0)
.build()
.unwrap();
let result_low = vectorizer_low_k1.vectorize("hello hello");
let result_high = vectorizer_high_k1.vectorize("hello hello");
assert!(result_high.0[0].value > result_low.0[0].value);
}
#[test]
fn test_length_normalisation_effect() {
let vectorizer_no_norm = Bm25VectorizerBuilder::new()
.tokenizer(MockWhitespaceTokenizer)
.token_indexer(MockDictionaryTokenIndexer::new())
.b(0.0) .avgdl(5.0)
.build()
.unwrap();
let vectorizer_full_norm = Bm25VectorizerBuilder::new()
.tokenizer(MockWhitespaceTokenizer)
.token_indexer(MockDictionaryTokenIndexer::new())
.b(1.0) .avgdl(5.0)
.build()
.unwrap();
let long_text = "hello world this is a long document";
let short_text = "hello world";
let long_no_norm = vectorizer_no_norm.vectorize(long_text);
let long_full_norm = vectorizer_full_norm.vectorize(long_text);
let short_no_norm = vectorizer_no_norm.vectorize(short_text);
let hello_long_no_norm = long_no_norm.0.iter().find(|t| t.index == 0).unwrap().value;
let hello_long_full_norm = long_full_norm
.0
.iter()
.find(|t| t.index == 0)
.unwrap()
.value;
let hello_short_no_norm = short_no_norm.0.iter().find(|t| t.index == 0).unwrap().value;
assert!(hello_long_no_norm > hello_long_full_norm);
assert!(hello_short_no_norm > hello_long_full_norm);
}
#[test]
fn test_delta_effect() {
let vectorizer_no_delta = Bm25VectorizerBuilder::new()
.tokenizer(MockWhitespaceTokenizer)
.token_indexer(MockDictionaryTokenIndexer::new())
.delta(0.0)
.avgdl(2.0)
.build()
.unwrap();
let vectorizer_with_delta = Bm25VectorizerBuilder::new()
.tokenizer(MockWhitespaceTokenizer)
.token_indexer(MockDictionaryTokenIndexer::new())
.delta(0.5)
.avgdl(2.0)
.build()
.unwrap();
let result_no_delta = vectorizer_no_delta.vectorize("hello");
let result_with_delta = vectorizer_with_delta.vectorize("hello");
assert_eq!(
result_with_delta.0[0].value,
result_no_delta.0[0].value + 0.5
);
}
#[cfg(not(feature = "parallelism"))]
#[test]
fn test_fit_iter() {
let corpus = vec!["hello world", "world rust", "hello programming"];
let builder = Bm25VectorizerBuilder::<MockHashTokenIndexer, MockWhitespaceTokenizer>::new()
.tokenizer(MockWhitespaceTokenizer)
.fit_iter(corpus)
.unwrap();
let expected_avgdl = (2.0 + 2.0 + 2.0) / 3.0;
assert_eq!(builder.avgdl.unwrap().avgdl, expected_avgdl);
}
#[cfg(feature = "parallelism")]
#[test]
fn test_fit_par_iter() {
let corpus = vec!["hello world", "world rust", "hello programming"];
let builder = Bm25VectorizerBuilder::<MockHashTokenIndexer, MockWhitespaceTokenizer>::new()
.tokenizer(MockWhitespaceTokenizer)
.fit_par_iter(corpus)
.unwrap();
let expected_avgdl = (2.0 + 2.0 + 2.0) / 3.0;
assert_eq!(builder.avgdl.unwrap().avgdl, expected_avgdl);
}
}