pub const DEFAULT_DIM: u32 = 384;
#[derive(Clone, Debug)]
pub struct ToyEmbedder {
model: String,
dim: u32,
}
impl ToyEmbedder {
#[must_use]
pub fn new(dim: u32) -> Self {
let d = dim.max(8);
Self {
model: format!("mnem-bench:bag-of-tokens-{d}"),
dim: d,
}
}
#[must_use]
pub fn model(&self) -> &str {
&self.model
}
#[must_use]
pub const fn dim(&self) -> u32 {
self.dim
}
#[must_use]
pub fn embed_text(&self, text: &str) -> Vec<f32> {
let dim = self.dim as usize;
let mut v = vec![0f32; dim];
for tok in tokenise(text) {
let h1 = fnv1a(tok.as_bytes()) as usize;
let h2 = fnv1a_seeded(tok.as_bytes(), 0x9E37_79B9_7F4A_7C15) as usize;
v[h1 % dim] += 1.0;
v[h2 % dim] += 1.0;
}
let mut s = 0f64;
for x in &v {
s += f64::from(*x) * f64::from(*x);
}
let norm = s.sqrt() as f32;
if norm > 0.0 {
for x in &mut v {
*x /= norm;
}
}
v
}
}
fn tokenise(text: &str) -> impl Iterator<Item = String> + '_ {
text.split(|c: char| !c.is_alphanumeric())
.filter(|t| t.len() >= 2)
.map(|t| {
let lower = t.to_lowercase();
if lower.len() > 64 {
let mut end = 64;
while end > 0 && !lower.is_char_boundary(end) {
end -= 1;
}
lower[..end].to_string()
} else {
lower
}
})
}
fn fnv1a(bytes: &[u8]) -> u64 {
let mut h: u64 = 0xcbf2_9ce4_8422_2325;
for b in bytes {
h ^= u64::from(*b);
h = h.wrapping_mul(0x100_0000_01b3);
}
h
}
fn fnv1a_seeded(bytes: &[u8], seed: u64) -> u64 {
let mut h: u64 = 0xcbf2_9ce4_8422_2325 ^ seed;
for b in bytes {
h ^= u64::from(*b);
h = h.wrapping_mul(0x100_0000_01b3);
}
h
}
pub enum BenchEmbedder {
BagOfTokens(ToyEmbedder),
#[cfg(feature = "onnx-minilm")]
OnnxMiniLm {
inner: Box<dyn mnem_embed_providers::Embedder>,
model_id: String,
dim: u32,
},
}
impl std::fmt::Debug for BenchEmbedder {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::BagOfTokens(e) => f.debug_tuple("BagOfTokens").field(e).finish(),
#[cfg(feature = "onnx-minilm")]
Self::OnnxMiniLm { model_id, dim, .. } => f
.debug_struct("OnnxMiniLm")
.field("model_id", model_id)
.field("dim", dim)
.finish(),
}
}
}
impl BenchEmbedder {
#[must_use]
pub fn bag_of_tokens(dim: u32) -> Self {
Self::BagOfTokens(ToyEmbedder::new(dim))
}
#[cfg(feature = "onnx-minilm")]
pub fn onnx_minilm() -> Result<Self, Box<dyn std::error::Error>> {
use mnem_embed_providers::{OnnxConfig, ProviderConfig, open};
let cfg = ProviderConfig::Onnx(OnnxConfig {
model: "all-MiniLM-L6-v2".to_string(),
max_length: None,
});
let inner = open(&cfg).map_err(|e| Box::new(e) as Box<dyn std::error::Error>)?;
let model_id = inner.model().to_string();
let dim = inner.dim();
Ok(Self::OnnxMiniLm {
inner,
model_id,
dim,
})
}
#[must_use]
pub fn model(&self) -> &str {
match self {
Self::BagOfTokens(e) => e.model(),
#[cfg(feature = "onnx-minilm")]
Self::OnnxMiniLm { model_id, .. } => model_id.as_str(),
}
}
#[must_use]
pub fn dim(&self) -> u32 {
match self {
Self::BagOfTokens(e) => e.dim(),
#[cfg(feature = "onnx-minilm")]
Self::OnnxMiniLm { dim, .. } => *dim,
}
}
pub fn embed_text(&self, text: &str) -> Result<Vec<f32>, Box<dyn std::error::Error>> {
match self {
Self::BagOfTokens(e) => Ok(e.embed_text(text)),
#[cfg(feature = "onnx-minilm")]
Self::OnnxMiniLm { inner, .. } => inner
.embed(text)
.map_err(|e| Box::new(e) as Box<dyn std::error::Error>),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn embed_is_deterministic() {
let e = ToyEmbedder::new(64);
assert_eq!(e.embed_text("hello world"), e.embed_text("hello world"));
}
#[test]
fn empty_yields_zero_vector() {
let e = ToyEmbedder::new(32);
let v = e.embed_text("");
assert_eq!(v.len(), 32);
assert!(v.iter().all(|x| *x == 0.0));
}
#[test]
fn related_text_similarity_is_high() {
let e = ToyEmbedder::new(384);
let a = e.embed_text("alice climbs in berlin");
let b = e.embed_text("alice goes climbing in berlin every weekend");
let c = e.embed_text("the eiffel tower is in paris");
let dot_ab: f32 = a.iter().zip(&b).map(|(x, y)| x * y).sum();
let dot_ac: f32 = a.iter().zip(&c).map(|(x, y)| x * y).sum();
assert!(dot_ab > dot_ac, "ab={dot_ab} should beat ac={dot_ac}");
}
}