1use std::sync::Arc;
2
3use serde::{Deserialize, Serialize};
4
5use crate::engine::EmbeddingEngine;
6use crate::error::EmbeddingResult;
7use crate::mock::{MockEmbeddingEngine, MockVectorMode};
8use crate::ollama::OllamaEmbeddingEngine;
9use crate::openai_compatible::OpenAICompatibleEmbeddingEngine;
10use crate::provider::EmbeddingProvider;
11
12#[cfg(feature = "onnx")]
13use crate::onnx::OnnxEmbeddingEngine;
14#[cfg(feature = "onnx")]
15use std::path::PathBuf;
16
17const FALLBACK_DIMENSIONS: usize = 384;
28
29pub fn known_model_dimensions(provider: EmbeddingProvider, model: &str) -> Option<usize> {
47 let bare = model.rsplit('/').next().unwrap_or(model);
50 let key = bare.to_ascii_lowercase();
51 let dim = match key.as_str() {
52 "text-embedding-3-large" => 3072,
54 "text-embedding-3-small" => 1536,
55 "text-embedding-ada-002" => 1536,
56 "bge-small-v1.5" | "bge-small-en-v1.5" => 384,
58 "bge-base-en-v1.5" => 768,
59 "bge-large-en-v1.5" => 1024,
60 "all-minilm-l6-v2" => 384,
62 "nomic-embed-text" => 768,
64 "mxbai-embed-large" => 1024,
65 _ => return None,
66 };
67 let _ = provider; Some(dim)
69}
70
71#[cfg(feature = "onnx")]
76#[derive(Debug, Clone, Serialize, Deserialize)]
77pub struct OnnxEmbeddingConfig {
78 pub model_path: PathBuf,
80
81 pub tokenizer_path: PathBuf,
83
84 pub model_name: String,
86
87 pub dimensions: usize,
89
90 pub max_sequence_length: usize,
92
93 pub batch_size: usize,
95}
96
97#[cfg(feature = "onnx")]
98impl Default for OnnxEmbeddingConfig {
99 fn default() -> Self {
100 Self::bge_small("./target/models")
101 }
102}
103
104#[cfg(feature = "onnx")]
105impl OnnxEmbeddingConfig {
106 pub fn bge_small(model_dir: impl Into<PathBuf>) -> Self {
108 let base = model_dir.into();
109 let model_path = base.join("BGE-Small-v1.5-model_quantized.onnx");
110 let tokenizer_path = base.join("bge-small-tokenizer.json");
111 Self {
112 model_path,
113 tokenizer_path,
114 model_name: "bge-small-en-v1.5".to_string(),
115 dimensions: 384,
116 max_sequence_length: 512,
117 batch_size: 32,
118 }
119 }
120
121 pub fn minilm_l6(model_dir: impl Into<PathBuf>) -> Self {
123 let base = model_dir.into();
124 let model_path = base.join("all-MiniLM-L6-v2.onnx");
125 let tokenizer_path = base.join("minilm-l6-tokenizer.json");
126 Self {
127 model_path,
128 tokenizer_path,
129 model_name: "all-MiniLM-L6-v2".to_string(),
130 dimensions: 384,
131 max_sequence_length: 256,
132 batch_size: 32,
133 }
134 }
135}
136
137#[derive(Debug, Clone, Serialize, Deserialize)]
156pub struct EmbeddingConfig {
157 pub provider: EmbeddingProvider,
159
160 pub model: String,
163
164 pub dimensions: usize,
166
167 pub endpoint: Option<String>,
169
170 pub api_key: Option<String>,
172
173 pub api_version: Option<String>,
175
176 pub max_completion_tokens: usize,
178
179 pub batch_size: usize,
187
188 pub mock: bool,
191
192 #[serde(default)]
196 pub mock_mode: MockVectorMode,
197
198 #[cfg(feature = "onnx")]
200 pub onnx: OnnxEmbeddingConfig,
201
202 pub huggingface_tokenizer: Option<String>,
205}
206
207impl Default for EmbeddingConfig {
208 fn default() -> Self {
209 #[cfg(all(feature = "onnx", target_os = "android"))]
212 let (provider, model, dimensions, endpoint) = {
213 let onnx_cfg = OnnxEmbeddingConfig::default();
214 (
215 EmbeddingProvider::Onnx,
216 onnx_cfg.model_name.clone(),
217 onnx_cfg.dimensions,
218 None,
219 )
220 };
221 #[cfg(all(feature = "onnx", not(target_os = "android")))]
222 let (provider, model, dimensions, endpoint) = {
223 let m = "text-embedding-3-small".to_string();
224 let d = known_model_dimensions(EmbeddingProvider::OpenAi, &m)
226 .unwrap_or(FALLBACK_DIMENSIONS);
227 (
228 EmbeddingProvider::OpenAi,
229 m,
230 d,
231 Some("https://api.openai.com/v1".to_string()),
232 )
233 };
234 #[cfg(not(feature = "onnx"))]
235 let (provider, model, dimensions, endpoint) = {
236 let m = "text-embedding-3-small".to_string();
237 let d = known_model_dimensions(EmbeddingProvider::OpenAi, &m)
239 .unwrap_or(FALLBACK_DIMENSIONS);
240 (
241 EmbeddingProvider::OpenAi,
242 m,
243 d,
244 Some("https://api.openai.com/v1".to_string()),
245 )
246 };
247
248 Self {
249 provider,
250 model,
251 dimensions,
252 endpoint,
253 api_key: None,
254 api_version: None,
255 max_completion_tokens: 8191,
256 batch_size: 36,
257 mock: false,
258 mock_mode: MockVectorMode::Zero,
259 #[cfg(feature = "onnx")]
260 onnx: OnnxEmbeddingConfig::default(),
261 huggingface_tokenizer: None,
262 }
263 }
264}
265
266impl EmbeddingConfig {
267 pub fn from_env() -> Self {
272 let mut config = Self::default();
273
274 if let Ok(val) = std::env::var("MOCK_EMBEDDING") {
278 let val = val.trim().to_lowercase();
279 if val == "deterministic" || val == "hash" {
280 config.mock = true;
281 config.provider = EmbeddingProvider::Mock;
282 config.mock_mode = MockVectorMode::Deterministic;
283 return config;
284 }
285 if val == "true" || val == "1" || val == "yes" {
286 config.mock = true;
287 config.provider = EmbeddingProvider::Mock;
288 config.mock_mode = MockVectorMode::Zero;
289 return config;
290 }
291 }
292
293 if let Ok(val) = std::env::var("EMBEDDING_PROVIDER") {
295 let val = val.trim().to_lowercase();
296 match val.as_str() {
297 "onnx" => config.provider = EmbeddingProvider::Onnx,
298 "fastembed" => config.provider = EmbeddingProvider::Fastembed,
299 "openai" => config.provider = EmbeddingProvider::OpenAi,
300 "openai_compatible" => config.provider = EmbeddingProvider::OpenAiCompatible,
301 "ollama" => config.provider = EmbeddingProvider::Ollama,
302 "mock" => {
303 config.mock = true;
304 config.provider = EmbeddingProvider::Mock;
305 }
306 _ => {
307 }
311 }
312 }
313
314 if config.provider == EmbeddingProvider::Ollama {
320 config.model = "avr/sfr-embedding-mistral:latest".to_string();
321 }
322
323 if let Ok(val) = std::env::var("EMBEDDING_MODEL") {
325 let val = val.trim().to_string();
326 if !val.is_empty() {
327 config.model = val;
328 }
329 }
330
331 let explicit_dims = std::env::var("EMBEDDING_DIMENSIONS")
339 .ok()
340 .and_then(|v| v.trim().parse::<usize>().ok());
341
342 let resolve_from_table = |config: &EmbeddingConfig| match known_model_dimensions(
347 config.provider.clone(),
348 &config.model,
349 ) {
350 Some(d) => d,
352 None => {
354 tracing::warn!(
355 provider = ?config.provider,
356 model = %config.model,
357 fallback = FALLBACK_DIMENSIONS,
358 "Could not auto-derive embedding dimensions; set \
359 EMBEDDING_DIMENSIONS explicitly if your embedder produces \
360 a different vector size, otherwise the first vector write \
361 will fail with a shape mismatch."
362 );
363 FALLBACK_DIMENSIONS
364 }
365 };
366
367 config.dimensions = match explicit_dims {
368 Some(d) => d,
370 None => {
371 #[cfg(feature = "onnx")]
375 {
376 if matches!(
377 config.provider,
378 EmbeddingProvider::Onnx | EmbeddingProvider::Fastembed
379 ) {
380 config.onnx.dimensions
381 } else {
382 resolve_from_table(&config)
383 }
384 }
385 #[cfg(not(feature = "onnx"))]
386 {
387 resolve_from_table(&config)
388 }
389 }
390 };
391
392 if let Ok(val) = std::env::var("EMBEDDING_ENDPOINT") {
394 let val = val.trim().to_string();
395 if !val.is_empty() {
396 config.endpoint = Some(val);
397 }
398 }
399
400 if let Ok(val) = std::env::var("EMBEDDING_API_KEY") {
402 let val = val.trim().to_string();
403 if !val.is_empty() {
404 config.api_key = Some(val);
405 }
406 } else if let Ok(val) = std::env::var("LLM_API_KEY") {
407 let val = val.trim().to_string();
408 if !val.is_empty() {
409 config.api_key = Some(val);
410 }
411 }
412
413 if let Ok(val) = std::env::var("EMBEDDING_API_VERSION") {
415 let val = val.trim().to_string();
416 if !val.is_empty() {
417 config.api_version = Some(val);
418 }
419 }
420
421 if let Ok(val) = std::env::var("EMBEDDING_MAX_COMPLETION_TOKENS")
423 && let Ok(n) = val.trim().parse::<usize>()
424 {
425 config.max_completion_tokens = n;
426 }
427
428 if let Ok(val) = std::env::var("EMBEDDING_BATCH_SIZE")
430 && let Ok(n) = val.trim().parse::<usize>()
431 {
432 config.batch_size = n;
433 }
434
435 #[cfg(feature = "onnx")]
436 if let Ok(val) = std::env::var("EMBEDDING_ONNX_BATCH_SIZE")
437 && let Ok(n) = val.trim().parse::<usize>()
438 && n > 0
439 {
440 config.onnx.batch_size = n;
441 }
442
443 if let Ok(val) = std::env::var("HUGGINGFACE_TOKENIZER") {
445 let val = val.trim().to_string();
446 if !val.is_empty() {
447 config.huggingface_tokenizer = Some(val);
448 }
449 }
450
451 config
452 }
453
454 pub fn effective_provider(&self) -> EmbeddingProvider {
456 if self.mock {
457 EmbeddingProvider::Mock
458 } else {
459 self.provider.clone()
460 }
461 }
462
463 pub async fn create_engine(&self) -> EmbeddingResult<Arc<dyn EmbeddingEngine>> {
469 match self.effective_provider() {
470 #[cfg(feature = "onnx")]
471 EmbeddingProvider::Onnx | EmbeddingProvider::Fastembed => {
472 let engine = OnnxEmbeddingEngine::with_auto_download(self.onnx.clone()).await?;
473 Ok(Arc::new(engine))
474 }
475 #[cfg(not(feature = "onnx"))]
476 EmbeddingProvider::Onnx | EmbeddingProvider::Fastembed => {
477 Err(crate::error::EmbeddingError::NotImplemented(
478 "ONNX embedding engine requires the `onnx` crate feature".to_string(),
479 ))
480 }
481 EmbeddingProvider::OpenAi | EmbeddingProvider::OpenAiCompatible => {
482 let engine = OpenAICompatibleEmbeddingEngine::new(self)?;
483 Ok(Arc::new(engine))
484 }
485 EmbeddingProvider::Ollama => {
486 let engine = OllamaEmbeddingEngine::new(self)?;
487 Ok(Arc::new(engine))
488 }
489 EmbeddingProvider::Mock => Ok(Arc::new(
490 MockEmbeddingEngine::new(self.dimensions).with_mode(self.mock_mode),
491 )),
492 }
493 }
494}
495
496#[cfg(test)]
497mod tests {
498 use super::*;
499 use serial_test::serial;
500
501 #[test]
502 #[cfg(all(feature = "onnx", target_os = "android"))]
503 fn test_default_is_onnx_on_android() {
504 let config = EmbeddingConfig::default();
505 assert_eq!(config.provider, EmbeddingProvider::Onnx);
506 assert_eq!(config.dimensions, 384);
507 assert_eq!(config.batch_size, 36);
508 assert_eq!(config.max_completion_tokens, 8191);
509 assert!(!config.mock);
510 }
511
512 #[test]
513 #[cfg(not(target_os = "android"))]
514 fn test_default_is_openai_off_android() {
515 let config = EmbeddingConfig::default();
516 assert_eq!(config.provider, EmbeddingProvider::OpenAi);
517 assert_eq!(config.model, "text-embedding-3-small");
518 assert_eq!(config.dimensions, 1536);
519 assert_eq!(
520 config.endpoint.as_deref(),
521 Some("https://api.openai.com/v1")
522 );
523 assert!(!config.mock);
524 }
525
526 #[test]
527 fn test_effective_provider_mock_override() {
528 let config = EmbeddingConfig {
529 mock: true,
530 ..Default::default()
531 };
532 assert_eq!(config.effective_provider(), EmbeddingProvider::Mock);
533 }
534
535 #[test]
536 #[cfg(all(feature = "onnx", target_os = "android"))]
537 fn test_effective_provider_passthrough_onnx() {
538 let config = EmbeddingConfig::default();
539 assert_eq!(config.effective_provider(), EmbeddingProvider::Onnx);
540 }
541
542 #[test]
543 #[cfg(not(target_os = "android"))]
544 fn test_effective_provider_passthrough_openai() {
545 let config = EmbeddingConfig::default();
546 assert_eq!(config.effective_provider(), EmbeddingProvider::OpenAi);
547 }
548
549 #[test]
555 #[serial]
556 fn test_from_env_mock_embedding_true() {
557 unsafe { std::env::set_var("MOCK_EMBEDDING", "true") };
560 let config = EmbeddingConfig::from_env();
561 unsafe { std::env::remove_var("MOCK_EMBEDDING") };
562 assert!(config.mock);
563 assert_eq!(config.effective_provider(), EmbeddingProvider::Mock);
564 }
565
566 #[test]
567 #[serial]
568 fn test_from_env_mock_embedding_numeric() {
569 unsafe { std::env::set_var("MOCK_EMBEDDING", "1") };
571 let config = EmbeddingConfig::from_env();
572 unsafe { std::env::remove_var("MOCK_EMBEDDING") };
573 assert!(config.mock);
574 assert_eq!(config.mock_mode, MockVectorMode::Zero);
576 }
577
578 #[test]
579 #[ignore = "mutates global env vars; run with --test-threads=1 --ignored"]
580 fn test_from_env_mock_embedding_deterministic() {
581 unsafe { std::env::set_var("MOCK_EMBEDDING", "deterministic") };
583 let config = EmbeddingConfig::from_env();
584 unsafe { std::env::remove_var("MOCK_EMBEDDING") };
585 assert!(config.mock);
586 assert_eq!(config.effective_provider(), EmbeddingProvider::Mock);
587 assert_eq!(config.mock_mode, MockVectorMode::Deterministic);
588 }
589
590 #[test]
591 #[serial]
592 fn test_from_env_provider() {
593 unsafe { std::env::set_var("EMBEDDING_PROVIDER", "openai") };
595 let config = EmbeddingConfig::from_env();
596 unsafe { std::env::remove_var("EMBEDDING_PROVIDER") };
597 assert_eq!(config.provider, EmbeddingProvider::OpenAi);
598 }
599
600 #[test]
601 #[serial]
602 fn test_from_env_fastembed_alias() {
603 unsafe { std::env::set_var("EMBEDDING_PROVIDER", "fastembed") };
605 let config = EmbeddingConfig::from_env();
606 unsafe { std::env::remove_var("EMBEDDING_PROVIDER") };
607 assert_eq!(config.provider, EmbeddingProvider::Fastembed);
608 }
609
610 #[test]
611 #[serial]
612 fn test_from_env_dimensions() {
613 unsafe { std::env::set_var("EMBEDDING_DIMENSIONS", "1536") };
615 let config = EmbeddingConfig::from_env();
616 unsafe { std::env::remove_var("EMBEDDING_DIMENSIONS") };
617 assert_eq!(config.dimensions, 1536);
618 }
619
620 #[test]
621 #[serial]
622 fn test_from_env_api_key_fallback() {
623 unsafe { std::env::remove_var("EMBEDDING_API_KEY") };
625 unsafe { std::env::set_var("LLM_API_KEY", "my-llm-key") };
626 let config = EmbeddingConfig::from_env();
627 unsafe { std::env::remove_var("LLM_API_KEY") };
628 assert_eq!(config.api_key, Some("my-llm-key".to_string()));
629 }
630
631 #[test]
632 #[serial]
633 fn test_from_env_api_key_prefers_embedding() {
634 unsafe { std::env::set_var("EMBEDDING_API_KEY", "embed-key") };
636 unsafe { std::env::set_var("LLM_API_KEY", "llm-key") };
637 let config = EmbeddingConfig::from_env();
638 unsafe { std::env::remove_var("EMBEDDING_API_KEY") };
639 unsafe { std::env::remove_var("LLM_API_KEY") };
640 assert_eq!(config.api_key, Some("embed-key".to_string()));
641 }
642
643 #[test]
644 #[cfg(feature = "onnx")]
645 #[serial]
646 fn from_env_onnx_batch_size_override() {
647 unsafe { std::env::set_var("EMBEDDING_ONNX_BATCH_SIZE", "8") };
649 let config = EmbeddingConfig::from_env();
650 unsafe { std::env::remove_var("EMBEDDING_ONNX_BATCH_SIZE") };
651 assert_eq!(config.onnx.batch_size, 8);
652 }
653
654 #[test]
655 #[cfg(feature = "onnx")]
656 fn test_onnx_config_bge_small() {
657 let cfg = OnnxEmbeddingConfig::bge_small("/models");
658 assert_eq!(cfg.dimensions, 384);
659 assert_eq!(cfg.max_sequence_length, 512);
660 assert_eq!(cfg.model_name, "bge-small-en-v1.5");
661 }
662
663 #[test]
664 #[cfg(feature = "onnx")]
665 fn test_onnx_config_minilm_l6() {
666 let cfg = OnnxEmbeddingConfig::minilm_l6("/models");
667 assert_eq!(cfg.dimensions, 384);
668 assert_eq!(cfg.max_sequence_length, 256);
669 assert_eq!(cfg.model_name, "all-MiniLM-L6-v2");
670 }
671
672 #[test]
676 fn known_dims_openai_large() {
677 assert_eq!(
678 known_model_dimensions(EmbeddingProvider::OpenAi, "text-embedding-3-large"),
679 Some(3072),
680 );
681 }
682
683 #[test]
684 fn known_dims_openai_small() {
685 assert_eq!(
686 known_model_dimensions(EmbeddingProvider::OpenAi, "text-embedding-3-small"),
687 Some(1536),
688 );
689 }
690
691 #[test]
692 fn known_dims_ada_002() {
693 assert_eq!(
694 known_model_dimensions(EmbeddingProvider::OpenAi, "text-embedding-ada-002"),
695 Some(1536),
696 );
697 }
698
699 #[test]
702 fn known_dims_prefix_stripped() {
703 assert_eq!(
704 known_model_dimensions(EmbeddingProvider::OpenAi, "openai/text-embedding-3-small"),
705 Some(1536),
706 );
707 assert_eq!(
709 known_model_dimensions(
710 EmbeddingProvider::OpenAiCompatible,
711 "azure/text-embedding-3-large"
712 ),
713 Some(3072),
714 );
715 }
716
717 #[test]
719 fn known_dims_bge_small() {
720 assert_eq!(
721 known_model_dimensions(EmbeddingProvider::Onnx, "bge-small-en-v1.5"),
722 Some(384),
723 );
724 assert_eq!(
725 known_model_dimensions(EmbeddingProvider::Onnx, "BGE-Small-v1.5"),
726 Some(384),
727 );
728 assert_eq!(
730 known_model_dimensions(EmbeddingProvider::Fastembed, "BAAI/bge-small-en-v1.5"),
731 Some(384),
732 );
733 }
734
735 #[test]
736 fn known_dims_bge_large() {
737 assert_eq!(
738 known_model_dimensions(EmbeddingProvider::Fastembed, "bge-large-en-v1.5"),
739 Some(1024),
740 );
741 }
742
743 #[test]
744 fn known_dims_unknown_returns_none() {
745 assert_eq!(
746 known_model_dimensions(EmbeddingProvider::OpenAi, "some-unknown-model"),
747 None,
748 );
749 }
750
751 #[test]
756 #[serial]
757 fn from_env_explicit_override_wins() {
758 unsafe {
760 std::env::set_var("EMBEDDING_PROVIDER", "openai");
761 std::env::set_var("EMBEDDING_MODEL", "text-embedding-3-large");
762 std::env::set_var("EMBEDDING_DIMENSIONS", "999");
763 }
764 let config = EmbeddingConfig::from_env();
765 unsafe {
766 std::env::remove_var("EMBEDDING_PROVIDER");
767 std::env::remove_var("EMBEDDING_MODEL");
768 std::env::remove_var("EMBEDDING_DIMENSIONS");
769 }
770 assert_eq!(config.dimensions, 999);
772 }
773
774 #[test]
778 #[serial]
779 fn from_env_model_change_resolves() {
780 unsafe {
782 std::env::set_var("EMBEDDING_PROVIDER", "openai");
783 std::env::set_var("EMBEDDING_MODEL", "text-embedding-3-large");
784 std::env::remove_var("EMBEDDING_DIMENSIONS");
785 }
786 let config = EmbeddingConfig::from_env();
787 unsafe {
788 std::env::remove_var("EMBEDDING_PROVIDER");
789 std::env::remove_var("EMBEDDING_MODEL");
790 }
791 assert_eq!(config.dimensions, 3072);
793 }
794
795 #[test]
798 #[serial]
799 fn from_env_unknown_falls_back() {
800 unsafe {
802 std::env::set_var("EMBEDDING_PROVIDER", "openai");
803 std::env::set_var("EMBEDDING_MODEL", "some-unknown-model-xyz");
804 std::env::remove_var("EMBEDDING_DIMENSIONS");
805 }
806 let config = EmbeddingConfig::from_env();
807 unsafe {
808 std::env::remove_var("EMBEDDING_PROVIDER");
809 std::env::remove_var("EMBEDDING_MODEL");
810 }
811 assert_eq!(config.dimensions, FALLBACK_DIMENSIONS);
812 }
813}