Skip to main content

cognee_embedding/
config.rs

1use std::sync::Arc;
2
3use serde::{Deserialize, Serialize};
4
5use crate::engine::EmbeddingEngine;
6use crate::error::EmbeddingResult;
7use crate::mock::{MockEmbeddingEngine, MockVectorMode};
8use crate::ollama::OllamaEmbeddingEngine;
9use crate::openai_compatible::OpenAICompatibleEmbeddingEngine;
10use crate::provider::EmbeddingProvider;
11
12#[cfg(feature = "onnx")]
13use crate::onnx::OnnxEmbeddingEngine;
14#[cfg(feature = "onnx")]
15use std::path::PathBuf;
16
17/// Fallback dimension when the model is unknown AND `EMBEDDING_DIMENSIONS` is unset.
18///
19/// 384 matches the ONNX BGE-Small edge model (Android default). On non-Android,
20/// the default model (`text-embedding-3-small` → 1536) resolves via the
21/// [`known_model_dimensions`] table, so this fallback is only hit for truly
22/// unknown models.
23///
24/// **Note:** Python uses 3072 as its fallback (matching the OpenAI default).
25/// Rust deliberately uses 384 because the Rust edge default is BGE-Small, not
26/// `text-embedding-3-large`.
27const FALLBACK_DIMENSIONS: usize = 384;
28
29/// Best-effort lookup of the output vector dimension for a known embedding model.
30///
31/// Mirrors the dimension table that Python resolves dynamically via the
32/// `litellm` and `fastembed` registries (`_resolve_embedding_dimensions` in
33/// `cognee/infrastructure/databases/vector/embeddings/config.py`). Rust
34/// hardcodes a small, curated table instead of depending on those Python
35/// packages.
36///
37/// Resolution rules (matches Python semantics):
38/// - Strips a leading provider segment: `"openai/text-embedding-3-large"` →
39///   `"text-embedding-3-large"` (uses the last `/`-separated component).
40/// - Matching is **case-insensitive**.
41/// - Returns `None` for unknown models so the caller can fall back with a
42///   warning rather than silently using a wrong dimension.
43///
44/// The `provider` argument is accepted for forward-compatibility (future
45/// provider-scoped overrides) but is not used in the current table.
46pub fn known_model_dimensions(provider: EmbeddingProvider, model: &str) -> Option<usize> {
47    // Strip a provider prefix: "openai/text-embedding-3-large" -> "text-embedding-3-large"
48    // rsplit('/').next() is infallible for any &str (always yields ≥ 1 element).
49    let bare = model.rsplit('/').next().unwrap_or(model);
50    let key = bare.to_ascii_lowercase();
51    let dim = match key.as_str() {
52        // --- OpenAI models (verified via litellm registry) ---
53        "text-embedding-3-large" => 3072,
54        "text-embedding-3-small" => 1536,
55        "text-embedding-ada-002" => 1536,
56        // --- BGE family (fastembed registry + ONNX defaults) ---
57        "bge-small-v1.5" | "bge-small-en-v1.5" => 384,
58        "bge-base-en-v1.5" => 768,
59        "bge-large-en-v1.5" => 1024,
60        // --- MiniLM ---
61        "all-minilm-l6-v2" => 384,
62        // --- Common Ollama models ---
63        "nomic-embed-text" => 768,
64        "mxbai-embed-large" => 1024,
65        _ => return None,
66    };
67    let _ = provider; // provider currently unused; kept for future provider-scoped dims
68    Some(dim)
69}
70
71/// ONNX-specific configuration.
72///
73/// Only used when `EmbeddingConfig::provider` is `Onnx` or `Fastembed`.
74/// All other providers use the top-level `EmbeddingConfig` fields only.
75#[cfg(feature = "onnx")]
76#[derive(Debug, Clone, Serialize, Deserialize)]
77pub struct OnnxEmbeddingConfig {
78    /// Path to ONNX model file (.onnx)
79    pub model_path: PathBuf,
80
81    /// Path to tokenizer.json file
82    pub tokenizer_path: PathBuf,
83
84    /// Model name for logging/identification and auto-download selection
85    pub model_name: String,
86
87    /// Embedding dimensions (must match model output)
88    pub dimensions: usize,
89
90    /// Maximum sequence length in tokens (truncate if longer)
91    pub max_sequence_length: usize,
92
93    /// Batch size for ONNX inference (max texts per inference call)
94    pub batch_size: usize,
95}
96
97#[cfg(feature = "onnx")]
98impl Default for OnnxEmbeddingConfig {
99    fn default() -> Self {
100        Self::bge_small("./target/models")
101    }
102}
103
104#[cfg(feature = "onnx")]
105impl OnnxEmbeddingConfig {
106    /// Create config for BGE-Small-v1.5 model
107    pub fn bge_small(model_dir: impl Into<PathBuf>) -> Self {
108        let base = model_dir.into();
109        let model_path = base.join("BGE-Small-v1.5-model_quantized.onnx");
110        let tokenizer_path = base.join("bge-small-tokenizer.json");
111        Self {
112            model_path,
113            tokenizer_path,
114            model_name: "bge-small-en-v1.5".to_string(),
115            dimensions: 384,
116            max_sequence_length: 512,
117            batch_size: 32,
118        }
119    }
120
121    /// Create config for all-MiniLM-L6-v2 model
122    pub fn minilm_l6(model_dir: impl Into<PathBuf>) -> Self {
123        let base = model_dir.into();
124        let model_path = base.join("all-MiniLM-L6-v2.onnx");
125        let tokenizer_path = base.join("minilm-l6-tokenizer.json");
126        Self {
127            model_path,
128            tokenizer_path,
129            model_name: "all-MiniLM-L6-v2".to_string(),
130            dimensions: 384,
131            max_sequence_length: 256,
132            batch_size: 32,
133        }
134    }
135}
136
137/// Unified embedding configuration.
138///
139/// Provider-agnostic; holds fields for all supported backends.
140/// Load from environment variables via [`EmbeddingConfig::from_env`], or construct
141/// programmatically and pass to [`EmbeddingConfig::create_engine`].
142///
143/// Environment variables (match Python SDK names):
144/// - `EMBEDDING_PROVIDER` — backend selection (default: `openai`; `onnx` on Android)
145/// - `MOCK_EMBEDDING` — set to `true`/`1`/`yes` to force mock mode
146/// - `EMBEDDING_MODEL` — model identifier
147/// - `EMBEDDING_DIMENSIONS` — vector size
148/// - `EMBEDDING_ENDPOINT` — API endpoint URL
149/// - `EMBEDDING_API_KEY` — API key (fallback: `LLM_API_KEY`)
150/// - `EMBEDDING_API_VERSION` — API version string
151/// - `EMBEDDING_MAX_COMPLETION_TOKENS` — maximum tokens (default: 8191)
152/// - `EMBEDDING_BATCH_SIZE` — texts per batch (default: 36)
153/// - `HUGGINGFACE_TOKENIZER` — HuggingFace tokenizer identifier
154#[derive(Debug, Clone, Serialize, Deserialize)]
155pub struct EmbeddingConfig {
156    /// Which backend to use for embedding generation.
157    pub provider: EmbeddingProvider,
158
159    /// Model identifier. For ONNX this is informational; for API providers this is sent in
160    /// the request body. Default depends on provider (BGE-Small-v1.5 for ONNX, empty for others).
161    pub model: String,
162
163    /// Embedding vector dimensionality. Must match the model output.
164    pub dimensions: usize,
165
166    /// API endpoint URL (used by OpenAI-compatible and Ollama providers).
167    pub endpoint: Option<String>,
168
169    /// API key. Reads `EMBEDDING_API_KEY` first, falls back to `LLM_API_KEY`.
170    pub api_key: Option<String>,
171
172    /// API version string (e.g. "2023-05-15" for Azure OpenAI).
173    pub api_version: Option<String>,
174
175    /// Maximum tokens for completion requests (default: 8191).
176    pub max_completion_tokens: usize,
177
178    /// Number of texts to send in a single embedding request (default: 36).
179    pub batch_size: usize,
180
181    /// If true, use mock embeddings regardless of `provider`.
182    /// Overrides `provider` to `Mock`. Set via `MOCK_EMBEDDING=true`.
183    pub mock: bool,
184
185    /// How the mock engine generates vectors when `provider` is `Mock`.
186    /// Defaults to [`MockVectorMode::Zero`]. Set via `MOCK_EMBEDDING=deterministic`
187    /// to derive content-stable vectors from `sha256(text)`.
188    #[serde(default)]
189    pub mock_mode: MockVectorMode,
190
191    /// ONNX-specific configuration. Only consulted when provider is `Onnx` or `Fastembed`.
192    #[cfg(feature = "onnx")]
193    pub onnx: OnnxEmbeddingConfig,
194
195    /// HuggingFace tokenizer identifier for chunking token counting.
196    /// When set, used by `HuggingFaceTokenCounter` in the chunking crate.
197    pub huggingface_tokenizer: Option<String>,
198}
199
200impl Default for EmbeddingConfig {
201    fn default() -> Self {
202        // On Android, local ONNX inference is the right default (edge deployment).
203        // Everywhere else, match the Python SDK default: OpenAI text-embedding-3-small.
204        #[cfg(all(feature = "onnx", target_os = "android"))]
205        let (provider, model, dimensions, endpoint) = {
206            let onnx_cfg = OnnxEmbeddingConfig::default();
207            (
208                EmbeddingProvider::Onnx,
209                onnx_cfg.model_name.clone(),
210                onnx_cfg.dimensions,
211                None,
212            )
213        };
214        #[cfg(all(feature = "onnx", not(target_os = "android")))]
215        let (provider, model, dimensions, endpoint) = {
216            let m = "text-embedding-3-small".to_string();
217            // Resolve via the known-model table so there is one source of truth.
218            let d = known_model_dimensions(EmbeddingProvider::OpenAi, &m)
219                .unwrap_or(FALLBACK_DIMENSIONS);
220            (
221                EmbeddingProvider::OpenAi,
222                m,
223                d,
224                Some("https://api.openai.com/v1".to_string()),
225            )
226        };
227        #[cfg(not(feature = "onnx"))]
228        let (provider, model, dimensions, endpoint) = {
229            let m = "text-embedding-3-small".to_string();
230            // Resolve via the known-model table so there is one source of truth.
231            let d = known_model_dimensions(EmbeddingProvider::OpenAi, &m)
232                .unwrap_or(FALLBACK_DIMENSIONS);
233            (
234                EmbeddingProvider::OpenAi,
235                m,
236                d,
237                Some("https://api.openai.com/v1".to_string()),
238            )
239        };
240
241        Self {
242            provider,
243            model,
244            dimensions,
245            endpoint,
246            api_key: None,
247            api_version: None,
248            max_completion_tokens: 8191,
249            batch_size: 36,
250            mock: false,
251            mock_mode: MockVectorMode::Zero,
252            #[cfg(feature = "onnx")]
253            onnx: OnnxEmbeddingConfig::default(),
254            huggingface_tokenizer: None,
255        }
256    }
257}
258
259impl EmbeddingConfig {
260    /// Load configuration from environment variables.
261    ///
262    /// Reads the same env var names as the Python SDK so that a shared `.env` file
263    /// works across both implementations without modification.
264    pub fn from_env() -> Self {
265        let mut config = Self::default();
266
267        // Parse MOCK_EMBEDDING first — it overrides everything else if set.
268        // `deterministic` (or `hash`) selects the SHA-256-derived deterministic
269        // mode; other truthy values keep the legacy zero-vector mode.
270        if let Ok(val) = std::env::var("MOCK_EMBEDDING") {
271            let val = val.trim().to_lowercase();
272            if val == "deterministic" || val == "hash" {
273                config.mock = true;
274                config.provider = EmbeddingProvider::Mock;
275                config.mock_mode = MockVectorMode::Deterministic;
276                return config;
277            }
278            if val == "true" || val == "1" || val == "yes" {
279                config.mock = true;
280                config.provider = EmbeddingProvider::Mock;
281                config.mock_mode = MockVectorMode::Zero;
282                return config;
283            }
284        }
285
286        // Parse EMBEDDING_PROVIDER
287        if let Ok(val) = std::env::var("EMBEDDING_PROVIDER") {
288            let val = val.trim().to_lowercase();
289            match val.as_str() {
290                "onnx" => config.provider = EmbeddingProvider::Onnx,
291                "fastembed" => config.provider = EmbeddingProvider::Fastembed,
292                "openai" => config.provider = EmbeddingProvider::OpenAi,
293                "openai_compatible" => config.provider = EmbeddingProvider::OpenAiCompatible,
294                "ollama" => config.provider = EmbeddingProvider::Ollama,
295                "mock" => {
296                    config.mock = true;
297                    config.provider = EmbeddingProvider::Mock;
298                }
299                _ => {
300                    // Unknown provider — leave the platform default (OpenAI, or
301                    // ONNX on Android) and log nothing; the caller will get a
302                    // clear error from create_engine() if needed.
303                }
304            }
305        }
306
307        // Apply provider-specific model defaults before checking env var overrides.
308        // This ensures that when a user switches to EMBEDDING_PROVIDER=ollama
309        // without setting EMBEDDING_MODEL explicitly, they get a sensible Ollama
310        // default model name rather than the ONNX model name.
311        // (Dimension is resolved below via the known-model table, not hardcoded here.)
312        if config.provider == EmbeddingProvider::Ollama {
313            config.model = "avr/sfr-embedding-mistral:latest".to_string();
314        }
315
316        // EMBEDDING_MODEL
317        if let Ok(val) = std::env::var("EMBEDDING_MODEL") {
318            let val = val.trim().to_string();
319            if !val.is_empty() {
320                config.model = val;
321            }
322        }
323
324        // EMBEDDING_DIMENSIONS — resolution order (mirrors Python model_post_init):
325        //   1. Explicit EMBEDDING_DIMENSIONS env var — always wins.
326        //   2. known_model_dimensions(provider, model) — table lookup.
327        //   3. Fallback FALLBACK_DIMENSIONS (384) with a tracing::warn! so the user
328        //      knows to set EMBEDDING_DIMENSIONS explicitly for unknown models.
329        // For ONNX the model file dictates the true dimension, so we prefer the
330        // onnx_cfg.dimensions unless the user set EMBEDDING_DIMENSIONS explicitly.
331        let explicit_dims = std::env::var("EMBEDDING_DIMENSIONS")
332            .ok()
333            .and_then(|v| v.trim().parse::<usize>().ok());
334
335        // Resolve via the known-model table, falling back to FALLBACK_DIMENSIONS
336        // with a warning when the model is unknown (parity with Python
337        // model_post_init). Used for every provider except ONNX/Fastembed, whose
338        // dimension is dictated by the model file (handled below).
339        let resolve_from_table = |config: &EmbeddingConfig| match known_model_dimensions(
340            config.provider.clone(),
341            &config.model,
342        ) {
343            // 2. Known model — derived dimension.
344            Some(d) => d,
345            // 3. Unknown model — fallback with warning.
346            None => {
347                tracing::warn!(
348                    provider = ?config.provider,
349                    model = %config.model,
350                    fallback = FALLBACK_DIMENSIONS,
351                    "Could not auto-derive embedding dimensions; set \
352                     EMBEDDING_DIMENSIONS explicitly if your embedder produces \
353                     a different vector size, otherwise the first vector write \
354                     will fail with a shape mismatch."
355                );
356                FALLBACK_DIMENSIONS
357            }
358        };
359
360        config.dimensions = match explicit_dims {
361            // 1. Explicit override always wins.
362            Some(d) => d,
363            None => {
364                // For ONNX/Fastembed the model file carries the authoritative
365                // dimension, so use onnx.dimensions rather than the text table —
366                // this keeps custom ONNX models working.
367                #[cfg(feature = "onnx")]
368                {
369                    if matches!(
370                        config.provider,
371                        EmbeddingProvider::Onnx | EmbeddingProvider::Fastembed
372                    ) {
373                        config.onnx.dimensions
374                    } else {
375                        resolve_from_table(&config)
376                    }
377                }
378                #[cfg(not(feature = "onnx"))]
379                {
380                    resolve_from_table(&config)
381                }
382            }
383        };
384
385        // EMBEDDING_ENDPOINT
386        if let Ok(val) = std::env::var("EMBEDDING_ENDPOINT") {
387            let val = val.trim().to_string();
388            if !val.is_empty() {
389                config.endpoint = Some(val);
390            }
391        }
392
393        // EMBEDDING_API_KEY, fallback to LLM_API_KEY
394        if let Ok(val) = std::env::var("EMBEDDING_API_KEY") {
395            let val = val.trim().to_string();
396            if !val.is_empty() {
397                config.api_key = Some(val);
398            }
399        } else if let Ok(val) = std::env::var("LLM_API_KEY") {
400            let val = val.trim().to_string();
401            if !val.is_empty() {
402                config.api_key = Some(val);
403            }
404        }
405
406        // EMBEDDING_API_VERSION
407        if let Ok(val) = std::env::var("EMBEDDING_API_VERSION") {
408            let val = val.trim().to_string();
409            if !val.is_empty() {
410                config.api_version = Some(val);
411            }
412        }
413
414        // EMBEDDING_MAX_COMPLETION_TOKENS
415        if let Ok(val) = std::env::var("EMBEDDING_MAX_COMPLETION_TOKENS")
416            && let Ok(n) = val.trim().parse::<usize>()
417        {
418            config.max_completion_tokens = n;
419        }
420
421        // EMBEDDING_BATCH_SIZE
422        if let Ok(val) = std::env::var("EMBEDDING_BATCH_SIZE")
423            && let Ok(n) = val.trim().parse::<usize>()
424        {
425            config.batch_size = n;
426        }
427
428        // HUGGINGFACE_TOKENIZER
429        if let Ok(val) = std::env::var("HUGGINGFACE_TOKENIZER") {
430            let val = val.trim().to_string();
431            if !val.is_empty() {
432                config.huggingface_tokenizer = Some(val);
433            }
434        }
435
436        config
437    }
438
439    /// Returns the effective provider, substituting Mock when `self.mock` is true.
440    pub fn effective_provider(&self) -> EmbeddingProvider {
441        if self.mock {
442            EmbeddingProvider::Mock
443        } else {
444            self.provider.clone()
445        }
446    }
447
448    /// Create an embedding engine based on this configuration.
449    ///
450    /// Dispatches to the appropriate engine implementation based on
451    /// [`EmbeddingConfig::effective_provider`]. Providers not yet implemented
452    /// return [`EmbeddingError::NotImplemented`].
453    pub async fn create_engine(&self) -> EmbeddingResult<Arc<dyn EmbeddingEngine>> {
454        match self.effective_provider() {
455            #[cfg(feature = "onnx")]
456            EmbeddingProvider::Onnx | EmbeddingProvider::Fastembed => {
457                let engine = OnnxEmbeddingEngine::with_auto_download(self.onnx.clone()).await?;
458                Ok(Arc::new(engine))
459            }
460            #[cfg(not(feature = "onnx"))]
461            EmbeddingProvider::Onnx | EmbeddingProvider::Fastembed => {
462                Err(crate::error::EmbeddingError::NotImplemented(
463                    "ONNX embedding engine requires the `onnx` crate feature".to_string(),
464                ))
465            }
466            EmbeddingProvider::OpenAi | EmbeddingProvider::OpenAiCompatible => {
467                let engine = OpenAICompatibleEmbeddingEngine::new(self)?;
468                Ok(Arc::new(engine))
469            }
470            EmbeddingProvider::Ollama => {
471                let engine = OllamaEmbeddingEngine::new(self)?;
472                Ok(Arc::new(engine))
473            }
474            EmbeddingProvider::Mock => Ok(Arc::new(
475                MockEmbeddingEngine::new(self.dimensions).with_mode(self.mock_mode),
476            )),
477        }
478    }
479}
480
481#[cfg(test)]
482mod tests {
483    use super::*;
484    use serial_test::serial;
485
486    #[test]
487    #[cfg(all(feature = "onnx", target_os = "android"))]
488    fn test_default_is_onnx_on_android() {
489        let config = EmbeddingConfig::default();
490        assert_eq!(config.provider, EmbeddingProvider::Onnx);
491        assert_eq!(config.dimensions, 384);
492        assert_eq!(config.batch_size, 36);
493        assert_eq!(config.max_completion_tokens, 8191);
494        assert!(!config.mock);
495    }
496
497    #[test]
498    #[cfg(not(target_os = "android"))]
499    fn test_default_is_openai_off_android() {
500        let config = EmbeddingConfig::default();
501        assert_eq!(config.provider, EmbeddingProvider::OpenAi);
502        assert_eq!(config.model, "text-embedding-3-small");
503        assert_eq!(config.dimensions, 1536);
504        assert_eq!(
505            config.endpoint.as_deref(),
506            Some("https://api.openai.com/v1")
507        );
508        assert!(!config.mock);
509    }
510
511    #[test]
512    fn test_effective_provider_mock_override() {
513        let config = EmbeddingConfig {
514            mock: true,
515            ..Default::default()
516        };
517        assert_eq!(config.effective_provider(), EmbeddingProvider::Mock);
518    }
519
520    #[test]
521    #[cfg(all(feature = "onnx", target_os = "android"))]
522    fn test_effective_provider_passthrough_onnx() {
523        let config = EmbeddingConfig::default();
524        assert_eq!(config.effective_provider(), EmbeddingProvider::Onnx);
525    }
526
527    #[test]
528    #[cfg(not(target_os = "android"))]
529    fn test_effective_provider_passthrough_openai() {
530        let config = EmbeddingConfig::default();
531        assert_eq!(config.effective_provider(), EmbeddingProvider::OpenAi);
532    }
533
534    // env-var tests mutate global process state, so they are serialized with
535    // #[serial] to prevent races with each other. All env-mutating tests in this
536    // crate live in this single test binary, so serial_test (which serializes
537    // within a process) is sufficient; each test also cleans up its own vars.
538
539    #[test]
540    #[serial]
541    fn test_from_env_mock_embedding_true() {
542        // SAFETY: env var mutation is safe because #[serial] guarantees no other
543        // env-mutating test in this binary runs concurrently.
544        unsafe { std::env::set_var("MOCK_EMBEDDING", "true") };
545        let config = EmbeddingConfig::from_env();
546        unsafe { std::env::remove_var("MOCK_EMBEDDING") };
547        assert!(config.mock);
548        assert_eq!(config.effective_provider(), EmbeddingProvider::Mock);
549    }
550
551    #[test]
552    #[serial]
553    fn test_from_env_mock_embedding_numeric() {
554        // SAFETY: see test_from_env_mock_embedding_true
555        unsafe { std::env::set_var("MOCK_EMBEDDING", "1") };
556        let config = EmbeddingConfig::from_env();
557        unsafe { std::env::remove_var("MOCK_EMBEDDING") };
558        assert!(config.mock);
559        // Legacy truthy values keep the zero-vector mode.
560        assert_eq!(config.mock_mode, MockVectorMode::Zero);
561    }
562
563    #[test]
564    #[ignore = "mutates global env vars; run with --test-threads=1 --ignored"]
565    fn test_from_env_mock_embedding_deterministic() {
566        // SAFETY: see test_from_env_mock_embedding_true
567        unsafe { std::env::set_var("MOCK_EMBEDDING", "deterministic") };
568        let config = EmbeddingConfig::from_env();
569        unsafe { std::env::remove_var("MOCK_EMBEDDING") };
570        assert!(config.mock);
571        assert_eq!(config.effective_provider(), EmbeddingProvider::Mock);
572        assert_eq!(config.mock_mode, MockVectorMode::Deterministic);
573    }
574
575    #[test]
576    #[serial]
577    fn test_from_env_provider() {
578        // SAFETY: see test_from_env_mock_embedding_true
579        unsafe { std::env::set_var("EMBEDDING_PROVIDER", "openai") };
580        let config = EmbeddingConfig::from_env();
581        unsafe { std::env::remove_var("EMBEDDING_PROVIDER") };
582        assert_eq!(config.provider, EmbeddingProvider::OpenAi);
583    }
584
585    #[test]
586    #[serial]
587    fn test_from_env_fastembed_alias() {
588        // SAFETY: see test_from_env_mock_embedding_true
589        unsafe { std::env::set_var("EMBEDDING_PROVIDER", "fastembed") };
590        let config = EmbeddingConfig::from_env();
591        unsafe { std::env::remove_var("EMBEDDING_PROVIDER") };
592        assert_eq!(config.provider, EmbeddingProvider::Fastembed);
593    }
594
595    #[test]
596    #[serial]
597    fn test_from_env_dimensions() {
598        // SAFETY: see test_from_env_mock_embedding_true
599        unsafe { std::env::set_var("EMBEDDING_DIMENSIONS", "1536") };
600        let config = EmbeddingConfig::from_env();
601        unsafe { std::env::remove_var("EMBEDDING_DIMENSIONS") };
602        assert_eq!(config.dimensions, 1536);
603    }
604
605    #[test]
606    #[serial]
607    fn test_from_env_api_key_fallback() {
608        // SAFETY: see test_from_env_mock_embedding_true
609        unsafe { std::env::remove_var("EMBEDDING_API_KEY") };
610        unsafe { std::env::set_var("LLM_API_KEY", "my-llm-key") };
611        let config = EmbeddingConfig::from_env();
612        unsafe { std::env::remove_var("LLM_API_KEY") };
613        assert_eq!(config.api_key, Some("my-llm-key".to_string()));
614    }
615
616    #[test]
617    #[serial]
618    fn test_from_env_api_key_prefers_embedding() {
619        // SAFETY: see test_from_env_mock_embedding_true
620        unsafe { std::env::set_var("EMBEDDING_API_KEY", "embed-key") };
621        unsafe { std::env::set_var("LLM_API_KEY", "llm-key") };
622        let config = EmbeddingConfig::from_env();
623        unsafe { std::env::remove_var("EMBEDDING_API_KEY") };
624        unsafe { std::env::remove_var("LLM_API_KEY") };
625        assert_eq!(config.api_key, Some("embed-key".to_string()));
626    }
627
628    #[test]
629    #[cfg(feature = "onnx")]
630    fn test_onnx_config_bge_small() {
631        let cfg = OnnxEmbeddingConfig::bge_small("/models");
632        assert_eq!(cfg.dimensions, 384);
633        assert_eq!(cfg.max_sequence_length, 512);
634        assert_eq!(cfg.model_name, "bge-small-en-v1.5");
635    }
636
637    #[test]
638    #[cfg(feature = "onnx")]
639    fn test_onnx_config_minilm_l6() {
640        let cfg = OnnxEmbeddingConfig::minilm_l6("/models");
641        assert_eq!(cfg.dimensions, 384);
642        assert_eq!(cfg.max_sequence_length, 256);
643        assert_eq!(cfg.model_name, "all-MiniLM-L6-v2");
644    }
645
646    // ── known_model_dimensions unit tests ──────────────────────────────────
647    // These are pure lookup tests — no env vars, no network, no model files.
648
649    #[test]
650    fn known_dims_openai_large() {
651        assert_eq!(
652            known_model_dimensions(EmbeddingProvider::OpenAi, "text-embedding-3-large"),
653            Some(3072),
654        );
655    }
656
657    #[test]
658    fn known_dims_openai_small() {
659        assert_eq!(
660            known_model_dimensions(EmbeddingProvider::OpenAi, "text-embedding-3-small"),
661            Some(1536),
662        );
663    }
664
665    #[test]
666    fn known_dims_ada_002() {
667        assert_eq!(
668            known_model_dimensions(EmbeddingProvider::OpenAi, "text-embedding-ada-002"),
669            Some(1536),
670        );
671    }
672
673    /// Verify that a provider-prefixed model name is normalized before lookup.
674    /// Python uses `model.split("/")[-1]`; Rust uses `rsplit('/').next()`.
675    #[test]
676    fn known_dims_prefix_stripped() {
677        assert_eq!(
678            known_model_dimensions(EmbeddingProvider::OpenAi, "openai/text-embedding-3-small"),
679            Some(1536),
680        );
681        // Azure-prefixed variant
682        assert_eq!(
683            known_model_dimensions(
684                EmbeddingProvider::OpenAiCompatible,
685                "azure/text-embedding-3-large"
686            ),
687            Some(3072),
688        );
689    }
690
691    /// BGE-Small variants: bare name (both v1.5 spellings) and BAAI-prefixed.
692    #[test]
693    fn known_dims_bge_small() {
694        assert_eq!(
695            known_model_dimensions(EmbeddingProvider::Onnx, "bge-small-en-v1.5"),
696            Some(384),
697        );
698        assert_eq!(
699            known_model_dimensions(EmbeddingProvider::Onnx, "BGE-Small-v1.5"),
700            Some(384),
701        );
702        // fastembed-style prefix stripped correctly
703        assert_eq!(
704            known_model_dimensions(EmbeddingProvider::Fastembed, "BAAI/bge-small-en-v1.5"),
705            Some(384),
706        );
707    }
708
709    #[test]
710    fn known_dims_bge_large() {
711        assert_eq!(
712            known_model_dimensions(EmbeddingProvider::Fastembed, "bge-large-en-v1.5"),
713            Some(1024),
714        );
715    }
716
717    #[test]
718    fn known_dims_unknown_returns_none() {
719        assert_eq!(
720            known_model_dimensions(EmbeddingProvider::OpenAi, "some-unknown-model"),
721            None,
722        );
723    }
724
725    // ── from_env dimension-resolution tests ────────────────────────────────
726    // These mutate process env vars and must not run in parallel.
727
728    /// Explicit EMBEDDING_DIMENSIONS always overrides the table lookup.
729    #[test]
730    #[serial]
731    fn from_env_explicit_override_wins() {
732        // SAFETY: #[serial] guarantees no concurrent env readers in this binary.
733        unsafe {
734            std::env::set_var("EMBEDDING_PROVIDER", "openai");
735            std::env::set_var("EMBEDDING_MODEL", "text-embedding-3-large");
736            std::env::set_var("EMBEDDING_DIMENSIONS", "999");
737        }
738        let config = EmbeddingConfig::from_env();
739        unsafe {
740            std::env::remove_var("EMBEDDING_PROVIDER");
741            std::env::remove_var("EMBEDDING_MODEL");
742            std::env::remove_var("EMBEDDING_DIMENSIONS");
743        }
744        // Explicit env var must win over the table value (3072).
745        assert_eq!(config.dimensions, 999);
746    }
747
748    /// Changing EMBEDDING_MODEL to a known model (without EMBEDDING_DIMENSIONS) must
749    /// resolve the correct dimension — not silently keep the default 384.
750    /// This is the regression this task fixes (audit B7.2).
751    #[test]
752    #[serial]
753    fn from_env_model_change_resolves() {
754        // SAFETY: see from_env_explicit_override_wins
755        unsafe {
756            std::env::set_var("EMBEDDING_PROVIDER", "openai");
757            std::env::set_var("EMBEDDING_MODEL", "text-embedding-3-large");
758            std::env::remove_var("EMBEDDING_DIMENSIONS");
759        }
760        let config = EmbeddingConfig::from_env();
761        unsafe {
762            std::env::remove_var("EMBEDDING_PROVIDER");
763            std::env::remove_var("EMBEDDING_MODEL");
764        }
765        // Previously returned 384 (the ONNX default); now must return 3072.
766        assert_eq!(config.dimensions, 3072);
767    }
768
769    /// An unknown model with no explicit EMBEDDING_DIMENSIONS must fall back to
770    /// FALLBACK_DIMENSIONS (384) and log a warning (we only assert the dimension here).
771    #[test]
772    #[serial]
773    fn from_env_unknown_falls_back() {
774        // SAFETY: see from_env_explicit_override_wins
775        unsafe {
776            std::env::set_var("EMBEDDING_PROVIDER", "openai");
777            std::env::set_var("EMBEDDING_MODEL", "some-unknown-model-xyz");
778            std::env::remove_var("EMBEDDING_DIMENSIONS");
779        }
780        let config = EmbeddingConfig::from_env();
781        unsafe {
782            std::env::remove_var("EMBEDDING_PROVIDER");
783            std::env::remove_var("EMBEDDING_MODEL");
784        }
785        assert_eq!(config.dimensions, FALLBACK_DIMENSIONS);
786    }
787}