Skip to main content

cognee_embedding/
config.rs

1use std::sync::Arc;
2
3use serde::{Deserialize, Serialize};
4
5use crate::engine::EmbeddingEngine;
6use crate::error::EmbeddingResult;
7use crate::mock::{MockEmbeddingEngine, MockVectorMode};
8use crate::ollama::OllamaEmbeddingEngine;
9use crate::openai_compatible::OpenAICompatibleEmbeddingEngine;
10use crate::provider::EmbeddingProvider;
11
12#[cfg(feature = "onnx")]
13use crate::onnx::OnnxEmbeddingEngine;
14#[cfg(feature = "onnx")]
15use std::path::PathBuf;
16
17/// Fallback dimension when the model is unknown AND `EMBEDDING_DIMENSIONS` is unset.
18///
19/// 384 matches the ONNX BGE-Small edge model (Android default). On non-Android,
20/// the default model (`text-embedding-3-small` → 1536) resolves via the
21/// [`known_model_dimensions`] table, so this fallback is only hit for truly
22/// unknown models.
23///
24/// **Note:** Python uses 3072 as its fallback (matching the OpenAI default).
25/// Rust deliberately uses 384 because the Rust edge default is BGE-Small, not
26/// `text-embedding-3-large`.
27const FALLBACK_DIMENSIONS: usize = 384;
28
29/// Best-effort lookup of the output vector dimension for a known embedding model.
30///
31/// Mirrors the dimension table that Python resolves dynamically via the
32/// `litellm` and `fastembed` registries (`_resolve_embedding_dimensions` in
33/// `cognee/infrastructure/databases/vector/embeddings/config.py`). Rust
34/// hardcodes a small, curated table instead of depending on those Python
35/// packages.
36///
37/// Resolution rules (matches Python semantics):
38/// - Strips a leading provider segment: `"openai/text-embedding-3-large"` →
39///   `"text-embedding-3-large"` (uses the last `/`-separated component).
40/// - Matching is **case-insensitive**.
41/// - Returns `None` for unknown models so the caller can fall back with a
42///   warning rather than silently using a wrong dimension.
43///
44/// The `provider` argument is accepted for forward-compatibility (future
45/// provider-scoped overrides) but is not used in the current table.
46pub fn known_model_dimensions(provider: EmbeddingProvider, model: &str) -> Option<usize> {
47    // Strip a provider prefix: "openai/text-embedding-3-large" -> "text-embedding-3-large"
48    // rsplit('/').next() is infallible for any &str (always yields ≥ 1 element).
49    let bare = model.rsplit('/').next().unwrap_or(model);
50    let key = bare.to_ascii_lowercase();
51    let dim = match key.as_str() {
52        // --- OpenAI models (verified via litellm registry) ---
53        "text-embedding-3-large" => 3072,
54        "text-embedding-3-small" => 1536,
55        "text-embedding-ada-002" => 1536,
56        // --- BGE family (fastembed registry + ONNX defaults) ---
57        "bge-small-v1.5" | "bge-small-en-v1.5" => 384,
58        "bge-base-en-v1.5" => 768,
59        "bge-large-en-v1.5" => 1024,
60        // --- MiniLM ---
61        "all-minilm-l6-v2" => 384,
62        // --- Common Ollama models ---
63        "nomic-embed-text" => 768,
64        "mxbai-embed-large" => 1024,
65        _ => return None,
66    };
67    let _ = provider; // provider currently unused; kept for future provider-scoped dims
68    Some(dim)
69}
70
71/// ONNX-specific configuration.
72///
73/// Only used when `EmbeddingConfig::provider` is `Onnx` or `Fastembed`.
74/// All other providers use the top-level `EmbeddingConfig` fields only.
75#[cfg(feature = "onnx")]
76#[derive(Debug, Clone, Serialize, Deserialize)]
77pub struct OnnxEmbeddingConfig {
78    /// Path to ONNX model file (.onnx)
79    pub model_path: PathBuf,
80
81    /// Path to tokenizer.json file
82    pub tokenizer_path: PathBuf,
83
84    /// Model name for logging/identification and auto-download selection
85    pub model_name: String,
86
87    /// Embedding dimensions (must match model output)
88    pub dimensions: usize,
89
90    /// Maximum sequence length in tokens (truncate if longer)
91    pub max_sequence_length: usize,
92
93    /// Batch size for ONNX inference (max texts per inference call)
94    pub batch_size: usize,
95}
96
97#[cfg(feature = "onnx")]
98impl Default for OnnxEmbeddingConfig {
99    fn default() -> Self {
100        Self::bge_small("./target/models")
101    }
102}
103
104#[cfg(feature = "onnx")]
105impl OnnxEmbeddingConfig {
106    /// Create config for BGE-Small-v1.5 model
107    pub fn bge_small(model_dir: impl Into<PathBuf>) -> Self {
108        let base = model_dir.into();
109        let model_path = base.join("BGE-Small-v1.5-model_quantized.onnx");
110        let tokenizer_path = base.join("bge-small-tokenizer.json");
111        Self {
112            model_path,
113            tokenizer_path,
114            model_name: "bge-small-en-v1.5".to_string(),
115            dimensions: 384,
116            max_sequence_length: 512,
117            batch_size: 32,
118        }
119    }
120
121    /// Create config for all-MiniLM-L6-v2 model
122    pub fn minilm_l6(model_dir: impl Into<PathBuf>) -> Self {
123        let base = model_dir.into();
124        let model_path = base.join("all-MiniLM-L6-v2.onnx");
125        let tokenizer_path = base.join("minilm-l6-tokenizer.json");
126        Self {
127            model_path,
128            tokenizer_path,
129            model_name: "all-MiniLM-L6-v2".to_string(),
130            dimensions: 384,
131            max_sequence_length: 256,
132            batch_size: 32,
133        }
134    }
135}
136
137/// Unified embedding configuration.
138///
139/// Provider-agnostic; holds fields for all supported backends.
140/// Load from environment variables via [`EmbeddingConfig::from_env`], or construct
141/// programmatically and pass to [`EmbeddingConfig::create_engine`].
142///
143/// Environment variables (match Python SDK names):
144/// - `EMBEDDING_PROVIDER` — backend selection (default: `openai`; `onnx` on Android)
145/// - `MOCK_EMBEDDING` — set to `true`/`1`/`yes` to force mock mode
146/// - `EMBEDDING_MODEL` — model identifier
147/// - `EMBEDDING_DIMENSIONS` — vector size
148/// - `EMBEDDING_ENDPOINT` — API endpoint URL
149/// - `EMBEDDING_API_KEY` — API key (fallback: `LLM_API_KEY`)
150/// - `EMBEDDING_API_VERSION` — API version string
151/// - `EMBEDDING_MAX_COMPLETION_TOKENS` — maximum tokens (default: 8191)
152/// - `EMBEDDING_BATCH_SIZE` — texts per embedding request (default: 36)
153/// - `EMBEDDING_ONNX_BATCH_SIZE` — ONNX inference batch size (default: 32; `onnx` feature only)
154/// - `HUGGINGFACE_TOKENIZER` — HuggingFace tokenizer identifier
155#[derive(Debug, Clone, Serialize, Deserialize)]
156pub struct EmbeddingConfig {
157    /// Which backend to use for embedding generation.
158    pub provider: EmbeddingProvider,
159
160    /// Model identifier. For ONNX this is informational; for API providers this is sent in
161    /// the request body. Default depends on provider (BGE-Small-v1.5 for ONNX, empty for others).
162    pub model: String,
163
164    /// Embedding vector dimensionality. Must match the model output.
165    pub dimensions: usize,
166
167    /// API endpoint URL (used by OpenAI-compatible and Ollama providers).
168    pub endpoint: Option<String>,
169
170    /// API key. Reads `EMBEDDING_API_KEY` first, falls back to `LLM_API_KEY`.
171    pub api_key: Option<String>,
172
173    /// API version string (e.g. "2023-05-15" for Azure OpenAI).
174    pub api_version: Option<String>,
175
176    /// Maximum tokens for completion requests (default: 8191).
177    pub max_completion_tokens: usize,
178
179    /// Number of texts to send in a single embedding request (default: 36).
180    ///
181    /// Matches the Python SDK and stays within the small client-batch limits of
182    /// the self-hosted servers this adapter targets (e.g. TEI defaults to 32).
183    /// Raise it via `EMBEDDING_BATCH_SIZE` for providers that accept larger
184    /// batches. For the OpenAI-compatible engine, up to `MAX_CONCURRENT_BATCHES`
185    /// sub-batches are also dispatched concurrently.
186    pub batch_size: usize,
187
188    /// If true, use mock embeddings regardless of `provider`.
189    /// Overrides `provider` to `Mock`. Set via `MOCK_EMBEDDING=true`.
190    pub mock: bool,
191
192    /// How the mock engine generates vectors when `provider` is `Mock`.
193    /// Defaults to [`MockVectorMode::Zero`]. Set via `MOCK_EMBEDDING=deterministic`
194    /// to derive content-stable vectors from `sha256(text)`.
195    #[serde(default)]
196    pub mock_mode: MockVectorMode,
197
198    /// ONNX-specific configuration. Only consulted when provider is `Onnx` or `Fastembed`.
199    #[cfg(feature = "onnx")]
200    pub onnx: OnnxEmbeddingConfig,
201
202    /// HuggingFace tokenizer identifier for chunking token counting.
203    /// When set, used by `HuggingFaceTokenCounter` in the chunking crate.
204    pub huggingface_tokenizer: Option<String>,
205}
206
207impl Default for EmbeddingConfig {
208    fn default() -> Self {
209        // On Android, local ONNX inference is the right default (edge deployment).
210        // Everywhere else, match the Python SDK default: OpenAI text-embedding-3-small.
211        #[cfg(all(feature = "onnx", target_os = "android"))]
212        let (provider, model, dimensions, endpoint) = {
213            let onnx_cfg = OnnxEmbeddingConfig::default();
214            (
215                EmbeddingProvider::Onnx,
216                onnx_cfg.model_name.clone(),
217                onnx_cfg.dimensions,
218                None,
219            )
220        };
221        #[cfg(all(feature = "onnx", not(target_os = "android")))]
222        let (provider, model, dimensions, endpoint) = {
223            let m = "text-embedding-3-small".to_string();
224            // Resolve via the known-model table so there is one source of truth.
225            let d = known_model_dimensions(EmbeddingProvider::OpenAi, &m)
226                .unwrap_or(FALLBACK_DIMENSIONS);
227            (
228                EmbeddingProvider::OpenAi,
229                m,
230                d,
231                Some("https://api.openai.com/v1".to_string()),
232            )
233        };
234        #[cfg(not(feature = "onnx"))]
235        let (provider, model, dimensions, endpoint) = {
236            let m = "text-embedding-3-small".to_string();
237            // Resolve via the known-model table so there is one source of truth.
238            let d = known_model_dimensions(EmbeddingProvider::OpenAi, &m)
239                .unwrap_or(FALLBACK_DIMENSIONS);
240            (
241                EmbeddingProvider::OpenAi,
242                m,
243                d,
244                Some("https://api.openai.com/v1".to_string()),
245            )
246        };
247
248        Self {
249            provider,
250            model,
251            dimensions,
252            endpoint,
253            api_key: None,
254            api_version: None,
255            max_completion_tokens: 8191,
256            batch_size: 36,
257            mock: false,
258            mock_mode: MockVectorMode::Zero,
259            #[cfg(feature = "onnx")]
260            onnx: OnnxEmbeddingConfig::default(),
261            huggingface_tokenizer: None,
262        }
263    }
264}
265
266impl EmbeddingConfig {
267    /// Load configuration from environment variables.
268    ///
269    /// Reads the same env var names as the Python SDK so that a shared `.env` file
270    /// works across both implementations without modification.
271    pub fn from_env() -> Self {
272        let mut config = Self::default();
273
274        // Parse MOCK_EMBEDDING first — it overrides everything else if set.
275        // `deterministic` (or `hash`) selects the SHA-256-derived deterministic
276        // mode; other truthy values keep the legacy zero-vector mode.
277        if let Ok(val) = std::env::var("MOCK_EMBEDDING") {
278            let val = val.trim().to_lowercase();
279            if val == "deterministic" || val == "hash" {
280                config.mock = true;
281                config.provider = EmbeddingProvider::Mock;
282                config.mock_mode = MockVectorMode::Deterministic;
283                return config;
284            }
285            if val == "true" || val == "1" || val == "yes" {
286                config.mock = true;
287                config.provider = EmbeddingProvider::Mock;
288                config.mock_mode = MockVectorMode::Zero;
289                return config;
290            }
291        }
292
293        // Parse EMBEDDING_PROVIDER
294        if let Ok(val) = std::env::var("EMBEDDING_PROVIDER") {
295            let val = val.trim().to_lowercase();
296            match val.as_str() {
297                "onnx" => config.provider = EmbeddingProvider::Onnx,
298                "fastembed" => config.provider = EmbeddingProvider::Fastembed,
299                "openai" => config.provider = EmbeddingProvider::OpenAi,
300                "openai_compatible" => config.provider = EmbeddingProvider::OpenAiCompatible,
301                "ollama" => config.provider = EmbeddingProvider::Ollama,
302                "mock" => {
303                    config.mock = true;
304                    config.provider = EmbeddingProvider::Mock;
305                }
306                _ => {
307                    // Unknown provider — leave the platform default (OpenAI, or
308                    // ONNX on Android) and log nothing; the caller will get a
309                    // clear error from create_engine() if needed.
310                }
311            }
312        }
313
314        // Apply provider-specific model defaults before checking env var overrides.
315        // This ensures that when a user switches to EMBEDDING_PROVIDER=ollama
316        // without setting EMBEDDING_MODEL explicitly, they get a sensible Ollama
317        // default model name rather than the ONNX model name.
318        // (Dimension is resolved below via the known-model table, not hardcoded here.)
319        if config.provider == EmbeddingProvider::Ollama {
320            config.model = "avr/sfr-embedding-mistral:latest".to_string();
321        }
322
323        // EMBEDDING_MODEL
324        if let Ok(val) = std::env::var("EMBEDDING_MODEL") {
325            let val = val.trim().to_string();
326            if !val.is_empty() {
327                config.model = val;
328            }
329        }
330
331        // EMBEDDING_DIMENSIONS — resolution order (mirrors Python model_post_init):
332        //   1. Explicit EMBEDDING_DIMENSIONS env var — always wins.
333        //   2. known_model_dimensions(provider, model) — table lookup.
334        //   3. Fallback FALLBACK_DIMENSIONS (384) with a tracing::warn! so the user
335        //      knows to set EMBEDDING_DIMENSIONS explicitly for unknown models.
336        // For ONNX the model file dictates the true dimension, so we prefer the
337        // onnx_cfg.dimensions unless the user set EMBEDDING_DIMENSIONS explicitly.
338        let explicit_dims = std::env::var("EMBEDDING_DIMENSIONS")
339            .ok()
340            .and_then(|v| v.trim().parse::<usize>().ok());
341
342        // Resolve via the known-model table, falling back to FALLBACK_DIMENSIONS
343        // with a warning when the model is unknown (parity with Python
344        // model_post_init). Used for every provider except ONNX/Fastembed, whose
345        // dimension is dictated by the model file (handled below).
346        let resolve_from_table = |config: &EmbeddingConfig| match known_model_dimensions(
347            config.provider.clone(),
348            &config.model,
349        ) {
350            // 2. Known model — derived dimension.
351            Some(d) => d,
352            // 3. Unknown model — fallback with warning.
353            None => {
354                tracing::warn!(
355                    provider = ?config.provider,
356                    model = %config.model,
357                    fallback = FALLBACK_DIMENSIONS,
358                    "Could not auto-derive embedding dimensions; set \
359                     EMBEDDING_DIMENSIONS explicitly if your embedder produces \
360                     a different vector size, otherwise the first vector write \
361                     will fail with a shape mismatch."
362                );
363                FALLBACK_DIMENSIONS
364            }
365        };
366
367        config.dimensions = match explicit_dims {
368            // 1. Explicit override always wins.
369            Some(d) => d,
370            None => {
371                // For ONNX/Fastembed the model file carries the authoritative
372                // dimension, so use onnx.dimensions rather than the text table —
373                // this keeps custom ONNX models working.
374                #[cfg(feature = "onnx")]
375                {
376                    if matches!(
377                        config.provider,
378                        EmbeddingProvider::Onnx | EmbeddingProvider::Fastembed
379                    ) {
380                        config.onnx.dimensions
381                    } else {
382                        resolve_from_table(&config)
383                    }
384                }
385                #[cfg(not(feature = "onnx"))]
386                {
387                    resolve_from_table(&config)
388                }
389            }
390        };
391
392        // EMBEDDING_ENDPOINT
393        if let Ok(val) = std::env::var("EMBEDDING_ENDPOINT") {
394            let val = val.trim().to_string();
395            if !val.is_empty() {
396                config.endpoint = Some(val);
397            }
398        }
399
400        // EMBEDDING_API_KEY, fallback to LLM_API_KEY
401        if let Ok(val) = std::env::var("EMBEDDING_API_KEY") {
402            let val = val.trim().to_string();
403            if !val.is_empty() {
404                config.api_key = Some(val);
405            }
406        } else if let Ok(val) = std::env::var("LLM_API_KEY") {
407            let val = val.trim().to_string();
408            if !val.is_empty() {
409                config.api_key = Some(val);
410            }
411        }
412
413        // EMBEDDING_API_VERSION
414        if let Ok(val) = std::env::var("EMBEDDING_API_VERSION") {
415            let val = val.trim().to_string();
416            if !val.is_empty() {
417                config.api_version = Some(val);
418            }
419        }
420
421        // EMBEDDING_MAX_COMPLETION_TOKENS
422        if let Ok(val) = std::env::var("EMBEDDING_MAX_COMPLETION_TOKENS")
423            && let Ok(n) = val.trim().parse::<usize>()
424        {
425            config.max_completion_tokens = n;
426        }
427
428        // EMBEDDING_BATCH_SIZE
429        if let Ok(val) = std::env::var("EMBEDDING_BATCH_SIZE")
430            && let Ok(n) = val.trim().parse::<usize>()
431        {
432            config.batch_size = n;
433        }
434
435        #[cfg(feature = "onnx")]
436        if let Ok(val) = std::env::var("EMBEDDING_ONNX_BATCH_SIZE")
437            && let Ok(n) = val.trim().parse::<usize>()
438            && n > 0
439        {
440            config.onnx.batch_size = n;
441        }
442
443        // HUGGINGFACE_TOKENIZER
444        if let Ok(val) = std::env::var("HUGGINGFACE_TOKENIZER") {
445            let val = val.trim().to_string();
446            if !val.is_empty() {
447                config.huggingface_tokenizer = Some(val);
448            }
449        }
450
451        config
452    }
453
454    /// Returns the effective provider, substituting Mock when `self.mock` is true.
455    pub fn effective_provider(&self) -> EmbeddingProvider {
456        if self.mock {
457            EmbeddingProvider::Mock
458        } else {
459            self.provider.clone()
460        }
461    }
462
463    /// Create an embedding engine based on this configuration.
464    ///
465    /// Dispatches to the appropriate engine implementation based on
466    /// [`EmbeddingConfig::effective_provider`]. Providers not yet implemented
467    /// return [`EmbeddingError::NotImplemented`].
468    pub async fn create_engine(&self) -> EmbeddingResult<Arc<dyn EmbeddingEngine>> {
469        match self.effective_provider() {
470            #[cfg(feature = "onnx")]
471            EmbeddingProvider::Onnx | EmbeddingProvider::Fastembed => {
472                let engine = OnnxEmbeddingEngine::with_auto_download(self.onnx.clone()).await?;
473                Ok(Arc::new(engine))
474            }
475            #[cfg(not(feature = "onnx"))]
476            EmbeddingProvider::Onnx | EmbeddingProvider::Fastembed => {
477                Err(crate::error::EmbeddingError::NotImplemented(
478                    "ONNX embedding engine requires the `onnx` crate feature".to_string(),
479                ))
480            }
481            EmbeddingProvider::OpenAi | EmbeddingProvider::OpenAiCompatible => {
482                let engine = OpenAICompatibleEmbeddingEngine::new(self)?;
483                Ok(Arc::new(engine))
484            }
485            EmbeddingProvider::Ollama => {
486                let engine = OllamaEmbeddingEngine::new(self)?;
487                Ok(Arc::new(engine))
488            }
489            EmbeddingProvider::Mock => Ok(Arc::new(
490                MockEmbeddingEngine::new(self.dimensions).with_mode(self.mock_mode),
491            )),
492        }
493    }
494}
495
496#[cfg(test)]
497mod tests {
498    use super::*;
499    use serial_test::serial;
500
501    #[test]
502    #[cfg(all(feature = "onnx", target_os = "android"))]
503    fn test_default_is_onnx_on_android() {
504        let config = EmbeddingConfig::default();
505        assert_eq!(config.provider, EmbeddingProvider::Onnx);
506        assert_eq!(config.dimensions, 384);
507        assert_eq!(config.batch_size, 36);
508        assert_eq!(config.max_completion_tokens, 8191);
509        assert!(!config.mock);
510    }
511
512    #[test]
513    #[cfg(not(target_os = "android"))]
514    fn test_default_is_openai_off_android() {
515        let config = EmbeddingConfig::default();
516        assert_eq!(config.provider, EmbeddingProvider::OpenAi);
517        assert_eq!(config.model, "text-embedding-3-small");
518        assert_eq!(config.dimensions, 1536);
519        assert_eq!(
520            config.endpoint.as_deref(),
521            Some("https://api.openai.com/v1")
522        );
523        assert!(!config.mock);
524    }
525
526    #[test]
527    fn test_effective_provider_mock_override() {
528        let config = EmbeddingConfig {
529            mock: true,
530            ..Default::default()
531        };
532        assert_eq!(config.effective_provider(), EmbeddingProvider::Mock);
533    }
534
535    #[test]
536    #[cfg(all(feature = "onnx", target_os = "android"))]
537    fn test_effective_provider_passthrough_onnx() {
538        let config = EmbeddingConfig::default();
539        assert_eq!(config.effective_provider(), EmbeddingProvider::Onnx);
540    }
541
542    #[test]
543    #[cfg(not(target_os = "android"))]
544    fn test_effective_provider_passthrough_openai() {
545        let config = EmbeddingConfig::default();
546        assert_eq!(config.effective_provider(), EmbeddingProvider::OpenAi);
547    }
548
549    // env-var tests mutate global process state, so they are serialized with
550    // #[serial] to prevent races with each other. All env-mutating tests in this
551    // crate live in this single test binary, so serial_test (which serializes
552    // within a process) is sufficient; each test also cleans up its own vars.
553
554    #[test]
555    #[serial]
556    fn test_from_env_mock_embedding_true() {
557        // SAFETY: env var mutation is safe because #[serial] guarantees no other
558        // env-mutating test in this binary runs concurrently.
559        unsafe { std::env::set_var("MOCK_EMBEDDING", "true") };
560        let config = EmbeddingConfig::from_env();
561        unsafe { std::env::remove_var("MOCK_EMBEDDING") };
562        assert!(config.mock);
563        assert_eq!(config.effective_provider(), EmbeddingProvider::Mock);
564    }
565
566    #[test]
567    #[serial]
568    fn test_from_env_mock_embedding_numeric() {
569        // SAFETY: see test_from_env_mock_embedding_true
570        unsafe { std::env::set_var("MOCK_EMBEDDING", "1") };
571        let config = EmbeddingConfig::from_env();
572        unsafe { std::env::remove_var("MOCK_EMBEDDING") };
573        assert!(config.mock);
574        // Legacy truthy values keep the zero-vector mode.
575        assert_eq!(config.mock_mode, MockVectorMode::Zero);
576    }
577
578    #[test]
579    #[ignore = "mutates global env vars; run with --test-threads=1 --ignored"]
580    fn test_from_env_mock_embedding_deterministic() {
581        // SAFETY: see test_from_env_mock_embedding_true
582        unsafe { std::env::set_var("MOCK_EMBEDDING", "deterministic") };
583        let config = EmbeddingConfig::from_env();
584        unsafe { std::env::remove_var("MOCK_EMBEDDING") };
585        assert!(config.mock);
586        assert_eq!(config.effective_provider(), EmbeddingProvider::Mock);
587        assert_eq!(config.mock_mode, MockVectorMode::Deterministic);
588    }
589
590    #[test]
591    #[serial]
592    fn test_from_env_provider() {
593        // SAFETY: see test_from_env_mock_embedding_true
594        unsafe { std::env::set_var("EMBEDDING_PROVIDER", "openai") };
595        let config = EmbeddingConfig::from_env();
596        unsafe { std::env::remove_var("EMBEDDING_PROVIDER") };
597        assert_eq!(config.provider, EmbeddingProvider::OpenAi);
598    }
599
600    #[test]
601    #[serial]
602    fn test_from_env_fastembed_alias() {
603        // SAFETY: see test_from_env_mock_embedding_true
604        unsafe { std::env::set_var("EMBEDDING_PROVIDER", "fastembed") };
605        let config = EmbeddingConfig::from_env();
606        unsafe { std::env::remove_var("EMBEDDING_PROVIDER") };
607        assert_eq!(config.provider, EmbeddingProvider::Fastembed);
608    }
609
610    #[test]
611    #[serial]
612    fn test_from_env_dimensions() {
613        // SAFETY: see test_from_env_mock_embedding_true
614        unsafe { std::env::set_var("EMBEDDING_DIMENSIONS", "1536") };
615        let config = EmbeddingConfig::from_env();
616        unsafe { std::env::remove_var("EMBEDDING_DIMENSIONS") };
617        assert_eq!(config.dimensions, 1536);
618    }
619
620    #[test]
621    #[serial]
622    fn test_from_env_api_key_fallback() {
623        // SAFETY: see test_from_env_mock_embedding_true
624        unsafe { std::env::remove_var("EMBEDDING_API_KEY") };
625        unsafe { std::env::set_var("LLM_API_KEY", "my-llm-key") };
626        let config = EmbeddingConfig::from_env();
627        unsafe { std::env::remove_var("LLM_API_KEY") };
628        assert_eq!(config.api_key, Some("my-llm-key".to_string()));
629    }
630
631    #[test]
632    #[serial]
633    fn test_from_env_api_key_prefers_embedding() {
634        // SAFETY: see test_from_env_mock_embedding_true
635        unsafe { std::env::set_var("EMBEDDING_API_KEY", "embed-key") };
636        unsafe { std::env::set_var("LLM_API_KEY", "llm-key") };
637        let config = EmbeddingConfig::from_env();
638        unsafe { std::env::remove_var("EMBEDDING_API_KEY") };
639        unsafe { std::env::remove_var("LLM_API_KEY") };
640        assert_eq!(config.api_key, Some("embed-key".to_string()));
641    }
642
643    #[test]
644    #[cfg(feature = "onnx")]
645    #[serial]
646    fn from_env_onnx_batch_size_override() {
647        // SAFETY: see test_from_env_mock_embedding_true
648        unsafe { std::env::set_var("EMBEDDING_ONNX_BATCH_SIZE", "8") };
649        let config = EmbeddingConfig::from_env();
650        unsafe { std::env::remove_var("EMBEDDING_ONNX_BATCH_SIZE") };
651        assert_eq!(config.onnx.batch_size, 8);
652    }
653
654    #[test]
655    #[cfg(feature = "onnx")]
656    fn test_onnx_config_bge_small() {
657        let cfg = OnnxEmbeddingConfig::bge_small("/models");
658        assert_eq!(cfg.dimensions, 384);
659        assert_eq!(cfg.max_sequence_length, 512);
660        assert_eq!(cfg.model_name, "bge-small-en-v1.5");
661    }
662
663    #[test]
664    #[cfg(feature = "onnx")]
665    fn test_onnx_config_minilm_l6() {
666        let cfg = OnnxEmbeddingConfig::minilm_l6("/models");
667        assert_eq!(cfg.dimensions, 384);
668        assert_eq!(cfg.max_sequence_length, 256);
669        assert_eq!(cfg.model_name, "all-MiniLM-L6-v2");
670    }
671
672    // ── known_model_dimensions unit tests ──────────────────────────────────
673    // These are pure lookup tests — no env vars, no network, no model files.
674
675    #[test]
676    fn known_dims_openai_large() {
677        assert_eq!(
678            known_model_dimensions(EmbeddingProvider::OpenAi, "text-embedding-3-large"),
679            Some(3072),
680        );
681    }
682
683    #[test]
684    fn known_dims_openai_small() {
685        assert_eq!(
686            known_model_dimensions(EmbeddingProvider::OpenAi, "text-embedding-3-small"),
687            Some(1536),
688        );
689    }
690
691    #[test]
692    fn known_dims_ada_002() {
693        assert_eq!(
694            known_model_dimensions(EmbeddingProvider::OpenAi, "text-embedding-ada-002"),
695            Some(1536),
696        );
697    }
698
699    /// Verify that a provider-prefixed model name is normalized before lookup.
700    /// Python uses `model.split("/")[-1]`; Rust uses `rsplit('/').next()`.
701    #[test]
702    fn known_dims_prefix_stripped() {
703        assert_eq!(
704            known_model_dimensions(EmbeddingProvider::OpenAi, "openai/text-embedding-3-small"),
705            Some(1536),
706        );
707        // Azure-prefixed variant
708        assert_eq!(
709            known_model_dimensions(
710                EmbeddingProvider::OpenAiCompatible,
711                "azure/text-embedding-3-large"
712            ),
713            Some(3072),
714        );
715    }
716
717    /// BGE-Small variants: bare name (both v1.5 spellings) and BAAI-prefixed.
718    #[test]
719    fn known_dims_bge_small() {
720        assert_eq!(
721            known_model_dimensions(EmbeddingProvider::Onnx, "bge-small-en-v1.5"),
722            Some(384),
723        );
724        assert_eq!(
725            known_model_dimensions(EmbeddingProvider::Onnx, "BGE-Small-v1.5"),
726            Some(384),
727        );
728        // fastembed-style prefix stripped correctly
729        assert_eq!(
730            known_model_dimensions(EmbeddingProvider::Fastembed, "BAAI/bge-small-en-v1.5"),
731            Some(384),
732        );
733    }
734
735    #[test]
736    fn known_dims_bge_large() {
737        assert_eq!(
738            known_model_dimensions(EmbeddingProvider::Fastembed, "bge-large-en-v1.5"),
739            Some(1024),
740        );
741    }
742
743    #[test]
744    fn known_dims_unknown_returns_none() {
745        assert_eq!(
746            known_model_dimensions(EmbeddingProvider::OpenAi, "some-unknown-model"),
747            None,
748        );
749    }
750
751    // ── from_env dimension-resolution tests ────────────────────────────────
752    // These mutate process env vars and must not run in parallel.
753
754    /// Explicit EMBEDDING_DIMENSIONS always overrides the table lookup.
755    #[test]
756    #[serial]
757    fn from_env_explicit_override_wins() {
758        // SAFETY: #[serial] guarantees no concurrent env readers in this binary.
759        unsafe {
760            std::env::set_var("EMBEDDING_PROVIDER", "openai");
761            std::env::set_var("EMBEDDING_MODEL", "text-embedding-3-large");
762            std::env::set_var("EMBEDDING_DIMENSIONS", "999");
763        }
764        let config = EmbeddingConfig::from_env();
765        unsafe {
766            std::env::remove_var("EMBEDDING_PROVIDER");
767            std::env::remove_var("EMBEDDING_MODEL");
768            std::env::remove_var("EMBEDDING_DIMENSIONS");
769        }
770        // Explicit env var must win over the table value (3072).
771        assert_eq!(config.dimensions, 999);
772    }
773
774    /// Changing EMBEDDING_MODEL to a known model (without EMBEDDING_DIMENSIONS) must
775    /// resolve the correct dimension — not silently keep the default 384.
776    /// This is the regression this task fixes (audit B7.2).
777    #[test]
778    #[serial]
779    fn from_env_model_change_resolves() {
780        // SAFETY: see from_env_explicit_override_wins
781        unsafe {
782            std::env::set_var("EMBEDDING_PROVIDER", "openai");
783            std::env::set_var("EMBEDDING_MODEL", "text-embedding-3-large");
784            std::env::remove_var("EMBEDDING_DIMENSIONS");
785        }
786        let config = EmbeddingConfig::from_env();
787        unsafe {
788            std::env::remove_var("EMBEDDING_PROVIDER");
789            std::env::remove_var("EMBEDDING_MODEL");
790        }
791        // Previously returned 384 (the ONNX default); now must return 3072.
792        assert_eq!(config.dimensions, 3072);
793    }
794
795    /// An unknown model with no explicit EMBEDDING_DIMENSIONS must fall back to
796    /// FALLBACK_DIMENSIONS (384) and log a warning (we only assert the dimension here).
797    #[test]
798    #[serial]
799    fn from_env_unknown_falls_back() {
800        // SAFETY: see from_env_explicit_override_wins
801        unsafe {
802            std::env::set_var("EMBEDDING_PROVIDER", "openai");
803            std::env::set_var("EMBEDDING_MODEL", "some-unknown-model-xyz");
804            std::env::remove_var("EMBEDDING_DIMENSIONS");
805        }
806        let config = EmbeddingConfig::from_env();
807        unsafe {
808            std::env::remove_var("EMBEDDING_PROVIDER");
809            std::env::remove_var("EMBEDDING_MODEL");
810        }
811        assert_eq!(config.dimensions, FALLBACK_DIMENSIONS);
812    }
813}