bm25_vectorizer/
bm25_vectorizer.rs

1//! # BM25 Vectorizer
2//!
3//! This module implements part of the BM25 (Best Matching 25) ranking function, a probabilistic
4//! ranking algorithm commonly used in information retrieval and search engines.
5//!
6//! ## BM25 Algorithm
7//!
8//! BM25 is a bag-of-words retrieval function that ranks matching documents by their relevance
9//! to a query.
10//!
11//! The BM25 value for a term in a document is calculated as:
12//!
13//! ```text
14//! BM25(t,d) = (tf(t,d) * (k1 + 1)) / (tf(t,d) + k1 * (1 - b + b * (|d| / avgdl)))
15//! ```
16//!
17//! The BM25+ value for a term in a document is calculated as:
18//!
19//! ```text
20//! BM25(t,d) = (tf(t,d) * (k1 + 1)) / (tf(t,d) + k1 * (1 - b + b * (|d| / avgdl))) + δ
21//! ```
22//!
23//! Where:
24//! - `tf(t,d)` is the term frequency in the document
25//! - `k1` controls term frequency saturation (typically 1.2)
26//! - `b` controls length normalisation (typically 0.75)
27//! - `|d|` is the document length
28//! - `avgdl` is the average document length in the corpus
29//! - `δ` (delta) is a lower bound for term frequency scoring
30//!
31//! ## Usage
32//!
33//! ```rust
34//! # fn main() -> Result<(), Box<dyn std::error::Error>> {
35//! use bm25_vectorizer::{Bm25VectorizerBuilder, MockWhitespaceTokenizer, MockHashTokenIndexer};
36//!
37//! let corpus = vec!["hello world", "world of rust", "hello rust"];
38//! let tokenizer = MockWhitespaceTokenizer;
39//! let indexer = MockHashTokenIndexer;
40//!
41//! let vectorizer = Bm25VectorizerBuilder::new()
42//!     .tokenizer(tokenizer)
43//!     .token_indexer(indexer)
44//!     .k1(1.2)
45//!     .b(0.75)
46//!     .fit(&corpus)?
47//!     .build()?;
48//!
49//! let vector = vectorizer.vectorize("hello world");
50//! # Ok(())
51//! # }
52//! ```
53
54use crate::bm25_tokenizer::{ Bm25Tokenizer};
55use crate::bm25_vectorizer::Bm25VectorizerError::{
56    InvalidAverageDocumentLength, InvalidTermFrequencyLowerBound, InvalidTermRelevanceSaturation,
57    MissingAverageDocumentLength, MissingTokenIndexer, MissingTokenizer,
58};
59use std::collections::BTreeMap;
60use std::fmt::Debug;
61use std::hash::Hash;
62
63#[cfg(feature = "parallelism")]
64use rayon::prelude::*;
65use crate::Bm25TokenIndexer;
66
67/// Represents a token with its index and BM25 value.
68///
69/// # Type Parameters
70/// - `T`: The type of the token index (must implement required traits for the vectorizer)
71///
72/// # Fields
73/// - `index`: The token's unique identifier
74/// - `value`: The computed BM25 value for this token
75#[derive(PartialEq, Debug, Clone, PartialOrd)]
76pub struct TokenIndexValue<T> {
77    /// The unique identifier for the token
78    pub index: T,
79    /// The BM25 value value for this token
80    pub value: f32,
81}
82
83/// A sparse vector representation containing token indices and their BM25 values.
84///
85/// # Type Parameters
86/// - `T`: The type of token indices used
87///
88/// # Examples
89/// ```rust
90/// use bm25_vectorizer::{TokenIndexValue, SparseRepresentation};
91///
92/// let tokens = vec![
93///     TokenIndexValue { index: 0, value: 1.2 },
94///     TokenIndexValue { index: 5, value: 0.8 },
95/// ];
96/// let sparse_vector = SparseRepresentation(tokens);
97/// ```
98#[derive(PartialEq, Debug, Clone, PartialOrd)]
99pub struct SparseRepresentation<T>(pub Vec<TokenIndexValue<T>>);
100
101/// Controls term frequency saturation.
102#[derive(Debug)]
103pub struct TermRelevanceSaturation {
104    k1: f32,
105}
106
107/// The additional δ parameter for BM25+
108#[derive(Debug)]
109pub struct TermFrequencyLowerBound {
110    delta: f32,
111}
112
113/// Controls document length normalisation.
114#[derive(Debug)]
115pub struct LengthNormalisation {
116    b: f32,
117}
118
119/// Represents the average document length in the corpus.
120///
121/// This value is used for document length normalisation.
122/// It's typically computed during the fitting process by analysing the training corpus.
123#[derive(Debug)]
124pub struct AverageDocumentLength {
125    avgdl: f32,
126}
127
128/// The main BM25 vectorizer that converts text into sparse vector representations.
129///
130/// This struct encapsulates all the parameters and components needed to perform BM25
131/// vectorization. It uses a tokenizer to break text into tokens and a token indexer
132/// to map tokens to indices.
133///
134/// # Type Parameters
135/// - `TokenIndexer`: Implementation of `Bm25TokenIndexer` trait for mapping tokens to indices
136/// - `Tokenizer`: Implementation of `Bm25Tokenizer` trait for text tokenization
137///
138/// # Examples
139/// ```rust
140/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
141/// use bm25_vectorizer::{Bm25VectorizerBuilder, MockWhitespaceTokenizer, MockHashTokenIndexer};
142///
143/// let corpus = vec!["hello world", "world of rust"];
144/// let vectorizer = Bm25VectorizerBuilder::new()
145///     .tokenizer(MockWhitespaceTokenizer)
146///     .token_indexer(MockHashTokenIndexer)
147///     .fit(&corpus)?
148///     .build()?;
149///
150/// let result = vectorizer.vectorize("hello rust");
151/// # Ok(())
152/// # }
153/// ```
154#[derive(Debug)]
155pub struct Bm25Vectorizer<TokenIndexer, Tokenizer> {
156    tokenizer: Tokenizer,
157    k1: TermRelevanceSaturation,
158    b: LengthNormalisation,
159    avgdl: AverageDocumentLength,
160    delta: TermFrequencyLowerBound,
161    token_indexer: TokenIndexer,
162}
163
164impl<TokenIndexer, Tokenizer> Bm25Vectorizer<TokenIndexer, Tokenizer> {
165    /// Returns the average document length used for normalisation.
166    ///
167    /// # Examples
168    /// ```rust
169    /// # use bm25_vectorizer::*;
170    /// # let vectorizer = Bm25VectorizerBuilder::new()
171    /// #     .tokenizer(MockWhitespaceTokenizer)
172    /// #     .token_indexer(MockHashTokenIndexer)
173    /// #     .avgdl(10.5)
174    /// #     .build().unwrap();
175    /// assert_eq!(vectorizer.avgdl(), 10.5);
176    /// ```
177    pub fn avgdl(&self) -> f32 {
178        self.avgdl.avgdl
179    }
180
181    /// Returns the k1 parameter controlling term frequency saturation.
182    ///
183    /// # Examples
184    /// ```rust
185    /// # use bm25_vectorizer::*;
186    /// # let vectorizer = Bm25VectorizerBuilder::new()
187    /// #     .tokenizer(MockWhitespaceTokenizer)
188    /// #     .token_indexer(MockHashTokenIndexer)
189    /// #     .k1(1.5)
190    /// #     .avgdl(10.0)
191    /// #     .build().unwrap();
192    /// assert_eq!(vectorizer.k1(), 1.5);
193    /// ```
194    pub fn k1(&self) -> f32 {
195        self.k1.k1
196    }
197
198    /// Returns the b parameter controlling length normalisation.
199    ///
200    /// # Examples
201    /// ```rust
202    /// # use bm25_vectorizer::*;
203    /// # let vectorizer = Bm25VectorizerBuilder::new()
204    /// #     .tokenizer(MockWhitespaceTokenizer)
205    /// #     .token_indexer(MockHashTokenIndexer)
206    /// #     .b(0.8)
207    /// #     .avgdl(10.0)
208    /// #     .build().unwrap();
209    /// assert_eq!(vectorizer.b(), 0.8);
210    /// ```
211    pub fn b(&self) -> f32 {
212        self.b.b
213    }
214
215    /// Returns the delta parameter used as a lower bound for term values.
216    ///
217    /// # Examples
218    /// ```rust
219    /// # use bm25_vectorizer::*;
220    /// # let vectorizer = Bm25VectorizerBuilder::new()
221    /// #     .tokenizer(MockWhitespaceTokenizer)
222    /// #     .token_indexer(MockHashTokenIndexer)
223    /// #     .delta(0.25)
224    /// #     .avgdl(10.0)
225    /// #     .build().unwrap();
226    /// assert_eq!(vectorizer.delta(), 0.25);
227    /// ```
228    pub fn delta(&self) -> f32 {
229        self.delta.delta
230    }
231
232    /// Converts input text into a sparse BM25 vector representation.
233    ///
234    /// This method tokenizes the input text, and computes BM25 term frequencies to
235    /// generate a sparse vector representation that can then be uploaded to a vector database.
236    ///
237    /// NOTE: Vector databases might require to specify an IDF modifier when setting up the
238    /// vector store to instruct them to calculate IDF statistics automatically.
239    /// This implementation produces only the normalised term frequency (TF) component in document
240    /// vectors and expects the inverse document frequency (IDF) to be computed by the vector database.
241    ///
242    /// # Arguments
243    /// - `text`: The input text to vectorize
244    ///
245    /// # Returns
246    /// A `SparseRepresentation` containing token indices and their BM25 values
247    ///
248    /// # Examples
249    /// ```rust
250    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
251    /// use bm25_vectorizer::{Bm25VectorizerBuilder, MockWhitespaceTokenizer, MockHashTokenIndexer};
252    ///
253    /// let corpus = vec!["hello world", "world rust"];
254    /// let vectorizer = Bm25VectorizerBuilder::new()
255    ///     .tokenizer(MockWhitespaceTokenizer)
256    ///     .token_indexer(MockHashTokenIndexer)
257    ///     .fit(&corpus)?
258    ///     .build()?;
259    ///
260    /// let result = vectorizer.vectorize("hello world");
261    /// // Result contains BM25 values for tokens "hello" and "world"
262    /// assert_eq!(result.0.len(), 2);
263    /// # Ok(())
264    /// # }
265    /// ```
266    pub fn vectorize(&self, text: &str) -> SparseRepresentation<TokenIndexer::Bm25TokenIndex>
267    where
268        TokenIndexer: Bm25TokenIndexer,
269        TokenIndexer::Bm25TokenIndex: Eq + Hash + Clone + Debug + Ord,
270        Tokenizer: Bm25Tokenizer,
271    {
272        let tokens = self.tokenizer.tokenize(text);
273        let doc_length = tokens.len() as f32;
274
275        // Build unique map of indices to their term frequencies
276        // Using tree map for deterministic results
277        let mut index_counts: BTreeMap<TokenIndexer::Bm25TokenIndex, usize> = BTreeMap::new();
278
279        for token in tokens.iter() {
280            let index = self.token_indexer.index(token);
281            *index_counts.entry(index).or_insert(0) += 1;
282        }
283
284        let embeddings: Vec<TokenIndexValue<TokenIndexer::Bm25TokenIndex>> = index_counts
285            .into_iter()
286            .map(|(index, count)| {
287                let token_frequency = count as f32;
288                let numerator = token_frequency * (self.k1() + 1.0);
289                let denominator = token_frequency
290                    + self.k1() * (1.0 - self.b() + self.b() * (doc_length / self.avgdl()));
291
292                // BM25+: adds delta (δ) to ensure minimum contribution from matching terms
293                let value = (numerator / denominator) + self.delta();
294
295                TokenIndexValue { index, value }
296            })
297            .collect();
298
299        SparseRepresentation(embeddings)
300    }
301}
302
303/// Builder for creating and configuring a `Bm25Vectorizer`.
304///
305/// It supports fitting on a corpus to automatically compute the
306/// average document length, and validates all parameters before building.
307///
308/// # Type Parameters
309/// - `TokenIndexer`: Implementation of `Bm25TokenIndexer` trait
310/// - `Tokenizer`: Implementation of `Bm25Tokenizer` trait
311///
312/// # Examples
313///
314/// Basic usage with manual avgdl:
315/// ```rust
316/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
317/// use bm25_vectorizer::{Bm25VectorizerBuilder, MockWhitespaceTokenizer, MockHashTokenIndexer};
318///
319/// let vectorizer = Bm25VectorizerBuilder::new()
320///     .tokenizer(MockWhitespaceTokenizer)
321///     .token_indexer(MockHashTokenIndexer)
322///     .k1(1.2)
323///     .b(0.75)
324///     .avgdl(10.0)
325///     .build()?;
326/// # Ok(())
327/// # }
328/// ```
329///
330/// Usage with corpus fitting:
331/// ```rust
332/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
333/// use bm25_vectorizer::{Bm25VectorizerBuilder, MockWhitespaceTokenizer, MockHashTokenIndexer};
334///
335/// let corpus = vec!["hello world", "world of rust", "hello rust programming"];
336/// let vectorizer = Bm25VectorizerBuilder::new()
337///     .tokenizer(MockWhitespaceTokenizer)
338///     .token_indexer(MockHashTokenIndexer)
339///     .k1(1.2)
340///     .b(0.75)
341///     .fit(&corpus)?  // Automatically computes avgdl
342///     .build()?;
343/// # Ok(())
344/// # }
345/// ```
346pub struct Bm25VectorizerBuilder<TokenIndexer, Tokenizer> {
347    tokenizer: Option<Tokenizer>,
348    k1: TermRelevanceSaturation,
349    b: LengthNormalisation,
350    avgdl: Option<AverageDocumentLength>,
351    delta: TermFrequencyLowerBound,
352    token_indexer: Option<TokenIndexer>,
353}
354
355impl<TokenIndexer, Tokenizer> Bm25VectorizerBuilder<TokenIndexer, Tokenizer> {
356    pub fn new() -> Self {
357        Self {
358            tokenizer: None,
359            k1: TermRelevanceSaturation { k1: 1.2 },
360            b: LengthNormalisation { b: 0.75 },
361            avgdl: None,
362            delta: TermFrequencyLowerBound { delta: 0.0 },
363            token_indexer: None,
364        }
365    }
366
367    pub fn k1(mut self, k1: f32) -> Self {
368        self.k1 = TermRelevanceSaturation { k1 };
369        self
370    }
371
372    pub fn b(mut self, b: f32) -> Self {
373        self.b = LengthNormalisation { b };
374        self
375    }
376
377    pub fn delta(mut self, delta: f32) -> Self {
378        self.delta = TermFrequencyLowerBound { delta };
379        self
380    }
381
382    pub fn avgdl(mut self, avgdl: f32) -> Self {
383        self.avgdl = Some(AverageDocumentLength { avgdl });
384        self
385    }
386
387    pub fn tokenizer(mut self, tokenizer: Tokenizer) -> Self {
388        self.tokenizer = Some(tokenizer);
389        self
390    }
391
392    pub fn token_indexer(mut self, token_indexer: TokenIndexer) -> Self {
393        self.token_indexer = Some(token_indexer);
394        self
395    }
396
397    pub fn fit(mut self, corpus: &[&str]) -> Result<Self, Bm25VectorizerError>
398    where
399        Tokenizer: Bm25Tokenizer + Sync,
400    {
401        if let Some(ref tokenizer) = self.tokenizer {
402            let doc_count = corpus.len();
403            if doc_count == 0 {
404                return Err(Bm25VectorizerError::EmptyCorpus);
405            }
406
407            #[cfg(not(feature = "parallelism"))]
408            let corpus_iter = corpus.iter();
409            #[cfg(feature = "parallelism")]
410            let corpus_iter = corpus.par_iter();
411
412            let total_length: usize = corpus_iter.map(|doc| tokenizer.tokenize(doc).len()).sum();
413            self.avgdl = Some(AverageDocumentLength {
414                avgdl: total_length as f32 / doc_count as f32,
415            });
416        }
417        Ok(self)
418    }
419
420    pub fn fit_iter<I, S>(mut self, corpus: I) -> Result<Self, Bm25VectorizerError>
421    where
422        I: IntoIterator<Item = S>,
423        S: AsRef<str>,
424        Tokenizer: Bm25Tokenizer + Sync,
425    {
426        if let Some(ref tokenizer) = self.tokenizer {
427            let (doc_count, total_length) = corpus
428                .into_iter()
429                .map(|doc| tokenizer.tokenize(doc.as_ref()).len())
430                .fold((0usize, 0usize), |(count, sum), len| (count + 1, sum + len));
431
432            self.avgdl = Some(AverageDocumentLength {
433                avgdl: total_length as f32 / doc_count as f32,
434            });
435        }
436        Ok(self)
437    }
438
439    #[cfg(feature = "parallelism")]
440    pub fn fit_par_iter<I, S>(mut self, corpus: I) -> Result<Self, Bm25VectorizerError>
441    where
442        I: IntoIterator<Item = S>,
443        I::IntoIter: Send,
444        S: AsRef<str> + Send,
445        Tokenizer: Bm25Tokenizer + Sync,
446    {
447        if let Some(ref tokenizer) = self.tokenizer {
448            let (doc_count, total_length) = {
449                use rayon::iter::ParallelBridge;
450                corpus
451                    .into_iter()
452                    .par_bridge()
453                    .map(|doc| tokenizer.tokenize(doc.as_ref()).len())
454                    .fold(
455                        || (0usize, 0usize),
456                        |(count, sum), len| (count + 1, sum + len),
457                    )
458                    .reduce(|| (0, 0), |(c1, s1), (c2, s2)| (c1 + c2, s1 + s2))
459            };
460
461            if doc_count == 0 {
462                return Err(Bm25VectorizerError::EmptyCorpus);
463            }
464
465            self.avgdl = Some(AverageDocumentLength {
466                avgdl: total_length as f32 / doc_count as f32,
467            });
468        }
469        Ok(self)
470    }
471
472    pub fn build(self) -> Result<Bm25Vectorizer<TokenIndexer, Tokenizer>, Bm25VectorizerError> {
473        let tokenizer = self.tokenizer.ok_or(MissingTokenizer)?;
474        let token_indexer = self.token_indexer.ok_or(MissingTokenIndexer)?;
475        let avgdl = self.avgdl.ok_or(MissingAverageDocumentLength)?;
476
477        if &self.k1.k1 < &0.0 {
478            return Err(InvalidTermRelevanceSaturation);
479        }
480        if &self.b.b < &0.0 || &self.b.b > &1.0 {
481            return Err(InvalidTermRelevanceSaturation);
482        }
483        if &avgdl.avgdl <= &0.0 {
484            return Err(InvalidAverageDocumentLength);
485        }
486        if &self.delta.delta < &0.0 {
487            return Err(InvalidTermFrequencyLowerBound);
488        }
489
490        Ok(Bm25Vectorizer {
491            tokenizer,
492            k1: self.k1,
493            b: self.b,
494            avgdl,
495            delta: self.delta,
496            token_indexer,
497        })
498    }
499}
500
501#[derive(Debug, thiserror::Error)]
502pub enum Bm25VectorizerError {
503    #[error("Cannot fit on empty corpus.")]
504    EmptyCorpus,
505    #[error("Average document length must be provided or computed via fit().")]
506    MissingAverageDocumentLength,
507    #[error("Tokenizer must be provided.")]
508    MissingTokenizer,
509    #[error("Token indexer must be provided.")]
510    MissingTokenIndexer,
511    #[error("Invalid b value: must be between 0 and 1.")]
512    InvalidLengthNormalisation,
513    #[error(
514        "Invalid k1 value: should normally fall within the 0 to 3 range. However, there is no strict enforcement preventing values higher than 3."
515    )]
516    InvalidTermRelevanceSaturation,
517    #[error("Invalid average document length: value must be greater than 0.")]
518    InvalidAverageDocumentLength,
519    #[error("Invalid delta (δ) value: must be 0 or greater.")]
520    InvalidTermFrequencyLowerBound,
521}
522
523#[cfg(test)]
524mod tests {
525    use super::*;
526    use crate::mocking::{
527        MockDictionaryTokenIndexer, MockHashTokenIndexer, MockWhitespaceTokenizer,
528    };
529
530    #[test]
531    fn test_builder_new_defaults() {
532        let builder = Bm25VectorizerBuilder::<MockHashTokenIndexer, MockWhitespaceTokenizer>::new();
533
534        // Check default values
535        assert_eq!(builder.k1.k1, 1.2);
536        assert_eq!(builder.b.b, 0.75);
537        assert_eq!(builder.delta.delta, 0.0);
538        assert!(builder.tokenizer.is_none());
539        assert!(builder.token_indexer.is_none());
540        assert!(builder.avgdl.is_none());
541    }
542
543    #[test]
544    fn test_builder_parameter_setting() {
545        let builder = Bm25VectorizerBuilder::<MockHashTokenIndexer, MockWhitespaceTokenizer>::new()
546            .k1(2.0)
547            .b(0.5)
548            .delta(0.25)
549            .avgdl(15.0);
550
551        assert_eq!(builder.k1.k1, 2.0);
552        assert_eq!(builder.b.b, 0.5);
553        assert_eq!(builder.delta.delta, 0.25);
554        assert_eq!(builder.avgdl.unwrap().avgdl, 15.0);
555    }
556
557    #[test]
558    fn test_builder_missing_components() {
559        let result = Bm25VectorizerBuilder::<MockHashTokenIndexer, MockWhitespaceTokenizer>::new()
560            .avgdl(10.0)
561            .build();
562
563        assert!(matches!(result, Err(MissingTokenizer)));
564
565        let result = Bm25VectorizerBuilder::<MockHashTokenIndexer, MockWhitespaceTokenizer>::new()
566            .tokenizer(MockWhitespaceTokenizer)
567            .avgdl(10.0)
568            .build();
569
570        assert!(matches!(result, Err(MissingTokenIndexer)));
571
572        let result = Bm25VectorizerBuilder::<MockHashTokenIndexer, MockWhitespaceTokenizer>::new()
573            .tokenizer(MockWhitespaceTokenizer)
574            .token_indexer(MockHashTokenIndexer)
575            .build();
576
577        assert!(matches!(result, Err(MissingAverageDocumentLength)));
578    }
579
580    #[test]
581    fn test_builder_invalid_parameters() {
582        // Test negative k1
583        let result = Bm25VectorizerBuilder::new()
584            .tokenizer(MockWhitespaceTokenizer)
585            .token_indexer(MockHashTokenIndexer)
586            .k1(-1.0)
587            .avgdl(10.0)
588            .build();
589
590        assert!(matches!(result, Err(InvalidTermRelevanceSaturation)));
591
592        // Test invalid b values
593        let result = Bm25VectorizerBuilder::new()
594            .tokenizer(MockWhitespaceTokenizer)
595            .token_indexer(MockHashTokenIndexer)
596            .b(-0.1)
597            .avgdl(10.0)
598            .build();
599
600        assert!(matches!(result, Err(InvalidTermRelevanceSaturation)));
601
602        let result = Bm25VectorizerBuilder::new()
603            .tokenizer(MockWhitespaceTokenizer)
604            .token_indexer(MockHashTokenIndexer)
605            .b(1.1)
606            .avgdl(10.0)
607            .build();
608
609        assert!(matches!(result, Err(InvalidTermRelevanceSaturation)));
610
611        // Test invalid avgdl
612        let result = Bm25VectorizerBuilder::new()
613            .tokenizer(MockWhitespaceTokenizer)
614            .token_indexer(MockHashTokenIndexer)
615            .avgdl(0.0)
616            .build();
617
618        assert!(matches!(result, Err(InvalidAverageDocumentLength)));
619
620        // Test negative delta
621        let result = Bm25VectorizerBuilder::new()
622            .tokenizer(MockWhitespaceTokenizer)
623            .token_indexer(MockHashTokenIndexer)
624            .delta(-0.1)
625            .avgdl(10.0)
626            .build();
627
628        assert!(matches!(result, Err(InvalidTermFrequencyLowerBound)));
629    }
630
631    #[test]
632    fn test_successful_build() {
633        let vectorizer = Bm25VectorizerBuilder::new()
634            .tokenizer(MockWhitespaceTokenizer)
635            .token_indexer(MockHashTokenIndexer)
636            .k1(1.5)
637            .b(0.8)
638            .delta(0.25)
639            .avgdl(12.0)
640            .build()
641            .unwrap();
642
643        assert_eq!(vectorizer.k1(), 1.5);
644        assert_eq!(vectorizer.b(), 0.8);
645        assert_eq!(vectorizer.delta(), 0.25);
646        assert_eq!(vectorizer.avgdl(), 12.0);
647    }
648
649    #[test]
650    fn test_fit_corpus() {
651        let corpus = vec!["hello world", "world of rust", "hello rust programming"];
652        let builder = Bm25VectorizerBuilder::new()
653            .tokenizer(MockWhitespaceTokenizer)
654            .token_indexer(MockHashTokenIndexer)
655            .fit(&corpus)
656            .unwrap();
657
658        // Average document length should be (2 + 3 + 3) / 3 = 2.67 (approximately)
659        let expected_avgdl = (2.0 + 3.0 + 3.0) / 3.0;
660        assert_eq!(builder.avgdl.unwrap().avgdl, expected_avgdl);
661    }
662
663    #[test]
664    fn test_fit_empty_corpus() {
665        let corpus: Vec<&str> = vec![];
666        let result = Bm25VectorizerBuilder::<MockHashTokenIndexer, MockWhitespaceTokenizer>::new()
667            .tokenizer(MockWhitespaceTokenizer)
668            .fit(&corpus);
669
670        assert!(matches!(result, Err(Bm25VectorizerError::EmptyCorpus)));
671    }
672
673    #[test]
674    fn test_vectorize_basic() {
675        let vectorizer = Bm25VectorizerBuilder::new()
676            .tokenizer(MockWhitespaceTokenizer)
677            .token_indexer(MockDictionaryTokenIndexer::new())
678            .avgdl(2.0)
679            .build()
680            .unwrap();
681
682        let result = vectorizer.vectorize("hello world");
683
684        // Should have 2 tokens
685        assert_eq!(result.0.len(), 2);
686
687        // All values should be positive due to BM25 formula
688        for token in &result.0 {
689            assert!(token.value > 0.0);
690        }
691    }
692
693    #[test]
694    fn test_vectorize_repeated_tokens() {
695        let vectorizer = Bm25VectorizerBuilder::new()
696            .tokenizer(MockWhitespaceTokenizer)
697            .token_indexer(MockDictionaryTokenIndexer::new())
698            .avgdl(3.0)
699            .build()
700            .unwrap();
701
702        let result = vectorizer.vectorize("hello hello world");
703
704        // Should have 2 unique tokens (hello appears twice, world once)
705        assert_eq!(result.0.len(), 2);
706
707        // Token for "hello" should have higher value due to higher frequency
708        let hello_value = result.0.iter().find(|t| t.index == 0).unwrap().value; // "hello" gets index 0
709        let world_value = result.0.iter().find(|t| t.index == 1).unwrap().value; // "world" gets index 1
710
711        assert!(hello_value > world_value);
712    }
713
714    #[test]
715    fn test_vectorize_empty_text() {
716        let vectorizer = Bm25VectorizerBuilder::new()
717            .tokenizer(MockWhitespaceTokenizer)
718            .token_indexer(MockHashTokenIndexer)
719            .avgdl(2.0)
720            .build()
721            .unwrap();
722
723        let result = vectorizer.vectorize("");
724        assert_eq!(result.0.len(), 0);
725    }
726
727    #[test]
728    fn test_bm25_parameters_effect() {
729        // Test that changing k1 affects the values
730        let vectorizer_low_k1 = Bm25VectorizerBuilder::new()
731            .tokenizer(MockWhitespaceTokenizer)
732            .token_indexer(MockDictionaryTokenIndexer::new())
733            .k1(0.5)
734            .avgdl(2.0)
735            .build()
736            .unwrap();
737
738        let vectorizer_high_k1 = Bm25VectorizerBuilder::new()
739            .tokenizer(MockWhitespaceTokenizer)
740            .token_indexer(MockDictionaryTokenIndexer::new())
741            .k1(3.0)
742            .avgdl(2.0)
743            .build()
744            .unwrap();
745
746        let result_low = vectorizer_low_k1.vectorize("hello hello");
747        let result_high = vectorizer_high_k1.vectorize("hello hello");
748
749        // Higher k1 should result in higher values for repeated terms
750        assert!(result_high.0[0].value > result_low.0[0].value);
751    }
752
753    #[test]
754    fn test_length_normalisation_effect() {
755        let vectorizer_no_norm = Bm25VectorizerBuilder::new()
756            .tokenizer(MockWhitespaceTokenizer)
757            .token_indexer(MockDictionaryTokenIndexer::new())
758            .b(0.0) // No length normalisation
759            .avgdl(5.0)
760            .build()
761            .unwrap();
762
763        let vectorizer_full_norm = Bm25VectorizerBuilder::new()
764            .tokenizer(MockWhitespaceTokenizer)
765            .token_indexer(MockDictionaryTokenIndexer::new())
766            .b(1.0) // Full length normalisation
767            .avgdl(5.0)
768            .build()
769            .unwrap();
770
771        // Test with a longer document
772        let long_text = "hello world this is a long document";
773        let short_text = "hello world";
774
775        let long_no_norm = vectorizer_no_norm.vectorize(long_text);
776        let long_full_norm = vectorizer_full_norm.vectorize(long_text);
777        let short_no_norm = vectorizer_no_norm.vectorize(short_text);
778
779        // With no normalisation, longer docs don't get penalised as much
780        // With full normalisation, values should be more similar between docs
781        let hello_long_no_norm = long_no_norm.0.iter().find(|t| t.index == 0).unwrap().value;
782        let hello_long_full_norm = long_full_norm
783            .0
784            .iter()
785            .find(|t| t.index == 0)
786            .unwrap()
787            .value;
788        let hello_short_no_norm = short_no_norm.0.iter().find(|t| t.index == 0).unwrap().value;
789
790        // Length normalisation should make long document values lower
791        assert!(hello_long_no_norm > hello_long_full_norm);
792        assert!(hello_short_no_norm > hello_long_full_norm);
793    }
794
795    #[test]
796    fn test_delta_effect() {
797        let vectorizer_no_delta = Bm25VectorizerBuilder::new()
798            .tokenizer(MockWhitespaceTokenizer)
799            .token_indexer(MockDictionaryTokenIndexer::new())
800            .delta(0.0)
801            .avgdl(2.0)
802            .build()
803            .unwrap();
804
805        let vectorizer_with_delta = Bm25VectorizerBuilder::new()
806            .tokenizer(MockWhitespaceTokenizer)
807            .token_indexer(MockDictionaryTokenIndexer::new())
808            .delta(0.5)
809            .avgdl(2.0)
810            .build()
811            .unwrap();
812
813        let result_no_delta = vectorizer_no_delta.vectorize("hello");
814        let result_with_delta = vectorizer_with_delta.vectorize("hello");
815
816        // Delta should add to all values
817        assert_eq!(
818            result_with_delta.0[0].value,
819            result_no_delta.0[0].value + 0.5
820        );
821    }
822
823    #[cfg(not(feature = "parallelism"))]
824    #[test]
825    fn test_fit_iter() {
826        let corpus = vec!["hello world", "world rust", "hello programming"];
827        let builder = Bm25VectorizerBuilder::<MockHashTokenIndexer, MockWhitespaceTokenizer>::new()
828            .tokenizer(MockWhitespaceTokenizer)
829            .fit_iter(corpus)
830            .unwrap();
831
832        let expected_avgdl = (2.0 + 2.0 + 2.0) / 3.0;
833        assert_eq!(builder.avgdl.unwrap().avgdl, expected_avgdl);
834    }
835
836    #[cfg(feature = "parallelism")]
837    #[test]
838    fn test_fit_par_iter() {
839        let corpus = vec!["hello world", "world rust", "hello programming"];
840        let builder = Bm25VectorizerBuilder::<MockHashTokenIndexer, MockWhitespaceTokenizer>::new()
841            .tokenizer(MockWhitespaceTokenizer)
842            .fit_par_iter(corpus)
843            .unwrap();
844
845        let expected_avgdl = (2.0 + 2.0 + 2.0) / 3.0;
846        assert_eq!(builder.avgdl.unwrap().avgdl, expected_avgdl);
847    }
848}