use crate::tokenizer::Tokenizer;
use fxhash::{hash, hash32, hash64};
#[cfg(feature = "parallelism")]
use rayon::prelude::*;
use std::{
collections::HashMap,
fmt::{self, Debug, Display},
hash::Hash,
marker::PhantomData,
ops::{Deref, DerefMut},
};
pub type DefaultTokenEmbedder = u32;
pub type DefaultEmbeddingSpace = u32;
#[cfg(feature = "default_tokenizer")]
pub type DefaultTokenizer = crate::default_tokenizer::DefaultTokenizer;
#[cfg(not(feature = "default_tokenizer"))]
pub struct NoDefaultTokenizer {}
#[cfg(not(feature = "default_tokenizer"))]
pub type DefaultTokenizer = NoDefaultTokenizer;
#[derive(PartialEq, Debug, Clone, PartialOrd)]
pub struct TokenEmbedding<D = DefaultEmbeddingSpace> {
pub index: D,
pub value: f32,
}
impl Display for TokenEmbedding {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{self:?}")
}
}
#[derive(PartialEq, Debug, Clone, PartialOrd)]
pub struct Embedding<D = DefaultEmbeddingSpace>(pub Vec<TokenEmbedding<D>>);
impl<D> Deref for Embedding<D> {
type Target = Vec<TokenEmbedding<D>>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl DerefMut for Embedding {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
impl<D> Embedding<D> {
pub fn indices(&self) -> impl Iterator<Item = &D> {
self.iter().map(|TokenEmbedding { index, .. }| index)
}
pub fn values(&self) -> impl Iterator<Item = &f32> {
self.iter().map(|TokenEmbedding { value, .. }| value)
}
}
impl<D: Debug> Display for Embedding<D> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{self:?}")
}
}
pub trait TokenEmbedder {
type EmbeddingSpace;
fn embed(token: &str) -> Self::EmbeddingSpace;
}
impl TokenEmbedder for u32 {
type EmbeddingSpace = Self;
fn embed(token: &str) -> u32 {
hash32(token)
}
}
impl TokenEmbedder for u64 {
type EmbeddingSpace = Self;
fn embed(token: &str) -> u64 {
hash64(token)
}
}
impl TokenEmbedder for usize {
type EmbeddingSpace = Self;
fn embed(token: &str) -> usize {
hash(token)
}
}
#[derive(Debug)]
pub struct Embedder<D = DefaultTokenEmbedder, T = DefaultTokenizer> {
tokenizer: T,
k1: f32,
b: f32,
avgdl: f32,
token_embedder_type: PhantomData<D>,
}
impl<D, T> Embedder<D, T> {
const FALLBACK_AVGDL: f32 = 256.0;
pub fn avgdl(&self) -> f32 {
self.avgdl
}
pub fn embed(&self, text: &str) -> Embedding<D::EmbeddingSpace>
where
D: TokenEmbedder,
D::EmbeddingSpace: Eq + Hash,
T: Tokenizer,
{
let tokens = self.tokenizer.tokenize(text);
let avgdl = if self.avgdl <= 0.0 {
Self::FALLBACK_AVGDL
} else {
self.avgdl
};
let indices: Vec<D::EmbeddingSpace> = tokens.iter().map(|s| D::embed(s)).collect();
let counts = indices.iter().fold(HashMap::new(), |mut acc, token| {
let count = acc.entry(token).or_insert(0);
*count += 1;
acc
});
let values: Vec<f32> = indices
.iter()
.map(|i| {
let token_frequency = *counts.get(i).unwrap_or(&0) as f32;
let numerator = token_frequency * (self.k1 + 1.0);
let denominator = token_frequency
+ self.k1 * (1.0 - self.b + self.b * (tokens.len() as f32 / avgdl));
numerator / denominator
})
.collect();
Embedding(
indices
.into_iter()
.zip(values)
.map(|(index, value)| TokenEmbedding { index, value })
.collect(),
)
}
}
pub struct EmbedderBuilder<D = DefaultTokenEmbedder, T = DefaultTokenizer> {
k1: f32,
b: f32,
avgdl: f32,
tokenizer: T,
token_embedder_type: PhantomData<D>,
}
impl<D, T> EmbedderBuilder<D, T> {
pub fn with_avgdl(avgdl: f32) -> EmbedderBuilder<D, T>
where
T: Default,
{
EmbedderBuilder {
k1: 1.2,
b: 0.75,
avgdl,
tokenizer: T::default(),
token_embedder_type: PhantomData,
}
}
pub fn with_tokenizer_and_fit_to_corpus(tokenizer: T, corpus: &[&str]) -> EmbedderBuilder<D, T>
where
T: Tokenizer + Sync,
{
let avgdl = if corpus.is_empty() {
Embedder::<D>::FALLBACK_AVGDL
} else {
#[cfg(not(feature = "parallelism"))]
let corpus_iter = corpus.iter();
#[cfg(feature = "parallelism")]
let corpus_iter = corpus.par_iter();
let total_len: u64 = corpus_iter
.map(|doc| tokenizer.tokenize(doc).len() as u64)
.sum();
(total_len as f64 / corpus.len() as f64) as f32
};
EmbedderBuilder {
k1: 1.2,
b: 0.75,
avgdl,
tokenizer,
token_embedder_type: PhantomData,
}
}
pub fn k1(self, k1: f32) -> EmbedderBuilder<D, T> {
EmbedderBuilder { k1, ..self }
}
pub fn b(self, b: f32) -> EmbedderBuilder<D, T> {
EmbedderBuilder { b, ..self }
}
pub fn avgdl(self, avgdl: f32) -> EmbedderBuilder<D, T> {
EmbedderBuilder { avgdl, ..self }
}
pub fn tokenizer(self, tokenizer: T) -> EmbedderBuilder<D, T> {
EmbedderBuilder { tokenizer, ..self }
}
pub fn build(self) -> Embedder<D, T> {
Embedder {
tokenizer: self.tokenizer,
k1: self.k1,
b: self.b,
avgdl: self.avgdl,
token_embedder_type: PhantomData,
}
}
}
#[cfg(feature = "default_tokenizer")]
impl<D> EmbedderBuilder<D, DefaultTokenizer> {
pub fn with_fit_to_corpus(
language_mode: impl Into<crate::LanguageMode>,
corpus: &[&str],
) -> EmbedderBuilder<D, DefaultTokenizer> {
let tokenizer = DefaultTokenizer::new(language_mode);
EmbedderBuilder::with_tokenizer_and_fit_to_corpus(tokenizer, corpus)
}
pub fn language_mode(
self,
language_mode: impl Into<crate::LanguageMode>,
) -> EmbedderBuilder<D, DefaultTokenizer> {
let tokenizer = DefaultTokenizer::new(language_mode);
EmbedderBuilder { tokenizer, ..self }
}
}
#[cfg(test)]
#[allow(missing_docs)]
mod tests {
use insta::assert_debug_snapshot;
use crate::{
test_data_loader::tests::{read_recipes, Recipe},
Language, LanguageMode,
};
use super::*;
impl Embedding {
pub fn any() -> Self {
Embedding(vec![TokenEmbedding {
index: 1,
value: 1.0,
}])
}
}
impl<D> TokenEmbedding<D> {
pub fn new(index: D, value: f32) -> Self {
TokenEmbedding { index, value }
}
}
fn embed_recipes(recipe_file: &str, language_mode: LanguageMode) -> Vec<Embedding> {
let recipes = read_recipes(recipe_file);
let embedder: Embedder = EmbedderBuilder::with_fit_to_corpus(
language_mode,
&recipes
.iter()
.map(|Recipe { recipe, .. }| recipe.as_str())
.collect::<Vec<_>>(),
)
.build();
recipes
.iter()
.map(|Recipe { recipe, .. }| recipe.as_str())
.map(|recipe| embedder.embed(recipe))
.collect::<Vec<_>>()
}
#[test]
fn it_weights_unique_words_equally() {
let embedder = EmbedderBuilder::<u32>::with_avgdl(3.0).build();
let embedding = embedder.embed("banana apple orange");
assert!(embedding.len() == 3);
assert!(embedding.windows(2).all(|e| e[0].value == e[1].value));
}
#[test]
fn it_weights_repeated_words_unequally() {
let embedder = EmbedderBuilder::<u32>::with_avgdl(3.0)
.tokenizer(DefaultTokenizer::new(Language::English))
.build();
let embedding = embedder.embed("space station station");
assert!(
*embedding
== vec![
TokenEmbedding::new(866767497, 1.0),
TokenEmbedding::new(666609503, 1.375),
TokenEmbedding::new(666609503, 1.375)
]
);
}
#[test]
fn it_constrains_avgdl() {
let embedder = EmbedderBuilder::<u32>::with_avgdl(0.0)
.language_mode(Language::English)
.build();
let embedding = embedder.embed("space station");
assert!(!embedding.is_empty());
assert!(embedding.iter().all(|e| e.value > 0.0));
}
#[test]
fn it_handles_empty_corpus() {
let embedder = EmbedderBuilder::<u32>::with_fit_to_corpus(Language::English, &[]).build();
let embedding = embedder.embed("space station");
assert!(!embedding.is_empty());
}
#[test]
fn it_handles_empty_input() {
let embedder = EmbedderBuilder::<u32>::with_avgdl(1.0).build();
let embedding = embedder.embed("");
assert!(embedding.is_empty());
}
#[test]
fn it_allows_customisation_of_embedder() {
#[derive(Eq, PartialEq, Hash, Clone, Debug)]
struct MyType(u32);
impl TokenEmbedder for MyType {
type EmbeddingSpace = Self;
fn embed(_: &str) -> Self {
MyType(42)
}
}
let embedder = EmbedderBuilder::<MyType>::with_avgdl(2.0).build();
let embedding = embedder.embed("space station");
assert_eq!(
embedding.indices().cloned().collect::<Vec<_>>(),
vec![MyType(42), MyType(42)]
);
}
#[test]
fn it_matches_snapshot_en() {
let embeddings = embed_recipes("recipes_en.csv", LanguageMode::Fixed(Language::English));
insta::with_settings!({snapshot_path => "../snapshots"}, {
assert_debug_snapshot!(embeddings);
});
}
#[test]
fn it_matches_snapshot_de() {
let embeddings = embed_recipes("recipes_de.csv", LanguageMode::Fixed(Language::German));
insta::with_settings!({snapshot_path => "../snapshots"}, {
assert_debug_snapshot!(embeddings);
});
}
#[test]
fn it_allows_customisation_of_tokenizer() {
#[derive(Default)]
struct MyTokenizer {}
impl Tokenizer for MyTokenizer {
fn tokenize(&self, input_text: &str) -> Vec<String> {
input_text
.split("T")
.filter(|s| !s.is_empty())
.map(str::to_string)
.collect()
}
}
let embedder = EmbedderBuilder::<u32, MyTokenizer>::with_avgdl(1.0).build();
let embedding = embedder.embed("CupTofTtea");
assert_eq!(
embedding.indices().cloned().collect::<Vec<_>>(),
vec![3568447556, 3221979461, 415655421]
);
}
}