Skip to main content

coding_agent_search/search/
fastembed_embedder.rs

1//! FastEmbed-based ML embedders.
2//!
3//! Loads local ONNX model + tokenizer bundles and produces semantic embeddings.
4//! This implementation never downloads model assets; it expects the model files
5//! to be present on disk and returns a clear error when they are missing.
6//!
7//! Supports multiple models:
8//! - MiniLM (baseline)
9//! - EmbeddingGemma (bake-off candidate)
10//! - Qwen3-Embedding (bake-off candidate)
11//! - ModernBERT-embed (bake-off candidate)
12//! - Snowflake Arctic Embed (bake-off candidate)
13//! - Nomic Embed Text (bake-off candidate)
14
15use std::fs;
16use std::path::{Path, PathBuf};
17use std::sync::Mutex;
18
19use fastembed::{
20    InitOptionsUserDefined, Pooling, TextEmbedding, TokenizerFiles, UserDefinedEmbeddingModel,
21};
22
23use super::embedder::{Embedder, EmbedderError, EmbedderResult};
24use frankensearch::{ModelCategory, ModelTier};
25
26// MiniLM constants (baseline)
27const MINILM_MODEL_ID: &str = "all-minilm-l6-v2";
28const MINILM_DIR_NAME: &str = "all-MiniLM-L6-v2";
29const MINILM_EMBEDDER_ID: &str = "minilm-384";
30const MINILM_DIMENSION: usize = 384;
31
32// Standard ONNX file names — prefer onnx/ subdir (modern layout), fall back to flat (legacy).
33pub const MODEL_ONNX_SUBDIR: &str = "onnx/model.onnx";
34pub const MODEL_ONNX_LEGACY: &str = "model.onnx";
35const TOKENIZER_JSON: &str = "tokenizer.json";
36const CONFIG_JSON: &str = "config.json";
37const SPECIAL_TOKENS_JSON: &str = "special_tokens_map.json";
38const TOKENIZER_CONFIG_JSON: &str = "tokenizer_config.json";
39
40/// Configuration for loading an ONNX embedder.
41#[derive(Debug, Clone)]
42pub struct OnnxEmbedderConfig {
43    /// Unique embedder ID (e.g., "minilm-384").
44    pub embedder_id: String,
45    /// Model identifier for logging.
46    pub model_id: String,
47    /// Output embedding dimension.
48    pub dimension: usize,
49    /// Pooling strategy.
50    pub pooling: Pooling,
51}
52
53impl Default for OnnxEmbedderConfig {
54    fn default() -> Self {
55        Self {
56            embedder_id: MINILM_EMBEDDER_ID.to_string(),
57            model_id: MINILM_MODEL_ID.to_string(),
58            dimension: MINILM_DIMENSION,
59            pooling: Pooling::Mean,
60        }
61    }
62}
63
64/// FastEmbed-backed semantic embedder.
65///
66/// Supports multiple ONNX models with configurable dimensions and pooling.
67pub struct FastEmbedder {
68    model: Mutex<TextEmbedding>,
69    id: String,
70    model_id: String,
71    dimension: usize,
72}
73
74impl FastEmbedder {
75    /// Stable embedder identifier for MiniLM (matches vector index naming).
76    pub fn embedder_id_static() -> &'static str {
77        MINILM_EMBEDDER_ID
78    }
79
80    /// Stable model identifier for MiniLM.
81    pub fn model_id_static() -> &'static str {
82        MINILM_MODEL_ID
83    }
84
85    /// Required non-model files for any ONNX embedder.
86    ///
87    /// The ONNX model itself can live at `onnx/model.onnx` (modern) or
88    /// `model.onnx` (legacy) — use [`select_model_file`] to find it.
89    pub fn required_model_files() -> &'static [&'static str] {
90        &[
91            TOKENIZER_JSON,
92            CONFIG_JSON,
93            SPECIAL_TOKENS_JSON,
94            TOKENIZER_CONFIG_JSON,
95        ]
96    }
97
98    /// Candidate ONNX model locations, ordered from preferred to legacy.
99    pub fn model_file_candidates() -> &'static [&'static str] {
100        &[MODEL_ONNX_SUBDIR, MODEL_ONNX_LEGACY]
101    }
102
103    /// Select the ONNX model file, preferring `onnx/model.onnx` over `model.onnx`.
104    pub fn select_model_file(model_dir: &Path) -> Option<PathBuf> {
105        for candidate in Self::model_file_candidates() {
106            let path = model_dir.join(candidate);
107            if path.is_file() {
108                return Some(path);
109            }
110        }
111        None
112    }
113
114    /// Default MiniLM model directory relative to the cass data dir.
115    pub fn default_model_dir(data_dir: &Path) -> PathBuf {
116        data_dir.join("models").join(MINILM_DIR_NAME)
117    }
118
119    /// Get model directory for a specific embedder name.
120    pub fn model_dir_for(data_dir: &Path, embedder_name: &str) -> Option<PathBuf> {
121        let dir_name = match embedder_name {
122            "minilm" => MINILM_DIR_NAME,
123            "snowflake-arctic-s" => "snowflake-arctic-embed-s",
124            "nomic-embed" => "nomic-embed-text-v1.5",
125            _ => return None,
126        };
127        Some(data_dir.join("models").join(dir_name))
128    }
129
130    /// Get config for a specific embedder by name.
131    pub fn config_for(embedder_name: &str) -> Option<OnnxEmbedderConfig> {
132        match embedder_name {
133            "minilm" => Some(OnnxEmbedderConfig {
134                embedder_id: "minilm-384".to_string(),
135                model_id: "all-minilm-l6-v2".to_string(),
136                dimension: 384,
137                pooling: Pooling::Mean,
138            }),
139            "snowflake-arctic-s" => Some(OnnxEmbedderConfig {
140                embedder_id: "snowflake-arctic-s-384".to_string(),
141                model_id: "snowflake-arctic-embed-s".to_string(),
142                dimension: 384,
143                pooling: Pooling::Mean,
144            }),
145            "nomic-embed" => Some(OnnxEmbedderConfig {
146                embedder_id: "nomic-embed-768".to_string(),
147                model_id: "nomic-embed-text-v1.5".to_string(),
148                dimension: 768,
149                pooling: Pooling::Mean,
150            }),
151            _ => None,
152        }
153    }
154
155    /// Load the MiniLM model (convenience wrapper).
156    pub fn load_from_dir(model_dir: &Path) -> EmbedderResult<Self> {
157        Self::load_with_config(model_dir, OnnxEmbedderConfig::default())
158    }
159
160    /// Load an ONNX embedder with custom configuration.
161    pub fn load_with_config(model_dir: &Path, config: OnnxEmbedderConfig) -> EmbedderResult<Self> {
162        if !model_dir.is_dir() {
163            return Err(Self::unavailable_error(
164                &config.embedder_id,
165                format!("model directory not found: {}", model_dir.display()),
166            ));
167        }
168
169        let onnx_path = Self::select_model_file(model_dir).ok_or_else(|| {
170            Self::unavailable_error(
171                &config.embedder_id,
172                format!(
173                    "no ONNX model file in {} (checked {} and {})",
174                    model_dir.display(),
175                    MODEL_ONNX_SUBDIR,
176                    MODEL_ONNX_LEGACY
177                ),
178            )
179        })?;
180
181        let required = Self::required_model_files();
182        let mut missing = Vec::new();
183        for name in required {
184            let path = model_dir.join(name);
185            if !path.is_file() {
186                missing.push(*name);
187            }
188        }
189        if !missing.is_empty() {
190            return Err(Self::unavailable_error(
191                &config.embedder_id,
192                format!(
193                    "model files missing in {}: {}",
194                    model_dir.display(),
195                    missing.join(", ")
196                ),
197            ));
198        }
199
200        let model_file = Self::read_required(onnx_path, "model.onnx", &config.embedder_id)?;
201        let tokenizer_file = Self::read_required(
202            model_dir.join(TOKENIZER_JSON),
203            TOKENIZER_JSON,
204            &config.embedder_id,
205        )?;
206        let config_file = Self::read_required(
207            model_dir.join(CONFIG_JSON),
208            CONFIG_JSON,
209            &config.embedder_id,
210        )?;
211        let special_tokens_map_file = Self::read_required(
212            model_dir.join(SPECIAL_TOKENS_JSON),
213            SPECIAL_TOKENS_JSON,
214            &config.embedder_id,
215        )?;
216        let tokenizer_config_file = Self::read_required(
217            model_dir.join(TOKENIZER_CONFIG_JSON),
218            TOKENIZER_CONFIG_JSON,
219            &config.embedder_id,
220        )?;
221
222        let tokenizer_files = TokenizerFiles {
223            tokenizer_file,
224            config_file,
225            special_tokens_map_file,
226            tokenizer_config_file,
227        };
228
229        let mut model = UserDefinedEmbeddingModel::new(model_file, tokenizer_files);
230        model.pooling = Some(config.pooling);
231
232        let init_options = InitOptionsUserDefined::new();
233
234        let model = TextEmbedding::try_new_from_user_defined(model, init_options).map_err(|e| {
235            EmbedderError::EmbeddingFailed {
236                model: config.embedder_id.clone(),
237                source: Box::new(std::io::Error::other(format!("fastembed init failed: {e}"))),
238            }
239        })?;
240
241        Ok(Self {
242            model: Mutex::new(model),
243            id: config.embedder_id,
244            model_id: config.model_id,
245            dimension: config.dimension,
246        })
247    }
248
249    /// Load an embedder by name from the data directory.
250    pub fn load_by_name(data_dir: &Path, embedder_name: &str) -> EmbedderResult<Self> {
251        let model_dir = Self::model_dir_for(data_dir, embedder_name).ok_or_else(|| {
252            Self::unavailable_error(
253                embedder_name,
254                format!("unknown embedder: {}", embedder_name),
255            )
256        })?;
257        let config = Self::config_for(embedder_name).ok_or_else(|| {
258            Self::unavailable_error(
259                embedder_name,
260                format!("no config for embedder: {}", embedder_name),
261            )
262        })?;
263        Self::load_with_config(&model_dir, config)
264    }
265
266    /// Stable model identifier for compatibility checks.
267    pub fn model_id(&self) -> &str {
268        &self.model_id
269    }
270
271    fn read_required(path: PathBuf, label: &str, model_id: &str) -> EmbedderResult<Vec<u8>> {
272        fs::read(&path).map_err(|e| {
273            Self::unavailable_error(
274                model_id,
275                format!("unable to read {label} at {}: {e}", path.display()),
276            )
277        })
278    }
279
280    fn unavailable_error(model: impl Into<String>, reason: impl Into<String>) -> EmbedderError {
281        EmbedderError::EmbedderUnavailable {
282            model: model.into(),
283            reason: reason.into(),
284        }
285    }
286
287    fn normalize_in_place(embedding: &mut [f32]) {
288        let norm_sq: f32 = embedding.iter().map(|x| x * x).sum();
289        if norm_sq.is_finite() && norm_sq > f32::EPSILON {
290            let inv_norm = 1.0 / norm_sq.sqrt();
291            for v in embedding.iter_mut() {
292                *v *= inv_norm;
293            }
294        } else {
295            // NaN/Inf contamination — zero out to prevent poisoning similarity search.
296            embedding.fill(0.0);
297        }
298    }
299}
300
301impl Embedder for FastEmbedder {
302    fn embed_sync(&self, text: &str) -> EmbedderResult<Vec<f32>> {
303        if text.is_empty() {
304            return Err(EmbedderError::InvalidConfig {
305                field: "input_text".to_string(),
306                value: "(empty)".to_string(),
307                reason: "empty text".to_string(),
308            });
309        }
310
311        #[allow(unused_mut)]
312        let mut model = self
313            .model
314            .lock()
315            .map_err(|_| EmbedderError::SubsystemError {
316                subsystem: "embedder",
317                source: Box::new(std::io::Error::other("fastembed lock poisoned")),
318            })?;
319
320        let embeddings =
321            model
322                .embed(vec![text], None)
323                .map_err(|e| EmbedderError::EmbeddingFailed {
324                    model: self.id.clone(),
325                    source: Box::new(std::io::Error::other(format!(
326                        "fastembed embed failed: {e}"
327                    ))),
328                })?;
329
330        let mut embedding =
331            embeddings
332                .into_iter()
333                .next()
334                .ok_or_else(|| EmbedderError::EmbeddingFailed {
335                    model: self.id.clone(),
336                    source: Box::new(std::io::Error::other("fastembed returned no embedding")),
337                })?;
338
339        if embedding.len() != self.dimension {
340            return Err(EmbedderError::EmbeddingFailed {
341                model: self.id.clone(),
342                source: Box::new(std::io::Error::other(format!(
343                    "fastembed dimension mismatch: expected {}, got {}",
344                    self.dimension,
345                    embedding.len()
346                ))),
347            });
348        }
349
350        Self::normalize_in_place(&mut embedding);
351        Ok(embedding)
352    }
353
354    fn embed_batch_sync(&self, texts: &[&str]) -> EmbedderResult<Vec<Vec<f32>>> {
355        for text in texts {
356            if text.is_empty() {
357                return Err(EmbedderError::InvalidConfig {
358                    field: "input_text".to_string(),
359                    value: "(empty)".to_string(),
360                    reason: "empty text in batch".to_string(),
361                });
362            }
363        }
364
365        if texts.is_empty() {
366            return Ok(Vec::new());
367        }
368
369        #[allow(unused_mut)]
370        let mut model = self
371            .model
372            .lock()
373            .map_err(|_| EmbedderError::SubsystemError {
374                subsystem: "embedder",
375                source: Box::new(std::io::Error::other("fastembed lock poisoned")),
376            })?;
377
378        let inputs = texts.to_vec();
379        let mut embeddings =
380            model
381                .embed(inputs, None)
382                .map_err(|e| EmbedderError::EmbeddingFailed {
383                    model: self.id.clone(),
384                    source: Box::new(std::io::Error::other(format!(
385                        "fastembed embed failed: {e}"
386                    ))),
387                })?;
388
389        for embedding in embeddings.iter_mut() {
390            if embedding.len() != self.dimension {
391                return Err(EmbedderError::EmbeddingFailed {
392                    model: self.id.clone(),
393                    source: Box::new(std::io::Error::other(format!(
394                        "fastembed dimension mismatch: expected {}, got {}",
395                        self.dimension,
396                        embedding.len()
397                    ))),
398                });
399            }
400            Self::normalize_in_place(embedding);
401        }
402
403        Ok(embeddings)
404    }
405
406    fn dimension(&self) -> usize {
407        self.dimension
408    }
409
410    fn id(&self) -> &str {
411        &self.id
412    }
413
414    fn model_name(&self) -> &str {
415        &self.model_id
416    }
417
418    fn is_semantic(&self) -> bool {
419        true
420    }
421
422    fn category(&self) -> ModelCategory {
423        ModelCategory::TransformerEmbedder
424    }
425
426    fn tier(&self) -> ModelTier {
427        ModelTier::Quality
428    }
429}
430
431#[cfg(test)]
432mod tests {
433    use super::*;
434
435    #[test]
436    fn fastembed_missing_files_returns_unavailable() {
437        let tmp = tempfile::tempdir().expect("tempdir");
438        let err = FastEmbedder::load_from_dir(tmp.path())
439            .err()
440            .expect("missing model should fail");
441        assert!(
442            matches!(err, EmbedderError::EmbedderUnavailable { .. }),
443            "expected EmbedderUnavailable, got {err:?}"
444        );
445    }
446
447    #[test]
448    fn unavailable_error_preserves_shape() {
449        let err = FastEmbedder::unavailable_error("test-model", "missing files");
450        assert!(std::error::Error::source(&err).is_none());
451        match err {
452            EmbedderError::EmbedderUnavailable { model, reason } => {
453                assert_eq!(model, "test-model");
454                assert_eq!(reason, "missing files");
455            }
456            other => panic!("expected EmbedderUnavailable, got {other:?}"),
457        }
458    }
459
460    #[test]
461    fn select_model_file_prefers_modern_onnx_layout() {
462        let tmp = tempfile::tempdir().expect("tempdir");
463        std::fs::create_dir_all(tmp.path().join("onnx")).unwrap();
464        std::fs::write(tmp.path().join("onnx/model.onnx"), b"modern").unwrap();
465        std::fs::write(tmp.path().join("model.onnx"), b"legacy").unwrap();
466
467        let selected = FastEmbedder::select_model_file(tmp.path()).unwrap();
468        assert!(
469            selected.ends_with("onnx/model.onnx"),
470            "should prefer onnx/ subdir: {selected:?}"
471        );
472    }
473
474    #[test]
475    fn select_model_file_falls_back_to_legacy() {
476        let tmp = tempfile::tempdir().expect("tempdir");
477        std::fs::write(tmp.path().join("model.onnx"), b"legacy").unwrap();
478
479        let selected = FastEmbedder::select_model_file(tmp.path()).unwrap();
480        assert!(
481            selected.ends_with("model.onnx"),
482            "should fall back to legacy: {selected:?}"
483        );
484    }
485
486    #[test]
487    fn select_model_file_returns_none_for_empty_dir() {
488        let tmp = tempfile::tempdir().expect("tempdir");
489        assert!(FastEmbedder::select_model_file(tmp.path()).is_none());
490    }
491
492    #[test]
493    fn config_for_known_models() {
494        let minilm = FastEmbedder::config_for("minilm").unwrap();
495        assert_eq!(minilm.dimension, 384);
496
497        let snowflake = FastEmbedder::config_for("snowflake-arctic-s").unwrap();
498        assert_eq!(snowflake.dimension, 384);
499
500        let nomic = FastEmbedder::config_for("nomic-embed").unwrap();
501        assert_eq!(nomic.dimension, 768);
502
503        assert!(FastEmbedder::config_for("unknown").is_none());
504    }
505}