1use std::sync::Arc;
2
3use serde::{Deserialize, Serialize};
4
5use crate::engine::EmbeddingEngine;
6use crate::error::EmbeddingResult;
7use crate::mock::{MockEmbeddingEngine, MockVectorMode};
8use crate::ollama::OllamaEmbeddingEngine;
9use crate::openai_compatible::OpenAICompatibleEmbeddingEngine;
10use crate::provider::EmbeddingProvider;
11
12#[cfg(feature = "onnx")]
13use crate::onnx::OnnxEmbeddingEngine;
14#[cfg(feature = "onnx")]
15use std::path::PathBuf;
16
17const FALLBACK_DIMENSIONS: usize = 384;
28
29pub fn known_model_dimensions(provider: EmbeddingProvider, model: &str) -> Option<usize> {
47 let bare = model.rsplit('/').next().unwrap_or(model);
50 let key = bare.to_ascii_lowercase();
51 let dim = match key.as_str() {
52 "text-embedding-3-large" => 3072,
54 "text-embedding-3-small" => 1536,
55 "text-embedding-ada-002" => 1536,
56 "bge-small-v1.5" | "bge-small-en-v1.5" => 384,
58 "bge-base-en-v1.5" => 768,
59 "bge-large-en-v1.5" => 1024,
60 "all-minilm-l6-v2" => 384,
62 "nomic-embed-text" => 768,
64 "mxbai-embed-large" => 1024,
65 _ => return None,
66 };
67 let _ = provider; Some(dim)
69}
70
71#[cfg(feature = "onnx")]
76#[derive(Debug, Clone, Serialize, Deserialize)]
77pub struct OnnxEmbeddingConfig {
78 pub model_path: PathBuf,
80
81 pub tokenizer_path: PathBuf,
83
84 pub model_name: String,
86
87 pub dimensions: usize,
89
90 pub max_sequence_length: usize,
92
93 pub batch_size: usize,
95}
96
97#[cfg(feature = "onnx")]
98impl Default for OnnxEmbeddingConfig {
99 fn default() -> Self {
100 Self::bge_small("./target/models")
101 }
102}
103
104#[cfg(feature = "onnx")]
105impl OnnxEmbeddingConfig {
106 pub fn bge_small(model_dir: impl Into<PathBuf>) -> Self {
108 let base = model_dir.into();
109 let model_path = base.join("BGE-Small-v1.5-model_quantized.onnx");
110 let tokenizer_path = base.join("bge-small-tokenizer.json");
111 Self {
112 model_path,
113 tokenizer_path,
114 model_name: "bge-small-en-v1.5".to_string(),
115 dimensions: 384,
116 max_sequence_length: 512,
117 batch_size: 32,
118 }
119 }
120
121 pub fn minilm_l6(model_dir: impl Into<PathBuf>) -> Self {
123 let base = model_dir.into();
124 let model_path = base.join("all-MiniLM-L6-v2.onnx");
125 let tokenizer_path = base.join("minilm-l6-tokenizer.json");
126 Self {
127 model_path,
128 tokenizer_path,
129 model_name: "all-MiniLM-L6-v2".to_string(),
130 dimensions: 384,
131 max_sequence_length: 256,
132 batch_size: 32,
133 }
134 }
135}
136
137#[derive(Debug, Clone, Serialize, Deserialize)]
155pub struct EmbeddingConfig {
156 pub provider: EmbeddingProvider,
158
159 pub model: String,
162
163 pub dimensions: usize,
165
166 pub endpoint: Option<String>,
168
169 pub api_key: Option<String>,
171
172 pub api_version: Option<String>,
174
175 pub max_completion_tokens: usize,
177
178 pub batch_size: usize,
180
181 pub mock: bool,
184
185 #[serde(default)]
189 pub mock_mode: MockVectorMode,
190
191 #[cfg(feature = "onnx")]
193 pub onnx: OnnxEmbeddingConfig,
194
195 pub huggingface_tokenizer: Option<String>,
198}
199
200impl Default for EmbeddingConfig {
201 fn default() -> Self {
202 #[cfg(all(feature = "onnx", target_os = "android"))]
205 let (provider, model, dimensions, endpoint) = {
206 let onnx_cfg = OnnxEmbeddingConfig::default();
207 (
208 EmbeddingProvider::Onnx,
209 onnx_cfg.model_name.clone(),
210 onnx_cfg.dimensions,
211 None,
212 )
213 };
214 #[cfg(all(feature = "onnx", not(target_os = "android")))]
215 let (provider, model, dimensions, endpoint) = {
216 let m = "text-embedding-3-small".to_string();
217 let d = known_model_dimensions(EmbeddingProvider::OpenAi, &m)
219 .unwrap_or(FALLBACK_DIMENSIONS);
220 (
221 EmbeddingProvider::OpenAi,
222 m,
223 d,
224 Some("https://api.openai.com/v1".to_string()),
225 )
226 };
227 #[cfg(not(feature = "onnx"))]
228 let (provider, model, dimensions, endpoint) = {
229 let m = "text-embedding-3-small".to_string();
230 let d = known_model_dimensions(EmbeddingProvider::OpenAi, &m)
232 .unwrap_or(FALLBACK_DIMENSIONS);
233 (
234 EmbeddingProvider::OpenAi,
235 m,
236 d,
237 Some("https://api.openai.com/v1".to_string()),
238 )
239 };
240
241 Self {
242 provider,
243 model,
244 dimensions,
245 endpoint,
246 api_key: None,
247 api_version: None,
248 max_completion_tokens: 8191,
249 batch_size: 36,
250 mock: false,
251 mock_mode: MockVectorMode::Zero,
252 #[cfg(feature = "onnx")]
253 onnx: OnnxEmbeddingConfig::default(),
254 huggingface_tokenizer: None,
255 }
256 }
257}
258
259impl EmbeddingConfig {
260 pub fn from_env() -> Self {
265 let mut config = Self::default();
266
267 if let Ok(val) = std::env::var("MOCK_EMBEDDING") {
271 let val = val.trim().to_lowercase();
272 if val == "deterministic" || val == "hash" {
273 config.mock = true;
274 config.provider = EmbeddingProvider::Mock;
275 config.mock_mode = MockVectorMode::Deterministic;
276 return config;
277 }
278 if val == "true" || val == "1" || val == "yes" {
279 config.mock = true;
280 config.provider = EmbeddingProvider::Mock;
281 config.mock_mode = MockVectorMode::Zero;
282 return config;
283 }
284 }
285
286 if let Ok(val) = std::env::var("EMBEDDING_PROVIDER") {
288 let val = val.trim().to_lowercase();
289 match val.as_str() {
290 "onnx" => config.provider = EmbeddingProvider::Onnx,
291 "fastembed" => config.provider = EmbeddingProvider::Fastembed,
292 "openai" => config.provider = EmbeddingProvider::OpenAi,
293 "openai_compatible" => config.provider = EmbeddingProvider::OpenAiCompatible,
294 "ollama" => config.provider = EmbeddingProvider::Ollama,
295 "mock" => {
296 config.mock = true;
297 config.provider = EmbeddingProvider::Mock;
298 }
299 _ => {
300 }
304 }
305 }
306
307 if config.provider == EmbeddingProvider::Ollama {
313 config.model = "avr/sfr-embedding-mistral:latest".to_string();
314 }
315
316 if let Ok(val) = std::env::var("EMBEDDING_MODEL") {
318 let val = val.trim().to_string();
319 if !val.is_empty() {
320 config.model = val;
321 }
322 }
323
324 let explicit_dims = std::env::var("EMBEDDING_DIMENSIONS")
332 .ok()
333 .and_then(|v| v.trim().parse::<usize>().ok());
334
335 let resolve_from_table = |config: &EmbeddingConfig| match known_model_dimensions(
340 config.provider.clone(),
341 &config.model,
342 ) {
343 Some(d) => d,
345 None => {
347 tracing::warn!(
348 provider = ?config.provider,
349 model = %config.model,
350 fallback = FALLBACK_DIMENSIONS,
351 "Could not auto-derive embedding dimensions; set \
352 EMBEDDING_DIMENSIONS explicitly if your embedder produces \
353 a different vector size, otherwise the first vector write \
354 will fail with a shape mismatch."
355 );
356 FALLBACK_DIMENSIONS
357 }
358 };
359
360 config.dimensions = match explicit_dims {
361 Some(d) => d,
363 None => {
364 #[cfg(feature = "onnx")]
368 {
369 if matches!(
370 config.provider,
371 EmbeddingProvider::Onnx | EmbeddingProvider::Fastembed
372 ) {
373 config.onnx.dimensions
374 } else {
375 resolve_from_table(&config)
376 }
377 }
378 #[cfg(not(feature = "onnx"))]
379 {
380 resolve_from_table(&config)
381 }
382 }
383 };
384
385 if let Ok(val) = std::env::var("EMBEDDING_ENDPOINT") {
387 let val = val.trim().to_string();
388 if !val.is_empty() {
389 config.endpoint = Some(val);
390 }
391 }
392
393 if let Ok(val) = std::env::var("EMBEDDING_API_KEY") {
395 let val = val.trim().to_string();
396 if !val.is_empty() {
397 config.api_key = Some(val);
398 }
399 } else if let Ok(val) = std::env::var("LLM_API_KEY") {
400 let val = val.trim().to_string();
401 if !val.is_empty() {
402 config.api_key = Some(val);
403 }
404 }
405
406 if let Ok(val) = std::env::var("EMBEDDING_API_VERSION") {
408 let val = val.trim().to_string();
409 if !val.is_empty() {
410 config.api_version = Some(val);
411 }
412 }
413
414 if let Ok(val) = std::env::var("EMBEDDING_MAX_COMPLETION_TOKENS")
416 && let Ok(n) = val.trim().parse::<usize>()
417 {
418 config.max_completion_tokens = n;
419 }
420
421 if let Ok(val) = std::env::var("EMBEDDING_BATCH_SIZE")
423 && let Ok(n) = val.trim().parse::<usize>()
424 {
425 config.batch_size = n;
426 }
427
428 if let Ok(val) = std::env::var("HUGGINGFACE_TOKENIZER") {
430 let val = val.trim().to_string();
431 if !val.is_empty() {
432 config.huggingface_tokenizer = Some(val);
433 }
434 }
435
436 config
437 }
438
439 pub fn effective_provider(&self) -> EmbeddingProvider {
441 if self.mock {
442 EmbeddingProvider::Mock
443 } else {
444 self.provider.clone()
445 }
446 }
447
448 pub async fn create_engine(&self) -> EmbeddingResult<Arc<dyn EmbeddingEngine>> {
454 match self.effective_provider() {
455 #[cfg(feature = "onnx")]
456 EmbeddingProvider::Onnx | EmbeddingProvider::Fastembed => {
457 let engine = OnnxEmbeddingEngine::with_auto_download(self.onnx.clone()).await?;
458 Ok(Arc::new(engine))
459 }
460 #[cfg(not(feature = "onnx"))]
461 EmbeddingProvider::Onnx | EmbeddingProvider::Fastembed => {
462 Err(crate::error::EmbeddingError::NotImplemented(
463 "ONNX embedding engine requires the `onnx` crate feature".to_string(),
464 ))
465 }
466 EmbeddingProvider::OpenAi | EmbeddingProvider::OpenAiCompatible => {
467 let engine = OpenAICompatibleEmbeddingEngine::new(self)?;
468 Ok(Arc::new(engine))
469 }
470 EmbeddingProvider::Ollama => {
471 let engine = OllamaEmbeddingEngine::new(self)?;
472 Ok(Arc::new(engine))
473 }
474 EmbeddingProvider::Mock => Ok(Arc::new(
475 MockEmbeddingEngine::new(self.dimensions).with_mode(self.mock_mode),
476 )),
477 }
478 }
479}
480
481#[cfg(test)]
482mod tests {
483 use super::*;
484 use serial_test::serial;
485
486 #[test]
487 #[cfg(all(feature = "onnx", target_os = "android"))]
488 fn test_default_is_onnx_on_android() {
489 let config = EmbeddingConfig::default();
490 assert_eq!(config.provider, EmbeddingProvider::Onnx);
491 assert_eq!(config.dimensions, 384);
492 assert_eq!(config.batch_size, 36);
493 assert_eq!(config.max_completion_tokens, 8191);
494 assert!(!config.mock);
495 }
496
497 #[test]
498 #[cfg(not(target_os = "android"))]
499 fn test_default_is_openai_off_android() {
500 let config = EmbeddingConfig::default();
501 assert_eq!(config.provider, EmbeddingProvider::OpenAi);
502 assert_eq!(config.model, "text-embedding-3-small");
503 assert_eq!(config.dimensions, 1536);
504 assert_eq!(
505 config.endpoint.as_deref(),
506 Some("https://api.openai.com/v1")
507 );
508 assert!(!config.mock);
509 }
510
511 #[test]
512 fn test_effective_provider_mock_override() {
513 let config = EmbeddingConfig {
514 mock: true,
515 ..Default::default()
516 };
517 assert_eq!(config.effective_provider(), EmbeddingProvider::Mock);
518 }
519
520 #[test]
521 #[cfg(all(feature = "onnx", target_os = "android"))]
522 fn test_effective_provider_passthrough_onnx() {
523 let config = EmbeddingConfig::default();
524 assert_eq!(config.effective_provider(), EmbeddingProvider::Onnx);
525 }
526
527 #[test]
528 #[cfg(not(target_os = "android"))]
529 fn test_effective_provider_passthrough_openai() {
530 let config = EmbeddingConfig::default();
531 assert_eq!(config.effective_provider(), EmbeddingProvider::OpenAi);
532 }
533
534 #[test]
540 #[serial]
541 fn test_from_env_mock_embedding_true() {
542 unsafe { std::env::set_var("MOCK_EMBEDDING", "true") };
545 let config = EmbeddingConfig::from_env();
546 unsafe { std::env::remove_var("MOCK_EMBEDDING") };
547 assert!(config.mock);
548 assert_eq!(config.effective_provider(), EmbeddingProvider::Mock);
549 }
550
551 #[test]
552 #[serial]
553 fn test_from_env_mock_embedding_numeric() {
554 unsafe { std::env::set_var("MOCK_EMBEDDING", "1") };
556 let config = EmbeddingConfig::from_env();
557 unsafe { std::env::remove_var("MOCK_EMBEDDING") };
558 assert!(config.mock);
559 assert_eq!(config.mock_mode, MockVectorMode::Zero);
561 }
562
563 #[test]
564 #[ignore = "mutates global env vars; run with --test-threads=1 --ignored"]
565 fn test_from_env_mock_embedding_deterministic() {
566 unsafe { std::env::set_var("MOCK_EMBEDDING", "deterministic") };
568 let config = EmbeddingConfig::from_env();
569 unsafe { std::env::remove_var("MOCK_EMBEDDING") };
570 assert!(config.mock);
571 assert_eq!(config.effective_provider(), EmbeddingProvider::Mock);
572 assert_eq!(config.mock_mode, MockVectorMode::Deterministic);
573 }
574
575 #[test]
576 #[serial]
577 fn test_from_env_provider() {
578 unsafe { std::env::set_var("EMBEDDING_PROVIDER", "openai") };
580 let config = EmbeddingConfig::from_env();
581 unsafe { std::env::remove_var("EMBEDDING_PROVIDER") };
582 assert_eq!(config.provider, EmbeddingProvider::OpenAi);
583 }
584
585 #[test]
586 #[serial]
587 fn test_from_env_fastembed_alias() {
588 unsafe { std::env::set_var("EMBEDDING_PROVIDER", "fastembed") };
590 let config = EmbeddingConfig::from_env();
591 unsafe { std::env::remove_var("EMBEDDING_PROVIDER") };
592 assert_eq!(config.provider, EmbeddingProvider::Fastembed);
593 }
594
595 #[test]
596 #[serial]
597 fn test_from_env_dimensions() {
598 unsafe { std::env::set_var("EMBEDDING_DIMENSIONS", "1536") };
600 let config = EmbeddingConfig::from_env();
601 unsafe { std::env::remove_var("EMBEDDING_DIMENSIONS") };
602 assert_eq!(config.dimensions, 1536);
603 }
604
605 #[test]
606 #[serial]
607 fn test_from_env_api_key_fallback() {
608 unsafe { std::env::remove_var("EMBEDDING_API_KEY") };
610 unsafe { std::env::set_var("LLM_API_KEY", "my-llm-key") };
611 let config = EmbeddingConfig::from_env();
612 unsafe { std::env::remove_var("LLM_API_KEY") };
613 assert_eq!(config.api_key, Some("my-llm-key".to_string()));
614 }
615
616 #[test]
617 #[serial]
618 fn test_from_env_api_key_prefers_embedding() {
619 unsafe { std::env::set_var("EMBEDDING_API_KEY", "embed-key") };
621 unsafe { std::env::set_var("LLM_API_KEY", "llm-key") };
622 let config = EmbeddingConfig::from_env();
623 unsafe { std::env::remove_var("EMBEDDING_API_KEY") };
624 unsafe { std::env::remove_var("LLM_API_KEY") };
625 assert_eq!(config.api_key, Some("embed-key".to_string()));
626 }
627
628 #[test]
629 #[cfg(feature = "onnx")]
630 fn test_onnx_config_bge_small() {
631 let cfg = OnnxEmbeddingConfig::bge_small("/models");
632 assert_eq!(cfg.dimensions, 384);
633 assert_eq!(cfg.max_sequence_length, 512);
634 assert_eq!(cfg.model_name, "bge-small-en-v1.5");
635 }
636
637 #[test]
638 #[cfg(feature = "onnx")]
639 fn test_onnx_config_minilm_l6() {
640 let cfg = OnnxEmbeddingConfig::minilm_l6("/models");
641 assert_eq!(cfg.dimensions, 384);
642 assert_eq!(cfg.max_sequence_length, 256);
643 assert_eq!(cfg.model_name, "all-MiniLM-L6-v2");
644 }
645
646 #[test]
650 fn known_dims_openai_large() {
651 assert_eq!(
652 known_model_dimensions(EmbeddingProvider::OpenAi, "text-embedding-3-large"),
653 Some(3072),
654 );
655 }
656
657 #[test]
658 fn known_dims_openai_small() {
659 assert_eq!(
660 known_model_dimensions(EmbeddingProvider::OpenAi, "text-embedding-3-small"),
661 Some(1536),
662 );
663 }
664
665 #[test]
666 fn known_dims_ada_002() {
667 assert_eq!(
668 known_model_dimensions(EmbeddingProvider::OpenAi, "text-embedding-ada-002"),
669 Some(1536),
670 );
671 }
672
673 #[test]
676 fn known_dims_prefix_stripped() {
677 assert_eq!(
678 known_model_dimensions(EmbeddingProvider::OpenAi, "openai/text-embedding-3-small"),
679 Some(1536),
680 );
681 assert_eq!(
683 known_model_dimensions(
684 EmbeddingProvider::OpenAiCompatible,
685 "azure/text-embedding-3-large"
686 ),
687 Some(3072),
688 );
689 }
690
691 #[test]
693 fn known_dims_bge_small() {
694 assert_eq!(
695 known_model_dimensions(EmbeddingProvider::Onnx, "bge-small-en-v1.5"),
696 Some(384),
697 );
698 assert_eq!(
699 known_model_dimensions(EmbeddingProvider::Onnx, "BGE-Small-v1.5"),
700 Some(384),
701 );
702 assert_eq!(
704 known_model_dimensions(EmbeddingProvider::Fastembed, "BAAI/bge-small-en-v1.5"),
705 Some(384),
706 );
707 }
708
709 #[test]
710 fn known_dims_bge_large() {
711 assert_eq!(
712 known_model_dimensions(EmbeddingProvider::Fastembed, "bge-large-en-v1.5"),
713 Some(1024),
714 );
715 }
716
717 #[test]
718 fn known_dims_unknown_returns_none() {
719 assert_eq!(
720 known_model_dimensions(EmbeddingProvider::OpenAi, "some-unknown-model"),
721 None,
722 );
723 }
724
725 #[test]
730 #[serial]
731 fn from_env_explicit_override_wins() {
732 unsafe {
734 std::env::set_var("EMBEDDING_PROVIDER", "openai");
735 std::env::set_var("EMBEDDING_MODEL", "text-embedding-3-large");
736 std::env::set_var("EMBEDDING_DIMENSIONS", "999");
737 }
738 let config = EmbeddingConfig::from_env();
739 unsafe {
740 std::env::remove_var("EMBEDDING_PROVIDER");
741 std::env::remove_var("EMBEDDING_MODEL");
742 std::env::remove_var("EMBEDDING_DIMENSIONS");
743 }
744 assert_eq!(config.dimensions, 999);
746 }
747
748 #[test]
752 #[serial]
753 fn from_env_model_change_resolves() {
754 unsafe {
756 std::env::set_var("EMBEDDING_PROVIDER", "openai");
757 std::env::set_var("EMBEDDING_MODEL", "text-embedding-3-large");
758 std::env::remove_var("EMBEDDING_DIMENSIONS");
759 }
760 let config = EmbeddingConfig::from_env();
761 unsafe {
762 std::env::remove_var("EMBEDDING_PROVIDER");
763 std::env::remove_var("EMBEDDING_MODEL");
764 }
765 assert_eq!(config.dimensions, 3072);
767 }
768
769 #[test]
772 #[serial]
773 fn from_env_unknown_falls_back() {
774 unsafe {
776 std::env::set_var("EMBEDDING_PROVIDER", "openai");
777 std::env::set_var("EMBEDDING_MODEL", "some-unknown-model-xyz");
778 std::env::remove_var("EMBEDDING_DIMENSIONS");
779 }
780 let config = EmbeddingConfig::from_env();
781 unsafe {
782 std::env::remove_var("EMBEDDING_PROVIDER");
783 std::env::remove_var("EMBEDDING_MODEL");
784 }
785 assert_eq!(config.dimensions, FALLBACK_DIMENSIONS);
786 }
787}