1use std::path::{Path, PathBuf};
28
29#[cfg(feature = "semantic")]
30use std::fs;
31#[cfg(feature = "semantic")]
32use std::sync::Mutex;
33
34#[cfg(feature = "semantic")]
35use fastembed::{
36 InitOptionsUserDefined, Pooling, TextEmbedding, TokenizerFiles, UserDefinedEmbeddingModel,
37};
38
39use super::embedder::{Embedder, EmbedderError, EmbedderResult};
40use frankensearch::{ModelCategory, ModelTier};
41
42#[cfg(not(feature = "semantic"))]
49#[derive(Debug, Clone, Copy, PartialEq, Eq)]
50pub enum Pooling {
51 Mean,
52}
53
54const MINILM_MODEL_ID: &str = "all-minilm-l6-v2";
56const MINILM_DIR_NAME: &str = "all-MiniLM-L6-v2";
57const MINILM_EMBEDDER_ID: &str = "minilm-384";
58const MINILM_DIMENSION: usize = 384;
59
60pub const MODEL_ONNX_SUBDIR: &str = "onnx/model.onnx";
62pub const MODEL_ONNX_LEGACY: &str = "model.onnx";
63const TOKENIZER_JSON: &str = "tokenizer.json";
64const CONFIG_JSON: &str = "config.json";
65const SPECIAL_TOKENS_JSON: &str = "special_tokens_map.json";
66const TOKENIZER_CONFIG_JSON: &str = "tokenizer_config.json";
67
68#[derive(Debug, Clone)]
70pub struct OnnxEmbedderConfig {
71 pub embedder_id: String,
73 pub model_id: String,
75 pub dimension: usize,
77 pub pooling: Pooling,
79}
80
81impl Default for OnnxEmbedderConfig {
82 fn default() -> Self {
83 Self {
84 embedder_id: MINILM_EMBEDDER_ID.to_string(),
85 model_id: MINILM_MODEL_ID.to_string(),
86 dimension: MINILM_DIMENSION,
87 pooling: Pooling::Mean,
88 }
89 }
90}
91
92pub struct FastEmbedder {
100 #[cfg(feature = "semantic")]
101 model: Mutex<TextEmbedding>,
102 id: String,
103 model_id: String,
104 dimension: usize,
105}
106
107impl FastEmbedder {
108 pub fn embedder_id_static() -> &'static str {
110 MINILM_EMBEDDER_ID
111 }
112
113 pub fn model_id_static() -> &'static str {
115 MINILM_MODEL_ID
116 }
117
118 pub fn required_model_files() -> &'static [&'static str] {
123 &[
124 TOKENIZER_JSON,
125 CONFIG_JSON,
126 SPECIAL_TOKENS_JSON,
127 TOKENIZER_CONFIG_JSON,
128 ]
129 }
130
131 pub fn model_file_candidates() -> &'static [&'static str] {
133 &[MODEL_ONNX_SUBDIR, MODEL_ONNX_LEGACY]
134 }
135
136 pub fn select_model_file(model_dir: &Path) -> Option<PathBuf> {
138 for candidate in Self::model_file_candidates() {
139 let path = model_dir.join(candidate);
140 if path.is_file() {
141 return Some(path);
142 }
143 }
144 None
145 }
146
147 pub fn default_model_dir(data_dir: &Path) -> PathBuf {
149 data_dir.join("models").join(MINILM_DIR_NAME)
150 }
151
152 pub fn model_dir_for(data_dir: &Path, embedder_name: &str) -> Option<PathBuf> {
154 let dir_name = match Self::canonical_name(embedder_name)? {
155 "minilm" => MINILM_DIR_NAME,
156 "snowflake-arctic-s" => "snowflake-arctic-embed-s",
157 "nomic-embed" => "nomic-embed-text-v1.5",
158 _ => return None,
159 };
160 Some(data_dir.join("models").join(dir_name))
161 }
162
163 pub fn runtime_model_dir_for(data_dir: &Path, embedder_name: &str) -> Option<PathBuf> {
169 model_dir_override().or_else(|| Self::model_dir_for(data_dir, embedder_name))
170 }
171
172 pub fn canonical_name(embedder_name: &str) -> Option<&'static str> {
173 match embedder_name.trim().to_ascii_lowercase().as_str() {
174 "fastembed" | "minilm" | "all-minilm-l6-v2" | "minilm-384" => Some("minilm"),
175 "snowflake"
176 | "snowflake-arctic-s"
177 | "snowflake-arctic-embed-s"
178 | "snowflake-arctic-s-384" => Some("snowflake-arctic-s"),
179 "nomic" | "nomic-embed" | "nomic-embed-text-v1.5" | "nomic-embed-768" => {
180 Some("nomic-embed")
181 }
182 _ => None,
183 }
184 }
185
186 pub fn config_for(embedder_name: &str) -> Option<OnnxEmbedderConfig> {
188 match Self::canonical_name(embedder_name)? {
189 "minilm" => Some(OnnxEmbedderConfig {
190 embedder_id: "minilm-384".to_string(),
191 model_id: "all-minilm-l6-v2".to_string(),
192 dimension: 384,
193 pooling: Pooling::Mean,
194 }),
195 "snowflake-arctic-s" => Some(OnnxEmbedderConfig {
196 embedder_id: "snowflake-arctic-s-384".to_string(),
197 model_id: "snowflake-arctic-embed-s".to_string(),
198 dimension: 384,
199 pooling: Pooling::Mean,
200 }),
201 "nomic-embed" => Some(OnnxEmbedderConfig {
202 embedder_id: "nomic-embed-768".to_string(),
203 model_id: "nomic-embed-text-v1.5".to_string(),
204 dimension: 768,
205 pooling: Pooling::Mean,
206 }),
207 _ => None,
208 }
209 }
210
211 #[cfg(feature = "semantic")]
217 pub fn load_from_dir(model_dir: &Path) -> EmbedderResult<Self> {
218 Self::load_with_config(model_dir, OnnxEmbedderConfig::default())
219 }
220
221 #[cfg(not(feature = "semantic"))]
222 pub fn load_from_dir(_model_dir: &Path) -> EmbedderResult<Self> {
223 Err(Self::unavailable_error(
224 MINILM_EMBEDDER_ID,
225 "semantic search is not available in this build (cass was built without the `semantic` feature; rebuild with `--features semantic` or use the full release artifact)",
226 ))
227 }
228
229 #[cfg(feature = "semantic")]
231 pub fn load_with_config(model_dir: &Path, config: OnnxEmbedderConfig) -> EmbedderResult<Self> {
232 if !model_dir.is_dir() {
233 return Err(Self::unavailable_error(
234 &config.embedder_id,
235 format!("model directory not found: {}", model_dir.display()),
236 ));
237 }
238
239 let onnx_path = Self::select_model_file(model_dir).ok_or_else(|| {
240 Self::unavailable_error(
241 &config.embedder_id,
242 format!(
243 "no ONNX model file in {} (checked {} and {})",
244 model_dir.display(),
245 MODEL_ONNX_SUBDIR,
246 MODEL_ONNX_LEGACY
247 ),
248 )
249 })?;
250
251 let required = Self::required_model_files();
252 let mut missing = Vec::new();
253 for name in required {
254 let path = model_dir.join(name);
255 if !path.is_file() {
256 missing.push(*name);
257 }
258 }
259 if !missing.is_empty() {
260 return Err(Self::unavailable_error(
261 &config.embedder_id,
262 format!(
263 "model files missing in {}: {}",
264 model_dir.display(),
265 missing.join(", ")
266 ),
267 ));
268 }
269
270 let model_file = Self::read_required(onnx_path, "model.onnx", &config.embedder_id)?;
271 let tokenizer_file = Self::read_required(
272 model_dir.join(TOKENIZER_JSON),
273 TOKENIZER_JSON,
274 &config.embedder_id,
275 )?;
276 let config_file = Self::read_required(
277 model_dir.join(CONFIG_JSON),
278 CONFIG_JSON,
279 &config.embedder_id,
280 )?;
281 let special_tokens_map_file = Self::read_required(
282 model_dir.join(SPECIAL_TOKENS_JSON),
283 SPECIAL_TOKENS_JSON,
284 &config.embedder_id,
285 )?;
286 let tokenizer_config_file = Self::read_required(
287 model_dir.join(TOKENIZER_CONFIG_JSON),
288 TOKENIZER_CONFIG_JSON,
289 &config.embedder_id,
290 )?;
291
292 let tokenizer_files = TokenizerFiles {
293 tokenizer_file,
294 config_file,
295 special_tokens_map_file,
296 tokenizer_config_file,
297 };
298
299 let mut model = UserDefinedEmbeddingModel::new(model_file, tokenizer_files);
300 model.pooling = Some(config.pooling);
301
302 let init_options = InitOptionsUserDefined::new();
303
304 let model = TextEmbedding::try_new_from_user_defined(model, init_options).map_err(|e| {
305 EmbedderError::EmbeddingFailed {
306 model: config.embedder_id.clone(),
307 source: Box::new(std::io::Error::other(format!("fastembed init failed: {e}"))),
308 }
309 })?;
310
311 Ok(Self {
312 model: Mutex::new(model),
313 id: config.embedder_id,
314 model_id: config.model_id,
315 dimension: config.dimension,
316 })
317 }
318
319 #[cfg(not(feature = "semantic"))]
321 pub fn load_with_config(_model_dir: &Path, config: OnnxEmbedderConfig) -> EmbedderResult<Self> {
322 Err(Self::unavailable_error(
323 &config.embedder_id,
324 "semantic search is not available in this build (cass was built without the `semantic` feature; rebuild with `--features semantic` or use the full release artifact)",
325 ))
326 }
327
328 #[cfg(feature = "semantic")]
330 pub fn load_by_name(data_dir: &Path, embedder_name: &str) -> EmbedderResult<Self> {
331 let canonical_name = Self::canonical_name(embedder_name).ok_or_else(|| {
332 Self::unavailable_error(
333 embedder_name,
334 format!("unknown embedder: {}", embedder_name),
335 )
336 })?;
337 let model_dir = Self::runtime_model_dir_for(data_dir, canonical_name).ok_or_else(|| {
338 Self::unavailable_error(
339 embedder_name,
340 format!("unknown embedder: {}", embedder_name),
341 )
342 })?;
343 let config = Self::config_for(canonical_name).ok_or_else(|| {
344 Self::unavailable_error(
345 embedder_name,
346 format!("no config for embedder: {}", embedder_name),
347 )
348 })?;
349 Self::load_with_config(&model_dir, config)
350 }
351
352 #[cfg(not(feature = "semantic"))]
353 pub fn load_by_name(_data_dir: &Path, embedder_name: &str) -> EmbedderResult<Self> {
354 Err(Self::unavailable_error(
355 embedder_name,
356 "semantic search is not available in this build (cass was built without the `semantic` feature; rebuild with `--features semantic` or use the full release artifact)",
357 ))
358 }
359
360 pub fn model_id(&self) -> &str {
362 &self.model_id
363 }
364
365 #[cfg(feature = "semantic")]
366 fn read_required(path: PathBuf, label: &str, model_id: &str) -> EmbedderResult<Vec<u8>> {
367 fs::read(&path).map_err(|e| {
368 Self::unavailable_error(
369 model_id,
370 format!("unable to read {label} at {}: {e}", path.display()),
371 )
372 })
373 }
374
375 fn unavailable_error(model: impl Into<String>, reason: impl Into<String>) -> EmbedderError {
376 EmbedderError::EmbedderUnavailable {
377 model: model.into(),
378 reason: reason.into(),
379 }
380 }
381
382 #[cfg(feature = "semantic")]
383 fn normalize_in_place(embedding: &mut [f32]) {
384 let norm_sq: f32 = embedding.iter().map(|x| x * x).sum();
385 if norm_sq.is_finite() && norm_sq > f32::EPSILON {
386 let inv_norm = 1.0 / norm_sq.sqrt();
387 for v in embedding.iter_mut() {
388 *v *= inv_norm;
389 }
390 } else {
391 embedding.fill(0.0);
393 }
394 }
395}
396
397pub fn model_dir_override() -> Option<PathBuf> {
398 dotenvy::var("FRANKENSEARCH_MODEL_DIR")
399 .ok()
400 .map(|raw| raw.trim().to_string())
401 .filter(|raw| !raw.is_empty())
402 .map(|raw| expand_model_dir_override(&raw))
403}
404
405fn expand_model_dir_override(raw: &str) -> PathBuf {
406 if raw == "~" {
407 return dotenvy::var("HOME")
408 .map(PathBuf::from)
409 .unwrap_or_else(|_| PathBuf::from(raw));
410 }
411 if let Some(rest) = raw.strip_prefix("~/") {
412 return dotenvy::var("HOME")
413 .map(|home| PathBuf::from(home).join(rest))
414 .unwrap_or_else(|_| PathBuf::from(raw));
415 }
416 PathBuf::from(raw)
417}
418
419#[cfg(feature = "semantic")]
420impl Embedder for FastEmbedder {
421 fn embed_sync(&self, text: &str) -> EmbedderResult<Vec<f32>> {
422 if text.is_empty() {
423 return Err(EmbedderError::InvalidConfig {
424 field: "input_text".to_string(),
425 value: "(empty)".to_string(),
426 reason: "empty text".to_string(),
427 });
428 }
429
430 #[allow(unused_mut)]
431 let mut model = self
432 .model
433 .lock()
434 .map_err(|_| EmbedderError::SubsystemError {
435 subsystem: "embedder",
436 source: Box::new(std::io::Error::other("fastembed lock poisoned")),
437 })?;
438
439 let embeddings =
440 model
441 .embed(vec![text], None)
442 .map_err(|e| EmbedderError::EmbeddingFailed {
443 model: self.id.clone(),
444 source: Box::new(std::io::Error::other(format!(
445 "fastembed embed failed: {e}"
446 ))),
447 })?;
448
449 let mut embedding =
450 embeddings
451 .into_iter()
452 .next()
453 .ok_or_else(|| EmbedderError::EmbeddingFailed {
454 model: self.id.clone(),
455 source: Box::new(std::io::Error::other("fastembed returned no embedding")),
456 })?;
457
458 if embedding.len() != self.dimension {
459 return Err(EmbedderError::EmbeddingFailed {
460 model: self.id.clone(),
461 source: Box::new(std::io::Error::other(format!(
462 "fastembed dimension mismatch: expected {}, got {}",
463 self.dimension,
464 embedding.len()
465 ))),
466 });
467 }
468
469 Self::normalize_in_place(&mut embedding);
470 Ok(embedding)
471 }
472
473 fn embed_batch_sync(&self, texts: &[&str]) -> EmbedderResult<Vec<Vec<f32>>> {
474 for text in texts {
475 if text.is_empty() {
476 return Err(EmbedderError::InvalidConfig {
477 field: "input_text".to_string(),
478 value: "(empty)".to_string(),
479 reason: "empty text in batch".to_string(),
480 });
481 }
482 }
483
484 if texts.is_empty() {
485 return Ok(Vec::new());
486 }
487
488 #[allow(unused_mut)]
489 let mut model = self
490 .model
491 .lock()
492 .map_err(|_| EmbedderError::SubsystemError {
493 subsystem: "embedder",
494 source: Box::new(std::io::Error::other("fastembed lock poisoned")),
495 })?;
496
497 let inputs = texts.to_vec();
498 let mut embeddings =
499 model
500 .embed(inputs, None)
501 .map_err(|e| EmbedderError::EmbeddingFailed {
502 model: self.id.clone(),
503 source: Box::new(std::io::Error::other(format!(
504 "fastembed embed failed: {e}"
505 ))),
506 })?;
507
508 for embedding in embeddings.iter_mut() {
509 if embedding.len() != self.dimension {
510 return Err(EmbedderError::EmbeddingFailed {
511 model: self.id.clone(),
512 source: Box::new(std::io::Error::other(format!(
513 "fastembed dimension mismatch: expected {}, got {}",
514 self.dimension,
515 embedding.len()
516 ))),
517 });
518 }
519 Self::normalize_in_place(embedding);
520 }
521
522 Ok(embeddings)
523 }
524
525 fn dimension(&self) -> usize {
526 self.dimension
527 }
528
529 fn id(&self) -> &str {
530 &self.id
531 }
532
533 fn model_name(&self) -> &str {
534 &self.model_id
535 }
536
537 fn is_semantic(&self) -> bool {
538 true
539 }
540
541 fn category(&self) -> ModelCategory {
542 ModelCategory::TransformerEmbedder
543 }
544
545 fn tier(&self) -> ModelTier {
546 ModelTier::Quality
547 }
548}
549
550#[cfg(not(feature = "semantic"))]
556impl Embedder for FastEmbedder {
557 fn embed_sync(&self, _text: &str) -> EmbedderResult<Vec<f32>> {
558 Err(Self::unavailable_error(
559 &self.id,
560 "semantic search is not available in this build (cass was built without the `semantic` feature; rebuild with `--features semantic` or use the full release artifact)",
561 ))
562 }
563
564 fn embed_batch_sync(&self, _texts: &[&str]) -> EmbedderResult<Vec<Vec<f32>>> {
565 Err(Self::unavailable_error(
566 &self.id,
567 "semantic search is not available in this build (cass was built without the `semantic` feature; rebuild with `--features semantic` or use the full release artifact)",
568 ))
569 }
570
571 fn dimension(&self) -> usize {
572 self.dimension
573 }
574
575 fn id(&self) -> &str {
576 &self.id
577 }
578
579 fn model_name(&self) -> &str {
580 &self.model_id
581 }
582
583 fn is_semantic(&self) -> bool {
584 true
585 }
586
587 fn category(&self) -> ModelCategory {
588 ModelCategory::TransformerEmbedder
589 }
590
591 fn tier(&self) -> ModelTier {
592 ModelTier::Quality
593 }
594}
595
596#[cfg(test)]
597mod tests {
598 use super::*;
599 use serial_test::serial;
600
601 #[test]
602 fn fastembed_missing_files_returns_unavailable() {
603 let tmp = tempfile::tempdir().expect("tempdir");
604 let err = FastEmbedder::load_from_dir(tmp.path())
605 .err()
606 .expect("missing model should fail");
607 assert!(
608 matches!(err, EmbedderError::EmbedderUnavailable { .. }),
609 "expected EmbedderUnavailable, got {err:?}"
610 );
611 }
612
613 #[test]
614 fn unavailable_error_preserves_shape() {
615 let err = FastEmbedder::unavailable_error("test-model", "missing files");
616 assert!(std::error::Error::source(&err).is_none());
617 match err {
618 EmbedderError::EmbedderUnavailable { model, reason } => {
619 assert_eq!(model, "test-model");
620 assert_eq!(reason, "missing files");
621 }
622 other => panic!("expected EmbedderUnavailable, got {other:?}"),
623 }
624 }
625
626 #[test]
627 fn select_model_file_prefers_modern_onnx_layout() {
628 let tmp = tempfile::tempdir().expect("tempdir");
629 std::fs::create_dir_all(tmp.path().join("onnx")).unwrap();
630 std::fs::write(tmp.path().join("onnx/model.onnx"), b"modern").unwrap();
631 std::fs::write(tmp.path().join("model.onnx"), b"legacy").unwrap();
632
633 let selected = FastEmbedder::select_model_file(tmp.path()).unwrap();
634 assert!(
635 selected.ends_with("onnx/model.onnx"),
636 "should prefer onnx/ subdir: {selected:?}"
637 );
638 }
639
640 #[test]
641 fn select_model_file_falls_back_to_legacy() {
642 let tmp = tempfile::tempdir().expect("tempdir");
643 std::fs::write(tmp.path().join("model.onnx"), b"legacy").unwrap();
644
645 let selected = FastEmbedder::select_model_file(tmp.path()).unwrap();
646 assert!(
647 selected.ends_with("model.onnx"),
648 "should fall back to legacy: {selected:?}"
649 );
650 }
651
652 #[test]
653 fn select_model_file_returns_none_for_empty_dir() {
654 let tmp = tempfile::tempdir().expect("tempdir");
655 assert!(FastEmbedder::select_model_file(tmp.path()).is_none());
656 }
657
658 #[test]
659 fn config_for_known_models() {
660 let minilm = FastEmbedder::config_for("minilm").unwrap();
661 assert_eq!(minilm.dimension, 384);
662
663 let snowflake = FastEmbedder::config_for("snowflake-arctic-s").unwrap();
664 assert_eq!(snowflake.dimension, 384);
665
666 let nomic = FastEmbedder::config_for("nomic-embed").unwrap();
667 assert_eq!(nomic.dimension, 768);
668
669 assert!(FastEmbedder::config_for("unknown").is_none());
670 }
671
672 #[test]
673 fn canonical_name_accepts_policy_and_index_aliases() {
674 assert_eq!(FastEmbedder::canonical_name("fastembed"), Some("minilm"));
675 assert_eq!(
676 FastEmbedder::canonical_name("snowflake-arctic-s-384"),
677 Some("snowflake-arctic-s")
678 );
679 assert_eq!(
680 FastEmbedder::canonical_name("nomic-embed-text-v1.5"),
681 Some("nomic-embed")
682 );
683 }
684
685 #[test]
686 #[serial]
687 fn runtime_model_dir_honors_frankensearch_override_and_expands_home() {
688 let old_override = dotenvy::var("FRANKENSEARCH_MODEL_DIR").ok();
689 let old_home = dotenvy::var("HOME").ok();
690 unsafe {
691 std::env::set_var("HOME", "/tmp/cass-home-for-model-test");
692 std::env::set_var("FRANKENSEARCH_MODEL_DIR", "~/models/snowflake");
693 }
694
695 let resolved = FastEmbedder::runtime_model_dir_for(Path::new("/tmp/cass"), "snowflake")
696 .expect("runtime model dir");
697 assert_eq!(
698 resolved,
699 PathBuf::from("/tmp/cass-home-for-model-test/models/snowflake")
700 );
701
702 unsafe {
703 if let Some(value) = old_override {
704 std::env::set_var("FRANKENSEARCH_MODEL_DIR", value);
705 } else {
706 std::env::remove_var("FRANKENSEARCH_MODEL_DIR");
707 }
708 if let Some(value) = old_home {
709 std::env::set_var("HOME", value);
710 } else {
711 std::env::remove_var("HOME");
712 }
713 }
714 }
715}