1use crate::bm25_tokenizer::{ Bm25Tokenizer};
55use crate::bm25_vectorizer::Bm25VectorizerError::{
56 InvalidAverageDocumentLength, InvalidTermFrequencyLowerBound, InvalidTermRelevanceSaturation,
57 MissingAverageDocumentLength, MissingTokenIndexer, MissingTokenizer,
58};
59use std::collections::BTreeMap;
60use std::fmt::Debug;
61use std::hash::Hash;
62
63#[cfg(feature = "parallelism")]
64use rayon::prelude::*;
65use crate::Bm25TokenIndexer;
66
67#[derive(PartialEq, Debug, Clone, PartialOrd)]
76pub struct TokenIndexValue<T> {
77 pub index: T,
79 pub value: f32,
81}
82
83#[derive(PartialEq, Debug, Clone, PartialOrd)]
99pub struct SparseRepresentation<T>(pub Vec<TokenIndexValue<T>>);
100
101#[derive(Debug)]
103pub struct TermRelevanceSaturation {
104 k1: f32,
105}
106
107#[derive(Debug)]
109pub struct TermFrequencyLowerBound {
110 delta: f32,
111}
112
113#[derive(Debug)]
115pub struct LengthNormalisation {
116 b: f32,
117}
118
119#[derive(Debug)]
124pub struct AverageDocumentLength {
125 avgdl: f32,
126}
127
128#[derive(Debug)]
155pub struct Bm25Vectorizer<TokenIndexer, Tokenizer> {
156 tokenizer: Tokenizer,
157 k1: TermRelevanceSaturation,
158 b: LengthNormalisation,
159 avgdl: AverageDocumentLength,
160 delta: TermFrequencyLowerBound,
161 token_indexer: TokenIndexer,
162}
163
164impl<TokenIndexer, Tokenizer> Bm25Vectorizer<TokenIndexer, Tokenizer> {
165 pub fn avgdl(&self) -> f32 {
178 self.avgdl.avgdl
179 }
180
181 pub fn k1(&self) -> f32 {
195 self.k1.k1
196 }
197
198 pub fn b(&self) -> f32 {
212 self.b.b
213 }
214
215 pub fn delta(&self) -> f32 {
229 self.delta.delta
230 }
231
232 pub fn vectorize(&self, text: &str) -> SparseRepresentation<TokenIndexer::Bm25TokenIndex>
267 where
268 TokenIndexer: Bm25TokenIndexer,
269 TokenIndexer::Bm25TokenIndex: Eq + Hash + Clone + Debug + Ord,
270 Tokenizer: Bm25Tokenizer,
271 {
272 let tokens = self.tokenizer.tokenize(text);
273 let doc_length = tokens.len() as f32;
274
275 let mut index_counts: BTreeMap<TokenIndexer::Bm25TokenIndex, usize> = BTreeMap::new();
278
279 for token in tokens.iter() {
280 let index = self.token_indexer.index(token);
281 *index_counts.entry(index).or_insert(0) += 1;
282 }
283
284 let embeddings: Vec<TokenIndexValue<TokenIndexer::Bm25TokenIndex>> = index_counts
285 .into_iter()
286 .map(|(index, count)| {
287 let token_frequency = count as f32;
288 let numerator = token_frequency * (self.k1() + 1.0);
289 let denominator = token_frequency
290 + self.k1() * (1.0 - self.b() + self.b() * (doc_length / self.avgdl()));
291
292 let value = (numerator / denominator) + self.delta();
294
295 TokenIndexValue { index, value }
296 })
297 .collect();
298
299 SparseRepresentation(embeddings)
300 }
301}
302
303pub struct Bm25VectorizerBuilder<TokenIndexer, Tokenizer> {
347 tokenizer: Option<Tokenizer>,
348 k1: TermRelevanceSaturation,
349 b: LengthNormalisation,
350 avgdl: Option<AverageDocumentLength>,
351 delta: TermFrequencyLowerBound,
352 token_indexer: Option<TokenIndexer>,
353}
354
355impl<TokenIndexer, Tokenizer> Bm25VectorizerBuilder<TokenIndexer, Tokenizer> {
356 pub fn new() -> Self {
357 Self {
358 tokenizer: None,
359 k1: TermRelevanceSaturation { k1: 1.2 },
360 b: LengthNormalisation { b: 0.75 },
361 avgdl: None,
362 delta: TermFrequencyLowerBound { delta: 0.0 },
363 token_indexer: None,
364 }
365 }
366
367 pub fn k1(mut self, k1: f32) -> Self {
368 self.k1 = TermRelevanceSaturation { k1 };
369 self
370 }
371
372 pub fn b(mut self, b: f32) -> Self {
373 self.b = LengthNormalisation { b };
374 self
375 }
376
377 pub fn delta(mut self, delta: f32) -> Self {
378 self.delta = TermFrequencyLowerBound { delta };
379 self
380 }
381
382 pub fn avgdl(mut self, avgdl: f32) -> Self {
383 self.avgdl = Some(AverageDocumentLength { avgdl });
384 self
385 }
386
387 pub fn tokenizer(mut self, tokenizer: Tokenizer) -> Self {
388 self.tokenizer = Some(tokenizer);
389 self
390 }
391
392 pub fn token_indexer(mut self, token_indexer: TokenIndexer) -> Self {
393 self.token_indexer = Some(token_indexer);
394 self
395 }
396
397 pub fn fit(mut self, corpus: &[&str]) -> Result<Self, Bm25VectorizerError>
398 where
399 Tokenizer: Bm25Tokenizer + Sync,
400 {
401 if let Some(ref tokenizer) = self.tokenizer {
402 let doc_count = corpus.len();
403 if doc_count == 0 {
404 return Err(Bm25VectorizerError::EmptyCorpus);
405 }
406
407 #[cfg(not(feature = "parallelism"))]
408 let corpus_iter = corpus.iter();
409 #[cfg(feature = "parallelism")]
410 let corpus_iter = corpus.par_iter();
411
412 let total_length: usize = corpus_iter.map(|doc| tokenizer.tokenize(doc).len()).sum();
413 self.avgdl = Some(AverageDocumentLength {
414 avgdl: total_length as f32 / doc_count as f32,
415 });
416 }
417 Ok(self)
418 }
419
420 pub fn fit_iter<I, S>(mut self, corpus: I) -> Result<Self, Bm25VectorizerError>
421 where
422 I: IntoIterator<Item = S>,
423 S: AsRef<str>,
424 Tokenizer: Bm25Tokenizer + Sync,
425 {
426 if let Some(ref tokenizer) = self.tokenizer {
427 let (doc_count, total_length) = corpus
428 .into_iter()
429 .map(|doc| tokenizer.tokenize(doc.as_ref()).len())
430 .fold((0usize, 0usize), |(count, sum), len| (count + 1, sum + len));
431
432 self.avgdl = Some(AverageDocumentLength {
433 avgdl: total_length as f32 / doc_count as f32,
434 });
435 }
436 Ok(self)
437 }
438
439 #[cfg(feature = "parallelism")]
440 pub fn fit_par_iter<I, S>(mut self, corpus: I) -> Result<Self, Bm25VectorizerError>
441 where
442 I: IntoIterator<Item = S>,
443 I::IntoIter: Send,
444 S: AsRef<str> + Send,
445 Tokenizer: Bm25Tokenizer + Sync,
446 {
447 if let Some(ref tokenizer) = self.tokenizer {
448 let (doc_count, total_length) = {
449 use rayon::iter::ParallelBridge;
450 corpus
451 .into_iter()
452 .par_bridge()
453 .map(|doc| tokenizer.tokenize(doc.as_ref()).len())
454 .fold(
455 || (0usize, 0usize),
456 |(count, sum), len| (count + 1, sum + len),
457 )
458 .reduce(|| (0, 0), |(c1, s1), (c2, s2)| (c1 + c2, s1 + s2))
459 };
460
461 if doc_count == 0 {
462 return Err(Bm25VectorizerError::EmptyCorpus);
463 }
464
465 self.avgdl = Some(AverageDocumentLength {
466 avgdl: total_length as f32 / doc_count as f32,
467 });
468 }
469 Ok(self)
470 }
471
472 pub fn build(self) -> Result<Bm25Vectorizer<TokenIndexer, Tokenizer>, Bm25VectorizerError> {
473 let tokenizer = self.tokenizer.ok_or(MissingTokenizer)?;
474 let token_indexer = self.token_indexer.ok_or(MissingTokenIndexer)?;
475 let avgdl = self.avgdl.ok_or(MissingAverageDocumentLength)?;
476
477 if &self.k1.k1 < &0.0 {
478 return Err(InvalidTermRelevanceSaturation);
479 }
480 if &self.b.b < &0.0 || &self.b.b > &1.0 {
481 return Err(InvalidTermRelevanceSaturation);
482 }
483 if &avgdl.avgdl <= &0.0 {
484 return Err(InvalidAverageDocumentLength);
485 }
486 if &self.delta.delta < &0.0 {
487 return Err(InvalidTermFrequencyLowerBound);
488 }
489
490 Ok(Bm25Vectorizer {
491 tokenizer,
492 k1: self.k1,
493 b: self.b,
494 avgdl,
495 delta: self.delta,
496 token_indexer,
497 })
498 }
499}
500
501#[derive(Debug, thiserror::Error)]
502pub enum Bm25VectorizerError {
503 #[error("Cannot fit on empty corpus.")]
504 EmptyCorpus,
505 #[error("Average document length must be provided or computed via fit().")]
506 MissingAverageDocumentLength,
507 #[error("Tokenizer must be provided.")]
508 MissingTokenizer,
509 #[error("Token indexer must be provided.")]
510 MissingTokenIndexer,
511 #[error("Invalid b value: must be between 0 and 1.")]
512 InvalidLengthNormalisation,
513 #[error(
514 "Invalid k1 value: should normally fall within the 0 to 3 range. However, there is no strict enforcement preventing values higher than 3."
515 )]
516 InvalidTermRelevanceSaturation,
517 #[error("Invalid average document length: value must be greater than 0.")]
518 InvalidAverageDocumentLength,
519 #[error("Invalid delta (δ) value: must be 0 or greater.")]
520 InvalidTermFrequencyLowerBound,
521}
522
523#[cfg(test)]
524mod tests {
525 use super::*;
526 use crate::mocking::{
527 MockDictionaryTokenIndexer, MockHashTokenIndexer, MockWhitespaceTokenizer,
528 };
529
530 #[test]
531 fn test_builder_new_defaults() {
532 let builder = Bm25VectorizerBuilder::<MockHashTokenIndexer, MockWhitespaceTokenizer>::new();
533
534 assert_eq!(builder.k1.k1, 1.2);
536 assert_eq!(builder.b.b, 0.75);
537 assert_eq!(builder.delta.delta, 0.0);
538 assert!(builder.tokenizer.is_none());
539 assert!(builder.token_indexer.is_none());
540 assert!(builder.avgdl.is_none());
541 }
542
543 #[test]
544 fn test_builder_parameter_setting() {
545 let builder = Bm25VectorizerBuilder::<MockHashTokenIndexer, MockWhitespaceTokenizer>::new()
546 .k1(2.0)
547 .b(0.5)
548 .delta(0.25)
549 .avgdl(15.0);
550
551 assert_eq!(builder.k1.k1, 2.0);
552 assert_eq!(builder.b.b, 0.5);
553 assert_eq!(builder.delta.delta, 0.25);
554 assert_eq!(builder.avgdl.unwrap().avgdl, 15.0);
555 }
556
557 #[test]
558 fn test_builder_missing_components() {
559 let result = Bm25VectorizerBuilder::<MockHashTokenIndexer, MockWhitespaceTokenizer>::new()
560 .avgdl(10.0)
561 .build();
562
563 assert!(matches!(result, Err(MissingTokenizer)));
564
565 let result = Bm25VectorizerBuilder::<MockHashTokenIndexer, MockWhitespaceTokenizer>::new()
566 .tokenizer(MockWhitespaceTokenizer)
567 .avgdl(10.0)
568 .build();
569
570 assert!(matches!(result, Err(MissingTokenIndexer)));
571
572 let result = Bm25VectorizerBuilder::<MockHashTokenIndexer, MockWhitespaceTokenizer>::new()
573 .tokenizer(MockWhitespaceTokenizer)
574 .token_indexer(MockHashTokenIndexer)
575 .build();
576
577 assert!(matches!(result, Err(MissingAverageDocumentLength)));
578 }
579
580 #[test]
581 fn test_builder_invalid_parameters() {
582 let result = Bm25VectorizerBuilder::new()
584 .tokenizer(MockWhitespaceTokenizer)
585 .token_indexer(MockHashTokenIndexer)
586 .k1(-1.0)
587 .avgdl(10.0)
588 .build();
589
590 assert!(matches!(result, Err(InvalidTermRelevanceSaturation)));
591
592 let result = Bm25VectorizerBuilder::new()
594 .tokenizer(MockWhitespaceTokenizer)
595 .token_indexer(MockHashTokenIndexer)
596 .b(-0.1)
597 .avgdl(10.0)
598 .build();
599
600 assert!(matches!(result, Err(InvalidTermRelevanceSaturation)));
601
602 let result = Bm25VectorizerBuilder::new()
603 .tokenizer(MockWhitespaceTokenizer)
604 .token_indexer(MockHashTokenIndexer)
605 .b(1.1)
606 .avgdl(10.0)
607 .build();
608
609 assert!(matches!(result, Err(InvalidTermRelevanceSaturation)));
610
611 let result = Bm25VectorizerBuilder::new()
613 .tokenizer(MockWhitespaceTokenizer)
614 .token_indexer(MockHashTokenIndexer)
615 .avgdl(0.0)
616 .build();
617
618 assert!(matches!(result, Err(InvalidAverageDocumentLength)));
619
620 let result = Bm25VectorizerBuilder::new()
622 .tokenizer(MockWhitespaceTokenizer)
623 .token_indexer(MockHashTokenIndexer)
624 .delta(-0.1)
625 .avgdl(10.0)
626 .build();
627
628 assert!(matches!(result, Err(InvalidTermFrequencyLowerBound)));
629 }
630
631 #[test]
632 fn test_successful_build() {
633 let vectorizer = Bm25VectorizerBuilder::new()
634 .tokenizer(MockWhitespaceTokenizer)
635 .token_indexer(MockHashTokenIndexer)
636 .k1(1.5)
637 .b(0.8)
638 .delta(0.25)
639 .avgdl(12.0)
640 .build()
641 .unwrap();
642
643 assert_eq!(vectorizer.k1(), 1.5);
644 assert_eq!(vectorizer.b(), 0.8);
645 assert_eq!(vectorizer.delta(), 0.25);
646 assert_eq!(vectorizer.avgdl(), 12.0);
647 }
648
649 #[test]
650 fn test_fit_corpus() {
651 let corpus = vec!["hello world", "world of rust", "hello rust programming"];
652 let builder = Bm25VectorizerBuilder::new()
653 .tokenizer(MockWhitespaceTokenizer)
654 .token_indexer(MockHashTokenIndexer)
655 .fit(&corpus)
656 .unwrap();
657
658 let expected_avgdl = (2.0 + 3.0 + 3.0) / 3.0;
660 assert_eq!(builder.avgdl.unwrap().avgdl, expected_avgdl);
661 }
662
663 #[test]
664 fn test_fit_empty_corpus() {
665 let corpus: Vec<&str> = vec![];
666 let result = Bm25VectorizerBuilder::<MockHashTokenIndexer, MockWhitespaceTokenizer>::new()
667 .tokenizer(MockWhitespaceTokenizer)
668 .fit(&corpus);
669
670 assert!(matches!(result, Err(Bm25VectorizerError::EmptyCorpus)));
671 }
672
673 #[test]
674 fn test_vectorize_basic() {
675 let vectorizer = Bm25VectorizerBuilder::new()
676 .tokenizer(MockWhitespaceTokenizer)
677 .token_indexer(MockDictionaryTokenIndexer::new())
678 .avgdl(2.0)
679 .build()
680 .unwrap();
681
682 let result = vectorizer.vectorize("hello world");
683
684 assert_eq!(result.0.len(), 2);
686
687 for token in &result.0 {
689 assert!(token.value > 0.0);
690 }
691 }
692
693 #[test]
694 fn test_vectorize_repeated_tokens() {
695 let vectorizer = Bm25VectorizerBuilder::new()
696 .tokenizer(MockWhitespaceTokenizer)
697 .token_indexer(MockDictionaryTokenIndexer::new())
698 .avgdl(3.0)
699 .build()
700 .unwrap();
701
702 let result = vectorizer.vectorize("hello hello world");
703
704 assert_eq!(result.0.len(), 2);
706
707 let hello_value = result.0.iter().find(|t| t.index == 0).unwrap().value; let world_value = result.0.iter().find(|t| t.index == 1).unwrap().value; assert!(hello_value > world_value);
712 }
713
714 #[test]
715 fn test_vectorize_empty_text() {
716 let vectorizer = Bm25VectorizerBuilder::new()
717 .tokenizer(MockWhitespaceTokenizer)
718 .token_indexer(MockHashTokenIndexer)
719 .avgdl(2.0)
720 .build()
721 .unwrap();
722
723 let result = vectorizer.vectorize("");
724 assert_eq!(result.0.len(), 0);
725 }
726
727 #[test]
728 fn test_bm25_parameters_effect() {
729 let vectorizer_low_k1 = Bm25VectorizerBuilder::new()
731 .tokenizer(MockWhitespaceTokenizer)
732 .token_indexer(MockDictionaryTokenIndexer::new())
733 .k1(0.5)
734 .avgdl(2.0)
735 .build()
736 .unwrap();
737
738 let vectorizer_high_k1 = Bm25VectorizerBuilder::new()
739 .tokenizer(MockWhitespaceTokenizer)
740 .token_indexer(MockDictionaryTokenIndexer::new())
741 .k1(3.0)
742 .avgdl(2.0)
743 .build()
744 .unwrap();
745
746 let result_low = vectorizer_low_k1.vectorize("hello hello");
747 let result_high = vectorizer_high_k1.vectorize("hello hello");
748
749 assert!(result_high.0[0].value > result_low.0[0].value);
751 }
752
753 #[test]
754 fn test_length_normalisation_effect() {
755 let vectorizer_no_norm = Bm25VectorizerBuilder::new()
756 .tokenizer(MockWhitespaceTokenizer)
757 .token_indexer(MockDictionaryTokenIndexer::new())
758 .b(0.0) .avgdl(5.0)
760 .build()
761 .unwrap();
762
763 let vectorizer_full_norm = Bm25VectorizerBuilder::new()
764 .tokenizer(MockWhitespaceTokenizer)
765 .token_indexer(MockDictionaryTokenIndexer::new())
766 .b(1.0) .avgdl(5.0)
768 .build()
769 .unwrap();
770
771 let long_text = "hello world this is a long document";
773 let short_text = "hello world";
774
775 let long_no_norm = vectorizer_no_norm.vectorize(long_text);
776 let long_full_norm = vectorizer_full_norm.vectorize(long_text);
777 let short_no_norm = vectorizer_no_norm.vectorize(short_text);
778
779 let hello_long_no_norm = long_no_norm.0.iter().find(|t| t.index == 0).unwrap().value;
782 let hello_long_full_norm = long_full_norm
783 .0
784 .iter()
785 .find(|t| t.index == 0)
786 .unwrap()
787 .value;
788 let hello_short_no_norm = short_no_norm.0.iter().find(|t| t.index == 0).unwrap().value;
789
790 assert!(hello_long_no_norm > hello_long_full_norm);
792 assert!(hello_short_no_norm > hello_long_full_norm);
793 }
794
795 #[test]
796 fn test_delta_effect() {
797 let vectorizer_no_delta = Bm25VectorizerBuilder::new()
798 .tokenizer(MockWhitespaceTokenizer)
799 .token_indexer(MockDictionaryTokenIndexer::new())
800 .delta(0.0)
801 .avgdl(2.0)
802 .build()
803 .unwrap();
804
805 let vectorizer_with_delta = Bm25VectorizerBuilder::new()
806 .tokenizer(MockWhitespaceTokenizer)
807 .token_indexer(MockDictionaryTokenIndexer::new())
808 .delta(0.5)
809 .avgdl(2.0)
810 .build()
811 .unwrap();
812
813 let result_no_delta = vectorizer_no_delta.vectorize("hello");
814 let result_with_delta = vectorizer_with_delta.vectorize("hello");
815
816 assert_eq!(
818 result_with_delta.0[0].value,
819 result_no_delta.0[0].value + 0.5
820 );
821 }
822
823 #[cfg(not(feature = "parallelism"))]
824 #[test]
825 fn test_fit_iter() {
826 let corpus = vec!["hello world", "world rust", "hello programming"];
827 let builder = Bm25VectorizerBuilder::<MockHashTokenIndexer, MockWhitespaceTokenizer>::new()
828 .tokenizer(MockWhitespaceTokenizer)
829 .fit_iter(corpus)
830 .unwrap();
831
832 let expected_avgdl = (2.0 + 2.0 + 2.0) / 3.0;
833 assert_eq!(builder.avgdl.unwrap().avgdl, expected_avgdl);
834 }
835
836 #[cfg(feature = "parallelism")]
837 #[test]
838 fn test_fit_par_iter() {
839 let corpus = vec!["hello world", "world rust", "hello programming"];
840 let builder = Bm25VectorizerBuilder::<MockHashTokenIndexer, MockWhitespaceTokenizer>::new()
841 .tokenizer(MockWhitespaceTokenizer)
842 .fit_par_iter(corpus)
843 .unwrap();
844
845 let expected_avgdl = (2.0 + 2.0 + 2.0) / 3.0;
846 assert_eq!(builder.avgdl.unwrap().avgdl, expected_avgdl);
847 }
848}