1use std::env;
7use std::path::PathBuf;
8use std::str::FromStr;
9use std::sync::atomic::{AtomicUsize, Ordering};
10
11use anyhow::{anyhow, Result};
12use ed25519_dalek::VerifyingKey;
13
14const DEFAULT_API_URL: &str = "https://kgpfm35ddc.us-east-2.awsapprunner.com";
15const DEFAULT_CACHE_DIR: &str = "~/.cache/memvid";
16
17#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
19pub enum EmbeddingModelChoice {
20 #[default]
22 BgeSmall,
23 BgeBase,
25 Nomic,
27 GteLarge,
29 OpenAILarge,
31 OpenAISmall,
33 OpenAIAda,
35 Nvidia,
37}
38
39impl EmbeddingModelChoice {
40 pub fn is_openai(&self) -> bool {
42 matches!(
43 self,
44 EmbeddingModelChoice::OpenAILarge
45 | EmbeddingModelChoice::OpenAISmall
46 | EmbeddingModelChoice::OpenAIAda
47 )
48 }
49
50 pub fn to_fastembed_model(&self) -> fastembed::EmbeddingModel {
55 match self {
56 EmbeddingModelChoice::BgeSmall => fastembed::EmbeddingModel::BGESmallENV15,
57 EmbeddingModelChoice::BgeBase => fastembed::EmbeddingModel::BGEBaseENV15,
58 EmbeddingModelChoice::Nomic => fastembed::EmbeddingModel::NomicEmbedTextV15,
59 EmbeddingModelChoice::GteLarge => fastembed::EmbeddingModel::GTELargeENV15,
60 EmbeddingModelChoice::OpenAILarge
61 | EmbeddingModelChoice::OpenAISmall
62 | EmbeddingModelChoice::OpenAIAda => {
63 panic!("OpenAI models don't use fastembed. Check is_openai() first.")
64 }
65 EmbeddingModelChoice::Nvidia => {
66 panic!("NVIDIA embeddings don't use fastembed. Check is_openai()/model choice first.")
67 }
68 }
69 }
70
71 pub fn name(&self) -> &'static str {
73 match self {
74 EmbeddingModelChoice::BgeSmall => "bge-small",
75 EmbeddingModelChoice::BgeBase => "bge-base",
76 EmbeddingModelChoice::Nomic => "nomic",
77 EmbeddingModelChoice::GteLarge => "gte-large",
78 EmbeddingModelChoice::OpenAILarge => "openai-large",
79 EmbeddingModelChoice::OpenAISmall => "openai-small",
80 EmbeddingModelChoice::OpenAIAda => "openai-ada",
81 EmbeddingModelChoice::Nvidia => "nvidia",
82 }
83 }
84
85 pub fn canonical_model_id(&self) -> &'static str {
91 match self {
92 EmbeddingModelChoice::BgeSmall => "BAAI/bge-small-en-v1.5",
93 EmbeddingModelChoice::BgeBase => "BAAI/bge-base-en-v1.5",
94 EmbeddingModelChoice::Nomic => "nomic-embed-text-v1.5",
95 EmbeddingModelChoice::GteLarge => "thenlper/gte-large",
96 EmbeddingModelChoice::OpenAILarge => "text-embedding-3-large",
97 EmbeddingModelChoice::OpenAISmall => "text-embedding-3-small",
98 EmbeddingModelChoice::OpenAIAda => "text-embedding-ada-002",
99 EmbeddingModelChoice::Nvidia => "nvidia/nv-embed-v1",
100 }
101 }
102
103 pub fn dimensions(&self) -> usize {
105 match self {
106 EmbeddingModelChoice::BgeSmall => 384,
107 EmbeddingModelChoice::BgeBase => 768,
108 EmbeddingModelChoice::Nomic => 768,
109 EmbeddingModelChoice::GteLarge => 1024,
110 EmbeddingModelChoice::OpenAILarge => 3072,
111 EmbeddingModelChoice::OpenAISmall => 1536,
112 EmbeddingModelChoice::OpenAIAda => 1536,
113 EmbeddingModelChoice::Nvidia => 0,
115 }
116 }
117}
118
119impl FromStr for EmbeddingModelChoice {
120 type Err = anyhow::Error;
121
122 fn from_str(s: &str) -> Result<Self> {
123 let lowered = s.trim().to_ascii_lowercase();
124 match lowered.as_str() {
125 "bge-small" | "bge_small" | "bgesmall" | "small" => Ok(EmbeddingModelChoice::BgeSmall),
126 "baai/bge-small-en-v1.5" => Ok(EmbeddingModelChoice::BgeSmall),
127 "bge-base" | "bge_base" | "bgebase" | "base" => Ok(EmbeddingModelChoice::BgeBase),
128 "baai/bge-base-en-v1.5" => Ok(EmbeddingModelChoice::BgeBase),
129 "nomic" | "nomic-embed" | "nomic_embed" => Ok(EmbeddingModelChoice::Nomic),
130 "nomic-embed-text-v1.5" => Ok(EmbeddingModelChoice::Nomic),
131 "gte-large" | "gte_large" | "gtelarge" | "gte" => Ok(EmbeddingModelChoice::GteLarge),
132 "thenlper/gte-large" => Ok(EmbeddingModelChoice::GteLarge),
133 "openai" | "openai-large" | "openai_large" | "text-embedding-3-large" => {
135 Ok(EmbeddingModelChoice::OpenAILarge)
136 }
137 "openai-small" | "openai_small" | "text-embedding-3-small" => {
138 Ok(EmbeddingModelChoice::OpenAISmall)
139 }
140 "openai-ada" | "openai_ada" | "text-embedding-ada-002" | "ada" => {
141 Ok(EmbeddingModelChoice::OpenAIAda)
142 }
143 "nvidia" | "nv" | "nv-embed-v1" | "nvidia/nv-embed-v1" => Ok(EmbeddingModelChoice::Nvidia),
144 _ if lowered.starts_with("nvidia/") || lowered.starts_with("nvidia:") || lowered.starts_with("nv:") => {
145 Ok(EmbeddingModelChoice::Nvidia)
146 }
147 _ => Err(anyhow!(
148 "unknown embedding model '{}'. Valid options: bge-small, bge-base, nomic, gte-large, openai, openai-small, openai-ada, nvidia",
149 s
150 )),
151 }
152 }
153}
154
155impl std::fmt::Display for EmbeddingModelChoice {
156 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
157 write!(f, "{}", self.name())
158 }
159}
160
161impl EmbeddingModelChoice {
162 pub fn from_dimension(dim: u32) -> Option<Self> {
174 match dim {
175 384 => Some(EmbeddingModelChoice::BgeSmall),
176 768 => Some(EmbeddingModelChoice::BgeBase), 1024 => Some(EmbeddingModelChoice::GteLarge),
178 1536 => Some(EmbeddingModelChoice::OpenAISmall), 3072 => Some(EmbeddingModelChoice::OpenAILarge),
180 0 => None, _ => {
182 tracing::warn!("Unknown embedding dimension {}, using default model", dim);
183 None
184 }
185 }
186 }
187}
188
189#[derive(Debug, Clone)]
191pub struct CliConfig {
192 pub api_key: Option<String>,
193 pub api_url: String,
194 pub cache_dir: PathBuf,
195 pub ticket_pubkey: Option<VerifyingKey>,
196 pub models_dir: PathBuf,
197 pub offline: bool,
198 pub embedding_model: EmbeddingModelChoice,
200}
201
202impl PartialEq for CliConfig {
203 fn eq(&self, other: &Self) -> bool {
204 self.api_key == other.api_key
205 && self.api_url == other.api_url
206 && self.cache_dir == other.cache_dir
207 && self.models_dir == other.models_dir
208 && self.offline == other.offline
209 && self.embedding_model == other.embedding_model
210 }
211}
212
213impl Eq for CliConfig {}
214
215impl CliConfig {
216 pub fn load() -> Result<Self> {
217 let api_key = env::var("MEMVID_API_KEY").ok().and_then(|value| {
218 let trimmed = value.trim().to_string();
219 (!trimmed.is_empty()).then_some(trimmed)
220 });
221
222 let api_url = env::var("MEMVID_API_URL").unwrap_or_else(|_| DEFAULT_API_URL.to_string());
223
224 let cache_dir_raw =
225 env::var("MEMVID_CACHE_DIR").unwrap_or_else(|_| DEFAULT_CACHE_DIR.to_string());
226 let cache_dir = expand_path(&cache_dir_raw)?;
227
228 let models_dir_raw =
229 env::var("MEMVID_MODELS_DIR").unwrap_or_else(|_| "~/.memvid/models".to_string());
230 let models_dir = expand_path(&models_dir_raw)?;
231
232 let ticket_pubkey = env::var("MEMVID_TICKET_PUBKEY")
233 .ok()
234 .and_then(|value| {
235 let trimmed = value.trim();
236 if trimmed.is_empty() {
237 None
238 } else {
239 Some(memvid_core::parse_ed25519_public_key_base64(trimmed))
240 }
241 })
242 .transpose()?;
243
244 let offline = env::var("MEMVID_OFFLINE")
245 .ok()
246 .map(|value| match value.trim().to_ascii_lowercase().as_str() {
247 "1" | "true" | "yes" => true,
248 _ => false,
249 })
250 .unwrap_or(false);
251
252 let embedding_model = env::var("MEMVID_EMBEDDING_MODEL")
254 .ok()
255 .and_then(|value| {
256 let trimmed = value.trim();
257 if trimmed.is_empty() {
258 None
259 } else {
260 EmbeddingModelChoice::from_str(trimmed).ok()
261 }
262 })
263 .unwrap_or_default();
264
265 Ok(Self {
266 api_key,
267 api_url,
268 cache_dir,
269 ticket_pubkey,
270 models_dir,
271 offline,
272 embedding_model,
273 })
274 }
275
276 pub fn with_embedding_model(&self, model: EmbeddingModelChoice) -> Self {
278 Self {
279 embedding_model: model,
280 ..self.clone()
281 }
282 }
283}
284
285fn expand_path(value: &str) -> Result<PathBuf> {
286 if value.trim().is_empty() {
287 return Err(anyhow!("cache directory cannot be empty"));
288 }
289
290 let expanded = if let Some(stripped) = value.strip_prefix("~/") {
291 home_dir()?.join(stripped)
292 } else if let Some(stripped) = value.strip_prefix("~\\") {
293 home_dir()?.join(stripped)
295 } else if value == "~" {
296 home_dir()?
297 } else {
298 PathBuf::from(value)
299 };
300
301 if expanded.is_absolute() {
302 Ok(expanded)
303 } else {
304 Ok(env::current_dir()?.join(expanded))
305 }
306}
307
308fn home_dir() -> Result<PathBuf> {
309 if let Some(path) = env::var_os("HOME") {
310 if !path.is_empty() {
311 return Ok(PathBuf::from(path));
312 }
313 }
314
315 #[cfg(windows)]
316 {
317 if let Some(path) = env::var_os("USERPROFILE") {
318 if !path.is_empty() {
319 return Ok(PathBuf::from(path));
320 }
321 }
322 if let (Some(drive), Some(path)) = (env::var_os("HOMEDRIVE"), env::var_os("HOMEPATH")) {
323 if !drive.is_empty() && !path.is_empty() {
324 return Ok(PathBuf::from(format!(
325 "{}{}",
326 drive.to_string_lossy(),
327 path.to_string_lossy()
328 )));
329 }
330 }
331 }
332
333 Err(anyhow!("unable to resolve home directory"))
334}
335
336#[cfg(test)]
337mod tests {
338 use super::*;
339 use base64::engine::general_purpose::STANDARD as BASE64_STANDARD;
340 use base64::Engine;
341 use ed25519_dalek::SigningKey;
342 use std::sync::{Mutex, OnceLock};
343
344 fn env_lock() -> std::sync::MutexGuard<'static, ()> {
345 static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
346 LOCK.get_or_init(|| Mutex::new(())).lock().unwrap()
347 }
348
349 fn set_or_unset(var: &str, value: Option<String>) {
350 match value {
351 Some(v) => unsafe { env::set_var(var, v) },
352 None => unsafe { env::remove_var(var) },
353 }
354 }
355
356 #[test]
357 fn defaults_expand_using_home_directory() {
358 let _guard = env_lock();
359
360 let previous_home = env::var("HOME").ok();
361 #[cfg(windows)]
362 let previous_userprofile = env::var("USERPROFILE").ok();
363
364 for var in [
365 "MEMVID_API_KEY",
366 "MEMVID_API_URL",
367 "MEMVID_CACHE_DIR",
368 "MEMVID_TICKET_PUBKEY",
369 "MEMVID_MODELS_DIR",
370 "MEMVID_OFFLINE",
371 ] {
372 unsafe { env::remove_var(var) };
373 }
374
375 let tmp = tempfile::tempdir().expect("tmpdir");
376 let tmp_path = tmp.path().to_path_buf();
377 unsafe { env::set_var("HOME", &tmp_path) };
378 #[cfg(windows)]
379 unsafe {
380 env::set_var("USERPROFILE", &tmp_path)
381 };
382
383 let config = CliConfig::load().expect("load");
384 assert_eq!(config.api_key, None);
385 assert_eq!(config.api_url, DEFAULT_API_URL);
386 assert_eq!(config.cache_dir, tmp_path.join(".cache/memvid"));
387 assert!(config.ticket_pubkey.is_none());
388 assert_eq!(config.models_dir, tmp_path.join(".memvid/models"));
389 assert!(!config.offline);
390
391 set_or_unset("HOME", previous_home);
392 #[cfg(windows)]
393 {
394 set_or_unset("USERPROFILE", previous_userprofile);
395 }
396 }
397
398 #[test]
399 fn env_overrides_are_respected() {
400 let _guard = env_lock();
401
402 let previous_env: Vec<(&'static str, Option<String>)> = [
403 "MEMVID_API_KEY",
404 "MEMVID_API_URL",
405 "MEMVID_CACHE_DIR",
406 "MEMVID_TICKET_PUBKEY",
407 "MEMVID_MODELS_DIR",
408 "MEMVID_OFFLINE",
409 ]
410 .into_iter()
411 .map(|var| (var, env::var(var).ok()))
412 .collect();
413
414 unsafe { env::set_var("MEMVID_API_KEY", "abc123") };
415 unsafe { env::set_var("MEMVID_API_URL", "https://staging.memvid.app") };
416 unsafe { env::set_var("MEMVID_CACHE_DIR", "~/memvid-cache") };
417 unsafe { env::set_var("MEMVID_MODELS_DIR", "~/models") };
418 unsafe { env::set_var("MEMVID_OFFLINE", "true") };
419 let signing = SigningKey::from_bytes(&[9u8; 32]);
420 let encoded = BASE64_STANDARD.encode(signing.verifying_key().as_bytes());
421 unsafe { env::set_var("MEMVID_TICKET_PUBKEY", encoded) };
422
423 let tmp = tempfile::tempdir().expect("tmpdir");
424 let tmp_path = tmp.path().to_path_buf();
425 unsafe { env::set_var("HOME", &tmp_path) };
426 #[cfg(windows)]
427 unsafe {
428 env::set_var("USERPROFILE", &tmp_path)
429 };
430
431 let config = CliConfig::load().expect("load");
432 assert_eq!(config.api_key.as_deref(), Some("abc123"));
433 assert_eq!(config.api_url, "https://staging.memvid.app");
434 assert_eq!(config.cache_dir, tmp_path.join("memvid-cache"));
435 assert_eq!(
436 config.ticket_pubkey.expect("pubkey").as_bytes(),
437 signing.verifying_key().as_bytes()
438 );
439 assert_eq!(config.models_dir, tmp_path.join("models"));
440 assert!(config.offline);
441
442 for (var, value) in previous_env {
443 set_or_unset(var, value);
444 }
445 }
446
447 #[test]
448 fn rejects_empty_cache_dir() {
449 let _guard = env_lock();
450
451 let previous = env::var("MEMVID_CACHE_DIR").ok();
452 unsafe { env::set_var("MEMVID_CACHE_DIR", " ") };
453 let err = CliConfig::load().expect_err("should fail");
454 assert!(err.to_string().contains("cache directory"));
455 set_or_unset("MEMVID_CACHE_DIR", previous);
456 }
457}
458
459pub fn init_tracing(verbosity: u8) -> Result<()> {
461 use std::io::IsTerminal;
462 use tracing_subscriber::{filter::Directive, fmt, EnvFilter};
463
464 let level = match verbosity {
465 0 => "warn",
466 1 => "info",
467 2 => "debug",
468 _ => "trace",
469 };
470
471 let mut env_filter =
472 EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new(level));
473 for directive_str in ["llama_cpp=error", "llama_cpp_sys=error", "ggml=error"] {
474 if let Ok(directive) = directive_str.parse::<Directive>() {
475 env_filter = env_filter.add_directive(directive);
476 }
477 }
478
479 let use_ansi = std::io::stderr().is_terminal();
483
484 fmt()
485 .with_env_filter(env_filter)
486 .with_writer(std::io::stderr)
487 .with_target(false)
488 .without_time()
489 .with_ansi(use_ansi)
490 .try_init()
491 .map_err(|err| anyhow!(err))?;
492 Ok(())
493}
494
495pub fn resolve_llm_context_budget_override(cli_value: Option<usize>) -> Result<Option<usize>> {
497 use anyhow::bail;
498
499 if let Some(value) = cli_value {
500 if value == 0 {
501 bail!("--llm-context-depth must be a positive integer");
502 }
503 return Ok(Some(value));
504 }
505
506 let raw_env = match env::var("MEMVID_LLM_CONTEXT_BUDGET") {
507 Ok(value) => value,
508 Err(_) => return Ok(None),
509 };
510
511 let trimmed = raw_env.trim();
512 if trimmed.is_empty() {
513 return Ok(None);
514 }
515
516 let digits: String = trimmed
517 .chars()
518 .filter(|ch| !ch.is_ascii_whitespace() && *ch != '_')
519 .collect();
520
521 if digits.is_empty() {
522 bail!("MEMVID_LLM_CONTEXT_BUDGET must be a positive integer value");
523 }
524
525 let value: usize = digits.parse().map_err(|err| {
526 anyhow!(
527 "MEMVID_LLM_CONTEXT_BUDGET value '{}' is not a valid number: {}",
528 trimmed,
529 err
530 )
531 })?;
532
533 if value == 0 {
534 bail!("MEMVID_LLM_CONTEXT_BUDGET must be a positive integer");
535 }
536
537 Ok(Some(value))
538}
539
540use crate::nvidia_embeddings::NvidiaEmbeddingProvider;
541use crate::openai_embeddings::OpenAIEmbeddingProvider;
542
543#[derive(Clone)]
545enum EmbeddingBackend {
546 FastEmbed(std::sync::Arc<std::sync::Mutex<fastembed::TextEmbedding>>),
547 OpenAI(std::sync::Arc<OpenAIEmbeddingProvider>),
548 Nvidia(std::sync::Arc<NvidiaEmbeddingProvider>),
549}
550
551#[derive(Clone)]
553pub struct EmbeddingRuntime {
554 backend: EmbeddingBackend,
555 model: EmbeddingModelChoice,
556 dimension: std::sync::Arc<AtomicUsize>,
557}
558
559impl EmbeddingRuntime {
560 fn new_fastembed(
561 backend: fastembed::TextEmbedding,
562 model: EmbeddingModelChoice,
563 dimension: usize,
564 ) -> Self {
565 Self {
566 backend: EmbeddingBackend::FastEmbed(std::sync::Arc::new(std::sync::Mutex::new(
567 backend,
568 ))),
569 model,
570 dimension: std::sync::Arc::new(AtomicUsize::new(dimension)),
571 }
572 }
573
574 fn new_openai(
575 provider: OpenAIEmbeddingProvider,
576 model: EmbeddingModelChoice,
577 dimension: usize,
578 ) -> Self {
579 Self {
580 backend: EmbeddingBackend::OpenAI(std::sync::Arc::new(provider)),
581 model,
582 dimension: std::sync::Arc::new(AtomicUsize::new(dimension)),
583 }
584 }
585
586 fn new_nvidia(provider: NvidiaEmbeddingProvider, model: EmbeddingModelChoice) -> Self {
587 Self {
588 backend: EmbeddingBackend::Nvidia(std::sync::Arc::new(provider)),
589 model,
590 dimension: std::sync::Arc::new(AtomicUsize::new(0)),
591 }
592 }
593
594 const MAX_OPENAI_EMBEDDING_TEXT_LEN: usize = 20_000;
595 const MAX_NVIDIA_EMBEDDING_TEXT_LEN: usize = 12_000;
597
598 fn max_remote_embedding_chars(&self) -> usize {
599 match &self.backend {
600 EmbeddingBackend::OpenAI(_) => Self::MAX_OPENAI_EMBEDDING_TEXT_LEN,
601 EmbeddingBackend::Nvidia(_) => Self::MAX_NVIDIA_EMBEDDING_TEXT_LEN,
602 EmbeddingBackend::FastEmbed(_) => usize::MAX,
603 }
604 }
605
606 fn truncate_for_embedding<'a>(
608 text: &'a str,
609 max_chars: usize,
610 ) -> std::borrow::Cow<'a, str> {
611 if text.len() <= max_chars {
612 std::borrow::Cow::Borrowed(text)
613 } else {
614 let truncated = &text[..max_chars];
616 let end = truncated
617 .char_indices()
618 .rev()
619 .next()
620 .map(|(i, c)| i + c.len_utf8())
621 .unwrap_or(max_chars);
622 tracing::info!("Truncated embedding text from {} to {} chars", text.len(), end);
623 std::borrow::Cow::Owned(text[..end].to_string())
624 }
625 }
626
627 fn note_dimension(&self, observed: usize) -> Result<()> {
628 if observed == 0 {
629 return Err(anyhow!("embedding provider returned zero-length embedding"));
630 }
631
632 let current = self.dimension.load(Ordering::Relaxed);
633 if current == 0 {
634 self.dimension.store(observed, Ordering::Relaxed);
635 return Ok(());
636 }
637
638 if current != observed {
639 return Err(anyhow!(
640 "embedding provider returned {observed}D vectors but runtime expects {current}D"
641 ));
642 }
643
644 Ok(())
645 }
646
647 fn truncate_if_remote<'a>(&self, text: &'a str) -> std::borrow::Cow<'a, str> {
648 match &self.backend {
649 EmbeddingBackend::OpenAI(_) | EmbeddingBackend::Nvidia(_) => {
650 Self::truncate_for_embedding(text, self.max_remote_embedding_chars())
651 }
652 _ => std::borrow::Cow::Borrowed(text),
653 }
654 }
655
656 pub fn embed_passage(&self, text: &str) -> Result<Vec<f32>> {
657 let text = self.truncate_if_remote(text);
658 let embedding = match &self.backend {
659 EmbeddingBackend::FastEmbed(model) => {
660 let mut guard = model
661 .lock()
662 .map_err(|_| anyhow!("fastembed runtime poisoned"))?;
663 let outputs = guard
664 .embed(vec![text.into_owned()], None)
665 .map_err(|err| anyhow!("failed to compute embedding with fastembed: {err}"))?;
666 outputs
667 .into_iter()
668 .next()
669 .ok_or_else(|| anyhow!("fastembed returned no embedding output"))?
670 }
671 EmbeddingBackend::OpenAI(provider) => {
672 use memvid_core::EmbeddingProvider;
673 provider
674 .embed_text(&text)
675 .map_err(|err| anyhow!("failed to compute embedding with OpenAI: {err}"))?
676 }
677 EmbeddingBackend::Nvidia(provider) => provider
678 .embed_passage(&text)
679 .map_err(|err| anyhow!("failed to compute embedding with NVIDIA: {err}"))?,
680 };
681
682 self.note_dimension(embedding.len())?;
683 Ok(embedding)
684 }
685
686 pub fn embed_query(&self, text: &str) -> Result<Vec<f32>> {
687 let text = self.truncate_if_remote(text);
688 match &self.backend {
689 EmbeddingBackend::Nvidia(provider) => {
690 let embedding = provider
691 .embed_query(&text)
692 .map_err(|err| anyhow!("failed to compute embedding with NVIDIA: {err}"))?;
693 self.note_dimension(embedding.len())?;
694 Ok(embedding)
695 }
696 _ => self.embed_passage(&text),
697 }
698 }
699
700 pub fn embed_batch_passages(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
701 if texts.is_empty() {
702 return Ok(Vec::new());
703 }
704
705 let truncated: Vec<std::borrow::Cow<'_, str>> =
706 texts.iter().map(|t| self.truncate_if_remote(t)).collect();
707 let truncated_refs: Vec<&str> = truncated.iter().map(|c| c.as_ref()).collect();
708
709 let embeddings = match &self.backend {
710 EmbeddingBackend::FastEmbed(model) => {
711 let mut guard = model
712 .lock()
713 .map_err(|_| anyhow!("fastembed runtime poisoned"))?;
714 guard
715 .embed(
716 truncated_refs
717 .iter()
718 .map(|s| (*s).to_string())
719 .collect::<Vec<String>>(),
720 None,
721 )
722 .map_err(|err| anyhow!("failed to compute embeddings with fastembed: {err}"))?
723 }
724 EmbeddingBackend::OpenAI(provider) => {
725 use memvid_core::EmbeddingProvider;
726 provider
727 .embed_batch(&truncated_refs)
728 .map_err(|err| anyhow!("failed to compute embeddings with OpenAI: {err}"))?
729 }
730 EmbeddingBackend::Nvidia(provider) => provider
731 .embed_passages(&truncated_refs)
732 .map_err(|err| anyhow!("failed to compute embeddings with NVIDIA: {err}"))?,
733 };
734
735 if let Some(first) = embeddings.first() {
736 self.note_dimension(first.len())?;
737 }
738 if let Some(expected) = embeddings.first().map(|e| e.len()) {
739 if embeddings.iter().any(|e| e.len() != expected) {
740 return Err(anyhow!("embedding provider returned mixed vector dimensions"));
741 }
742 }
743
744 Ok(embeddings)
745 }
746
747 pub fn embed_batch_queries(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
748 if texts.is_empty() {
749 return Ok(Vec::new());
750 }
751
752 let truncated: Vec<std::borrow::Cow<'_, str>> =
753 texts.iter().map(|t| self.truncate_if_remote(t)).collect();
754 let truncated_refs: Vec<&str> = truncated.iter().map(|c| c.as_ref()).collect();
755
756 match &self.backend {
757 EmbeddingBackend::Nvidia(provider) => {
758 let embeddings = provider
759 .embed_queries(&truncated_refs)
760 .map_err(|err| anyhow!("failed to compute embeddings with NVIDIA: {err}"))?;
761
762 if let Some(first) = embeddings.first() {
763 self.note_dimension(first.len())?;
764 }
765 if let Some(expected) = embeddings.first().map(|e| e.len()) {
766 if embeddings.iter().any(|e| e.len() != expected) {
767 return Err(anyhow!("embedding provider returned mixed vector dimensions"));
768 }
769 }
770
771 Ok(embeddings)
772 }
773 _ => self.embed_batch_passages(&truncated_refs),
774 }
775 }
776
777 pub fn dimension(&self) -> usize {
778 self.dimension.load(Ordering::Relaxed)
779 }
780
781 pub fn model_choice(&self) -> EmbeddingModelChoice {
782 self.model
783 }
784
785 pub fn provider_kind(&self) -> &'static str {
786 match &self.backend {
787 EmbeddingBackend::FastEmbed(_) => "fastembed",
788 EmbeddingBackend::OpenAI(_) => "openai",
789 EmbeddingBackend::Nvidia(_) => "nvidia",
790 }
791 }
792
793 pub fn provider_model_id(&self) -> String {
794 match &self.backend {
795 EmbeddingBackend::FastEmbed(_) => self.model.canonical_model_id().to_string(),
796 EmbeddingBackend::OpenAI(provider) => {
797 use memvid_core::EmbeddingProvider;
798 provider.model().to_string()
799 }
800 EmbeddingBackend::Nvidia(provider) => provider.model().to_string(),
801 }
802 }
803}
804
805impl memvid_core::VecEmbedder for EmbeddingRuntime {
806 fn embed_query(&self, text: &str) -> memvid_core::Result<Vec<f32>> {
807 EmbeddingRuntime::embed_query(self, text).map_err(|err| {
808 memvid_core::MemvidError::EmbeddingFailed {
809 reason: err.to_string().into_boxed_str(),
810 }
811 })
812 }
813
814 fn embedding_dimension(&self) -> usize {
815 self.dimension()
816 }
817}
818
819fn ensure_fastembed_cache(config: &CliConfig) -> Result<PathBuf> {
821 use std::fs;
822
823 let cache_dir = config.models_dir.clone();
824 fs::create_dir_all(&cache_dir)?;
825 Ok(cache_dir)
826}
827
828fn model_size_mb(model: EmbeddingModelChoice) -> usize {
830 match model {
831 EmbeddingModelChoice::BgeSmall => 33,
832 EmbeddingModelChoice::BgeBase => 110,
833 EmbeddingModelChoice::Nomic => 137,
834 EmbeddingModelChoice::GteLarge => 327,
835 EmbeddingModelChoice::OpenAILarge
837 | EmbeddingModelChoice::OpenAISmall
838 | EmbeddingModelChoice::OpenAIAda
839 | EmbeddingModelChoice::Nvidia => 0,
840 }
841}
842
843fn instantiate_embedding_runtime(config: &CliConfig) -> Result<EmbeddingRuntime> {
845 use tracing::info;
846
847 let embedding_model = config.embedding_model;
848
849 if embedding_model.dimensions() > 0 {
850 info!(
851 "Loading embedding model: {} ({}D)",
852 embedding_model.name(),
853 embedding_model.dimensions()
854 );
855 } else {
856 info!("Loading embedding model: {}", embedding_model.name());
857 }
858
859 if config.offline && (embedding_model.is_openai() || embedding_model == EmbeddingModelChoice::Nvidia)
860 {
861 anyhow::bail!(
862 "remote embeddings are unavailable while offline; set MEMVID_OFFLINE=0 or use a local embedding model"
863 );
864 }
865
866 if embedding_model.is_openai() {
868 return instantiate_openai_runtime(embedding_model);
869 }
870
871 if embedding_model == EmbeddingModelChoice::Nvidia {
872 return instantiate_nvidia_runtime(None);
873 }
874
875 instantiate_fastembed_runtime(config, embedding_model)
877}
878
879fn instantiate_openai_runtime(embedding_model: EmbeddingModelChoice) -> Result<EmbeddingRuntime> {
881 use anyhow::bail;
882 use memvid_core::EmbeddingConfig;
883 use tracing::info;
884
885 let api_key = std::env::var("OPENAI_API_KEY").map_err(|_| {
886 anyhow!("OPENAI_API_KEY environment variable is required for OpenAI embeddings")
887 })?;
888
889 if api_key.is_empty() {
890 bail!("OPENAI_API_KEY cannot be empty");
891 }
892
893 let config = match embedding_model {
894 EmbeddingModelChoice::OpenAILarge => EmbeddingConfig::openai_large(),
895 EmbeddingModelChoice::OpenAISmall => EmbeddingConfig::openai_small(),
896 EmbeddingModelChoice::OpenAIAda => EmbeddingConfig::openai_ada(),
897 _ => unreachable!("is_openai() should have been false"),
898 };
899
900 let provider = OpenAIEmbeddingProvider::new(api_key, config.clone())
901 .map_err(|err| anyhow!("failed to create OpenAI embedding provider: {err}"))?;
902
903 info!(
904 "OpenAI embedding provider ready: model={}, dimension={}",
905 config.model, config.dimension
906 );
907
908 Ok(EmbeddingRuntime::new_openai(
909 provider,
910 embedding_model,
911 config.dimension,
912 ))
913}
914
915fn normalize_nvidia_embedding_model_override(raw: &str) -> Option<String> {
916 let trimmed = raw.trim();
917 if trimmed.is_empty() {
918 return None;
919 }
920
921 let lowered = trimmed.to_ascii_lowercase();
922 if lowered == "nvidia" || lowered == "nv" {
923 return None;
924 }
925
926 let without_prefix = trimmed
927 .strip_prefix("nvidia:")
928 .or_else(|| trimmed.strip_prefix("nv:"))
929 .unwrap_or(trimmed)
930 .trim();
931
932 if without_prefix.is_empty() {
933 return None;
934 }
935
936 if without_prefix.eq_ignore_ascii_case("nv-embed-v1") {
937 return Some("nvidia/nv-embed-v1".to_string());
938 }
939
940 if without_prefix.contains('/') {
941 return Some(without_prefix.to_string());
942 }
943
944 Some(format!("nvidia/{without_prefix}"))
945}
946
947fn instantiate_nvidia_runtime(model_override: Option<&str>) -> Result<EmbeddingRuntime> {
949 use tracing::info;
950
951 let normalized = model_override.and_then(normalize_nvidia_embedding_model_override);
952 let provider = NvidiaEmbeddingProvider::from_env(normalized.as_deref())
953 .map_err(|err| anyhow!("failed to create NVIDIA embedding provider: {err}"))?;
954
955 info!(
956 "NVIDIA embedding provider ready: model={}",
957 provider.model()
958 );
959
960 Ok(EmbeddingRuntime::new_nvidia(
961 provider,
962 EmbeddingModelChoice::Nvidia,
963 ))
964}
965
966fn instantiate_fastembed_runtime(
968 config: &CliConfig,
969 embedding_model: EmbeddingModelChoice,
970) -> Result<EmbeddingRuntime> {
971 use anyhow::bail;
972 use fastembed::{InitOptions, TextEmbedding};
973 use std::fs;
974
975 let cache_dir = ensure_fastembed_cache(config)?;
976
977 if config.offline {
978 let mut entries = fs::read_dir(&cache_dir)?;
979 if entries.next().is_none() {
980 bail!(
981 "semantic embeddings unavailable while offline; allow one connected run so fastembed can cache model weights"
982 );
983 }
984 }
985
986 let options = InitOptions::new(embedding_model.to_fastembed_model())
987 .with_cache_dir(cache_dir)
988 .with_show_download_progress(true);
989 let mut model = TextEmbedding::try_new(options).map_err(|err| {
990 let platform_hint = if cfg!(target_os = "windows") {
992 "\n\nWindows users: If model downloads fail, try:\n\
993 1. Run as Administrator\n\
994 2. Check your antivirus isn't blocking downloads\n\
995 3. Use OpenAI embeddings instead: set OPENAI_API_KEY and use --embedding-model openai"
996 } else if cfg!(target_os = "linux") {
997 "\n\nLinux users: If model downloads fail, try:\n\
998 1. Check disk space in ~/.memvid/models\n\
999 2. Ensure you have network access to huggingface.co\n\
1000 3. Use OpenAI embeddings instead: export OPENAI_API_KEY=... and use --embedding-model openai"
1001 } else {
1002 "\n\nIf model downloads fail, try using OpenAI embeddings:\n\
1003 export OPENAI_API_KEY=your-key && memvid ... --embedding-model openai"
1004 };
1005
1006 anyhow!(
1007 "Failed to initialize embedding model '{}': {err}\n\n\
1008 This typically means the model couldn't be downloaded or loaded.\n\
1009 Model size: ~{} MB{}\n\n\
1010 See: https://docs.memvid.com/embedding-models",
1011 embedding_model.name(),
1012 model_size_mb(embedding_model),
1013 platform_hint
1014 )
1015 })?;
1016
1017 let probe = model
1018 .embed(vec!["memvid probe".to_string()], None)
1019 .map_err(|err| anyhow!("failed to determine embedding dimension: {err}"))?;
1020 let dimension = probe.first().map(|vec| vec.len()).unwrap_or(0);
1021
1022 if dimension == 0 {
1023 bail!("fastembed reported zero-length embeddings");
1024 }
1025
1026 if dimension != embedding_model.dimensions() {
1028 tracing::warn!(
1029 "Embedding dimension mismatch: expected {}, got {}",
1030 embedding_model.dimensions(),
1031 dimension
1032 );
1033 }
1034
1035 Ok(EmbeddingRuntime::new_fastembed(model, embedding_model, dimension))
1036}
1037
1038pub fn load_embedding_runtime(config: &CliConfig) -> Result<EmbeddingRuntime> {
1040 use anyhow::bail;
1041
1042 match instantiate_embedding_runtime(config) {
1043 Ok(runtime) => Ok(runtime),
1044 Err(err) => {
1045 if config.offline {
1046 bail!(
1047 "semantic embeddings unavailable while offline; allow one connected run so fastembed can cache model weights ({err})"
1048 );
1049 }
1050 Err(err)
1051 }
1052 }
1053}
1054
1055pub fn try_load_embedding_runtime(config: &CliConfig) -> Option<EmbeddingRuntime> {
1057 use tracing::warn;
1058
1059 match instantiate_embedding_runtime(config) {
1060 Ok(runtime) => Some(runtime),
1061 Err(err) => {
1062 warn!("semantic embeddings unavailable: {err}");
1063 None
1064 }
1065 }
1066}
1067
1068pub fn load_embedding_runtime_with_model(
1071 config: &CliConfig,
1072 model_override: Option<&str>,
1073) -> Result<EmbeddingRuntime> {
1074 use tracing::info;
1075
1076 let mut raw_override: Option<&str> = None;
1077 let embedding_model = match model_override {
1078 Some(model_str) => {
1079 raw_override = Some(model_str);
1080 let parsed = model_str.parse::<EmbeddingModelChoice>()?;
1081 if parsed.dimensions() > 0 {
1082 info!(
1083 "Using embedding model override: {} ({}D)",
1084 parsed.name(),
1085 parsed.dimensions()
1086 );
1087 } else {
1088 info!("Using embedding model override: {}", parsed.name());
1089 }
1090 parsed
1091 }
1092 None => config.embedding_model,
1093 };
1094
1095 if embedding_model.dimensions() > 0 {
1096 info!(
1097 "Loading embedding model: {} ({}D)",
1098 embedding_model.name(),
1099 embedding_model.dimensions()
1100 );
1101 } else {
1102 info!("Loading embedding model: {}", embedding_model.name());
1103 }
1104
1105 if config.offline && (embedding_model.is_openai() || embedding_model == EmbeddingModelChoice::Nvidia)
1106 {
1107 anyhow::bail!(
1108 "remote embeddings are unavailable while offline; set MEMVID_OFFLINE=0 or use a local embedding model"
1109 );
1110 }
1111
1112 if embedding_model.is_openai() {
1113 return instantiate_openai_runtime(embedding_model);
1114 }
1115
1116 if embedding_model == EmbeddingModelChoice::Nvidia {
1117 return instantiate_nvidia_runtime(raw_override);
1118 }
1119
1120 instantiate_fastembed_runtime(config, embedding_model)
1121}
1122
1123pub fn try_load_embedding_runtime_with_model(
1125 config: &CliConfig,
1126 model_override: Option<&str>,
1127) -> Option<EmbeddingRuntime> {
1128 use tracing::warn;
1129
1130 match load_embedding_runtime_with_model(config, model_override) {
1131 Ok(runtime) => Some(runtime),
1132 Err(err) => {
1133 warn!("semantic embeddings unavailable: {err}");
1134 None
1135 }
1136 }
1137}
1138
1139pub fn load_embedding_runtime_for_mv2(
1149 config: &CliConfig,
1150 model_override: Option<&str>,
1151 mv2_dimension: Option<u32>,
1152) -> Result<EmbeddingRuntime> {
1153 use tracing::info;
1154
1155 if let Some(model_str) = model_override {
1157 return load_embedding_runtime_with_model(config, Some(model_str));
1158 }
1159
1160 if let Some(dim) = mv2_dimension {
1162 if let Some(detected_model) = EmbeddingModelChoice::from_dimension(dim) {
1163 info!(
1164 "Auto-detected embedding model from MV2: {} ({}D)",
1165 detected_model.name(),
1166 dim
1167 );
1168
1169 if detected_model.is_openai() {
1171 if std::env::var("OPENAI_API_KEY").is_ok() {
1172 return load_embedding_runtime_with_model(config, Some(detected_model.name()));
1173 } else {
1174 return Err(anyhow!(
1176 "MV2 file uses OpenAI embeddings ({}D) but OPENAI_API_KEY is not set.\n\n\
1177 Options:\n\
1178 1. Set OPENAI_API_KEY environment variable\n\
1179 2. Use --query-embedding-model to specify a different model\n\
1180 3. Use lexical-only search with --mode lex\n\n\
1181 See: https://docs.memvid.com/embedding-models",
1182 dim
1183 ));
1184 }
1185 }
1186
1187 return load_embedding_runtime_with_model(config, Some(detected_model.name()));
1188 }
1189 }
1190
1191 load_embedding_runtime(config)
1193}
1194
1195pub fn try_load_embedding_runtime_for_mv2(
1197 config: &CliConfig,
1198 model_override: Option<&str>,
1199 mv2_dimension: Option<u32>,
1200) -> Option<EmbeddingRuntime> {
1201 use tracing::warn;
1202
1203 match load_embedding_runtime_for_mv2(config, model_override, mv2_dimension) {
1204 Ok(runtime) => Some(runtime),
1205 Err(err) => {
1206 warn!("semantic embeddings unavailable: {err}");
1207 None
1208 }
1209 }
1210}