memvid_cli/
config.rs

1//! CLI configuration and environment handling
2//!
3//! This module provides configuration loading from environment variables,
4//! tracing initialization, and embedding runtime management for semantic search.
5
6use std::env;
7use std::path::PathBuf;
8use std::str::FromStr;
9
10use anyhow::{anyhow, Result};
11use ed25519_dalek::VerifyingKey;
12
13const DEFAULT_API_URL: &str = "https://kgpfm35ddc.us-east-2.awsapprunner.com";
14const DEFAULT_CACHE_DIR: &str = "~/.cache/memvid";
15
16/// Supported embedding models for semantic search
17#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
18pub enum EmbeddingModelChoice {
19    /// BGE-small-en-v1.5: Fast, 384-dim, ~78% accuracy (default)
20    #[default]
21    BgeSmall,
22    /// BGE-base-en-v1.5: Balanced, 768-dim, ~85% accuracy
23    BgeBase,
24    /// Nomic-embed-text-v1.5: High accuracy, 768-dim, ~86% accuracy
25    Nomic,
26    /// GTE-large-en-v1.5: Best semantic depth, 1024-dim
27    GteLarge,
28    /// OpenAI text-embedding-3-large: Highest quality, 3072-dim (requires OPENAI_API_KEY)
29    OpenAILarge,
30    /// OpenAI text-embedding-3-small: Good quality, 1536-dim (requires OPENAI_API_KEY)
31    OpenAISmall,
32    /// OpenAI text-embedding-ada-002: Legacy model, 1536-dim (requires OPENAI_API_KEY)
33    OpenAIAda,
34}
35
36impl EmbeddingModelChoice {
37    /// Check if this is an OpenAI model (requires OPENAI_API_KEY)
38    pub fn is_openai(&self) -> bool {
39        matches!(
40            self,
41            EmbeddingModelChoice::OpenAILarge
42                | EmbeddingModelChoice::OpenAISmall
43                | EmbeddingModelChoice::OpenAIAda
44        )
45    }
46
47    /// Get the fastembed EmbeddingModel enum value (only for local models)
48    ///
49    /// # Panics
50    /// Panics if called on an OpenAI model. Use `is_openai()` to check first.
51    pub fn to_fastembed_model(&self) -> fastembed::EmbeddingModel {
52        match self {
53            EmbeddingModelChoice::BgeSmall => fastembed::EmbeddingModel::BGESmallENV15,
54            EmbeddingModelChoice::BgeBase => fastembed::EmbeddingModel::BGEBaseENV15,
55            EmbeddingModelChoice::Nomic => fastembed::EmbeddingModel::NomicEmbedTextV15,
56            EmbeddingModelChoice::GteLarge => fastembed::EmbeddingModel::GTELargeENV15,
57            EmbeddingModelChoice::OpenAILarge
58            | EmbeddingModelChoice::OpenAISmall
59            | EmbeddingModelChoice::OpenAIAda => {
60                panic!("OpenAI models don't use fastembed. Check is_openai() first.")
61            }
62        }
63    }
64
65    /// Get human-readable model name
66    pub fn name(&self) -> &'static str {
67        match self {
68            EmbeddingModelChoice::BgeSmall => "bge-small",
69            EmbeddingModelChoice::BgeBase => "bge-base",
70            EmbeddingModelChoice::Nomic => "nomic",
71            EmbeddingModelChoice::GteLarge => "gte-large",
72            EmbeddingModelChoice::OpenAILarge => "openai-large",
73            EmbeddingModelChoice::OpenAISmall => "openai-small",
74            EmbeddingModelChoice::OpenAIAda => "openai-ada",
75        }
76    }
77
78    /// Get embedding dimensions
79    pub fn dimensions(&self) -> usize {
80        match self {
81            EmbeddingModelChoice::BgeSmall => 384,
82            EmbeddingModelChoice::BgeBase => 768,
83            EmbeddingModelChoice::Nomic => 768,
84            EmbeddingModelChoice::GteLarge => 1024,
85            EmbeddingModelChoice::OpenAILarge => 3072,
86            EmbeddingModelChoice::OpenAISmall => 1536,
87            EmbeddingModelChoice::OpenAIAda => 1536,
88        }
89    }
90}
91
92impl FromStr for EmbeddingModelChoice {
93    type Err = anyhow::Error;
94
95    fn from_str(s: &str) -> Result<Self> {
96        match s.to_lowercase().as_str() {
97            "bge-small" | "bge_small" | "bgesmall" | "small" => Ok(EmbeddingModelChoice::BgeSmall),
98            "bge-base" | "bge_base" | "bgebase" | "base" => Ok(EmbeddingModelChoice::BgeBase),
99            "nomic" | "nomic-embed" | "nomic_embed" => Ok(EmbeddingModelChoice::Nomic),
100            "gte-large" | "gte_large" | "gtelarge" | "gte" => Ok(EmbeddingModelChoice::GteLarge),
101            // OpenAI models - default "openai" maps to "openai-large" for highest quality
102            "openai" | "openai-large" | "openai_large" | "text-embedding-3-large" => {
103                Ok(EmbeddingModelChoice::OpenAILarge)
104            }
105            "openai-small" | "openai_small" | "text-embedding-3-small" => {
106                Ok(EmbeddingModelChoice::OpenAISmall)
107            }
108            "openai-ada" | "openai_ada" | "text-embedding-ada-002" | "ada" => {
109                Ok(EmbeddingModelChoice::OpenAIAda)
110            }
111            _ => Err(anyhow!(
112                "unknown embedding model '{}'. Valid options: bge-small, bge-base, nomic, gte-large, openai, openai-small, openai-ada",
113                s
114            )),
115        }
116    }
117}
118
119impl std::fmt::Display for EmbeddingModelChoice {
120    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
121        write!(f, "{}", self.name())
122    }
123}
124
125impl EmbeddingModelChoice {
126    /// Infer the best embedding model from vector dimension stored in MV2 file.
127    ///
128    /// This enables auto-detection: users don't need to specify --query-embedding-model
129    /// if the MV2 file has vectors. The dimension uniquely identifies the model family.
130    ///
131    /// # Dimension Mapping
132    /// - 384  → BGE-small (default local model)
133    /// - 768  → BGE-base (could also be Nomic, but same dimension works)
134    /// - 1024 → GTE-large
135    /// - 1536 → OpenAI small/ada
136    /// - 3072 → OpenAI large
137    pub fn from_dimension(dim: u32) -> Option<Self> {
138        match dim {
139            384 => Some(EmbeddingModelChoice::BgeSmall),
140            768 => Some(EmbeddingModelChoice::BgeBase), // Could be Nomic, but same dim
141            1024 => Some(EmbeddingModelChoice::GteLarge),
142            1536 => Some(EmbeddingModelChoice::OpenAISmall), // Could be Ada, same dim
143            3072 => Some(EmbeddingModelChoice::OpenAILarge),
144            0 => None, // No vectors in file
145            _ => {
146                tracing::warn!(
147                    "Unknown embedding dimension {}, using default model",
148                    dim
149                );
150                None
151            }
152        }
153    }
154}
155
156/// CLI configuration loaded from environment variables
157#[derive(Debug, Clone)]
158pub struct CliConfig {
159    pub api_key: Option<String>,
160    pub api_url: String,
161    pub cache_dir: PathBuf,
162    pub ticket_pubkey: Option<VerifyingKey>,
163    pub models_dir: PathBuf,
164    pub offline: bool,
165    /// Embedding model for semantic search (can be overridden by CLI flag)
166    pub embedding_model: EmbeddingModelChoice,
167}
168
169impl PartialEq for CliConfig {
170    fn eq(&self, other: &Self) -> bool {
171        self.api_key == other.api_key
172            && self.api_url == other.api_url
173            && self.cache_dir == other.cache_dir
174            && self.models_dir == other.models_dir
175            && self.offline == other.offline
176            && self.embedding_model == other.embedding_model
177    }
178}
179
180impl Eq for CliConfig {}
181
182impl CliConfig {
183    pub fn load() -> Result<Self> {
184        let api_key = env::var("MEMVID_API_KEY").ok().and_then(|value| {
185            let trimmed = value.trim().to_string();
186            (!trimmed.is_empty()).then_some(trimmed)
187        });
188
189        let api_url = env::var("MEMVID_API_URL").unwrap_or_else(|_| DEFAULT_API_URL.to_string());
190
191        let cache_dir_raw =
192            env::var("MEMVID_CACHE_DIR").unwrap_or_else(|_| DEFAULT_CACHE_DIR.to_string());
193        let cache_dir = expand_path(&cache_dir_raw)?;
194
195        let models_dir_raw =
196            env::var("MEMVID_MODELS_DIR").unwrap_or_else(|_| "~/.memvid/models".to_string());
197        let models_dir = expand_path(&models_dir_raw)?;
198
199        let ticket_pubkey = env::var("MEMVID_TICKET_PUBKEY")
200            .ok()
201            .and_then(|value| {
202                let trimmed = value.trim();
203                if trimmed.is_empty() {
204                    None
205                } else {
206                    Some(memvid_core::parse_ed25519_public_key_base64(trimmed))
207                }
208            })
209            .transpose()?;
210
211        let offline = env::var("MEMVID_OFFLINE")
212            .ok()
213            .map(|value| match value.trim().to_ascii_lowercase().as_str() {
214                "1" | "true" | "yes" => true,
215                _ => false,
216            })
217            .unwrap_or(false);
218
219        // Load embedding model from env var, default to BGE-small
220        let embedding_model = env::var("MEMVID_EMBEDDING_MODEL")
221            .ok()
222            .and_then(|value| {
223                let trimmed = value.trim();
224                if trimmed.is_empty() {
225                    None
226                } else {
227                    EmbeddingModelChoice::from_str(trimmed).ok()
228                }
229            })
230            .unwrap_or_default();
231
232        Ok(Self {
233            api_key,
234            api_url,
235            cache_dir,
236            ticket_pubkey,
237            models_dir,
238            offline,
239            embedding_model,
240        })
241    }
242
243    /// Create a new config with a different embedding model
244    pub fn with_embedding_model(&self, model: EmbeddingModelChoice) -> Self {
245        Self {
246            embedding_model: model,
247            ..self.clone()
248        }
249    }
250}
251
252fn expand_path(value: &str) -> Result<PathBuf> {
253    if value.trim().is_empty() {
254        return Err(anyhow!("cache directory cannot be empty"));
255    }
256
257    let expanded = if let Some(stripped) = value.strip_prefix("~/") {
258        home_dir()?.join(stripped)
259    } else if let Some(stripped) = value.strip_prefix("~\\") {
260        // Support Windows-style "~\" prefix.
261        home_dir()?.join(stripped)
262    } else if value == "~" {
263        home_dir()?
264    } else {
265        PathBuf::from(value)
266    };
267
268    if expanded.is_absolute() {
269        Ok(expanded)
270    } else {
271        Ok(env::current_dir()?.join(expanded))
272    }
273}
274
275fn home_dir() -> Result<PathBuf> {
276    if let Some(path) = env::var_os("HOME") {
277        if !path.is_empty() {
278            return Ok(PathBuf::from(path));
279        }
280    }
281
282    #[cfg(windows)]
283    {
284        if let Some(path) = env::var_os("USERPROFILE") {
285            if !path.is_empty() {
286                return Ok(PathBuf::from(path));
287            }
288        }
289        if let (Some(drive), Some(path)) = (env::var_os("HOMEDRIVE"), env::var_os("HOMEPATH")) {
290            if !drive.is_empty() && !path.is_empty() {
291                return Ok(PathBuf::from(format!(
292                    "{}{}",
293                    drive.to_string_lossy(),
294                    path.to_string_lossy()
295                )));
296            }
297        }
298    }
299
300    Err(anyhow!("unable to resolve home directory"))
301}
302
303#[cfg(test)]
304mod tests {
305    use super::*;
306    use base64::engine::general_purpose::STANDARD as BASE64_STANDARD;
307    use base64::Engine;
308    use ed25519_dalek::SigningKey;
309    use std::sync::{Mutex, OnceLock};
310
311    fn env_lock() -> std::sync::MutexGuard<'static, ()> {
312        static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
313        LOCK.get_or_init(|| Mutex::new(())).lock().unwrap()
314    }
315
316    fn set_or_unset(var: &str, value: Option<String>) {
317        match value {
318            Some(v) => unsafe { env::set_var(var, v) },
319            None => unsafe { env::remove_var(var) },
320        }
321    }
322
323    #[test]
324    fn defaults_expand_using_home_directory() {
325        let _guard = env_lock();
326
327        let previous_home = env::var("HOME").ok();
328        #[cfg(windows)]
329        let previous_userprofile = env::var("USERPROFILE").ok();
330
331        for var in [
332            "MEMVID_API_KEY",
333            "MEMVID_API_URL",
334            "MEMVID_CACHE_DIR",
335            "MEMVID_TICKET_PUBKEY",
336            "MEMVID_MODELS_DIR",
337            "MEMVID_OFFLINE",
338        ] {
339            unsafe { env::remove_var(var) };
340        }
341
342        let tmp = tempfile::tempdir().expect("tmpdir");
343        let tmp_path = tmp.path().to_path_buf();
344        unsafe { env::set_var("HOME", &tmp_path) };
345        #[cfg(windows)]
346        unsafe {
347            env::set_var("USERPROFILE", &tmp_path)
348        };
349
350        let config = CliConfig::load().expect("load");
351        assert_eq!(config.api_key, None);
352        assert_eq!(config.api_url, DEFAULT_API_URL);
353        assert_eq!(config.cache_dir, tmp_path.join(".cache/memvid"));
354        assert!(config.ticket_pubkey.is_none());
355        assert_eq!(config.models_dir, tmp_path.join(".memvid/models"));
356        assert!(!config.offline);
357
358        set_or_unset("HOME", previous_home);
359        #[cfg(windows)]
360        {
361            set_or_unset("USERPROFILE", previous_userprofile);
362        }
363    }
364
365    #[test]
366    fn env_overrides_are_respected() {
367        let _guard = env_lock();
368
369        let previous_env: Vec<(&'static str, Option<String>)> = [
370            "MEMVID_API_KEY",
371            "MEMVID_API_URL",
372            "MEMVID_CACHE_DIR",
373            "MEMVID_TICKET_PUBKEY",
374            "MEMVID_MODELS_DIR",
375            "MEMVID_OFFLINE",
376        ]
377        .into_iter()
378        .map(|var| (var, env::var(var).ok()))
379        .collect();
380
381        unsafe { env::set_var("MEMVID_API_KEY", "abc123") };
382        unsafe { env::set_var("MEMVID_API_URL", "https://staging.memvid.app") };
383        unsafe { env::set_var("MEMVID_CACHE_DIR", "~/memvid-cache") };
384        unsafe { env::set_var("MEMVID_MODELS_DIR", "~/models") };
385        unsafe { env::set_var("MEMVID_OFFLINE", "true") };
386        let signing = SigningKey::from_bytes(&[9u8; 32]);
387        let encoded = BASE64_STANDARD.encode(signing.verifying_key().as_bytes());
388        unsafe { env::set_var("MEMVID_TICKET_PUBKEY", encoded) };
389
390        let tmp = tempfile::tempdir().expect("tmpdir");
391        let tmp_path = tmp.path().to_path_buf();
392        unsafe { env::set_var("HOME", &tmp_path) };
393        #[cfg(windows)]
394        unsafe {
395            env::set_var("USERPROFILE", &tmp_path)
396        };
397
398        let config = CliConfig::load().expect("load");
399        assert_eq!(config.api_key.as_deref(), Some("abc123"));
400        assert_eq!(config.api_url, "https://staging.memvid.app");
401        assert_eq!(config.cache_dir, tmp_path.join("memvid-cache"));
402        assert_eq!(
403            config.ticket_pubkey.expect("pubkey").as_bytes(),
404            signing.verifying_key().as_bytes()
405        );
406        assert_eq!(config.models_dir, tmp_path.join("models"));
407        assert!(config.offline);
408
409        for (var, value) in previous_env {
410            set_or_unset(var, value);
411        }
412    }
413
414    #[test]
415    fn rejects_empty_cache_dir() {
416        let _guard = env_lock();
417
418        let previous = env::var("MEMVID_CACHE_DIR").ok();
419        unsafe { env::set_var("MEMVID_CACHE_DIR", " ") };
420        let err = CliConfig::load().expect_err("should fail");
421        assert!(err.to_string().contains("cache directory"));
422        set_or_unset("MEMVID_CACHE_DIR", previous);
423    }
424}
425
426/// Initialize tracing/logging based on verbosity level
427pub fn init_tracing(verbosity: u8) -> Result<()> {
428    use std::io::IsTerminal;
429    use tracing_subscriber::{filter::Directive, fmt, EnvFilter};
430
431    let level = match verbosity {
432        0 => "warn",
433        1 => "info",
434        2 => "debug",
435        _ => "trace",
436    };
437
438    let mut env_filter =
439        EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new(level));
440    for directive_str in ["llama_cpp=error", "llama_cpp_sys=error", "ggml=error"] {
441        if let Ok(directive) = directive_str.parse::<Directive>() {
442            env_filter = env_filter.add_directive(directive);
443        }
444    }
445
446    // Disable ANSI color codes when stderr is not a terminal (e.g., piped or
447    // combined with `2>&1`). This prevents control characters from polluting
448    // JSON output when combined with stdout.
449    let use_ansi = std::io::stderr().is_terminal();
450
451    fmt()
452        .with_env_filter(env_filter)
453        .with_writer(std::io::stderr)
454        .with_target(false)
455        .without_time()
456        .with_ansi(use_ansi)
457        .try_init()
458        .map_err(|err| anyhow!(err))?;
459    Ok(())
460}
461
462/// Resolve LLM context budget override from CLI or environment
463pub fn resolve_llm_context_budget_override(cli_value: Option<usize>) -> Result<Option<usize>> {
464    use anyhow::bail;
465
466    if let Some(value) = cli_value {
467        if value == 0 {
468            bail!("--llm-context-depth must be a positive integer");
469        }
470        return Ok(Some(value));
471    }
472
473    let raw_env = match env::var("MEMVID_LLM_CONTEXT_BUDGET") {
474        Ok(value) => value,
475        Err(_) => return Ok(None),
476    };
477
478    let trimmed = raw_env.trim();
479    if trimmed.is_empty() {
480        return Ok(None);
481    }
482
483    let digits: String = trimmed
484        .chars()
485        .filter(|ch| !ch.is_ascii_whitespace() && *ch != '_')
486        .collect();
487
488    if digits.is_empty() {
489        bail!("MEMVID_LLM_CONTEXT_BUDGET must be a positive integer value");
490    }
491
492    let value: usize = digits.parse().map_err(|err| {
493        anyhow!(
494            "MEMVID_LLM_CONTEXT_BUDGET value '{}' is not a valid number: {}",
495            trimmed,
496            err
497        )
498    })?;
499
500    if value == 0 {
501        bail!("MEMVID_LLM_CONTEXT_BUDGET must be a positive integer");
502    }
503
504    Ok(Some(value))
505}
506
507use crate::openai_embeddings::OpenAIEmbeddingProvider;
508
509/// Internal embedding backend - either local fastembed or OpenAI API
510#[derive(Clone)]
511enum EmbeddingBackend {
512    FastEmbed(std::sync::Arc<std::sync::Mutex<fastembed::TextEmbedding>>),
513    OpenAI(std::sync::Arc<OpenAIEmbeddingProvider>),
514}
515
516/// Embedding runtime wrapper supporting both local and OpenAI embeddings
517#[derive(Clone)]
518pub struct EmbeddingRuntime {
519    backend: EmbeddingBackend,
520    dimension: usize,
521}
522
523impl EmbeddingRuntime {
524    fn new_fastembed(model: fastembed::TextEmbedding, dimension: usize) -> Self {
525        Self {
526            backend: EmbeddingBackend::FastEmbed(std::sync::Arc::new(std::sync::Mutex::new(
527                model,
528            ))),
529            dimension,
530        }
531    }
532
533    fn new_openai(provider: OpenAIEmbeddingProvider, dimension: usize) -> Self {
534        Self {
535            backend: EmbeddingBackend::OpenAI(std::sync::Arc::new(provider)),
536            dimension,
537        }
538    }
539
540    /// Maximum characters for embedding text to avoid exceeding OpenAI's 8192 token limit.
541    /// Using ~3 chars/token estimate (conservative for dense content), 20K chars ≈ 6.6K tokens.
542    const MAX_EMBEDDING_TEXT_LEN: usize = 20_000;
543
544    /// Truncate text for embedding to fit within token limits.
545    fn truncate_for_embedding(text: &str) -> std::borrow::Cow<'_, str> {
546        if text.len() <= Self::MAX_EMBEDDING_TEXT_LEN {
547            std::borrow::Cow::Borrowed(text)
548        } else {
549            // Find the last valid UTF-8 char boundary within the limit
550            let truncated = &text[..Self::MAX_EMBEDDING_TEXT_LEN];
551            let end = truncated
552                .char_indices()
553                .rev()
554                .next()
555                .map(|(i, c)| i + c.len_utf8())
556                .unwrap_or(Self::MAX_EMBEDDING_TEXT_LEN);
557            tracing::info!(
558                "Truncated embedding text from {} to {} chars to fit OpenAI token limit",
559                text.len(),
560                end
561            );
562            std::borrow::Cow::Owned(text[..end].to_string())
563        }
564    }
565
566    pub fn embed(&self, text: &str) -> Result<Vec<f32>> {
567        // Truncate text for OpenAI to avoid token limit errors
568        let text = match &self.backend {
569            EmbeddingBackend::OpenAI(_) => Self::truncate_for_embedding(text),
570            _ => std::borrow::Cow::Borrowed(text),
571        };
572
573        match &self.backend {
574            EmbeddingBackend::FastEmbed(model) => {
575                let mut guard = model
576                    .lock()
577                    .map_err(|_| anyhow!("fastembed runtime poisoned"))?;
578                let outputs = guard
579                    .embed(vec![text.into_owned()], None)
580                    .map_err(|err| anyhow!("failed to compute embedding with fastembed: {err}"))?;
581                outputs
582                    .into_iter()
583                    .next()
584                    .ok_or_else(|| anyhow!("fastembed returned no embedding output"))
585            }
586            EmbeddingBackend::OpenAI(provider) => {
587                use memvid_core::EmbeddingProvider;
588                provider
589                    .embed_text(&text)
590                    .map_err(|err| anyhow!("failed to compute embedding with OpenAI: {err}"))
591            }
592        }
593    }
594
595    pub fn dimension(&self) -> usize {
596        self.dimension
597    }
598}
599
600impl memvid_core::VecEmbedder for EmbeddingRuntime {
601    fn embed_query(&self, text: &str) -> memvid_core::Result<Vec<f32>> {
602        self.embed(text)
603            .map_err(|err| memvid_core::MemvidError::EmbeddingFailed {
604                reason: err.to_string().into_boxed_str(),
605            })
606    }
607
608    fn embedding_dimension(&self) -> usize {
609        self.dimension()
610    }
611}
612
613/// Ensure fastembed cache directory exists
614fn ensure_fastembed_cache(config: &CliConfig) -> Result<PathBuf> {
615    use std::fs;
616
617    let cache_dir = config.models_dir.clone();
618    fs::create_dir_all(&cache_dir)?;
619    Ok(cache_dir)
620}
621
622/// Get approximate model size in MB for user-friendly error messages
623fn model_size_mb(model: EmbeddingModelChoice) -> usize {
624    match model {
625        EmbeddingModelChoice::BgeSmall => 33,
626        EmbeddingModelChoice::BgeBase => 110,
627        EmbeddingModelChoice::Nomic => 137,
628        EmbeddingModelChoice::GteLarge => 327,
629        // OpenAI models don't require local download
630        EmbeddingModelChoice::OpenAILarge
631        | EmbeddingModelChoice::OpenAISmall
632        | EmbeddingModelChoice::OpenAIAda => 0,
633    }
634}
635
636/// Instantiate an embedding runtime with the configured model
637fn instantiate_embedding_runtime(config: &CliConfig) -> Result<EmbeddingRuntime> {
638    use tracing::info;
639
640    let embedding_model = config.embedding_model;
641
642    info!(
643        "Loading embedding model: {} ({}D)",
644        embedding_model.name(),
645        embedding_model.dimensions()
646    );
647
648    // Check if OpenAI model
649    if embedding_model.is_openai() {
650        return instantiate_openai_runtime(embedding_model);
651    }
652
653    // Local fastembed model
654    instantiate_fastembed_runtime(config, embedding_model)
655}
656
657/// Instantiate OpenAI embedding runtime
658fn instantiate_openai_runtime(embedding_model: EmbeddingModelChoice) -> Result<EmbeddingRuntime> {
659    use anyhow::bail;
660    use memvid_core::EmbeddingConfig;
661    use tracing::info;
662
663    let api_key = std::env::var("OPENAI_API_KEY")
664        .map_err(|_| anyhow!("OPENAI_API_KEY environment variable is required for OpenAI embeddings"))?;
665
666    if api_key.is_empty() {
667        bail!("OPENAI_API_KEY cannot be empty");
668    }
669
670    let config = match embedding_model {
671        EmbeddingModelChoice::OpenAILarge => EmbeddingConfig::openai_large(),
672        EmbeddingModelChoice::OpenAISmall => EmbeddingConfig::openai_small(),
673        EmbeddingModelChoice::OpenAIAda => EmbeddingConfig::openai_ada(),
674        _ => unreachable!("is_openai() should have been false"),
675    };
676
677    let provider = OpenAIEmbeddingProvider::new(api_key, config.clone())
678        .map_err(|err| anyhow!("failed to create OpenAI embedding provider: {err}"))?;
679
680    info!(
681        "OpenAI embedding provider ready: model={}, dimension={}",
682        config.model, config.dimension
683    );
684
685    Ok(EmbeddingRuntime::new_openai(provider, config.dimension))
686}
687
688/// Instantiate fastembed (local) embedding runtime
689fn instantiate_fastembed_runtime(
690    config: &CliConfig,
691    embedding_model: EmbeddingModelChoice,
692) -> Result<EmbeddingRuntime> {
693    use anyhow::bail;
694    use fastembed::{InitOptions, TextEmbedding};
695    use std::fs;
696
697    let cache_dir = ensure_fastembed_cache(config)?;
698
699    if config.offline {
700        let mut entries = fs::read_dir(&cache_dir)?;
701        if entries.next().is_none() {
702            bail!(
703                "semantic embeddings unavailable while offline; allow one connected run so fastembed can cache model weights"
704            );
705        }
706    }
707
708    let options = InitOptions::new(embedding_model.to_fastembed_model())
709        .with_cache_dir(cache_dir)
710        .with_show_download_progress(true);
711    let mut model = TextEmbedding::try_new(options).map_err(|err| {
712        // Provide platform-specific guidance for model download issues
713        let platform_hint = if cfg!(target_os = "windows") {
714            "\n\nWindows users: If model downloads fail, try:\n\
715            1. Run as Administrator\n\
716            2. Check your antivirus isn't blocking downloads\n\
717            3. Use OpenAI embeddings instead: set OPENAI_API_KEY and use --embedding-model openai"
718        } else if cfg!(target_os = "linux") {
719            "\n\nLinux users: If model downloads fail, try:\n\
720            1. Check disk space in ~/.memvid/models\n\
721            2. Ensure you have network access to huggingface.co\n\
722            3. Use OpenAI embeddings instead: export OPENAI_API_KEY=... and use --embedding-model openai"
723        } else {
724            "\n\nIf model downloads fail, try using OpenAI embeddings:\n\
725            export OPENAI_API_KEY=your-key && memvid ... --embedding-model openai"
726        };
727
728        anyhow!(
729            "Failed to initialize embedding model '{}': {err}\n\n\
730            This typically means the model couldn't be downloaded or loaded.\n\
731            Model size: ~{} MB{}\n\n\
732            See: https://docs.memvid.com/embedding-models",
733            embedding_model.name(),
734            model_size_mb(embedding_model),
735            platform_hint
736        )
737    })?;
738
739    let probe = model
740        .embed(vec!["memvid probe".to_string()], None)
741        .map_err(|err| anyhow!("failed to determine embedding dimension: {err}"))?;
742    let dimension = probe.first().map(|vec| vec.len()).unwrap_or(0);
743
744    if dimension == 0 {
745        bail!("fastembed reported zero-length embeddings");
746    }
747
748    // Verify dimension matches expected
749    if dimension != embedding_model.dimensions() {
750        tracing::warn!(
751            "Embedding dimension mismatch: expected {}, got {}",
752            embedding_model.dimensions(),
753            dimension
754        );
755    }
756
757    Ok(EmbeddingRuntime::new_fastembed(model, dimension))
758}
759
760/// Load embedding runtime (fails if unavailable)
761pub fn load_embedding_runtime(config: &CliConfig) -> Result<EmbeddingRuntime> {
762    use anyhow::bail;
763
764    match instantiate_embedding_runtime(config) {
765        Ok(runtime) => Ok(runtime),
766        Err(err) => {
767            if config.offline {
768                bail!(
769                    "semantic embeddings unavailable while offline; allow one connected run so fastembed can cache model weights ({err})"
770                );
771            }
772            Err(err)
773        }
774    }
775}
776
777/// Try to load embedding runtime (returns None if unavailable)
778pub fn try_load_embedding_runtime(config: &CliConfig) -> Option<EmbeddingRuntime> {
779    use tracing::warn;
780
781    match instantiate_embedding_runtime(config) {
782        Ok(runtime) => Some(runtime),
783        Err(err) => {
784            warn!("semantic embeddings unavailable: {err}");
785            None
786        }
787    }
788}
789
790/// Load embedding runtime with an optional model override.
791/// If `model_override` is provided, it will be used instead of the config's embedding_model.
792pub fn load_embedding_runtime_with_model(
793    config: &CliConfig,
794    model_override: Option<&str>,
795) -> Result<EmbeddingRuntime> {
796    use tracing::info;
797
798    let embedding_model = match model_override {
799        Some(model_str) => {
800            let parsed = model_str.parse::<EmbeddingModelChoice>()?;
801            info!("Using embedding model override: {} ({}D)", parsed.name(), parsed.dimensions());
802            parsed
803        }
804        None => config.embedding_model,
805    };
806
807    info!(
808        "Loading embedding model: {} ({}D)",
809        embedding_model.name(),
810        embedding_model.dimensions()
811    );
812
813    if embedding_model.is_openai() {
814        return instantiate_openai_runtime(embedding_model);
815    }
816
817    instantiate_fastembed_runtime(config, embedding_model)
818}
819
820/// Try to load embedding runtime with model override (returns None if unavailable)
821pub fn try_load_embedding_runtime_with_model(
822    config: &CliConfig,
823    model_override: Option<&str>,
824) -> Option<EmbeddingRuntime> {
825    use tracing::warn;
826
827    match load_embedding_runtime_with_model(config, model_override) {
828        Ok(runtime) => Some(runtime),
829        Err(err) => {
830            warn!("semantic embeddings unavailable: {err}");
831            None
832        }
833    }
834}
835
836/// Load embedding runtime by auto-detecting from MV2 vector dimension.
837///
838/// Priority:
839/// 1. Explicit model override (--query-embedding-model flag)
840/// 2. Auto-detect from MV2 file's stored dimension
841/// 3. Fall back to config default
842///
843/// This allows users to omit --query-embedding-model when querying files
844/// created with non-default embedding models (like OpenAI).
845pub fn load_embedding_runtime_for_mv2(
846    config: &CliConfig,
847    model_override: Option<&str>,
848    mv2_dimension: Option<u32>,
849) -> Result<EmbeddingRuntime> {
850    use tracing::info;
851
852    // Priority 1: Explicit override
853    if let Some(model_str) = model_override {
854        return load_embedding_runtime_with_model(config, Some(model_str));
855    }
856
857    // Priority 2: Auto-detect from MV2 dimension
858    if let Some(dim) = mv2_dimension {
859        if let Some(detected_model) = EmbeddingModelChoice::from_dimension(dim) {
860            info!(
861                "Auto-detected embedding model from MV2: {} ({}D)",
862                detected_model.name(),
863                dim
864            );
865
866            // For OpenAI models, check if API key is available
867            if detected_model.is_openai() {
868                if std::env::var("OPENAI_API_KEY").is_ok() {
869                    return load_embedding_runtime_with_model(config, Some(detected_model.name()));
870                } else {
871                    // OpenAI detected but no API key - provide helpful error
872                    return Err(anyhow!(
873                        "MV2 file uses OpenAI embeddings ({}D) but OPENAI_API_KEY is not set.\n\n\
874                        Options:\n\
875                        1. Set OPENAI_API_KEY environment variable\n\
876                        2. Use --query-embedding-model to specify a different model\n\
877                        3. Use lexical-only search with --mode lex\n\n\
878                        See: https://docs.memvid.com/embedding-models",
879                        dim
880                    ));
881                }
882            }
883
884            return load_embedding_runtime_with_model(config, Some(detected_model.name()));
885        }
886    }
887
888    // Priority 3: Fall back to config default
889    load_embedding_runtime(config)
890}
891
892/// Try to load embedding runtime for MV2 with auto-detection (returns None if unavailable)
893pub fn try_load_embedding_runtime_for_mv2(
894    config: &CliConfig,
895    model_override: Option<&str>,
896    mv2_dimension: Option<u32>,
897) -> Option<EmbeddingRuntime> {
898    use tracing::warn;
899
900    match load_embedding_runtime_for_mv2(config, model_override, mv2_dimension) {
901        Ok(runtime) => Some(runtime),
902        Err(err) => {
903            warn!("semantic embeddings unavailable: {err}");
904            None
905        }
906    }
907}