coding_agent_search/search/
fastembed_embedder.rs1use std::fs;
16use std::path::{Path, PathBuf};
17use std::sync::Mutex;
18
19use fastembed::{
20 InitOptionsUserDefined, Pooling, TextEmbedding, TokenizerFiles, UserDefinedEmbeddingModel,
21};
22
23use super::embedder::{Embedder, EmbedderError, EmbedderResult};
24use frankensearch::{ModelCategory, ModelTier};
25
26const MINILM_MODEL_ID: &str = "all-minilm-l6-v2";
28const MINILM_DIR_NAME: &str = "all-MiniLM-L6-v2";
29const MINILM_EMBEDDER_ID: &str = "minilm-384";
30const MINILM_DIMENSION: usize = 384;
31
32pub const MODEL_ONNX_SUBDIR: &str = "onnx/model.onnx";
34pub const MODEL_ONNX_LEGACY: &str = "model.onnx";
35const TOKENIZER_JSON: &str = "tokenizer.json";
36const CONFIG_JSON: &str = "config.json";
37const SPECIAL_TOKENS_JSON: &str = "special_tokens_map.json";
38const TOKENIZER_CONFIG_JSON: &str = "tokenizer_config.json";
39
40#[derive(Debug, Clone)]
42pub struct OnnxEmbedderConfig {
43 pub embedder_id: String,
45 pub model_id: String,
47 pub dimension: usize,
49 pub pooling: Pooling,
51}
52
53impl Default for OnnxEmbedderConfig {
54 fn default() -> Self {
55 Self {
56 embedder_id: MINILM_EMBEDDER_ID.to_string(),
57 model_id: MINILM_MODEL_ID.to_string(),
58 dimension: MINILM_DIMENSION,
59 pooling: Pooling::Mean,
60 }
61 }
62}
63
64pub struct FastEmbedder {
68 model: Mutex<TextEmbedding>,
69 id: String,
70 model_id: String,
71 dimension: usize,
72}
73
74impl FastEmbedder {
75 pub fn embedder_id_static() -> &'static str {
77 MINILM_EMBEDDER_ID
78 }
79
80 pub fn model_id_static() -> &'static str {
82 MINILM_MODEL_ID
83 }
84
85 pub fn required_model_files() -> &'static [&'static str] {
90 &[
91 TOKENIZER_JSON,
92 CONFIG_JSON,
93 SPECIAL_TOKENS_JSON,
94 TOKENIZER_CONFIG_JSON,
95 ]
96 }
97
98 pub fn model_file_candidates() -> &'static [&'static str] {
100 &[MODEL_ONNX_SUBDIR, MODEL_ONNX_LEGACY]
101 }
102
103 pub fn select_model_file(model_dir: &Path) -> Option<PathBuf> {
105 for candidate in Self::model_file_candidates() {
106 let path = model_dir.join(candidate);
107 if path.is_file() {
108 return Some(path);
109 }
110 }
111 None
112 }
113
114 pub fn default_model_dir(data_dir: &Path) -> PathBuf {
116 data_dir.join("models").join(MINILM_DIR_NAME)
117 }
118
119 pub fn model_dir_for(data_dir: &Path, embedder_name: &str) -> Option<PathBuf> {
121 let dir_name = match embedder_name {
122 "minilm" => MINILM_DIR_NAME,
123 "snowflake-arctic-s" => "snowflake-arctic-embed-s",
124 "nomic-embed" => "nomic-embed-text-v1.5",
125 _ => return None,
126 };
127 Some(data_dir.join("models").join(dir_name))
128 }
129
130 pub fn config_for(embedder_name: &str) -> Option<OnnxEmbedderConfig> {
132 match embedder_name {
133 "minilm" => Some(OnnxEmbedderConfig {
134 embedder_id: "minilm-384".to_string(),
135 model_id: "all-minilm-l6-v2".to_string(),
136 dimension: 384,
137 pooling: Pooling::Mean,
138 }),
139 "snowflake-arctic-s" => Some(OnnxEmbedderConfig {
140 embedder_id: "snowflake-arctic-s-384".to_string(),
141 model_id: "snowflake-arctic-embed-s".to_string(),
142 dimension: 384,
143 pooling: Pooling::Mean,
144 }),
145 "nomic-embed" => Some(OnnxEmbedderConfig {
146 embedder_id: "nomic-embed-768".to_string(),
147 model_id: "nomic-embed-text-v1.5".to_string(),
148 dimension: 768,
149 pooling: Pooling::Mean,
150 }),
151 _ => None,
152 }
153 }
154
155 pub fn load_from_dir(model_dir: &Path) -> EmbedderResult<Self> {
157 Self::load_with_config(model_dir, OnnxEmbedderConfig::default())
158 }
159
160 pub fn load_with_config(model_dir: &Path, config: OnnxEmbedderConfig) -> EmbedderResult<Self> {
162 if !model_dir.is_dir() {
163 return Err(Self::unavailable_error(
164 &config.embedder_id,
165 format!("model directory not found: {}", model_dir.display()),
166 ));
167 }
168
169 let onnx_path = Self::select_model_file(model_dir).ok_or_else(|| {
170 Self::unavailable_error(
171 &config.embedder_id,
172 format!(
173 "no ONNX model file in {} (checked {} and {})",
174 model_dir.display(),
175 MODEL_ONNX_SUBDIR,
176 MODEL_ONNX_LEGACY
177 ),
178 )
179 })?;
180
181 let required = Self::required_model_files();
182 let mut missing = Vec::new();
183 for name in required {
184 let path = model_dir.join(name);
185 if !path.is_file() {
186 missing.push(*name);
187 }
188 }
189 if !missing.is_empty() {
190 return Err(Self::unavailable_error(
191 &config.embedder_id,
192 format!(
193 "model files missing in {}: {}",
194 model_dir.display(),
195 missing.join(", ")
196 ),
197 ));
198 }
199
200 let model_file = Self::read_required(onnx_path, "model.onnx", &config.embedder_id)?;
201 let tokenizer_file = Self::read_required(
202 model_dir.join(TOKENIZER_JSON),
203 TOKENIZER_JSON,
204 &config.embedder_id,
205 )?;
206 let config_file = Self::read_required(
207 model_dir.join(CONFIG_JSON),
208 CONFIG_JSON,
209 &config.embedder_id,
210 )?;
211 let special_tokens_map_file = Self::read_required(
212 model_dir.join(SPECIAL_TOKENS_JSON),
213 SPECIAL_TOKENS_JSON,
214 &config.embedder_id,
215 )?;
216 let tokenizer_config_file = Self::read_required(
217 model_dir.join(TOKENIZER_CONFIG_JSON),
218 TOKENIZER_CONFIG_JSON,
219 &config.embedder_id,
220 )?;
221
222 let tokenizer_files = TokenizerFiles {
223 tokenizer_file,
224 config_file,
225 special_tokens_map_file,
226 tokenizer_config_file,
227 };
228
229 let mut model = UserDefinedEmbeddingModel::new(model_file, tokenizer_files);
230 model.pooling = Some(config.pooling);
231
232 let init_options = InitOptionsUserDefined::new();
233
234 let model = TextEmbedding::try_new_from_user_defined(model, init_options).map_err(|e| {
235 EmbedderError::EmbeddingFailed {
236 model: config.embedder_id.clone(),
237 source: Box::new(std::io::Error::other(format!("fastembed init failed: {e}"))),
238 }
239 })?;
240
241 Ok(Self {
242 model: Mutex::new(model),
243 id: config.embedder_id,
244 model_id: config.model_id,
245 dimension: config.dimension,
246 })
247 }
248
249 pub fn load_by_name(data_dir: &Path, embedder_name: &str) -> EmbedderResult<Self> {
251 let model_dir = Self::model_dir_for(data_dir, embedder_name).ok_or_else(|| {
252 Self::unavailable_error(
253 embedder_name,
254 format!("unknown embedder: {}", embedder_name),
255 )
256 })?;
257 let config = Self::config_for(embedder_name).ok_or_else(|| {
258 Self::unavailable_error(
259 embedder_name,
260 format!("no config for embedder: {}", embedder_name),
261 )
262 })?;
263 Self::load_with_config(&model_dir, config)
264 }
265
266 pub fn model_id(&self) -> &str {
268 &self.model_id
269 }
270
271 fn read_required(path: PathBuf, label: &str, model_id: &str) -> EmbedderResult<Vec<u8>> {
272 fs::read(&path).map_err(|e| {
273 Self::unavailable_error(
274 model_id,
275 format!("unable to read {label} at {}: {e}", path.display()),
276 )
277 })
278 }
279
280 fn unavailable_error(model: impl Into<String>, reason: impl Into<String>) -> EmbedderError {
281 EmbedderError::EmbedderUnavailable {
282 model: model.into(),
283 reason: reason.into(),
284 }
285 }
286
287 fn normalize_in_place(embedding: &mut [f32]) {
288 let norm_sq: f32 = embedding.iter().map(|x| x * x).sum();
289 if norm_sq.is_finite() && norm_sq > f32::EPSILON {
290 let inv_norm = 1.0 / norm_sq.sqrt();
291 for v in embedding.iter_mut() {
292 *v *= inv_norm;
293 }
294 } else {
295 embedding.fill(0.0);
297 }
298 }
299}
300
301impl Embedder for FastEmbedder {
302 fn embed_sync(&self, text: &str) -> EmbedderResult<Vec<f32>> {
303 if text.is_empty() {
304 return Err(EmbedderError::InvalidConfig {
305 field: "input_text".to_string(),
306 value: "(empty)".to_string(),
307 reason: "empty text".to_string(),
308 });
309 }
310
311 #[allow(unused_mut)]
312 let mut model = self
313 .model
314 .lock()
315 .map_err(|_| EmbedderError::SubsystemError {
316 subsystem: "embedder",
317 source: Box::new(std::io::Error::other("fastembed lock poisoned")),
318 })?;
319
320 let embeddings =
321 model
322 .embed(vec![text], None)
323 .map_err(|e| EmbedderError::EmbeddingFailed {
324 model: self.id.clone(),
325 source: Box::new(std::io::Error::other(format!(
326 "fastembed embed failed: {e}"
327 ))),
328 })?;
329
330 let mut embedding =
331 embeddings
332 .into_iter()
333 .next()
334 .ok_or_else(|| EmbedderError::EmbeddingFailed {
335 model: self.id.clone(),
336 source: Box::new(std::io::Error::other("fastembed returned no embedding")),
337 })?;
338
339 if embedding.len() != self.dimension {
340 return Err(EmbedderError::EmbeddingFailed {
341 model: self.id.clone(),
342 source: Box::new(std::io::Error::other(format!(
343 "fastembed dimension mismatch: expected {}, got {}",
344 self.dimension,
345 embedding.len()
346 ))),
347 });
348 }
349
350 Self::normalize_in_place(&mut embedding);
351 Ok(embedding)
352 }
353
354 fn embed_batch_sync(&self, texts: &[&str]) -> EmbedderResult<Vec<Vec<f32>>> {
355 for text in texts {
356 if text.is_empty() {
357 return Err(EmbedderError::InvalidConfig {
358 field: "input_text".to_string(),
359 value: "(empty)".to_string(),
360 reason: "empty text in batch".to_string(),
361 });
362 }
363 }
364
365 if texts.is_empty() {
366 return Ok(Vec::new());
367 }
368
369 #[allow(unused_mut)]
370 let mut model = self
371 .model
372 .lock()
373 .map_err(|_| EmbedderError::SubsystemError {
374 subsystem: "embedder",
375 source: Box::new(std::io::Error::other("fastembed lock poisoned")),
376 })?;
377
378 let inputs = texts.to_vec();
379 let mut embeddings =
380 model
381 .embed(inputs, None)
382 .map_err(|e| EmbedderError::EmbeddingFailed {
383 model: self.id.clone(),
384 source: Box::new(std::io::Error::other(format!(
385 "fastembed embed failed: {e}"
386 ))),
387 })?;
388
389 for embedding in embeddings.iter_mut() {
390 if embedding.len() != self.dimension {
391 return Err(EmbedderError::EmbeddingFailed {
392 model: self.id.clone(),
393 source: Box::new(std::io::Error::other(format!(
394 "fastembed dimension mismatch: expected {}, got {}",
395 self.dimension,
396 embedding.len()
397 ))),
398 });
399 }
400 Self::normalize_in_place(embedding);
401 }
402
403 Ok(embeddings)
404 }
405
406 fn dimension(&self) -> usize {
407 self.dimension
408 }
409
410 fn id(&self) -> &str {
411 &self.id
412 }
413
414 fn model_name(&self) -> &str {
415 &self.model_id
416 }
417
418 fn is_semantic(&self) -> bool {
419 true
420 }
421
422 fn category(&self) -> ModelCategory {
423 ModelCategory::TransformerEmbedder
424 }
425
426 fn tier(&self) -> ModelTier {
427 ModelTier::Quality
428 }
429}
430
431#[cfg(test)]
432mod tests {
433 use super::*;
434
435 #[test]
436 fn fastembed_missing_files_returns_unavailable() {
437 let tmp = tempfile::tempdir().expect("tempdir");
438 let err = FastEmbedder::load_from_dir(tmp.path())
439 .err()
440 .expect("missing model should fail");
441 assert!(
442 matches!(err, EmbedderError::EmbedderUnavailable { .. }),
443 "expected EmbedderUnavailable, got {err:?}"
444 );
445 }
446
447 #[test]
448 fn unavailable_error_preserves_shape() {
449 let err = FastEmbedder::unavailable_error("test-model", "missing files");
450 assert!(std::error::Error::source(&err).is_none());
451 match err {
452 EmbedderError::EmbedderUnavailable { model, reason } => {
453 assert_eq!(model, "test-model");
454 assert_eq!(reason, "missing files");
455 }
456 other => panic!("expected EmbedderUnavailable, got {other:?}"),
457 }
458 }
459
460 #[test]
461 fn select_model_file_prefers_modern_onnx_layout() {
462 let tmp = tempfile::tempdir().expect("tempdir");
463 std::fs::create_dir_all(tmp.path().join("onnx")).unwrap();
464 std::fs::write(tmp.path().join("onnx/model.onnx"), b"modern").unwrap();
465 std::fs::write(tmp.path().join("model.onnx"), b"legacy").unwrap();
466
467 let selected = FastEmbedder::select_model_file(tmp.path()).unwrap();
468 assert!(
469 selected.ends_with("onnx/model.onnx"),
470 "should prefer onnx/ subdir: {selected:?}"
471 );
472 }
473
474 #[test]
475 fn select_model_file_falls_back_to_legacy() {
476 let tmp = tempfile::tempdir().expect("tempdir");
477 std::fs::write(tmp.path().join("model.onnx"), b"legacy").unwrap();
478
479 let selected = FastEmbedder::select_model_file(tmp.path()).unwrap();
480 assert!(
481 selected.ends_with("model.onnx"),
482 "should fall back to legacy: {selected:?}"
483 );
484 }
485
486 #[test]
487 fn select_model_file_returns_none_for_empty_dir() {
488 let tmp = tempfile::tempdir().expect("tempdir");
489 assert!(FastEmbedder::select_model_file(tmp.path()).is_none());
490 }
491
492 #[test]
493 fn config_for_known_models() {
494 let minilm = FastEmbedder::config_for("minilm").unwrap();
495 assert_eq!(minilm.dimension, 384);
496
497 let snowflake = FastEmbedder::config_for("snowflake-arctic-s").unwrap();
498 assert_eq!(snowflake.dimension, 384);
499
500 let nomic = FastEmbedder::config_for("nomic-embed").unwrap();
501 assert_eq!(nomic.dimension, 768);
502
503 assert!(FastEmbedder::config_for("unknown").is_none());
504 }
505}