Skip to main content

memvid_cli/
config.rs

1//! CLI configuration and environment handling
2//!
3//! This module provides configuration loading from environment variables,
4//! tracing initialization, and embedding runtime management for semantic search.
5
6use std::env;
7use std::path::PathBuf;
8use std::str::FromStr;
9use std::sync::atomic::{AtomicUsize, Ordering};
10
11use anyhow::{anyhow, Result};
12use ed25519_dalek::VerifyingKey;
13
14const DEFAULT_API_URL: &str = "https://memvid.com";
15const DEFAULT_CACHE_DIR: &str = "~/.cache/memvid";
16
17/// Supported embedding models for semantic search
18#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
19pub enum EmbeddingModelChoice {
20    /// BGE-small-en-v1.5: Fast, 384-dim, ~78% accuracy (default)
21    #[default]
22    BgeSmall,
23    /// BGE-base-en-v1.5: Balanced, 768-dim, ~85% accuracy
24    BgeBase,
25    /// Nomic-embed-text-v1.5: High accuracy, 768-dim, ~86% accuracy
26    Nomic,
27    /// GTE-large-en-v1.5: Best semantic depth, 1024-dim
28    GteLarge,
29    /// OpenAI text-embedding-3-large: Highest quality, 3072-dim (requires OPENAI_API_KEY)
30    OpenAILarge,
31    /// OpenAI text-embedding-3-small: Good quality, 1536-dim (requires OPENAI_API_KEY)
32    OpenAISmall,
33    /// OpenAI text-embedding-ada-002: Legacy model, 1536-dim (requires OPENAI_API_KEY)
34    OpenAIAda,
35    /// NVIDIA nv-embed-v1: High quality, remote embeddings (requires NVIDIA_API_KEY)
36    Nvidia,
37    /// Gemini text-embedding-004: Google AI embeddings, 768-dim (requires GOOGLE_API_KEY or GEMINI_API_KEY)
38    Gemini,
39    /// Mistral mistral-embed: Mistral AI embeddings, 1024-dim (requires MISTRAL_API_KEY)
40    Mistral,
41}
42
43impl EmbeddingModelChoice {
44    /// Check if this is an OpenAI model (requires OPENAI_API_KEY)
45    pub fn is_openai(&self) -> bool {
46        matches!(
47            self,
48            EmbeddingModelChoice::OpenAILarge
49                | EmbeddingModelChoice::OpenAISmall
50                | EmbeddingModelChoice::OpenAIAda
51        )
52    }
53
54    /// Check if this is a remote/cloud model (not local fastembed)
55    pub fn is_remote(&self) -> bool {
56        matches!(
57            self,
58            EmbeddingModelChoice::OpenAILarge
59                | EmbeddingModelChoice::OpenAISmall
60                | EmbeddingModelChoice::OpenAIAda
61                | EmbeddingModelChoice::Nvidia
62                | EmbeddingModelChoice::Gemini
63                | EmbeddingModelChoice::Mistral
64        )
65    }
66
67    /// Get the fastembed EmbeddingModel enum value (only for local models)
68    ///
69    /// # Panics
70    /// Panics if called on an OpenAI model. Use `is_openai()` to check first.
71    #[cfg(feature = "local-embeddings")]
72    pub fn to_fastembed_model(&self) -> fastembed::EmbeddingModel {
73        match self {
74            EmbeddingModelChoice::BgeSmall => fastembed::EmbeddingModel::BGESmallENV15,
75            EmbeddingModelChoice::BgeBase => fastembed::EmbeddingModel::BGEBaseENV15,
76            EmbeddingModelChoice::Nomic => fastembed::EmbeddingModel::NomicEmbedTextV15,
77            EmbeddingModelChoice::GteLarge => fastembed::EmbeddingModel::GTELargeENV15,
78            EmbeddingModelChoice::OpenAILarge
79            | EmbeddingModelChoice::OpenAISmall
80            | EmbeddingModelChoice::OpenAIAda => {
81                panic!("OpenAI models don't use fastembed. Check is_remote() first.")
82            }
83            EmbeddingModelChoice::Nvidia => {
84                panic!("NVIDIA embeddings don't use fastembed. Check is_remote() first.")
85            }
86            EmbeddingModelChoice::Gemini => {
87                panic!("Gemini embeddings don't use fastembed. Check is_remote() first.")
88            }
89            EmbeddingModelChoice::Mistral => {
90                panic!("Mistral embeddings don't use fastembed. Check is_remote() first.")
91            }
92        }
93    }
94
95    /// Get human-readable model name
96    pub fn name(&self) -> &'static str {
97        match self {
98            EmbeddingModelChoice::BgeSmall => "bge-small",
99            EmbeddingModelChoice::BgeBase => "bge-base",
100            EmbeddingModelChoice::Nomic => "nomic",
101            EmbeddingModelChoice::GteLarge => "gte-large",
102            EmbeddingModelChoice::OpenAILarge => "openai-large",
103            EmbeddingModelChoice::OpenAISmall => "openai-small",
104            EmbeddingModelChoice::OpenAIAda => "openai-ada",
105            EmbeddingModelChoice::Nvidia => "nvidia",
106            EmbeddingModelChoice::Gemini => "gemini",
107            EmbeddingModelChoice::Mistral => "mistral",
108        }
109    }
110
111    /// Get the canonical provider model identifier used for persisted metadata.
112    ///
113    /// This is intended to match upstream provider IDs (OpenAI) and HuggingFace-style IDs
114    /// (fastembed/ONNX) so that memories can record an embedding "identity" that other
115    /// runtimes can select deterministically.
116    pub fn canonical_model_id(&self) -> &'static str {
117        match self {
118            EmbeddingModelChoice::BgeSmall => "BAAI/bge-small-en-v1.5",
119            EmbeddingModelChoice::BgeBase => "BAAI/bge-base-en-v1.5",
120            EmbeddingModelChoice::Nomic => "nomic-embed-text-v1.5",
121            EmbeddingModelChoice::GteLarge => "thenlper/gte-large",
122            EmbeddingModelChoice::OpenAILarge => "text-embedding-3-large",
123            EmbeddingModelChoice::OpenAISmall => "text-embedding-3-small",
124            EmbeddingModelChoice::OpenAIAda => "text-embedding-ada-002",
125            EmbeddingModelChoice::Nvidia => "nvidia/nv-embed-v1",
126            EmbeddingModelChoice::Gemini => "text-embedding-004",
127            EmbeddingModelChoice::Mistral => "mistral-embed",
128        }
129    }
130
131    /// Get embedding dimensions
132    pub fn dimensions(&self) -> usize {
133        match self {
134            EmbeddingModelChoice::BgeSmall => 384,
135            EmbeddingModelChoice::BgeBase => 768,
136            EmbeddingModelChoice::Nomic => 768,
137            EmbeddingModelChoice::GteLarge => 1024,
138            EmbeddingModelChoice::OpenAILarge => 3072,
139            EmbeddingModelChoice::OpenAISmall => 1536,
140            EmbeddingModelChoice::OpenAIAda => 1536,
141            // Remote model; infer from the first embedding response.
142            EmbeddingModelChoice::Nvidia => 0,
143            EmbeddingModelChoice::Gemini => 768,
144            EmbeddingModelChoice::Mistral => 1024,
145        }
146    }
147}
148
149impl FromStr for EmbeddingModelChoice {
150    type Err = anyhow::Error;
151
152    fn from_str(s: &str) -> Result<Self> {
153        let lowered = s.trim().to_ascii_lowercase();
154        match lowered.as_str() {
155            "bge-small" | "bge_small" | "bgesmall" | "small" => Ok(EmbeddingModelChoice::BgeSmall),
156            "baai/bge-small-en-v1.5" => Ok(EmbeddingModelChoice::BgeSmall),
157            "bge-base" | "bge_base" | "bgebase" | "base" => Ok(EmbeddingModelChoice::BgeBase),
158            "baai/bge-base-en-v1.5" => Ok(EmbeddingModelChoice::BgeBase),
159            "nomic" | "nomic-embed" | "nomic_embed" => Ok(EmbeddingModelChoice::Nomic),
160            "nomic-embed-text-v1.5" => Ok(EmbeddingModelChoice::Nomic),
161            "gte-large" | "gte_large" | "gtelarge" | "gte" => Ok(EmbeddingModelChoice::GteLarge),
162            "thenlper/gte-large" => Ok(EmbeddingModelChoice::GteLarge),
163            // OpenAI models - default "openai" maps to "openai-large" for highest quality
164            "openai" | "openai-large" | "openai_large" | "text-embedding-3-large" => {
165                Ok(EmbeddingModelChoice::OpenAILarge)
166            }
167            "openai-small" | "openai_small" | "text-embedding-3-small" => {
168                Ok(EmbeddingModelChoice::OpenAISmall)
169            }
170            "openai-ada" | "openai_ada" | "text-embedding-ada-002" | "ada" => {
171                Ok(EmbeddingModelChoice::OpenAIAda)
172            }
173            "nvidia" | "nv" | "nv-embed-v1" | "nvidia/nv-embed-v1" => Ok(EmbeddingModelChoice::Nvidia),
174            _ if lowered.starts_with("nvidia/") || lowered.starts_with("nvidia:") || lowered.starts_with("nv:") => {
175                Ok(EmbeddingModelChoice::Nvidia)
176            }
177            // Gemini embeddings
178            "gemini" | "gemini-embed" | "text-embedding-004" | "gemini-embedding-001" => {
179                Ok(EmbeddingModelChoice::Gemini)
180            }
181            _ if lowered.starts_with("gemini/") || lowered.starts_with("gemini:") || lowered.starts_with("google:") => {
182                Ok(EmbeddingModelChoice::Gemini)
183            }
184            // Mistral embeddings
185            "mistral" | "mistral-embed" => Ok(EmbeddingModelChoice::Mistral),
186            _ if lowered.starts_with("mistral/") || lowered.starts_with("mistral:") => {
187                Ok(EmbeddingModelChoice::Mistral)
188            }
189            _ => Err(anyhow!(
190                "unknown embedding model '{}'. Valid options: bge-small, bge-base, nomic, gte-large, openai, openai-small, openai-ada, nvidia, gemini, mistral",
191                s
192            )),
193        }
194    }
195}
196
197impl std::fmt::Display for EmbeddingModelChoice {
198    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
199        write!(f, "{}", self.name())
200    }
201}
202
203impl EmbeddingModelChoice {
204    /// Infer the best embedding model from vector dimension stored in MV2 file.
205    ///
206    /// This enables auto-detection: users don't need to specify --query-embedding-model
207    /// if the MV2 file has vectors. The dimension uniquely identifies the model family.
208    ///
209    /// # Dimension Mapping
210    /// - 384  → BGE-small (default local model)
211    /// - 768  → BGE-base (could also be Nomic, but same dimension works)
212    /// - 1024 → GTE-large
213    /// - 1536 → OpenAI small/ada
214    /// - 3072 → OpenAI large
215    pub fn from_dimension(dim: u32) -> Option<Self> {
216        match dim {
217            384 => Some(EmbeddingModelChoice::BgeSmall),
218            768 => Some(EmbeddingModelChoice::BgeBase), // Could be Nomic, but same dim
219            1024 => Some(EmbeddingModelChoice::GteLarge),
220            1536 => Some(EmbeddingModelChoice::OpenAISmall), // Could be Ada, same dim
221            3072 => Some(EmbeddingModelChoice::OpenAILarge),
222            0 => None, // No vectors in file
223            _ => {
224                tracing::warn!("Unknown embedding dimension {}, using default model", dim);
225                None
226            }
227        }
228    }
229}
230
231/// CLI configuration loaded from environment variables and config file
232#[derive(Debug, Clone)]
233pub struct CliConfig {
234    pub api_key: Option<String>,
235    pub api_url: String,
236    /// Default memory ID for dashboard sync (from config file)
237    pub memory_id: Option<String>,
238    pub cache_dir: PathBuf,
239    pub ticket_pubkey: Option<VerifyingKey>,
240    pub models_dir: PathBuf,
241    pub offline: bool,
242    /// Embedding model for semantic search (can be overridden by CLI flag)
243    pub embedding_model: EmbeddingModelChoice,
244}
245
246impl PartialEq for CliConfig {
247    fn eq(&self, other: &Self) -> bool {
248        self.api_key == other.api_key
249            && self.api_url == other.api_url
250            && self.memory_id == other.memory_id
251            && self.cache_dir == other.cache_dir
252            && self.models_dir == other.models_dir
253            && self.offline == other.offline
254            && self.embedding_model == other.embedding_model
255    }
256}
257
258impl Eq for CliConfig {}
259
260impl CliConfig {
261    pub fn load() -> Result<Self> {
262        // Load persistent config file (if exists) for fallback values
263        let persistent_config = crate::commands::config::PersistentConfig::load().ok();
264
265        // API Key: env var takes precedence, then config file
266        let api_key = env::var("MEMVID_API_KEY")
267            .ok()
268            .and_then(|value| {
269                let trimmed = value.trim().to_string();
270                (!trimmed.is_empty()).then_some(trimmed)
271            })
272            .or_else(|| persistent_config.as_ref().and_then(|c| c.api_key.clone()));
273
274        // API URL: env var takes precedence, then config file, then default
275        let api_url = env::var("MEMVID_API_URL")
276            .ok()
277            .or_else(|| persistent_config.as_ref().and_then(|c| c.api_url.clone()))
278            .unwrap_or_else(|| DEFAULT_API_URL.to_string());
279
280        // Memory ID: env var takes precedence, then config file (memory.default or legacy memory_id)
281        let memory_id = env::var("MEMVID_MEMORY_ID")
282            .ok()
283            .and_then(|value| {
284                let trimmed = value.trim().to_string();
285                (!trimmed.is_empty()).then_some(trimmed)
286            })
287            .or_else(|| {
288                persistent_config
289                    .as_ref()
290                    .and_then(|c| c.default_memory_id())
291            });
292
293        let cache_dir_raw =
294            env::var("MEMVID_CACHE_DIR").unwrap_or_else(|_| DEFAULT_CACHE_DIR.to_string());
295        let cache_dir = expand_path(&cache_dir_raw)?;
296
297        let models_dir_raw =
298            env::var("MEMVID_MODELS_DIR").unwrap_or_else(|_| "~/.memvid/models".to_string());
299        let models_dir = expand_path(&models_dir_raw)?;
300
301        // Default public key for memvid.com dashboard ticket verification
302        // This allows users to use --memory-id without setting MEMVID_TICKET_PUBKEY
303        // Must match memvid-core's MEMVID_TICKET_PUBKEY constant
304        const DEFAULT_TICKET_PUBKEY: &str = "DFKNhP/yO5i1b9aKL+aHeBaGunz9sMfOF736fzYws4Q=";
305
306        let ticket_pubkey_str = env::var("MEMVID_TICKET_PUBKEY")
307            .ok()
308            .and_then(|value| {
309                let trimmed = value.trim();
310                if trimmed.is_empty() {
311                    None
312                } else {
313                    Some(trimmed.to_string())
314                }
315            })
316            .unwrap_or_else(|| DEFAULT_TICKET_PUBKEY.to_string());
317
318        let ticket_pubkey = Some(memvid_core::parse_ed25519_public_key_base64(
319            &ticket_pubkey_str,
320        )?);
321
322        let offline = env::var("MEMVID_OFFLINE")
323            .ok()
324            .map(|value| match value.trim().to_ascii_lowercase().as_str() {
325                "1" | "true" | "yes" => true,
326                _ => false,
327            })
328            .unwrap_or(false);
329
330        // Load embedding model from env var, default to BGE-small
331        let embedding_model = env::var("MEMVID_EMBEDDING_MODEL")
332            .ok()
333            .and_then(|value| {
334                let trimmed = value.trim();
335                if trimmed.is_empty() {
336                    None
337                } else {
338                    EmbeddingModelChoice::from_str(trimmed).ok()
339                }
340            })
341            .unwrap_or_default();
342
343        Ok(Self {
344            api_key,
345            api_url,
346            memory_id,
347            cache_dir,
348            ticket_pubkey,
349            models_dir,
350            offline,
351            embedding_model,
352        })
353    }
354
355    /// Create a new config with a different embedding model
356    pub fn with_embedding_model(&self, model: EmbeddingModelChoice) -> Self {
357        Self {
358            embedding_model: model,
359            ..self.clone()
360        }
361    }
362}
363
364fn expand_path(value: &str) -> Result<PathBuf> {
365    if value.trim().is_empty() {
366        return Err(anyhow!("cache directory cannot be empty"));
367    }
368
369    let expanded = if let Some(stripped) = value.strip_prefix("~/") {
370        home_dir()?.join(stripped)
371    } else if let Some(stripped) = value.strip_prefix("~\\") {
372        // Support Windows-style "~\" prefix.
373        home_dir()?.join(stripped)
374    } else if value == "~" {
375        home_dir()?
376    } else {
377        PathBuf::from(value)
378    };
379
380    if expanded.is_absolute() {
381        Ok(expanded)
382    } else {
383        Ok(env::current_dir()?.join(expanded))
384    }
385}
386
387fn home_dir() -> Result<PathBuf> {
388    if let Some(path) = env::var_os("HOME") {
389        if !path.is_empty() {
390            return Ok(PathBuf::from(path));
391        }
392    }
393
394    #[cfg(windows)]
395    {
396        if let Some(path) = env::var_os("USERPROFILE") {
397            if !path.is_empty() {
398                return Ok(PathBuf::from(path));
399            }
400        }
401        if let (Some(drive), Some(path)) = (env::var_os("HOMEDRIVE"), env::var_os("HOMEPATH")) {
402            if !drive.is_empty() && !path.is_empty() {
403                return Ok(PathBuf::from(format!(
404                    "{}{}",
405                    drive.to_string_lossy(),
406                    path.to_string_lossy()
407                )));
408            }
409        }
410    }
411
412    Err(anyhow!("unable to resolve home directory"))
413}
414
415#[cfg(test)]
416mod tests {
417    use super::*;
418    use base64::engine::general_purpose::STANDARD as BASE64_STANDARD;
419    use base64::Engine;
420    use ed25519_dalek::SigningKey;
421    use std::sync::{Mutex, OnceLock};
422
423    fn env_lock() -> std::sync::MutexGuard<'static, ()> {
424        static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
425        LOCK.get_or_init(|| Mutex::new(())).lock().unwrap()
426    }
427
428    fn set_or_unset(var: &str, value: Option<String>) {
429        match value {
430            Some(v) => unsafe { env::set_var(var, v) },
431            None => unsafe { env::remove_var(var) },
432        }
433    }
434
435    #[test]
436    fn defaults_expand_using_home_directory() {
437        let _guard = env_lock();
438
439        let previous_home = env::var("HOME").ok();
440        #[cfg(windows)]
441        let previous_userprofile = env::var("USERPROFILE").ok();
442
443        for var in [
444            "MEMVID_API_KEY",
445            "MEMVID_API_URL",
446            "MEMVID_CACHE_DIR",
447            "MEMVID_TICKET_PUBKEY",
448            "MEMVID_MODELS_DIR",
449            "MEMVID_OFFLINE",
450        ] {
451            unsafe { env::remove_var(var) };
452        }
453
454        let tmp = tempfile::tempdir().expect("tmpdir");
455        let tmp_path = tmp.path().to_path_buf();
456        unsafe { env::set_var("HOME", &tmp_path) };
457        #[cfg(windows)]
458        unsafe {
459            env::set_var("USERPROFILE", &tmp_path)
460        };
461
462        let config = CliConfig::load().expect("load");
463        assert_eq!(config.api_key, None);
464        assert_eq!(config.api_url, "https://memvid.com");
465        assert_eq!(config.cache_dir, tmp_path.join(".cache/memvid"));
466        // ticket_pubkey has a default value now, so it should be Some
467        assert!(config.ticket_pubkey.is_some());
468        assert_eq!(config.models_dir, tmp_path.join(".memvid/models"));
469        assert!(!config.offline);
470
471        set_or_unset("HOME", previous_home);
472        #[cfg(windows)]
473        {
474            set_or_unset("USERPROFILE", previous_userprofile);
475        }
476    }
477
478    #[test]
479    fn env_overrides_are_respected() {
480        let _guard = env_lock();
481
482        let previous_env: Vec<(&'static str, Option<String>)> = [
483            "MEMVID_API_KEY",
484            "MEMVID_API_URL",
485            "MEMVID_CACHE_DIR",
486            "MEMVID_TICKET_PUBKEY",
487            "MEMVID_MODELS_DIR",
488            "MEMVID_OFFLINE",
489        ]
490        .into_iter()
491        .map(|var| (var, env::var(var).ok()))
492        .collect();
493
494        unsafe { env::set_var("MEMVID_API_KEY", "abc123") };
495        unsafe { env::set_var("MEMVID_API_URL", "https://staging.memvid.app") };
496        unsafe { env::set_var("MEMVID_CACHE_DIR", "~/memvid-cache") };
497        unsafe { env::set_var("MEMVID_MODELS_DIR", "~/models") };
498        unsafe { env::set_var("MEMVID_OFFLINE", "true") };
499        let signing = SigningKey::from_bytes(&[9u8; 32]);
500        let encoded = BASE64_STANDARD.encode(signing.verifying_key().as_bytes());
501        unsafe { env::set_var("MEMVID_TICKET_PUBKEY", encoded) };
502
503        let tmp = tempfile::tempdir().expect("tmpdir");
504        let tmp_path = tmp.path().to_path_buf();
505        unsafe { env::set_var("HOME", &tmp_path) };
506        #[cfg(windows)]
507        unsafe {
508            env::set_var("USERPROFILE", &tmp_path)
509        };
510
511        let config = CliConfig::load().expect("load");
512        assert_eq!(config.api_key.as_deref(), Some("abc123"));
513        assert_eq!(config.api_url, "https://staging.memvid.app");
514        assert_eq!(config.cache_dir, tmp_path.join("memvid-cache"));
515        assert_eq!(
516            config.ticket_pubkey.expect("pubkey").as_bytes(),
517            signing.verifying_key().as_bytes()
518        );
519        assert_eq!(config.models_dir, tmp_path.join("models"));
520        assert!(config.offline);
521
522        for (var, value) in previous_env {
523            set_or_unset(var, value);
524        }
525    }
526
527    #[test]
528    fn rejects_empty_cache_dir() {
529        let _guard = env_lock();
530
531        let previous = env::var("MEMVID_CACHE_DIR").ok();
532        unsafe { env::set_var("MEMVID_CACHE_DIR", " ") };
533        let err = CliConfig::load().expect_err("should fail");
534        assert!(err.to_string().contains("cache directory"));
535        set_or_unset("MEMVID_CACHE_DIR", previous);
536    }
537}
538
539/// Initialize tracing/logging based on verbosity level
540pub fn init_tracing(verbosity: u8) -> Result<()> {
541    use std::io::IsTerminal;
542    use tracing_subscriber::{filter::Directive, fmt, EnvFilter};
543
544    let level = match verbosity {
545        0 => "warn",
546        1 => "info",
547        2 => "debug",
548        _ => "trace",
549    };
550
551    let mut env_filter =
552        EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new(level));
553    for directive_str in ["llama_cpp=error", "llama_cpp_sys=error", "ggml=error"] {
554        if let Ok(directive) = directive_str.parse::<Directive>() {
555            env_filter = env_filter.add_directive(directive);
556        }
557    }
558
559    // Disable ANSI color codes when stderr is not a terminal (e.g., piped or
560    // combined with `2>&1`). This prevents control characters from polluting
561    // JSON output when combined with stdout.
562    let use_ansi = std::io::stderr().is_terminal();
563
564    fmt()
565        .with_env_filter(env_filter)
566        .with_writer(std::io::stderr)
567        .with_target(false)
568        .without_time()
569        .with_ansi(use_ansi)
570        .try_init()
571        .map_err(|err| anyhow!(err))?;
572    Ok(())
573}
574
575/// Resolve LLM context budget override from CLI or environment
576pub fn resolve_llm_context_budget_override(cli_value: Option<usize>) -> Result<Option<usize>> {
577    use anyhow::bail;
578
579    if let Some(value) = cli_value {
580        if value == 0 {
581            bail!("--llm-context-depth must be a positive integer");
582        }
583        return Ok(Some(value));
584    }
585
586    let raw_env = match env::var("MEMVID_LLM_CONTEXT_BUDGET") {
587        Ok(value) => value,
588        Err(_) => return Ok(None),
589    };
590
591    let trimmed = raw_env.trim();
592    if trimmed.is_empty() {
593        return Ok(None);
594    }
595
596    let digits: String = trimmed
597        .chars()
598        .filter(|ch| !ch.is_ascii_whitespace() && *ch != '_')
599        .collect();
600
601    if digits.is_empty() {
602        bail!("MEMVID_LLM_CONTEXT_BUDGET must be a positive integer value");
603    }
604
605    let value: usize = digits.parse().map_err(|err| {
606        anyhow!(
607            "MEMVID_LLM_CONTEXT_BUDGET value '{}' is not a valid number: {}",
608            trimmed,
609            err
610        )
611    })?;
612
613    if value == 0 {
614        bail!("MEMVID_LLM_CONTEXT_BUDGET must be a positive integer");
615    }
616
617    Ok(Some(value))
618}
619
620use crate::gemini_embeddings::GeminiEmbeddingProvider;
621use crate::mistral_embeddings::MistralEmbeddingProvider;
622use crate::nvidia_embeddings::NvidiaEmbeddingProvider;
623use crate::openai_embeddings::OpenAIEmbeddingProvider;
624
625/// Internal embedding backend - local fastembed or remote providers.
626#[derive(Clone)]
627enum EmbeddingBackend {
628    #[cfg(feature = "local-embeddings")]
629    FastEmbed(std::sync::Arc<std::sync::Mutex<fastembed::TextEmbedding>>),
630    OpenAI(std::sync::Arc<OpenAIEmbeddingProvider>),
631    Nvidia(std::sync::Arc<NvidiaEmbeddingProvider>),
632    Gemini(std::sync::Arc<GeminiEmbeddingProvider>),
633    Mistral(std::sync::Arc<MistralEmbeddingProvider>),
634}
635
636/// Embedding runtime wrapper supporting local and remote embeddings
637#[derive(Clone)]
638pub struct EmbeddingRuntime {
639    backend: EmbeddingBackend,
640    model: EmbeddingModelChoice,
641    dimension: std::sync::Arc<AtomicUsize>,
642}
643
644impl EmbeddingRuntime {
645    #[cfg(feature = "local-embeddings")]
646    fn new_fastembed(
647        backend: fastembed::TextEmbedding,
648        model: EmbeddingModelChoice,
649        dimension: usize,
650    ) -> Self {
651        Self {
652            backend: EmbeddingBackend::FastEmbed(std::sync::Arc::new(std::sync::Mutex::new(
653                backend,
654            ))),
655            model,
656            dimension: std::sync::Arc::new(AtomicUsize::new(dimension)),
657        }
658    }
659
660    fn new_openai(
661        provider: OpenAIEmbeddingProvider,
662        model: EmbeddingModelChoice,
663        dimension: usize,
664    ) -> Self {
665        Self {
666            backend: EmbeddingBackend::OpenAI(std::sync::Arc::new(provider)),
667            model,
668            dimension: std::sync::Arc::new(AtomicUsize::new(dimension)),
669        }
670    }
671
672    fn new_nvidia(provider: NvidiaEmbeddingProvider, model: EmbeddingModelChoice) -> Self {
673        Self {
674            backend: EmbeddingBackend::Nvidia(std::sync::Arc::new(provider)),
675            model,
676            dimension: std::sync::Arc::new(AtomicUsize::new(0)),
677        }
678    }
679
680    fn new_gemini(
681        provider: GeminiEmbeddingProvider,
682        model: EmbeddingModelChoice,
683        dimension: usize,
684    ) -> Self {
685        Self {
686            backend: EmbeddingBackend::Gemini(std::sync::Arc::new(provider)),
687            model,
688            dimension: std::sync::Arc::new(AtomicUsize::new(dimension)),
689        }
690    }
691
692    fn new_mistral(
693        provider: MistralEmbeddingProvider,
694        model: EmbeddingModelChoice,
695        dimension: usize,
696    ) -> Self {
697        Self {
698            backend: EmbeddingBackend::Mistral(std::sync::Arc::new(provider)),
699            model,
700            dimension: std::sync::Arc::new(AtomicUsize::new(dimension)),
701        }
702    }
703
704    const MAX_OPENAI_EMBEDDING_TEXT_LEN: usize = 20_000;
705    // NVIDIA Integrate embeddings enforce a 4096 token limit; use a tighter char cap as a guardrail.
706    const MAX_NVIDIA_EMBEDDING_TEXT_LEN: usize = 12_000;
707
708    // Gemini has an 8192 token limit, using conservative estimate
709    const MAX_GEMINI_EMBEDDING_TEXT_LEN: usize = 20_000;
710    // Mistral has an 8192 token limit, using conservative estimate
711    const MAX_MISTRAL_EMBEDDING_TEXT_LEN: usize = 20_000;
712
713    fn max_remote_embedding_chars(&self) -> usize {
714        match &self.backend {
715            EmbeddingBackend::OpenAI(_) => Self::MAX_OPENAI_EMBEDDING_TEXT_LEN,
716            EmbeddingBackend::Nvidia(_) => Self::MAX_NVIDIA_EMBEDDING_TEXT_LEN,
717            EmbeddingBackend::Gemini(_) => Self::MAX_GEMINI_EMBEDDING_TEXT_LEN,
718            EmbeddingBackend::Mistral(_) => Self::MAX_MISTRAL_EMBEDDING_TEXT_LEN,
719            #[cfg(feature = "local-embeddings")]
720            EmbeddingBackend::FastEmbed(_) => usize::MAX,
721        }
722    }
723
724    /// Truncate text for embedding to reduce the risk of provider token-limit errors.
725    fn truncate_for_embedding<'a>(text: &'a str, max_chars: usize) -> std::borrow::Cow<'a, str> {
726        if text.len() <= max_chars {
727            std::borrow::Cow::Borrowed(text)
728        } else {
729            // Find the last valid UTF-8 char boundary within the limit
730            let truncated = &text[..max_chars];
731            let end = truncated
732                .char_indices()
733                .rev()
734                .next()
735                .map(|(i, c)| i + c.len_utf8())
736                .unwrap_or(max_chars);
737            tracing::info!(
738                "Truncated embedding text from {} to {} chars",
739                text.len(),
740                end
741            );
742            std::borrow::Cow::Owned(text[..end].to_string())
743        }
744    }
745
746    fn note_dimension(&self, observed: usize) -> Result<()> {
747        if observed == 0 {
748            return Err(anyhow!("embedding provider returned zero-length embedding"));
749        }
750
751        let current = self.dimension.load(Ordering::Relaxed);
752        if current == 0 {
753            self.dimension.store(observed, Ordering::Relaxed);
754            return Ok(());
755        }
756
757        if current != observed {
758            return Err(anyhow!(
759                "embedding provider returned {observed}D vectors but runtime expects {current}D"
760            ));
761        }
762
763        Ok(())
764    }
765
766    fn truncate_if_remote<'a>(&self, text: &'a str) -> std::borrow::Cow<'a, str> {
767        match &self.backend {
768            EmbeddingBackend::OpenAI(_)
769            | EmbeddingBackend::Nvidia(_)
770            | EmbeddingBackend::Gemini(_)
771            | EmbeddingBackend::Mistral(_) => {
772                Self::truncate_for_embedding(text, self.max_remote_embedding_chars())
773            }
774            #[cfg(feature = "local-embeddings")]
775            EmbeddingBackend::FastEmbed(_) => std::borrow::Cow::Borrowed(text),
776        }
777    }
778
779    pub fn embed_passage(&self, text: &str) -> Result<Vec<f32>> {
780        let text = self.truncate_if_remote(text);
781        let embedding = match &self.backend {
782            #[cfg(feature = "local-embeddings")]
783            EmbeddingBackend::FastEmbed(model) => {
784                let mut guard = model
785                    .lock()
786                    .map_err(|_| anyhow!("fastembed runtime poisoned"))?;
787                let outputs = guard
788                    .embed(vec![text.into_owned()], None)
789                    .map_err(|err| anyhow!("failed to compute embedding with fastembed: {err}"))?;
790                outputs
791                    .into_iter()
792                    .next()
793                    .ok_or_else(|| anyhow!("fastembed returned no embedding output"))?
794            }
795            EmbeddingBackend::OpenAI(provider) => {
796                use memvid_core::EmbeddingProvider;
797                provider
798                    .embed_text(&text)
799                    .map_err(|err| anyhow!("failed to compute embedding with OpenAI: {err}"))?
800            }
801            EmbeddingBackend::Nvidia(provider) => provider
802                .embed_passage(&text)
803                .map_err(|err| anyhow!("failed to compute embedding with NVIDIA: {err}"))?,
804            EmbeddingBackend::Gemini(provider) => provider
805                .embed_text(&text)
806                .map_err(|err| anyhow!("failed to compute embedding with Gemini: {err}"))?,
807            EmbeddingBackend::Mistral(provider) => provider
808                .embed_text(&text)
809                .map_err(|err| anyhow!("failed to compute embedding with Mistral: {err}"))?,
810        };
811
812        self.note_dimension(embedding.len())?;
813        Ok(embedding)
814    }
815
816    pub fn embed_query(&self, text: &str) -> Result<Vec<f32>> {
817        let text = self.truncate_if_remote(text);
818        match &self.backend {
819            EmbeddingBackend::Nvidia(provider) => {
820                let embedding = provider
821                    .embed_query(&text)
822                    .map_err(|err| anyhow!("failed to compute embedding with NVIDIA: {err}"))?;
823                self.note_dimension(embedding.len())?;
824                Ok(embedding)
825            }
826            _ => self.embed_passage(&text),
827        }
828    }
829
830    pub fn embed_batch_passages(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
831        if texts.is_empty() {
832            return Ok(Vec::new());
833        }
834
835        let truncated: Vec<std::borrow::Cow<'_, str>> =
836            texts.iter().map(|t| self.truncate_if_remote(t)).collect();
837        let truncated_refs: Vec<&str> = truncated.iter().map(|c| c.as_ref()).collect();
838
839        let embeddings = match &self.backend {
840            #[cfg(feature = "local-embeddings")]
841            EmbeddingBackend::FastEmbed(model) => {
842                let mut guard = model
843                    .lock()
844                    .map_err(|_| anyhow!("fastembed runtime poisoned"))?;
845                guard
846                    .embed(
847                        truncated_refs
848                            .iter()
849                            .map(|s| (*s).to_string())
850                            .collect::<Vec<String>>(),
851                        None,
852                    )
853                    .map_err(|err| anyhow!("failed to compute embeddings with fastembed: {err}"))?
854            }
855            EmbeddingBackend::OpenAI(provider) => {
856                use memvid_core::EmbeddingProvider;
857                provider
858                    .embed_batch(&truncated_refs)
859                    .map_err(|err| anyhow!("failed to compute embeddings with OpenAI: {err}"))?
860            }
861            EmbeddingBackend::Nvidia(provider) => provider
862                .embed_passages(&truncated_refs)
863                .map_err(|err| anyhow!("failed to compute embeddings with NVIDIA: {err}"))?,
864            EmbeddingBackend::Gemini(provider) => provider
865                .embed_batch(&truncated_refs)
866                .map_err(|err| anyhow!("failed to compute embeddings with Gemini: {err}"))?,
867            EmbeddingBackend::Mistral(provider) => provider
868                .embed_batch(&truncated_refs)
869                .map_err(|err| anyhow!("failed to compute embeddings with Mistral: {err}"))?,
870        };
871
872        if let Some(first) = embeddings.first() {
873            self.note_dimension(first.len())?;
874        }
875        if let Some(expected) = embeddings.first().map(|e| e.len()) {
876            if embeddings.iter().any(|e| e.len() != expected) {
877                return Err(anyhow!(
878                    "embedding provider returned mixed vector dimensions"
879                ));
880            }
881        }
882
883        Ok(embeddings)
884    }
885
886    pub fn embed_batch_queries(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
887        if texts.is_empty() {
888            return Ok(Vec::new());
889        }
890
891        let truncated: Vec<std::borrow::Cow<'_, str>> =
892            texts.iter().map(|t| self.truncate_if_remote(t)).collect();
893        let truncated_refs: Vec<&str> = truncated.iter().map(|c| c.as_ref()).collect();
894
895        match &self.backend {
896            EmbeddingBackend::Nvidia(provider) => {
897                let embeddings = provider
898                    .embed_queries(&truncated_refs)
899                    .map_err(|err| anyhow!("failed to compute embeddings with NVIDIA: {err}"))?;
900
901                if let Some(first) = embeddings.first() {
902                    self.note_dimension(first.len())?;
903                }
904                if let Some(expected) = embeddings.first().map(|e| e.len()) {
905                    if embeddings.iter().any(|e| e.len() != expected) {
906                        return Err(anyhow!(
907                            "embedding provider returned mixed vector dimensions"
908                        ));
909                    }
910                }
911
912                Ok(embeddings)
913            }
914            _ => self.embed_batch_passages(&truncated_refs),
915        }
916    }
917
918    pub fn dimension(&self) -> usize {
919        self.dimension.load(Ordering::Relaxed)
920    }
921
922    pub fn model_choice(&self) -> EmbeddingModelChoice {
923        self.model
924    }
925
926    pub fn provider_kind(&self) -> &'static str {
927        match &self.backend {
928            #[cfg(feature = "local-embeddings")]
929            EmbeddingBackend::FastEmbed(_) => "fastembed",
930            EmbeddingBackend::OpenAI(_) => "openai",
931            EmbeddingBackend::Nvidia(_) => "nvidia",
932            EmbeddingBackend::Gemini(_) => "gemini",
933            EmbeddingBackend::Mistral(_) => "mistral",
934        }
935    }
936
937    pub fn provider_model_id(&self) -> String {
938        match &self.backend {
939            #[cfg(feature = "local-embeddings")]
940            EmbeddingBackend::FastEmbed(_) => self.model.canonical_model_id().to_string(),
941            EmbeddingBackend::OpenAI(provider) => {
942                use memvid_core::EmbeddingProvider;
943                provider.model().to_string()
944            }
945            EmbeddingBackend::Nvidia(provider) => provider.model().to_string(),
946            EmbeddingBackend::Gemini(provider) => provider.model().to_string(),
947            EmbeddingBackend::Mistral(provider) => provider.model().to_string(),
948        }
949    }
950}
951
952impl memvid_core::VecEmbedder for EmbeddingRuntime {
953    fn embed_query(&self, text: &str) -> memvid_core::Result<Vec<f32>> {
954        EmbeddingRuntime::embed_query(self, text).map_err(|err| {
955            memvid_core::MemvidError::EmbeddingFailed {
956                reason: err.to_string().into_boxed_str(),
957            }
958        })
959    }
960
961    fn embedding_dimension(&self) -> usize {
962        self.dimension()
963    }
964}
965
966/// Ensure fastembed cache directory exists
967#[cfg(feature = "local-embeddings")]
968fn ensure_fastembed_cache(config: &CliConfig) -> Result<PathBuf> {
969    use std::fs;
970
971    let cache_dir = config.models_dir.clone();
972    fs::create_dir_all(&cache_dir)?;
973    Ok(cache_dir)
974}
975
976/// Get approximate model size in MB for user-friendly error messages
977fn model_size_mb(model: EmbeddingModelChoice) -> usize {
978    match model {
979        EmbeddingModelChoice::BgeSmall => 33,
980        EmbeddingModelChoice::BgeBase => 110,
981        EmbeddingModelChoice::Nomic => 137,
982        EmbeddingModelChoice::GteLarge => 327,
983        // Remote/cloud models don't require local download
984        EmbeddingModelChoice::OpenAILarge
985        | EmbeddingModelChoice::OpenAISmall
986        | EmbeddingModelChoice::OpenAIAda
987        | EmbeddingModelChoice::Nvidia
988        | EmbeddingModelChoice::Gemini
989        | EmbeddingModelChoice::Mistral => 0,
990    }
991}
992
993/// Instantiate an embedding runtime with the configured model
994fn instantiate_embedding_runtime(config: &CliConfig) -> Result<EmbeddingRuntime> {
995    use tracing::info;
996
997    let embedding_model = config.embedding_model;
998
999    if embedding_model.dimensions() > 0 {
1000        info!(
1001            "Loading embedding model: {} ({}D)",
1002            embedding_model.name(),
1003            embedding_model.dimensions()
1004        );
1005    } else {
1006        info!("Loading embedding model: {}", embedding_model.name());
1007    }
1008
1009    if config.offline && embedding_model.is_remote() {
1010        anyhow::bail!(
1011            "remote embeddings are unavailable while offline; set MEMVID_OFFLINE=0 or use a local embedding model"
1012        );
1013    }
1014
1015    // Check if OpenAI model
1016    if embedding_model.is_openai() {
1017        return instantiate_openai_runtime(embedding_model);
1018    }
1019
1020    if embedding_model == EmbeddingModelChoice::Nvidia {
1021        return instantiate_nvidia_runtime(None);
1022    }
1023
1024    if embedding_model == EmbeddingModelChoice::Gemini {
1025        return instantiate_gemini_runtime();
1026    }
1027
1028    if embedding_model == EmbeddingModelChoice::Mistral {
1029        return instantiate_mistral_runtime();
1030    }
1031
1032    // Local fastembed model
1033    #[cfg(feature = "local-embeddings")]
1034    {
1035        return instantiate_fastembed_runtime(config, embedding_model);
1036    }
1037
1038    #[cfg(not(feature = "local-embeddings"))]
1039    {
1040        anyhow::bail!(
1041            "Local embeddings are not available on this platform. \
1042            Please use a remote embedding provider:\n\
1043            - Set OPENAI_API_KEY and use --embedding-model openai-large\n\
1044            - Set GEMINI_API_KEY and use --embedding-model gemini\n\
1045            - Set MISTRAL_API_KEY and use --embedding-model mistral\n\
1046            - Set NVIDIA_API_KEY and use --embedding-model nvidia"
1047        );
1048    }
1049}
1050
1051/// Instantiate OpenAI embedding runtime
1052fn instantiate_openai_runtime(embedding_model: EmbeddingModelChoice) -> Result<EmbeddingRuntime> {
1053    use anyhow::bail;
1054    use memvid_core::EmbeddingConfig;
1055    use tracing::info;
1056
1057    let api_key = std::env::var("OPENAI_API_KEY").map_err(|_| {
1058        anyhow!("OPENAI_API_KEY environment variable is required for OpenAI embeddings")
1059    })?;
1060
1061    if api_key.is_empty() {
1062        bail!("OPENAI_API_KEY cannot be empty");
1063    }
1064
1065    let config = match embedding_model {
1066        EmbeddingModelChoice::OpenAILarge => EmbeddingConfig::openai_large(),
1067        EmbeddingModelChoice::OpenAISmall => EmbeddingConfig::openai_small(),
1068        EmbeddingModelChoice::OpenAIAda => EmbeddingConfig::openai_ada(),
1069        _ => unreachable!("is_openai() should have been false"),
1070    };
1071
1072    let provider = OpenAIEmbeddingProvider::new(api_key, config.clone())
1073        .map_err(|err| anyhow!("failed to create OpenAI embedding provider: {err}"))?;
1074
1075    info!(
1076        "OpenAI embedding provider ready: model={}, dimension={}",
1077        config.model, config.dimension
1078    );
1079
1080    Ok(EmbeddingRuntime::new_openai(
1081        provider,
1082        embedding_model,
1083        config.dimension,
1084    ))
1085}
1086
1087fn normalize_nvidia_embedding_model_override(raw: &str) -> Option<String> {
1088    let trimmed = raw.trim();
1089    if trimmed.is_empty() {
1090        return None;
1091    }
1092
1093    let lowered = trimmed.to_ascii_lowercase();
1094    if lowered == "nvidia" || lowered == "nv" {
1095        return None;
1096    }
1097
1098    let without_prefix = trimmed
1099        .strip_prefix("nvidia:")
1100        .or_else(|| trimmed.strip_prefix("nv:"))
1101        .unwrap_or(trimmed)
1102        .trim();
1103
1104    if without_prefix.is_empty() {
1105        return None;
1106    }
1107
1108    if without_prefix.eq_ignore_ascii_case("nv-embed-v1") {
1109        return Some("nvidia/nv-embed-v1".to_string());
1110    }
1111
1112    if without_prefix.contains('/') {
1113        return Some(without_prefix.to_string());
1114    }
1115
1116    Some(format!("nvidia/{without_prefix}"))
1117}
1118
1119/// Instantiate NVIDIA embedding runtime
1120fn instantiate_nvidia_runtime(model_override: Option<&str>) -> Result<EmbeddingRuntime> {
1121    use tracing::info;
1122
1123    let normalized = model_override.and_then(normalize_nvidia_embedding_model_override);
1124    let provider = NvidiaEmbeddingProvider::from_env(normalized.as_deref())
1125        .map_err(|err| anyhow!("failed to create NVIDIA embedding provider: {err}"))?;
1126
1127    info!(
1128        "NVIDIA embedding provider ready: model={}",
1129        provider.model()
1130    );
1131
1132    Ok(EmbeddingRuntime::new_nvidia(
1133        provider,
1134        EmbeddingModelChoice::Nvidia,
1135    ))
1136}
1137
1138/// Instantiate Gemini embedding runtime
1139fn instantiate_gemini_runtime() -> Result<EmbeddingRuntime> {
1140    use tracing::info;
1141
1142    let provider = GeminiEmbeddingProvider::from_env()
1143        .map_err(|err| anyhow!("failed to create Gemini embedding provider: {err}"))?;
1144
1145    let dimension = provider.dimension();
1146    info!(
1147        "Gemini embedding provider ready: model={}, dimension={}",
1148        provider.model(),
1149        dimension
1150    );
1151
1152    Ok(EmbeddingRuntime::new_gemini(
1153        provider,
1154        EmbeddingModelChoice::Gemini,
1155        dimension,
1156    ))
1157}
1158
1159/// Instantiate Mistral embedding runtime
1160fn instantiate_mistral_runtime() -> Result<EmbeddingRuntime> {
1161    use tracing::info;
1162
1163    let provider = MistralEmbeddingProvider::from_env()
1164        .map_err(|err| anyhow!("failed to create Mistral embedding provider: {err}"))?;
1165
1166    let dimension = provider.dimension();
1167    info!(
1168        "Mistral embedding provider ready: model={}, dimension={}",
1169        provider.model(),
1170        dimension
1171    );
1172
1173    Ok(EmbeddingRuntime::new_mistral(
1174        provider,
1175        EmbeddingModelChoice::Mistral,
1176        dimension,
1177    ))
1178}
1179
1180/// Instantiate fastembed (local) embedding runtime
1181#[cfg(feature = "local-embeddings")]
1182fn instantiate_fastembed_runtime(
1183    config: &CliConfig,
1184    embedding_model: EmbeddingModelChoice,
1185) -> Result<EmbeddingRuntime> {
1186    use anyhow::bail;
1187    use fastembed::{InitOptions, TextEmbedding};
1188    use std::fs;
1189
1190    let cache_dir = ensure_fastembed_cache(config)?;
1191
1192    if config.offline {
1193        let mut entries = fs::read_dir(&cache_dir)?;
1194        if entries.next().is_none() {
1195            bail!(
1196                "semantic embeddings unavailable while offline; allow one connected run so fastembed can cache model weights"
1197            );
1198        }
1199    }
1200
1201    let options = InitOptions::new(embedding_model.to_fastembed_model())
1202        .with_cache_dir(cache_dir)
1203        .with_show_download_progress(true);
1204    let mut model = TextEmbedding::try_new(options).map_err(|err| {
1205        // Provide platform-specific guidance for model download issues
1206        let platform_hint = if cfg!(target_os = "windows") {
1207            "\n\nWindows users: If model downloads fail, try:\n\
1208            1. Run as Administrator\n\
1209            2. Check your antivirus isn't blocking downloads\n\
1210            3. Use OpenAI embeddings instead: set OPENAI_API_KEY and use --embedding-model openai"
1211        } else if cfg!(target_os = "linux") {
1212            "\n\nLinux users: If model downloads fail, try:\n\
1213            1. Check disk space in ~/.memvid/models\n\
1214            2. Ensure you have network access to huggingface.co\n\
1215            3. Use OpenAI embeddings instead: export OPENAI_API_KEY=... and use --embedding-model openai"
1216        } else {
1217            "\n\nIf model downloads fail, try using OpenAI embeddings:\n\
1218            export OPENAI_API_KEY=your-key && memvid ... --embedding-model openai"
1219        };
1220
1221        anyhow!(
1222            "Failed to initialize embedding model '{}': {err}\n\n\
1223            This typically means the model couldn't be downloaded or loaded.\n\
1224            Model size: ~{} MB{}\n\n\
1225            See: https://docs.memvid.com/embedding-models",
1226            embedding_model.name(),
1227            model_size_mb(embedding_model),
1228            platform_hint
1229        )
1230    })?;
1231
1232    let probe = model
1233        .embed(vec!["memvid probe".to_string()], None)
1234        .map_err(|err| anyhow!("failed to determine embedding dimension: {err}"))?;
1235    let dimension = probe.first().map(|vec| vec.len()).unwrap_or(0);
1236
1237    if dimension == 0 {
1238        bail!("fastembed reported zero-length embeddings");
1239    }
1240
1241    // Verify dimension matches expected
1242    if dimension != embedding_model.dimensions() {
1243        tracing::warn!(
1244            "Embedding dimension mismatch: expected {}, got {}",
1245            embedding_model.dimensions(),
1246            dimension
1247        );
1248    }
1249
1250    Ok(EmbeddingRuntime::new_fastembed(
1251        model,
1252        embedding_model,
1253        dimension,
1254    ))
1255}
1256
1257/// Load embedding runtime (fails if unavailable)
1258pub fn load_embedding_runtime(config: &CliConfig) -> Result<EmbeddingRuntime> {
1259    use anyhow::bail;
1260
1261    match instantiate_embedding_runtime(config) {
1262        Ok(runtime) => Ok(runtime),
1263        Err(err) => {
1264            if config.offline {
1265                bail!(
1266                    "semantic embeddings unavailable while offline; allow one connected run so fastembed can cache model weights ({err})"
1267                );
1268            }
1269            Err(err)
1270        }
1271    }
1272}
1273
1274/// Try to load embedding runtime (returns None if unavailable)
1275pub fn try_load_embedding_runtime(config: &CliConfig) -> Option<EmbeddingRuntime> {
1276    use tracing::warn;
1277
1278    match instantiate_embedding_runtime(config) {
1279        Ok(runtime) => Some(runtime),
1280        Err(err) => {
1281            warn!("semantic embeddings unavailable: {err}");
1282            None
1283        }
1284    }
1285}
1286
1287/// Load embedding runtime with an optional model override.
1288/// If `model_override` is provided, it will be used instead of the config's embedding_model.
1289pub fn load_embedding_runtime_with_model(
1290    config: &CliConfig,
1291    model_override: Option<&str>,
1292) -> Result<EmbeddingRuntime> {
1293    use tracing::info;
1294
1295    let mut raw_override: Option<&str> = None;
1296    let embedding_model = match model_override {
1297        Some(model_str) => {
1298            raw_override = Some(model_str);
1299            let parsed = model_str.parse::<EmbeddingModelChoice>()?;
1300            if parsed.dimensions() > 0 {
1301                info!(
1302                    "Using embedding model override: {} ({}D)",
1303                    parsed.name(),
1304                    parsed.dimensions()
1305                );
1306            } else {
1307                info!("Using embedding model override: {}", parsed.name());
1308            }
1309            parsed
1310        }
1311        None => config.embedding_model,
1312    };
1313
1314    if embedding_model.dimensions() > 0 {
1315        info!(
1316            "Loading embedding model: {} ({}D)",
1317            embedding_model.name(),
1318            embedding_model.dimensions()
1319        );
1320    } else {
1321        info!("Loading embedding model: {}", embedding_model.name());
1322    }
1323
1324    if config.offline && embedding_model.is_remote() {
1325        anyhow::bail!(
1326            "remote embeddings are unavailable while offline; set MEMVID_OFFLINE=0 or use a local embedding model"
1327        );
1328    }
1329
1330    if embedding_model.is_openai() {
1331        return instantiate_openai_runtime(embedding_model);
1332    }
1333
1334    if embedding_model == EmbeddingModelChoice::Nvidia {
1335        return instantiate_nvidia_runtime(raw_override);
1336    }
1337
1338    if embedding_model == EmbeddingModelChoice::Gemini {
1339        return instantiate_gemini_runtime();
1340    }
1341
1342    if embedding_model == EmbeddingModelChoice::Mistral {
1343        return instantiate_mistral_runtime();
1344    }
1345
1346    #[cfg(feature = "local-embeddings")]
1347    {
1348        return instantiate_fastembed_runtime(config, embedding_model);
1349    }
1350
1351    #[cfg(not(feature = "local-embeddings"))]
1352    {
1353        anyhow::bail!(
1354            "Local embeddings are not available on this platform. \
1355            Please use a remote embedding provider."
1356        );
1357    }
1358}
1359
1360/// Try to load embedding runtime with model override (returns None if unavailable)
1361pub fn try_load_embedding_runtime_with_model(
1362    config: &CliConfig,
1363    model_override: Option<&str>,
1364) -> Option<EmbeddingRuntime> {
1365    use tracing::warn;
1366
1367    match load_embedding_runtime_with_model(config, model_override) {
1368        Ok(runtime) => Some(runtime),
1369        Err(err) => {
1370            warn!("semantic embeddings unavailable: {err}");
1371            None
1372        }
1373    }
1374}
1375
1376/// Load embedding runtime by auto-detecting from MV2 vector dimension.
1377///
1378/// Priority:
1379/// 1. Explicit model override (--query-embedding-model flag)
1380/// 2. Auto-detect from MV2 file's stored dimension
1381/// 3. Fall back to config default
1382///
1383/// This allows users to omit --query-embedding-model when querying files
1384/// created with non-default embedding models (like OpenAI).
1385pub fn load_embedding_runtime_for_mv2(
1386    config: &CliConfig,
1387    model_override: Option<&str>,
1388    mv2_dimension: Option<u32>,
1389) -> Result<EmbeddingRuntime> {
1390    use tracing::info;
1391
1392    // Priority 1: Explicit override
1393    if let Some(model_str) = model_override {
1394        return load_embedding_runtime_with_model(config, Some(model_str));
1395    }
1396
1397    // Priority 2: Auto-detect from MV2 dimension
1398    if let Some(dim) = mv2_dimension {
1399        if let Some(detected_model) = EmbeddingModelChoice::from_dimension(dim) {
1400            info!(
1401                "Auto-detected embedding model from MV2: {} ({}D)",
1402                detected_model.name(),
1403                dim
1404            );
1405
1406            // For OpenAI models, check if API key is available
1407            if detected_model.is_openai() {
1408                if std::env::var("OPENAI_API_KEY").is_ok() {
1409                    return load_embedding_runtime_with_model(config, Some(detected_model.name()));
1410                } else {
1411                    // OpenAI detected but no API key - provide helpful error
1412                    return Err(anyhow!(
1413                        "MV2 file uses OpenAI embeddings ({}D) but OPENAI_API_KEY is not set.\n\n\
1414                        Options:\n\
1415                        1. Set OPENAI_API_KEY environment variable\n\
1416                        2. Use --query-embedding-model to specify a different model\n\
1417                        3. Use lexical-only search with --mode lex\n\n\
1418                        See: https://docs.memvid.com/embedding-models",
1419                        dim
1420                    ));
1421                }
1422            }
1423
1424            return load_embedding_runtime_with_model(config, Some(detected_model.name()));
1425        }
1426    }
1427
1428    // Priority 3: Fall back to config default
1429    load_embedding_runtime(config)
1430}
1431
1432/// Try to load embedding runtime for MV2 with auto-detection (returns None if unavailable)
1433pub fn try_load_embedding_runtime_for_mv2(
1434    config: &CliConfig,
1435    model_override: Option<&str>,
1436    mv2_dimension: Option<u32>,
1437) -> Option<EmbeddingRuntime> {
1438    use tracing::warn;
1439
1440    match load_embedding_runtime_for_mv2(config, model_override, mv2_dimension) {
1441        Ok(runtime) => Some(runtime),
1442        Err(err) => {
1443            warn!("semantic embeddings unavailable: {err}");
1444            None
1445        }
1446    }
1447}