1use std::env;
7use std::path::PathBuf;
8use std::str::FromStr;
9use std::sync::atomic::{AtomicUsize, Ordering};
10
11use anyhow::{anyhow, Result};
12use ed25519_dalek::VerifyingKey;
13
14const DEFAULT_API_URL: &str = "https://memvid.com";
15const DEFAULT_CACHE_DIR: &str = "~/.cache/memvid";
16
17#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
19pub enum EmbeddingModelChoice {
20 #[default]
22 BgeSmall,
23 BgeBase,
25 Nomic,
27 GteLarge,
29 OpenAILarge,
31 OpenAISmall,
33 OpenAIAda,
35 Nvidia,
37 Gemini,
39 Mistral,
41}
42
43impl EmbeddingModelChoice {
44 pub fn is_openai(&self) -> bool {
46 matches!(
47 self,
48 EmbeddingModelChoice::OpenAILarge
49 | EmbeddingModelChoice::OpenAISmall
50 | EmbeddingModelChoice::OpenAIAda
51 )
52 }
53
54 pub fn is_remote(&self) -> bool {
56 matches!(
57 self,
58 EmbeddingModelChoice::OpenAILarge
59 | EmbeddingModelChoice::OpenAISmall
60 | EmbeddingModelChoice::OpenAIAda
61 | EmbeddingModelChoice::Nvidia
62 | EmbeddingModelChoice::Gemini
63 | EmbeddingModelChoice::Mistral
64 )
65 }
66
67 pub fn to_fastembed_model(&self) -> fastembed::EmbeddingModel {
72 match self {
73 EmbeddingModelChoice::BgeSmall => fastembed::EmbeddingModel::BGESmallENV15,
74 EmbeddingModelChoice::BgeBase => fastembed::EmbeddingModel::BGEBaseENV15,
75 EmbeddingModelChoice::Nomic => fastembed::EmbeddingModel::NomicEmbedTextV15,
76 EmbeddingModelChoice::GteLarge => fastembed::EmbeddingModel::GTELargeENV15,
77 EmbeddingModelChoice::OpenAILarge
78 | EmbeddingModelChoice::OpenAISmall
79 | EmbeddingModelChoice::OpenAIAda => {
80 panic!("OpenAI models don't use fastembed. Check is_remote() first.")
81 }
82 EmbeddingModelChoice::Nvidia => {
83 panic!("NVIDIA embeddings don't use fastembed. Check is_remote() first.")
84 }
85 EmbeddingModelChoice::Gemini => {
86 panic!("Gemini embeddings don't use fastembed. Check is_remote() first.")
87 }
88 EmbeddingModelChoice::Mistral => {
89 panic!("Mistral embeddings don't use fastembed. Check is_remote() first.")
90 }
91 }
92 }
93
94 pub fn name(&self) -> &'static str {
96 match self {
97 EmbeddingModelChoice::BgeSmall => "bge-small",
98 EmbeddingModelChoice::BgeBase => "bge-base",
99 EmbeddingModelChoice::Nomic => "nomic",
100 EmbeddingModelChoice::GteLarge => "gte-large",
101 EmbeddingModelChoice::OpenAILarge => "openai-large",
102 EmbeddingModelChoice::OpenAISmall => "openai-small",
103 EmbeddingModelChoice::OpenAIAda => "openai-ada",
104 EmbeddingModelChoice::Nvidia => "nvidia",
105 EmbeddingModelChoice::Gemini => "gemini",
106 EmbeddingModelChoice::Mistral => "mistral",
107 }
108 }
109
110 pub fn canonical_model_id(&self) -> &'static str {
116 match self {
117 EmbeddingModelChoice::BgeSmall => "BAAI/bge-small-en-v1.5",
118 EmbeddingModelChoice::BgeBase => "BAAI/bge-base-en-v1.5",
119 EmbeddingModelChoice::Nomic => "nomic-embed-text-v1.5",
120 EmbeddingModelChoice::GteLarge => "thenlper/gte-large",
121 EmbeddingModelChoice::OpenAILarge => "text-embedding-3-large",
122 EmbeddingModelChoice::OpenAISmall => "text-embedding-3-small",
123 EmbeddingModelChoice::OpenAIAda => "text-embedding-ada-002",
124 EmbeddingModelChoice::Nvidia => "nvidia/nv-embed-v1",
125 EmbeddingModelChoice::Gemini => "text-embedding-004",
126 EmbeddingModelChoice::Mistral => "mistral-embed",
127 }
128 }
129
130 pub fn dimensions(&self) -> usize {
132 match self {
133 EmbeddingModelChoice::BgeSmall => 384,
134 EmbeddingModelChoice::BgeBase => 768,
135 EmbeddingModelChoice::Nomic => 768,
136 EmbeddingModelChoice::GteLarge => 1024,
137 EmbeddingModelChoice::OpenAILarge => 3072,
138 EmbeddingModelChoice::OpenAISmall => 1536,
139 EmbeddingModelChoice::OpenAIAda => 1536,
140 EmbeddingModelChoice::Nvidia => 0,
142 EmbeddingModelChoice::Gemini => 768,
143 EmbeddingModelChoice::Mistral => 1024,
144 }
145 }
146}
147
148impl FromStr for EmbeddingModelChoice {
149 type Err = anyhow::Error;
150
151 fn from_str(s: &str) -> Result<Self> {
152 let lowered = s.trim().to_ascii_lowercase();
153 match lowered.as_str() {
154 "bge-small" | "bge_small" | "bgesmall" | "small" => Ok(EmbeddingModelChoice::BgeSmall),
155 "baai/bge-small-en-v1.5" => Ok(EmbeddingModelChoice::BgeSmall),
156 "bge-base" | "bge_base" | "bgebase" | "base" => Ok(EmbeddingModelChoice::BgeBase),
157 "baai/bge-base-en-v1.5" => Ok(EmbeddingModelChoice::BgeBase),
158 "nomic" | "nomic-embed" | "nomic_embed" => Ok(EmbeddingModelChoice::Nomic),
159 "nomic-embed-text-v1.5" => Ok(EmbeddingModelChoice::Nomic),
160 "gte-large" | "gte_large" | "gtelarge" | "gte" => Ok(EmbeddingModelChoice::GteLarge),
161 "thenlper/gte-large" => Ok(EmbeddingModelChoice::GteLarge),
162 "openai" | "openai-large" | "openai_large" | "text-embedding-3-large" => {
164 Ok(EmbeddingModelChoice::OpenAILarge)
165 }
166 "openai-small" | "openai_small" | "text-embedding-3-small" => {
167 Ok(EmbeddingModelChoice::OpenAISmall)
168 }
169 "openai-ada" | "openai_ada" | "text-embedding-ada-002" | "ada" => {
170 Ok(EmbeddingModelChoice::OpenAIAda)
171 }
172 "nvidia" | "nv" | "nv-embed-v1" | "nvidia/nv-embed-v1" => Ok(EmbeddingModelChoice::Nvidia),
173 _ if lowered.starts_with("nvidia/") || lowered.starts_with("nvidia:") || lowered.starts_with("nv:") => {
174 Ok(EmbeddingModelChoice::Nvidia)
175 }
176 "gemini" | "gemini-embed" | "text-embedding-004" | "gemini-embedding-001" => {
178 Ok(EmbeddingModelChoice::Gemini)
179 }
180 _ if lowered.starts_with("gemini/") || lowered.starts_with("gemini:") || lowered.starts_with("google:") => {
181 Ok(EmbeddingModelChoice::Gemini)
182 }
183 "mistral" | "mistral-embed" => Ok(EmbeddingModelChoice::Mistral),
185 _ if lowered.starts_with("mistral/") || lowered.starts_with("mistral:") => {
186 Ok(EmbeddingModelChoice::Mistral)
187 }
188 _ => Err(anyhow!(
189 "unknown embedding model '{}'. Valid options: bge-small, bge-base, nomic, gte-large, openai, openai-small, openai-ada, nvidia, gemini, mistral",
190 s
191 )),
192 }
193 }
194}
195
196impl std::fmt::Display for EmbeddingModelChoice {
197 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
198 write!(f, "{}", self.name())
199 }
200}
201
202impl EmbeddingModelChoice {
203 pub fn from_dimension(dim: u32) -> Option<Self> {
215 match dim {
216 384 => Some(EmbeddingModelChoice::BgeSmall),
217 768 => Some(EmbeddingModelChoice::BgeBase), 1024 => Some(EmbeddingModelChoice::GteLarge),
219 1536 => Some(EmbeddingModelChoice::OpenAISmall), 3072 => Some(EmbeddingModelChoice::OpenAILarge),
221 0 => None, _ => {
223 tracing::warn!("Unknown embedding dimension {}, using default model", dim);
224 None
225 }
226 }
227 }
228}
229
230#[derive(Debug, Clone)]
232pub struct CliConfig {
233 pub api_key: Option<String>,
234 pub api_url: String,
235 pub cache_dir: PathBuf,
236 pub ticket_pubkey: Option<VerifyingKey>,
237 pub models_dir: PathBuf,
238 pub offline: bool,
239 pub embedding_model: EmbeddingModelChoice,
241}
242
243impl PartialEq for CliConfig {
244 fn eq(&self, other: &Self) -> bool {
245 self.api_key == other.api_key
246 && self.api_url == other.api_url
247 && self.cache_dir == other.cache_dir
248 && self.models_dir == other.models_dir
249 && self.offline == other.offline
250 && self.embedding_model == other.embedding_model
251 }
252}
253
254impl Eq for CliConfig {}
255
256impl CliConfig {
257 pub fn load() -> Result<Self> {
258 let api_key = env::var("MEMVID_API_KEY").ok().and_then(|value| {
259 let trimmed = value.trim().to_string();
260 (!trimmed.is_empty()).then_some(trimmed)
261 });
262
263 let api_url = env::var("MEMVID_API_URL").unwrap_or_else(|_| DEFAULT_API_URL.to_string());
264
265 let cache_dir_raw =
266 env::var("MEMVID_CACHE_DIR").unwrap_or_else(|_| DEFAULT_CACHE_DIR.to_string());
267 let cache_dir = expand_path(&cache_dir_raw)?;
268
269 let models_dir_raw =
270 env::var("MEMVID_MODELS_DIR").unwrap_or_else(|_| "~/.memvid/models".to_string());
271 let models_dir = expand_path(&models_dir_raw)?;
272
273 const DEFAULT_TICKET_PUBKEY: &str = "8wP1J2H+Tlx3PM3eT0lN2wDvoYrvl1DREKGKVb/V2cw=";
276
277 let ticket_pubkey_str = env::var("MEMVID_TICKET_PUBKEY")
278 .ok()
279 .and_then(|value| {
280 let trimmed = value.trim();
281 if trimmed.is_empty() {
282 None
283 } else {
284 Some(trimmed.to_string())
285 }
286 })
287 .unwrap_or_else(|| DEFAULT_TICKET_PUBKEY.to_string());
288
289 let ticket_pubkey = Some(memvid_core::parse_ed25519_public_key_base64(&ticket_pubkey_str)?);
290
291 let offline = env::var("MEMVID_OFFLINE")
292 .ok()
293 .map(|value| match value.trim().to_ascii_lowercase().as_str() {
294 "1" | "true" | "yes" => true,
295 _ => false,
296 })
297 .unwrap_or(false);
298
299 let embedding_model = env::var("MEMVID_EMBEDDING_MODEL")
301 .ok()
302 .and_then(|value| {
303 let trimmed = value.trim();
304 if trimmed.is_empty() {
305 None
306 } else {
307 EmbeddingModelChoice::from_str(trimmed).ok()
308 }
309 })
310 .unwrap_or_default();
311
312 Ok(Self {
313 api_key,
314 api_url,
315 cache_dir,
316 ticket_pubkey,
317 models_dir,
318 offline,
319 embedding_model,
320 })
321 }
322
323 pub fn with_embedding_model(&self, model: EmbeddingModelChoice) -> Self {
325 Self {
326 embedding_model: model,
327 ..self.clone()
328 }
329 }
330}
331
332fn expand_path(value: &str) -> Result<PathBuf> {
333 if value.trim().is_empty() {
334 return Err(anyhow!("cache directory cannot be empty"));
335 }
336
337 let expanded = if let Some(stripped) = value.strip_prefix("~/") {
338 home_dir()?.join(stripped)
339 } else if let Some(stripped) = value.strip_prefix("~\\") {
340 home_dir()?.join(stripped)
342 } else if value == "~" {
343 home_dir()?
344 } else {
345 PathBuf::from(value)
346 };
347
348 if expanded.is_absolute() {
349 Ok(expanded)
350 } else {
351 Ok(env::current_dir()?.join(expanded))
352 }
353}
354
355fn home_dir() -> Result<PathBuf> {
356 if let Some(path) = env::var_os("HOME") {
357 if !path.is_empty() {
358 return Ok(PathBuf::from(path));
359 }
360 }
361
362 #[cfg(windows)]
363 {
364 if let Some(path) = env::var_os("USERPROFILE") {
365 if !path.is_empty() {
366 return Ok(PathBuf::from(path));
367 }
368 }
369 if let (Some(drive), Some(path)) = (env::var_os("HOMEDRIVE"), env::var_os("HOMEPATH")) {
370 if !drive.is_empty() && !path.is_empty() {
371 return Ok(PathBuf::from(format!(
372 "{}{}",
373 drive.to_string_lossy(),
374 path.to_string_lossy()
375 )));
376 }
377 }
378 }
379
380 Err(anyhow!("unable to resolve home directory"))
381}
382
383#[cfg(test)]
384mod tests {
385 use super::*;
386 use base64::engine::general_purpose::STANDARD as BASE64_STANDARD;
387 use base64::Engine;
388 use ed25519_dalek::SigningKey;
389 use std::sync::{Mutex, OnceLock};
390
391 fn env_lock() -> std::sync::MutexGuard<'static, ()> {
392 static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
393 LOCK.get_or_init(|| Mutex::new(())).lock().unwrap()
394 }
395
396 fn set_or_unset(var: &str, value: Option<String>) {
397 match value {
398 Some(v) => unsafe { env::set_var(var, v) },
399 None => unsafe { env::remove_var(var) },
400 }
401 }
402
403 #[test]
404 fn defaults_expand_using_home_directory() {
405 let _guard = env_lock();
406
407 let previous_home = env::var("HOME").ok();
408 #[cfg(windows)]
409 let previous_userprofile = env::var("USERPROFILE").ok();
410
411 for var in [
412 "MEMVID_API_KEY",
413 "MEMVID_API_URL",
414 "MEMVID_CACHE_DIR",
415 "MEMVID_TICKET_PUBKEY",
416 "MEMVID_MODELS_DIR",
417 "MEMVID_OFFLINE",
418 ] {
419 unsafe { env::remove_var(var) };
420 }
421
422 let tmp = tempfile::tempdir().expect("tmpdir");
423 let tmp_path = tmp.path().to_path_buf();
424 unsafe { env::set_var("HOME", &tmp_path) };
425 #[cfg(windows)]
426 unsafe {
427 env::set_var("USERPROFILE", &tmp_path)
428 };
429
430 let config = CliConfig::load().expect("load");
431 assert_eq!(config.api_key, None);
432 assert_eq!(config.api_url, "https://memvid.com");
433 assert_eq!(config.cache_dir, tmp_path.join(".cache/memvid"));
434 assert!(config.ticket_pubkey.is_some());
436 assert_eq!(config.models_dir, tmp_path.join(".memvid/models"));
437 assert!(!config.offline);
438
439 set_or_unset("HOME", previous_home);
440 #[cfg(windows)]
441 {
442 set_or_unset("USERPROFILE", previous_userprofile);
443 }
444 }
445
446 #[test]
447 fn env_overrides_are_respected() {
448 let _guard = env_lock();
449
450 let previous_env: Vec<(&'static str, Option<String>)> = [
451 "MEMVID_API_KEY",
452 "MEMVID_API_URL",
453 "MEMVID_CACHE_DIR",
454 "MEMVID_TICKET_PUBKEY",
455 "MEMVID_MODELS_DIR",
456 "MEMVID_OFFLINE",
457 ]
458 .into_iter()
459 .map(|var| (var, env::var(var).ok()))
460 .collect();
461
462 unsafe { env::set_var("MEMVID_API_KEY", "abc123") };
463 unsafe { env::set_var("MEMVID_API_URL", "https://staging.memvid.app") };
464 unsafe { env::set_var("MEMVID_CACHE_DIR", "~/memvid-cache") };
465 unsafe { env::set_var("MEMVID_MODELS_DIR", "~/models") };
466 unsafe { env::set_var("MEMVID_OFFLINE", "true") };
467 let signing = SigningKey::from_bytes(&[9u8; 32]);
468 let encoded = BASE64_STANDARD.encode(signing.verifying_key().as_bytes());
469 unsafe { env::set_var("MEMVID_TICKET_PUBKEY", encoded) };
470
471 let tmp = tempfile::tempdir().expect("tmpdir");
472 let tmp_path = tmp.path().to_path_buf();
473 unsafe { env::set_var("HOME", &tmp_path) };
474 #[cfg(windows)]
475 unsafe {
476 env::set_var("USERPROFILE", &tmp_path)
477 };
478
479 let config = CliConfig::load().expect("load");
480 assert_eq!(config.api_key.as_deref(), Some("abc123"));
481 assert_eq!(config.api_url, "https://staging.memvid.app");
482 assert_eq!(config.cache_dir, tmp_path.join("memvid-cache"));
483 assert_eq!(
484 config.ticket_pubkey.expect("pubkey").as_bytes(),
485 signing.verifying_key().as_bytes()
486 );
487 assert_eq!(config.models_dir, tmp_path.join("models"));
488 assert!(config.offline);
489
490 for (var, value) in previous_env {
491 set_or_unset(var, value);
492 }
493 }
494
495 #[test]
496 fn rejects_empty_cache_dir() {
497 let _guard = env_lock();
498
499 let previous = env::var("MEMVID_CACHE_DIR").ok();
500 unsafe { env::set_var("MEMVID_CACHE_DIR", " ") };
501 let err = CliConfig::load().expect_err("should fail");
502 assert!(err.to_string().contains("cache directory"));
503 set_or_unset("MEMVID_CACHE_DIR", previous);
504 }
505}
506
507pub fn init_tracing(verbosity: u8) -> Result<()> {
509 use std::io::IsTerminal;
510 use tracing_subscriber::{filter::Directive, fmt, EnvFilter};
511
512 let level = match verbosity {
513 0 => "warn",
514 1 => "info",
515 2 => "debug",
516 _ => "trace",
517 };
518
519 let mut env_filter =
520 EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new(level));
521 for directive_str in ["llama_cpp=error", "llama_cpp_sys=error", "ggml=error"] {
522 if let Ok(directive) = directive_str.parse::<Directive>() {
523 env_filter = env_filter.add_directive(directive);
524 }
525 }
526
527 let use_ansi = std::io::stderr().is_terminal();
531
532 fmt()
533 .with_env_filter(env_filter)
534 .with_writer(std::io::stderr)
535 .with_target(false)
536 .without_time()
537 .with_ansi(use_ansi)
538 .try_init()
539 .map_err(|err| anyhow!(err))?;
540 Ok(())
541}
542
543pub fn resolve_llm_context_budget_override(cli_value: Option<usize>) -> Result<Option<usize>> {
545 use anyhow::bail;
546
547 if let Some(value) = cli_value {
548 if value == 0 {
549 bail!("--llm-context-depth must be a positive integer");
550 }
551 return Ok(Some(value));
552 }
553
554 let raw_env = match env::var("MEMVID_LLM_CONTEXT_BUDGET") {
555 Ok(value) => value,
556 Err(_) => return Ok(None),
557 };
558
559 let trimmed = raw_env.trim();
560 if trimmed.is_empty() {
561 return Ok(None);
562 }
563
564 let digits: String = trimmed
565 .chars()
566 .filter(|ch| !ch.is_ascii_whitespace() && *ch != '_')
567 .collect();
568
569 if digits.is_empty() {
570 bail!("MEMVID_LLM_CONTEXT_BUDGET must be a positive integer value");
571 }
572
573 let value: usize = digits.parse().map_err(|err| {
574 anyhow!(
575 "MEMVID_LLM_CONTEXT_BUDGET value '{}' is not a valid number: {}",
576 trimmed,
577 err
578 )
579 })?;
580
581 if value == 0 {
582 bail!("MEMVID_LLM_CONTEXT_BUDGET must be a positive integer");
583 }
584
585 Ok(Some(value))
586}
587
588use crate::gemini_embeddings::GeminiEmbeddingProvider;
589use crate::mistral_embeddings::MistralEmbeddingProvider;
590use crate::nvidia_embeddings::NvidiaEmbeddingProvider;
591use crate::openai_embeddings::OpenAIEmbeddingProvider;
592
593#[derive(Clone)]
595enum EmbeddingBackend {
596 FastEmbed(std::sync::Arc<std::sync::Mutex<fastembed::TextEmbedding>>),
597 OpenAI(std::sync::Arc<OpenAIEmbeddingProvider>),
598 Nvidia(std::sync::Arc<NvidiaEmbeddingProvider>),
599 Gemini(std::sync::Arc<GeminiEmbeddingProvider>),
600 Mistral(std::sync::Arc<MistralEmbeddingProvider>),
601}
602
603#[derive(Clone)]
605pub struct EmbeddingRuntime {
606 backend: EmbeddingBackend,
607 model: EmbeddingModelChoice,
608 dimension: std::sync::Arc<AtomicUsize>,
609}
610
611impl EmbeddingRuntime {
612 fn new_fastembed(
613 backend: fastembed::TextEmbedding,
614 model: EmbeddingModelChoice,
615 dimension: usize,
616 ) -> Self {
617 Self {
618 backend: EmbeddingBackend::FastEmbed(std::sync::Arc::new(std::sync::Mutex::new(
619 backend,
620 ))),
621 model,
622 dimension: std::sync::Arc::new(AtomicUsize::new(dimension)),
623 }
624 }
625
626 fn new_openai(
627 provider: OpenAIEmbeddingProvider,
628 model: EmbeddingModelChoice,
629 dimension: usize,
630 ) -> Self {
631 Self {
632 backend: EmbeddingBackend::OpenAI(std::sync::Arc::new(provider)),
633 model,
634 dimension: std::sync::Arc::new(AtomicUsize::new(dimension)),
635 }
636 }
637
638 fn new_nvidia(provider: NvidiaEmbeddingProvider, model: EmbeddingModelChoice) -> Self {
639 Self {
640 backend: EmbeddingBackend::Nvidia(std::sync::Arc::new(provider)),
641 model,
642 dimension: std::sync::Arc::new(AtomicUsize::new(0)),
643 }
644 }
645
646 fn new_gemini(
647 provider: GeminiEmbeddingProvider,
648 model: EmbeddingModelChoice,
649 dimension: usize,
650 ) -> Self {
651 Self {
652 backend: EmbeddingBackend::Gemini(std::sync::Arc::new(provider)),
653 model,
654 dimension: std::sync::Arc::new(AtomicUsize::new(dimension)),
655 }
656 }
657
658 fn new_mistral(
659 provider: MistralEmbeddingProvider,
660 model: EmbeddingModelChoice,
661 dimension: usize,
662 ) -> Self {
663 Self {
664 backend: EmbeddingBackend::Mistral(std::sync::Arc::new(provider)),
665 model,
666 dimension: std::sync::Arc::new(AtomicUsize::new(dimension)),
667 }
668 }
669
670 const MAX_OPENAI_EMBEDDING_TEXT_LEN: usize = 20_000;
671 const MAX_NVIDIA_EMBEDDING_TEXT_LEN: usize = 12_000;
673
674 const MAX_GEMINI_EMBEDDING_TEXT_LEN: usize = 20_000;
676 const MAX_MISTRAL_EMBEDDING_TEXT_LEN: usize = 20_000;
678
679 fn max_remote_embedding_chars(&self) -> usize {
680 match &self.backend {
681 EmbeddingBackend::OpenAI(_) => Self::MAX_OPENAI_EMBEDDING_TEXT_LEN,
682 EmbeddingBackend::Nvidia(_) => Self::MAX_NVIDIA_EMBEDDING_TEXT_LEN,
683 EmbeddingBackend::Gemini(_) => Self::MAX_GEMINI_EMBEDDING_TEXT_LEN,
684 EmbeddingBackend::Mistral(_) => Self::MAX_MISTRAL_EMBEDDING_TEXT_LEN,
685 EmbeddingBackend::FastEmbed(_) => usize::MAX,
686 }
687 }
688
689 fn truncate_for_embedding<'a>(
691 text: &'a str,
692 max_chars: usize,
693 ) -> std::borrow::Cow<'a, str> {
694 if text.len() <= max_chars {
695 std::borrow::Cow::Borrowed(text)
696 } else {
697 let truncated = &text[..max_chars];
699 let end = truncated
700 .char_indices()
701 .rev()
702 .next()
703 .map(|(i, c)| i + c.len_utf8())
704 .unwrap_or(max_chars);
705 tracing::info!("Truncated embedding text from {} to {} chars", text.len(), end);
706 std::borrow::Cow::Owned(text[..end].to_string())
707 }
708 }
709
710 fn note_dimension(&self, observed: usize) -> Result<()> {
711 if observed == 0 {
712 return Err(anyhow!("embedding provider returned zero-length embedding"));
713 }
714
715 let current = self.dimension.load(Ordering::Relaxed);
716 if current == 0 {
717 self.dimension.store(observed, Ordering::Relaxed);
718 return Ok(());
719 }
720
721 if current != observed {
722 return Err(anyhow!(
723 "embedding provider returned {observed}D vectors but runtime expects {current}D"
724 ));
725 }
726
727 Ok(())
728 }
729
730 fn truncate_if_remote<'a>(&self, text: &'a str) -> std::borrow::Cow<'a, str> {
731 match &self.backend {
732 EmbeddingBackend::OpenAI(_)
733 | EmbeddingBackend::Nvidia(_)
734 | EmbeddingBackend::Gemini(_)
735 | EmbeddingBackend::Mistral(_) => {
736 Self::truncate_for_embedding(text, self.max_remote_embedding_chars())
737 }
738 EmbeddingBackend::FastEmbed(_) => std::borrow::Cow::Borrowed(text),
739 }
740 }
741
742 pub fn embed_passage(&self, text: &str) -> Result<Vec<f32>> {
743 let text = self.truncate_if_remote(text);
744 let embedding = match &self.backend {
745 EmbeddingBackend::FastEmbed(model) => {
746 let mut guard = model
747 .lock()
748 .map_err(|_| anyhow!("fastembed runtime poisoned"))?;
749 let outputs = guard
750 .embed(vec![text.into_owned()], None)
751 .map_err(|err| anyhow!("failed to compute embedding with fastembed: {err}"))?;
752 outputs
753 .into_iter()
754 .next()
755 .ok_or_else(|| anyhow!("fastembed returned no embedding output"))?
756 }
757 EmbeddingBackend::OpenAI(provider) => {
758 use memvid_core::EmbeddingProvider;
759 provider
760 .embed_text(&text)
761 .map_err(|err| anyhow!("failed to compute embedding with OpenAI: {err}"))?
762 }
763 EmbeddingBackend::Nvidia(provider) => provider
764 .embed_passage(&text)
765 .map_err(|err| anyhow!("failed to compute embedding with NVIDIA: {err}"))?,
766 EmbeddingBackend::Gemini(provider) => provider
767 .embed_text(&text)
768 .map_err(|err| anyhow!("failed to compute embedding with Gemini: {err}"))?,
769 EmbeddingBackend::Mistral(provider) => provider
770 .embed_text(&text)
771 .map_err(|err| anyhow!("failed to compute embedding with Mistral: {err}"))?,
772 };
773
774 self.note_dimension(embedding.len())?;
775 Ok(embedding)
776 }
777
778 pub fn embed_query(&self, text: &str) -> Result<Vec<f32>> {
779 let text = self.truncate_if_remote(text);
780 match &self.backend {
781 EmbeddingBackend::Nvidia(provider) => {
782 let embedding = provider
783 .embed_query(&text)
784 .map_err(|err| anyhow!("failed to compute embedding with NVIDIA: {err}"))?;
785 self.note_dimension(embedding.len())?;
786 Ok(embedding)
787 }
788 _ => self.embed_passage(&text),
789 }
790 }
791
792 pub fn embed_batch_passages(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
793 if texts.is_empty() {
794 return Ok(Vec::new());
795 }
796
797 let truncated: Vec<std::borrow::Cow<'_, str>> =
798 texts.iter().map(|t| self.truncate_if_remote(t)).collect();
799 let truncated_refs: Vec<&str> = truncated.iter().map(|c| c.as_ref()).collect();
800
801 let embeddings = match &self.backend {
802 EmbeddingBackend::FastEmbed(model) => {
803 let mut guard = model
804 .lock()
805 .map_err(|_| anyhow!("fastembed runtime poisoned"))?;
806 guard
807 .embed(
808 truncated_refs
809 .iter()
810 .map(|s| (*s).to_string())
811 .collect::<Vec<String>>(),
812 None,
813 )
814 .map_err(|err| anyhow!("failed to compute embeddings with fastembed: {err}"))?
815 }
816 EmbeddingBackend::OpenAI(provider) => {
817 use memvid_core::EmbeddingProvider;
818 provider
819 .embed_batch(&truncated_refs)
820 .map_err(|err| anyhow!("failed to compute embeddings with OpenAI: {err}"))?
821 }
822 EmbeddingBackend::Nvidia(provider) => provider
823 .embed_passages(&truncated_refs)
824 .map_err(|err| anyhow!("failed to compute embeddings with NVIDIA: {err}"))?,
825 EmbeddingBackend::Gemini(provider) => provider
826 .embed_batch(&truncated_refs)
827 .map_err(|err| anyhow!("failed to compute embeddings with Gemini: {err}"))?,
828 EmbeddingBackend::Mistral(provider) => provider
829 .embed_batch(&truncated_refs)
830 .map_err(|err| anyhow!("failed to compute embeddings with Mistral: {err}"))?,
831 };
832
833 if let Some(first) = embeddings.first() {
834 self.note_dimension(first.len())?;
835 }
836 if let Some(expected) = embeddings.first().map(|e| e.len()) {
837 if embeddings.iter().any(|e| e.len() != expected) {
838 return Err(anyhow!("embedding provider returned mixed vector dimensions"));
839 }
840 }
841
842 Ok(embeddings)
843 }
844
845 pub fn embed_batch_queries(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
846 if texts.is_empty() {
847 return Ok(Vec::new());
848 }
849
850 let truncated: Vec<std::borrow::Cow<'_, str>> =
851 texts.iter().map(|t| self.truncate_if_remote(t)).collect();
852 let truncated_refs: Vec<&str> = truncated.iter().map(|c| c.as_ref()).collect();
853
854 match &self.backend {
855 EmbeddingBackend::Nvidia(provider) => {
856 let embeddings = provider
857 .embed_queries(&truncated_refs)
858 .map_err(|err| anyhow!("failed to compute embeddings with NVIDIA: {err}"))?;
859
860 if let Some(first) = embeddings.first() {
861 self.note_dimension(first.len())?;
862 }
863 if let Some(expected) = embeddings.first().map(|e| e.len()) {
864 if embeddings.iter().any(|e| e.len() != expected) {
865 return Err(anyhow!("embedding provider returned mixed vector dimensions"));
866 }
867 }
868
869 Ok(embeddings)
870 }
871 _ => self.embed_batch_passages(&truncated_refs),
872 }
873 }
874
875 pub fn dimension(&self) -> usize {
876 self.dimension.load(Ordering::Relaxed)
877 }
878
879 pub fn model_choice(&self) -> EmbeddingModelChoice {
880 self.model
881 }
882
883 pub fn provider_kind(&self) -> &'static str {
884 match &self.backend {
885 EmbeddingBackend::FastEmbed(_) => "fastembed",
886 EmbeddingBackend::OpenAI(_) => "openai",
887 EmbeddingBackend::Nvidia(_) => "nvidia",
888 EmbeddingBackend::Gemini(_) => "gemini",
889 EmbeddingBackend::Mistral(_) => "mistral",
890 }
891 }
892
893 pub fn provider_model_id(&self) -> String {
894 match &self.backend {
895 EmbeddingBackend::FastEmbed(_) => self.model.canonical_model_id().to_string(),
896 EmbeddingBackend::OpenAI(provider) => {
897 use memvid_core::EmbeddingProvider;
898 provider.model().to_string()
899 }
900 EmbeddingBackend::Nvidia(provider) => provider.model().to_string(),
901 EmbeddingBackend::Gemini(provider) => provider.model().to_string(),
902 EmbeddingBackend::Mistral(provider) => provider.model().to_string(),
903 }
904 }
905}
906
907impl memvid_core::VecEmbedder for EmbeddingRuntime {
908 fn embed_query(&self, text: &str) -> memvid_core::Result<Vec<f32>> {
909 EmbeddingRuntime::embed_query(self, text).map_err(|err| {
910 memvid_core::MemvidError::EmbeddingFailed {
911 reason: err.to_string().into_boxed_str(),
912 }
913 })
914 }
915
916 fn embedding_dimension(&self) -> usize {
917 self.dimension()
918 }
919}
920
921fn ensure_fastembed_cache(config: &CliConfig) -> Result<PathBuf> {
923 use std::fs;
924
925 let cache_dir = config.models_dir.clone();
926 fs::create_dir_all(&cache_dir)?;
927 Ok(cache_dir)
928}
929
930fn model_size_mb(model: EmbeddingModelChoice) -> usize {
932 match model {
933 EmbeddingModelChoice::BgeSmall => 33,
934 EmbeddingModelChoice::BgeBase => 110,
935 EmbeddingModelChoice::Nomic => 137,
936 EmbeddingModelChoice::GteLarge => 327,
937 EmbeddingModelChoice::OpenAILarge
939 | EmbeddingModelChoice::OpenAISmall
940 | EmbeddingModelChoice::OpenAIAda
941 | EmbeddingModelChoice::Nvidia
942 | EmbeddingModelChoice::Gemini
943 | EmbeddingModelChoice::Mistral => 0,
944 }
945}
946
947fn instantiate_embedding_runtime(config: &CliConfig) -> Result<EmbeddingRuntime> {
949 use tracing::info;
950
951 let embedding_model = config.embedding_model;
952
953 if embedding_model.dimensions() > 0 {
954 info!(
955 "Loading embedding model: {} ({}D)",
956 embedding_model.name(),
957 embedding_model.dimensions()
958 );
959 } else {
960 info!("Loading embedding model: {}", embedding_model.name());
961 }
962
963 if config.offline && embedding_model.is_remote() {
964 anyhow::bail!(
965 "remote embeddings are unavailable while offline; set MEMVID_OFFLINE=0 or use a local embedding model"
966 );
967 }
968
969 if embedding_model.is_openai() {
971 return instantiate_openai_runtime(embedding_model);
972 }
973
974 if embedding_model == EmbeddingModelChoice::Nvidia {
975 return instantiate_nvidia_runtime(None);
976 }
977
978 if embedding_model == EmbeddingModelChoice::Gemini {
979 return instantiate_gemini_runtime();
980 }
981
982 if embedding_model == EmbeddingModelChoice::Mistral {
983 return instantiate_mistral_runtime();
984 }
985
986 instantiate_fastembed_runtime(config, embedding_model)
988}
989
990fn instantiate_openai_runtime(embedding_model: EmbeddingModelChoice) -> Result<EmbeddingRuntime> {
992 use anyhow::bail;
993 use memvid_core::EmbeddingConfig;
994 use tracing::info;
995
996 let api_key = std::env::var("OPENAI_API_KEY").map_err(|_| {
997 anyhow!("OPENAI_API_KEY environment variable is required for OpenAI embeddings")
998 })?;
999
1000 if api_key.is_empty() {
1001 bail!("OPENAI_API_KEY cannot be empty");
1002 }
1003
1004 let config = match embedding_model {
1005 EmbeddingModelChoice::OpenAILarge => EmbeddingConfig::openai_large(),
1006 EmbeddingModelChoice::OpenAISmall => EmbeddingConfig::openai_small(),
1007 EmbeddingModelChoice::OpenAIAda => EmbeddingConfig::openai_ada(),
1008 _ => unreachable!("is_openai() should have been false"),
1009 };
1010
1011 let provider = OpenAIEmbeddingProvider::new(api_key, config.clone())
1012 .map_err(|err| anyhow!("failed to create OpenAI embedding provider: {err}"))?;
1013
1014 info!(
1015 "OpenAI embedding provider ready: model={}, dimension={}",
1016 config.model, config.dimension
1017 );
1018
1019 Ok(EmbeddingRuntime::new_openai(
1020 provider,
1021 embedding_model,
1022 config.dimension,
1023 ))
1024}
1025
1026fn normalize_nvidia_embedding_model_override(raw: &str) -> Option<String> {
1027 let trimmed = raw.trim();
1028 if trimmed.is_empty() {
1029 return None;
1030 }
1031
1032 let lowered = trimmed.to_ascii_lowercase();
1033 if lowered == "nvidia" || lowered == "nv" {
1034 return None;
1035 }
1036
1037 let without_prefix = trimmed
1038 .strip_prefix("nvidia:")
1039 .or_else(|| trimmed.strip_prefix("nv:"))
1040 .unwrap_or(trimmed)
1041 .trim();
1042
1043 if without_prefix.is_empty() {
1044 return None;
1045 }
1046
1047 if without_prefix.eq_ignore_ascii_case("nv-embed-v1") {
1048 return Some("nvidia/nv-embed-v1".to_string());
1049 }
1050
1051 if without_prefix.contains('/') {
1052 return Some(without_prefix.to_string());
1053 }
1054
1055 Some(format!("nvidia/{without_prefix}"))
1056}
1057
1058fn instantiate_nvidia_runtime(model_override: Option<&str>) -> Result<EmbeddingRuntime> {
1060 use tracing::info;
1061
1062 let normalized = model_override.and_then(normalize_nvidia_embedding_model_override);
1063 let provider = NvidiaEmbeddingProvider::from_env(normalized.as_deref())
1064 .map_err(|err| anyhow!("failed to create NVIDIA embedding provider: {err}"))?;
1065
1066 info!(
1067 "NVIDIA embedding provider ready: model={}",
1068 provider.model()
1069 );
1070
1071 Ok(EmbeddingRuntime::new_nvidia(
1072 provider,
1073 EmbeddingModelChoice::Nvidia,
1074 ))
1075}
1076
1077fn instantiate_gemini_runtime() -> Result<EmbeddingRuntime> {
1079 use tracing::info;
1080
1081 let provider = GeminiEmbeddingProvider::from_env()
1082 .map_err(|err| anyhow!("failed to create Gemini embedding provider: {err}"))?;
1083
1084 let dimension = provider.dimension();
1085 info!(
1086 "Gemini embedding provider ready: model={}, dimension={}",
1087 provider.model(),
1088 dimension
1089 );
1090
1091 Ok(EmbeddingRuntime::new_gemini(
1092 provider,
1093 EmbeddingModelChoice::Gemini,
1094 dimension,
1095 ))
1096}
1097
1098fn instantiate_mistral_runtime() -> Result<EmbeddingRuntime> {
1100 use tracing::info;
1101
1102 let provider = MistralEmbeddingProvider::from_env()
1103 .map_err(|err| anyhow!("failed to create Mistral embedding provider: {err}"))?;
1104
1105 let dimension = provider.dimension();
1106 info!(
1107 "Mistral embedding provider ready: model={}, dimension={}",
1108 provider.model(),
1109 dimension
1110 );
1111
1112 Ok(EmbeddingRuntime::new_mistral(
1113 provider,
1114 EmbeddingModelChoice::Mistral,
1115 dimension,
1116 ))
1117}
1118
1119fn instantiate_fastembed_runtime(
1121 config: &CliConfig,
1122 embedding_model: EmbeddingModelChoice,
1123) -> Result<EmbeddingRuntime> {
1124 use anyhow::bail;
1125 use fastembed::{InitOptions, TextEmbedding};
1126 use std::fs;
1127
1128 let cache_dir = ensure_fastembed_cache(config)?;
1129
1130 if config.offline {
1131 let mut entries = fs::read_dir(&cache_dir)?;
1132 if entries.next().is_none() {
1133 bail!(
1134 "semantic embeddings unavailable while offline; allow one connected run so fastembed can cache model weights"
1135 );
1136 }
1137 }
1138
1139 let options = InitOptions::new(embedding_model.to_fastembed_model())
1140 .with_cache_dir(cache_dir)
1141 .with_show_download_progress(true);
1142 let mut model = TextEmbedding::try_new(options).map_err(|err| {
1143 let platform_hint = if cfg!(target_os = "windows") {
1145 "\n\nWindows users: If model downloads fail, try:\n\
1146 1. Run as Administrator\n\
1147 2. Check your antivirus isn't blocking downloads\n\
1148 3. Use OpenAI embeddings instead: set OPENAI_API_KEY and use --embedding-model openai"
1149 } else if cfg!(target_os = "linux") {
1150 "\n\nLinux users: If model downloads fail, try:\n\
1151 1. Check disk space in ~/.memvid/models\n\
1152 2. Ensure you have network access to huggingface.co\n\
1153 3. Use OpenAI embeddings instead: export OPENAI_API_KEY=... and use --embedding-model openai"
1154 } else {
1155 "\n\nIf model downloads fail, try using OpenAI embeddings:\n\
1156 export OPENAI_API_KEY=your-key && memvid ... --embedding-model openai"
1157 };
1158
1159 anyhow!(
1160 "Failed to initialize embedding model '{}': {err}\n\n\
1161 This typically means the model couldn't be downloaded or loaded.\n\
1162 Model size: ~{} MB{}\n\n\
1163 See: https://docs.memvid.com/embedding-models",
1164 embedding_model.name(),
1165 model_size_mb(embedding_model),
1166 platform_hint
1167 )
1168 })?;
1169
1170 let probe = model
1171 .embed(vec!["memvid probe".to_string()], None)
1172 .map_err(|err| anyhow!("failed to determine embedding dimension: {err}"))?;
1173 let dimension = probe.first().map(|vec| vec.len()).unwrap_or(0);
1174
1175 if dimension == 0 {
1176 bail!("fastembed reported zero-length embeddings");
1177 }
1178
1179 if dimension != embedding_model.dimensions() {
1181 tracing::warn!(
1182 "Embedding dimension mismatch: expected {}, got {}",
1183 embedding_model.dimensions(),
1184 dimension
1185 );
1186 }
1187
1188 Ok(EmbeddingRuntime::new_fastembed(model, embedding_model, dimension))
1189}
1190
1191pub fn load_embedding_runtime(config: &CliConfig) -> Result<EmbeddingRuntime> {
1193 use anyhow::bail;
1194
1195 match instantiate_embedding_runtime(config) {
1196 Ok(runtime) => Ok(runtime),
1197 Err(err) => {
1198 if config.offline {
1199 bail!(
1200 "semantic embeddings unavailable while offline; allow one connected run so fastembed can cache model weights ({err})"
1201 );
1202 }
1203 Err(err)
1204 }
1205 }
1206}
1207
1208pub fn try_load_embedding_runtime(config: &CliConfig) -> Option<EmbeddingRuntime> {
1210 use tracing::warn;
1211
1212 match instantiate_embedding_runtime(config) {
1213 Ok(runtime) => Some(runtime),
1214 Err(err) => {
1215 warn!("semantic embeddings unavailable: {err}");
1216 None
1217 }
1218 }
1219}
1220
1221pub fn load_embedding_runtime_with_model(
1224 config: &CliConfig,
1225 model_override: Option<&str>,
1226) -> Result<EmbeddingRuntime> {
1227 use tracing::info;
1228
1229 let mut raw_override: Option<&str> = None;
1230 let embedding_model = match model_override {
1231 Some(model_str) => {
1232 raw_override = Some(model_str);
1233 let parsed = model_str.parse::<EmbeddingModelChoice>()?;
1234 if parsed.dimensions() > 0 {
1235 info!(
1236 "Using embedding model override: {} ({}D)",
1237 parsed.name(),
1238 parsed.dimensions()
1239 );
1240 } else {
1241 info!("Using embedding model override: {}", parsed.name());
1242 }
1243 parsed
1244 }
1245 None => config.embedding_model,
1246 };
1247
1248 if embedding_model.dimensions() > 0 {
1249 info!(
1250 "Loading embedding model: {} ({}D)",
1251 embedding_model.name(),
1252 embedding_model.dimensions()
1253 );
1254 } else {
1255 info!("Loading embedding model: {}", embedding_model.name());
1256 }
1257
1258 if config.offline && embedding_model.is_remote() {
1259 anyhow::bail!(
1260 "remote embeddings are unavailable while offline; set MEMVID_OFFLINE=0 or use a local embedding model"
1261 );
1262 }
1263
1264 if embedding_model.is_openai() {
1265 return instantiate_openai_runtime(embedding_model);
1266 }
1267
1268 if embedding_model == EmbeddingModelChoice::Nvidia {
1269 return instantiate_nvidia_runtime(raw_override);
1270 }
1271
1272 if embedding_model == EmbeddingModelChoice::Gemini {
1273 return instantiate_gemini_runtime();
1274 }
1275
1276 if embedding_model == EmbeddingModelChoice::Mistral {
1277 return instantiate_mistral_runtime();
1278 }
1279
1280 instantiate_fastembed_runtime(config, embedding_model)
1281}
1282
1283pub fn try_load_embedding_runtime_with_model(
1285 config: &CliConfig,
1286 model_override: Option<&str>,
1287) -> Option<EmbeddingRuntime> {
1288 use tracing::warn;
1289
1290 match load_embedding_runtime_with_model(config, model_override) {
1291 Ok(runtime) => Some(runtime),
1292 Err(err) => {
1293 warn!("semantic embeddings unavailable: {err}");
1294 None
1295 }
1296 }
1297}
1298
1299pub fn load_embedding_runtime_for_mv2(
1309 config: &CliConfig,
1310 model_override: Option<&str>,
1311 mv2_dimension: Option<u32>,
1312) -> Result<EmbeddingRuntime> {
1313 use tracing::info;
1314
1315 if let Some(model_str) = model_override {
1317 return load_embedding_runtime_with_model(config, Some(model_str));
1318 }
1319
1320 if let Some(dim) = mv2_dimension {
1322 if let Some(detected_model) = EmbeddingModelChoice::from_dimension(dim) {
1323 info!(
1324 "Auto-detected embedding model from MV2: {} ({}D)",
1325 detected_model.name(),
1326 dim
1327 );
1328
1329 if detected_model.is_openai() {
1331 if std::env::var("OPENAI_API_KEY").is_ok() {
1332 return load_embedding_runtime_with_model(config, Some(detected_model.name()));
1333 } else {
1334 return Err(anyhow!(
1336 "MV2 file uses OpenAI embeddings ({}D) but OPENAI_API_KEY is not set.\n\n\
1337 Options:\n\
1338 1. Set OPENAI_API_KEY environment variable\n\
1339 2. Use --query-embedding-model to specify a different model\n\
1340 3. Use lexical-only search with --mode lex\n\n\
1341 See: https://docs.memvid.com/embedding-models",
1342 dim
1343 ));
1344 }
1345 }
1346
1347 return load_embedding_runtime_with_model(config, Some(detected_model.name()));
1348 }
1349 }
1350
1351 load_embedding_runtime(config)
1353}
1354
1355pub fn try_load_embedding_runtime_for_mv2(
1357 config: &CliConfig,
1358 model_override: Option<&str>,
1359 mv2_dimension: Option<u32>,
1360) -> Option<EmbeddingRuntime> {
1361 use tracing::warn;
1362
1363 match load_embedding_runtime_for_mv2(config, model_override, mv2_dimension) {
1364 Ok(runtime) => Some(runtime),
1365 Err(err) => {
1366 warn!("semantic embeddings unavailable: {err}");
1367 None
1368 }
1369 }
1370}