1use crate::tokenizer::Tokenizer;
2use fxhash::{hash, hash32, hash64};
3#[cfg(feature = "parallelism")]
4use rayon::prelude::*;
5use std::{
6 collections::HashMap,
7 fmt::{self, Debug, Display},
8 hash::Hash,
9 marker::PhantomData,
10 ops::{Deref, DerefMut},
11};
12
13pub type DefaultTokenEmbedder = u32;
14pub type DefaultEmbeddingSpace = u32;
15
16#[cfg(feature = "default_tokenizer")]
22pub type DefaultTokenizer = crate::default_tokenizer::DefaultTokenizer;
23
24#[cfg(not(feature = "default_tokenizer"))]
28pub struct NoDefaultTokenizer {}
29#[cfg(not(feature = "default_tokenizer"))]
35pub type DefaultTokenizer = NoDefaultTokenizer;
36
37#[derive(PartialEq, Debug, Clone, PartialOrd)]
39pub struct TokenEmbedding<D = DefaultEmbeddingSpace> {
40 pub index: D,
42 pub value: f32,
44}
45
46impl Display for TokenEmbedding {
47 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
48 write!(f, "{self:?}")
49 }
50}
51
52#[derive(PartialEq, Debug, Clone, PartialOrd)]
54pub struct Embedding<D = DefaultEmbeddingSpace>(pub Vec<TokenEmbedding<D>>);
55
56impl<D> Deref for Embedding<D> {
57 type Target = Vec<TokenEmbedding<D>>;
58
59 fn deref(&self) -> &Self::Target {
60 &self.0
61 }
62}
63
64impl DerefMut for Embedding {
65 fn deref_mut(&mut self) -> &mut Self::Target {
66 &mut self.0
67 }
68}
69
70impl<D> Embedding<D> {
71 pub fn indices(&self) -> impl Iterator<Item = &D> {
73 self.iter().map(|TokenEmbedding { index, .. }| index)
74 }
75
76 pub fn values(&self) -> impl Iterator<Item = &f32> {
78 self.iter().map(|TokenEmbedding { value, .. }| value)
79 }
80}
81
82impl<D: Debug> Display for Embedding<D> {
83 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
84 write!(f, "{self:?}")
85 }
86}
87
88pub trait TokenEmbedder {
90 type EmbeddingSpace;
92 fn embed(token: &str) -> Self::EmbeddingSpace;
94}
95
96impl TokenEmbedder for u32 {
97 type EmbeddingSpace = Self;
98 fn embed(token: &str) -> u32 {
99 hash32(token)
100 }
101}
102
103impl TokenEmbedder for u64 {
104 type EmbeddingSpace = Self;
105 fn embed(token: &str) -> u64 {
106 hash64(token)
107 }
108}
109
110impl TokenEmbedder for usize {
111 type EmbeddingSpace = Self;
112 fn embed(token: &str) -> usize {
113 hash(token)
114 }
115}
116
117#[derive(Debug)]
120pub struct Embedder<D = DefaultTokenEmbedder, T = DefaultTokenizer> {
121 tokenizer: T,
122 k1: f32,
123 b: f32,
124 avgdl: f32,
125 token_embedder_type: PhantomData<D>,
126}
127
128impl<D, T> Embedder<D, T> {
129 const FALLBACK_AVGDL: f32 = 256.0;
130
131 pub fn avgdl(&self) -> f32 {
133 self.avgdl
134 }
135
136 pub fn embed(&self, text: &str) -> Embedding<D::EmbeddingSpace>
138 where
139 D: TokenEmbedder,
140 D::EmbeddingSpace: Eq + Hash,
141 T: Tokenizer,
142 {
143 let avgdl = if self.avgdl <= 0.0 {
144 Self::FALLBACK_AVGDL
145 } else {
146 self.avgdl
147 };
148 let indices: Vec<D::EmbeddingSpace> = self
149 .tokenizer
150 .tokenize(text)
151 .map(|s| D::embed(&s))
152 .collect();
153 let len = indices.len();
154 let counts = indices.iter().fold(HashMap::new(), |mut acc, token| {
155 let count = acc.entry(token).or_insert(0);
156 *count += 1;
157 acc
158 });
159 let values: Vec<f32> = indices
160 .iter()
161 .map(|i| {
162 let token_frequency = *counts.get(i).unwrap_or(&0) as f32;
163 let numerator = token_frequency * (self.k1 + 1.0);
164 let denominator =
165 token_frequency + self.k1 * (1.0 - self.b + self.b * (len as f32 / avgdl));
166 numerator / denominator
167 })
168 .collect();
169
170 Embedding(
171 indices
172 .into_iter()
173 .zip(values)
174 .map(|(index, value)| TokenEmbedding { index, value })
175 .collect(),
176 )
177 }
178}
179
180pub struct EmbedderBuilder<D = DefaultTokenEmbedder, T = DefaultTokenizer> {
182 k1: f32,
183 b: f32,
184 avgdl: f32,
185 tokenizer: T,
186 token_embedder_type: PhantomData<D>,
187}
188
189impl<D, T> EmbedderBuilder<D, T> {
190 pub fn with_avgdl(avgdl: f32) -> EmbedderBuilder<D, T>
201 where
202 T: Default,
203 {
204 EmbedderBuilder {
205 k1: 1.2,
206 b: 0.75,
207 avgdl,
208 tokenizer: T::default(),
209 token_embedder_type: PhantomData,
210 }
211 }
212
213 pub fn with_tokenizer_and_fit_to_corpus(tokenizer: T, corpus: &[&str]) -> EmbedderBuilder<D, T>
218 where
219 T: Tokenizer + Sync,
220 {
221 let avgdl = if corpus.is_empty() {
222 Embedder::<D>::FALLBACK_AVGDL
223 } else {
224 #[cfg(not(feature = "parallelism"))]
225 let corpus_iter = corpus.iter();
226 #[cfg(feature = "parallelism")]
227 let corpus_iter = corpus.par_iter();
228 let total_len: u64 = corpus_iter
229 .map(|doc| tokenizer.tokenize(doc).count() as u64)
230 .sum();
231 (total_len as f64 / corpus.len() as f64) as f32
232 };
233
234 EmbedderBuilder {
235 k1: 1.2,
236 b: 0.75,
237 avgdl,
238 tokenizer,
239 token_embedder_type: PhantomData,
240 }
241 }
242
243 pub fn k1(self, k1: f32) -> EmbedderBuilder<D, T> {
245 EmbedderBuilder { k1, ..self }
246 }
247
248 pub fn b(self, b: f32) -> EmbedderBuilder<D, T> {
250 EmbedderBuilder { b, ..self }
251 }
252
253 pub fn avgdl(self, avgdl: f32) -> EmbedderBuilder<D, T> {
255 EmbedderBuilder { avgdl, ..self }
256 }
257
258 pub fn tokenizer(self, tokenizer: T) -> EmbedderBuilder<D, T> {
260 EmbedderBuilder { tokenizer, ..self }
261 }
262
263 pub fn build(self) -> Embedder<D, T> {
265 Embedder {
266 tokenizer: self.tokenizer,
267 k1: self.k1,
268 b: self.b,
269 avgdl: self.avgdl,
270 token_embedder_type: PhantomData,
271 }
272 }
273}
274
275#[cfg(feature = "default_tokenizer")]
276impl<D> EmbedderBuilder<D, DefaultTokenizer> {
277 pub fn with_fit_to_corpus(
283 language_mode: impl Into<crate::LanguageMode>,
284 corpus: &[&str],
285 ) -> EmbedderBuilder<D, DefaultTokenizer> {
286 let tokenizer = DefaultTokenizer::new(language_mode);
287 EmbedderBuilder::with_tokenizer_and_fit_to_corpus(tokenizer, corpus)
288 }
289
290 pub fn language_mode(
292 self,
293 language_mode: impl Into<crate::LanguageMode>,
294 ) -> EmbedderBuilder<D, DefaultTokenizer> {
295 let tokenizer = DefaultTokenizer::new(language_mode);
296 EmbedderBuilder { tokenizer, ..self }
297 }
298}
299
300#[cfg(test)]
301#[allow(missing_docs)]
302mod tests {
303 use insta::assert_debug_snapshot;
304
305 use crate::{
306 test_data_loader::tests::{read_recipes, Recipe},
307 Language, LanguageMode,
308 };
309
310 use super::*;
311
312 impl Embedding {
313 pub fn any() -> Self {
314 Embedding(vec![TokenEmbedding {
315 index: 1,
316 value: 1.0,
317 }])
318 }
319 }
320
321 impl<D> TokenEmbedding<D> {
322 pub fn new(index: D, value: f32) -> Self {
323 TokenEmbedding { index, value }
324 }
325 }
326
327 fn embed_recipes(recipe_file: &str, language_mode: LanguageMode) -> Vec<Embedding> {
328 let recipes = read_recipes(recipe_file);
329 let embedder: Embedder = EmbedderBuilder::with_fit_to_corpus(
330 language_mode,
331 &recipes
332 .iter()
333 .map(|Recipe { recipe, .. }| recipe.as_str())
334 .collect::<Vec<_>>(),
335 )
336 .build();
337
338 recipes
339 .iter()
340 .map(|Recipe { recipe, .. }| recipe.as_str())
341 .map(|recipe| embedder.embed(recipe))
342 .collect::<Vec<_>>()
343 }
344
345 #[test]
346 fn it_weights_unique_words_equally() {
347 let embedder = EmbedderBuilder::<u32>::with_avgdl(3.0).build();
348 let embedding = embedder.embed("banana apple orange");
349
350 assert!(embedding.len() == 3);
351 assert!(embedding.windows(2).all(|e| e[0].value == e[1].value));
352 }
353
354 #[test]
355 fn it_weights_repeated_words_unequally() {
356 let embedder = EmbedderBuilder::<u32>::with_avgdl(3.0)
357 .tokenizer(DefaultTokenizer::new(Language::English))
358 .build();
359 let embedding = embedder.embed("space station station");
360
361 assert!(
362 *embedding
363 == vec![
364 TokenEmbedding::new(866767497, 1.0),
365 TokenEmbedding::new(666609503, 1.375),
366 TokenEmbedding::new(666609503, 1.375)
367 ]
368 );
369 }
370
371 #[test]
372 fn it_constrains_avgdl() {
373 let embedder = EmbedderBuilder::<u32>::with_avgdl(0.0)
374 .language_mode(Language::English)
375 .build();
376
377 let embedding = embedder.embed("space station");
378
379 assert!(!embedding.is_empty());
380 assert!(embedding.iter().all(|e| e.value > 0.0));
381 }
382
383 #[test]
384 fn it_handles_empty_corpus() {
385 let embedder = EmbedderBuilder::<u32>::with_fit_to_corpus(Language::English, &[]).build();
386
387 let embedding = embedder.embed("space station");
388
389 assert!(!embedding.is_empty());
390 }
391
392 #[test]
393 fn it_handles_empty_input() {
394 let embedder = EmbedderBuilder::<u32>::with_avgdl(1.0).build();
395
396 let embedding = embedder.embed("");
397
398 assert!(embedding.is_empty());
399 }
400
401 #[test]
402 fn it_allows_customisation_of_embedder() {
403 #[derive(Eq, PartialEq, Hash, Clone, Debug)]
404 struct MyType(u32);
405
406 impl TokenEmbedder for MyType {
407 type EmbeddingSpace = Self;
408 fn embed(_: &str) -> Self {
409 MyType(42)
410 }
411 }
412
413 let embedder = EmbedderBuilder::<MyType>::with_avgdl(2.0).build();
414
415 let embedding = embedder.embed("space station");
416
417 assert_eq!(
418 embedding.indices().cloned().collect::<Vec<_>>(),
419 vec![MyType(42), MyType(42)]
420 );
421 }
422
423 #[test]
424 fn it_matches_snapshot_en() {
425 let embeddings = embed_recipes("recipes_en.csv", LanguageMode::Fixed(Language::English));
426
427 insta::with_settings!({snapshot_path => "../snapshots"}, {
428 assert_debug_snapshot!(embeddings);
429 });
430 }
431
432 #[test]
433 fn it_matches_snapshot_de() {
434 let embeddings = embed_recipes("recipes_de.csv", LanguageMode::Fixed(Language::German));
435
436 insta::with_settings!({snapshot_path => "../snapshots"}, {
437 assert_debug_snapshot!(embeddings);
438 });
439 }
440
441 #[test]
442 fn it_allows_customisation_of_tokenizer() {
443 #[derive(Default)]
444 struct MyTokenizer {}
445
446 impl Tokenizer for MyTokenizer {
447 fn tokenize<'a>(&'a self, input_text: &'a str) -> impl Iterator<Item = String> + 'a {
448 input_text
449 .split('T')
450 .filter(|s| !s.is_empty())
451 .map(str::to_string)
452 }
453 }
454
455 let embedder = EmbedderBuilder::<u32, MyTokenizer>::with_avgdl(1.0).build();
456
457 let embedding = embedder.embed("CupTofTtea");
458
459 assert_eq!(
460 embedding.indices().cloned().collect::<Vec<_>>(),
461 vec![3568447556, 3221979461, 415655421]
462 );
463 }
464}