1use crate::backend::{BackendKind, EmbeddingBackend};
26use crate::error::{InferenceError, Result};
27use crate::models::ModelConfig;
28use async_trait::async_trait;
29use std::sync::Arc;
30use tokenizers::Tokenizer;
31use tracing::{debug, info, instrument};
32
33pub struct StaticBackend {
39 vocab_matrix: Arc<Vec<f32>>,
41 tokenizer: Arc<Tokenizer>,
42 dimension: usize,
43 vocab_size: usize,
44}
45
46impl StaticBackend {
47 #[instrument(skip_all)]
52 pub async fn new(config: &ModelConfig) -> Result<Self> {
53 let config = config.clone();
54 info!("Initialising StaticBackend (Model2Vec)");
55
56 let dim = Self::model2vec_dimension();
57
58 let model_id = config.model.model_id();
60 let cache_dir = crate::backend::onnx::OnnxBackend::model_cache_dir(model_id)?;
61
62 if !cache_dir.join("tokenizer.json").exists() {
64 let model_id_owned = model_id.to_string();
65 let cache_dir_clone = cache_dir.clone();
66 tokio::task::spawn_blocking(move || {
67 crate::backend::onnx::OnnxBackend::download_hf_file(
68 &model_id_owned,
69 "tokenizer.json",
70 &cache_dir_clone,
71 )
72 .map_err(InferenceError::HubError)
73 })
74 .await
75 .map_err(|e| InferenceError::HubError(format!("Download panicked: {e}")))??;
76 }
77
78 let tokenizer_path = cache_dir.join("tokenizer.json");
79 let tokenizer = Tokenizer::from_file(&tokenizer_path)
80 .map_err(|e| InferenceError::TokenizationError(e.to_string()))?;
81
82 let vocab_matrix = Self::load_vocab_matrix(&config, dim).await?;
84 let vocab_size = vocab_matrix.len() / dim;
85
86 info!(
87 "StaticBackend ready: vocab_size={}, dimension={}",
88 vocab_size, dim
89 );
90
91 Ok(Self {
92 vocab_matrix: Arc::new(vocab_matrix),
93 tokenizer: Arc::new(tokenizer),
94 dimension: dim,
95 vocab_size,
96 })
97 }
98
99 pub fn from_matrix(matrix: Vec<f32>, tokenizer: Tokenizer, dimension: usize) -> Result<Self> {
103 if !matrix.len().is_multiple_of(dimension) {
104 return Err(InferenceError::InvalidInput(format!(
105 "vocab_matrix length {} is not divisible by dimension {}",
106 matrix.len(),
107 dimension
108 )));
109 }
110 let vocab_size = matrix.len() / dimension;
111 Ok(Self {
112 vocab_matrix: Arc::new(matrix),
113 tokenizer: Arc::new(tokenizer),
114 dimension,
115 vocab_size,
116 })
117 }
118
119 pub fn model2vec_dimension() -> usize {
121 std::env::var("DAKERA_MRL_DIM")
122 .ok()
123 .and_then(|v| v.parse::<usize>().ok())
124 .filter(|&d| d > 0)
125 .unwrap_or(256)
126 }
127
128 #[instrument(skip(self, text), fields(text_len = text.len()))]
130 fn embed_single(&self, text: &str) -> Vec<f32> {
131 let encoding = match self.tokenizer.encode(text, false) {
133 Ok(enc) => enc,
134 Err(_) => return vec![0.0; self.dimension],
135 };
136
137 let ids = encoding.get_ids();
138 if ids.is_empty() {
139 return vec![0.0; self.dimension];
140 }
141
142 let mut result = vec![0.0f32; self.dimension];
144 let mut valid_tokens = 0usize;
145
146 for &id in ids {
147 let idx = id as usize;
148 if idx >= self.vocab_size {
149 continue;
151 }
152 let offset = idx * self.dimension;
153 let row = &self.vocab_matrix[offset..offset + self.dimension];
154 for (r, v) in result.iter_mut().zip(row.iter()) {
155 *r += v;
156 }
157 valid_tokens += 1;
158 }
159
160 if valid_tokens == 0 {
161 return vec![0.0; self.dimension];
162 }
163
164 let n = valid_tokens as f32;
165 for v in result.iter_mut() {
166 *v /= n;
167 }
168
169 let norm: f32 = result.iter().map(|x| x * x).sum::<f32>().sqrt().max(1e-12);
171 for v in result.iter_mut() {
172 *v /= norm;
173 }
174
175 result
176 }
177
178 async fn load_vocab_matrix(config: &ModelConfig, _dim: usize) -> Result<Vec<f32>> {
180 let model2vec_repo = config.model.model2vec_repo_id();
182 let cache_dir = crate::backend::onnx::OnnxBackend::model_cache_dir(model2vec_repo)?;
183 let matrix_path = cache_dir.join("vocab_matrix.bin");
184
185 if !matrix_path.exists() {
186 info!("Downloading Model2Vec vocab matrix from {}", model2vec_repo);
187 let repo = model2vec_repo.to_string();
188 let cache = cache_dir.clone();
189 tokio::task::spawn_blocking(move || {
190 crate::backend::onnx::OnnxBackend::download_hf_file(
191 &repo,
192 "vocab_matrix.bin",
193 &cache,
194 )
195 .map_err(InferenceError::HubError)
196 })
197 .await
198 .map_err(|e| InferenceError::HubError(format!("Download panicked: {e}")))??;
199 }
200
201 info!("Loading vocab matrix from {:?}", matrix_path);
203 let bytes = std::fs::read(&matrix_path)?;
204 if bytes.len() % 4 != 0 {
205 return Err(InferenceError::ModelLoadError(format!(
206 "vocab_matrix.bin size {} is not a multiple of 4 bytes",
207 bytes.len()
208 )));
209 }
210
211 let floats: Vec<f32> = bytes
212 .chunks_exact(4)
213 .map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))
214 .collect();
215
216 debug!("Vocab matrix loaded: {} f32 values", floats.len());
217 Ok(floats)
218 }
219}
220
221#[async_trait]
222impl EmbeddingBackend for StaticBackend {
223 async fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
224 if texts.is_empty() {
225 return Ok(vec![]);
226 }
227 let results: Vec<Vec<f32>> = texts.iter().map(|t| self.embed_single(t)).collect();
229 Ok(results)
230 }
231
232 fn dimension(&self) -> usize {
233 self.dimension
234 }
235
236 fn backend_kind(&self) -> BackendKind {
237 BackendKind::Static
238 }
239}
240
241#[cfg(test)]
242mod tests {
243 use super::*;
244 use tokenizers::models::wordlevel::WordLevel;
245 use tokenizers::pre_tokenizers::whitespace::Whitespace;
246
247 fn make_test_tokenizer(words: &[&str]) -> Tokenizer {
248 let mut vocab = std::collections::HashMap::new();
249 for (i, w) in words.iter().enumerate() {
250 vocab.insert(w.to_string(), i as u32);
251 }
252 let model = WordLevel::builder()
253 .vocab(vocab)
254 .unk_token("[UNK]".to_string())
255 .build()
256 .unwrap();
257 let mut tok = Tokenizer::new(model);
258 tok.with_pre_tokenizer(Some(Whitespace {}));
259 tok
260 }
261
262 fn make_identity_matrix(vocab_size: usize, dim: usize) -> Vec<f32> {
263 let mut m = vec![0.0f32; vocab_size * dim];
265 for i in 0..vocab_size {
266 m[i * dim + (i % dim)] = 1.0;
267 }
268 m
269 }
270
271 #[test]
272 fn test_static_backend_from_matrix_dimension() {
273 let words = ["[UNK]", "hello", "world", "test", "foo"];
274 let tok = make_test_tokenizer(&words);
275 let matrix = make_identity_matrix(5, 4);
276 let backend = StaticBackend::from_matrix(matrix, tok, 4).unwrap();
277 assert_eq!(backend.dimension(), 4);
278 }
279
280 #[test]
281 fn test_static_backend_from_matrix_vocab_size() {
282 let words = ["[UNK]", "a", "b", "c"];
283 let tok = make_test_tokenizer(&words);
284 let matrix = make_identity_matrix(4, 8);
285 let backend = StaticBackend::from_matrix(matrix, tok, 8).unwrap();
286 assert_eq!(backend.vocab_size, 4);
287 }
288
289 #[test]
290 fn test_static_backend_kind() {
291 let words = ["[UNK]", "hello"];
292 let tok = make_test_tokenizer(&words);
293 let matrix = vec![0.0f32; 2 * 4];
294 let backend = StaticBackend::from_matrix(matrix, tok, 4).unwrap();
295 assert_eq!(backend.backend_kind(), BackendKind::Static);
296 }
297
298 #[test]
299 fn test_static_embed_empty_text_returns_zeros() {
300 let words = ["[UNK]", "hello"];
301 let tok = make_test_tokenizer(&words);
302 let matrix = vec![1.0f32; 2 * 4]; let backend = StaticBackend::from_matrix(matrix, tok, 4).unwrap();
304 let result = backend.embed_single("");
305 assert_eq!(result.len(), 4);
307 assert!(result.iter().all(|&v| v.abs() < 1e-6));
308 }
309
310 #[test]
311 fn test_static_embed_single_token_normalized() {
312 let words = ["[UNK]", "hello", "world"];
313 let tok = make_test_tokenizer(&words);
314 let mut matrix = vec![0.0f32; 3 * 4];
316 matrix[4] = 1.0; let backend = StaticBackend::from_matrix(matrix, tok, 4).unwrap();
318 let emb = backend.embed_single("hello");
319 assert_eq!(emb.len(), 4);
320 assert!((emb[0] - 1.0).abs() < 1e-5);
322 assert!(emb[1].abs() < 1e-5);
323 }
324
325 #[test]
326 fn test_static_embed_invalid_matrix_dimension_error() {
327 let words = ["[UNK]", "hello"];
328 let tok = make_test_tokenizer(&words);
329 let matrix = vec![1.0f32; 5];
331 let result = StaticBackend::from_matrix(matrix, tok, 4);
332 assert!(result.is_err());
333 }
334
335 #[test]
336 fn test_model2vec_dimension_default() {
337 std::env::remove_var("DAKERA_MRL_DIM");
339 assert_eq!(StaticBackend::model2vec_dimension(), 256);
340 }
341
342 #[tokio::test]
343 async fn test_static_embed_batch_empty() {
344 let words = ["[UNK]", "hello"];
345 let tok = make_test_tokenizer(&words);
346 let matrix = vec![0.0f32; 2 * 4];
347 let backend = StaticBackend::from_matrix(matrix, tok, 4).unwrap();
348 let result = backend.embed_batch(&[]).await.unwrap();
349 assert!(result.is_empty());
350 }
351
352 #[tokio::test]
353 async fn test_static_embed_batch_multiple() {
354 let words = ["[UNK]", "hello", "world"];
355 let tok = make_test_tokenizer(&words);
356 let matrix = make_identity_matrix(3, 4);
357 let backend = StaticBackend::from_matrix(matrix, tok, 4).unwrap();
358 let texts = vec!["hello".to_string(), "world".to_string()];
359 let results = backend.embed_batch(&texts).await.unwrap();
360 assert_eq!(results.len(), 2);
361 assert_eq!(results[0].len(), 4);
362 assert_eq!(results[1].len(), 4);
363 }
364
365 #[tokio::test]
366 async fn test_static_embed_batch_preserves_order() {
367 let words = ["[UNK]", "hello", "world"];
368 let tok = make_test_tokenizer(&words);
369 let mut matrix = vec![0.0f32; 3 * 4];
371 matrix[4] = 1.0; matrix[9] = 1.0; let backend = StaticBackend::from_matrix(matrix, tok, 4).unwrap();
374 let texts = vec!["hello".to_string(), "world".to_string()];
375 let results = backend.embed_batch(&texts).await.unwrap();
376 assert!(results[0][0] > results[0][1]);
378 assert!(results[1][1] > results[1][0]);
380 }
381}