1use std::fs;
16use std::path::{Path, PathBuf};
17use std::sync::Mutex;
18
19use fastembed::{
20 InitOptionsUserDefined, Pooling, TextEmbedding, TokenizerFiles, UserDefinedEmbeddingModel,
21};
22
23use super::embedder::{Embedder, EmbedderError, EmbedderResult};
24use frankensearch::{ModelCategory, ModelTier};
25
26const MINILM_MODEL_ID: &str = "all-minilm-l6-v2";
28const MINILM_DIR_NAME: &str = "all-MiniLM-L6-v2";
29const MINILM_EMBEDDER_ID: &str = "minilm-384";
30const MINILM_DIMENSION: usize = 384;
31
32pub const MODEL_ONNX_SUBDIR: &str = "onnx/model.onnx";
34pub const MODEL_ONNX_LEGACY: &str = "model.onnx";
35const TOKENIZER_JSON: &str = "tokenizer.json";
36const CONFIG_JSON: &str = "config.json";
37const SPECIAL_TOKENS_JSON: &str = "special_tokens_map.json";
38const TOKENIZER_CONFIG_JSON: &str = "tokenizer_config.json";
39
40#[derive(Debug, Clone)]
42pub struct OnnxEmbedderConfig {
43 pub embedder_id: String,
45 pub model_id: String,
47 pub dimension: usize,
49 pub pooling: Pooling,
51}
52
53impl Default for OnnxEmbedderConfig {
54 fn default() -> Self {
55 Self {
56 embedder_id: MINILM_EMBEDDER_ID.to_string(),
57 model_id: MINILM_MODEL_ID.to_string(),
58 dimension: MINILM_DIMENSION,
59 pooling: Pooling::Mean,
60 }
61 }
62}
63
64pub struct FastEmbedder {
68 model: Mutex<TextEmbedding>,
69 id: String,
70 model_id: String,
71 dimension: usize,
72}
73
74impl FastEmbedder {
75 pub fn embedder_id_static() -> &'static str {
77 MINILM_EMBEDDER_ID
78 }
79
80 pub fn model_id_static() -> &'static str {
82 MINILM_MODEL_ID
83 }
84
85 pub fn required_model_files() -> &'static [&'static str] {
90 &[
91 TOKENIZER_JSON,
92 CONFIG_JSON,
93 SPECIAL_TOKENS_JSON,
94 TOKENIZER_CONFIG_JSON,
95 ]
96 }
97
98 pub fn model_file_candidates() -> &'static [&'static str] {
100 &[MODEL_ONNX_SUBDIR, MODEL_ONNX_LEGACY]
101 }
102
103 pub fn select_model_file(model_dir: &Path) -> Option<PathBuf> {
105 for candidate in Self::model_file_candidates() {
106 let path = model_dir.join(candidate);
107 if path.is_file() {
108 return Some(path);
109 }
110 }
111 None
112 }
113
114 pub fn default_model_dir(data_dir: &Path) -> PathBuf {
116 data_dir.join("models").join(MINILM_DIR_NAME)
117 }
118
119 pub fn model_dir_for(data_dir: &Path, embedder_name: &str) -> Option<PathBuf> {
121 let dir_name = match Self::canonical_name(embedder_name)? {
122 "minilm" => MINILM_DIR_NAME,
123 "snowflake-arctic-s" => "snowflake-arctic-embed-s",
124 "nomic-embed" => "nomic-embed-text-v1.5",
125 _ => return None,
126 };
127 Some(data_dir.join("models").join(dir_name))
128 }
129
130 pub fn runtime_model_dir_for(data_dir: &Path, embedder_name: &str) -> Option<PathBuf> {
136 model_dir_override().or_else(|| Self::model_dir_for(data_dir, embedder_name))
137 }
138
139 pub fn canonical_name(embedder_name: &str) -> Option<&'static str> {
140 match embedder_name.trim().to_ascii_lowercase().as_str() {
141 "fastembed" | "minilm" | "all-minilm-l6-v2" | "minilm-384" => Some("minilm"),
142 "snowflake"
143 | "snowflake-arctic-s"
144 | "snowflake-arctic-embed-s"
145 | "snowflake-arctic-s-384" => Some("snowflake-arctic-s"),
146 "nomic" | "nomic-embed" | "nomic-embed-text-v1.5" | "nomic-embed-768" => {
147 Some("nomic-embed")
148 }
149 _ => None,
150 }
151 }
152
153 pub fn config_for(embedder_name: &str) -> Option<OnnxEmbedderConfig> {
155 match Self::canonical_name(embedder_name)? {
156 "minilm" => Some(OnnxEmbedderConfig {
157 embedder_id: "minilm-384".to_string(),
158 model_id: "all-minilm-l6-v2".to_string(),
159 dimension: 384,
160 pooling: Pooling::Mean,
161 }),
162 "snowflake-arctic-s" => Some(OnnxEmbedderConfig {
163 embedder_id: "snowflake-arctic-s-384".to_string(),
164 model_id: "snowflake-arctic-embed-s".to_string(),
165 dimension: 384,
166 pooling: Pooling::Mean,
167 }),
168 "nomic-embed" => Some(OnnxEmbedderConfig {
169 embedder_id: "nomic-embed-768".to_string(),
170 model_id: "nomic-embed-text-v1.5".to_string(),
171 dimension: 768,
172 pooling: Pooling::Mean,
173 }),
174 _ => None,
175 }
176 }
177
178 pub fn load_from_dir(model_dir: &Path) -> EmbedderResult<Self> {
180 Self::load_with_config(model_dir, OnnxEmbedderConfig::default())
181 }
182
183 pub fn load_with_config(model_dir: &Path, config: OnnxEmbedderConfig) -> EmbedderResult<Self> {
185 if !model_dir.is_dir() {
186 return Err(Self::unavailable_error(
187 &config.embedder_id,
188 format!("model directory not found: {}", model_dir.display()),
189 ));
190 }
191
192 let onnx_path = Self::select_model_file(model_dir).ok_or_else(|| {
193 Self::unavailable_error(
194 &config.embedder_id,
195 format!(
196 "no ONNX model file in {} (checked {} and {})",
197 model_dir.display(),
198 MODEL_ONNX_SUBDIR,
199 MODEL_ONNX_LEGACY
200 ),
201 )
202 })?;
203
204 let required = Self::required_model_files();
205 let mut missing = Vec::new();
206 for name in required {
207 let path = model_dir.join(name);
208 if !path.is_file() {
209 missing.push(*name);
210 }
211 }
212 if !missing.is_empty() {
213 return Err(Self::unavailable_error(
214 &config.embedder_id,
215 format!(
216 "model files missing in {}: {}",
217 model_dir.display(),
218 missing.join(", ")
219 ),
220 ));
221 }
222
223 let model_file = Self::read_required(onnx_path, "model.onnx", &config.embedder_id)?;
224 let tokenizer_file = Self::read_required(
225 model_dir.join(TOKENIZER_JSON),
226 TOKENIZER_JSON,
227 &config.embedder_id,
228 )?;
229 let config_file = Self::read_required(
230 model_dir.join(CONFIG_JSON),
231 CONFIG_JSON,
232 &config.embedder_id,
233 )?;
234 let special_tokens_map_file = Self::read_required(
235 model_dir.join(SPECIAL_TOKENS_JSON),
236 SPECIAL_TOKENS_JSON,
237 &config.embedder_id,
238 )?;
239 let tokenizer_config_file = Self::read_required(
240 model_dir.join(TOKENIZER_CONFIG_JSON),
241 TOKENIZER_CONFIG_JSON,
242 &config.embedder_id,
243 )?;
244
245 let tokenizer_files = TokenizerFiles {
246 tokenizer_file,
247 config_file,
248 special_tokens_map_file,
249 tokenizer_config_file,
250 };
251
252 let mut model = UserDefinedEmbeddingModel::new(model_file, tokenizer_files);
253 model.pooling = Some(config.pooling);
254
255 let init_options = InitOptionsUserDefined::new();
256
257 let model = TextEmbedding::try_new_from_user_defined(model, init_options).map_err(|e| {
258 EmbedderError::EmbeddingFailed {
259 model: config.embedder_id.clone(),
260 source: Box::new(std::io::Error::other(format!("fastembed init failed: {e}"))),
261 }
262 })?;
263
264 Ok(Self {
265 model: Mutex::new(model),
266 id: config.embedder_id,
267 model_id: config.model_id,
268 dimension: config.dimension,
269 })
270 }
271
272 pub fn load_by_name(data_dir: &Path, embedder_name: &str) -> EmbedderResult<Self> {
274 let canonical_name = Self::canonical_name(embedder_name).ok_or_else(|| {
275 Self::unavailable_error(
276 embedder_name,
277 format!("unknown embedder: {}", embedder_name),
278 )
279 })?;
280 let model_dir = Self::runtime_model_dir_for(data_dir, canonical_name).ok_or_else(|| {
281 Self::unavailable_error(
282 embedder_name,
283 format!("unknown embedder: {}", embedder_name),
284 )
285 })?;
286 let config = Self::config_for(canonical_name).ok_or_else(|| {
287 Self::unavailable_error(
288 embedder_name,
289 format!("no config for embedder: {}", embedder_name),
290 )
291 })?;
292 Self::load_with_config(&model_dir, config)
293 }
294
295 pub fn model_id(&self) -> &str {
297 &self.model_id
298 }
299
300 fn read_required(path: PathBuf, label: &str, model_id: &str) -> EmbedderResult<Vec<u8>> {
301 fs::read(&path).map_err(|e| {
302 Self::unavailable_error(
303 model_id,
304 format!("unable to read {label} at {}: {e}", path.display()),
305 )
306 })
307 }
308
309 fn unavailable_error(model: impl Into<String>, reason: impl Into<String>) -> EmbedderError {
310 EmbedderError::EmbedderUnavailable {
311 model: model.into(),
312 reason: reason.into(),
313 }
314 }
315
316 fn normalize_in_place(embedding: &mut [f32]) {
317 let norm_sq: f32 = embedding.iter().map(|x| x * x).sum();
318 if norm_sq.is_finite() && norm_sq > f32::EPSILON {
319 let inv_norm = 1.0 / norm_sq.sqrt();
320 for v in embedding.iter_mut() {
321 *v *= inv_norm;
322 }
323 } else {
324 embedding.fill(0.0);
326 }
327 }
328}
329
330pub fn model_dir_override() -> Option<PathBuf> {
331 dotenvy::var("FRANKENSEARCH_MODEL_DIR")
332 .ok()
333 .map(|raw| raw.trim().to_string())
334 .filter(|raw| !raw.is_empty())
335 .map(|raw| expand_model_dir_override(&raw))
336}
337
338fn expand_model_dir_override(raw: &str) -> PathBuf {
339 if raw == "~" {
340 return dotenvy::var("HOME")
341 .map(PathBuf::from)
342 .unwrap_or_else(|_| PathBuf::from(raw));
343 }
344 if let Some(rest) = raw.strip_prefix("~/") {
345 return dotenvy::var("HOME")
346 .map(|home| PathBuf::from(home).join(rest))
347 .unwrap_or_else(|_| PathBuf::from(raw));
348 }
349 PathBuf::from(raw)
350}
351
352impl Embedder for FastEmbedder {
353 fn embed_sync(&self, text: &str) -> EmbedderResult<Vec<f32>> {
354 if text.is_empty() {
355 return Err(EmbedderError::InvalidConfig {
356 field: "input_text".to_string(),
357 value: "(empty)".to_string(),
358 reason: "empty text".to_string(),
359 });
360 }
361
362 #[allow(unused_mut)]
363 let mut model = self
364 .model
365 .lock()
366 .map_err(|_| EmbedderError::SubsystemError {
367 subsystem: "embedder",
368 source: Box::new(std::io::Error::other("fastembed lock poisoned")),
369 })?;
370
371 let embeddings =
372 model
373 .embed(vec![text], None)
374 .map_err(|e| EmbedderError::EmbeddingFailed {
375 model: self.id.clone(),
376 source: Box::new(std::io::Error::other(format!(
377 "fastembed embed failed: {e}"
378 ))),
379 })?;
380
381 let mut embedding =
382 embeddings
383 .into_iter()
384 .next()
385 .ok_or_else(|| EmbedderError::EmbeddingFailed {
386 model: self.id.clone(),
387 source: Box::new(std::io::Error::other("fastembed returned no embedding")),
388 })?;
389
390 if embedding.len() != self.dimension {
391 return Err(EmbedderError::EmbeddingFailed {
392 model: self.id.clone(),
393 source: Box::new(std::io::Error::other(format!(
394 "fastembed dimension mismatch: expected {}, got {}",
395 self.dimension,
396 embedding.len()
397 ))),
398 });
399 }
400
401 Self::normalize_in_place(&mut embedding);
402 Ok(embedding)
403 }
404
405 fn embed_batch_sync(&self, texts: &[&str]) -> EmbedderResult<Vec<Vec<f32>>> {
406 for text in texts {
407 if text.is_empty() {
408 return Err(EmbedderError::InvalidConfig {
409 field: "input_text".to_string(),
410 value: "(empty)".to_string(),
411 reason: "empty text in batch".to_string(),
412 });
413 }
414 }
415
416 if texts.is_empty() {
417 return Ok(Vec::new());
418 }
419
420 #[allow(unused_mut)]
421 let mut model = self
422 .model
423 .lock()
424 .map_err(|_| EmbedderError::SubsystemError {
425 subsystem: "embedder",
426 source: Box::new(std::io::Error::other("fastembed lock poisoned")),
427 })?;
428
429 let inputs = texts.to_vec();
430 let mut embeddings =
431 model
432 .embed(inputs, None)
433 .map_err(|e| EmbedderError::EmbeddingFailed {
434 model: self.id.clone(),
435 source: Box::new(std::io::Error::other(format!(
436 "fastembed embed failed: {e}"
437 ))),
438 })?;
439
440 for embedding in embeddings.iter_mut() {
441 if embedding.len() != self.dimension {
442 return Err(EmbedderError::EmbeddingFailed {
443 model: self.id.clone(),
444 source: Box::new(std::io::Error::other(format!(
445 "fastembed dimension mismatch: expected {}, got {}",
446 self.dimension,
447 embedding.len()
448 ))),
449 });
450 }
451 Self::normalize_in_place(embedding);
452 }
453
454 Ok(embeddings)
455 }
456
457 fn dimension(&self) -> usize {
458 self.dimension
459 }
460
461 fn id(&self) -> &str {
462 &self.id
463 }
464
465 fn model_name(&self) -> &str {
466 &self.model_id
467 }
468
469 fn is_semantic(&self) -> bool {
470 true
471 }
472
473 fn category(&self) -> ModelCategory {
474 ModelCategory::TransformerEmbedder
475 }
476
477 fn tier(&self) -> ModelTier {
478 ModelTier::Quality
479 }
480}
481
482#[cfg(test)]
483mod tests {
484 use super::*;
485 use serial_test::serial;
486
487 #[test]
488 fn fastembed_missing_files_returns_unavailable() {
489 let tmp = tempfile::tempdir().expect("tempdir");
490 let err = FastEmbedder::load_from_dir(tmp.path())
491 .err()
492 .expect("missing model should fail");
493 assert!(
494 matches!(err, EmbedderError::EmbedderUnavailable { .. }),
495 "expected EmbedderUnavailable, got {err:?}"
496 );
497 }
498
499 #[test]
500 fn unavailable_error_preserves_shape() {
501 let err = FastEmbedder::unavailable_error("test-model", "missing files");
502 assert!(std::error::Error::source(&err).is_none());
503 match err {
504 EmbedderError::EmbedderUnavailable { model, reason } => {
505 assert_eq!(model, "test-model");
506 assert_eq!(reason, "missing files");
507 }
508 other => panic!("expected EmbedderUnavailable, got {other:?}"),
509 }
510 }
511
512 #[test]
513 fn select_model_file_prefers_modern_onnx_layout() {
514 let tmp = tempfile::tempdir().expect("tempdir");
515 std::fs::create_dir_all(tmp.path().join("onnx")).unwrap();
516 std::fs::write(tmp.path().join("onnx/model.onnx"), b"modern").unwrap();
517 std::fs::write(tmp.path().join("model.onnx"), b"legacy").unwrap();
518
519 let selected = FastEmbedder::select_model_file(tmp.path()).unwrap();
520 assert!(
521 selected.ends_with("onnx/model.onnx"),
522 "should prefer onnx/ subdir: {selected:?}"
523 );
524 }
525
526 #[test]
527 fn select_model_file_falls_back_to_legacy() {
528 let tmp = tempfile::tempdir().expect("tempdir");
529 std::fs::write(tmp.path().join("model.onnx"), b"legacy").unwrap();
530
531 let selected = FastEmbedder::select_model_file(tmp.path()).unwrap();
532 assert!(
533 selected.ends_with("model.onnx"),
534 "should fall back to legacy: {selected:?}"
535 );
536 }
537
538 #[test]
539 fn select_model_file_returns_none_for_empty_dir() {
540 let tmp = tempfile::tempdir().expect("tempdir");
541 assert!(FastEmbedder::select_model_file(tmp.path()).is_none());
542 }
543
544 #[test]
545 fn config_for_known_models() {
546 let minilm = FastEmbedder::config_for("minilm").unwrap();
547 assert_eq!(minilm.dimension, 384);
548
549 let snowflake = FastEmbedder::config_for("snowflake-arctic-s").unwrap();
550 assert_eq!(snowflake.dimension, 384);
551
552 let nomic = FastEmbedder::config_for("nomic-embed").unwrap();
553 assert_eq!(nomic.dimension, 768);
554
555 assert!(FastEmbedder::config_for("unknown").is_none());
556 }
557
558 #[test]
559 fn canonical_name_accepts_policy_and_index_aliases() {
560 assert_eq!(FastEmbedder::canonical_name("fastembed"), Some("minilm"));
561 assert_eq!(
562 FastEmbedder::canonical_name("snowflake-arctic-s-384"),
563 Some("snowflake-arctic-s")
564 );
565 assert_eq!(
566 FastEmbedder::canonical_name("nomic-embed-text-v1.5"),
567 Some("nomic-embed")
568 );
569 }
570
571 #[test]
572 #[serial]
573 fn runtime_model_dir_honors_frankensearch_override_and_expands_home() {
574 let old_override = dotenvy::var("FRANKENSEARCH_MODEL_DIR").ok();
575 let old_home = dotenvy::var("HOME").ok();
576 unsafe {
577 std::env::set_var("HOME", "/tmp/cass-home-for-model-test");
578 std::env::set_var("FRANKENSEARCH_MODEL_DIR", "~/models/snowflake");
579 }
580
581 let resolved = FastEmbedder::runtime_model_dir_for(Path::new("/tmp/cass"), "snowflake")
582 .expect("runtime model dir");
583 assert_eq!(
584 resolved,
585 PathBuf::from("/tmp/cass-home-for-model-test/models/snowflake")
586 );
587
588 unsafe {
589 if let Some(value) = old_override {
590 std::env::set_var("FRANKENSEARCH_MODEL_DIR", value);
591 } else {
592 std::env::remove_var("FRANKENSEARCH_MODEL_DIR");
593 }
594 if let Some(value) = old_home {
595 std::env::set_var("HOME", value);
596 } else {
597 std::env::remove_var("HOME");
598 }
599 }
600 }
601}