memvid_cli/
config.rs

1//! CLI configuration and environment handling
2//!
3//! This module provides configuration loading from environment variables,
4//! tracing initialization, and embedding runtime management for semantic search.
5
6use std::env;
7use std::path::PathBuf;
8use std::str::FromStr;
9use std::sync::atomic::{AtomicUsize, Ordering};
10
11use anyhow::{anyhow, Result};
12use ed25519_dalek::VerifyingKey;
13
14const DEFAULT_API_URL: &str = "https://memvid.com";
15const DEFAULT_CACHE_DIR: &str = "~/.cache/memvid";
16
17/// Supported embedding models for semantic search
18#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
19pub enum EmbeddingModelChoice {
20    /// BGE-small-en-v1.5: Fast, 384-dim, ~78% accuracy (default)
21    #[default]
22    BgeSmall,
23    /// BGE-base-en-v1.5: Balanced, 768-dim, ~85% accuracy
24    BgeBase,
25    /// Nomic-embed-text-v1.5: High accuracy, 768-dim, ~86% accuracy
26    Nomic,
27    /// GTE-large-en-v1.5: Best semantic depth, 1024-dim
28    GteLarge,
29    /// OpenAI text-embedding-3-large: Highest quality, 3072-dim (requires OPENAI_API_KEY)
30    OpenAILarge,
31    /// OpenAI text-embedding-3-small: Good quality, 1536-dim (requires OPENAI_API_KEY)
32    OpenAISmall,
33    /// OpenAI text-embedding-ada-002: Legacy model, 1536-dim (requires OPENAI_API_KEY)
34    OpenAIAda,
35    /// NVIDIA nv-embed-v1: High quality, remote embeddings (requires NVIDIA_API_KEY)
36    Nvidia,
37    /// Gemini text-embedding-004: Google AI embeddings, 768-dim (requires GOOGLE_API_KEY or GEMINI_API_KEY)
38    Gemini,
39    /// Mistral mistral-embed: Mistral AI embeddings, 1024-dim (requires MISTRAL_API_KEY)
40    Mistral,
41}
42
43impl EmbeddingModelChoice {
44    /// Check if this is an OpenAI model (requires OPENAI_API_KEY)
45    pub fn is_openai(&self) -> bool {
46        matches!(
47            self,
48            EmbeddingModelChoice::OpenAILarge
49                | EmbeddingModelChoice::OpenAISmall
50                | EmbeddingModelChoice::OpenAIAda
51        )
52    }
53
54    /// Check if this is a remote/cloud model (not local fastembed)
55    pub fn is_remote(&self) -> bool {
56        matches!(
57            self,
58            EmbeddingModelChoice::OpenAILarge
59                | EmbeddingModelChoice::OpenAISmall
60                | EmbeddingModelChoice::OpenAIAda
61                | EmbeddingModelChoice::Nvidia
62                | EmbeddingModelChoice::Gemini
63                | EmbeddingModelChoice::Mistral
64        )
65    }
66
67    /// Get the fastembed EmbeddingModel enum value (only for local models)
68    ///
69    /// # Panics
70    /// Panics if called on an OpenAI model. Use `is_openai()` to check first.
71    pub fn to_fastembed_model(&self) -> fastembed::EmbeddingModel {
72        match self {
73            EmbeddingModelChoice::BgeSmall => fastembed::EmbeddingModel::BGESmallENV15,
74            EmbeddingModelChoice::BgeBase => fastembed::EmbeddingModel::BGEBaseENV15,
75            EmbeddingModelChoice::Nomic => fastembed::EmbeddingModel::NomicEmbedTextV15,
76            EmbeddingModelChoice::GteLarge => fastembed::EmbeddingModel::GTELargeENV15,
77            EmbeddingModelChoice::OpenAILarge
78            | EmbeddingModelChoice::OpenAISmall
79            | EmbeddingModelChoice::OpenAIAda => {
80                panic!("OpenAI models don't use fastembed. Check is_remote() first.")
81            }
82            EmbeddingModelChoice::Nvidia => {
83                panic!("NVIDIA embeddings don't use fastembed. Check is_remote() first.")
84            }
85            EmbeddingModelChoice::Gemini => {
86                panic!("Gemini embeddings don't use fastembed. Check is_remote() first.")
87            }
88            EmbeddingModelChoice::Mistral => {
89                panic!("Mistral embeddings don't use fastembed. Check is_remote() first.")
90            }
91        }
92    }
93
94    /// Get human-readable model name
95    pub fn name(&self) -> &'static str {
96        match self {
97            EmbeddingModelChoice::BgeSmall => "bge-small",
98            EmbeddingModelChoice::BgeBase => "bge-base",
99            EmbeddingModelChoice::Nomic => "nomic",
100            EmbeddingModelChoice::GteLarge => "gte-large",
101            EmbeddingModelChoice::OpenAILarge => "openai-large",
102            EmbeddingModelChoice::OpenAISmall => "openai-small",
103            EmbeddingModelChoice::OpenAIAda => "openai-ada",
104            EmbeddingModelChoice::Nvidia => "nvidia",
105            EmbeddingModelChoice::Gemini => "gemini",
106            EmbeddingModelChoice::Mistral => "mistral",
107        }
108    }
109
110    /// Get the canonical provider model identifier used for persisted metadata.
111    ///
112    /// This is intended to match upstream provider IDs (OpenAI) and HuggingFace-style IDs
113    /// (fastembed/ONNX) so that memories can record an embedding "identity" that other
114    /// runtimes can select deterministically.
115    pub fn canonical_model_id(&self) -> &'static str {
116        match self {
117            EmbeddingModelChoice::BgeSmall => "BAAI/bge-small-en-v1.5",
118            EmbeddingModelChoice::BgeBase => "BAAI/bge-base-en-v1.5",
119            EmbeddingModelChoice::Nomic => "nomic-embed-text-v1.5",
120            EmbeddingModelChoice::GteLarge => "thenlper/gte-large",
121            EmbeddingModelChoice::OpenAILarge => "text-embedding-3-large",
122            EmbeddingModelChoice::OpenAISmall => "text-embedding-3-small",
123            EmbeddingModelChoice::OpenAIAda => "text-embedding-ada-002",
124            EmbeddingModelChoice::Nvidia => "nvidia/nv-embed-v1",
125            EmbeddingModelChoice::Gemini => "text-embedding-004",
126            EmbeddingModelChoice::Mistral => "mistral-embed",
127        }
128    }
129
130    /// Get embedding dimensions
131    pub fn dimensions(&self) -> usize {
132        match self {
133            EmbeddingModelChoice::BgeSmall => 384,
134            EmbeddingModelChoice::BgeBase => 768,
135            EmbeddingModelChoice::Nomic => 768,
136            EmbeddingModelChoice::GteLarge => 1024,
137            EmbeddingModelChoice::OpenAILarge => 3072,
138            EmbeddingModelChoice::OpenAISmall => 1536,
139            EmbeddingModelChoice::OpenAIAda => 1536,
140            // Remote model; infer from the first embedding response.
141            EmbeddingModelChoice::Nvidia => 0,
142            EmbeddingModelChoice::Gemini => 768,
143            EmbeddingModelChoice::Mistral => 1024,
144        }
145    }
146}
147
148impl FromStr for EmbeddingModelChoice {
149    type Err = anyhow::Error;
150
151    fn from_str(s: &str) -> Result<Self> {
152        let lowered = s.trim().to_ascii_lowercase();
153        match lowered.as_str() {
154            "bge-small" | "bge_small" | "bgesmall" | "small" => Ok(EmbeddingModelChoice::BgeSmall),
155            "baai/bge-small-en-v1.5" => Ok(EmbeddingModelChoice::BgeSmall),
156            "bge-base" | "bge_base" | "bgebase" | "base" => Ok(EmbeddingModelChoice::BgeBase),
157            "baai/bge-base-en-v1.5" => Ok(EmbeddingModelChoice::BgeBase),
158            "nomic" | "nomic-embed" | "nomic_embed" => Ok(EmbeddingModelChoice::Nomic),
159            "nomic-embed-text-v1.5" => Ok(EmbeddingModelChoice::Nomic),
160            "gte-large" | "gte_large" | "gtelarge" | "gte" => Ok(EmbeddingModelChoice::GteLarge),
161            "thenlper/gte-large" => Ok(EmbeddingModelChoice::GteLarge),
162            // OpenAI models - default "openai" maps to "openai-large" for highest quality
163            "openai" | "openai-large" | "openai_large" | "text-embedding-3-large" => {
164                Ok(EmbeddingModelChoice::OpenAILarge)
165            }
166            "openai-small" | "openai_small" | "text-embedding-3-small" => {
167                Ok(EmbeddingModelChoice::OpenAISmall)
168            }
169            "openai-ada" | "openai_ada" | "text-embedding-ada-002" | "ada" => {
170                Ok(EmbeddingModelChoice::OpenAIAda)
171            }
172            "nvidia" | "nv" | "nv-embed-v1" | "nvidia/nv-embed-v1" => Ok(EmbeddingModelChoice::Nvidia),
173            _ if lowered.starts_with("nvidia/") || lowered.starts_with("nvidia:") || lowered.starts_with("nv:") => {
174                Ok(EmbeddingModelChoice::Nvidia)
175            }
176            // Gemini embeddings
177            "gemini" | "gemini-embed" | "text-embedding-004" | "gemini-embedding-001" => {
178                Ok(EmbeddingModelChoice::Gemini)
179            }
180            _ if lowered.starts_with("gemini/") || lowered.starts_with("gemini:") || lowered.starts_with("google:") => {
181                Ok(EmbeddingModelChoice::Gemini)
182            }
183            // Mistral embeddings
184            "mistral" | "mistral-embed" => Ok(EmbeddingModelChoice::Mistral),
185            _ if lowered.starts_with("mistral/") || lowered.starts_with("mistral:") => {
186                Ok(EmbeddingModelChoice::Mistral)
187            }
188            _ => Err(anyhow!(
189                "unknown embedding model '{}'. Valid options: bge-small, bge-base, nomic, gte-large, openai, openai-small, openai-ada, nvidia, gemini, mistral",
190                s
191            )),
192        }
193    }
194}
195
196impl std::fmt::Display for EmbeddingModelChoice {
197    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
198        write!(f, "{}", self.name())
199    }
200}
201
202impl EmbeddingModelChoice {
203    /// Infer the best embedding model from vector dimension stored in MV2 file.
204    ///
205    /// This enables auto-detection: users don't need to specify --query-embedding-model
206    /// if the MV2 file has vectors. The dimension uniquely identifies the model family.
207    ///
208    /// # Dimension Mapping
209    /// - 384  → BGE-small (default local model)
210    /// - 768  → BGE-base (could also be Nomic, but same dimension works)
211    /// - 1024 → GTE-large
212    /// - 1536 → OpenAI small/ada
213    /// - 3072 → OpenAI large
214    pub fn from_dimension(dim: u32) -> Option<Self> {
215        match dim {
216            384 => Some(EmbeddingModelChoice::BgeSmall),
217            768 => Some(EmbeddingModelChoice::BgeBase), // Could be Nomic, but same dim
218            1024 => Some(EmbeddingModelChoice::GteLarge),
219            1536 => Some(EmbeddingModelChoice::OpenAISmall), // Could be Ada, same dim
220            3072 => Some(EmbeddingModelChoice::OpenAILarge),
221            0 => None, // No vectors in file
222            _ => {
223                tracing::warn!("Unknown embedding dimension {}, using default model", dim);
224                None
225            }
226        }
227    }
228}
229
230/// CLI configuration loaded from environment variables and config file
231#[derive(Debug, Clone)]
232pub struct CliConfig {
233    pub api_key: Option<String>,
234    pub api_url: String,
235    /// Default memory ID for dashboard sync (from config file)
236    pub memory_id: Option<String>,
237    pub cache_dir: PathBuf,
238    pub ticket_pubkey: Option<VerifyingKey>,
239    pub models_dir: PathBuf,
240    pub offline: bool,
241    /// Embedding model for semantic search (can be overridden by CLI flag)
242    pub embedding_model: EmbeddingModelChoice,
243}
244
245impl PartialEq for CliConfig {
246    fn eq(&self, other: &Self) -> bool {
247        self.api_key == other.api_key
248            && self.api_url == other.api_url
249            && self.memory_id == other.memory_id
250            && self.cache_dir == other.cache_dir
251            && self.models_dir == other.models_dir
252            && self.offline == other.offline
253            && self.embedding_model == other.embedding_model
254    }
255}
256
257impl Eq for CliConfig {}
258
259impl CliConfig {
260    pub fn load() -> Result<Self> {
261        // Load persistent config file (if exists) for fallback values
262        let persistent_config = crate::commands::config::PersistentConfig::load().ok();
263
264        // API Key: env var takes precedence, then config file
265        let api_key = env::var("MEMVID_API_KEY")
266            .ok()
267            .and_then(|value| {
268                let trimmed = value.trim().to_string();
269                (!trimmed.is_empty()).then_some(trimmed)
270            })
271            .or_else(|| persistent_config.as_ref().and_then(|c| c.api_key.clone()));
272
273        // API URL: env var takes precedence, then config file, then default
274        let api_url = env::var("MEMVID_API_URL")
275            .ok()
276            .or_else(|| persistent_config.as_ref().and_then(|c| c.api_url.clone()))
277            .unwrap_or_else(|| DEFAULT_API_URL.to_string());
278
279        // Memory ID: env var takes precedence, then config file (memory.default or legacy memory_id)
280        let memory_id = env::var("MEMVID_MEMORY_ID")
281            .ok()
282            .and_then(|value| {
283                let trimmed = value.trim().to_string();
284                (!trimmed.is_empty()).then_some(trimmed)
285            })
286            .or_else(|| persistent_config.as_ref().and_then(|c| c.default_memory_id()));
287
288        let cache_dir_raw =
289            env::var("MEMVID_CACHE_DIR").unwrap_or_else(|_| DEFAULT_CACHE_DIR.to_string());
290        let cache_dir = expand_path(&cache_dir_raw)?;
291
292        let models_dir_raw =
293            env::var("MEMVID_MODELS_DIR").unwrap_or_else(|_| "~/.memvid/models".to_string());
294        let models_dir = expand_path(&models_dir_raw)?;
295
296        // Default public key for memvid.com dashboard ticket verification
297        // This allows users to use --memory-id without setting MEMVID_TICKET_PUBKEY
298        const DEFAULT_TICKET_PUBKEY: &str = "8wP1J2H+Tlx3PM3eT0lN2wDvoYrvl1DREKGKVb/V2cw=";
299
300        let ticket_pubkey_str = env::var("MEMVID_TICKET_PUBKEY")
301            .ok()
302            .and_then(|value| {
303                let trimmed = value.trim();
304                if trimmed.is_empty() {
305                    None
306                } else {
307                    Some(trimmed.to_string())
308                }
309            })
310            .unwrap_or_else(|| DEFAULT_TICKET_PUBKEY.to_string());
311
312        let ticket_pubkey = Some(memvid_core::parse_ed25519_public_key_base64(&ticket_pubkey_str)?);
313
314        let offline = env::var("MEMVID_OFFLINE")
315            .ok()
316            .map(|value| match value.trim().to_ascii_lowercase().as_str() {
317                "1" | "true" | "yes" => true,
318                _ => false,
319            })
320            .unwrap_or(false);
321
322        // Load embedding model from env var, default to BGE-small
323        let embedding_model = env::var("MEMVID_EMBEDDING_MODEL")
324            .ok()
325            .and_then(|value| {
326                let trimmed = value.trim();
327                if trimmed.is_empty() {
328                    None
329                } else {
330                    EmbeddingModelChoice::from_str(trimmed).ok()
331                }
332            })
333            .unwrap_or_default();
334
335        Ok(Self {
336            api_key,
337            api_url,
338            memory_id,
339            cache_dir,
340            ticket_pubkey,
341            models_dir,
342            offline,
343            embedding_model,
344        })
345    }
346
347    /// Create a new config with a different embedding model
348    pub fn with_embedding_model(&self, model: EmbeddingModelChoice) -> Self {
349        Self {
350            embedding_model: model,
351            ..self.clone()
352        }
353    }
354}
355
356fn expand_path(value: &str) -> Result<PathBuf> {
357    if value.trim().is_empty() {
358        return Err(anyhow!("cache directory cannot be empty"));
359    }
360
361    let expanded = if let Some(stripped) = value.strip_prefix("~/") {
362        home_dir()?.join(stripped)
363    } else if let Some(stripped) = value.strip_prefix("~\\") {
364        // Support Windows-style "~\" prefix.
365        home_dir()?.join(stripped)
366    } else if value == "~" {
367        home_dir()?
368    } else {
369        PathBuf::from(value)
370    };
371
372    if expanded.is_absolute() {
373        Ok(expanded)
374    } else {
375        Ok(env::current_dir()?.join(expanded))
376    }
377}
378
379fn home_dir() -> Result<PathBuf> {
380    if let Some(path) = env::var_os("HOME") {
381        if !path.is_empty() {
382            return Ok(PathBuf::from(path));
383        }
384    }
385
386    #[cfg(windows)]
387    {
388        if let Some(path) = env::var_os("USERPROFILE") {
389            if !path.is_empty() {
390                return Ok(PathBuf::from(path));
391            }
392        }
393        if let (Some(drive), Some(path)) = (env::var_os("HOMEDRIVE"), env::var_os("HOMEPATH")) {
394            if !drive.is_empty() && !path.is_empty() {
395                return Ok(PathBuf::from(format!(
396                    "{}{}",
397                    drive.to_string_lossy(),
398                    path.to_string_lossy()
399                )));
400            }
401        }
402    }
403
404    Err(anyhow!("unable to resolve home directory"))
405}
406
407#[cfg(test)]
408mod tests {
409    use super::*;
410    use base64::engine::general_purpose::STANDARD as BASE64_STANDARD;
411    use base64::Engine;
412    use ed25519_dalek::SigningKey;
413    use std::sync::{Mutex, OnceLock};
414
415    fn env_lock() -> std::sync::MutexGuard<'static, ()> {
416        static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
417        LOCK.get_or_init(|| Mutex::new(())).lock().unwrap()
418    }
419
420    fn set_or_unset(var: &str, value: Option<String>) {
421        match value {
422            Some(v) => unsafe { env::set_var(var, v) },
423            None => unsafe { env::remove_var(var) },
424        }
425    }
426
427    #[test]
428    fn defaults_expand_using_home_directory() {
429        let _guard = env_lock();
430
431        let previous_home = env::var("HOME").ok();
432        #[cfg(windows)]
433        let previous_userprofile = env::var("USERPROFILE").ok();
434
435        for var in [
436            "MEMVID_API_KEY",
437            "MEMVID_API_URL",
438            "MEMVID_CACHE_DIR",
439            "MEMVID_TICKET_PUBKEY",
440            "MEMVID_MODELS_DIR",
441            "MEMVID_OFFLINE",
442        ] {
443            unsafe { env::remove_var(var) };
444        }
445
446        let tmp = tempfile::tempdir().expect("tmpdir");
447        let tmp_path = tmp.path().to_path_buf();
448        unsafe { env::set_var("HOME", &tmp_path) };
449        #[cfg(windows)]
450        unsafe {
451            env::set_var("USERPROFILE", &tmp_path)
452        };
453
454        let config = CliConfig::load().expect("load");
455        assert_eq!(config.api_key, None);
456        assert_eq!(config.api_url, "https://memvid.com");
457        assert_eq!(config.cache_dir, tmp_path.join(".cache/memvid"));
458        // ticket_pubkey has a default value now, so it should be Some
459        assert!(config.ticket_pubkey.is_some());
460        assert_eq!(config.models_dir, tmp_path.join(".memvid/models"));
461        assert!(!config.offline);
462
463        set_or_unset("HOME", previous_home);
464        #[cfg(windows)]
465        {
466            set_or_unset("USERPROFILE", previous_userprofile);
467        }
468    }
469
470    #[test]
471    fn env_overrides_are_respected() {
472        let _guard = env_lock();
473
474        let previous_env: Vec<(&'static str, Option<String>)> = [
475            "MEMVID_API_KEY",
476            "MEMVID_API_URL",
477            "MEMVID_CACHE_DIR",
478            "MEMVID_TICKET_PUBKEY",
479            "MEMVID_MODELS_DIR",
480            "MEMVID_OFFLINE",
481        ]
482        .into_iter()
483        .map(|var| (var, env::var(var).ok()))
484        .collect();
485
486        unsafe { env::set_var("MEMVID_API_KEY", "abc123") };
487        unsafe { env::set_var("MEMVID_API_URL", "https://staging.memvid.app") };
488        unsafe { env::set_var("MEMVID_CACHE_DIR", "~/memvid-cache") };
489        unsafe { env::set_var("MEMVID_MODELS_DIR", "~/models") };
490        unsafe { env::set_var("MEMVID_OFFLINE", "true") };
491        let signing = SigningKey::from_bytes(&[9u8; 32]);
492        let encoded = BASE64_STANDARD.encode(signing.verifying_key().as_bytes());
493        unsafe { env::set_var("MEMVID_TICKET_PUBKEY", encoded) };
494
495        let tmp = tempfile::tempdir().expect("tmpdir");
496        let tmp_path = tmp.path().to_path_buf();
497        unsafe { env::set_var("HOME", &tmp_path) };
498        #[cfg(windows)]
499        unsafe {
500            env::set_var("USERPROFILE", &tmp_path)
501        };
502
503        let config = CliConfig::load().expect("load");
504        assert_eq!(config.api_key.as_deref(), Some("abc123"));
505        assert_eq!(config.api_url, "https://staging.memvid.app");
506        assert_eq!(config.cache_dir, tmp_path.join("memvid-cache"));
507        assert_eq!(
508            config.ticket_pubkey.expect("pubkey").as_bytes(),
509            signing.verifying_key().as_bytes()
510        );
511        assert_eq!(config.models_dir, tmp_path.join("models"));
512        assert!(config.offline);
513
514        for (var, value) in previous_env {
515            set_or_unset(var, value);
516        }
517    }
518
519    #[test]
520    fn rejects_empty_cache_dir() {
521        let _guard = env_lock();
522
523        let previous = env::var("MEMVID_CACHE_DIR").ok();
524        unsafe { env::set_var("MEMVID_CACHE_DIR", " ") };
525        let err = CliConfig::load().expect_err("should fail");
526        assert!(err.to_string().contains("cache directory"));
527        set_or_unset("MEMVID_CACHE_DIR", previous);
528    }
529}
530
531/// Initialize tracing/logging based on verbosity level
532pub fn init_tracing(verbosity: u8) -> Result<()> {
533    use std::io::IsTerminal;
534    use tracing_subscriber::{filter::Directive, fmt, EnvFilter};
535
536    let level = match verbosity {
537        0 => "warn",
538        1 => "info",
539        2 => "debug",
540        _ => "trace",
541    };
542
543    let mut env_filter =
544        EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new(level));
545    for directive_str in ["llama_cpp=error", "llama_cpp_sys=error", "ggml=error"] {
546        if let Ok(directive) = directive_str.parse::<Directive>() {
547            env_filter = env_filter.add_directive(directive);
548        }
549    }
550
551    // Disable ANSI color codes when stderr is not a terminal (e.g., piped or
552    // combined with `2>&1`). This prevents control characters from polluting
553    // JSON output when combined with stdout.
554    let use_ansi = std::io::stderr().is_terminal();
555
556    fmt()
557        .with_env_filter(env_filter)
558        .with_writer(std::io::stderr)
559        .with_target(false)
560        .without_time()
561        .with_ansi(use_ansi)
562        .try_init()
563        .map_err(|err| anyhow!(err))?;
564    Ok(())
565}
566
567/// Resolve LLM context budget override from CLI or environment
568pub fn resolve_llm_context_budget_override(cli_value: Option<usize>) -> Result<Option<usize>> {
569    use anyhow::bail;
570
571    if let Some(value) = cli_value {
572        if value == 0 {
573            bail!("--llm-context-depth must be a positive integer");
574        }
575        return Ok(Some(value));
576    }
577
578    let raw_env = match env::var("MEMVID_LLM_CONTEXT_BUDGET") {
579        Ok(value) => value,
580        Err(_) => return Ok(None),
581    };
582
583    let trimmed = raw_env.trim();
584    if trimmed.is_empty() {
585        return Ok(None);
586    }
587
588    let digits: String = trimmed
589        .chars()
590        .filter(|ch| !ch.is_ascii_whitespace() && *ch != '_')
591        .collect();
592
593    if digits.is_empty() {
594        bail!("MEMVID_LLM_CONTEXT_BUDGET must be a positive integer value");
595    }
596
597    let value: usize = digits.parse().map_err(|err| {
598        anyhow!(
599            "MEMVID_LLM_CONTEXT_BUDGET value '{}' is not a valid number: {}",
600            trimmed,
601            err
602        )
603    })?;
604
605    if value == 0 {
606        bail!("MEMVID_LLM_CONTEXT_BUDGET must be a positive integer");
607    }
608
609    Ok(Some(value))
610}
611
612use crate::gemini_embeddings::GeminiEmbeddingProvider;
613use crate::mistral_embeddings::MistralEmbeddingProvider;
614use crate::nvidia_embeddings::NvidiaEmbeddingProvider;
615use crate::openai_embeddings::OpenAIEmbeddingProvider;
616
617/// Internal embedding backend - local fastembed or remote providers.
618#[derive(Clone)]
619enum EmbeddingBackend {
620    FastEmbed(std::sync::Arc<std::sync::Mutex<fastembed::TextEmbedding>>),
621    OpenAI(std::sync::Arc<OpenAIEmbeddingProvider>),
622    Nvidia(std::sync::Arc<NvidiaEmbeddingProvider>),
623    Gemini(std::sync::Arc<GeminiEmbeddingProvider>),
624    Mistral(std::sync::Arc<MistralEmbeddingProvider>),
625}
626
627/// Embedding runtime wrapper supporting local and remote embeddings
628#[derive(Clone)]
629pub struct EmbeddingRuntime {
630    backend: EmbeddingBackend,
631    model: EmbeddingModelChoice,
632    dimension: std::sync::Arc<AtomicUsize>,
633}
634
635impl EmbeddingRuntime {
636    fn new_fastembed(
637        backend: fastembed::TextEmbedding,
638        model: EmbeddingModelChoice,
639        dimension: usize,
640    ) -> Self {
641        Self {
642            backend: EmbeddingBackend::FastEmbed(std::sync::Arc::new(std::sync::Mutex::new(
643                backend,
644            ))),
645            model,
646            dimension: std::sync::Arc::new(AtomicUsize::new(dimension)),
647        }
648    }
649
650    fn new_openai(
651        provider: OpenAIEmbeddingProvider,
652        model: EmbeddingModelChoice,
653        dimension: usize,
654    ) -> Self {
655        Self {
656            backend: EmbeddingBackend::OpenAI(std::sync::Arc::new(provider)),
657            model,
658            dimension: std::sync::Arc::new(AtomicUsize::new(dimension)),
659        }
660    }
661
662    fn new_nvidia(provider: NvidiaEmbeddingProvider, model: EmbeddingModelChoice) -> Self {
663        Self {
664            backend: EmbeddingBackend::Nvidia(std::sync::Arc::new(provider)),
665            model,
666            dimension: std::sync::Arc::new(AtomicUsize::new(0)),
667        }
668    }
669
670    fn new_gemini(
671        provider: GeminiEmbeddingProvider,
672        model: EmbeddingModelChoice,
673        dimension: usize,
674    ) -> Self {
675        Self {
676            backend: EmbeddingBackend::Gemini(std::sync::Arc::new(provider)),
677            model,
678            dimension: std::sync::Arc::new(AtomicUsize::new(dimension)),
679        }
680    }
681
682    fn new_mistral(
683        provider: MistralEmbeddingProvider,
684        model: EmbeddingModelChoice,
685        dimension: usize,
686    ) -> Self {
687        Self {
688            backend: EmbeddingBackend::Mistral(std::sync::Arc::new(provider)),
689            model,
690            dimension: std::sync::Arc::new(AtomicUsize::new(dimension)),
691        }
692    }
693
694    const MAX_OPENAI_EMBEDDING_TEXT_LEN: usize = 20_000;
695    // NVIDIA Integrate embeddings enforce a 4096 token limit; use a tighter char cap as a guardrail.
696    const MAX_NVIDIA_EMBEDDING_TEXT_LEN: usize = 12_000;
697
698    // Gemini has an 8192 token limit, using conservative estimate
699    const MAX_GEMINI_EMBEDDING_TEXT_LEN: usize = 20_000;
700    // Mistral has an 8192 token limit, using conservative estimate
701    const MAX_MISTRAL_EMBEDDING_TEXT_LEN: usize = 20_000;
702
703    fn max_remote_embedding_chars(&self) -> usize {
704        match &self.backend {
705            EmbeddingBackend::OpenAI(_) => Self::MAX_OPENAI_EMBEDDING_TEXT_LEN,
706            EmbeddingBackend::Nvidia(_) => Self::MAX_NVIDIA_EMBEDDING_TEXT_LEN,
707            EmbeddingBackend::Gemini(_) => Self::MAX_GEMINI_EMBEDDING_TEXT_LEN,
708            EmbeddingBackend::Mistral(_) => Self::MAX_MISTRAL_EMBEDDING_TEXT_LEN,
709            EmbeddingBackend::FastEmbed(_) => usize::MAX,
710        }
711    }
712
713    /// Truncate text for embedding to reduce the risk of provider token-limit errors.
714    fn truncate_for_embedding<'a>(
715        text: &'a str,
716        max_chars: usize,
717    ) -> std::borrow::Cow<'a, str> {
718        if text.len() <= max_chars {
719            std::borrow::Cow::Borrowed(text)
720        } else {
721            // Find the last valid UTF-8 char boundary within the limit
722            let truncated = &text[..max_chars];
723            let end = truncated
724                .char_indices()
725                .rev()
726                .next()
727                .map(|(i, c)| i + c.len_utf8())
728                .unwrap_or(max_chars);
729            tracing::info!("Truncated embedding text from {} to {} chars", text.len(), end);
730            std::borrow::Cow::Owned(text[..end].to_string())
731        }
732    }
733
734    fn note_dimension(&self, observed: usize) -> Result<()> {
735        if observed == 0 {
736            return Err(anyhow!("embedding provider returned zero-length embedding"));
737        }
738
739        let current = self.dimension.load(Ordering::Relaxed);
740        if current == 0 {
741            self.dimension.store(observed, Ordering::Relaxed);
742            return Ok(());
743        }
744
745        if current != observed {
746            return Err(anyhow!(
747                "embedding provider returned {observed}D vectors but runtime expects {current}D"
748            ));
749        }
750
751        Ok(())
752    }
753
754    fn truncate_if_remote<'a>(&self, text: &'a str) -> std::borrow::Cow<'a, str> {
755        match &self.backend {
756            EmbeddingBackend::OpenAI(_)
757            | EmbeddingBackend::Nvidia(_)
758            | EmbeddingBackend::Gemini(_)
759            | EmbeddingBackend::Mistral(_) => {
760                Self::truncate_for_embedding(text, self.max_remote_embedding_chars())
761            }
762            EmbeddingBackend::FastEmbed(_) => std::borrow::Cow::Borrowed(text),
763        }
764    }
765
766    pub fn embed_passage(&self, text: &str) -> Result<Vec<f32>> {
767        let text = self.truncate_if_remote(text);
768        let embedding = match &self.backend {
769            EmbeddingBackend::FastEmbed(model) => {
770                let mut guard = model
771                    .lock()
772                    .map_err(|_| anyhow!("fastembed runtime poisoned"))?;
773                let outputs = guard
774                    .embed(vec![text.into_owned()], None)
775                    .map_err(|err| anyhow!("failed to compute embedding with fastembed: {err}"))?;
776                outputs
777                    .into_iter()
778                    .next()
779                    .ok_or_else(|| anyhow!("fastembed returned no embedding output"))?
780            }
781            EmbeddingBackend::OpenAI(provider) => {
782                use memvid_core::EmbeddingProvider;
783                provider
784                    .embed_text(&text)
785                    .map_err(|err| anyhow!("failed to compute embedding with OpenAI: {err}"))?
786            }
787            EmbeddingBackend::Nvidia(provider) => provider
788                .embed_passage(&text)
789                .map_err(|err| anyhow!("failed to compute embedding with NVIDIA: {err}"))?,
790            EmbeddingBackend::Gemini(provider) => provider
791                .embed_text(&text)
792                .map_err(|err| anyhow!("failed to compute embedding with Gemini: {err}"))?,
793            EmbeddingBackend::Mistral(provider) => provider
794                .embed_text(&text)
795                .map_err(|err| anyhow!("failed to compute embedding with Mistral: {err}"))?,
796        };
797
798        self.note_dimension(embedding.len())?;
799        Ok(embedding)
800    }
801
802    pub fn embed_query(&self, text: &str) -> Result<Vec<f32>> {
803        let text = self.truncate_if_remote(text);
804        match &self.backend {
805            EmbeddingBackend::Nvidia(provider) => {
806                let embedding = provider
807                    .embed_query(&text)
808                    .map_err(|err| anyhow!("failed to compute embedding with NVIDIA: {err}"))?;
809                self.note_dimension(embedding.len())?;
810                Ok(embedding)
811            }
812            _ => self.embed_passage(&text),
813        }
814    }
815
816    pub fn embed_batch_passages(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
817        if texts.is_empty() {
818            return Ok(Vec::new());
819        }
820
821        let truncated: Vec<std::borrow::Cow<'_, str>> =
822            texts.iter().map(|t| self.truncate_if_remote(t)).collect();
823        let truncated_refs: Vec<&str> = truncated.iter().map(|c| c.as_ref()).collect();
824
825        let embeddings = match &self.backend {
826            EmbeddingBackend::FastEmbed(model) => {
827                let mut guard = model
828                    .lock()
829                    .map_err(|_| anyhow!("fastembed runtime poisoned"))?;
830                guard
831                    .embed(
832                        truncated_refs
833                            .iter()
834                            .map(|s| (*s).to_string())
835                            .collect::<Vec<String>>(),
836                        None,
837                    )
838                    .map_err(|err| anyhow!("failed to compute embeddings with fastembed: {err}"))?
839            }
840            EmbeddingBackend::OpenAI(provider) => {
841                use memvid_core::EmbeddingProvider;
842                provider
843                    .embed_batch(&truncated_refs)
844                    .map_err(|err| anyhow!("failed to compute embeddings with OpenAI: {err}"))?
845            }
846            EmbeddingBackend::Nvidia(provider) => provider
847                .embed_passages(&truncated_refs)
848                .map_err(|err| anyhow!("failed to compute embeddings with NVIDIA: {err}"))?,
849            EmbeddingBackend::Gemini(provider) => provider
850                .embed_batch(&truncated_refs)
851                .map_err(|err| anyhow!("failed to compute embeddings with Gemini: {err}"))?,
852            EmbeddingBackend::Mistral(provider) => provider
853                .embed_batch(&truncated_refs)
854                .map_err(|err| anyhow!("failed to compute embeddings with Mistral: {err}"))?,
855        };
856
857        if let Some(first) = embeddings.first() {
858            self.note_dimension(first.len())?;
859        }
860        if let Some(expected) = embeddings.first().map(|e| e.len()) {
861            if embeddings.iter().any(|e| e.len() != expected) {
862                return Err(anyhow!("embedding provider returned mixed vector dimensions"));
863            }
864        }
865
866        Ok(embeddings)
867    }
868
869    pub fn embed_batch_queries(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
870        if texts.is_empty() {
871            return Ok(Vec::new());
872        }
873
874        let truncated: Vec<std::borrow::Cow<'_, str>> =
875            texts.iter().map(|t| self.truncate_if_remote(t)).collect();
876        let truncated_refs: Vec<&str> = truncated.iter().map(|c| c.as_ref()).collect();
877
878        match &self.backend {
879            EmbeddingBackend::Nvidia(provider) => {
880                let embeddings = provider
881                    .embed_queries(&truncated_refs)
882                    .map_err(|err| anyhow!("failed to compute embeddings with NVIDIA: {err}"))?;
883
884                if let Some(first) = embeddings.first() {
885                    self.note_dimension(first.len())?;
886                }
887                if let Some(expected) = embeddings.first().map(|e| e.len()) {
888                    if embeddings.iter().any(|e| e.len() != expected) {
889                        return Err(anyhow!("embedding provider returned mixed vector dimensions"));
890                    }
891                }
892
893                Ok(embeddings)
894            }
895            _ => self.embed_batch_passages(&truncated_refs),
896        }
897    }
898
899    pub fn dimension(&self) -> usize {
900        self.dimension.load(Ordering::Relaxed)
901    }
902
903    pub fn model_choice(&self) -> EmbeddingModelChoice {
904        self.model
905    }
906
907    pub fn provider_kind(&self) -> &'static str {
908        match &self.backend {
909            EmbeddingBackend::FastEmbed(_) => "fastembed",
910            EmbeddingBackend::OpenAI(_) => "openai",
911            EmbeddingBackend::Nvidia(_) => "nvidia",
912            EmbeddingBackend::Gemini(_) => "gemini",
913            EmbeddingBackend::Mistral(_) => "mistral",
914        }
915    }
916
917    pub fn provider_model_id(&self) -> String {
918        match &self.backend {
919            EmbeddingBackend::FastEmbed(_) => self.model.canonical_model_id().to_string(),
920            EmbeddingBackend::OpenAI(provider) => {
921                use memvid_core::EmbeddingProvider;
922                provider.model().to_string()
923            }
924            EmbeddingBackend::Nvidia(provider) => provider.model().to_string(),
925            EmbeddingBackend::Gemini(provider) => provider.model().to_string(),
926            EmbeddingBackend::Mistral(provider) => provider.model().to_string(),
927        }
928    }
929}
930
931impl memvid_core::VecEmbedder for EmbeddingRuntime {
932    fn embed_query(&self, text: &str) -> memvid_core::Result<Vec<f32>> {
933        EmbeddingRuntime::embed_query(self, text).map_err(|err| {
934            memvid_core::MemvidError::EmbeddingFailed {
935                reason: err.to_string().into_boxed_str(),
936            }
937        })
938    }
939
940    fn embedding_dimension(&self) -> usize {
941        self.dimension()
942    }
943}
944
945/// Ensure fastembed cache directory exists
946fn ensure_fastembed_cache(config: &CliConfig) -> Result<PathBuf> {
947    use std::fs;
948
949    let cache_dir = config.models_dir.clone();
950    fs::create_dir_all(&cache_dir)?;
951    Ok(cache_dir)
952}
953
954/// Get approximate model size in MB for user-friendly error messages
955fn model_size_mb(model: EmbeddingModelChoice) -> usize {
956    match model {
957        EmbeddingModelChoice::BgeSmall => 33,
958        EmbeddingModelChoice::BgeBase => 110,
959        EmbeddingModelChoice::Nomic => 137,
960        EmbeddingModelChoice::GteLarge => 327,
961        // Remote/cloud models don't require local download
962        EmbeddingModelChoice::OpenAILarge
963        | EmbeddingModelChoice::OpenAISmall
964        | EmbeddingModelChoice::OpenAIAda
965        | EmbeddingModelChoice::Nvidia
966        | EmbeddingModelChoice::Gemini
967        | EmbeddingModelChoice::Mistral => 0,
968    }
969}
970
971/// Instantiate an embedding runtime with the configured model
972fn instantiate_embedding_runtime(config: &CliConfig) -> Result<EmbeddingRuntime> {
973    use tracing::info;
974
975    let embedding_model = config.embedding_model;
976
977    if embedding_model.dimensions() > 0 {
978        info!(
979            "Loading embedding model: {} ({}D)",
980            embedding_model.name(),
981            embedding_model.dimensions()
982        );
983    } else {
984        info!("Loading embedding model: {}", embedding_model.name());
985    }
986
987    if config.offline && embedding_model.is_remote() {
988        anyhow::bail!(
989            "remote embeddings are unavailable while offline; set MEMVID_OFFLINE=0 or use a local embedding model"
990        );
991    }
992
993    // Check if OpenAI model
994    if embedding_model.is_openai() {
995        return instantiate_openai_runtime(embedding_model);
996    }
997
998    if embedding_model == EmbeddingModelChoice::Nvidia {
999        return instantiate_nvidia_runtime(None);
1000    }
1001
1002    if embedding_model == EmbeddingModelChoice::Gemini {
1003        return instantiate_gemini_runtime();
1004    }
1005
1006    if embedding_model == EmbeddingModelChoice::Mistral {
1007        return instantiate_mistral_runtime();
1008    }
1009
1010    // Local fastembed model
1011    instantiate_fastembed_runtime(config, embedding_model)
1012}
1013
1014/// Instantiate OpenAI embedding runtime
1015fn instantiate_openai_runtime(embedding_model: EmbeddingModelChoice) -> Result<EmbeddingRuntime> {
1016    use anyhow::bail;
1017    use memvid_core::EmbeddingConfig;
1018    use tracing::info;
1019
1020    let api_key = std::env::var("OPENAI_API_KEY").map_err(|_| {
1021        anyhow!("OPENAI_API_KEY environment variable is required for OpenAI embeddings")
1022    })?;
1023
1024    if api_key.is_empty() {
1025        bail!("OPENAI_API_KEY cannot be empty");
1026    }
1027
1028    let config = match embedding_model {
1029        EmbeddingModelChoice::OpenAILarge => EmbeddingConfig::openai_large(),
1030        EmbeddingModelChoice::OpenAISmall => EmbeddingConfig::openai_small(),
1031        EmbeddingModelChoice::OpenAIAda => EmbeddingConfig::openai_ada(),
1032        _ => unreachable!("is_openai() should have been false"),
1033    };
1034
1035    let provider = OpenAIEmbeddingProvider::new(api_key, config.clone())
1036        .map_err(|err| anyhow!("failed to create OpenAI embedding provider: {err}"))?;
1037
1038    info!(
1039        "OpenAI embedding provider ready: model={}, dimension={}",
1040        config.model, config.dimension
1041    );
1042
1043    Ok(EmbeddingRuntime::new_openai(
1044        provider,
1045        embedding_model,
1046        config.dimension,
1047    ))
1048}
1049
1050fn normalize_nvidia_embedding_model_override(raw: &str) -> Option<String> {
1051    let trimmed = raw.trim();
1052    if trimmed.is_empty() {
1053        return None;
1054    }
1055
1056    let lowered = trimmed.to_ascii_lowercase();
1057    if lowered == "nvidia" || lowered == "nv" {
1058        return None;
1059    }
1060
1061    let without_prefix = trimmed
1062        .strip_prefix("nvidia:")
1063        .or_else(|| trimmed.strip_prefix("nv:"))
1064        .unwrap_or(trimmed)
1065        .trim();
1066
1067    if without_prefix.is_empty() {
1068        return None;
1069    }
1070
1071    if without_prefix.eq_ignore_ascii_case("nv-embed-v1") {
1072        return Some("nvidia/nv-embed-v1".to_string());
1073    }
1074
1075    if without_prefix.contains('/') {
1076        return Some(without_prefix.to_string());
1077    }
1078
1079    Some(format!("nvidia/{without_prefix}"))
1080}
1081
1082/// Instantiate NVIDIA embedding runtime
1083fn instantiate_nvidia_runtime(model_override: Option<&str>) -> Result<EmbeddingRuntime> {
1084    use tracing::info;
1085
1086    let normalized = model_override.and_then(normalize_nvidia_embedding_model_override);
1087    let provider = NvidiaEmbeddingProvider::from_env(normalized.as_deref())
1088        .map_err(|err| anyhow!("failed to create NVIDIA embedding provider: {err}"))?;
1089
1090    info!(
1091        "NVIDIA embedding provider ready: model={}",
1092        provider.model()
1093    );
1094
1095    Ok(EmbeddingRuntime::new_nvidia(
1096        provider,
1097        EmbeddingModelChoice::Nvidia,
1098    ))
1099}
1100
1101/// Instantiate Gemini embedding runtime
1102fn instantiate_gemini_runtime() -> Result<EmbeddingRuntime> {
1103    use tracing::info;
1104
1105    let provider = GeminiEmbeddingProvider::from_env()
1106        .map_err(|err| anyhow!("failed to create Gemini embedding provider: {err}"))?;
1107
1108    let dimension = provider.dimension();
1109    info!(
1110        "Gemini embedding provider ready: model={}, dimension={}",
1111        provider.model(),
1112        dimension
1113    );
1114
1115    Ok(EmbeddingRuntime::new_gemini(
1116        provider,
1117        EmbeddingModelChoice::Gemini,
1118        dimension,
1119    ))
1120}
1121
1122/// Instantiate Mistral embedding runtime
1123fn instantiate_mistral_runtime() -> Result<EmbeddingRuntime> {
1124    use tracing::info;
1125
1126    let provider = MistralEmbeddingProvider::from_env()
1127        .map_err(|err| anyhow!("failed to create Mistral embedding provider: {err}"))?;
1128
1129    let dimension = provider.dimension();
1130    info!(
1131        "Mistral embedding provider ready: model={}, dimension={}",
1132        provider.model(),
1133        dimension
1134    );
1135
1136    Ok(EmbeddingRuntime::new_mistral(
1137        provider,
1138        EmbeddingModelChoice::Mistral,
1139        dimension,
1140    ))
1141}
1142
1143/// Instantiate fastembed (local) embedding runtime
1144fn instantiate_fastembed_runtime(
1145    config: &CliConfig,
1146    embedding_model: EmbeddingModelChoice,
1147) -> Result<EmbeddingRuntime> {
1148    use anyhow::bail;
1149    use fastembed::{InitOptions, TextEmbedding};
1150    use std::fs;
1151
1152    let cache_dir = ensure_fastembed_cache(config)?;
1153
1154    if config.offline {
1155        let mut entries = fs::read_dir(&cache_dir)?;
1156        if entries.next().is_none() {
1157            bail!(
1158                "semantic embeddings unavailable while offline; allow one connected run so fastembed can cache model weights"
1159            );
1160        }
1161    }
1162
1163    let options = InitOptions::new(embedding_model.to_fastembed_model())
1164        .with_cache_dir(cache_dir)
1165        .with_show_download_progress(true);
1166    let mut model = TextEmbedding::try_new(options).map_err(|err| {
1167        // Provide platform-specific guidance for model download issues
1168        let platform_hint = if cfg!(target_os = "windows") {
1169            "\n\nWindows users: If model downloads fail, try:\n\
1170            1. Run as Administrator\n\
1171            2. Check your antivirus isn't blocking downloads\n\
1172            3. Use OpenAI embeddings instead: set OPENAI_API_KEY and use --embedding-model openai"
1173        } else if cfg!(target_os = "linux") {
1174            "\n\nLinux users: If model downloads fail, try:\n\
1175            1. Check disk space in ~/.memvid/models\n\
1176            2. Ensure you have network access to huggingface.co\n\
1177            3. Use OpenAI embeddings instead: export OPENAI_API_KEY=... and use --embedding-model openai"
1178        } else {
1179            "\n\nIf model downloads fail, try using OpenAI embeddings:\n\
1180            export OPENAI_API_KEY=your-key && memvid ... --embedding-model openai"
1181        };
1182
1183        anyhow!(
1184            "Failed to initialize embedding model '{}': {err}\n\n\
1185            This typically means the model couldn't be downloaded or loaded.\n\
1186            Model size: ~{} MB{}\n\n\
1187            See: https://docs.memvid.com/embedding-models",
1188            embedding_model.name(),
1189            model_size_mb(embedding_model),
1190            platform_hint
1191        )
1192    })?;
1193
1194    let probe = model
1195        .embed(vec!["memvid probe".to_string()], None)
1196        .map_err(|err| anyhow!("failed to determine embedding dimension: {err}"))?;
1197    let dimension = probe.first().map(|vec| vec.len()).unwrap_or(0);
1198
1199    if dimension == 0 {
1200        bail!("fastembed reported zero-length embeddings");
1201    }
1202
1203    // Verify dimension matches expected
1204    if dimension != embedding_model.dimensions() {
1205        tracing::warn!(
1206            "Embedding dimension mismatch: expected {}, got {}",
1207            embedding_model.dimensions(),
1208            dimension
1209        );
1210    }
1211
1212    Ok(EmbeddingRuntime::new_fastembed(model, embedding_model, dimension))
1213}
1214
1215/// Load embedding runtime (fails if unavailable)
1216pub fn load_embedding_runtime(config: &CliConfig) -> Result<EmbeddingRuntime> {
1217    use anyhow::bail;
1218
1219    match instantiate_embedding_runtime(config) {
1220        Ok(runtime) => Ok(runtime),
1221        Err(err) => {
1222            if config.offline {
1223                bail!(
1224                    "semantic embeddings unavailable while offline; allow one connected run so fastembed can cache model weights ({err})"
1225                );
1226            }
1227            Err(err)
1228        }
1229    }
1230}
1231
1232/// Try to load embedding runtime (returns None if unavailable)
1233pub fn try_load_embedding_runtime(config: &CliConfig) -> Option<EmbeddingRuntime> {
1234    use tracing::warn;
1235
1236    match instantiate_embedding_runtime(config) {
1237        Ok(runtime) => Some(runtime),
1238        Err(err) => {
1239            warn!("semantic embeddings unavailable: {err}");
1240            None
1241        }
1242    }
1243}
1244
1245/// Load embedding runtime with an optional model override.
1246/// If `model_override` is provided, it will be used instead of the config's embedding_model.
1247pub fn load_embedding_runtime_with_model(
1248    config: &CliConfig,
1249    model_override: Option<&str>,
1250) -> Result<EmbeddingRuntime> {
1251    use tracing::info;
1252
1253    let mut raw_override: Option<&str> = None;
1254    let embedding_model = match model_override {
1255        Some(model_str) => {
1256            raw_override = Some(model_str);
1257            let parsed = model_str.parse::<EmbeddingModelChoice>()?;
1258            if parsed.dimensions() > 0 {
1259                info!(
1260                    "Using embedding model override: {} ({}D)",
1261                    parsed.name(),
1262                    parsed.dimensions()
1263                );
1264            } else {
1265                info!("Using embedding model override: {}", parsed.name());
1266            }
1267            parsed
1268        }
1269        None => config.embedding_model,
1270    };
1271
1272    if embedding_model.dimensions() > 0 {
1273        info!(
1274            "Loading embedding model: {} ({}D)",
1275            embedding_model.name(),
1276            embedding_model.dimensions()
1277        );
1278    } else {
1279        info!("Loading embedding model: {}", embedding_model.name());
1280    }
1281
1282    if config.offline && embedding_model.is_remote() {
1283        anyhow::bail!(
1284            "remote embeddings are unavailable while offline; set MEMVID_OFFLINE=0 or use a local embedding model"
1285        );
1286    }
1287
1288    if embedding_model.is_openai() {
1289        return instantiate_openai_runtime(embedding_model);
1290    }
1291
1292    if embedding_model == EmbeddingModelChoice::Nvidia {
1293        return instantiate_nvidia_runtime(raw_override);
1294    }
1295
1296    if embedding_model == EmbeddingModelChoice::Gemini {
1297        return instantiate_gemini_runtime();
1298    }
1299
1300    if embedding_model == EmbeddingModelChoice::Mistral {
1301        return instantiate_mistral_runtime();
1302    }
1303
1304    instantiate_fastembed_runtime(config, embedding_model)
1305}
1306
1307/// Try to load embedding runtime with model override (returns None if unavailable)
1308pub fn try_load_embedding_runtime_with_model(
1309    config: &CliConfig,
1310    model_override: Option<&str>,
1311) -> Option<EmbeddingRuntime> {
1312    use tracing::warn;
1313
1314    match load_embedding_runtime_with_model(config, model_override) {
1315        Ok(runtime) => Some(runtime),
1316        Err(err) => {
1317            warn!("semantic embeddings unavailable: {err}");
1318            None
1319        }
1320    }
1321}
1322
1323/// Load embedding runtime by auto-detecting from MV2 vector dimension.
1324///
1325/// Priority:
1326/// 1. Explicit model override (--query-embedding-model flag)
1327/// 2. Auto-detect from MV2 file's stored dimension
1328/// 3. Fall back to config default
1329///
1330/// This allows users to omit --query-embedding-model when querying files
1331/// created with non-default embedding models (like OpenAI).
1332pub fn load_embedding_runtime_for_mv2(
1333    config: &CliConfig,
1334    model_override: Option<&str>,
1335    mv2_dimension: Option<u32>,
1336) -> Result<EmbeddingRuntime> {
1337    use tracing::info;
1338
1339    // Priority 1: Explicit override
1340    if let Some(model_str) = model_override {
1341        return load_embedding_runtime_with_model(config, Some(model_str));
1342    }
1343
1344    // Priority 2: Auto-detect from MV2 dimension
1345    if let Some(dim) = mv2_dimension {
1346        if let Some(detected_model) = EmbeddingModelChoice::from_dimension(dim) {
1347            info!(
1348                "Auto-detected embedding model from MV2: {} ({}D)",
1349                detected_model.name(),
1350                dim
1351            );
1352
1353            // For OpenAI models, check if API key is available
1354            if detected_model.is_openai() {
1355                if std::env::var("OPENAI_API_KEY").is_ok() {
1356                    return load_embedding_runtime_with_model(config, Some(detected_model.name()));
1357                } else {
1358                    // OpenAI detected but no API key - provide helpful error
1359                    return Err(anyhow!(
1360                        "MV2 file uses OpenAI embeddings ({}D) but OPENAI_API_KEY is not set.\n\n\
1361                        Options:\n\
1362                        1. Set OPENAI_API_KEY environment variable\n\
1363                        2. Use --query-embedding-model to specify a different model\n\
1364                        3. Use lexical-only search with --mode lex\n\n\
1365                        See: https://docs.memvid.com/embedding-models",
1366                        dim
1367                    ));
1368                }
1369            }
1370
1371            return load_embedding_runtime_with_model(config, Some(detected_model.name()));
1372        }
1373    }
1374
1375    // Priority 3: Fall back to config default
1376    load_embedding_runtime(config)
1377}
1378
1379/// Try to load embedding runtime for MV2 with auto-detection (returns None if unavailable)
1380pub fn try_load_embedding_runtime_for_mv2(
1381    config: &CliConfig,
1382    model_override: Option<&str>,
1383    mv2_dimension: Option<u32>,
1384) -> Option<EmbeddingRuntime> {
1385    use tracing::warn;
1386
1387    match load_embedding_runtime_for_mv2(config, model_override, mv2_dimension) {
1388        Ok(runtime) => Some(runtime),
1389        Err(err) => {
1390            warn!("semantic embeddings unavailable: {err}");
1391            None
1392        }
1393    }
1394}