memvid_cli/
config.rs

1//! CLI configuration and environment handling
2//!
3//! This module provides configuration loading from environment variables,
4//! tracing initialization, and embedding runtime management for semantic search.
5
6use std::env;
7use std::path::PathBuf;
8use std::str::FromStr;
9use std::sync::atomic::{AtomicUsize, Ordering};
10
11use anyhow::{anyhow, Result};
12use ed25519_dalek::VerifyingKey;
13
14const DEFAULT_API_URL: &str = "https://memvid.com";
15const DEFAULT_CACHE_DIR: &str = "~/.cache/memvid";
16
17/// Supported embedding models for semantic search
18#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
19pub enum EmbeddingModelChoice {
20    /// BGE-small-en-v1.5: Fast, 384-dim, ~78% accuracy (default)
21    #[default]
22    BgeSmall,
23    /// BGE-base-en-v1.5: Balanced, 768-dim, ~85% accuracy
24    BgeBase,
25    /// Nomic-embed-text-v1.5: High accuracy, 768-dim, ~86% accuracy
26    Nomic,
27    /// GTE-large-en-v1.5: Best semantic depth, 1024-dim
28    GteLarge,
29    /// OpenAI text-embedding-3-large: Highest quality, 3072-dim (requires OPENAI_API_KEY)
30    OpenAILarge,
31    /// OpenAI text-embedding-3-small: Good quality, 1536-dim (requires OPENAI_API_KEY)
32    OpenAISmall,
33    /// OpenAI text-embedding-ada-002: Legacy model, 1536-dim (requires OPENAI_API_KEY)
34    OpenAIAda,
35    /// NVIDIA nv-embed-v1: High quality, remote embeddings (requires NVIDIA_API_KEY)
36    Nvidia,
37    /// Gemini text-embedding-004: Google AI embeddings, 768-dim (requires GOOGLE_API_KEY or GEMINI_API_KEY)
38    Gemini,
39    /// Mistral mistral-embed: Mistral AI embeddings, 1024-dim (requires MISTRAL_API_KEY)
40    Mistral,
41}
42
43impl EmbeddingModelChoice {
44    /// Check if this is an OpenAI model (requires OPENAI_API_KEY)
45    pub fn is_openai(&self) -> bool {
46        matches!(
47            self,
48            EmbeddingModelChoice::OpenAILarge
49                | EmbeddingModelChoice::OpenAISmall
50                | EmbeddingModelChoice::OpenAIAda
51        )
52    }
53
54    /// Check if this is a remote/cloud model (not local fastembed)
55    pub fn is_remote(&self) -> bool {
56        matches!(
57            self,
58            EmbeddingModelChoice::OpenAILarge
59                | EmbeddingModelChoice::OpenAISmall
60                | EmbeddingModelChoice::OpenAIAda
61                | EmbeddingModelChoice::Nvidia
62                | EmbeddingModelChoice::Gemini
63                | EmbeddingModelChoice::Mistral
64        )
65    }
66
67    /// Get the fastembed EmbeddingModel enum value (only for local models)
68    ///
69    /// # Panics
70    /// Panics if called on an OpenAI model. Use `is_openai()` to check first.
71    pub fn to_fastembed_model(&self) -> fastembed::EmbeddingModel {
72        match self {
73            EmbeddingModelChoice::BgeSmall => fastembed::EmbeddingModel::BGESmallENV15,
74            EmbeddingModelChoice::BgeBase => fastembed::EmbeddingModel::BGEBaseENV15,
75            EmbeddingModelChoice::Nomic => fastembed::EmbeddingModel::NomicEmbedTextV15,
76            EmbeddingModelChoice::GteLarge => fastembed::EmbeddingModel::GTELargeENV15,
77            EmbeddingModelChoice::OpenAILarge
78            | EmbeddingModelChoice::OpenAISmall
79            | EmbeddingModelChoice::OpenAIAda => {
80                panic!("OpenAI models don't use fastembed. Check is_remote() first.")
81            }
82            EmbeddingModelChoice::Nvidia => {
83                panic!("NVIDIA embeddings don't use fastembed. Check is_remote() first.")
84            }
85            EmbeddingModelChoice::Gemini => {
86                panic!("Gemini embeddings don't use fastembed. Check is_remote() first.")
87            }
88            EmbeddingModelChoice::Mistral => {
89                panic!("Mistral embeddings don't use fastembed. Check is_remote() first.")
90            }
91        }
92    }
93
94    /// Get human-readable model name
95    pub fn name(&self) -> &'static str {
96        match self {
97            EmbeddingModelChoice::BgeSmall => "bge-small",
98            EmbeddingModelChoice::BgeBase => "bge-base",
99            EmbeddingModelChoice::Nomic => "nomic",
100            EmbeddingModelChoice::GteLarge => "gte-large",
101            EmbeddingModelChoice::OpenAILarge => "openai-large",
102            EmbeddingModelChoice::OpenAISmall => "openai-small",
103            EmbeddingModelChoice::OpenAIAda => "openai-ada",
104            EmbeddingModelChoice::Nvidia => "nvidia",
105            EmbeddingModelChoice::Gemini => "gemini",
106            EmbeddingModelChoice::Mistral => "mistral",
107        }
108    }
109
110    /// Get the canonical provider model identifier used for persisted metadata.
111    ///
112    /// This is intended to match upstream provider IDs (OpenAI) and HuggingFace-style IDs
113    /// (fastembed/ONNX) so that memories can record an embedding "identity" that other
114    /// runtimes can select deterministically.
115    pub fn canonical_model_id(&self) -> &'static str {
116        match self {
117            EmbeddingModelChoice::BgeSmall => "BAAI/bge-small-en-v1.5",
118            EmbeddingModelChoice::BgeBase => "BAAI/bge-base-en-v1.5",
119            EmbeddingModelChoice::Nomic => "nomic-embed-text-v1.5",
120            EmbeddingModelChoice::GteLarge => "thenlper/gte-large",
121            EmbeddingModelChoice::OpenAILarge => "text-embedding-3-large",
122            EmbeddingModelChoice::OpenAISmall => "text-embedding-3-small",
123            EmbeddingModelChoice::OpenAIAda => "text-embedding-ada-002",
124            EmbeddingModelChoice::Nvidia => "nvidia/nv-embed-v1",
125            EmbeddingModelChoice::Gemini => "text-embedding-004",
126            EmbeddingModelChoice::Mistral => "mistral-embed",
127        }
128    }
129
130    /// Get embedding dimensions
131    pub fn dimensions(&self) -> usize {
132        match self {
133            EmbeddingModelChoice::BgeSmall => 384,
134            EmbeddingModelChoice::BgeBase => 768,
135            EmbeddingModelChoice::Nomic => 768,
136            EmbeddingModelChoice::GteLarge => 1024,
137            EmbeddingModelChoice::OpenAILarge => 3072,
138            EmbeddingModelChoice::OpenAISmall => 1536,
139            EmbeddingModelChoice::OpenAIAda => 1536,
140            // Remote model; infer from the first embedding response.
141            EmbeddingModelChoice::Nvidia => 0,
142            EmbeddingModelChoice::Gemini => 768,
143            EmbeddingModelChoice::Mistral => 1024,
144        }
145    }
146}
147
148impl FromStr for EmbeddingModelChoice {
149    type Err = anyhow::Error;
150
151    fn from_str(s: &str) -> Result<Self> {
152        let lowered = s.trim().to_ascii_lowercase();
153        match lowered.as_str() {
154            "bge-small" | "bge_small" | "bgesmall" | "small" => Ok(EmbeddingModelChoice::BgeSmall),
155            "baai/bge-small-en-v1.5" => Ok(EmbeddingModelChoice::BgeSmall),
156            "bge-base" | "bge_base" | "bgebase" | "base" => Ok(EmbeddingModelChoice::BgeBase),
157            "baai/bge-base-en-v1.5" => Ok(EmbeddingModelChoice::BgeBase),
158            "nomic" | "nomic-embed" | "nomic_embed" => Ok(EmbeddingModelChoice::Nomic),
159            "nomic-embed-text-v1.5" => Ok(EmbeddingModelChoice::Nomic),
160            "gte-large" | "gte_large" | "gtelarge" | "gte" => Ok(EmbeddingModelChoice::GteLarge),
161            "thenlper/gte-large" => Ok(EmbeddingModelChoice::GteLarge),
162            // OpenAI models - default "openai" maps to "openai-large" for highest quality
163            "openai" | "openai-large" | "openai_large" | "text-embedding-3-large" => {
164                Ok(EmbeddingModelChoice::OpenAILarge)
165            }
166            "openai-small" | "openai_small" | "text-embedding-3-small" => {
167                Ok(EmbeddingModelChoice::OpenAISmall)
168            }
169            "openai-ada" | "openai_ada" | "text-embedding-ada-002" | "ada" => {
170                Ok(EmbeddingModelChoice::OpenAIAda)
171            }
172            "nvidia" | "nv" | "nv-embed-v1" | "nvidia/nv-embed-v1" => Ok(EmbeddingModelChoice::Nvidia),
173            _ if lowered.starts_with("nvidia/") || lowered.starts_with("nvidia:") || lowered.starts_with("nv:") => {
174                Ok(EmbeddingModelChoice::Nvidia)
175            }
176            // Gemini embeddings
177            "gemini" | "gemini-embed" | "text-embedding-004" | "gemini-embedding-001" => {
178                Ok(EmbeddingModelChoice::Gemini)
179            }
180            _ if lowered.starts_with("gemini/") || lowered.starts_with("gemini:") || lowered.starts_with("google:") => {
181                Ok(EmbeddingModelChoice::Gemini)
182            }
183            // Mistral embeddings
184            "mistral" | "mistral-embed" => Ok(EmbeddingModelChoice::Mistral),
185            _ if lowered.starts_with("mistral/") || lowered.starts_with("mistral:") => {
186                Ok(EmbeddingModelChoice::Mistral)
187            }
188            _ => Err(anyhow!(
189                "unknown embedding model '{}'. Valid options: bge-small, bge-base, nomic, gte-large, openai, openai-small, openai-ada, nvidia, gemini, mistral",
190                s
191            )),
192        }
193    }
194}
195
196impl std::fmt::Display for EmbeddingModelChoice {
197    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
198        write!(f, "{}", self.name())
199    }
200}
201
202impl EmbeddingModelChoice {
203    /// Infer the best embedding model from vector dimension stored in MV2 file.
204    ///
205    /// This enables auto-detection: users don't need to specify --query-embedding-model
206    /// if the MV2 file has vectors. The dimension uniquely identifies the model family.
207    ///
208    /// # Dimension Mapping
209    /// - 384  → BGE-small (default local model)
210    /// - 768  → BGE-base (could also be Nomic, but same dimension works)
211    /// - 1024 → GTE-large
212    /// - 1536 → OpenAI small/ada
213    /// - 3072 → OpenAI large
214    pub fn from_dimension(dim: u32) -> Option<Self> {
215        match dim {
216            384 => Some(EmbeddingModelChoice::BgeSmall),
217            768 => Some(EmbeddingModelChoice::BgeBase), // Could be Nomic, but same dim
218            1024 => Some(EmbeddingModelChoice::GteLarge),
219            1536 => Some(EmbeddingModelChoice::OpenAISmall), // Could be Ada, same dim
220            3072 => Some(EmbeddingModelChoice::OpenAILarge),
221            0 => None, // No vectors in file
222            _ => {
223                tracing::warn!("Unknown embedding dimension {}, using default model", dim);
224                None
225            }
226        }
227    }
228}
229
230/// CLI configuration loaded from environment variables
231#[derive(Debug, Clone)]
232pub struct CliConfig {
233    pub api_key: Option<String>,
234    pub api_url: String,
235    pub cache_dir: PathBuf,
236    pub ticket_pubkey: Option<VerifyingKey>,
237    pub models_dir: PathBuf,
238    pub offline: bool,
239    /// Embedding model for semantic search (can be overridden by CLI flag)
240    pub embedding_model: EmbeddingModelChoice,
241}
242
243impl PartialEq for CliConfig {
244    fn eq(&self, other: &Self) -> bool {
245        self.api_key == other.api_key
246            && self.api_url == other.api_url
247            && self.cache_dir == other.cache_dir
248            && self.models_dir == other.models_dir
249            && self.offline == other.offline
250            && self.embedding_model == other.embedding_model
251    }
252}
253
254impl Eq for CliConfig {}
255
256impl CliConfig {
257    pub fn load() -> Result<Self> {
258        let api_key = env::var("MEMVID_API_KEY").ok().and_then(|value| {
259            let trimmed = value.trim().to_string();
260            (!trimmed.is_empty()).then_some(trimmed)
261        });
262
263        let api_url = env::var("MEMVID_API_URL").unwrap_or_else(|_| DEFAULT_API_URL.to_string());
264
265        let cache_dir_raw =
266            env::var("MEMVID_CACHE_DIR").unwrap_or_else(|_| DEFAULT_CACHE_DIR.to_string());
267        let cache_dir = expand_path(&cache_dir_raw)?;
268
269        let models_dir_raw =
270            env::var("MEMVID_MODELS_DIR").unwrap_or_else(|_| "~/.memvid/models".to_string());
271        let models_dir = expand_path(&models_dir_raw)?;
272
273        // Default public key for memvid.com dashboard ticket verification
274        // This allows users to use --memory-id without setting MEMVID_TICKET_PUBKEY
275        const DEFAULT_TICKET_PUBKEY: &str = "8wP1J2H+Tlx3PM3eT0lN2wDvoYrvl1DREKGKVb/V2cw=";
276
277        let ticket_pubkey_str = env::var("MEMVID_TICKET_PUBKEY")
278            .ok()
279            .and_then(|value| {
280                let trimmed = value.trim();
281                if trimmed.is_empty() {
282                    None
283                } else {
284                    Some(trimmed.to_string())
285                }
286            })
287            .unwrap_or_else(|| DEFAULT_TICKET_PUBKEY.to_string());
288
289        let ticket_pubkey = Some(memvid_core::parse_ed25519_public_key_base64(&ticket_pubkey_str)?);
290
291        let offline = env::var("MEMVID_OFFLINE")
292            .ok()
293            .map(|value| match value.trim().to_ascii_lowercase().as_str() {
294                "1" | "true" | "yes" => true,
295                _ => false,
296            })
297            .unwrap_or(false);
298
299        // Load embedding model from env var, default to BGE-small
300        let embedding_model = env::var("MEMVID_EMBEDDING_MODEL")
301            .ok()
302            .and_then(|value| {
303                let trimmed = value.trim();
304                if trimmed.is_empty() {
305                    None
306                } else {
307                    EmbeddingModelChoice::from_str(trimmed).ok()
308                }
309            })
310            .unwrap_or_default();
311
312        Ok(Self {
313            api_key,
314            api_url,
315            cache_dir,
316            ticket_pubkey,
317            models_dir,
318            offline,
319            embedding_model,
320        })
321    }
322
323    /// Create a new config with a different embedding model
324    pub fn with_embedding_model(&self, model: EmbeddingModelChoice) -> Self {
325        Self {
326            embedding_model: model,
327            ..self.clone()
328        }
329    }
330}
331
332fn expand_path(value: &str) -> Result<PathBuf> {
333    if value.trim().is_empty() {
334        return Err(anyhow!("cache directory cannot be empty"));
335    }
336
337    let expanded = if let Some(stripped) = value.strip_prefix("~/") {
338        home_dir()?.join(stripped)
339    } else if let Some(stripped) = value.strip_prefix("~\\") {
340        // Support Windows-style "~\" prefix.
341        home_dir()?.join(stripped)
342    } else if value == "~" {
343        home_dir()?
344    } else {
345        PathBuf::from(value)
346    };
347
348    if expanded.is_absolute() {
349        Ok(expanded)
350    } else {
351        Ok(env::current_dir()?.join(expanded))
352    }
353}
354
355fn home_dir() -> Result<PathBuf> {
356    if let Some(path) = env::var_os("HOME") {
357        if !path.is_empty() {
358            return Ok(PathBuf::from(path));
359        }
360    }
361
362    #[cfg(windows)]
363    {
364        if let Some(path) = env::var_os("USERPROFILE") {
365            if !path.is_empty() {
366                return Ok(PathBuf::from(path));
367            }
368        }
369        if let (Some(drive), Some(path)) = (env::var_os("HOMEDRIVE"), env::var_os("HOMEPATH")) {
370            if !drive.is_empty() && !path.is_empty() {
371                return Ok(PathBuf::from(format!(
372                    "{}{}",
373                    drive.to_string_lossy(),
374                    path.to_string_lossy()
375                )));
376            }
377        }
378    }
379
380    Err(anyhow!("unable to resolve home directory"))
381}
382
383#[cfg(test)]
384mod tests {
385    use super::*;
386    use base64::engine::general_purpose::STANDARD as BASE64_STANDARD;
387    use base64::Engine;
388    use ed25519_dalek::SigningKey;
389    use std::sync::{Mutex, OnceLock};
390
391    fn env_lock() -> std::sync::MutexGuard<'static, ()> {
392        static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
393        LOCK.get_or_init(|| Mutex::new(())).lock().unwrap()
394    }
395
396    fn set_or_unset(var: &str, value: Option<String>) {
397        match value {
398            Some(v) => unsafe { env::set_var(var, v) },
399            None => unsafe { env::remove_var(var) },
400        }
401    }
402
403    #[test]
404    fn defaults_expand_using_home_directory() {
405        let _guard = env_lock();
406
407        let previous_home = env::var("HOME").ok();
408        #[cfg(windows)]
409        let previous_userprofile = env::var("USERPROFILE").ok();
410
411        for var in [
412            "MEMVID_API_KEY",
413            "MEMVID_API_URL",
414            "MEMVID_CACHE_DIR",
415            "MEMVID_TICKET_PUBKEY",
416            "MEMVID_MODELS_DIR",
417            "MEMVID_OFFLINE",
418        ] {
419            unsafe { env::remove_var(var) };
420        }
421
422        let tmp = tempfile::tempdir().expect("tmpdir");
423        let tmp_path = tmp.path().to_path_buf();
424        unsafe { env::set_var("HOME", &tmp_path) };
425        #[cfg(windows)]
426        unsafe {
427            env::set_var("USERPROFILE", &tmp_path)
428        };
429
430        let config = CliConfig::load().expect("load");
431        assert_eq!(config.api_key, None);
432        assert_eq!(config.api_url, "https://memvid.com");
433        assert_eq!(config.cache_dir, tmp_path.join(".cache/memvid"));
434        // ticket_pubkey has a default value now, so it should be Some
435        assert!(config.ticket_pubkey.is_some());
436        assert_eq!(config.models_dir, tmp_path.join(".memvid/models"));
437        assert!(!config.offline);
438
439        set_or_unset("HOME", previous_home);
440        #[cfg(windows)]
441        {
442            set_or_unset("USERPROFILE", previous_userprofile);
443        }
444    }
445
446    #[test]
447    fn env_overrides_are_respected() {
448        let _guard = env_lock();
449
450        let previous_env: Vec<(&'static str, Option<String>)> = [
451            "MEMVID_API_KEY",
452            "MEMVID_API_URL",
453            "MEMVID_CACHE_DIR",
454            "MEMVID_TICKET_PUBKEY",
455            "MEMVID_MODELS_DIR",
456            "MEMVID_OFFLINE",
457        ]
458        .into_iter()
459        .map(|var| (var, env::var(var).ok()))
460        .collect();
461
462        unsafe { env::set_var("MEMVID_API_KEY", "abc123") };
463        unsafe { env::set_var("MEMVID_API_URL", "https://staging.memvid.app") };
464        unsafe { env::set_var("MEMVID_CACHE_DIR", "~/memvid-cache") };
465        unsafe { env::set_var("MEMVID_MODELS_DIR", "~/models") };
466        unsafe { env::set_var("MEMVID_OFFLINE", "true") };
467        let signing = SigningKey::from_bytes(&[9u8; 32]);
468        let encoded = BASE64_STANDARD.encode(signing.verifying_key().as_bytes());
469        unsafe { env::set_var("MEMVID_TICKET_PUBKEY", encoded) };
470
471        let tmp = tempfile::tempdir().expect("tmpdir");
472        let tmp_path = tmp.path().to_path_buf();
473        unsafe { env::set_var("HOME", &tmp_path) };
474        #[cfg(windows)]
475        unsafe {
476            env::set_var("USERPROFILE", &tmp_path)
477        };
478
479        let config = CliConfig::load().expect("load");
480        assert_eq!(config.api_key.as_deref(), Some("abc123"));
481        assert_eq!(config.api_url, "https://staging.memvid.app");
482        assert_eq!(config.cache_dir, tmp_path.join("memvid-cache"));
483        assert_eq!(
484            config.ticket_pubkey.expect("pubkey").as_bytes(),
485            signing.verifying_key().as_bytes()
486        );
487        assert_eq!(config.models_dir, tmp_path.join("models"));
488        assert!(config.offline);
489
490        for (var, value) in previous_env {
491            set_or_unset(var, value);
492        }
493    }
494
495    #[test]
496    fn rejects_empty_cache_dir() {
497        let _guard = env_lock();
498
499        let previous = env::var("MEMVID_CACHE_DIR").ok();
500        unsafe { env::set_var("MEMVID_CACHE_DIR", " ") };
501        let err = CliConfig::load().expect_err("should fail");
502        assert!(err.to_string().contains("cache directory"));
503        set_or_unset("MEMVID_CACHE_DIR", previous);
504    }
505}
506
507/// Initialize tracing/logging based on verbosity level
508pub fn init_tracing(verbosity: u8) -> Result<()> {
509    use std::io::IsTerminal;
510    use tracing_subscriber::{filter::Directive, fmt, EnvFilter};
511
512    let level = match verbosity {
513        0 => "warn",
514        1 => "info",
515        2 => "debug",
516        _ => "trace",
517    };
518
519    let mut env_filter =
520        EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new(level));
521    for directive_str in ["llama_cpp=error", "llama_cpp_sys=error", "ggml=error"] {
522        if let Ok(directive) = directive_str.parse::<Directive>() {
523            env_filter = env_filter.add_directive(directive);
524        }
525    }
526
527    // Disable ANSI color codes when stderr is not a terminal (e.g., piped or
528    // combined with `2>&1`). This prevents control characters from polluting
529    // JSON output when combined with stdout.
530    let use_ansi = std::io::stderr().is_terminal();
531
532    fmt()
533        .with_env_filter(env_filter)
534        .with_writer(std::io::stderr)
535        .with_target(false)
536        .without_time()
537        .with_ansi(use_ansi)
538        .try_init()
539        .map_err(|err| anyhow!(err))?;
540    Ok(())
541}
542
543/// Resolve LLM context budget override from CLI or environment
544pub fn resolve_llm_context_budget_override(cli_value: Option<usize>) -> Result<Option<usize>> {
545    use anyhow::bail;
546
547    if let Some(value) = cli_value {
548        if value == 0 {
549            bail!("--llm-context-depth must be a positive integer");
550        }
551        return Ok(Some(value));
552    }
553
554    let raw_env = match env::var("MEMVID_LLM_CONTEXT_BUDGET") {
555        Ok(value) => value,
556        Err(_) => return Ok(None),
557    };
558
559    let trimmed = raw_env.trim();
560    if trimmed.is_empty() {
561        return Ok(None);
562    }
563
564    let digits: String = trimmed
565        .chars()
566        .filter(|ch| !ch.is_ascii_whitespace() && *ch != '_')
567        .collect();
568
569    if digits.is_empty() {
570        bail!("MEMVID_LLM_CONTEXT_BUDGET must be a positive integer value");
571    }
572
573    let value: usize = digits.parse().map_err(|err| {
574        anyhow!(
575            "MEMVID_LLM_CONTEXT_BUDGET value '{}' is not a valid number: {}",
576            trimmed,
577            err
578        )
579    })?;
580
581    if value == 0 {
582        bail!("MEMVID_LLM_CONTEXT_BUDGET must be a positive integer");
583    }
584
585    Ok(Some(value))
586}
587
588use crate::gemini_embeddings::GeminiEmbeddingProvider;
589use crate::mistral_embeddings::MistralEmbeddingProvider;
590use crate::nvidia_embeddings::NvidiaEmbeddingProvider;
591use crate::openai_embeddings::OpenAIEmbeddingProvider;
592
593/// Internal embedding backend - local fastembed or remote providers.
594#[derive(Clone)]
595enum EmbeddingBackend {
596    FastEmbed(std::sync::Arc<std::sync::Mutex<fastembed::TextEmbedding>>),
597    OpenAI(std::sync::Arc<OpenAIEmbeddingProvider>),
598    Nvidia(std::sync::Arc<NvidiaEmbeddingProvider>),
599    Gemini(std::sync::Arc<GeminiEmbeddingProvider>),
600    Mistral(std::sync::Arc<MistralEmbeddingProvider>),
601}
602
603/// Embedding runtime wrapper supporting local and remote embeddings
604#[derive(Clone)]
605pub struct EmbeddingRuntime {
606    backend: EmbeddingBackend,
607    model: EmbeddingModelChoice,
608    dimension: std::sync::Arc<AtomicUsize>,
609}
610
611impl EmbeddingRuntime {
612    fn new_fastembed(
613        backend: fastembed::TextEmbedding,
614        model: EmbeddingModelChoice,
615        dimension: usize,
616    ) -> Self {
617        Self {
618            backend: EmbeddingBackend::FastEmbed(std::sync::Arc::new(std::sync::Mutex::new(
619                backend,
620            ))),
621            model,
622            dimension: std::sync::Arc::new(AtomicUsize::new(dimension)),
623        }
624    }
625
626    fn new_openai(
627        provider: OpenAIEmbeddingProvider,
628        model: EmbeddingModelChoice,
629        dimension: usize,
630    ) -> Self {
631        Self {
632            backend: EmbeddingBackend::OpenAI(std::sync::Arc::new(provider)),
633            model,
634            dimension: std::sync::Arc::new(AtomicUsize::new(dimension)),
635        }
636    }
637
638    fn new_nvidia(provider: NvidiaEmbeddingProvider, model: EmbeddingModelChoice) -> Self {
639        Self {
640            backend: EmbeddingBackend::Nvidia(std::sync::Arc::new(provider)),
641            model,
642            dimension: std::sync::Arc::new(AtomicUsize::new(0)),
643        }
644    }
645
646    fn new_gemini(
647        provider: GeminiEmbeddingProvider,
648        model: EmbeddingModelChoice,
649        dimension: usize,
650    ) -> Self {
651        Self {
652            backend: EmbeddingBackend::Gemini(std::sync::Arc::new(provider)),
653            model,
654            dimension: std::sync::Arc::new(AtomicUsize::new(dimension)),
655        }
656    }
657
658    fn new_mistral(
659        provider: MistralEmbeddingProvider,
660        model: EmbeddingModelChoice,
661        dimension: usize,
662    ) -> Self {
663        Self {
664            backend: EmbeddingBackend::Mistral(std::sync::Arc::new(provider)),
665            model,
666            dimension: std::sync::Arc::new(AtomicUsize::new(dimension)),
667        }
668    }
669
670    const MAX_OPENAI_EMBEDDING_TEXT_LEN: usize = 20_000;
671    // NVIDIA Integrate embeddings enforce a 4096 token limit; use a tighter char cap as a guardrail.
672    const MAX_NVIDIA_EMBEDDING_TEXT_LEN: usize = 12_000;
673
674    // Gemini has an 8192 token limit, using conservative estimate
675    const MAX_GEMINI_EMBEDDING_TEXT_LEN: usize = 20_000;
676    // Mistral has an 8192 token limit, using conservative estimate
677    const MAX_MISTRAL_EMBEDDING_TEXT_LEN: usize = 20_000;
678
679    fn max_remote_embedding_chars(&self) -> usize {
680        match &self.backend {
681            EmbeddingBackend::OpenAI(_) => Self::MAX_OPENAI_EMBEDDING_TEXT_LEN,
682            EmbeddingBackend::Nvidia(_) => Self::MAX_NVIDIA_EMBEDDING_TEXT_LEN,
683            EmbeddingBackend::Gemini(_) => Self::MAX_GEMINI_EMBEDDING_TEXT_LEN,
684            EmbeddingBackend::Mistral(_) => Self::MAX_MISTRAL_EMBEDDING_TEXT_LEN,
685            EmbeddingBackend::FastEmbed(_) => usize::MAX,
686        }
687    }
688
689    /// Truncate text for embedding to reduce the risk of provider token-limit errors.
690    fn truncate_for_embedding<'a>(
691        text: &'a str,
692        max_chars: usize,
693    ) -> std::borrow::Cow<'a, str> {
694        if text.len() <= max_chars {
695            std::borrow::Cow::Borrowed(text)
696        } else {
697            // Find the last valid UTF-8 char boundary within the limit
698            let truncated = &text[..max_chars];
699            let end = truncated
700                .char_indices()
701                .rev()
702                .next()
703                .map(|(i, c)| i + c.len_utf8())
704                .unwrap_or(max_chars);
705            tracing::info!("Truncated embedding text from {} to {} chars", text.len(), end);
706            std::borrow::Cow::Owned(text[..end].to_string())
707        }
708    }
709
710    fn note_dimension(&self, observed: usize) -> Result<()> {
711        if observed == 0 {
712            return Err(anyhow!("embedding provider returned zero-length embedding"));
713        }
714
715        let current = self.dimension.load(Ordering::Relaxed);
716        if current == 0 {
717            self.dimension.store(observed, Ordering::Relaxed);
718            return Ok(());
719        }
720
721        if current != observed {
722            return Err(anyhow!(
723                "embedding provider returned {observed}D vectors but runtime expects {current}D"
724            ));
725        }
726
727        Ok(())
728    }
729
730    fn truncate_if_remote<'a>(&self, text: &'a str) -> std::borrow::Cow<'a, str> {
731        match &self.backend {
732            EmbeddingBackend::OpenAI(_)
733            | EmbeddingBackend::Nvidia(_)
734            | EmbeddingBackend::Gemini(_)
735            | EmbeddingBackend::Mistral(_) => {
736                Self::truncate_for_embedding(text, self.max_remote_embedding_chars())
737            }
738            EmbeddingBackend::FastEmbed(_) => std::borrow::Cow::Borrowed(text),
739        }
740    }
741
742    pub fn embed_passage(&self, text: &str) -> Result<Vec<f32>> {
743        let text = self.truncate_if_remote(text);
744        let embedding = match &self.backend {
745            EmbeddingBackend::FastEmbed(model) => {
746                let mut guard = model
747                    .lock()
748                    .map_err(|_| anyhow!("fastembed runtime poisoned"))?;
749                let outputs = guard
750                    .embed(vec![text.into_owned()], None)
751                    .map_err(|err| anyhow!("failed to compute embedding with fastembed: {err}"))?;
752                outputs
753                    .into_iter()
754                    .next()
755                    .ok_or_else(|| anyhow!("fastembed returned no embedding output"))?
756            }
757            EmbeddingBackend::OpenAI(provider) => {
758                use memvid_core::EmbeddingProvider;
759                provider
760                    .embed_text(&text)
761                    .map_err(|err| anyhow!("failed to compute embedding with OpenAI: {err}"))?
762            }
763            EmbeddingBackend::Nvidia(provider) => provider
764                .embed_passage(&text)
765                .map_err(|err| anyhow!("failed to compute embedding with NVIDIA: {err}"))?,
766            EmbeddingBackend::Gemini(provider) => provider
767                .embed_text(&text)
768                .map_err(|err| anyhow!("failed to compute embedding with Gemini: {err}"))?,
769            EmbeddingBackend::Mistral(provider) => provider
770                .embed_text(&text)
771                .map_err(|err| anyhow!("failed to compute embedding with Mistral: {err}"))?,
772        };
773
774        self.note_dimension(embedding.len())?;
775        Ok(embedding)
776    }
777
778    pub fn embed_query(&self, text: &str) -> Result<Vec<f32>> {
779        let text = self.truncate_if_remote(text);
780        match &self.backend {
781            EmbeddingBackend::Nvidia(provider) => {
782                let embedding = provider
783                    .embed_query(&text)
784                    .map_err(|err| anyhow!("failed to compute embedding with NVIDIA: {err}"))?;
785                self.note_dimension(embedding.len())?;
786                Ok(embedding)
787            }
788            _ => self.embed_passage(&text),
789        }
790    }
791
792    pub fn embed_batch_passages(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
793        if texts.is_empty() {
794            return Ok(Vec::new());
795        }
796
797        let truncated: Vec<std::borrow::Cow<'_, str>> =
798            texts.iter().map(|t| self.truncate_if_remote(t)).collect();
799        let truncated_refs: Vec<&str> = truncated.iter().map(|c| c.as_ref()).collect();
800
801        let embeddings = match &self.backend {
802            EmbeddingBackend::FastEmbed(model) => {
803                let mut guard = model
804                    .lock()
805                    .map_err(|_| anyhow!("fastembed runtime poisoned"))?;
806                guard
807                    .embed(
808                        truncated_refs
809                            .iter()
810                            .map(|s| (*s).to_string())
811                            .collect::<Vec<String>>(),
812                        None,
813                    )
814                    .map_err(|err| anyhow!("failed to compute embeddings with fastembed: {err}"))?
815            }
816            EmbeddingBackend::OpenAI(provider) => {
817                use memvid_core::EmbeddingProvider;
818                provider
819                    .embed_batch(&truncated_refs)
820                    .map_err(|err| anyhow!("failed to compute embeddings with OpenAI: {err}"))?
821            }
822            EmbeddingBackend::Nvidia(provider) => provider
823                .embed_passages(&truncated_refs)
824                .map_err(|err| anyhow!("failed to compute embeddings with NVIDIA: {err}"))?,
825            EmbeddingBackend::Gemini(provider) => provider
826                .embed_batch(&truncated_refs)
827                .map_err(|err| anyhow!("failed to compute embeddings with Gemini: {err}"))?,
828            EmbeddingBackend::Mistral(provider) => provider
829                .embed_batch(&truncated_refs)
830                .map_err(|err| anyhow!("failed to compute embeddings with Mistral: {err}"))?,
831        };
832
833        if let Some(first) = embeddings.first() {
834            self.note_dimension(first.len())?;
835        }
836        if let Some(expected) = embeddings.first().map(|e| e.len()) {
837            if embeddings.iter().any(|e| e.len() != expected) {
838                return Err(anyhow!("embedding provider returned mixed vector dimensions"));
839            }
840        }
841
842        Ok(embeddings)
843    }
844
845    pub fn embed_batch_queries(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
846        if texts.is_empty() {
847            return Ok(Vec::new());
848        }
849
850        let truncated: Vec<std::borrow::Cow<'_, str>> =
851            texts.iter().map(|t| self.truncate_if_remote(t)).collect();
852        let truncated_refs: Vec<&str> = truncated.iter().map(|c| c.as_ref()).collect();
853
854        match &self.backend {
855            EmbeddingBackend::Nvidia(provider) => {
856                let embeddings = provider
857                    .embed_queries(&truncated_refs)
858                    .map_err(|err| anyhow!("failed to compute embeddings with NVIDIA: {err}"))?;
859
860                if let Some(first) = embeddings.first() {
861                    self.note_dimension(first.len())?;
862                }
863                if let Some(expected) = embeddings.first().map(|e| e.len()) {
864                    if embeddings.iter().any(|e| e.len() != expected) {
865                        return Err(anyhow!("embedding provider returned mixed vector dimensions"));
866                    }
867                }
868
869                Ok(embeddings)
870            }
871            _ => self.embed_batch_passages(&truncated_refs),
872        }
873    }
874
875    pub fn dimension(&self) -> usize {
876        self.dimension.load(Ordering::Relaxed)
877    }
878
879    pub fn model_choice(&self) -> EmbeddingModelChoice {
880        self.model
881    }
882
883    pub fn provider_kind(&self) -> &'static str {
884        match &self.backend {
885            EmbeddingBackend::FastEmbed(_) => "fastembed",
886            EmbeddingBackend::OpenAI(_) => "openai",
887            EmbeddingBackend::Nvidia(_) => "nvidia",
888            EmbeddingBackend::Gemini(_) => "gemini",
889            EmbeddingBackend::Mistral(_) => "mistral",
890        }
891    }
892
893    pub fn provider_model_id(&self) -> String {
894        match &self.backend {
895            EmbeddingBackend::FastEmbed(_) => self.model.canonical_model_id().to_string(),
896            EmbeddingBackend::OpenAI(provider) => {
897                use memvid_core::EmbeddingProvider;
898                provider.model().to_string()
899            }
900            EmbeddingBackend::Nvidia(provider) => provider.model().to_string(),
901            EmbeddingBackend::Gemini(provider) => provider.model().to_string(),
902            EmbeddingBackend::Mistral(provider) => provider.model().to_string(),
903        }
904    }
905}
906
907impl memvid_core::VecEmbedder for EmbeddingRuntime {
908    fn embed_query(&self, text: &str) -> memvid_core::Result<Vec<f32>> {
909        EmbeddingRuntime::embed_query(self, text).map_err(|err| {
910            memvid_core::MemvidError::EmbeddingFailed {
911                reason: err.to_string().into_boxed_str(),
912            }
913        })
914    }
915
916    fn embedding_dimension(&self) -> usize {
917        self.dimension()
918    }
919}
920
921/// Ensure fastembed cache directory exists
922fn ensure_fastembed_cache(config: &CliConfig) -> Result<PathBuf> {
923    use std::fs;
924
925    let cache_dir = config.models_dir.clone();
926    fs::create_dir_all(&cache_dir)?;
927    Ok(cache_dir)
928}
929
930/// Get approximate model size in MB for user-friendly error messages
931fn model_size_mb(model: EmbeddingModelChoice) -> usize {
932    match model {
933        EmbeddingModelChoice::BgeSmall => 33,
934        EmbeddingModelChoice::BgeBase => 110,
935        EmbeddingModelChoice::Nomic => 137,
936        EmbeddingModelChoice::GteLarge => 327,
937        // Remote/cloud models don't require local download
938        EmbeddingModelChoice::OpenAILarge
939        | EmbeddingModelChoice::OpenAISmall
940        | EmbeddingModelChoice::OpenAIAda
941        | EmbeddingModelChoice::Nvidia
942        | EmbeddingModelChoice::Gemini
943        | EmbeddingModelChoice::Mistral => 0,
944    }
945}
946
947/// Instantiate an embedding runtime with the configured model
948fn instantiate_embedding_runtime(config: &CliConfig) -> Result<EmbeddingRuntime> {
949    use tracing::info;
950
951    let embedding_model = config.embedding_model;
952
953    if embedding_model.dimensions() > 0 {
954        info!(
955            "Loading embedding model: {} ({}D)",
956            embedding_model.name(),
957            embedding_model.dimensions()
958        );
959    } else {
960        info!("Loading embedding model: {}", embedding_model.name());
961    }
962
963    if config.offline && embedding_model.is_remote() {
964        anyhow::bail!(
965            "remote embeddings are unavailable while offline; set MEMVID_OFFLINE=0 or use a local embedding model"
966        );
967    }
968
969    // Check if OpenAI model
970    if embedding_model.is_openai() {
971        return instantiate_openai_runtime(embedding_model);
972    }
973
974    if embedding_model == EmbeddingModelChoice::Nvidia {
975        return instantiate_nvidia_runtime(None);
976    }
977
978    if embedding_model == EmbeddingModelChoice::Gemini {
979        return instantiate_gemini_runtime();
980    }
981
982    if embedding_model == EmbeddingModelChoice::Mistral {
983        return instantiate_mistral_runtime();
984    }
985
986    // Local fastembed model
987    instantiate_fastembed_runtime(config, embedding_model)
988}
989
990/// Instantiate OpenAI embedding runtime
991fn instantiate_openai_runtime(embedding_model: EmbeddingModelChoice) -> Result<EmbeddingRuntime> {
992    use anyhow::bail;
993    use memvid_core::EmbeddingConfig;
994    use tracing::info;
995
996    let api_key = std::env::var("OPENAI_API_KEY").map_err(|_| {
997        anyhow!("OPENAI_API_KEY environment variable is required for OpenAI embeddings")
998    })?;
999
1000    if api_key.is_empty() {
1001        bail!("OPENAI_API_KEY cannot be empty");
1002    }
1003
1004    let config = match embedding_model {
1005        EmbeddingModelChoice::OpenAILarge => EmbeddingConfig::openai_large(),
1006        EmbeddingModelChoice::OpenAISmall => EmbeddingConfig::openai_small(),
1007        EmbeddingModelChoice::OpenAIAda => EmbeddingConfig::openai_ada(),
1008        _ => unreachable!("is_openai() should have been false"),
1009    };
1010
1011    let provider = OpenAIEmbeddingProvider::new(api_key, config.clone())
1012        .map_err(|err| anyhow!("failed to create OpenAI embedding provider: {err}"))?;
1013
1014    info!(
1015        "OpenAI embedding provider ready: model={}, dimension={}",
1016        config.model, config.dimension
1017    );
1018
1019    Ok(EmbeddingRuntime::new_openai(
1020        provider,
1021        embedding_model,
1022        config.dimension,
1023    ))
1024}
1025
1026fn normalize_nvidia_embedding_model_override(raw: &str) -> Option<String> {
1027    let trimmed = raw.trim();
1028    if trimmed.is_empty() {
1029        return None;
1030    }
1031
1032    let lowered = trimmed.to_ascii_lowercase();
1033    if lowered == "nvidia" || lowered == "nv" {
1034        return None;
1035    }
1036
1037    let without_prefix = trimmed
1038        .strip_prefix("nvidia:")
1039        .or_else(|| trimmed.strip_prefix("nv:"))
1040        .unwrap_or(trimmed)
1041        .trim();
1042
1043    if without_prefix.is_empty() {
1044        return None;
1045    }
1046
1047    if without_prefix.eq_ignore_ascii_case("nv-embed-v1") {
1048        return Some("nvidia/nv-embed-v1".to_string());
1049    }
1050
1051    if without_prefix.contains('/') {
1052        return Some(without_prefix.to_string());
1053    }
1054
1055    Some(format!("nvidia/{without_prefix}"))
1056}
1057
1058/// Instantiate NVIDIA embedding runtime
1059fn instantiate_nvidia_runtime(model_override: Option<&str>) -> Result<EmbeddingRuntime> {
1060    use tracing::info;
1061
1062    let normalized = model_override.and_then(normalize_nvidia_embedding_model_override);
1063    let provider = NvidiaEmbeddingProvider::from_env(normalized.as_deref())
1064        .map_err(|err| anyhow!("failed to create NVIDIA embedding provider: {err}"))?;
1065
1066    info!(
1067        "NVIDIA embedding provider ready: model={}",
1068        provider.model()
1069    );
1070
1071    Ok(EmbeddingRuntime::new_nvidia(
1072        provider,
1073        EmbeddingModelChoice::Nvidia,
1074    ))
1075}
1076
1077/// Instantiate Gemini embedding runtime
1078fn instantiate_gemini_runtime() -> Result<EmbeddingRuntime> {
1079    use tracing::info;
1080
1081    let provider = GeminiEmbeddingProvider::from_env()
1082        .map_err(|err| anyhow!("failed to create Gemini embedding provider: {err}"))?;
1083
1084    let dimension = provider.dimension();
1085    info!(
1086        "Gemini embedding provider ready: model={}, dimension={}",
1087        provider.model(),
1088        dimension
1089    );
1090
1091    Ok(EmbeddingRuntime::new_gemini(
1092        provider,
1093        EmbeddingModelChoice::Gemini,
1094        dimension,
1095    ))
1096}
1097
1098/// Instantiate Mistral embedding runtime
1099fn instantiate_mistral_runtime() -> Result<EmbeddingRuntime> {
1100    use tracing::info;
1101
1102    let provider = MistralEmbeddingProvider::from_env()
1103        .map_err(|err| anyhow!("failed to create Mistral embedding provider: {err}"))?;
1104
1105    let dimension = provider.dimension();
1106    info!(
1107        "Mistral embedding provider ready: model={}, dimension={}",
1108        provider.model(),
1109        dimension
1110    );
1111
1112    Ok(EmbeddingRuntime::new_mistral(
1113        provider,
1114        EmbeddingModelChoice::Mistral,
1115        dimension,
1116    ))
1117}
1118
1119/// Instantiate fastembed (local) embedding runtime
1120fn instantiate_fastembed_runtime(
1121    config: &CliConfig,
1122    embedding_model: EmbeddingModelChoice,
1123) -> Result<EmbeddingRuntime> {
1124    use anyhow::bail;
1125    use fastembed::{InitOptions, TextEmbedding};
1126    use std::fs;
1127
1128    let cache_dir = ensure_fastembed_cache(config)?;
1129
1130    if config.offline {
1131        let mut entries = fs::read_dir(&cache_dir)?;
1132        if entries.next().is_none() {
1133            bail!(
1134                "semantic embeddings unavailable while offline; allow one connected run so fastembed can cache model weights"
1135            );
1136        }
1137    }
1138
1139    let options = InitOptions::new(embedding_model.to_fastembed_model())
1140        .with_cache_dir(cache_dir)
1141        .with_show_download_progress(true);
1142    let mut model = TextEmbedding::try_new(options).map_err(|err| {
1143        // Provide platform-specific guidance for model download issues
1144        let platform_hint = if cfg!(target_os = "windows") {
1145            "\n\nWindows users: If model downloads fail, try:\n\
1146            1. Run as Administrator\n\
1147            2. Check your antivirus isn't blocking downloads\n\
1148            3. Use OpenAI embeddings instead: set OPENAI_API_KEY and use --embedding-model openai"
1149        } else if cfg!(target_os = "linux") {
1150            "\n\nLinux users: If model downloads fail, try:\n\
1151            1. Check disk space in ~/.memvid/models\n\
1152            2. Ensure you have network access to huggingface.co\n\
1153            3. Use OpenAI embeddings instead: export OPENAI_API_KEY=... and use --embedding-model openai"
1154        } else {
1155            "\n\nIf model downloads fail, try using OpenAI embeddings:\n\
1156            export OPENAI_API_KEY=your-key && memvid ... --embedding-model openai"
1157        };
1158
1159        anyhow!(
1160            "Failed to initialize embedding model '{}': {err}\n\n\
1161            This typically means the model couldn't be downloaded or loaded.\n\
1162            Model size: ~{} MB{}\n\n\
1163            See: https://docs.memvid.com/embedding-models",
1164            embedding_model.name(),
1165            model_size_mb(embedding_model),
1166            platform_hint
1167        )
1168    })?;
1169
1170    let probe = model
1171        .embed(vec!["memvid probe".to_string()], None)
1172        .map_err(|err| anyhow!("failed to determine embedding dimension: {err}"))?;
1173    let dimension = probe.first().map(|vec| vec.len()).unwrap_or(0);
1174
1175    if dimension == 0 {
1176        bail!("fastembed reported zero-length embeddings");
1177    }
1178
1179    // Verify dimension matches expected
1180    if dimension != embedding_model.dimensions() {
1181        tracing::warn!(
1182            "Embedding dimension mismatch: expected {}, got {}",
1183            embedding_model.dimensions(),
1184            dimension
1185        );
1186    }
1187
1188    Ok(EmbeddingRuntime::new_fastembed(model, embedding_model, dimension))
1189}
1190
1191/// Load embedding runtime (fails if unavailable)
1192pub fn load_embedding_runtime(config: &CliConfig) -> Result<EmbeddingRuntime> {
1193    use anyhow::bail;
1194
1195    match instantiate_embedding_runtime(config) {
1196        Ok(runtime) => Ok(runtime),
1197        Err(err) => {
1198            if config.offline {
1199                bail!(
1200                    "semantic embeddings unavailable while offline; allow one connected run so fastembed can cache model weights ({err})"
1201                );
1202            }
1203            Err(err)
1204        }
1205    }
1206}
1207
1208/// Try to load embedding runtime (returns None if unavailable)
1209pub fn try_load_embedding_runtime(config: &CliConfig) -> Option<EmbeddingRuntime> {
1210    use tracing::warn;
1211
1212    match instantiate_embedding_runtime(config) {
1213        Ok(runtime) => Some(runtime),
1214        Err(err) => {
1215            warn!("semantic embeddings unavailable: {err}");
1216            None
1217        }
1218    }
1219}
1220
1221/// Load embedding runtime with an optional model override.
1222/// If `model_override` is provided, it will be used instead of the config's embedding_model.
1223pub fn load_embedding_runtime_with_model(
1224    config: &CliConfig,
1225    model_override: Option<&str>,
1226) -> Result<EmbeddingRuntime> {
1227    use tracing::info;
1228
1229    let mut raw_override: Option<&str> = None;
1230    let embedding_model = match model_override {
1231        Some(model_str) => {
1232            raw_override = Some(model_str);
1233            let parsed = model_str.parse::<EmbeddingModelChoice>()?;
1234            if parsed.dimensions() > 0 {
1235                info!(
1236                    "Using embedding model override: {} ({}D)",
1237                    parsed.name(),
1238                    parsed.dimensions()
1239                );
1240            } else {
1241                info!("Using embedding model override: {}", parsed.name());
1242            }
1243            parsed
1244        }
1245        None => config.embedding_model,
1246    };
1247
1248    if embedding_model.dimensions() > 0 {
1249        info!(
1250            "Loading embedding model: {} ({}D)",
1251            embedding_model.name(),
1252            embedding_model.dimensions()
1253        );
1254    } else {
1255        info!("Loading embedding model: {}", embedding_model.name());
1256    }
1257
1258    if config.offline && embedding_model.is_remote() {
1259        anyhow::bail!(
1260            "remote embeddings are unavailable while offline; set MEMVID_OFFLINE=0 or use a local embedding model"
1261        );
1262    }
1263
1264    if embedding_model.is_openai() {
1265        return instantiate_openai_runtime(embedding_model);
1266    }
1267
1268    if embedding_model == EmbeddingModelChoice::Nvidia {
1269        return instantiate_nvidia_runtime(raw_override);
1270    }
1271
1272    if embedding_model == EmbeddingModelChoice::Gemini {
1273        return instantiate_gemini_runtime();
1274    }
1275
1276    if embedding_model == EmbeddingModelChoice::Mistral {
1277        return instantiate_mistral_runtime();
1278    }
1279
1280    instantiate_fastembed_runtime(config, embedding_model)
1281}
1282
1283/// Try to load embedding runtime with model override (returns None if unavailable)
1284pub fn try_load_embedding_runtime_with_model(
1285    config: &CliConfig,
1286    model_override: Option<&str>,
1287) -> Option<EmbeddingRuntime> {
1288    use tracing::warn;
1289
1290    match load_embedding_runtime_with_model(config, model_override) {
1291        Ok(runtime) => Some(runtime),
1292        Err(err) => {
1293            warn!("semantic embeddings unavailable: {err}");
1294            None
1295        }
1296    }
1297}
1298
1299/// Load embedding runtime by auto-detecting from MV2 vector dimension.
1300///
1301/// Priority:
1302/// 1. Explicit model override (--query-embedding-model flag)
1303/// 2. Auto-detect from MV2 file's stored dimension
1304/// 3. Fall back to config default
1305///
1306/// This allows users to omit --query-embedding-model when querying files
1307/// created with non-default embedding models (like OpenAI).
1308pub fn load_embedding_runtime_for_mv2(
1309    config: &CliConfig,
1310    model_override: Option<&str>,
1311    mv2_dimension: Option<u32>,
1312) -> Result<EmbeddingRuntime> {
1313    use tracing::info;
1314
1315    // Priority 1: Explicit override
1316    if let Some(model_str) = model_override {
1317        return load_embedding_runtime_with_model(config, Some(model_str));
1318    }
1319
1320    // Priority 2: Auto-detect from MV2 dimension
1321    if let Some(dim) = mv2_dimension {
1322        if let Some(detected_model) = EmbeddingModelChoice::from_dimension(dim) {
1323            info!(
1324                "Auto-detected embedding model from MV2: {} ({}D)",
1325                detected_model.name(),
1326                dim
1327            );
1328
1329            // For OpenAI models, check if API key is available
1330            if detected_model.is_openai() {
1331                if std::env::var("OPENAI_API_KEY").is_ok() {
1332                    return load_embedding_runtime_with_model(config, Some(detected_model.name()));
1333                } else {
1334                    // OpenAI detected but no API key - provide helpful error
1335                    return Err(anyhow!(
1336                        "MV2 file uses OpenAI embeddings ({}D) but OPENAI_API_KEY is not set.\n\n\
1337                        Options:\n\
1338                        1. Set OPENAI_API_KEY environment variable\n\
1339                        2. Use --query-embedding-model to specify a different model\n\
1340                        3. Use lexical-only search with --mode lex\n\n\
1341                        See: https://docs.memvid.com/embedding-models",
1342                        dim
1343                    ));
1344                }
1345            }
1346
1347            return load_embedding_runtime_with_model(config, Some(detected_model.name()));
1348        }
1349    }
1350
1351    // Priority 3: Fall back to config default
1352    load_embedding_runtime(config)
1353}
1354
1355/// Try to load embedding runtime for MV2 with auto-detection (returns None if unavailable)
1356pub fn try_load_embedding_runtime_for_mv2(
1357    config: &CliConfig,
1358    model_override: Option<&str>,
1359    mv2_dimension: Option<u32>,
1360) -> Option<EmbeddingRuntime> {
1361    use tracing::warn;
1362
1363    match load_embedding_runtime_for_mv2(config, model_override, mv2_dimension) {
1364        Ok(runtime) => Some(runtime),
1365        Err(err) => {
1366            warn!("semantic embeddings unavailable: {err}");
1367            None
1368        }
1369    }
1370}