bm_25/
embedder.rs

1use crate::tokenizer::Tokenizer;
2use fxhash::{hash, hash32, hash64};
3#[cfg(feature = "parallelism")]
4use rayon::prelude::*;
5use std::{
6    collections::HashMap,
7    fmt::{self, Debug, Display},
8    hash::Hash,
9    marker::PhantomData,
10    ops::{Deref, DerefMut},
11};
12
13pub type DefaultTokenEmbedder = u32;
14pub type DefaultEmbeddingSpace = u32;
15
16/// The default tokenizer is available via the `default_tokenizer` feature. It should fit most
17/// use-cases. It splits on whitespace and punctuation, removes stop words and stems the
18/// remaining words. It can also detect languages via the `language_detection` feature. This crate
19/// uses `DefaultTokenizer` as the default concrete type for things that are generic
20/// over a `Tokenizer`.
21#[cfg(feature = "default_tokenizer")]
22pub type DefaultTokenizer = crate::default_tokenizer::DefaultTokenizer;
23
24/// A dummy type to represent the absence of a default tokenizer. If a compile error led you here,
25/// you either need to enable the `default_tokenizer` feature, or specify your custom tokenizer as
26/// a type parameter to whatever you're trying to construct.
27#[cfg(not(feature = "default_tokenizer"))]
28pub struct NoDefaultTokenizer {}
29/// The default tokenizer is available via the `default_tokenizer` feature. It should fit most
30/// use-cases. It splits on whitespace and punctuation, removes stop words and stems the
31/// remaining words. It can also detect languages via the `language_detection` feature. This crate
32/// uses `DefaultTokenizer` as the default concrete type for things that are generic
33/// over a `Tokenizer`.
34#[cfg(not(feature = "default_tokenizer"))]
35pub type DefaultTokenizer = NoDefaultTokenizer;
36
37/// Represents a token embedded in a D-dimensional space.
38#[derive(PartialEq, Debug, Clone, PartialOrd)]
39pub struct TokenEmbedding<D = DefaultEmbeddingSpace> {
40    /// The index of the token in the embedding space.
41    pub index: D,
42    /// The value of the token in the embedding space.
43    pub value: f32,
44}
45
46impl Display for TokenEmbedding {
47    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
48        write!(f, "{self:?}")
49    }
50}
51
52/// Represents a document embedded in a D-dimensional space.
53#[derive(PartialEq, Debug, Clone, PartialOrd)]
54pub struct Embedding<D = DefaultEmbeddingSpace>(pub Vec<TokenEmbedding<D>>);
55
56impl<D> Deref for Embedding<D> {
57    type Target = Vec<TokenEmbedding<D>>;
58
59    fn deref(&self) -> &Self::Target {
60        &self.0
61    }
62}
63
64impl DerefMut for Embedding {
65    fn deref_mut(&mut self) -> &mut Self::Target {
66        &mut self.0
67    }
68}
69
70impl<D> Embedding<D> {
71    /// Returns an iterator over the indices of the embedding.
72    pub fn indices(&self) -> impl Iterator<Item = &D> {
73        self.iter().map(|TokenEmbedding { index, .. }| index)
74    }
75
76    /// Returns an iterator over the values of the embedding.
77    pub fn values(&self) -> impl Iterator<Item = &f32> {
78        self.iter().map(|TokenEmbedding { value, .. }| value)
79    }
80}
81
82impl<D: Debug> Display for Embedding<D> {
83    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
84        write!(f, "{self:?}")
85    }
86}
87
88/// A trait for embedding. Implement this to customise the embedding space and function.
89pub trait TokenEmbedder {
90    /// The output type of the embedder, i.e., the embedding space.
91    type EmbeddingSpace;
92    /// Embeds a token into the embedding space.
93    fn embed(token: &str) -> Self::EmbeddingSpace;
94}
95
96impl TokenEmbedder for u32 {
97    type EmbeddingSpace = Self;
98    fn embed(token: &str) -> u32 {
99        hash32(token)
100    }
101}
102
103impl TokenEmbedder for u64 {
104    type EmbeddingSpace = Self;
105    fn embed(token: &str) -> u64 {
106        hash64(token)
107    }
108}
109
110impl TokenEmbedder for usize {
111    type EmbeddingSpace = Self;
112    fn embed(token: &str) -> usize {
113        hash(token)
114    }
115}
116
117/// Creates sparse embeddings from text. D is the type of the token embedder and T is the type of
118/// the tokenizer.
119#[derive(Debug)]
120pub struct Embedder<D = DefaultTokenEmbedder, T = DefaultTokenizer> {
121    tokenizer: T,
122    k1: f32,
123    b: f32,
124    avgdl: f32,
125    token_embedder_type: PhantomData<D>,
126}
127
128impl<D, T> Embedder<D, T> {
129    const FALLBACK_AVGDL: f32 = 256.0;
130
131    /// Returns the average document length used by the embedder.
132    pub fn avgdl(&self) -> f32 {
133        self.avgdl
134    }
135
136    /// Embeds the given text into the embedding space.
137    pub fn embed(&self, text: &str) -> Embedding<D::EmbeddingSpace>
138    where
139        D: TokenEmbedder,
140        D::EmbeddingSpace: Eq + Hash,
141        T: Tokenizer,
142    {
143        let avgdl = if self.avgdl <= 0.0 {
144            Self::FALLBACK_AVGDL
145        } else {
146            self.avgdl
147        };
148        let indices: Vec<D::EmbeddingSpace> = self
149            .tokenizer
150            .tokenize(text)
151            .map(|s| D::embed(&s))
152            .collect();
153        let len = indices.len();
154        let counts = indices.iter().fold(HashMap::new(), |mut acc, token| {
155            let count = acc.entry(token).or_insert(0);
156            *count += 1;
157            acc
158        });
159        let values: Vec<f32> = indices
160            .iter()
161            .map(|i| {
162                let token_frequency = *counts.get(i).unwrap_or(&0) as f32;
163                let numerator = token_frequency * (self.k1 + 1.0);
164                let denominator =
165                    token_frequency + self.k1 * (1.0 - self.b + self.b * (len as f32 / avgdl));
166                numerator / denominator
167            })
168            .collect();
169
170        Embedding(
171            indices
172                .into_iter()
173                .zip(values)
174                .map(|(index, value)| TokenEmbedding { index, value })
175                .collect(),
176        )
177    }
178}
179
180/// A consuming builder for Embedder.
181pub struct EmbedderBuilder<D = DefaultTokenEmbedder, T = DefaultTokenizer> {
182    k1: f32,
183    b: f32,
184    avgdl: f32,
185    tokenizer: T,
186    token_embedder_type: PhantomData<D>,
187}
188
189impl<D, T> EmbedderBuilder<D, T> {
190    /// Constructs a new EmbedderBuilder with the given average document length. Use this if you
191    /// know the average document length in advance. If you don't, but you have your full corpus
192    /// ahead of time, use `with_fit_to_corpus` or `with_tokenizer_and_fit_to_corpus` instead.
193    ///
194    /// If you have neither the full corpus nor a sample of it, you can configure the embedder to
195    /// disregard document length by setting `b` to 0.0. In this case, it doesn't matter what
196    /// value you pass to `with_avgdl`.
197    ///
198    /// The average document length is the average number of tokens in a document from your corpus;
199    /// if you need access to this value, you can construct an Embedder and call `avgdl` on it.
200    pub fn with_avgdl(avgdl: f32) -> EmbedderBuilder<D, T>
201    where
202        T: Default,
203    {
204        EmbedderBuilder {
205            k1: 1.2,
206            b: 0.75,
207            avgdl,
208            tokenizer: T::default(),
209            token_embedder_type: PhantomData,
210        }
211    }
212
213    /// Constructs a new EmbedderBuilder with its average document length fit to the given corpus.
214    /// Use this if you have the full corpus (or a sample of it) available in advance. The embedder
215    /// will assume the given tokenizer. Use the `parallelism` feature to speed the fitting process
216    /// up for large corpora.
217    pub fn with_tokenizer_and_fit_to_corpus(tokenizer: T, corpus: &[&str]) -> EmbedderBuilder<D, T>
218    where
219        T: Tokenizer + Sync,
220    {
221        let avgdl = if corpus.is_empty() {
222            Embedder::<D>::FALLBACK_AVGDL
223        } else {
224            #[cfg(not(feature = "parallelism"))]
225            let corpus_iter = corpus.iter();
226            #[cfg(feature = "parallelism")]
227            let corpus_iter = corpus.par_iter();
228            let total_len: u64 = corpus_iter
229                .map(|doc| tokenizer.tokenize(doc).count() as u64)
230                .sum();
231            (total_len as f64 / corpus.len() as f64) as f32
232        };
233
234        EmbedderBuilder {
235            k1: 1.2,
236            b: 0.75,
237            avgdl,
238            tokenizer,
239            token_embedder_type: PhantomData,
240        }
241    }
242
243    /// Sets the k1 parameter for the embedder. The default value is 1.2.
244    pub fn k1(self, k1: f32) -> EmbedderBuilder<D, T> {
245        EmbedderBuilder { k1, ..self }
246    }
247
248    /// Sets the b parameter for the embedder. The default value is 0.75.
249    pub fn b(self, b: f32) -> EmbedderBuilder<D, T> {
250        EmbedderBuilder { b, ..self }
251    }
252
253    /// Overrides the average document length for the embedder.
254    pub fn avgdl(self, avgdl: f32) -> EmbedderBuilder<D, T> {
255        EmbedderBuilder { avgdl, ..self }
256    }
257
258    /// Sets the tokenizer for the embedder.
259    pub fn tokenizer(self, tokenizer: T) -> EmbedderBuilder<D, T> {
260        EmbedderBuilder { tokenizer, ..self }
261    }
262
263    /// Builds the Embedder.
264    pub fn build(self) -> Embedder<D, T> {
265        Embedder {
266            tokenizer: self.tokenizer,
267            k1: self.k1,
268            b: self.b,
269            avgdl: self.avgdl,
270            token_embedder_type: PhantomData,
271        }
272    }
273}
274
275#[cfg(feature = "default_tokenizer")]
276impl<D> EmbedderBuilder<D, DefaultTokenizer> {
277    /// Constructs a new EmbedderBuilder with its average document length fit to the given corpus.
278    /// Use this if you have the full corpus (or a sample of it) available in advance. This
279    /// function uses the default tokenizer configured with the input language mode. The embedder
280    /// will assume this tokenizer. Use the `parallelism` feature to speed the fitting process up
281    /// for large corpora.
282    pub fn with_fit_to_corpus(
283        language_mode: impl Into<crate::LanguageMode>,
284        corpus: &[&str],
285    ) -> EmbedderBuilder<D, DefaultTokenizer> {
286        let tokenizer = DefaultTokenizer::new(language_mode);
287        EmbedderBuilder::with_tokenizer_and_fit_to_corpus(tokenizer, corpus)
288    }
289
290    /// Sets the language mode for the embedder tokenizer.
291    pub fn language_mode(
292        self,
293        language_mode: impl Into<crate::LanguageMode>,
294    ) -> EmbedderBuilder<D, DefaultTokenizer> {
295        let tokenizer = DefaultTokenizer::new(language_mode);
296        EmbedderBuilder { tokenizer, ..self }
297    }
298}
299
300#[cfg(test)]
301#[allow(missing_docs)]
302mod tests {
303    use insta::assert_debug_snapshot;
304
305    use crate::{
306        test_data_loader::tests::{read_recipes, Recipe},
307        Language, LanguageMode,
308    };
309
310    use super::*;
311
312    impl Embedding {
313        pub fn any() -> Self {
314            Embedding(vec![TokenEmbedding {
315                index: 1,
316                value: 1.0,
317            }])
318        }
319    }
320
321    impl<D> TokenEmbedding<D> {
322        pub fn new(index: D, value: f32) -> Self {
323            TokenEmbedding { index, value }
324        }
325    }
326
327    fn embed_recipes(recipe_file: &str, language_mode: LanguageMode) -> Vec<Embedding> {
328        let recipes = read_recipes(recipe_file);
329        let embedder: Embedder = EmbedderBuilder::with_fit_to_corpus(
330            language_mode,
331            &recipes
332                .iter()
333                .map(|Recipe { recipe, .. }| recipe.as_str())
334                .collect::<Vec<_>>(),
335        )
336        .build();
337
338        recipes
339            .iter()
340            .map(|Recipe { recipe, .. }| recipe.as_str())
341            .map(|recipe| embedder.embed(recipe))
342            .collect::<Vec<_>>()
343    }
344
345    #[test]
346    fn it_weights_unique_words_equally() {
347        let embedder = EmbedderBuilder::<u32>::with_avgdl(3.0).build();
348        let embedding = embedder.embed("banana apple orange");
349
350        assert!(embedding.len() == 3);
351        assert!(embedding.windows(2).all(|e| e[0].value == e[1].value));
352    }
353
354    #[test]
355    fn it_weights_repeated_words_unequally() {
356        let embedder = EmbedderBuilder::<u32>::with_avgdl(3.0)
357            .tokenizer(DefaultTokenizer::new(Language::English))
358            .build();
359        let embedding = embedder.embed("space station station");
360
361        assert!(
362            *embedding
363                == vec![
364                    TokenEmbedding::new(866767497, 1.0),
365                    TokenEmbedding::new(666609503, 1.375),
366                    TokenEmbedding::new(666609503, 1.375)
367                ]
368        );
369    }
370
371    #[test]
372    fn it_constrains_avgdl() {
373        let embedder = EmbedderBuilder::<u32>::with_avgdl(0.0)
374            .language_mode(Language::English)
375            .build();
376
377        let embedding = embedder.embed("space station");
378
379        assert!(!embedding.is_empty());
380        assert!(embedding.iter().all(|e| e.value > 0.0));
381    }
382
383    #[test]
384    fn it_handles_empty_corpus() {
385        let embedder = EmbedderBuilder::<u32>::with_fit_to_corpus(Language::English, &[]).build();
386
387        let embedding = embedder.embed("space station");
388
389        assert!(!embedding.is_empty());
390    }
391
392    #[test]
393    fn it_handles_empty_input() {
394        let embedder = EmbedderBuilder::<u32>::with_avgdl(1.0).build();
395
396        let embedding = embedder.embed("");
397
398        assert!(embedding.is_empty());
399    }
400
401    #[test]
402    fn it_allows_customisation_of_embedder() {
403        #[derive(Eq, PartialEq, Hash, Clone, Debug)]
404        struct MyType(u32);
405
406        impl TokenEmbedder for MyType {
407            type EmbeddingSpace = Self;
408            fn embed(_: &str) -> Self {
409                MyType(42)
410            }
411        }
412
413        let embedder = EmbedderBuilder::<MyType>::with_avgdl(2.0).build();
414
415        let embedding = embedder.embed("space station");
416
417        assert_eq!(
418            embedding.indices().cloned().collect::<Vec<_>>(),
419            vec![MyType(42), MyType(42)]
420        );
421    }
422
423    #[test]
424    fn it_matches_snapshot_en() {
425        let embeddings = embed_recipes("recipes_en.csv", LanguageMode::Fixed(Language::English));
426
427        insta::with_settings!({snapshot_path => "../snapshots"}, {
428            assert_debug_snapshot!(embeddings);
429        });
430    }
431
432    #[test]
433    fn it_matches_snapshot_de() {
434        let embeddings = embed_recipes("recipes_de.csv", LanguageMode::Fixed(Language::German));
435
436        insta::with_settings!({snapshot_path => "../snapshots"}, {
437            assert_debug_snapshot!(embeddings);
438        });
439    }
440
441    #[test]
442    fn it_allows_customisation_of_tokenizer() {
443        #[derive(Default)]
444        struct MyTokenizer {}
445
446        impl Tokenizer for MyTokenizer {
447            fn tokenize<'a>(&'a self, input_text: &'a str) -> impl Iterator<Item = String> + 'a {
448                input_text
449                    .split('T')
450                    .filter(|s| !s.is_empty())
451                    .map(str::to_string)
452            }
453        }
454
455        let embedder = EmbedderBuilder::<u32, MyTokenizer>::with_avgdl(1.0).build();
456
457        let embedding = embedder.embed("CupTofTtea");
458
459        assert_eq!(
460            embedding.indices().cloned().collect::<Vec<_>>(),
461            vec![3568447556, 3221979461, 415655421]
462        );
463    }
464}