Skip to main content

inference/backend/
static_backend.rs

1//! Model2Vec static embedding backend.
2//!
3//! Embeds text via a pre-distilled vocabulary matrix — no neural network
4//! forward pass at all.  Each token maps directly to a pre-computed vector;
5//! the final embedding is the mean pool of the token vectors followed by L2
6//! normalisation.
7//!
8//! **Performance**: <0.1 ms per text on CPU, >50,000 embeddings/second.
9//! **Memory**: ~30 MB vocab matrix vs. ~520 MB ONNX session pool.
10//!
11//! # Trade-off
12//!
13//! Quality is ~8–15% lower on MTEB relative to the full BGE-Large transformer.
14//! This backend is intended for the *write path* (fast ingest) only.  The
15//! [`TieredEngine`](crate::tiered::TieredEngine) uses the full transformer for
16//! the *recall* query path, preserving recall quality.
17//!
18//! # Model artifact
19//!
20//! The distilled vocabulary matrix is downloaded at runtime from
21//! `dakera-ai/bge-large-model2vec-256d` on HuggingFace Hub.  File:
22//! `vocab_matrix.bin` (flat `f32` array, `vocab_size × dimension`).
23//! File integrity is validated (length must be divisible by 4 bytes).
24
25use crate::backend::{BackendKind, EmbeddingBackend};
26use crate::error::{InferenceError, Result};
27use crate::models::ModelConfig;
28use async_trait::async_trait;
29use std::sync::Arc;
30use tokenizers::Tokenizer;
31use tracing::{debug, info, instrument};
32
33/// Model2Vec static embedding backend.
34///
35/// Holds the vocabulary matrix in memory (optionally memory-mapped for cold-start
36/// speed).  The tokenizer used is the *same* tokenizer.json as the source model
37/// (BGE-Large / ModernBERT), so token IDs are consistent with the ONNX/Candle paths.
38pub struct StaticBackend {
39    /// Flat row-major f32 array: `[vocab_size × dimension]`
40    vocab_matrix: Arc<Vec<f32>>,
41    tokenizer: Arc<Tokenizer>,
42    dimension: usize,
43    vocab_size: usize,
44}
45
46impl StaticBackend {
47    /// Build a new `StaticBackend`.
48    ///
49    /// Downloads `vocab_matrix.bin` and the model tokenizer from HuggingFace
50    /// on first run; subsequent calls use the local cache.
51    #[instrument(skip_all)]
52    pub async fn new(config: &ModelConfig) -> Result<Self> {
53        let config = config.clone();
54        info!("Initialising StaticBackend (Model2Vec)");
55
56        let dim = Self::model2vec_dimension();
57
58        // Download tokenizer.json from the source model repo
59        let model_id = config.model.model_id();
60        let cache_dir = crate::backend::onnx::OnnxBackend::model_cache_dir(model_id)?;
61
62        // Download tokenizer if not cached
63        if !cache_dir.join("tokenizer.json").exists() {
64            let model_id_owned = model_id.to_string();
65            let cache_dir_clone = cache_dir.clone();
66            tokio::task::spawn_blocking(move || {
67                crate::backend::onnx::OnnxBackend::download_hf_file(
68                    &model_id_owned,
69                    "tokenizer.json",
70                    &cache_dir_clone,
71                )
72                .map_err(InferenceError::HubError)
73            })
74            .await
75            .map_err(|e| InferenceError::HubError(format!("Download panicked: {e}")))??;
76        }
77
78        let tokenizer_path = cache_dir.join("tokenizer.json");
79        let tokenizer = Tokenizer::from_file(&tokenizer_path)
80            .map_err(|e| InferenceError::TokenizationError(e.to_string()))?;
81
82        // Download vocabulary matrix
83        let vocab_matrix = Self::load_vocab_matrix(&config, dim).await?;
84        let vocab_size = vocab_matrix.len() / dim;
85
86        info!(
87            "StaticBackend ready: vocab_size={}, dimension={}",
88            vocab_size, dim
89        );
90
91        Ok(Self {
92            vocab_matrix: Arc::new(vocab_matrix),
93            tokenizer: Arc::new(tokenizer),
94            dimension: dim,
95            vocab_size,
96        })
97    }
98
99    /// Build from a pre-loaded vocabulary matrix (useful for tests).
100    ///
101    /// `matrix` must be a flat row-major array of shape `[vocab_size × dimension]`.
102    pub fn from_matrix(matrix: Vec<f32>, tokenizer: Tokenizer, dimension: usize) -> Result<Self> {
103        if !matrix.len().is_multiple_of(dimension) {
104            return Err(InferenceError::InvalidInput(format!(
105                "vocab_matrix length {} is not divisible by dimension {}",
106                matrix.len(),
107                dimension
108            )));
109        }
110        let vocab_size = matrix.len() / dimension;
111        Ok(Self {
112            vocab_matrix: Arc::new(matrix),
113            tokenizer: Arc::new(tokenizer),
114            dimension,
115            vocab_size,
116        })
117    }
118
119    /// Configured Model2Vec output dimension (`DAKERA_MRL_DIM`, default 256).
120    pub fn model2vec_dimension() -> usize {
121        std::env::var("DAKERA_MRL_DIM")
122            .ok()
123            .and_then(|v| v.parse::<usize>().ok())
124            .filter(|&d| d > 0)
125            .unwrap_or(256)
126    }
127
128    /// Embed a single text via token-lookup + mean pooling.
129    #[instrument(skip(self, text), fields(text_len = text.len()))]
130    fn embed_single(&self, text: &str) -> Vec<f32> {
131        // Tokenize — encode returns token IDs
132        let encoding = match self.tokenizer.encode(text, false) {
133            Ok(enc) => enc,
134            Err(_) => return vec![0.0; self.dimension],
135        };
136
137        let ids = encoding.get_ids();
138        if ids.is_empty() {
139            return vec![0.0; self.dimension];
140        }
141
142        // Mean pool token vectors
143        let mut result = vec![0.0f32; self.dimension];
144        let mut valid_tokens = 0usize;
145
146        for &id in ids {
147            let idx = id as usize;
148            if idx >= self.vocab_size {
149                // OOV: skip (treat as zero vector contribution)
150                continue;
151            }
152            let offset = idx * self.dimension;
153            let row = &self.vocab_matrix[offset..offset + self.dimension];
154            for (r, v) in result.iter_mut().zip(row.iter()) {
155                *r += v;
156            }
157            valid_tokens += 1;
158        }
159
160        if valid_tokens == 0 {
161            return vec![0.0; self.dimension];
162        }
163
164        let n = valid_tokens as f32;
165        for v in result.iter_mut() {
166            *v /= n;
167        }
168
169        // L2 normalise
170        let norm: f32 = result.iter().map(|x| x * x).sum::<f32>().sqrt().max(1e-12);
171        for v in result.iter_mut() {
172            *v /= norm;
173        }
174
175        result
176    }
177
178    /// Load (or download) the Model2Vec vocabulary matrix.
179    async fn load_vocab_matrix(config: &ModelConfig, _dim: usize) -> Result<Vec<f32>> {
180        // Determine cache path
181        let model2vec_repo = config.model.model2vec_repo_id();
182        let cache_dir = crate::backend::onnx::OnnxBackend::model_cache_dir(model2vec_repo)?;
183        let matrix_path = cache_dir.join("vocab_matrix.bin");
184
185        if !matrix_path.exists() {
186            info!("Downloading Model2Vec vocab matrix from {}", model2vec_repo);
187            let repo = model2vec_repo.to_string();
188            let cache = cache_dir.clone();
189            tokio::task::spawn_blocking(move || {
190                crate::backend::onnx::OnnxBackend::download_hf_file(
191                    &repo,
192                    "vocab_matrix.bin",
193                    &cache,
194                )
195                .map_err(InferenceError::HubError)
196            })
197            .await
198            .map_err(|e| InferenceError::HubError(format!("Download panicked: {e}")))??;
199        }
200
201        // Read as raw f32 bytes
202        info!("Loading vocab matrix from {:?}", matrix_path);
203        let bytes = std::fs::read(&matrix_path)?;
204        if bytes.len() % 4 != 0 {
205            return Err(InferenceError::ModelLoadError(format!(
206                "vocab_matrix.bin size {} is not a multiple of 4 bytes",
207                bytes.len()
208            )));
209        }
210
211        let floats: Vec<f32> = bytes
212            .chunks_exact(4)
213            .map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))
214            .collect();
215
216        debug!("Vocab matrix loaded: {} f32 values", floats.len());
217        Ok(floats)
218    }
219}
220
221#[async_trait]
222impl EmbeddingBackend for StaticBackend {
223    async fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
224        if texts.is_empty() {
225            return Ok(vec![]);
226        }
227        // Pure CPU — no async needed; avoid spawn_blocking overhead for small batches
228        let results: Vec<Vec<f32>> = texts.iter().map(|t| self.embed_single(t)).collect();
229        Ok(results)
230    }
231
232    fn dimension(&self) -> usize {
233        self.dimension
234    }
235
236    fn backend_kind(&self) -> BackendKind {
237        BackendKind::Static
238    }
239}
240
241#[cfg(test)]
242mod tests {
243    use super::*;
244    use tokenizers::models::wordlevel::WordLevel;
245    use tokenizers::pre_tokenizers::whitespace::Whitespace;
246
247    fn make_test_tokenizer(words: &[&str]) -> Tokenizer {
248        let mut vocab = std::collections::HashMap::new();
249        for (i, w) in words.iter().enumerate() {
250            vocab.insert(w.to_string(), i as u32);
251        }
252        let model = WordLevel::builder()
253            .vocab(vocab)
254            .unk_token("[UNK]".to_string())
255            .build()
256            .unwrap();
257        let mut tok = Tokenizer::new(model);
258        tok.with_pre_tokenizer(Some(Whitespace {}));
259        tok
260    }
261
262    fn make_identity_matrix(vocab_size: usize, dim: usize) -> Vec<f32> {
263        // Each token's vector is a one-hot-like: token[i][i % dim] = 1.0
264        let mut m = vec![0.0f32; vocab_size * dim];
265        for i in 0..vocab_size {
266            m[i * dim + (i % dim)] = 1.0;
267        }
268        m
269    }
270
271    #[test]
272    fn test_static_backend_from_matrix_dimension() {
273        let words = ["[UNK]", "hello", "world", "test", "foo"];
274        let tok = make_test_tokenizer(&words);
275        let matrix = make_identity_matrix(5, 4);
276        let backend = StaticBackend::from_matrix(matrix, tok, 4).unwrap();
277        assert_eq!(backend.dimension(), 4);
278    }
279
280    #[test]
281    fn test_static_backend_from_matrix_vocab_size() {
282        let words = ["[UNK]", "a", "b", "c"];
283        let tok = make_test_tokenizer(&words);
284        let matrix = make_identity_matrix(4, 8);
285        let backend = StaticBackend::from_matrix(matrix, tok, 8).unwrap();
286        assert_eq!(backend.vocab_size, 4);
287    }
288
289    #[test]
290    fn test_static_backend_kind() {
291        let words = ["[UNK]", "hello"];
292        let tok = make_test_tokenizer(&words);
293        let matrix = vec![0.0f32; 2 * 4];
294        let backend = StaticBackend::from_matrix(matrix, tok, 4).unwrap();
295        assert_eq!(backend.backend_kind(), BackendKind::Static);
296    }
297
298    #[test]
299    fn test_static_embed_empty_text_returns_zeros() {
300        let words = ["[UNK]", "hello"];
301        let tok = make_test_tokenizer(&words);
302        let matrix = vec![1.0f32; 2 * 4]; // all ones
303        let backend = StaticBackend::from_matrix(matrix, tok, 4).unwrap();
304        let result = backend.embed_single("");
305        // Empty text → no tokens → zero vector
306        assert_eq!(result.len(), 4);
307        assert!(result.iter().all(|&v| v.abs() < 1e-6));
308    }
309
310    #[test]
311    fn test_static_embed_single_token_normalized() {
312        let words = ["[UNK]", "hello", "world"];
313        let tok = make_test_tokenizer(&words);
314        // hello → token id 1; row 1 = [1, 0, 0, 0]
315        let mut matrix = vec![0.0f32; 3 * 4];
316        matrix[4] = 1.0; // token 1, dim 0
317        let backend = StaticBackend::from_matrix(matrix, tok, 4).unwrap();
318        let emb = backend.embed_single("hello");
319        assert_eq!(emb.len(), 4);
320        // Normalized [1,0,0,0] is still [1,0,0,0]
321        assert!((emb[0] - 1.0).abs() < 1e-5);
322        assert!(emb[1].abs() < 1e-5);
323    }
324
325    #[test]
326    fn test_static_embed_invalid_matrix_dimension_error() {
327        let words = ["[UNK]", "hello"];
328        let tok = make_test_tokenizer(&words);
329        // 5 floats not divisible by dim=4
330        let matrix = vec![1.0f32; 5];
331        let result = StaticBackend::from_matrix(matrix, tok, 4);
332        assert!(result.is_err());
333    }
334
335    #[test]
336    fn test_model2vec_dimension_default() {
337        // Should return 256 when env var is unset
338        std::env::remove_var("DAKERA_MRL_DIM");
339        assert_eq!(StaticBackend::model2vec_dimension(), 256);
340    }
341
342    #[tokio::test]
343    async fn test_static_embed_batch_empty() {
344        let words = ["[UNK]", "hello"];
345        let tok = make_test_tokenizer(&words);
346        let matrix = vec![0.0f32; 2 * 4];
347        let backend = StaticBackend::from_matrix(matrix, tok, 4).unwrap();
348        let result = backend.embed_batch(&[]).await.unwrap();
349        assert!(result.is_empty());
350    }
351
352    #[tokio::test]
353    async fn test_static_embed_batch_multiple() {
354        let words = ["[UNK]", "hello", "world"];
355        let tok = make_test_tokenizer(&words);
356        let matrix = make_identity_matrix(3, 4);
357        let backend = StaticBackend::from_matrix(matrix, tok, 4).unwrap();
358        let texts = vec!["hello".to_string(), "world".to_string()];
359        let results = backend.embed_batch(&texts).await.unwrap();
360        assert_eq!(results.len(), 2);
361        assert_eq!(results[0].len(), 4);
362        assert_eq!(results[1].len(), 4);
363    }
364
365    #[tokio::test]
366    async fn test_static_embed_batch_preserves_order() {
367        let words = ["[UNK]", "hello", "world"];
368        let tok = make_test_tokenizer(&words);
369        // hello → [1, 0, 0, 0], world → [0, 1, 0, 0]
370        let mut matrix = vec![0.0f32; 3 * 4];
371        matrix[4] = 1.0; // token 1 (hello), dim 0
372        matrix[9] = 1.0; // token 2 (world), dim 1
373        let backend = StaticBackend::from_matrix(matrix, tok, 4).unwrap();
374        let texts = vec!["hello".to_string(), "world".to_string()];
375        let results = backend.embed_batch(&texts).await.unwrap();
376        // hello embedding: dim 0 dominant
377        assert!(results[0][0] > results[0][1]);
378        // world embedding: dim 1 dominant
379        assert!(results[1][1] > results[1][0]);
380    }
381}