memvid_cli/
config.rs

1//! CLI configuration and environment handling
2//!
3//! This module provides configuration loading from environment variables,
4//! tracing initialization, and embedding runtime management for semantic search.
5
6use std::env;
7use std::path::PathBuf;
8use std::str::FromStr;
9use std::sync::atomic::{AtomicUsize, Ordering};
10
11use anyhow::{anyhow, Result};
12use ed25519_dalek::VerifyingKey;
13
14const DEFAULT_API_URL: &str = "https://kgpfm35ddc.us-east-2.awsapprunner.com";
15const DEFAULT_CACHE_DIR: &str = "~/.cache/memvid";
16
17/// Supported embedding models for semantic search
18#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
19pub enum EmbeddingModelChoice {
20    /// BGE-small-en-v1.5: Fast, 384-dim, ~78% accuracy (default)
21    #[default]
22    BgeSmall,
23    /// BGE-base-en-v1.5: Balanced, 768-dim, ~85% accuracy
24    BgeBase,
25    /// Nomic-embed-text-v1.5: High accuracy, 768-dim, ~86% accuracy
26    Nomic,
27    /// GTE-large-en-v1.5: Best semantic depth, 1024-dim
28    GteLarge,
29    /// OpenAI text-embedding-3-large: Highest quality, 3072-dim (requires OPENAI_API_KEY)
30    OpenAILarge,
31    /// OpenAI text-embedding-3-small: Good quality, 1536-dim (requires OPENAI_API_KEY)
32    OpenAISmall,
33    /// OpenAI text-embedding-ada-002: Legacy model, 1536-dim (requires OPENAI_API_KEY)
34    OpenAIAda,
35    /// NVIDIA nv-embed-v1: High quality, remote embeddings (requires NVIDIA_API_KEY)
36    Nvidia,
37}
38
39impl EmbeddingModelChoice {
40    /// Check if this is an OpenAI model (requires OPENAI_API_KEY)
41    pub fn is_openai(&self) -> bool {
42        matches!(
43            self,
44            EmbeddingModelChoice::OpenAILarge
45                | EmbeddingModelChoice::OpenAISmall
46                | EmbeddingModelChoice::OpenAIAda
47        )
48    }
49
50    /// Get the fastembed EmbeddingModel enum value (only for local models)
51    ///
52    /// # Panics
53    /// Panics if called on an OpenAI model. Use `is_openai()` to check first.
54    pub fn to_fastembed_model(&self) -> fastembed::EmbeddingModel {
55        match self {
56            EmbeddingModelChoice::BgeSmall => fastembed::EmbeddingModel::BGESmallENV15,
57            EmbeddingModelChoice::BgeBase => fastembed::EmbeddingModel::BGEBaseENV15,
58            EmbeddingModelChoice::Nomic => fastembed::EmbeddingModel::NomicEmbedTextV15,
59            EmbeddingModelChoice::GteLarge => fastembed::EmbeddingModel::GTELargeENV15,
60            EmbeddingModelChoice::OpenAILarge
61            | EmbeddingModelChoice::OpenAISmall
62            | EmbeddingModelChoice::OpenAIAda => {
63                panic!("OpenAI models don't use fastembed. Check is_openai() first.")
64            }
65            EmbeddingModelChoice::Nvidia => {
66                panic!("NVIDIA embeddings don't use fastembed. Check is_openai()/model choice first.")
67            }
68        }
69    }
70
71    /// Get human-readable model name
72    pub fn name(&self) -> &'static str {
73        match self {
74            EmbeddingModelChoice::BgeSmall => "bge-small",
75            EmbeddingModelChoice::BgeBase => "bge-base",
76            EmbeddingModelChoice::Nomic => "nomic",
77            EmbeddingModelChoice::GteLarge => "gte-large",
78            EmbeddingModelChoice::OpenAILarge => "openai-large",
79            EmbeddingModelChoice::OpenAISmall => "openai-small",
80            EmbeddingModelChoice::OpenAIAda => "openai-ada",
81            EmbeddingModelChoice::Nvidia => "nvidia",
82        }
83    }
84
85    /// Get the canonical provider model identifier used for persisted metadata.
86    ///
87    /// This is intended to match upstream provider IDs (OpenAI) and HuggingFace-style IDs
88    /// (fastembed/ONNX) so that memories can record an embedding "identity" that other
89    /// runtimes can select deterministically.
90    pub fn canonical_model_id(&self) -> &'static str {
91        match self {
92            EmbeddingModelChoice::BgeSmall => "BAAI/bge-small-en-v1.5",
93            EmbeddingModelChoice::BgeBase => "BAAI/bge-base-en-v1.5",
94            EmbeddingModelChoice::Nomic => "nomic-embed-text-v1.5",
95            EmbeddingModelChoice::GteLarge => "thenlper/gte-large",
96            EmbeddingModelChoice::OpenAILarge => "text-embedding-3-large",
97            EmbeddingModelChoice::OpenAISmall => "text-embedding-3-small",
98            EmbeddingModelChoice::OpenAIAda => "text-embedding-ada-002",
99            EmbeddingModelChoice::Nvidia => "nvidia/nv-embed-v1",
100        }
101    }
102
103    /// Get embedding dimensions
104    pub fn dimensions(&self) -> usize {
105        match self {
106            EmbeddingModelChoice::BgeSmall => 384,
107            EmbeddingModelChoice::BgeBase => 768,
108            EmbeddingModelChoice::Nomic => 768,
109            EmbeddingModelChoice::GteLarge => 1024,
110            EmbeddingModelChoice::OpenAILarge => 3072,
111            EmbeddingModelChoice::OpenAISmall => 1536,
112            EmbeddingModelChoice::OpenAIAda => 1536,
113            // Remote model; infer from the first embedding response.
114            EmbeddingModelChoice::Nvidia => 0,
115        }
116    }
117}
118
119impl FromStr for EmbeddingModelChoice {
120    type Err = anyhow::Error;
121
122    fn from_str(s: &str) -> Result<Self> {
123        let lowered = s.trim().to_ascii_lowercase();
124        match lowered.as_str() {
125            "bge-small" | "bge_small" | "bgesmall" | "small" => Ok(EmbeddingModelChoice::BgeSmall),
126            "baai/bge-small-en-v1.5" => Ok(EmbeddingModelChoice::BgeSmall),
127            "bge-base" | "bge_base" | "bgebase" | "base" => Ok(EmbeddingModelChoice::BgeBase),
128            "baai/bge-base-en-v1.5" => Ok(EmbeddingModelChoice::BgeBase),
129            "nomic" | "nomic-embed" | "nomic_embed" => Ok(EmbeddingModelChoice::Nomic),
130            "nomic-embed-text-v1.5" => Ok(EmbeddingModelChoice::Nomic),
131            "gte-large" | "gte_large" | "gtelarge" | "gte" => Ok(EmbeddingModelChoice::GteLarge),
132            "thenlper/gte-large" => Ok(EmbeddingModelChoice::GteLarge),
133            // OpenAI models - default "openai" maps to "openai-large" for highest quality
134            "openai" | "openai-large" | "openai_large" | "text-embedding-3-large" => {
135                Ok(EmbeddingModelChoice::OpenAILarge)
136            }
137            "openai-small" | "openai_small" | "text-embedding-3-small" => {
138                Ok(EmbeddingModelChoice::OpenAISmall)
139            }
140            "openai-ada" | "openai_ada" | "text-embedding-ada-002" | "ada" => {
141                Ok(EmbeddingModelChoice::OpenAIAda)
142            }
143            "nvidia" | "nv" | "nv-embed-v1" | "nvidia/nv-embed-v1" => Ok(EmbeddingModelChoice::Nvidia),
144            _ if lowered.starts_with("nvidia/") || lowered.starts_with("nvidia:") || lowered.starts_with("nv:") => {
145                Ok(EmbeddingModelChoice::Nvidia)
146            }
147            _ => Err(anyhow!(
148                "unknown embedding model '{}'. Valid options: bge-small, bge-base, nomic, gte-large, openai, openai-small, openai-ada, nvidia",
149                s
150            )),
151        }
152    }
153}
154
155impl std::fmt::Display for EmbeddingModelChoice {
156    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
157        write!(f, "{}", self.name())
158    }
159}
160
161impl EmbeddingModelChoice {
162    /// Infer the best embedding model from vector dimension stored in MV2 file.
163    ///
164    /// This enables auto-detection: users don't need to specify --query-embedding-model
165    /// if the MV2 file has vectors. The dimension uniquely identifies the model family.
166    ///
167    /// # Dimension Mapping
168    /// - 384  → BGE-small (default local model)
169    /// - 768  → BGE-base (could also be Nomic, but same dimension works)
170    /// - 1024 → GTE-large
171    /// - 1536 → OpenAI small/ada
172    /// - 3072 → OpenAI large
173    pub fn from_dimension(dim: u32) -> Option<Self> {
174        match dim {
175            384 => Some(EmbeddingModelChoice::BgeSmall),
176            768 => Some(EmbeddingModelChoice::BgeBase), // Could be Nomic, but same dim
177            1024 => Some(EmbeddingModelChoice::GteLarge),
178            1536 => Some(EmbeddingModelChoice::OpenAISmall), // Could be Ada, same dim
179            3072 => Some(EmbeddingModelChoice::OpenAILarge),
180            0 => None, // No vectors in file
181            _ => {
182                tracing::warn!("Unknown embedding dimension {}, using default model", dim);
183                None
184            }
185        }
186    }
187}
188
189/// CLI configuration loaded from environment variables
190#[derive(Debug, Clone)]
191pub struct CliConfig {
192    pub api_key: Option<String>,
193    pub api_url: String,
194    pub cache_dir: PathBuf,
195    pub ticket_pubkey: Option<VerifyingKey>,
196    pub models_dir: PathBuf,
197    pub offline: bool,
198    /// Embedding model for semantic search (can be overridden by CLI flag)
199    pub embedding_model: EmbeddingModelChoice,
200}
201
202impl PartialEq for CliConfig {
203    fn eq(&self, other: &Self) -> bool {
204        self.api_key == other.api_key
205            && self.api_url == other.api_url
206            && self.cache_dir == other.cache_dir
207            && self.models_dir == other.models_dir
208            && self.offline == other.offline
209            && self.embedding_model == other.embedding_model
210    }
211}
212
213impl Eq for CliConfig {}
214
215impl CliConfig {
216    pub fn load() -> Result<Self> {
217        let api_key = env::var("MEMVID_API_KEY").ok().and_then(|value| {
218            let trimmed = value.trim().to_string();
219            (!trimmed.is_empty()).then_some(trimmed)
220        });
221
222        let api_url = env::var("MEMVID_API_URL").unwrap_or_else(|_| DEFAULT_API_URL.to_string());
223
224        let cache_dir_raw =
225            env::var("MEMVID_CACHE_DIR").unwrap_or_else(|_| DEFAULT_CACHE_DIR.to_string());
226        let cache_dir = expand_path(&cache_dir_raw)?;
227
228        let models_dir_raw =
229            env::var("MEMVID_MODELS_DIR").unwrap_or_else(|_| "~/.memvid/models".to_string());
230        let models_dir = expand_path(&models_dir_raw)?;
231
232        let ticket_pubkey = env::var("MEMVID_TICKET_PUBKEY")
233            .ok()
234            .and_then(|value| {
235                let trimmed = value.trim();
236                if trimmed.is_empty() {
237                    None
238                } else {
239                    Some(memvid_core::parse_ed25519_public_key_base64(trimmed))
240                }
241            })
242            .transpose()?;
243
244        let offline = env::var("MEMVID_OFFLINE")
245            .ok()
246            .map(|value| match value.trim().to_ascii_lowercase().as_str() {
247                "1" | "true" | "yes" => true,
248                _ => false,
249            })
250            .unwrap_or(false);
251
252        // Load embedding model from env var, default to BGE-small
253        let embedding_model = env::var("MEMVID_EMBEDDING_MODEL")
254            .ok()
255            .and_then(|value| {
256                let trimmed = value.trim();
257                if trimmed.is_empty() {
258                    None
259                } else {
260                    EmbeddingModelChoice::from_str(trimmed).ok()
261                }
262            })
263            .unwrap_or_default();
264
265        Ok(Self {
266            api_key,
267            api_url,
268            cache_dir,
269            ticket_pubkey,
270            models_dir,
271            offline,
272            embedding_model,
273        })
274    }
275
276    /// Create a new config with a different embedding model
277    pub fn with_embedding_model(&self, model: EmbeddingModelChoice) -> Self {
278        Self {
279            embedding_model: model,
280            ..self.clone()
281        }
282    }
283}
284
285fn expand_path(value: &str) -> Result<PathBuf> {
286    if value.trim().is_empty() {
287        return Err(anyhow!("cache directory cannot be empty"));
288    }
289
290    let expanded = if let Some(stripped) = value.strip_prefix("~/") {
291        home_dir()?.join(stripped)
292    } else if let Some(stripped) = value.strip_prefix("~\\") {
293        // Support Windows-style "~\" prefix.
294        home_dir()?.join(stripped)
295    } else if value == "~" {
296        home_dir()?
297    } else {
298        PathBuf::from(value)
299    };
300
301    if expanded.is_absolute() {
302        Ok(expanded)
303    } else {
304        Ok(env::current_dir()?.join(expanded))
305    }
306}
307
308fn home_dir() -> Result<PathBuf> {
309    if let Some(path) = env::var_os("HOME") {
310        if !path.is_empty() {
311            return Ok(PathBuf::from(path));
312        }
313    }
314
315    #[cfg(windows)]
316    {
317        if let Some(path) = env::var_os("USERPROFILE") {
318            if !path.is_empty() {
319                return Ok(PathBuf::from(path));
320            }
321        }
322        if let (Some(drive), Some(path)) = (env::var_os("HOMEDRIVE"), env::var_os("HOMEPATH")) {
323            if !drive.is_empty() && !path.is_empty() {
324                return Ok(PathBuf::from(format!(
325                    "{}{}",
326                    drive.to_string_lossy(),
327                    path.to_string_lossy()
328                )));
329            }
330        }
331    }
332
333    Err(anyhow!("unable to resolve home directory"))
334}
335
336#[cfg(test)]
337mod tests {
338    use super::*;
339    use base64::engine::general_purpose::STANDARD as BASE64_STANDARD;
340    use base64::Engine;
341    use ed25519_dalek::SigningKey;
342    use std::sync::{Mutex, OnceLock};
343
344    fn env_lock() -> std::sync::MutexGuard<'static, ()> {
345        static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
346        LOCK.get_or_init(|| Mutex::new(())).lock().unwrap()
347    }
348
349    fn set_or_unset(var: &str, value: Option<String>) {
350        match value {
351            Some(v) => unsafe { env::set_var(var, v) },
352            None => unsafe { env::remove_var(var) },
353        }
354    }
355
356    #[test]
357    fn defaults_expand_using_home_directory() {
358        let _guard = env_lock();
359
360        let previous_home = env::var("HOME").ok();
361        #[cfg(windows)]
362        let previous_userprofile = env::var("USERPROFILE").ok();
363
364        for var in [
365            "MEMVID_API_KEY",
366            "MEMVID_API_URL",
367            "MEMVID_CACHE_DIR",
368            "MEMVID_TICKET_PUBKEY",
369            "MEMVID_MODELS_DIR",
370            "MEMVID_OFFLINE",
371        ] {
372            unsafe { env::remove_var(var) };
373        }
374
375        let tmp = tempfile::tempdir().expect("tmpdir");
376        let tmp_path = tmp.path().to_path_buf();
377        unsafe { env::set_var("HOME", &tmp_path) };
378        #[cfg(windows)]
379        unsafe {
380            env::set_var("USERPROFILE", &tmp_path)
381        };
382
383        let config = CliConfig::load().expect("load");
384        assert_eq!(config.api_key, None);
385        assert_eq!(config.api_url, DEFAULT_API_URL);
386        assert_eq!(config.cache_dir, tmp_path.join(".cache/memvid"));
387        assert!(config.ticket_pubkey.is_none());
388        assert_eq!(config.models_dir, tmp_path.join(".memvid/models"));
389        assert!(!config.offline);
390
391        set_or_unset("HOME", previous_home);
392        #[cfg(windows)]
393        {
394            set_or_unset("USERPROFILE", previous_userprofile);
395        }
396    }
397
398    #[test]
399    fn env_overrides_are_respected() {
400        let _guard = env_lock();
401
402        let previous_env: Vec<(&'static str, Option<String>)> = [
403            "MEMVID_API_KEY",
404            "MEMVID_API_URL",
405            "MEMVID_CACHE_DIR",
406            "MEMVID_TICKET_PUBKEY",
407            "MEMVID_MODELS_DIR",
408            "MEMVID_OFFLINE",
409        ]
410        .into_iter()
411        .map(|var| (var, env::var(var).ok()))
412        .collect();
413
414        unsafe { env::set_var("MEMVID_API_KEY", "abc123") };
415        unsafe { env::set_var("MEMVID_API_URL", "https://staging.memvid.app") };
416        unsafe { env::set_var("MEMVID_CACHE_DIR", "~/memvid-cache") };
417        unsafe { env::set_var("MEMVID_MODELS_DIR", "~/models") };
418        unsafe { env::set_var("MEMVID_OFFLINE", "true") };
419        let signing = SigningKey::from_bytes(&[9u8; 32]);
420        let encoded = BASE64_STANDARD.encode(signing.verifying_key().as_bytes());
421        unsafe { env::set_var("MEMVID_TICKET_PUBKEY", encoded) };
422
423        let tmp = tempfile::tempdir().expect("tmpdir");
424        let tmp_path = tmp.path().to_path_buf();
425        unsafe { env::set_var("HOME", &tmp_path) };
426        #[cfg(windows)]
427        unsafe {
428            env::set_var("USERPROFILE", &tmp_path)
429        };
430
431        let config = CliConfig::load().expect("load");
432        assert_eq!(config.api_key.as_deref(), Some("abc123"));
433        assert_eq!(config.api_url, "https://staging.memvid.app");
434        assert_eq!(config.cache_dir, tmp_path.join("memvid-cache"));
435        assert_eq!(
436            config.ticket_pubkey.expect("pubkey").as_bytes(),
437            signing.verifying_key().as_bytes()
438        );
439        assert_eq!(config.models_dir, tmp_path.join("models"));
440        assert!(config.offline);
441
442        for (var, value) in previous_env {
443            set_or_unset(var, value);
444        }
445    }
446
447    #[test]
448    fn rejects_empty_cache_dir() {
449        let _guard = env_lock();
450
451        let previous = env::var("MEMVID_CACHE_DIR").ok();
452        unsafe { env::set_var("MEMVID_CACHE_DIR", " ") };
453        let err = CliConfig::load().expect_err("should fail");
454        assert!(err.to_string().contains("cache directory"));
455        set_or_unset("MEMVID_CACHE_DIR", previous);
456    }
457}
458
459/// Initialize tracing/logging based on verbosity level
460pub fn init_tracing(verbosity: u8) -> Result<()> {
461    use std::io::IsTerminal;
462    use tracing_subscriber::{filter::Directive, fmt, EnvFilter};
463
464    let level = match verbosity {
465        0 => "warn",
466        1 => "info",
467        2 => "debug",
468        _ => "trace",
469    };
470
471    let mut env_filter =
472        EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new(level));
473    for directive_str in ["llama_cpp=error", "llama_cpp_sys=error", "ggml=error"] {
474        if let Ok(directive) = directive_str.parse::<Directive>() {
475            env_filter = env_filter.add_directive(directive);
476        }
477    }
478
479    // Disable ANSI color codes when stderr is not a terminal (e.g., piped or
480    // combined with `2>&1`). This prevents control characters from polluting
481    // JSON output when combined with stdout.
482    let use_ansi = std::io::stderr().is_terminal();
483
484    fmt()
485        .with_env_filter(env_filter)
486        .with_writer(std::io::stderr)
487        .with_target(false)
488        .without_time()
489        .with_ansi(use_ansi)
490        .try_init()
491        .map_err(|err| anyhow!(err))?;
492    Ok(())
493}
494
495/// Resolve LLM context budget override from CLI or environment
496pub fn resolve_llm_context_budget_override(cli_value: Option<usize>) -> Result<Option<usize>> {
497    use anyhow::bail;
498
499    if let Some(value) = cli_value {
500        if value == 0 {
501            bail!("--llm-context-depth must be a positive integer");
502        }
503        return Ok(Some(value));
504    }
505
506    let raw_env = match env::var("MEMVID_LLM_CONTEXT_BUDGET") {
507        Ok(value) => value,
508        Err(_) => return Ok(None),
509    };
510
511    let trimmed = raw_env.trim();
512    if trimmed.is_empty() {
513        return Ok(None);
514    }
515
516    let digits: String = trimmed
517        .chars()
518        .filter(|ch| !ch.is_ascii_whitespace() && *ch != '_')
519        .collect();
520
521    if digits.is_empty() {
522        bail!("MEMVID_LLM_CONTEXT_BUDGET must be a positive integer value");
523    }
524
525    let value: usize = digits.parse().map_err(|err| {
526        anyhow!(
527            "MEMVID_LLM_CONTEXT_BUDGET value '{}' is not a valid number: {}",
528            trimmed,
529            err
530        )
531    })?;
532
533    if value == 0 {
534        bail!("MEMVID_LLM_CONTEXT_BUDGET must be a positive integer");
535    }
536
537    Ok(Some(value))
538}
539
540use crate::nvidia_embeddings::NvidiaEmbeddingProvider;
541use crate::openai_embeddings::OpenAIEmbeddingProvider;
542
543/// Internal embedding backend - local fastembed or remote providers.
544#[derive(Clone)]
545enum EmbeddingBackend {
546    FastEmbed(std::sync::Arc<std::sync::Mutex<fastembed::TextEmbedding>>),
547    OpenAI(std::sync::Arc<OpenAIEmbeddingProvider>),
548    Nvidia(std::sync::Arc<NvidiaEmbeddingProvider>),
549}
550
551/// Embedding runtime wrapper supporting local and remote embeddings
552#[derive(Clone)]
553pub struct EmbeddingRuntime {
554    backend: EmbeddingBackend,
555    model: EmbeddingModelChoice,
556    dimension: std::sync::Arc<AtomicUsize>,
557}
558
559impl EmbeddingRuntime {
560    fn new_fastembed(
561        backend: fastembed::TextEmbedding,
562        model: EmbeddingModelChoice,
563        dimension: usize,
564    ) -> Self {
565        Self {
566            backend: EmbeddingBackend::FastEmbed(std::sync::Arc::new(std::sync::Mutex::new(
567                backend,
568            ))),
569            model,
570            dimension: std::sync::Arc::new(AtomicUsize::new(dimension)),
571        }
572    }
573
574    fn new_openai(
575        provider: OpenAIEmbeddingProvider,
576        model: EmbeddingModelChoice,
577        dimension: usize,
578    ) -> Self {
579        Self {
580            backend: EmbeddingBackend::OpenAI(std::sync::Arc::new(provider)),
581            model,
582            dimension: std::sync::Arc::new(AtomicUsize::new(dimension)),
583        }
584    }
585
586    fn new_nvidia(provider: NvidiaEmbeddingProvider, model: EmbeddingModelChoice) -> Self {
587        Self {
588            backend: EmbeddingBackend::Nvidia(std::sync::Arc::new(provider)),
589            model,
590            dimension: std::sync::Arc::new(AtomicUsize::new(0)),
591        }
592    }
593
594    const MAX_OPENAI_EMBEDDING_TEXT_LEN: usize = 20_000;
595    // NVIDIA Integrate embeddings enforce a 4096 token limit; use a tighter char cap as a guardrail.
596    const MAX_NVIDIA_EMBEDDING_TEXT_LEN: usize = 12_000;
597
598    fn max_remote_embedding_chars(&self) -> usize {
599        match &self.backend {
600            EmbeddingBackend::OpenAI(_) => Self::MAX_OPENAI_EMBEDDING_TEXT_LEN,
601            EmbeddingBackend::Nvidia(_) => Self::MAX_NVIDIA_EMBEDDING_TEXT_LEN,
602            EmbeddingBackend::FastEmbed(_) => usize::MAX,
603        }
604    }
605
606    /// Truncate text for embedding to reduce the risk of provider token-limit errors.
607    fn truncate_for_embedding<'a>(
608        text: &'a str,
609        max_chars: usize,
610    ) -> std::borrow::Cow<'a, str> {
611        if text.len() <= max_chars {
612            std::borrow::Cow::Borrowed(text)
613        } else {
614            // Find the last valid UTF-8 char boundary within the limit
615            let truncated = &text[..max_chars];
616            let end = truncated
617                .char_indices()
618                .rev()
619                .next()
620                .map(|(i, c)| i + c.len_utf8())
621                .unwrap_or(max_chars);
622            tracing::info!("Truncated embedding text from {} to {} chars", text.len(), end);
623            std::borrow::Cow::Owned(text[..end].to_string())
624        }
625    }
626
627    fn note_dimension(&self, observed: usize) -> Result<()> {
628        if observed == 0 {
629            return Err(anyhow!("embedding provider returned zero-length embedding"));
630        }
631
632        let current = self.dimension.load(Ordering::Relaxed);
633        if current == 0 {
634            self.dimension.store(observed, Ordering::Relaxed);
635            return Ok(());
636        }
637
638        if current != observed {
639            return Err(anyhow!(
640                "embedding provider returned {observed}D vectors but runtime expects {current}D"
641            ));
642        }
643
644        Ok(())
645    }
646
647    fn truncate_if_remote<'a>(&self, text: &'a str) -> std::borrow::Cow<'a, str> {
648        match &self.backend {
649            EmbeddingBackend::OpenAI(_) | EmbeddingBackend::Nvidia(_) => {
650                Self::truncate_for_embedding(text, self.max_remote_embedding_chars())
651            }
652            _ => std::borrow::Cow::Borrowed(text),
653        }
654    }
655
656    pub fn embed_passage(&self, text: &str) -> Result<Vec<f32>> {
657        let text = self.truncate_if_remote(text);
658        let embedding = match &self.backend {
659            EmbeddingBackend::FastEmbed(model) => {
660                let mut guard = model
661                    .lock()
662                    .map_err(|_| anyhow!("fastembed runtime poisoned"))?;
663                let outputs = guard
664                    .embed(vec![text.into_owned()], None)
665                    .map_err(|err| anyhow!("failed to compute embedding with fastembed: {err}"))?;
666                outputs
667                    .into_iter()
668                    .next()
669                    .ok_or_else(|| anyhow!("fastembed returned no embedding output"))?
670            }
671            EmbeddingBackend::OpenAI(provider) => {
672                use memvid_core::EmbeddingProvider;
673                provider
674                    .embed_text(&text)
675                    .map_err(|err| anyhow!("failed to compute embedding with OpenAI: {err}"))?
676            }
677            EmbeddingBackend::Nvidia(provider) => provider
678                .embed_passage(&text)
679                .map_err(|err| anyhow!("failed to compute embedding with NVIDIA: {err}"))?,
680        };
681
682        self.note_dimension(embedding.len())?;
683        Ok(embedding)
684    }
685
686    pub fn embed_query(&self, text: &str) -> Result<Vec<f32>> {
687        let text = self.truncate_if_remote(text);
688        match &self.backend {
689            EmbeddingBackend::Nvidia(provider) => {
690                let embedding = provider
691                    .embed_query(&text)
692                    .map_err(|err| anyhow!("failed to compute embedding with NVIDIA: {err}"))?;
693                self.note_dimension(embedding.len())?;
694                Ok(embedding)
695            }
696            _ => self.embed_passage(&text),
697        }
698    }
699
700    pub fn embed_batch_passages(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
701        if texts.is_empty() {
702            return Ok(Vec::new());
703        }
704
705        let truncated: Vec<std::borrow::Cow<'_, str>> =
706            texts.iter().map(|t| self.truncate_if_remote(t)).collect();
707        let truncated_refs: Vec<&str> = truncated.iter().map(|c| c.as_ref()).collect();
708
709        let embeddings = match &self.backend {
710            EmbeddingBackend::FastEmbed(model) => {
711                let mut guard = model
712                    .lock()
713                    .map_err(|_| anyhow!("fastembed runtime poisoned"))?;
714                guard
715                    .embed(
716                        truncated_refs
717                            .iter()
718                            .map(|s| (*s).to_string())
719                            .collect::<Vec<String>>(),
720                        None,
721                    )
722                    .map_err(|err| anyhow!("failed to compute embeddings with fastembed: {err}"))?
723            }
724            EmbeddingBackend::OpenAI(provider) => {
725                use memvid_core::EmbeddingProvider;
726                provider
727                    .embed_batch(&truncated_refs)
728                    .map_err(|err| anyhow!("failed to compute embeddings with OpenAI: {err}"))?
729            }
730            EmbeddingBackend::Nvidia(provider) => provider
731                .embed_passages(&truncated_refs)
732                .map_err(|err| anyhow!("failed to compute embeddings with NVIDIA: {err}"))?,
733        };
734
735        if let Some(first) = embeddings.first() {
736            self.note_dimension(first.len())?;
737        }
738        if let Some(expected) = embeddings.first().map(|e| e.len()) {
739            if embeddings.iter().any(|e| e.len() != expected) {
740                return Err(anyhow!("embedding provider returned mixed vector dimensions"));
741            }
742        }
743
744        Ok(embeddings)
745    }
746
747    pub fn embed_batch_queries(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
748        if texts.is_empty() {
749            return Ok(Vec::new());
750        }
751
752        let truncated: Vec<std::borrow::Cow<'_, str>> =
753            texts.iter().map(|t| self.truncate_if_remote(t)).collect();
754        let truncated_refs: Vec<&str> = truncated.iter().map(|c| c.as_ref()).collect();
755
756        match &self.backend {
757            EmbeddingBackend::Nvidia(provider) => {
758                let embeddings = provider
759                    .embed_queries(&truncated_refs)
760                    .map_err(|err| anyhow!("failed to compute embeddings with NVIDIA: {err}"))?;
761
762                if let Some(first) = embeddings.first() {
763                    self.note_dimension(first.len())?;
764                }
765                if let Some(expected) = embeddings.first().map(|e| e.len()) {
766                    if embeddings.iter().any(|e| e.len() != expected) {
767                        return Err(anyhow!("embedding provider returned mixed vector dimensions"));
768                    }
769                }
770
771                Ok(embeddings)
772            }
773            _ => self.embed_batch_passages(&truncated_refs),
774        }
775    }
776
777    pub fn dimension(&self) -> usize {
778        self.dimension.load(Ordering::Relaxed)
779    }
780
781    pub fn model_choice(&self) -> EmbeddingModelChoice {
782        self.model
783    }
784
785    pub fn provider_kind(&self) -> &'static str {
786        match &self.backend {
787            EmbeddingBackend::FastEmbed(_) => "fastembed",
788            EmbeddingBackend::OpenAI(_) => "openai",
789            EmbeddingBackend::Nvidia(_) => "nvidia",
790        }
791    }
792
793    pub fn provider_model_id(&self) -> String {
794        match &self.backend {
795            EmbeddingBackend::FastEmbed(_) => self.model.canonical_model_id().to_string(),
796            EmbeddingBackend::OpenAI(provider) => {
797                use memvid_core::EmbeddingProvider;
798                provider.model().to_string()
799            }
800            EmbeddingBackend::Nvidia(provider) => provider.model().to_string(),
801        }
802    }
803}
804
805impl memvid_core::VecEmbedder for EmbeddingRuntime {
806    fn embed_query(&self, text: &str) -> memvid_core::Result<Vec<f32>> {
807        EmbeddingRuntime::embed_query(self, text).map_err(|err| {
808            memvid_core::MemvidError::EmbeddingFailed {
809                reason: err.to_string().into_boxed_str(),
810            }
811        })
812    }
813
814    fn embedding_dimension(&self) -> usize {
815        self.dimension()
816    }
817}
818
819/// Ensure fastembed cache directory exists
820fn ensure_fastembed_cache(config: &CliConfig) -> Result<PathBuf> {
821    use std::fs;
822
823    let cache_dir = config.models_dir.clone();
824    fs::create_dir_all(&cache_dir)?;
825    Ok(cache_dir)
826}
827
828/// Get approximate model size in MB for user-friendly error messages
829fn model_size_mb(model: EmbeddingModelChoice) -> usize {
830    match model {
831        EmbeddingModelChoice::BgeSmall => 33,
832        EmbeddingModelChoice::BgeBase => 110,
833        EmbeddingModelChoice::Nomic => 137,
834        EmbeddingModelChoice::GteLarge => 327,
835        // OpenAI models don't require local download
836        EmbeddingModelChoice::OpenAILarge
837        | EmbeddingModelChoice::OpenAISmall
838        | EmbeddingModelChoice::OpenAIAda
839        | EmbeddingModelChoice::Nvidia => 0,
840    }
841}
842
843/// Instantiate an embedding runtime with the configured model
844fn instantiate_embedding_runtime(config: &CliConfig) -> Result<EmbeddingRuntime> {
845    use tracing::info;
846
847    let embedding_model = config.embedding_model;
848
849    if embedding_model.dimensions() > 0 {
850        info!(
851            "Loading embedding model: {} ({}D)",
852            embedding_model.name(),
853            embedding_model.dimensions()
854        );
855    } else {
856        info!("Loading embedding model: {}", embedding_model.name());
857    }
858
859    if config.offline && (embedding_model.is_openai() || embedding_model == EmbeddingModelChoice::Nvidia)
860    {
861        anyhow::bail!(
862            "remote embeddings are unavailable while offline; set MEMVID_OFFLINE=0 or use a local embedding model"
863        );
864    }
865
866    // Check if OpenAI model
867    if embedding_model.is_openai() {
868        return instantiate_openai_runtime(embedding_model);
869    }
870
871    if embedding_model == EmbeddingModelChoice::Nvidia {
872        return instantiate_nvidia_runtime(None);
873    }
874
875    // Local fastembed model
876    instantiate_fastembed_runtime(config, embedding_model)
877}
878
879/// Instantiate OpenAI embedding runtime
880fn instantiate_openai_runtime(embedding_model: EmbeddingModelChoice) -> Result<EmbeddingRuntime> {
881    use anyhow::bail;
882    use memvid_core::EmbeddingConfig;
883    use tracing::info;
884
885    let api_key = std::env::var("OPENAI_API_KEY").map_err(|_| {
886        anyhow!("OPENAI_API_KEY environment variable is required for OpenAI embeddings")
887    })?;
888
889    if api_key.is_empty() {
890        bail!("OPENAI_API_KEY cannot be empty");
891    }
892
893    let config = match embedding_model {
894        EmbeddingModelChoice::OpenAILarge => EmbeddingConfig::openai_large(),
895        EmbeddingModelChoice::OpenAISmall => EmbeddingConfig::openai_small(),
896        EmbeddingModelChoice::OpenAIAda => EmbeddingConfig::openai_ada(),
897        _ => unreachable!("is_openai() should have been false"),
898    };
899
900    let provider = OpenAIEmbeddingProvider::new(api_key, config.clone())
901        .map_err(|err| anyhow!("failed to create OpenAI embedding provider: {err}"))?;
902
903    info!(
904        "OpenAI embedding provider ready: model={}, dimension={}",
905        config.model, config.dimension
906    );
907
908    Ok(EmbeddingRuntime::new_openai(
909        provider,
910        embedding_model,
911        config.dimension,
912    ))
913}
914
915fn normalize_nvidia_embedding_model_override(raw: &str) -> Option<String> {
916    let trimmed = raw.trim();
917    if trimmed.is_empty() {
918        return None;
919    }
920
921    let lowered = trimmed.to_ascii_lowercase();
922    if lowered == "nvidia" || lowered == "nv" {
923        return None;
924    }
925
926    let without_prefix = trimmed
927        .strip_prefix("nvidia:")
928        .or_else(|| trimmed.strip_prefix("nv:"))
929        .unwrap_or(trimmed)
930        .trim();
931
932    if without_prefix.is_empty() {
933        return None;
934    }
935
936    if without_prefix.eq_ignore_ascii_case("nv-embed-v1") {
937        return Some("nvidia/nv-embed-v1".to_string());
938    }
939
940    if without_prefix.contains('/') {
941        return Some(without_prefix.to_string());
942    }
943
944    Some(format!("nvidia/{without_prefix}"))
945}
946
947/// Instantiate NVIDIA embedding runtime
948fn instantiate_nvidia_runtime(model_override: Option<&str>) -> Result<EmbeddingRuntime> {
949    use tracing::info;
950
951    let normalized = model_override.and_then(normalize_nvidia_embedding_model_override);
952    let provider = NvidiaEmbeddingProvider::from_env(normalized.as_deref())
953        .map_err(|err| anyhow!("failed to create NVIDIA embedding provider: {err}"))?;
954
955    info!(
956        "NVIDIA embedding provider ready: model={}",
957        provider.model()
958    );
959
960    Ok(EmbeddingRuntime::new_nvidia(
961        provider,
962        EmbeddingModelChoice::Nvidia,
963    ))
964}
965
966/// Instantiate fastembed (local) embedding runtime
967fn instantiate_fastembed_runtime(
968    config: &CliConfig,
969    embedding_model: EmbeddingModelChoice,
970) -> Result<EmbeddingRuntime> {
971    use anyhow::bail;
972    use fastembed::{InitOptions, TextEmbedding};
973    use std::fs;
974
975    let cache_dir = ensure_fastembed_cache(config)?;
976
977    if config.offline {
978        let mut entries = fs::read_dir(&cache_dir)?;
979        if entries.next().is_none() {
980            bail!(
981                "semantic embeddings unavailable while offline; allow one connected run so fastembed can cache model weights"
982            );
983        }
984    }
985
986    let options = InitOptions::new(embedding_model.to_fastembed_model())
987        .with_cache_dir(cache_dir)
988        .with_show_download_progress(true);
989    let mut model = TextEmbedding::try_new(options).map_err(|err| {
990        // Provide platform-specific guidance for model download issues
991        let platform_hint = if cfg!(target_os = "windows") {
992            "\n\nWindows users: If model downloads fail, try:\n\
993            1. Run as Administrator\n\
994            2. Check your antivirus isn't blocking downloads\n\
995            3. Use OpenAI embeddings instead: set OPENAI_API_KEY and use --embedding-model openai"
996        } else if cfg!(target_os = "linux") {
997            "\n\nLinux users: If model downloads fail, try:\n\
998            1. Check disk space in ~/.memvid/models\n\
999            2. Ensure you have network access to huggingface.co\n\
1000            3. Use OpenAI embeddings instead: export OPENAI_API_KEY=... and use --embedding-model openai"
1001        } else {
1002            "\n\nIf model downloads fail, try using OpenAI embeddings:\n\
1003            export OPENAI_API_KEY=your-key && memvid ... --embedding-model openai"
1004        };
1005
1006        anyhow!(
1007            "Failed to initialize embedding model '{}': {err}\n\n\
1008            This typically means the model couldn't be downloaded or loaded.\n\
1009            Model size: ~{} MB{}\n\n\
1010            See: https://docs.memvid.com/embedding-models",
1011            embedding_model.name(),
1012            model_size_mb(embedding_model),
1013            platform_hint
1014        )
1015    })?;
1016
1017    let probe = model
1018        .embed(vec!["memvid probe".to_string()], None)
1019        .map_err(|err| anyhow!("failed to determine embedding dimension: {err}"))?;
1020    let dimension = probe.first().map(|vec| vec.len()).unwrap_or(0);
1021
1022    if dimension == 0 {
1023        bail!("fastembed reported zero-length embeddings");
1024    }
1025
1026    // Verify dimension matches expected
1027    if dimension != embedding_model.dimensions() {
1028        tracing::warn!(
1029            "Embedding dimension mismatch: expected {}, got {}",
1030            embedding_model.dimensions(),
1031            dimension
1032        );
1033    }
1034
1035    Ok(EmbeddingRuntime::new_fastembed(model, embedding_model, dimension))
1036}
1037
1038/// Load embedding runtime (fails if unavailable)
1039pub fn load_embedding_runtime(config: &CliConfig) -> Result<EmbeddingRuntime> {
1040    use anyhow::bail;
1041
1042    match instantiate_embedding_runtime(config) {
1043        Ok(runtime) => Ok(runtime),
1044        Err(err) => {
1045            if config.offline {
1046                bail!(
1047                    "semantic embeddings unavailable while offline; allow one connected run so fastembed can cache model weights ({err})"
1048                );
1049            }
1050            Err(err)
1051        }
1052    }
1053}
1054
1055/// Try to load embedding runtime (returns None if unavailable)
1056pub fn try_load_embedding_runtime(config: &CliConfig) -> Option<EmbeddingRuntime> {
1057    use tracing::warn;
1058
1059    match instantiate_embedding_runtime(config) {
1060        Ok(runtime) => Some(runtime),
1061        Err(err) => {
1062            warn!("semantic embeddings unavailable: {err}");
1063            None
1064        }
1065    }
1066}
1067
1068/// Load embedding runtime with an optional model override.
1069/// If `model_override` is provided, it will be used instead of the config's embedding_model.
1070pub fn load_embedding_runtime_with_model(
1071    config: &CliConfig,
1072    model_override: Option<&str>,
1073) -> Result<EmbeddingRuntime> {
1074    use tracing::info;
1075
1076    let mut raw_override: Option<&str> = None;
1077    let embedding_model = match model_override {
1078        Some(model_str) => {
1079            raw_override = Some(model_str);
1080            let parsed = model_str.parse::<EmbeddingModelChoice>()?;
1081            if parsed.dimensions() > 0 {
1082                info!(
1083                    "Using embedding model override: {} ({}D)",
1084                    parsed.name(),
1085                    parsed.dimensions()
1086                );
1087            } else {
1088                info!("Using embedding model override: {}", parsed.name());
1089            }
1090            parsed
1091        }
1092        None => config.embedding_model,
1093    };
1094
1095    if embedding_model.dimensions() > 0 {
1096        info!(
1097            "Loading embedding model: {} ({}D)",
1098            embedding_model.name(),
1099            embedding_model.dimensions()
1100        );
1101    } else {
1102        info!("Loading embedding model: {}", embedding_model.name());
1103    }
1104
1105    if config.offline && (embedding_model.is_openai() || embedding_model == EmbeddingModelChoice::Nvidia)
1106    {
1107        anyhow::bail!(
1108            "remote embeddings are unavailable while offline; set MEMVID_OFFLINE=0 or use a local embedding model"
1109        );
1110    }
1111
1112    if embedding_model.is_openai() {
1113        return instantiate_openai_runtime(embedding_model);
1114    }
1115
1116    if embedding_model == EmbeddingModelChoice::Nvidia {
1117        return instantiate_nvidia_runtime(raw_override);
1118    }
1119
1120    instantiate_fastembed_runtime(config, embedding_model)
1121}
1122
1123/// Try to load embedding runtime with model override (returns None if unavailable)
1124pub fn try_load_embedding_runtime_with_model(
1125    config: &CliConfig,
1126    model_override: Option<&str>,
1127) -> Option<EmbeddingRuntime> {
1128    use tracing::warn;
1129
1130    match load_embedding_runtime_with_model(config, model_override) {
1131        Ok(runtime) => Some(runtime),
1132        Err(err) => {
1133            warn!("semantic embeddings unavailable: {err}");
1134            None
1135        }
1136    }
1137}
1138
1139/// Load embedding runtime by auto-detecting from MV2 vector dimension.
1140///
1141/// Priority:
1142/// 1. Explicit model override (--query-embedding-model flag)
1143/// 2. Auto-detect from MV2 file's stored dimension
1144/// 3. Fall back to config default
1145///
1146/// This allows users to omit --query-embedding-model when querying files
1147/// created with non-default embedding models (like OpenAI).
1148pub fn load_embedding_runtime_for_mv2(
1149    config: &CliConfig,
1150    model_override: Option<&str>,
1151    mv2_dimension: Option<u32>,
1152) -> Result<EmbeddingRuntime> {
1153    use tracing::info;
1154
1155    // Priority 1: Explicit override
1156    if let Some(model_str) = model_override {
1157        return load_embedding_runtime_with_model(config, Some(model_str));
1158    }
1159
1160    // Priority 2: Auto-detect from MV2 dimension
1161    if let Some(dim) = mv2_dimension {
1162        if let Some(detected_model) = EmbeddingModelChoice::from_dimension(dim) {
1163            info!(
1164                "Auto-detected embedding model from MV2: {} ({}D)",
1165                detected_model.name(),
1166                dim
1167            );
1168
1169            // For OpenAI models, check if API key is available
1170            if detected_model.is_openai() {
1171                if std::env::var("OPENAI_API_KEY").is_ok() {
1172                    return load_embedding_runtime_with_model(config, Some(detected_model.name()));
1173                } else {
1174                    // OpenAI detected but no API key - provide helpful error
1175                    return Err(anyhow!(
1176                        "MV2 file uses OpenAI embeddings ({}D) but OPENAI_API_KEY is not set.\n\n\
1177                        Options:\n\
1178                        1. Set OPENAI_API_KEY environment variable\n\
1179                        2. Use --query-embedding-model to specify a different model\n\
1180                        3. Use lexical-only search with --mode lex\n\n\
1181                        See: https://docs.memvid.com/embedding-models",
1182                        dim
1183                    ));
1184                }
1185            }
1186
1187            return load_embedding_runtime_with_model(config, Some(detected_model.name()));
1188        }
1189    }
1190
1191    // Priority 3: Fall back to config default
1192    load_embedding_runtime(config)
1193}
1194
1195/// Try to load embedding runtime for MV2 with auto-detection (returns None if unavailable)
1196pub fn try_load_embedding_runtime_for_mv2(
1197    config: &CliConfig,
1198    model_override: Option<&str>,
1199    mv2_dimension: Option<u32>,
1200) -> Option<EmbeddingRuntime> {
1201    use tracing::warn;
1202
1203    match load_embedding_runtime_for_mv2(config, model_override, mv2_dimension) {
1204        Ok(runtime) => Some(runtime),
1205        Err(err) => {
1206            warn!("semantic embeddings unavailable: {err}");
1207            None
1208        }
1209    }
1210}