Skip to main content

mnem_bench/
embed.rs

1//! Embedder used by [`crate::adapters::MnemAdapter`].
2//!
3//! Two flavours, selected via [`BenchEmbedder`]:
4//!
5//! 1. **`BagOfTokens`** - the original 0.1.0 hashed-bag-of-tokens
6//!    embedder. Network-free, deterministic, ~0.2 recall@5 on
7//!    LongMemEval (toy ceiling). Always compiled.
8//! 2. **`OnnxMiniLm`** - real `sentence-transformers/all-MiniLM-L6-v2`
9//!    via `mnem-embed-providers` with the `onnx-bundled` feature.
10//!    384-dim, byte-for-byte parity with ChromaDB's
11//!    `DefaultEmbeddingFunction`. Default for the smoke gate; this
12//!    is the embedder the headline numbers are reported against.
13//!    Compiled in when `mnem-bench` is built with the (default-on)
14//!    `onnx-minilm` feature.
15//!
16//! The two flavours share one method surface (`model() / dim() /
17//! embed_text()`) so the adapter and the scorers stay flavour-blind.
18//!
19//! # Toy embedder rationale (kept)
20//!
21//! The bag-of-tokens variant stays compiled in for `--no-default-
22//! features` builds, embedded targets, and any environment where
23//! `ort/download-binaries` is undesirable. It uses double-hashed
24//! token buckets (Weinberger et al. 2009) and L2-normalises so
25//! cosine similarity collapses to dot product.
26
27/// Default embedding dimension. 384 matches MiniLM-L6-v2 so the
28/// toy embedder ships byte-compatible vector lengths with the real
29/// ONNX one - swapping flavours never invalidates a vector index.
30pub const DEFAULT_DIM: u32 = 384;
31
32/// Deterministic hashed bag-of-tokens embedder.
33///
34/// `embed(text)` lowercases, ASCII-tokenises on non-alphanumeric
35/// boundaries, hashes each token to two bucket positions
36/// (FNV-1a-style) and adds 1.0 to each. The output vector is L2-
37/// normalised so dense cosine similarity ranks documents by
38/// (count-weighted) shared-token overlap.
39#[derive(Clone, Debug)]
40pub struct ToyEmbedder {
41    model: String,
42    dim: u32,
43}
44
45impl ToyEmbedder {
46    /// Construct a new embedder with the given dimension. `dim`
47    /// must be > 0; values < 32 lead to heavy hash collisions on
48    /// natural text.
49    #[must_use]
50    pub fn new(dim: u32) -> Self {
51        let d = dim.max(8);
52        Self {
53            model: format!("mnem-bench:bag-of-tokens-{d}"),
54            dim: d,
55        }
56    }
57
58    /// Model identifier (passed to mnem's vector lane so embeddings
59    /// match query vectors at retrieve time).
60    #[must_use]
61    pub fn model(&self) -> &str {
62        &self.model
63    }
64
65    /// Embedding dimension.
66    #[must_use]
67    pub const fn dim(&self) -> u32 {
68        self.dim
69    }
70
71    /// Embed a string into a unit-norm vector.
72    #[must_use]
73    pub fn embed_text(&self, text: &str) -> Vec<f32> {
74        let dim = self.dim as usize;
75        let mut v = vec![0f32; dim];
76
77        for tok in tokenise(text) {
78            // Two buckets per token. Mixing two independent hashes
79            // dampens the worst-case collision distortion of the
80            // hashing trick (Weinberger et al. 2009).
81            let h1 = fnv1a(tok.as_bytes()) as usize;
82            let h2 = fnv1a_seeded(tok.as_bytes(), 0x9E37_79B9_7F4A_7C15) as usize;
83            v[h1 % dim] += 1.0;
84            v[h2 % dim] += 1.0;
85        }
86
87        // L2 normalise so cosine == dot.
88        let mut s = 0f64;
89        for x in &v {
90            s += f64::from(*x) * f64::from(*x);
91        }
92        let norm = s.sqrt() as f32;
93        if norm > 0.0 {
94            for x in &mut v {
95                *x /= norm;
96            }
97        }
98        v
99    }
100}
101
102/// Lower-case, ASCII-tokenise on `is_alphanumeric` boundaries.
103/// Drops 1-character tokens (mostly punctuation noise) and trims
104/// to <=64 characters per token to bound worst-case hashing.
105fn tokenise(text: &str) -> impl Iterator<Item = String> + '_ {
106    text.split(|c: char| !c.is_alphanumeric())
107        .filter(|t| t.len() >= 2)
108        .map(|t| {
109            let lower = t.to_lowercase();
110            if lower.len() > 64 {
111                let mut end = 64;
112                while end > 0 && !lower.is_char_boundary(end) {
113                    end -= 1;
114                }
115                lower[..end].to_string()
116            } else {
117                lower
118            }
119        })
120}
121
122/// FNV-1a 64-bit hash over a byte slice.
123fn fnv1a(bytes: &[u8]) -> u64 {
124    let mut h: u64 = 0xcbf2_9ce4_8422_2325;
125    for b in bytes {
126        h ^= u64::from(*b);
127        h = h.wrapping_mul(0x100_0000_01b3);
128    }
129    h
130}
131
132/// FNV-1a with a custom 64-bit seed mixed into the offset basis.
133/// Used to derive a second hash for double-hashing.
134fn fnv1a_seeded(bytes: &[u8], seed: u64) -> u64 {
135    let mut h: u64 = 0xcbf2_9ce4_8422_2325 ^ seed;
136    for b in bytes {
137        h ^= u64::from(*b);
138        h = h.wrapping_mul(0x100_0000_01b3);
139    }
140    h
141}
142
143// ============================================================
144// Unified embedder used by `MnemAdapter`
145// ============================================================
146
147/// Unified embedder used by [`crate::adapters::MnemAdapter`].
148///
149/// The two flavours share a method surface (`model()`, `dim()`,
150/// `embed_text()`) so the adapter does not branch on the variant on
151/// every call. Construction is the only code path that picks one.
152///
153/// # Variants
154///
155/// - [`BenchEmbedder::BagOfTokens`] - the always-compiled toy
156///   embedder. Selected by [`crate::EmbedderChoice::BagOfTokens`].
157/// - [`BenchEmbedder::OnnxMiniLm`] - real MiniLM-L6-v2 via
158///   `mnem-embed-providers` (gated on the `onnx-minilm` feature).
159///   Selected by [`crate::EmbedderChoice::OnnxMiniLm`].
160pub enum BenchEmbedder {
161    /// Toy hashed bag-of-tokens. Network-free, ~0.2 recall@5 on
162    /// LongMemEval; ships as the offline / WASM-clean fallback.
163    BagOfTokens(ToyEmbedder),
164    /// Real `all-MiniLM-L6-v2` via `mnem-embed-providers`. The
165    /// concrete type is hidden behind a `Box<dyn Embedder>` so this
166    /// crate stays agnostic to the underlying ORT session lifetime.
167    /// `model_id` and `dim` are cached so the hot path (per-doc
168    /// ingest, per-query retrieve) avoids vtable round-trips.
169    #[cfg(feature = "onnx-minilm")]
170    OnnxMiniLm {
171        /// Boxed provider implementing `mnem_embed_providers::Embedder`.
172        inner: Box<dyn mnem_embed_providers::Embedder>,
173        /// Cached fully-qualified model id (e.g.
174        /// `"onnx:all-MiniLM-L6-v2"`). Identifies the vector lane
175        /// keyed on the mnem retriever.
176        model_id: String,
177        /// Cached output dimension (384 for MiniLM-L6-v2).
178        dim: u32,
179    },
180}
181
182impl std::fmt::Debug for BenchEmbedder {
183    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
184        match self {
185            Self::BagOfTokens(e) => f.debug_tuple("BagOfTokens").field(e).finish(),
186            #[cfg(feature = "onnx-minilm")]
187            Self::OnnxMiniLm { model_id, dim, .. } => f
188                .debug_struct("OnnxMiniLm")
189                .field("model_id", model_id)
190                .field("dim", dim)
191                .finish(),
192        }
193    }
194}
195
196impl BenchEmbedder {
197    /// Construct the toy hashed bag-of-tokens embedder of dimension
198    /// `dim`. Matches 0.1.0 behaviour.
199    #[must_use]
200    pub fn bag_of_tokens(dim: u32) -> Self {
201        Self::BagOfTokens(ToyEmbedder::new(dim))
202    }
203
204    /// Construct the real ONNX MiniLM-L6-v2 embedder via
205    /// `mnem-embed-providers` (`onnx-bundled` flavour). Lazy-
206    /// downloads the model on first call (ORT + tokenizer + weights
207    /// fetched into the HuggingFace cache; ~90MB).
208    ///
209    /// # Errors
210    ///
211    /// Surfaces tokenizer / model-load / ORT-session failures from
212    /// `mnem-embed-providers` verbatim as a `Box<dyn Error>`.
213    #[cfg(feature = "onnx-minilm")]
214    pub fn onnx_minilm() -> Result<Self, Box<dyn std::error::Error>> {
215        use mnem_embed_providers::{OnnxConfig, ProviderConfig, open};
216        let cfg = ProviderConfig::Onnx(OnnxConfig {
217            // Matches the bench-Python adapter (LongMemEval session)
218            // and the `mnem-cli --features bundled-embedder` default.
219            model: "all-MiniLM-L6-v2".to_string(),
220            // None defers to the model's `default_max_length` (256
221            // for MiniLM-L6). LongMemEval sessions are typically
222            // <512 tokens, so the default is fine.
223            max_length: None,
224        });
225        let inner = open(&cfg).map_err(|e| Box::new(e) as Box<dyn std::error::Error>)?;
226        let model_id = inner.model().to_string();
227        let dim = inner.dim();
228        Ok(Self::OnnxMiniLm {
229            inner,
230            model_id,
231            dim,
232        })
233    }
234
235    /// Fully-qualified model identifier. Stamped on every
236    /// `Embedding` and used as the key the retriever's vector lane
237    /// resolves on, so two embedders with the same `model()` MUST
238    /// produce vectors in the same semantic space.
239    #[must_use]
240    pub fn model(&self) -> &str {
241        match self {
242            Self::BagOfTokens(e) => e.model(),
243            #[cfg(feature = "onnx-minilm")]
244            Self::OnnxMiniLm { model_id, .. } => model_id.as_str(),
245        }
246    }
247
248    /// Output vector dimension.
249    #[must_use]
250    pub fn dim(&self) -> u32 {
251        match self {
252            Self::BagOfTokens(e) => e.dim(),
253            #[cfg(feature = "onnx-minilm")]
254            Self::OnnxMiniLm { dim, .. } => *dim,
255        }
256    }
257
258    /// Embed a single string. Errors from the ONNX path (tokenizer,
259    /// session.run) are surfaced as `Box<dyn Error>`. The toy path
260    /// is infallible, so `Result` here is a small ergonomic tax we
261    /// pay so the call site stays variant-blind.
262    ///
263    /// # Errors
264    ///
265    /// Returns the underlying provider error verbatim for the ONNX
266    /// flavour. The bag-of-tokens flavour cannot fail.
267    pub fn embed_text(&self, text: &str) -> Result<Vec<f32>, Box<dyn std::error::Error>> {
268        match self {
269            Self::BagOfTokens(e) => Ok(e.embed_text(text)),
270            #[cfg(feature = "onnx-minilm")]
271            Self::OnnxMiniLm { inner, .. } => inner
272                .embed(text)
273                .map_err(|e| Box::new(e) as Box<dyn std::error::Error>),
274        }
275    }
276}
277
278#[cfg(test)]
279mod tests {
280    use super::*;
281
282    #[test]
283    fn embed_is_deterministic() {
284        let e = ToyEmbedder::new(64);
285        assert_eq!(e.embed_text("hello world"), e.embed_text("hello world"));
286    }
287
288    #[test]
289    fn empty_yields_zero_vector() {
290        let e = ToyEmbedder::new(32);
291        let v = e.embed_text("");
292        assert_eq!(v.len(), 32);
293        assert!(v.iter().all(|x| *x == 0.0));
294    }
295
296    #[test]
297    fn related_text_similarity_is_high() {
298        let e = ToyEmbedder::new(384);
299        let a = e.embed_text("alice climbs in berlin");
300        let b = e.embed_text("alice goes climbing in berlin every weekend");
301        let c = e.embed_text("the eiffel tower is in paris");
302        let dot_ab: f32 = a.iter().zip(&b).map(|(x, y)| x * y).sum();
303        let dot_ac: f32 = a.iter().zip(&c).map(|(x, y)| x * y).sum();
304        // Shared-token overlap should beat the unrelated pair.
305        assert!(dot_ab > dot_ac, "ab={dot_ab} should beat ac={dot_ac}");
306    }
307}