Skip to main content

coding_agent_search/search/
fastembed_embedder.rs

1//! FastEmbed-based ML embedders.
2//!
3//! Loads local ONNX model + tokenizer bundles and produces semantic embeddings.
4//! This implementation never downloads model assets; it expects the model files
5//! to be present on disk and returns a clear error when they are missing.
6//!
7//! Supports multiple models:
8//! - MiniLM (baseline)
9//! - EmbeddingGemma (bake-off candidate)
10//! - Qwen3-Embedding (bake-off candidate)
11//! - ModernBERT-embed (bake-off candidate)
12//! - Snowflake Arctic Embed (bake-off candidate)
13//! - Nomic Embed Text (bake-off candidate)
14
15use std::fs;
16use std::path::{Path, PathBuf};
17use std::sync::Mutex;
18
19use fastembed::{
20    InitOptionsUserDefined, Pooling, TextEmbedding, TokenizerFiles, UserDefinedEmbeddingModel,
21};
22
23use super::embedder::{Embedder, EmbedderError, EmbedderResult};
24use frankensearch::{ModelCategory, ModelTier};
25
26// MiniLM constants (baseline)
27const MINILM_MODEL_ID: &str = "all-minilm-l6-v2";
28const MINILM_DIR_NAME: &str = "all-MiniLM-L6-v2";
29const MINILM_EMBEDDER_ID: &str = "minilm-384";
30const MINILM_DIMENSION: usize = 384;
31
32// Standard ONNX file names — prefer onnx/ subdir (modern layout), fall back to flat (legacy).
33pub const MODEL_ONNX_SUBDIR: &str = "onnx/model.onnx";
34pub const MODEL_ONNX_LEGACY: &str = "model.onnx";
35const TOKENIZER_JSON: &str = "tokenizer.json";
36const CONFIG_JSON: &str = "config.json";
37const SPECIAL_TOKENS_JSON: &str = "special_tokens_map.json";
38const TOKENIZER_CONFIG_JSON: &str = "tokenizer_config.json";
39
40/// Configuration for loading an ONNX embedder.
41#[derive(Debug, Clone)]
42pub struct OnnxEmbedderConfig {
43    /// Unique embedder ID (e.g., "minilm-384").
44    pub embedder_id: String,
45    /// Model identifier for logging.
46    pub model_id: String,
47    /// Output embedding dimension.
48    pub dimension: usize,
49    /// Pooling strategy.
50    pub pooling: Pooling,
51}
52
53impl Default for OnnxEmbedderConfig {
54    fn default() -> Self {
55        Self {
56            embedder_id: MINILM_EMBEDDER_ID.to_string(),
57            model_id: MINILM_MODEL_ID.to_string(),
58            dimension: MINILM_DIMENSION,
59            pooling: Pooling::Mean,
60        }
61    }
62}
63
64/// FastEmbed-backed semantic embedder.
65///
66/// Supports multiple ONNX models with configurable dimensions and pooling.
67pub struct FastEmbedder {
68    model: Mutex<TextEmbedding>,
69    id: String,
70    model_id: String,
71    dimension: usize,
72}
73
74impl FastEmbedder {
75    /// Stable embedder identifier for MiniLM (matches vector index naming).
76    pub fn embedder_id_static() -> &'static str {
77        MINILM_EMBEDDER_ID
78    }
79
80    /// Stable model identifier for MiniLM.
81    pub fn model_id_static() -> &'static str {
82        MINILM_MODEL_ID
83    }
84
85    /// Required non-model files for any ONNX embedder.
86    ///
87    /// The ONNX model itself can live at `onnx/model.onnx` (modern) or
88    /// `model.onnx` (legacy) — use [`select_model_file`] to find it.
89    pub fn required_model_files() -> &'static [&'static str] {
90        &[
91            TOKENIZER_JSON,
92            CONFIG_JSON,
93            SPECIAL_TOKENS_JSON,
94            TOKENIZER_CONFIG_JSON,
95        ]
96    }
97
98    /// Candidate ONNX model locations, ordered from preferred to legacy.
99    pub fn model_file_candidates() -> &'static [&'static str] {
100        &[MODEL_ONNX_SUBDIR, MODEL_ONNX_LEGACY]
101    }
102
103    /// Select the ONNX model file, preferring `onnx/model.onnx` over `model.onnx`.
104    pub fn select_model_file(model_dir: &Path) -> Option<PathBuf> {
105        for candidate in Self::model_file_candidates() {
106            let path = model_dir.join(candidate);
107            if path.is_file() {
108                return Some(path);
109            }
110        }
111        None
112    }
113
114    /// Default MiniLM model directory relative to the cass data dir.
115    pub fn default_model_dir(data_dir: &Path) -> PathBuf {
116        data_dir.join("models").join(MINILM_DIR_NAME)
117    }
118
119    /// Get model directory for a specific embedder name.
120    pub fn model_dir_for(data_dir: &Path, embedder_name: &str) -> Option<PathBuf> {
121        let dir_name = match Self::canonical_name(embedder_name)? {
122            "minilm" => MINILM_DIR_NAME,
123            "snowflake-arctic-s" => "snowflake-arctic-embed-s",
124            "nomic-embed" => "nomic-embed-text-v1.5",
125            _ => return None,
126        };
127        Some(data_dir.join("models").join(dir_name))
128    }
129
130    /// Resolve the runtime model directory for an embedder.
131    ///
132    /// `model_dir_for` is the cass-managed cache location. This variant honors
133    /// the explicit FRANKENSEARCH_MODEL_DIR override used by operators who
134    /// pre-stage a model bundle outside the cass data directory.
135    pub fn runtime_model_dir_for(data_dir: &Path, embedder_name: &str) -> Option<PathBuf> {
136        model_dir_override().or_else(|| Self::model_dir_for(data_dir, embedder_name))
137    }
138
139    pub fn canonical_name(embedder_name: &str) -> Option<&'static str> {
140        match embedder_name.trim().to_ascii_lowercase().as_str() {
141            "fastembed" | "minilm" | "all-minilm-l6-v2" | "minilm-384" => Some("minilm"),
142            "snowflake"
143            | "snowflake-arctic-s"
144            | "snowflake-arctic-embed-s"
145            | "snowflake-arctic-s-384" => Some("snowflake-arctic-s"),
146            "nomic" | "nomic-embed" | "nomic-embed-text-v1.5" | "nomic-embed-768" => {
147                Some("nomic-embed")
148            }
149            _ => None,
150        }
151    }
152
153    /// Get config for a specific embedder by name.
154    pub fn config_for(embedder_name: &str) -> Option<OnnxEmbedderConfig> {
155        match Self::canonical_name(embedder_name)? {
156            "minilm" => Some(OnnxEmbedderConfig {
157                embedder_id: "minilm-384".to_string(),
158                model_id: "all-minilm-l6-v2".to_string(),
159                dimension: 384,
160                pooling: Pooling::Mean,
161            }),
162            "snowflake-arctic-s" => Some(OnnxEmbedderConfig {
163                embedder_id: "snowflake-arctic-s-384".to_string(),
164                model_id: "snowflake-arctic-embed-s".to_string(),
165                dimension: 384,
166                pooling: Pooling::Mean,
167            }),
168            "nomic-embed" => Some(OnnxEmbedderConfig {
169                embedder_id: "nomic-embed-768".to_string(),
170                model_id: "nomic-embed-text-v1.5".to_string(),
171                dimension: 768,
172                pooling: Pooling::Mean,
173            }),
174            _ => None,
175        }
176    }
177
178    /// Load the MiniLM model (convenience wrapper).
179    pub fn load_from_dir(model_dir: &Path) -> EmbedderResult<Self> {
180        Self::load_with_config(model_dir, OnnxEmbedderConfig::default())
181    }
182
183    /// Load an ONNX embedder with custom configuration.
184    pub fn load_with_config(model_dir: &Path, config: OnnxEmbedderConfig) -> EmbedderResult<Self> {
185        if !model_dir.is_dir() {
186            return Err(Self::unavailable_error(
187                &config.embedder_id,
188                format!("model directory not found: {}", model_dir.display()),
189            ));
190        }
191
192        let onnx_path = Self::select_model_file(model_dir).ok_or_else(|| {
193            Self::unavailable_error(
194                &config.embedder_id,
195                format!(
196                    "no ONNX model file in {} (checked {} and {})",
197                    model_dir.display(),
198                    MODEL_ONNX_SUBDIR,
199                    MODEL_ONNX_LEGACY
200                ),
201            )
202        })?;
203
204        let required = Self::required_model_files();
205        let mut missing = Vec::new();
206        for name in required {
207            let path = model_dir.join(name);
208            if !path.is_file() {
209                missing.push(*name);
210            }
211        }
212        if !missing.is_empty() {
213            return Err(Self::unavailable_error(
214                &config.embedder_id,
215                format!(
216                    "model files missing in {}: {}",
217                    model_dir.display(),
218                    missing.join(", ")
219                ),
220            ));
221        }
222
223        let model_file = Self::read_required(onnx_path, "model.onnx", &config.embedder_id)?;
224        let tokenizer_file = Self::read_required(
225            model_dir.join(TOKENIZER_JSON),
226            TOKENIZER_JSON,
227            &config.embedder_id,
228        )?;
229        let config_file = Self::read_required(
230            model_dir.join(CONFIG_JSON),
231            CONFIG_JSON,
232            &config.embedder_id,
233        )?;
234        let special_tokens_map_file = Self::read_required(
235            model_dir.join(SPECIAL_TOKENS_JSON),
236            SPECIAL_TOKENS_JSON,
237            &config.embedder_id,
238        )?;
239        let tokenizer_config_file = Self::read_required(
240            model_dir.join(TOKENIZER_CONFIG_JSON),
241            TOKENIZER_CONFIG_JSON,
242            &config.embedder_id,
243        )?;
244
245        let tokenizer_files = TokenizerFiles {
246            tokenizer_file,
247            config_file,
248            special_tokens_map_file,
249            tokenizer_config_file,
250        };
251
252        let mut model = UserDefinedEmbeddingModel::new(model_file, tokenizer_files);
253        model.pooling = Some(config.pooling);
254
255        let init_options = InitOptionsUserDefined::new();
256
257        let model = TextEmbedding::try_new_from_user_defined(model, init_options).map_err(|e| {
258            EmbedderError::EmbeddingFailed {
259                model: config.embedder_id.clone(),
260                source: Box::new(std::io::Error::other(format!("fastembed init failed: {e}"))),
261            }
262        })?;
263
264        Ok(Self {
265            model: Mutex::new(model),
266            id: config.embedder_id,
267            model_id: config.model_id,
268            dimension: config.dimension,
269        })
270    }
271
272    /// Load an embedder by name from the data directory.
273    pub fn load_by_name(data_dir: &Path, embedder_name: &str) -> EmbedderResult<Self> {
274        let canonical_name = Self::canonical_name(embedder_name).ok_or_else(|| {
275            Self::unavailable_error(
276                embedder_name,
277                format!("unknown embedder: {}", embedder_name),
278            )
279        })?;
280        let model_dir = Self::runtime_model_dir_for(data_dir, canonical_name).ok_or_else(|| {
281            Self::unavailable_error(
282                embedder_name,
283                format!("unknown embedder: {}", embedder_name),
284            )
285        })?;
286        let config = Self::config_for(canonical_name).ok_or_else(|| {
287            Self::unavailable_error(
288                embedder_name,
289                format!("no config for embedder: {}", embedder_name),
290            )
291        })?;
292        Self::load_with_config(&model_dir, config)
293    }
294
295    /// Stable model identifier for compatibility checks.
296    pub fn model_id(&self) -> &str {
297        &self.model_id
298    }
299
300    fn read_required(path: PathBuf, label: &str, model_id: &str) -> EmbedderResult<Vec<u8>> {
301        fs::read(&path).map_err(|e| {
302            Self::unavailable_error(
303                model_id,
304                format!("unable to read {label} at {}: {e}", path.display()),
305            )
306        })
307    }
308
309    fn unavailable_error(model: impl Into<String>, reason: impl Into<String>) -> EmbedderError {
310        EmbedderError::EmbedderUnavailable {
311            model: model.into(),
312            reason: reason.into(),
313        }
314    }
315
316    fn normalize_in_place(embedding: &mut [f32]) {
317        let norm_sq: f32 = embedding.iter().map(|x| x * x).sum();
318        if norm_sq.is_finite() && norm_sq > f32::EPSILON {
319            let inv_norm = 1.0 / norm_sq.sqrt();
320            for v in embedding.iter_mut() {
321                *v *= inv_norm;
322            }
323        } else {
324            // NaN/Inf contamination — zero out to prevent poisoning similarity search.
325            embedding.fill(0.0);
326        }
327    }
328}
329
330pub fn model_dir_override() -> Option<PathBuf> {
331    dotenvy::var("FRANKENSEARCH_MODEL_DIR")
332        .ok()
333        .map(|raw| raw.trim().to_string())
334        .filter(|raw| !raw.is_empty())
335        .map(|raw| expand_model_dir_override(&raw))
336}
337
338fn expand_model_dir_override(raw: &str) -> PathBuf {
339    if raw == "~" {
340        return dotenvy::var("HOME")
341            .map(PathBuf::from)
342            .unwrap_or_else(|_| PathBuf::from(raw));
343    }
344    if let Some(rest) = raw.strip_prefix("~/") {
345        return dotenvy::var("HOME")
346            .map(|home| PathBuf::from(home).join(rest))
347            .unwrap_or_else(|_| PathBuf::from(raw));
348    }
349    PathBuf::from(raw)
350}
351
352impl Embedder for FastEmbedder {
353    fn embed_sync(&self, text: &str) -> EmbedderResult<Vec<f32>> {
354        if text.is_empty() {
355            return Err(EmbedderError::InvalidConfig {
356                field: "input_text".to_string(),
357                value: "(empty)".to_string(),
358                reason: "empty text".to_string(),
359            });
360        }
361
362        #[allow(unused_mut)]
363        let mut model = self
364            .model
365            .lock()
366            .map_err(|_| EmbedderError::SubsystemError {
367                subsystem: "embedder",
368                source: Box::new(std::io::Error::other("fastembed lock poisoned")),
369            })?;
370
371        let embeddings =
372            model
373                .embed(vec![text], None)
374                .map_err(|e| EmbedderError::EmbeddingFailed {
375                    model: self.id.clone(),
376                    source: Box::new(std::io::Error::other(format!(
377                        "fastembed embed failed: {e}"
378                    ))),
379                })?;
380
381        let mut embedding =
382            embeddings
383                .into_iter()
384                .next()
385                .ok_or_else(|| EmbedderError::EmbeddingFailed {
386                    model: self.id.clone(),
387                    source: Box::new(std::io::Error::other("fastembed returned no embedding")),
388                })?;
389
390        if embedding.len() != self.dimension {
391            return Err(EmbedderError::EmbeddingFailed {
392                model: self.id.clone(),
393                source: Box::new(std::io::Error::other(format!(
394                    "fastembed dimension mismatch: expected {}, got {}",
395                    self.dimension,
396                    embedding.len()
397                ))),
398            });
399        }
400
401        Self::normalize_in_place(&mut embedding);
402        Ok(embedding)
403    }
404
405    fn embed_batch_sync(&self, texts: &[&str]) -> EmbedderResult<Vec<Vec<f32>>> {
406        for text in texts {
407            if text.is_empty() {
408                return Err(EmbedderError::InvalidConfig {
409                    field: "input_text".to_string(),
410                    value: "(empty)".to_string(),
411                    reason: "empty text in batch".to_string(),
412                });
413            }
414        }
415
416        if texts.is_empty() {
417            return Ok(Vec::new());
418        }
419
420        #[allow(unused_mut)]
421        let mut model = self
422            .model
423            .lock()
424            .map_err(|_| EmbedderError::SubsystemError {
425                subsystem: "embedder",
426                source: Box::new(std::io::Error::other("fastembed lock poisoned")),
427            })?;
428
429        let inputs = texts.to_vec();
430        let mut embeddings =
431            model
432                .embed(inputs, None)
433                .map_err(|e| EmbedderError::EmbeddingFailed {
434                    model: self.id.clone(),
435                    source: Box::new(std::io::Error::other(format!(
436                        "fastembed embed failed: {e}"
437                    ))),
438                })?;
439
440        for embedding in embeddings.iter_mut() {
441            if embedding.len() != self.dimension {
442                return Err(EmbedderError::EmbeddingFailed {
443                    model: self.id.clone(),
444                    source: Box::new(std::io::Error::other(format!(
445                        "fastembed dimension mismatch: expected {}, got {}",
446                        self.dimension,
447                        embedding.len()
448                    ))),
449                });
450            }
451            Self::normalize_in_place(embedding);
452        }
453
454        Ok(embeddings)
455    }
456
457    fn dimension(&self) -> usize {
458        self.dimension
459    }
460
461    fn id(&self) -> &str {
462        &self.id
463    }
464
465    fn model_name(&self) -> &str {
466        &self.model_id
467    }
468
469    fn is_semantic(&self) -> bool {
470        true
471    }
472
473    fn category(&self) -> ModelCategory {
474        ModelCategory::TransformerEmbedder
475    }
476
477    fn tier(&self) -> ModelTier {
478        ModelTier::Quality
479    }
480}
481
482#[cfg(test)]
483mod tests {
484    use super::*;
485    use serial_test::serial;
486
487    #[test]
488    fn fastembed_missing_files_returns_unavailable() {
489        let tmp = tempfile::tempdir().expect("tempdir");
490        let err = FastEmbedder::load_from_dir(tmp.path())
491            .err()
492            .expect("missing model should fail");
493        assert!(
494            matches!(err, EmbedderError::EmbedderUnavailable { .. }),
495            "expected EmbedderUnavailable, got {err:?}"
496        );
497    }
498
499    #[test]
500    fn unavailable_error_preserves_shape() {
501        let err = FastEmbedder::unavailable_error("test-model", "missing files");
502        assert!(std::error::Error::source(&err).is_none());
503        match err {
504            EmbedderError::EmbedderUnavailable { model, reason } => {
505                assert_eq!(model, "test-model");
506                assert_eq!(reason, "missing files");
507            }
508            other => panic!("expected EmbedderUnavailable, got {other:?}"),
509        }
510    }
511
512    #[test]
513    fn select_model_file_prefers_modern_onnx_layout() {
514        let tmp = tempfile::tempdir().expect("tempdir");
515        std::fs::create_dir_all(tmp.path().join("onnx")).unwrap();
516        std::fs::write(tmp.path().join("onnx/model.onnx"), b"modern").unwrap();
517        std::fs::write(tmp.path().join("model.onnx"), b"legacy").unwrap();
518
519        let selected = FastEmbedder::select_model_file(tmp.path()).unwrap();
520        assert!(
521            selected.ends_with("onnx/model.onnx"),
522            "should prefer onnx/ subdir: {selected:?}"
523        );
524    }
525
526    #[test]
527    fn select_model_file_falls_back_to_legacy() {
528        let tmp = tempfile::tempdir().expect("tempdir");
529        std::fs::write(tmp.path().join("model.onnx"), b"legacy").unwrap();
530
531        let selected = FastEmbedder::select_model_file(tmp.path()).unwrap();
532        assert!(
533            selected.ends_with("model.onnx"),
534            "should fall back to legacy: {selected:?}"
535        );
536    }
537
538    #[test]
539    fn select_model_file_returns_none_for_empty_dir() {
540        let tmp = tempfile::tempdir().expect("tempdir");
541        assert!(FastEmbedder::select_model_file(tmp.path()).is_none());
542    }
543
544    #[test]
545    fn config_for_known_models() {
546        let minilm = FastEmbedder::config_for("minilm").unwrap();
547        assert_eq!(minilm.dimension, 384);
548
549        let snowflake = FastEmbedder::config_for("snowflake-arctic-s").unwrap();
550        assert_eq!(snowflake.dimension, 384);
551
552        let nomic = FastEmbedder::config_for("nomic-embed").unwrap();
553        assert_eq!(nomic.dimension, 768);
554
555        assert!(FastEmbedder::config_for("unknown").is_none());
556    }
557
558    #[test]
559    fn canonical_name_accepts_policy_and_index_aliases() {
560        assert_eq!(FastEmbedder::canonical_name("fastembed"), Some("minilm"));
561        assert_eq!(
562            FastEmbedder::canonical_name("snowflake-arctic-s-384"),
563            Some("snowflake-arctic-s")
564        );
565        assert_eq!(
566            FastEmbedder::canonical_name("nomic-embed-text-v1.5"),
567            Some("nomic-embed")
568        );
569    }
570
571    #[test]
572    #[serial]
573    fn runtime_model_dir_honors_frankensearch_override_and_expands_home() {
574        let old_override = dotenvy::var("FRANKENSEARCH_MODEL_DIR").ok();
575        let old_home = dotenvy::var("HOME").ok();
576        unsafe {
577            std::env::set_var("HOME", "/tmp/cass-home-for-model-test");
578            std::env::set_var("FRANKENSEARCH_MODEL_DIR", "~/models/snowflake");
579        }
580
581        let resolved = FastEmbedder::runtime_model_dir_for(Path::new("/tmp/cass"), "snowflake")
582            .expect("runtime model dir");
583        assert_eq!(
584            resolved,
585            PathBuf::from("/tmp/cass-home-for-model-test/models/snowflake")
586        );
587
588        unsafe {
589            if let Some(value) = old_override {
590                std::env::set_var("FRANKENSEARCH_MODEL_DIR", value);
591            } else {
592                std::env::remove_var("FRANKENSEARCH_MODEL_DIR");
593            }
594            if let Some(value) = old_home {
595                std::env::set_var("HOME", value);
596            } else {
597                std::env::remove_var("HOME");
598            }
599        }
600    }
601}