1use std::env;
7use std::path::PathBuf;
8use std::str::FromStr;
9use std::sync::atomic::{AtomicUsize, Ordering};
10
11use anyhow::{anyhow, Result};
12use ed25519_dalek::VerifyingKey;
13
14const DEFAULT_API_URL: &str = "https://memvid.com";
15const DEFAULT_CACHE_DIR: &str = "~/.cache/memvid";
16
17#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
19pub enum EmbeddingModelChoice {
20 #[default]
22 BgeSmall,
23 BgeBase,
25 Nomic,
27 GteLarge,
29 OpenAILarge,
31 OpenAISmall,
33 OpenAIAda,
35 Nvidia,
37 Gemini,
39 Mistral,
41}
42
43impl EmbeddingModelChoice {
44 pub fn is_openai(&self) -> bool {
46 matches!(
47 self,
48 EmbeddingModelChoice::OpenAILarge
49 | EmbeddingModelChoice::OpenAISmall
50 | EmbeddingModelChoice::OpenAIAda
51 )
52 }
53
54 pub fn is_remote(&self) -> bool {
56 matches!(
57 self,
58 EmbeddingModelChoice::OpenAILarge
59 | EmbeddingModelChoice::OpenAISmall
60 | EmbeddingModelChoice::OpenAIAda
61 | EmbeddingModelChoice::Nvidia
62 | EmbeddingModelChoice::Gemini
63 | EmbeddingModelChoice::Mistral
64 )
65 }
66
67 pub fn to_fastembed_model(&self) -> fastembed::EmbeddingModel {
72 match self {
73 EmbeddingModelChoice::BgeSmall => fastembed::EmbeddingModel::BGESmallENV15,
74 EmbeddingModelChoice::BgeBase => fastembed::EmbeddingModel::BGEBaseENV15,
75 EmbeddingModelChoice::Nomic => fastembed::EmbeddingModel::NomicEmbedTextV15,
76 EmbeddingModelChoice::GteLarge => fastembed::EmbeddingModel::GTELargeENV15,
77 EmbeddingModelChoice::OpenAILarge
78 | EmbeddingModelChoice::OpenAISmall
79 | EmbeddingModelChoice::OpenAIAda => {
80 panic!("OpenAI models don't use fastembed. Check is_remote() first.")
81 }
82 EmbeddingModelChoice::Nvidia => {
83 panic!("NVIDIA embeddings don't use fastembed. Check is_remote() first.")
84 }
85 EmbeddingModelChoice::Gemini => {
86 panic!("Gemini embeddings don't use fastembed. Check is_remote() first.")
87 }
88 EmbeddingModelChoice::Mistral => {
89 panic!("Mistral embeddings don't use fastembed. Check is_remote() first.")
90 }
91 }
92 }
93
94 pub fn name(&self) -> &'static str {
96 match self {
97 EmbeddingModelChoice::BgeSmall => "bge-small",
98 EmbeddingModelChoice::BgeBase => "bge-base",
99 EmbeddingModelChoice::Nomic => "nomic",
100 EmbeddingModelChoice::GteLarge => "gte-large",
101 EmbeddingModelChoice::OpenAILarge => "openai-large",
102 EmbeddingModelChoice::OpenAISmall => "openai-small",
103 EmbeddingModelChoice::OpenAIAda => "openai-ada",
104 EmbeddingModelChoice::Nvidia => "nvidia",
105 EmbeddingModelChoice::Gemini => "gemini",
106 EmbeddingModelChoice::Mistral => "mistral",
107 }
108 }
109
110 pub fn canonical_model_id(&self) -> &'static str {
116 match self {
117 EmbeddingModelChoice::BgeSmall => "BAAI/bge-small-en-v1.5",
118 EmbeddingModelChoice::BgeBase => "BAAI/bge-base-en-v1.5",
119 EmbeddingModelChoice::Nomic => "nomic-embed-text-v1.5",
120 EmbeddingModelChoice::GteLarge => "thenlper/gte-large",
121 EmbeddingModelChoice::OpenAILarge => "text-embedding-3-large",
122 EmbeddingModelChoice::OpenAISmall => "text-embedding-3-small",
123 EmbeddingModelChoice::OpenAIAda => "text-embedding-ada-002",
124 EmbeddingModelChoice::Nvidia => "nvidia/nv-embed-v1",
125 EmbeddingModelChoice::Gemini => "text-embedding-004",
126 EmbeddingModelChoice::Mistral => "mistral-embed",
127 }
128 }
129
130 pub fn dimensions(&self) -> usize {
132 match self {
133 EmbeddingModelChoice::BgeSmall => 384,
134 EmbeddingModelChoice::BgeBase => 768,
135 EmbeddingModelChoice::Nomic => 768,
136 EmbeddingModelChoice::GteLarge => 1024,
137 EmbeddingModelChoice::OpenAILarge => 3072,
138 EmbeddingModelChoice::OpenAISmall => 1536,
139 EmbeddingModelChoice::OpenAIAda => 1536,
140 EmbeddingModelChoice::Nvidia => 0,
142 EmbeddingModelChoice::Gemini => 768,
143 EmbeddingModelChoice::Mistral => 1024,
144 }
145 }
146}
147
148impl FromStr for EmbeddingModelChoice {
149 type Err = anyhow::Error;
150
151 fn from_str(s: &str) -> Result<Self> {
152 let lowered = s.trim().to_ascii_lowercase();
153 match lowered.as_str() {
154 "bge-small" | "bge_small" | "bgesmall" | "small" => Ok(EmbeddingModelChoice::BgeSmall),
155 "baai/bge-small-en-v1.5" => Ok(EmbeddingModelChoice::BgeSmall),
156 "bge-base" | "bge_base" | "bgebase" | "base" => Ok(EmbeddingModelChoice::BgeBase),
157 "baai/bge-base-en-v1.5" => Ok(EmbeddingModelChoice::BgeBase),
158 "nomic" | "nomic-embed" | "nomic_embed" => Ok(EmbeddingModelChoice::Nomic),
159 "nomic-embed-text-v1.5" => Ok(EmbeddingModelChoice::Nomic),
160 "gte-large" | "gte_large" | "gtelarge" | "gte" => Ok(EmbeddingModelChoice::GteLarge),
161 "thenlper/gte-large" => Ok(EmbeddingModelChoice::GteLarge),
162 "openai" | "openai-large" | "openai_large" | "text-embedding-3-large" => {
164 Ok(EmbeddingModelChoice::OpenAILarge)
165 }
166 "openai-small" | "openai_small" | "text-embedding-3-small" => {
167 Ok(EmbeddingModelChoice::OpenAISmall)
168 }
169 "openai-ada" | "openai_ada" | "text-embedding-ada-002" | "ada" => {
170 Ok(EmbeddingModelChoice::OpenAIAda)
171 }
172 "nvidia" | "nv" | "nv-embed-v1" | "nvidia/nv-embed-v1" => Ok(EmbeddingModelChoice::Nvidia),
173 _ if lowered.starts_with("nvidia/") || lowered.starts_with("nvidia:") || lowered.starts_with("nv:") => {
174 Ok(EmbeddingModelChoice::Nvidia)
175 }
176 "gemini" | "gemini-embed" | "text-embedding-004" | "gemini-embedding-001" => {
178 Ok(EmbeddingModelChoice::Gemini)
179 }
180 _ if lowered.starts_with("gemini/") || lowered.starts_with("gemini:") || lowered.starts_with("google:") => {
181 Ok(EmbeddingModelChoice::Gemini)
182 }
183 "mistral" | "mistral-embed" => Ok(EmbeddingModelChoice::Mistral),
185 _ if lowered.starts_with("mistral/") || lowered.starts_with("mistral:") => {
186 Ok(EmbeddingModelChoice::Mistral)
187 }
188 _ => Err(anyhow!(
189 "unknown embedding model '{}'. Valid options: bge-small, bge-base, nomic, gte-large, openai, openai-small, openai-ada, nvidia, gemini, mistral",
190 s
191 )),
192 }
193 }
194}
195
196impl std::fmt::Display for EmbeddingModelChoice {
197 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
198 write!(f, "{}", self.name())
199 }
200}
201
202impl EmbeddingModelChoice {
203 pub fn from_dimension(dim: u32) -> Option<Self> {
215 match dim {
216 384 => Some(EmbeddingModelChoice::BgeSmall),
217 768 => Some(EmbeddingModelChoice::BgeBase), 1024 => Some(EmbeddingModelChoice::GteLarge),
219 1536 => Some(EmbeddingModelChoice::OpenAISmall), 3072 => Some(EmbeddingModelChoice::OpenAILarge),
221 0 => None, _ => {
223 tracing::warn!("Unknown embedding dimension {}, using default model", dim);
224 None
225 }
226 }
227 }
228}
229
230#[derive(Debug, Clone)]
232pub struct CliConfig {
233 pub api_key: Option<String>,
234 pub api_url: String,
235 pub memory_id: Option<String>,
237 pub cache_dir: PathBuf,
238 pub ticket_pubkey: Option<VerifyingKey>,
239 pub models_dir: PathBuf,
240 pub offline: bool,
241 pub embedding_model: EmbeddingModelChoice,
243}
244
245impl PartialEq for CliConfig {
246 fn eq(&self, other: &Self) -> bool {
247 self.api_key == other.api_key
248 && self.api_url == other.api_url
249 && self.memory_id == other.memory_id
250 && self.cache_dir == other.cache_dir
251 && self.models_dir == other.models_dir
252 && self.offline == other.offline
253 && self.embedding_model == other.embedding_model
254 }
255}
256
257impl Eq for CliConfig {}
258
259impl CliConfig {
260 pub fn load() -> Result<Self> {
261 let persistent_config = crate::commands::config::PersistentConfig::load().ok();
263
264 let api_key = env::var("MEMVID_API_KEY")
266 .ok()
267 .and_then(|value| {
268 let trimmed = value.trim().to_string();
269 (!trimmed.is_empty()).then_some(trimmed)
270 })
271 .or_else(|| persistent_config.as_ref().and_then(|c| c.api_key.clone()));
272
273 let api_url = env::var("MEMVID_API_URL")
275 .ok()
276 .or_else(|| persistent_config.as_ref().and_then(|c| c.api_url.clone()))
277 .unwrap_or_else(|| DEFAULT_API_URL.to_string());
278
279 let memory_id = env::var("MEMVID_MEMORY_ID")
281 .ok()
282 .and_then(|value| {
283 let trimmed = value.trim().to_string();
284 (!trimmed.is_empty()).then_some(trimmed)
285 })
286 .or_else(|| persistent_config.as_ref().and_then(|c| c.default_memory_id()));
287
288 let cache_dir_raw =
289 env::var("MEMVID_CACHE_DIR").unwrap_or_else(|_| DEFAULT_CACHE_DIR.to_string());
290 let cache_dir = expand_path(&cache_dir_raw)?;
291
292 let models_dir_raw =
293 env::var("MEMVID_MODELS_DIR").unwrap_or_else(|_| "~/.memvid/models".to_string());
294 let models_dir = expand_path(&models_dir_raw)?;
295
296 const DEFAULT_TICKET_PUBKEY: &str = "8wP1J2H+Tlx3PM3eT0lN2wDvoYrvl1DREKGKVb/V2cw=";
299
300 let ticket_pubkey_str = env::var("MEMVID_TICKET_PUBKEY")
301 .ok()
302 .and_then(|value| {
303 let trimmed = value.trim();
304 if trimmed.is_empty() {
305 None
306 } else {
307 Some(trimmed.to_string())
308 }
309 })
310 .unwrap_or_else(|| DEFAULT_TICKET_PUBKEY.to_string());
311
312 let ticket_pubkey = Some(memvid_core::parse_ed25519_public_key_base64(&ticket_pubkey_str)?);
313
314 let offline = env::var("MEMVID_OFFLINE")
315 .ok()
316 .map(|value| match value.trim().to_ascii_lowercase().as_str() {
317 "1" | "true" | "yes" => true,
318 _ => false,
319 })
320 .unwrap_or(false);
321
322 let embedding_model = env::var("MEMVID_EMBEDDING_MODEL")
324 .ok()
325 .and_then(|value| {
326 let trimmed = value.trim();
327 if trimmed.is_empty() {
328 None
329 } else {
330 EmbeddingModelChoice::from_str(trimmed).ok()
331 }
332 })
333 .unwrap_or_default();
334
335 Ok(Self {
336 api_key,
337 api_url,
338 memory_id,
339 cache_dir,
340 ticket_pubkey,
341 models_dir,
342 offline,
343 embedding_model,
344 })
345 }
346
347 pub fn with_embedding_model(&self, model: EmbeddingModelChoice) -> Self {
349 Self {
350 embedding_model: model,
351 ..self.clone()
352 }
353 }
354}
355
356fn expand_path(value: &str) -> Result<PathBuf> {
357 if value.trim().is_empty() {
358 return Err(anyhow!("cache directory cannot be empty"));
359 }
360
361 let expanded = if let Some(stripped) = value.strip_prefix("~/") {
362 home_dir()?.join(stripped)
363 } else if let Some(stripped) = value.strip_prefix("~\\") {
364 home_dir()?.join(stripped)
366 } else if value == "~" {
367 home_dir()?
368 } else {
369 PathBuf::from(value)
370 };
371
372 if expanded.is_absolute() {
373 Ok(expanded)
374 } else {
375 Ok(env::current_dir()?.join(expanded))
376 }
377}
378
379fn home_dir() -> Result<PathBuf> {
380 if let Some(path) = env::var_os("HOME") {
381 if !path.is_empty() {
382 return Ok(PathBuf::from(path));
383 }
384 }
385
386 #[cfg(windows)]
387 {
388 if let Some(path) = env::var_os("USERPROFILE") {
389 if !path.is_empty() {
390 return Ok(PathBuf::from(path));
391 }
392 }
393 if let (Some(drive), Some(path)) = (env::var_os("HOMEDRIVE"), env::var_os("HOMEPATH")) {
394 if !drive.is_empty() && !path.is_empty() {
395 return Ok(PathBuf::from(format!(
396 "{}{}",
397 drive.to_string_lossy(),
398 path.to_string_lossy()
399 )));
400 }
401 }
402 }
403
404 Err(anyhow!("unable to resolve home directory"))
405}
406
407#[cfg(test)]
408mod tests {
409 use super::*;
410 use base64::engine::general_purpose::STANDARD as BASE64_STANDARD;
411 use base64::Engine;
412 use ed25519_dalek::SigningKey;
413 use std::sync::{Mutex, OnceLock};
414
415 fn env_lock() -> std::sync::MutexGuard<'static, ()> {
416 static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
417 LOCK.get_or_init(|| Mutex::new(())).lock().unwrap()
418 }
419
420 fn set_or_unset(var: &str, value: Option<String>) {
421 match value {
422 Some(v) => unsafe { env::set_var(var, v) },
423 None => unsafe { env::remove_var(var) },
424 }
425 }
426
427 #[test]
428 fn defaults_expand_using_home_directory() {
429 let _guard = env_lock();
430
431 let previous_home = env::var("HOME").ok();
432 #[cfg(windows)]
433 let previous_userprofile = env::var("USERPROFILE").ok();
434
435 for var in [
436 "MEMVID_API_KEY",
437 "MEMVID_API_URL",
438 "MEMVID_CACHE_DIR",
439 "MEMVID_TICKET_PUBKEY",
440 "MEMVID_MODELS_DIR",
441 "MEMVID_OFFLINE",
442 ] {
443 unsafe { env::remove_var(var) };
444 }
445
446 let tmp = tempfile::tempdir().expect("tmpdir");
447 let tmp_path = tmp.path().to_path_buf();
448 unsafe { env::set_var("HOME", &tmp_path) };
449 #[cfg(windows)]
450 unsafe {
451 env::set_var("USERPROFILE", &tmp_path)
452 };
453
454 let config = CliConfig::load().expect("load");
455 assert_eq!(config.api_key, None);
456 assert_eq!(config.api_url, "https://memvid.com");
457 assert_eq!(config.cache_dir, tmp_path.join(".cache/memvid"));
458 assert!(config.ticket_pubkey.is_some());
460 assert_eq!(config.models_dir, tmp_path.join(".memvid/models"));
461 assert!(!config.offline);
462
463 set_or_unset("HOME", previous_home);
464 #[cfg(windows)]
465 {
466 set_or_unset("USERPROFILE", previous_userprofile);
467 }
468 }
469
470 #[test]
471 fn env_overrides_are_respected() {
472 let _guard = env_lock();
473
474 let previous_env: Vec<(&'static str, Option<String>)> = [
475 "MEMVID_API_KEY",
476 "MEMVID_API_URL",
477 "MEMVID_CACHE_DIR",
478 "MEMVID_TICKET_PUBKEY",
479 "MEMVID_MODELS_DIR",
480 "MEMVID_OFFLINE",
481 ]
482 .into_iter()
483 .map(|var| (var, env::var(var).ok()))
484 .collect();
485
486 unsafe { env::set_var("MEMVID_API_KEY", "abc123") };
487 unsafe { env::set_var("MEMVID_API_URL", "https://staging.memvid.app") };
488 unsafe { env::set_var("MEMVID_CACHE_DIR", "~/memvid-cache") };
489 unsafe { env::set_var("MEMVID_MODELS_DIR", "~/models") };
490 unsafe { env::set_var("MEMVID_OFFLINE", "true") };
491 let signing = SigningKey::from_bytes(&[9u8; 32]);
492 let encoded = BASE64_STANDARD.encode(signing.verifying_key().as_bytes());
493 unsafe { env::set_var("MEMVID_TICKET_PUBKEY", encoded) };
494
495 let tmp = tempfile::tempdir().expect("tmpdir");
496 let tmp_path = tmp.path().to_path_buf();
497 unsafe { env::set_var("HOME", &tmp_path) };
498 #[cfg(windows)]
499 unsafe {
500 env::set_var("USERPROFILE", &tmp_path)
501 };
502
503 let config = CliConfig::load().expect("load");
504 assert_eq!(config.api_key.as_deref(), Some("abc123"));
505 assert_eq!(config.api_url, "https://staging.memvid.app");
506 assert_eq!(config.cache_dir, tmp_path.join("memvid-cache"));
507 assert_eq!(
508 config.ticket_pubkey.expect("pubkey").as_bytes(),
509 signing.verifying_key().as_bytes()
510 );
511 assert_eq!(config.models_dir, tmp_path.join("models"));
512 assert!(config.offline);
513
514 for (var, value) in previous_env {
515 set_or_unset(var, value);
516 }
517 }
518
519 #[test]
520 fn rejects_empty_cache_dir() {
521 let _guard = env_lock();
522
523 let previous = env::var("MEMVID_CACHE_DIR").ok();
524 unsafe { env::set_var("MEMVID_CACHE_DIR", " ") };
525 let err = CliConfig::load().expect_err("should fail");
526 assert!(err.to_string().contains("cache directory"));
527 set_or_unset("MEMVID_CACHE_DIR", previous);
528 }
529}
530
531pub fn init_tracing(verbosity: u8) -> Result<()> {
533 use std::io::IsTerminal;
534 use tracing_subscriber::{filter::Directive, fmt, EnvFilter};
535
536 let level = match verbosity {
537 0 => "warn",
538 1 => "info",
539 2 => "debug",
540 _ => "trace",
541 };
542
543 let mut env_filter =
544 EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new(level));
545 for directive_str in ["llama_cpp=error", "llama_cpp_sys=error", "ggml=error"] {
546 if let Ok(directive) = directive_str.parse::<Directive>() {
547 env_filter = env_filter.add_directive(directive);
548 }
549 }
550
551 let use_ansi = std::io::stderr().is_terminal();
555
556 fmt()
557 .with_env_filter(env_filter)
558 .with_writer(std::io::stderr)
559 .with_target(false)
560 .without_time()
561 .with_ansi(use_ansi)
562 .try_init()
563 .map_err(|err| anyhow!(err))?;
564 Ok(())
565}
566
567pub fn resolve_llm_context_budget_override(cli_value: Option<usize>) -> Result<Option<usize>> {
569 use anyhow::bail;
570
571 if let Some(value) = cli_value {
572 if value == 0 {
573 bail!("--llm-context-depth must be a positive integer");
574 }
575 return Ok(Some(value));
576 }
577
578 let raw_env = match env::var("MEMVID_LLM_CONTEXT_BUDGET") {
579 Ok(value) => value,
580 Err(_) => return Ok(None),
581 };
582
583 let trimmed = raw_env.trim();
584 if trimmed.is_empty() {
585 return Ok(None);
586 }
587
588 let digits: String = trimmed
589 .chars()
590 .filter(|ch| !ch.is_ascii_whitespace() && *ch != '_')
591 .collect();
592
593 if digits.is_empty() {
594 bail!("MEMVID_LLM_CONTEXT_BUDGET must be a positive integer value");
595 }
596
597 let value: usize = digits.parse().map_err(|err| {
598 anyhow!(
599 "MEMVID_LLM_CONTEXT_BUDGET value '{}' is not a valid number: {}",
600 trimmed,
601 err
602 )
603 })?;
604
605 if value == 0 {
606 bail!("MEMVID_LLM_CONTEXT_BUDGET must be a positive integer");
607 }
608
609 Ok(Some(value))
610}
611
612use crate::gemini_embeddings::GeminiEmbeddingProvider;
613use crate::mistral_embeddings::MistralEmbeddingProvider;
614use crate::nvidia_embeddings::NvidiaEmbeddingProvider;
615use crate::openai_embeddings::OpenAIEmbeddingProvider;
616
617#[derive(Clone)]
619enum EmbeddingBackend {
620 FastEmbed(std::sync::Arc<std::sync::Mutex<fastembed::TextEmbedding>>),
621 OpenAI(std::sync::Arc<OpenAIEmbeddingProvider>),
622 Nvidia(std::sync::Arc<NvidiaEmbeddingProvider>),
623 Gemini(std::sync::Arc<GeminiEmbeddingProvider>),
624 Mistral(std::sync::Arc<MistralEmbeddingProvider>),
625}
626
627#[derive(Clone)]
629pub struct EmbeddingRuntime {
630 backend: EmbeddingBackend,
631 model: EmbeddingModelChoice,
632 dimension: std::sync::Arc<AtomicUsize>,
633}
634
635impl EmbeddingRuntime {
636 fn new_fastembed(
637 backend: fastembed::TextEmbedding,
638 model: EmbeddingModelChoice,
639 dimension: usize,
640 ) -> Self {
641 Self {
642 backend: EmbeddingBackend::FastEmbed(std::sync::Arc::new(std::sync::Mutex::new(
643 backend,
644 ))),
645 model,
646 dimension: std::sync::Arc::new(AtomicUsize::new(dimension)),
647 }
648 }
649
650 fn new_openai(
651 provider: OpenAIEmbeddingProvider,
652 model: EmbeddingModelChoice,
653 dimension: usize,
654 ) -> Self {
655 Self {
656 backend: EmbeddingBackend::OpenAI(std::sync::Arc::new(provider)),
657 model,
658 dimension: std::sync::Arc::new(AtomicUsize::new(dimension)),
659 }
660 }
661
662 fn new_nvidia(provider: NvidiaEmbeddingProvider, model: EmbeddingModelChoice) -> Self {
663 Self {
664 backend: EmbeddingBackend::Nvidia(std::sync::Arc::new(provider)),
665 model,
666 dimension: std::sync::Arc::new(AtomicUsize::new(0)),
667 }
668 }
669
670 fn new_gemini(
671 provider: GeminiEmbeddingProvider,
672 model: EmbeddingModelChoice,
673 dimension: usize,
674 ) -> Self {
675 Self {
676 backend: EmbeddingBackend::Gemini(std::sync::Arc::new(provider)),
677 model,
678 dimension: std::sync::Arc::new(AtomicUsize::new(dimension)),
679 }
680 }
681
682 fn new_mistral(
683 provider: MistralEmbeddingProvider,
684 model: EmbeddingModelChoice,
685 dimension: usize,
686 ) -> Self {
687 Self {
688 backend: EmbeddingBackend::Mistral(std::sync::Arc::new(provider)),
689 model,
690 dimension: std::sync::Arc::new(AtomicUsize::new(dimension)),
691 }
692 }
693
694 const MAX_OPENAI_EMBEDDING_TEXT_LEN: usize = 20_000;
695 const MAX_NVIDIA_EMBEDDING_TEXT_LEN: usize = 12_000;
697
698 const MAX_GEMINI_EMBEDDING_TEXT_LEN: usize = 20_000;
700 const MAX_MISTRAL_EMBEDDING_TEXT_LEN: usize = 20_000;
702
703 fn max_remote_embedding_chars(&self) -> usize {
704 match &self.backend {
705 EmbeddingBackend::OpenAI(_) => Self::MAX_OPENAI_EMBEDDING_TEXT_LEN,
706 EmbeddingBackend::Nvidia(_) => Self::MAX_NVIDIA_EMBEDDING_TEXT_LEN,
707 EmbeddingBackend::Gemini(_) => Self::MAX_GEMINI_EMBEDDING_TEXT_LEN,
708 EmbeddingBackend::Mistral(_) => Self::MAX_MISTRAL_EMBEDDING_TEXT_LEN,
709 EmbeddingBackend::FastEmbed(_) => usize::MAX,
710 }
711 }
712
713 fn truncate_for_embedding<'a>(
715 text: &'a str,
716 max_chars: usize,
717 ) -> std::borrow::Cow<'a, str> {
718 if text.len() <= max_chars {
719 std::borrow::Cow::Borrowed(text)
720 } else {
721 let truncated = &text[..max_chars];
723 let end = truncated
724 .char_indices()
725 .rev()
726 .next()
727 .map(|(i, c)| i + c.len_utf8())
728 .unwrap_or(max_chars);
729 tracing::info!("Truncated embedding text from {} to {} chars", text.len(), end);
730 std::borrow::Cow::Owned(text[..end].to_string())
731 }
732 }
733
734 fn note_dimension(&self, observed: usize) -> Result<()> {
735 if observed == 0 {
736 return Err(anyhow!("embedding provider returned zero-length embedding"));
737 }
738
739 let current = self.dimension.load(Ordering::Relaxed);
740 if current == 0 {
741 self.dimension.store(observed, Ordering::Relaxed);
742 return Ok(());
743 }
744
745 if current != observed {
746 return Err(anyhow!(
747 "embedding provider returned {observed}D vectors but runtime expects {current}D"
748 ));
749 }
750
751 Ok(())
752 }
753
754 fn truncate_if_remote<'a>(&self, text: &'a str) -> std::borrow::Cow<'a, str> {
755 match &self.backend {
756 EmbeddingBackend::OpenAI(_)
757 | EmbeddingBackend::Nvidia(_)
758 | EmbeddingBackend::Gemini(_)
759 | EmbeddingBackend::Mistral(_) => {
760 Self::truncate_for_embedding(text, self.max_remote_embedding_chars())
761 }
762 EmbeddingBackend::FastEmbed(_) => std::borrow::Cow::Borrowed(text),
763 }
764 }
765
766 pub fn embed_passage(&self, text: &str) -> Result<Vec<f32>> {
767 let text = self.truncate_if_remote(text);
768 let embedding = match &self.backend {
769 EmbeddingBackend::FastEmbed(model) => {
770 let mut guard = model
771 .lock()
772 .map_err(|_| anyhow!("fastembed runtime poisoned"))?;
773 let outputs = guard
774 .embed(vec![text.into_owned()], None)
775 .map_err(|err| anyhow!("failed to compute embedding with fastembed: {err}"))?;
776 outputs
777 .into_iter()
778 .next()
779 .ok_or_else(|| anyhow!("fastembed returned no embedding output"))?
780 }
781 EmbeddingBackend::OpenAI(provider) => {
782 use memvid_core::EmbeddingProvider;
783 provider
784 .embed_text(&text)
785 .map_err(|err| anyhow!("failed to compute embedding with OpenAI: {err}"))?
786 }
787 EmbeddingBackend::Nvidia(provider) => provider
788 .embed_passage(&text)
789 .map_err(|err| anyhow!("failed to compute embedding with NVIDIA: {err}"))?,
790 EmbeddingBackend::Gemini(provider) => provider
791 .embed_text(&text)
792 .map_err(|err| anyhow!("failed to compute embedding with Gemini: {err}"))?,
793 EmbeddingBackend::Mistral(provider) => provider
794 .embed_text(&text)
795 .map_err(|err| anyhow!("failed to compute embedding with Mistral: {err}"))?,
796 };
797
798 self.note_dimension(embedding.len())?;
799 Ok(embedding)
800 }
801
802 pub fn embed_query(&self, text: &str) -> Result<Vec<f32>> {
803 let text = self.truncate_if_remote(text);
804 match &self.backend {
805 EmbeddingBackend::Nvidia(provider) => {
806 let embedding = provider
807 .embed_query(&text)
808 .map_err(|err| anyhow!("failed to compute embedding with NVIDIA: {err}"))?;
809 self.note_dimension(embedding.len())?;
810 Ok(embedding)
811 }
812 _ => self.embed_passage(&text),
813 }
814 }
815
816 pub fn embed_batch_passages(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
817 if texts.is_empty() {
818 return Ok(Vec::new());
819 }
820
821 let truncated: Vec<std::borrow::Cow<'_, str>> =
822 texts.iter().map(|t| self.truncate_if_remote(t)).collect();
823 let truncated_refs: Vec<&str> = truncated.iter().map(|c| c.as_ref()).collect();
824
825 let embeddings = match &self.backend {
826 EmbeddingBackend::FastEmbed(model) => {
827 let mut guard = model
828 .lock()
829 .map_err(|_| anyhow!("fastembed runtime poisoned"))?;
830 guard
831 .embed(
832 truncated_refs
833 .iter()
834 .map(|s| (*s).to_string())
835 .collect::<Vec<String>>(),
836 None,
837 )
838 .map_err(|err| anyhow!("failed to compute embeddings with fastembed: {err}"))?
839 }
840 EmbeddingBackend::OpenAI(provider) => {
841 use memvid_core::EmbeddingProvider;
842 provider
843 .embed_batch(&truncated_refs)
844 .map_err(|err| anyhow!("failed to compute embeddings with OpenAI: {err}"))?
845 }
846 EmbeddingBackend::Nvidia(provider) => provider
847 .embed_passages(&truncated_refs)
848 .map_err(|err| anyhow!("failed to compute embeddings with NVIDIA: {err}"))?,
849 EmbeddingBackend::Gemini(provider) => provider
850 .embed_batch(&truncated_refs)
851 .map_err(|err| anyhow!("failed to compute embeddings with Gemini: {err}"))?,
852 EmbeddingBackend::Mistral(provider) => provider
853 .embed_batch(&truncated_refs)
854 .map_err(|err| anyhow!("failed to compute embeddings with Mistral: {err}"))?,
855 };
856
857 if let Some(first) = embeddings.first() {
858 self.note_dimension(first.len())?;
859 }
860 if let Some(expected) = embeddings.first().map(|e| e.len()) {
861 if embeddings.iter().any(|e| e.len() != expected) {
862 return Err(anyhow!("embedding provider returned mixed vector dimensions"));
863 }
864 }
865
866 Ok(embeddings)
867 }
868
869 pub fn embed_batch_queries(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
870 if texts.is_empty() {
871 return Ok(Vec::new());
872 }
873
874 let truncated: Vec<std::borrow::Cow<'_, str>> =
875 texts.iter().map(|t| self.truncate_if_remote(t)).collect();
876 let truncated_refs: Vec<&str> = truncated.iter().map(|c| c.as_ref()).collect();
877
878 match &self.backend {
879 EmbeddingBackend::Nvidia(provider) => {
880 let embeddings = provider
881 .embed_queries(&truncated_refs)
882 .map_err(|err| anyhow!("failed to compute embeddings with NVIDIA: {err}"))?;
883
884 if let Some(first) = embeddings.first() {
885 self.note_dimension(first.len())?;
886 }
887 if let Some(expected) = embeddings.first().map(|e| e.len()) {
888 if embeddings.iter().any(|e| e.len() != expected) {
889 return Err(anyhow!("embedding provider returned mixed vector dimensions"));
890 }
891 }
892
893 Ok(embeddings)
894 }
895 _ => self.embed_batch_passages(&truncated_refs),
896 }
897 }
898
899 pub fn dimension(&self) -> usize {
900 self.dimension.load(Ordering::Relaxed)
901 }
902
903 pub fn model_choice(&self) -> EmbeddingModelChoice {
904 self.model
905 }
906
907 pub fn provider_kind(&self) -> &'static str {
908 match &self.backend {
909 EmbeddingBackend::FastEmbed(_) => "fastembed",
910 EmbeddingBackend::OpenAI(_) => "openai",
911 EmbeddingBackend::Nvidia(_) => "nvidia",
912 EmbeddingBackend::Gemini(_) => "gemini",
913 EmbeddingBackend::Mistral(_) => "mistral",
914 }
915 }
916
917 pub fn provider_model_id(&self) -> String {
918 match &self.backend {
919 EmbeddingBackend::FastEmbed(_) => self.model.canonical_model_id().to_string(),
920 EmbeddingBackend::OpenAI(provider) => {
921 use memvid_core::EmbeddingProvider;
922 provider.model().to_string()
923 }
924 EmbeddingBackend::Nvidia(provider) => provider.model().to_string(),
925 EmbeddingBackend::Gemini(provider) => provider.model().to_string(),
926 EmbeddingBackend::Mistral(provider) => provider.model().to_string(),
927 }
928 }
929}
930
931impl memvid_core::VecEmbedder for EmbeddingRuntime {
932 fn embed_query(&self, text: &str) -> memvid_core::Result<Vec<f32>> {
933 EmbeddingRuntime::embed_query(self, text).map_err(|err| {
934 memvid_core::MemvidError::EmbeddingFailed {
935 reason: err.to_string().into_boxed_str(),
936 }
937 })
938 }
939
940 fn embedding_dimension(&self) -> usize {
941 self.dimension()
942 }
943}
944
945fn ensure_fastembed_cache(config: &CliConfig) -> Result<PathBuf> {
947 use std::fs;
948
949 let cache_dir = config.models_dir.clone();
950 fs::create_dir_all(&cache_dir)?;
951 Ok(cache_dir)
952}
953
954fn model_size_mb(model: EmbeddingModelChoice) -> usize {
956 match model {
957 EmbeddingModelChoice::BgeSmall => 33,
958 EmbeddingModelChoice::BgeBase => 110,
959 EmbeddingModelChoice::Nomic => 137,
960 EmbeddingModelChoice::GteLarge => 327,
961 EmbeddingModelChoice::OpenAILarge
963 | EmbeddingModelChoice::OpenAISmall
964 | EmbeddingModelChoice::OpenAIAda
965 | EmbeddingModelChoice::Nvidia
966 | EmbeddingModelChoice::Gemini
967 | EmbeddingModelChoice::Mistral => 0,
968 }
969}
970
971fn instantiate_embedding_runtime(config: &CliConfig) -> Result<EmbeddingRuntime> {
973 use tracing::info;
974
975 let embedding_model = config.embedding_model;
976
977 if embedding_model.dimensions() > 0 {
978 info!(
979 "Loading embedding model: {} ({}D)",
980 embedding_model.name(),
981 embedding_model.dimensions()
982 );
983 } else {
984 info!("Loading embedding model: {}", embedding_model.name());
985 }
986
987 if config.offline && embedding_model.is_remote() {
988 anyhow::bail!(
989 "remote embeddings are unavailable while offline; set MEMVID_OFFLINE=0 or use a local embedding model"
990 );
991 }
992
993 if embedding_model.is_openai() {
995 return instantiate_openai_runtime(embedding_model);
996 }
997
998 if embedding_model == EmbeddingModelChoice::Nvidia {
999 return instantiate_nvidia_runtime(None);
1000 }
1001
1002 if embedding_model == EmbeddingModelChoice::Gemini {
1003 return instantiate_gemini_runtime();
1004 }
1005
1006 if embedding_model == EmbeddingModelChoice::Mistral {
1007 return instantiate_mistral_runtime();
1008 }
1009
1010 instantiate_fastembed_runtime(config, embedding_model)
1012}
1013
1014fn instantiate_openai_runtime(embedding_model: EmbeddingModelChoice) -> Result<EmbeddingRuntime> {
1016 use anyhow::bail;
1017 use memvid_core::EmbeddingConfig;
1018 use tracing::info;
1019
1020 let api_key = std::env::var("OPENAI_API_KEY").map_err(|_| {
1021 anyhow!("OPENAI_API_KEY environment variable is required for OpenAI embeddings")
1022 })?;
1023
1024 if api_key.is_empty() {
1025 bail!("OPENAI_API_KEY cannot be empty");
1026 }
1027
1028 let config = match embedding_model {
1029 EmbeddingModelChoice::OpenAILarge => EmbeddingConfig::openai_large(),
1030 EmbeddingModelChoice::OpenAISmall => EmbeddingConfig::openai_small(),
1031 EmbeddingModelChoice::OpenAIAda => EmbeddingConfig::openai_ada(),
1032 _ => unreachable!("is_openai() should have been false"),
1033 };
1034
1035 let provider = OpenAIEmbeddingProvider::new(api_key, config.clone())
1036 .map_err(|err| anyhow!("failed to create OpenAI embedding provider: {err}"))?;
1037
1038 info!(
1039 "OpenAI embedding provider ready: model={}, dimension={}",
1040 config.model, config.dimension
1041 );
1042
1043 Ok(EmbeddingRuntime::new_openai(
1044 provider,
1045 embedding_model,
1046 config.dimension,
1047 ))
1048}
1049
1050fn normalize_nvidia_embedding_model_override(raw: &str) -> Option<String> {
1051 let trimmed = raw.trim();
1052 if trimmed.is_empty() {
1053 return None;
1054 }
1055
1056 let lowered = trimmed.to_ascii_lowercase();
1057 if lowered == "nvidia" || lowered == "nv" {
1058 return None;
1059 }
1060
1061 let without_prefix = trimmed
1062 .strip_prefix("nvidia:")
1063 .or_else(|| trimmed.strip_prefix("nv:"))
1064 .unwrap_or(trimmed)
1065 .trim();
1066
1067 if without_prefix.is_empty() {
1068 return None;
1069 }
1070
1071 if without_prefix.eq_ignore_ascii_case("nv-embed-v1") {
1072 return Some("nvidia/nv-embed-v1".to_string());
1073 }
1074
1075 if without_prefix.contains('/') {
1076 return Some(without_prefix.to_string());
1077 }
1078
1079 Some(format!("nvidia/{without_prefix}"))
1080}
1081
1082fn instantiate_nvidia_runtime(model_override: Option<&str>) -> Result<EmbeddingRuntime> {
1084 use tracing::info;
1085
1086 let normalized = model_override.and_then(normalize_nvidia_embedding_model_override);
1087 let provider = NvidiaEmbeddingProvider::from_env(normalized.as_deref())
1088 .map_err(|err| anyhow!("failed to create NVIDIA embedding provider: {err}"))?;
1089
1090 info!(
1091 "NVIDIA embedding provider ready: model={}",
1092 provider.model()
1093 );
1094
1095 Ok(EmbeddingRuntime::new_nvidia(
1096 provider,
1097 EmbeddingModelChoice::Nvidia,
1098 ))
1099}
1100
1101fn instantiate_gemini_runtime() -> Result<EmbeddingRuntime> {
1103 use tracing::info;
1104
1105 let provider = GeminiEmbeddingProvider::from_env()
1106 .map_err(|err| anyhow!("failed to create Gemini embedding provider: {err}"))?;
1107
1108 let dimension = provider.dimension();
1109 info!(
1110 "Gemini embedding provider ready: model={}, dimension={}",
1111 provider.model(),
1112 dimension
1113 );
1114
1115 Ok(EmbeddingRuntime::new_gemini(
1116 provider,
1117 EmbeddingModelChoice::Gemini,
1118 dimension,
1119 ))
1120}
1121
1122fn instantiate_mistral_runtime() -> Result<EmbeddingRuntime> {
1124 use tracing::info;
1125
1126 let provider = MistralEmbeddingProvider::from_env()
1127 .map_err(|err| anyhow!("failed to create Mistral embedding provider: {err}"))?;
1128
1129 let dimension = provider.dimension();
1130 info!(
1131 "Mistral embedding provider ready: model={}, dimension={}",
1132 provider.model(),
1133 dimension
1134 );
1135
1136 Ok(EmbeddingRuntime::new_mistral(
1137 provider,
1138 EmbeddingModelChoice::Mistral,
1139 dimension,
1140 ))
1141}
1142
1143fn instantiate_fastembed_runtime(
1145 config: &CliConfig,
1146 embedding_model: EmbeddingModelChoice,
1147) -> Result<EmbeddingRuntime> {
1148 use anyhow::bail;
1149 use fastembed::{InitOptions, TextEmbedding};
1150 use std::fs;
1151
1152 let cache_dir = ensure_fastembed_cache(config)?;
1153
1154 if config.offline {
1155 let mut entries = fs::read_dir(&cache_dir)?;
1156 if entries.next().is_none() {
1157 bail!(
1158 "semantic embeddings unavailable while offline; allow one connected run so fastembed can cache model weights"
1159 );
1160 }
1161 }
1162
1163 let options = InitOptions::new(embedding_model.to_fastembed_model())
1164 .with_cache_dir(cache_dir)
1165 .with_show_download_progress(true);
1166 let mut model = TextEmbedding::try_new(options).map_err(|err| {
1167 let platform_hint = if cfg!(target_os = "windows") {
1169 "\n\nWindows users: If model downloads fail, try:\n\
1170 1. Run as Administrator\n\
1171 2. Check your antivirus isn't blocking downloads\n\
1172 3. Use OpenAI embeddings instead: set OPENAI_API_KEY and use --embedding-model openai"
1173 } else if cfg!(target_os = "linux") {
1174 "\n\nLinux users: If model downloads fail, try:\n\
1175 1. Check disk space in ~/.memvid/models\n\
1176 2. Ensure you have network access to huggingface.co\n\
1177 3. Use OpenAI embeddings instead: export OPENAI_API_KEY=... and use --embedding-model openai"
1178 } else {
1179 "\n\nIf model downloads fail, try using OpenAI embeddings:\n\
1180 export OPENAI_API_KEY=your-key && memvid ... --embedding-model openai"
1181 };
1182
1183 anyhow!(
1184 "Failed to initialize embedding model '{}': {err}\n\n\
1185 This typically means the model couldn't be downloaded or loaded.\n\
1186 Model size: ~{} MB{}\n\n\
1187 See: https://docs.memvid.com/embedding-models",
1188 embedding_model.name(),
1189 model_size_mb(embedding_model),
1190 platform_hint
1191 )
1192 })?;
1193
1194 let probe = model
1195 .embed(vec!["memvid probe".to_string()], None)
1196 .map_err(|err| anyhow!("failed to determine embedding dimension: {err}"))?;
1197 let dimension = probe.first().map(|vec| vec.len()).unwrap_or(0);
1198
1199 if dimension == 0 {
1200 bail!("fastembed reported zero-length embeddings");
1201 }
1202
1203 if dimension != embedding_model.dimensions() {
1205 tracing::warn!(
1206 "Embedding dimension mismatch: expected {}, got {}",
1207 embedding_model.dimensions(),
1208 dimension
1209 );
1210 }
1211
1212 Ok(EmbeddingRuntime::new_fastembed(model, embedding_model, dimension))
1213}
1214
1215pub fn load_embedding_runtime(config: &CliConfig) -> Result<EmbeddingRuntime> {
1217 use anyhow::bail;
1218
1219 match instantiate_embedding_runtime(config) {
1220 Ok(runtime) => Ok(runtime),
1221 Err(err) => {
1222 if config.offline {
1223 bail!(
1224 "semantic embeddings unavailable while offline; allow one connected run so fastembed can cache model weights ({err})"
1225 );
1226 }
1227 Err(err)
1228 }
1229 }
1230}
1231
1232pub fn try_load_embedding_runtime(config: &CliConfig) -> Option<EmbeddingRuntime> {
1234 use tracing::warn;
1235
1236 match instantiate_embedding_runtime(config) {
1237 Ok(runtime) => Some(runtime),
1238 Err(err) => {
1239 warn!("semantic embeddings unavailable: {err}");
1240 None
1241 }
1242 }
1243}
1244
1245pub fn load_embedding_runtime_with_model(
1248 config: &CliConfig,
1249 model_override: Option<&str>,
1250) -> Result<EmbeddingRuntime> {
1251 use tracing::info;
1252
1253 let mut raw_override: Option<&str> = None;
1254 let embedding_model = match model_override {
1255 Some(model_str) => {
1256 raw_override = Some(model_str);
1257 let parsed = model_str.parse::<EmbeddingModelChoice>()?;
1258 if parsed.dimensions() > 0 {
1259 info!(
1260 "Using embedding model override: {} ({}D)",
1261 parsed.name(),
1262 parsed.dimensions()
1263 );
1264 } else {
1265 info!("Using embedding model override: {}", parsed.name());
1266 }
1267 parsed
1268 }
1269 None => config.embedding_model,
1270 };
1271
1272 if embedding_model.dimensions() > 0 {
1273 info!(
1274 "Loading embedding model: {} ({}D)",
1275 embedding_model.name(),
1276 embedding_model.dimensions()
1277 );
1278 } else {
1279 info!("Loading embedding model: {}", embedding_model.name());
1280 }
1281
1282 if config.offline && embedding_model.is_remote() {
1283 anyhow::bail!(
1284 "remote embeddings are unavailable while offline; set MEMVID_OFFLINE=0 or use a local embedding model"
1285 );
1286 }
1287
1288 if embedding_model.is_openai() {
1289 return instantiate_openai_runtime(embedding_model);
1290 }
1291
1292 if embedding_model == EmbeddingModelChoice::Nvidia {
1293 return instantiate_nvidia_runtime(raw_override);
1294 }
1295
1296 if embedding_model == EmbeddingModelChoice::Gemini {
1297 return instantiate_gemini_runtime();
1298 }
1299
1300 if embedding_model == EmbeddingModelChoice::Mistral {
1301 return instantiate_mistral_runtime();
1302 }
1303
1304 instantiate_fastembed_runtime(config, embedding_model)
1305}
1306
1307pub fn try_load_embedding_runtime_with_model(
1309 config: &CliConfig,
1310 model_override: Option<&str>,
1311) -> Option<EmbeddingRuntime> {
1312 use tracing::warn;
1313
1314 match load_embedding_runtime_with_model(config, model_override) {
1315 Ok(runtime) => Some(runtime),
1316 Err(err) => {
1317 warn!("semantic embeddings unavailable: {err}");
1318 None
1319 }
1320 }
1321}
1322
1323pub fn load_embedding_runtime_for_mv2(
1333 config: &CliConfig,
1334 model_override: Option<&str>,
1335 mv2_dimension: Option<u32>,
1336) -> Result<EmbeddingRuntime> {
1337 use tracing::info;
1338
1339 if let Some(model_str) = model_override {
1341 return load_embedding_runtime_with_model(config, Some(model_str));
1342 }
1343
1344 if let Some(dim) = mv2_dimension {
1346 if let Some(detected_model) = EmbeddingModelChoice::from_dimension(dim) {
1347 info!(
1348 "Auto-detected embedding model from MV2: {} ({}D)",
1349 detected_model.name(),
1350 dim
1351 );
1352
1353 if detected_model.is_openai() {
1355 if std::env::var("OPENAI_API_KEY").is_ok() {
1356 return load_embedding_runtime_with_model(config, Some(detected_model.name()));
1357 } else {
1358 return Err(anyhow!(
1360 "MV2 file uses OpenAI embeddings ({}D) but OPENAI_API_KEY is not set.\n\n\
1361 Options:\n\
1362 1. Set OPENAI_API_KEY environment variable\n\
1363 2. Use --query-embedding-model to specify a different model\n\
1364 3. Use lexical-only search with --mode lex\n\n\
1365 See: https://docs.memvid.com/embedding-models",
1366 dim
1367 ));
1368 }
1369 }
1370
1371 return load_embedding_runtime_with_model(config, Some(detected_model.name()));
1372 }
1373 }
1374
1375 load_embedding_runtime(config)
1377}
1378
1379pub fn try_load_embedding_runtime_for_mv2(
1381 config: &CliConfig,
1382 model_override: Option<&str>,
1383 mv2_dimension: Option<u32>,
1384) -> Option<EmbeddingRuntime> {
1385 use tracing::warn;
1386
1387 match load_embedding_runtime_for_mv2(config, model_override, mv2_dimension) {
1388 Ok(runtime) => Some(runtime),
1389 Err(err) => {
1390 warn!("semantic embeddings unavailable: {err}");
1391 None
1392 }
1393 }
1394}