mnem_bench/embed.rs
1//! Embedder used by [`crate::adapters::MnemAdapter`].
2//!
3//! Two flavours, selected via [`BenchEmbedder`]:
4//!
5//! 1. **`BagOfTokens`** - the original 0.1.0 hashed-bag-of-tokens
6//! embedder. Network-free, deterministic, ~0.2 recall@5 on
7//! LongMemEval (toy ceiling). Always compiled.
8//! 2. **`OnnxMiniLm`** - real `sentence-transformers/all-MiniLM-L6-v2`
9//! via `mnem-embed-providers` with the `onnx-bundled` feature.
10//! 384-dim, byte-for-byte parity with ChromaDB's
11//! `DefaultEmbeddingFunction`. Default for the smoke gate; this
12//! is the embedder the headline numbers are reported against.
13//! Compiled in when `mnem-bench` is built with the (default-on)
14//! `onnx-minilm` feature.
15//!
16//! The two flavours share one method surface (`model() / dim() /
17//! embed_text()`) so the adapter and the scorers stay flavour-blind.
18//!
19//! # Toy embedder rationale (kept)
20//!
21//! The bag-of-tokens variant stays compiled in for `--no-default-
22//! features` builds, embedded targets, and any environment where
23//! `ort/download-binaries` is undesirable. It uses double-hashed
24//! token buckets (Weinberger et al. 2009) and L2-normalises so
25//! cosine similarity collapses to dot product.
26
27/// Default embedding dimension. 384 matches MiniLM-L6-v2 so the
28/// toy embedder ships byte-compatible vector lengths with the real
29/// ONNX one - swapping flavours never invalidates a vector index.
30pub const DEFAULT_DIM: u32 = 384;
31
32/// Deterministic hashed bag-of-tokens embedder.
33///
34/// `embed(text)` lowercases, ASCII-tokenises on non-alphanumeric
35/// boundaries, hashes each token to two bucket positions
36/// (FNV-1a-style) and adds 1.0 to each. The output vector is L2-
37/// normalised so dense cosine similarity ranks documents by
38/// (count-weighted) shared-token overlap.
39#[derive(Clone, Debug)]
40pub struct ToyEmbedder {
41 model: String,
42 dim: u32,
43}
44
45impl ToyEmbedder {
46 /// Construct a new embedder with the given dimension. `dim`
47 /// must be > 0; values < 32 lead to heavy hash collisions on
48 /// natural text.
49 #[must_use]
50 pub fn new(dim: u32) -> Self {
51 let d = dim.max(8);
52 Self {
53 model: format!("mnem-bench:bag-of-tokens-{d}"),
54 dim: d,
55 }
56 }
57
58 /// Model identifier (passed to mnem's vector lane so embeddings
59 /// match query vectors at retrieve time).
60 #[must_use]
61 pub fn model(&self) -> &str {
62 &self.model
63 }
64
65 /// Embedding dimension.
66 #[must_use]
67 pub const fn dim(&self) -> u32 {
68 self.dim
69 }
70
71 /// Embed a string into a unit-norm vector.
72 #[must_use]
73 pub fn embed_text(&self, text: &str) -> Vec<f32> {
74 let dim = self.dim as usize;
75 let mut v = vec![0f32; dim];
76
77 for tok in tokenise(text) {
78 // Two buckets per token. Mixing two independent hashes
79 // dampens the worst-case collision distortion of the
80 // hashing trick (Weinberger et al. 2009).
81 let h1 = fnv1a(tok.as_bytes()) as usize;
82 let h2 = fnv1a_seeded(tok.as_bytes(), 0x9E37_79B9_7F4A_7C15) as usize;
83 v[h1 % dim] += 1.0;
84 v[h2 % dim] += 1.0;
85 }
86
87 // L2 normalise so cosine == dot.
88 let mut s = 0f64;
89 for x in &v {
90 s += f64::from(*x) * f64::from(*x);
91 }
92 let norm = s.sqrt() as f32;
93 if norm > 0.0 {
94 for x in &mut v {
95 *x /= norm;
96 }
97 }
98 v
99 }
100}
101
102/// Lower-case, ASCII-tokenise on `is_alphanumeric` boundaries.
103/// Drops 1-character tokens (mostly punctuation noise) and trims
104/// to <=64 characters per token to bound worst-case hashing.
105fn tokenise(text: &str) -> impl Iterator<Item = String> + '_ {
106 text.split(|c: char| !c.is_alphanumeric())
107 .filter(|t| t.len() >= 2)
108 .map(|t| {
109 let lower = t.to_lowercase();
110 if lower.len() > 64 {
111 let mut end = 64;
112 while end > 0 && !lower.is_char_boundary(end) {
113 end -= 1;
114 }
115 lower[..end].to_string()
116 } else {
117 lower
118 }
119 })
120}
121
122/// FNV-1a 64-bit hash over a byte slice.
123fn fnv1a(bytes: &[u8]) -> u64 {
124 let mut h: u64 = 0xcbf2_9ce4_8422_2325;
125 for b in bytes {
126 h ^= u64::from(*b);
127 h = h.wrapping_mul(0x100_0000_01b3);
128 }
129 h
130}
131
132/// FNV-1a with a custom 64-bit seed mixed into the offset basis.
133/// Used to derive a second hash for double-hashing.
134fn fnv1a_seeded(bytes: &[u8], seed: u64) -> u64 {
135 let mut h: u64 = 0xcbf2_9ce4_8422_2325 ^ seed;
136 for b in bytes {
137 h ^= u64::from(*b);
138 h = h.wrapping_mul(0x100_0000_01b3);
139 }
140 h
141}
142
143// ============================================================
144// Unified embedder used by `MnemAdapter`
145// ============================================================
146
147/// Unified embedder used by [`crate::adapters::MnemAdapter`].
148///
149/// The two flavours share a method surface (`model()`, `dim()`,
150/// `embed_text()`) so the adapter does not branch on the variant on
151/// every call. Construction is the only code path that picks one.
152///
153/// # Variants
154///
155/// - [`BenchEmbedder::BagOfTokens`] - the always-compiled toy
156/// embedder. Selected by [`crate::EmbedderChoice::BagOfTokens`].
157/// - [`BenchEmbedder::OnnxMiniLm`] - real MiniLM-L6-v2 via
158/// `mnem-embed-providers` (gated on the `onnx-minilm` feature).
159/// Selected by [`crate::EmbedderChoice::OnnxMiniLm`].
160pub enum BenchEmbedder {
161 /// Toy hashed bag-of-tokens. Network-free, ~0.2 recall@5 on
162 /// LongMemEval; ships as the offline / WASM-clean fallback.
163 BagOfTokens(ToyEmbedder),
164 /// Real `all-MiniLM-L6-v2` via `mnem-embed-providers`. The
165 /// concrete type is hidden behind a `Box<dyn Embedder>` so this
166 /// crate stays agnostic to the underlying ORT session lifetime.
167 /// `model_id` and `dim` are cached so the hot path (per-doc
168 /// ingest, per-query retrieve) avoids vtable round-trips.
169 #[cfg(feature = "onnx-minilm")]
170 OnnxMiniLm {
171 /// Boxed provider implementing `mnem_embed_providers::Embedder`.
172 inner: Box<dyn mnem_embed_providers::Embedder>,
173 /// Cached fully-qualified model id (e.g.
174 /// `"onnx:all-MiniLM-L6-v2"`). Identifies the vector lane
175 /// keyed on the mnem retriever.
176 model_id: String,
177 /// Cached output dimension (384 for MiniLM-L6-v2).
178 dim: u32,
179 },
180}
181
182impl std::fmt::Debug for BenchEmbedder {
183 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
184 match self {
185 Self::BagOfTokens(e) => f.debug_tuple("BagOfTokens").field(e).finish(),
186 #[cfg(feature = "onnx-minilm")]
187 Self::OnnxMiniLm { model_id, dim, .. } => f
188 .debug_struct("OnnxMiniLm")
189 .field("model_id", model_id)
190 .field("dim", dim)
191 .finish(),
192 }
193 }
194}
195
196impl BenchEmbedder {
197 /// Construct the toy hashed bag-of-tokens embedder of dimension
198 /// `dim`. Matches 0.1.0 behaviour.
199 #[must_use]
200 pub fn bag_of_tokens(dim: u32) -> Self {
201 Self::BagOfTokens(ToyEmbedder::new(dim))
202 }
203
204 /// Construct the real ONNX MiniLM-L6-v2 embedder via
205 /// `mnem-embed-providers` (`onnx-bundled` flavour). Lazy-
206 /// downloads the model on first call (ORT + tokenizer + weights
207 /// fetched into the HuggingFace cache; ~90MB).
208 ///
209 /// # Errors
210 ///
211 /// Surfaces tokenizer / model-load / ORT-session failures from
212 /// `mnem-embed-providers` verbatim as a `Box<dyn Error>`.
213 #[cfg(feature = "onnx-minilm")]
214 pub fn onnx_minilm() -> Result<Self, Box<dyn std::error::Error>> {
215 use mnem_embed_providers::{OnnxConfig, ProviderConfig, open};
216 let cfg = ProviderConfig::Onnx(OnnxConfig {
217 // Matches the bench-Python adapter (LongMemEval session)
218 // and the `mnem-cli --features bundled-embedder` default.
219 model: "all-MiniLM-L6-v2".to_string(),
220 // None defers to the model's `default_max_length` (256
221 // for MiniLM-L6). LongMemEval sessions are typically
222 // <512 tokens, so the default is fine.
223 max_length: None,
224 });
225 let inner = open(&cfg).map_err(|e| Box::new(e) as Box<dyn std::error::Error>)?;
226 let model_id = inner.model().to_string();
227 let dim = inner.dim();
228 Ok(Self::OnnxMiniLm {
229 inner,
230 model_id,
231 dim,
232 })
233 }
234
235 /// Fully-qualified model identifier. Stamped on every
236 /// `Embedding` and used as the key the retriever's vector lane
237 /// resolves on, so two embedders with the same `model()` MUST
238 /// produce vectors in the same semantic space.
239 #[must_use]
240 pub fn model(&self) -> &str {
241 match self {
242 Self::BagOfTokens(e) => e.model(),
243 #[cfg(feature = "onnx-minilm")]
244 Self::OnnxMiniLm { model_id, .. } => model_id.as_str(),
245 }
246 }
247
248 /// Output vector dimension.
249 #[must_use]
250 pub fn dim(&self) -> u32 {
251 match self {
252 Self::BagOfTokens(e) => e.dim(),
253 #[cfg(feature = "onnx-minilm")]
254 Self::OnnxMiniLm { dim, .. } => *dim,
255 }
256 }
257
258 /// Embed a single string. Errors from the ONNX path (tokenizer,
259 /// session.run) are surfaced as `Box<dyn Error>`. The toy path
260 /// is infallible, so `Result` here is a small ergonomic tax we
261 /// pay so the call site stays variant-blind.
262 ///
263 /// # Errors
264 ///
265 /// Returns the underlying provider error verbatim for the ONNX
266 /// flavour. The bag-of-tokens flavour cannot fail.
267 pub fn embed_text(&self, text: &str) -> Result<Vec<f32>, Box<dyn std::error::Error>> {
268 match self {
269 Self::BagOfTokens(e) => Ok(e.embed_text(text)),
270 #[cfg(feature = "onnx-minilm")]
271 Self::OnnxMiniLm { inner, .. } => inner
272 .embed(text)
273 .map_err(|e| Box::new(e) as Box<dyn std::error::Error>),
274 }
275 }
276}
277
278#[cfg(test)]
279mod tests {
280 use super::*;
281
282 #[test]
283 fn embed_is_deterministic() {
284 let e = ToyEmbedder::new(64);
285 assert_eq!(e.embed_text("hello world"), e.embed_text("hello world"));
286 }
287
288 #[test]
289 fn empty_yields_zero_vector() {
290 let e = ToyEmbedder::new(32);
291 let v = e.embed_text("");
292 assert_eq!(v.len(), 32);
293 assert!(v.iter().all(|x| *x == 0.0));
294 }
295
296 #[test]
297 fn related_text_similarity_is_high() {
298 let e = ToyEmbedder::new(384);
299 let a = e.embed_text("alice climbs in berlin");
300 let b = e.embed_text("alice goes climbing in berlin every weekend");
301 let c = e.embed_text("the eiffel tower is in paris");
302 let dot_ab: f32 = a.iter().zip(&b).map(|(x, y)| x * y).sum();
303 let dot_ac: f32 = a.iter().zip(&c).map(|(x, y)| x * y).sum();
304 // Shared-token overlap should beat the unrelated pair.
305 assert!(dot_ab > dot_ac, "ab={dot_ab} should beat ac={dot_ac}");
306 }
307}