1use std::env;
7use std::path::PathBuf;
8use std::str::FromStr;
9
10use anyhow::{anyhow, Result};
11use ed25519_dalek::VerifyingKey;
12
13const DEFAULT_API_URL: &str = "https://kgpfm35ddc.us-east-2.awsapprunner.com";
14const DEFAULT_CACHE_DIR: &str = "~/.cache/memvid";
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
18pub enum EmbeddingModelChoice {
19 #[default]
21 BgeSmall,
22 BgeBase,
24 Nomic,
26 GteLarge,
28 OpenAILarge,
30 OpenAISmall,
32 OpenAIAda,
34}
35
36impl EmbeddingModelChoice {
37 pub fn is_openai(&self) -> bool {
39 matches!(
40 self,
41 EmbeddingModelChoice::OpenAILarge
42 | EmbeddingModelChoice::OpenAISmall
43 | EmbeddingModelChoice::OpenAIAda
44 )
45 }
46
47 pub fn to_fastembed_model(&self) -> fastembed::EmbeddingModel {
52 match self {
53 EmbeddingModelChoice::BgeSmall => fastembed::EmbeddingModel::BGESmallENV15,
54 EmbeddingModelChoice::BgeBase => fastembed::EmbeddingModel::BGEBaseENV15,
55 EmbeddingModelChoice::Nomic => fastembed::EmbeddingModel::NomicEmbedTextV15,
56 EmbeddingModelChoice::GteLarge => fastembed::EmbeddingModel::GTELargeENV15,
57 EmbeddingModelChoice::OpenAILarge
58 | EmbeddingModelChoice::OpenAISmall
59 | EmbeddingModelChoice::OpenAIAda => {
60 panic!("OpenAI models don't use fastembed. Check is_openai() first.")
61 }
62 }
63 }
64
65 pub fn name(&self) -> &'static str {
67 match self {
68 EmbeddingModelChoice::BgeSmall => "bge-small",
69 EmbeddingModelChoice::BgeBase => "bge-base",
70 EmbeddingModelChoice::Nomic => "nomic",
71 EmbeddingModelChoice::GteLarge => "gte-large",
72 EmbeddingModelChoice::OpenAILarge => "openai-large",
73 EmbeddingModelChoice::OpenAISmall => "openai-small",
74 EmbeddingModelChoice::OpenAIAda => "openai-ada",
75 }
76 }
77
78 pub fn dimensions(&self) -> usize {
80 match self {
81 EmbeddingModelChoice::BgeSmall => 384,
82 EmbeddingModelChoice::BgeBase => 768,
83 EmbeddingModelChoice::Nomic => 768,
84 EmbeddingModelChoice::GteLarge => 1024,
85 EmbeddingModelChoice::OpenAILarge => 3072,
86 EmbeddingModelChoice::OpenAISmall => 1536,
87 EmbeddingModelChoice::OpenAIAda => 1536,
88 }
89 }
90}
91
92impl FromStr for EmbeddingModelChoice {
93 type Err = anyhow::Error;
94
95 fn from_str(s: &str) -> Result<Self> {
96 match s.to_lowercase().as_str() {
97 "bge-small" | "bge_small" | "bgesmall" | "small" => Ok(EmbeddingModelChoice::BgeSmall),
98 "bge-base" | "bge_base" | "bgebase" | "base" => Ok(EmbeddingModelChoice::BgeBase),
99 "nomic" | "nomic-embed" | "nomic_embed" => Ok(EmbeddingModelChoice::Nomic),
100 "gte-large" | "gte_large" | "gtelarge" | "gte" => Ok(EmbeddingModelChoice::GteLarge),
101 "openai" | "openai-large" | "openai_large" | "text-embedding-3-large" => {
103 Ok(EmbeddingModelChoice::OpenAILarge)
104 }
105 "openai-small" | "openai_small" | "text-embedding-3-small" => {
106 Ok(EmbeddingModelChoice::OpenAISmall)
107 }
108 "openai-ada" | "openai_ada" | "text-embedding-ada-002" | "ada" => {
109 Ok(EmbeddingModelChoice::OpenAIAda)
110 }
111 _ => Err(anyhow!(
112 "unknown embedding model '{}'. Valid options: bge-small, bge-base, nomic, gte-large, openai, openai-small, openai-ada",
113 s
114 )),
115 }
116 }
117}
118
119impl std::fmt::Display for EmbeddingModelChoice {
120 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
121 write!(f, "{}", self.name())
122 }
123}
124
125impl EmbeddingModelChoice {
126 pub fn from_dimension(dim: u32) -> Option<Self> {
138 match dim {
139 384 => Some(EmbeddingModelChoice::BgeSmall),
140 768 => Some(EmbeddingModelChoice::BgeBase), 1024 => Some(EmbeddingModelChoice::GteLarge),
142 1536 => Some(EmbeddingModelChoice::OpenAISmall), 3072 => Some(EmbeddingModelChoice::OpenAILarge),
144 0 => None, _ => {
146 tracing::warn!("Unknown embedding dimension {}, using default model", dim);
147 None
148 }
149 }
150 }
151}
152
153#[derive(Debug, Clone)]
155pub struct CliConfig {
156 pub api_key: Option<String>,
157 pub api_url: String,
158 pub cache_dir: PathBuf,
159 pub ticket_pubkey: Option<VerifyingKey>,
160 pub models_dir: PathBuf,
161 pub offline: bool,
162 pub embedding_model: EmbeddingModelChoice,
164}
165
166impl PartialEq for CliConfig {
167 fn eq(&self, other: &Self) -> bool {
168 self.api_key == other.api_key
169 && self.api_url == other.api_url
170 && self.cache_dir == other.cache_dir
171 && self.models_dir == other.models_dir
172 && self.offline == other.offline
173 && self.embedding_model == other.embedding_model
174 }
175}
176
177impl Eq for CliConfig {}
178
179impl CliConfig {
180 pub fn load() -> Result<Self> {
181 let api_key = env::var("MEMVID_API_KEY").ok().and_then(|value| {
182 let trimmed = value.trim().to_string();
183 (!trimmed.is_empty()).then_some(trimmed)
184 });
185
186 let api_url = env::var("MEMVID_API_URL").unwrap_or_else(|_| DEFAULT_API_URL.to_string());
187
188 let cache_dir_raw =
189 env::var("MEMVID_CACHE_DIR").unwrap_or_else(|_| DEFAULT_CACHE_DIR.to_string());
190 let cache_dir = expand_path(&cache_dir_raw)?;
191
192 let models_dir_raw =
193 env::var("MEMVID_MODELS_DIR").unwrap_or_else(|_| "~/.memvid/models".to_string());
194 let models_dir = expand_path(&models_dir_raw)?;
195
196 let ticket_pubkey = env::var("MEMVID_TICKET_PUBKEY")
197 .ok()
198 .and_then(|value| {
199 let trimmed = value.trim();
200 if trimmed.is_empty() {
201 None
202 } else {
203 Some(memvid_core::parse_ed25519_public_key_base64(trimmed))
204 }
205 })
206 .transpose()?;
207
208 let offline = env::var("MEMVID_OFFLINE")
209 .ok()
210 .map(|value| match value.trim().to_ascii_lowercase().as_str() {
211 "1" | "true" | "yes" => true,
212 _ => false,
213 })
214 .unwrap_or(false);
215
216 let embedding_model = env::var("MEMVID_EMBEDDING_MODEL")
218 .ok()
219 .and_then(|value| {
220 let trimmed = value.trim();
221 if trimmed.is_empty() {
222 None
223 } else {
224 EmbeddingModelChoice::from_str(trimmed).ok()
225 }
226 })
227 .unwrap_or_default();
228
229 Ok(Self {
230 api_key,
231 api_url,
232 cache_dir,
233 ticket_pubkey,
234 models_dir,
235 offline,
236 embedding_model,
237 })
238 }
239
240 pub fn with_embedding_model(&self, model: EmbeddingModelChoice) -> Self {
242 Self {
243 embedding_model: model,
244 ..self.clone()
245 }
246 }
247}
248
249fn expand_path(value: &str) -> Result<PathBuf> {
250 if value.trim().is_empty() {
251 return Err(anyhow!("cache directory cannot be empty"));
252 }
253
254 let expanded = if let Some(stripped) = value.strip_prefix("~/") {
255 home_dir()?.join(stripped)
256 } else if let Some(stripped) = value.strip_prefix("~\\") {
257 home_dir()?.join(stripped)
259 } else if value == "~" {
260 home_dir()?
261 } else {
262 PathBuf::from(value)
263 };
264
265 if expanded.is_absolute() {
266 Ok(expanded)
267 } else {
268 Ok(env::current_dir()?.join(expanded))
269 }
270}
271
272fn home_dir() -> Result<PathBuf> {
273 if let Some(path) = env::var_os("HOME") {
274 if !path.is_empty() {
275 return Ok(PathBuf::from(path));
276 }
277 }
278
279 #[cfg(windows)]
280 {
281 if let Some(path) = env::var_os("USERPROFILE") {
282 if !path.is_empty() {
283 return Ok(PathBuf::from(path));
284 }
285 }
286 if let (Some(drive), Some(path)) = (env::var_os("HOMEDRIVE"), env::var_os("HOMEPATH")) {
287 if !drive.is_empty() && !path.is_empty() {
288 return Ok(PathBuf::from(format!(
289 "{}{}",
290 drive.to_string_lossy(),
291 path.to_string_lossy()
292 )));
293 }
294 }
295 }
296
297 Err(anyhow!("unable to resolve home directory"))
298}
299
300#[cfg(test)]
301mod tests {
302 use super::*;
303 use base64::engine::general_purpose::STANDARD as BASE64_STANDARD;
304 use base64::Engine;
305 use ed25519_dalek::SigningKey;
306 use std::sync::{Mutex, OnceLock};
307
308 fn env_lock() -> std::sync::MutexGuard<'static, ()> {
309 static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
310 LOCK.get_or_init(|| Mutex::new(())).lock().unwrap()
311 }
312
313 fn set_or_unset(var: &str, value: Option<String>) {
314 match value {
315 Some(v) => unsafe { env::set_var(var, v) },
316 None => unsafe { env::remove_var(var) },
317 }
318 }
319
320 #[test]
321 fn defaults_expand_using_home_directory() {
322 let _guard = env_lock();
323
324 let previous_home = env::var("HOME").ok();
325 #[cfg(windows)]
326 let previous_userprofile = env::var("USERPROFILE").ok();
327
328 for var in [
329 "MEMVID_API_KEY",
330 "MEMVID_API_URL",
331 "MEMVID_CACHE_DIR",
332 "MEMVID_TICKET_PUBKEY",
333 "MEMVID_MODELS_DIR",
334 "MEMVID_OFFLINE",
335 ] {
336 unsafe { env::remove_var(var) };
337 }
338
339 let tmp = tempfile::tempdir().expect("tmpdir");
340 let tmp_path = tmp.path().to_path_buf();
341 unsafe { env::set_var("HOME", &tmp_path) };
342 #[cfg(windows)]
343 unsafe {
344 env::set_var("USERPROFILE", &tmp_path)
345 };
346
347 let config = CliConfig::load().expect("load");
348 assert_eq!(config.api_key, None);
349 assert_eq!(config.api_url, DEFAULT_API_URL);
350 assert_eq!(config.cache_dir, tmp_path.join(".cache/memvid"));
351 assert!(config.ticket_pubkey.is_none());
352 assert_eq!(config.models_dir, tmp_path.join(".memvid/models"));
353 assert!(!config.offline);
354
355 set_or_unset("HOME", previous_home);
356 #[cfg(windows)]
357 {
358 set_or_unset("USERPROFILE", previous_userprofile);
359 }
360 }
361
362 #[test]
363 fn env_overrides_are_respected() {
364 let _guard = env_lock();
365
366 let previous_env: Vec<(&'static str, Option<String>)> = [
367 "MEMVID_API_KEY",
368 "MEMVID_API_URL",
369 "MEMVID_CACHE_DIR",
370 "MEMVID_TICKET_PUBKEY",
371 "MEMVID_MODELS_DIR",
372 "MEMVID_OFFLINE",
373 ]
374 .into_iter()
375 .map(|var| (var, env::var(var).ok()))
376 .collect();
377
378 unsafe { env::set_var("MEMVID_API_KEY", "abc123") };
379 unsafe { env::set_var("MEMVID_API_URL", "https://staging.memvid.app") };
380 unsafe { env::set_var("MEMVID_CACHE_DIR", "~/memvid-cache") };
381 unsafe { env::set_var("MEMVID_MODELS_DIR", "~/models") };
382 unsafe { env::set_var("MEMVID_OFFLINE", "true") };
383 let signing = SigningKey::from_bytes(&[9u8; 32]);
384 let encoded = BASE64_STANDARD.encode(signing.verifying_key().as_bytes());
385 unsafe { env::set_var("MEMVID_TICKET_PUBKEY", encoded) };
386
387 let tmp = tempfile::tempdir().expect("tmpdir");
388 let tmp_path = tmp.path().to_path_buf();
389 unsafe { env::set_var("HOME", &tmp_path) };
390 #[cfg(windows)]
391 unsafe {
392 env::set_var("USERPROFILE", &tmp_path)
393 };
394
395 let config = CliConfig::load().expect("load");
396 assert_eq!(config.api_key.as_deref(), Some("abc123"));
397 assert_eq!(config.api_url, "https://staging.memvid.app");
398 assert_eq!(config.cache_dir, tmp_path.join("memvid-cache"));
399 assert_eq!(
400 config.ticket_pubkey.expect("pubkey").as_bytes(),
401 signing.verifying_key().as_bytes()
402 );
403 assert_eq!(config.models_dir, tmp_path.join("models"));
404 assert!(config.offline);
405
406 for (var, value) in previous_env {
407 set_or_unset(var, value);
408 }
409 }
410
411 #[test]
412 fn rejects_empty_cache_dir() {
413 let _guard = env_lock();
414
415 let previous = env::var("MEMVID_CACHE_DIR").ok();
416 unsafe { env::set_var("MEMVID_CACHE_DIR", " ") };
417 let err = CliConfig::load().expect_err("should fail");
418 assert!(err.to_string().contains("cache directory"));
419 set_or_unset("MEMVID_CACHE_DIR", previous);
420 }
421}
422
423pub fn init_tracing(verbosity: u8) -> Result<()> {
425 use std::io::IsTerminal;
426 use tracing_subscriber::{filter::Directive, fmt, EnvFilter};
427
428 let level = match verbosity {
429 0 => "warn",
430 1 => "info",
431 2 => "debug",
432 _ => "trace",
433 };
434
435 let mut env_filter =
436 EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new(level));
437 for directive_str in ["llama_cpp=error", "llama_cpp_sys=error", "ggml=error"] {
438 if let Ok(directive) = directive_str.parse::<Directive>() {
439 env_filter = env_filter.add_directive(directive);
440 }
441 }
442
443 let use_ansi = std::io::stderr().is_terminal();
447
448 fmt()
449 .with_env_filter(env_filter)
450 .with_writer(std::io::stderr)
451 .with_target(false)
452 .without_time()
453 .with_ansi(use_ansi)
454 .try_init()
455 .map_err(|err| anyhow!(err))?;
456 Ok(())
457}
458
459pub fn resolve_llm_context_budget_override(cli_value: Option<usize>) -> Result<Option<usize>> {
461 use anyhow::bail;
462
463 if let Some(value) = cli_value {
464 if value == 0 {
465 bail!("--llm-context-depth must be a positive integer");
466 }
467 return Ok(Some(value));
468 }
469
470 let raw_env = match env::var("MEMVID_LLM_CONTEXT_BUDGET") {
471 Ok(value) => value,
472 Err(_) => return Ok(None),
473 };
474
475 let trimmed = raw_env.trim();
476 if trimmed.is_empty() {
477 return Ok(None);
478 }
479
480 let digits: String = trimmed
481 .chars()
482 .filter(|ch| !ch.is_ascii_whitespace() && *ch != '_')
483 .collect();
484
485 if digits.is_empty() {
486 bail!("MEMVID_LLM_CONTEXT_BUDGET must be a positive integer value");
487 }
488
489 let value: usize = digits.parse().map_err(|err| {
490 anyhow!(
491 "MEMVID_LLM_CONTEXT_BUDGET value '{}' is not a valid number: {}",
492 trimmed,
493 err
494 )
495 })?;
496
497 if value == 0 {
498 bail!("MEMVID_LLM_CONTEXT_BUDGET must be a positive integer");
499 }
500
501 Ok(Some(value))
502}
503
504use crate::openai_embeddings::OpenAIEmbeddingProvider;
505
506#[derive(Clone)]
508enum EmbeddingBackend {
509 FastEmbed(std::sync::Arc<std::sync::Mutex<fastembed::TextEmbedding>>),
510 OpenAI(std::sync::Arc<OpenAIEmbeddingProvider>),
511}
512
513#[derive(Clone)]
515pub struct EmbeddingRuntime {
516 backend: EmbeddingBackend,
517 dimension: usize,
518}
519
520impl EmbeddingRuntime {
521 fn new_fastembed(model: fastembed::TextEmbedding, dimension: usize) -> Self {
522 Self {
523 backend: EmbeddingBackend::FastEmbed(std::sync::Arc::new(std::sync::Mutex::new(model))),
524 dimension,
525 }
526 }
527
528 fn new_openai(provider: OpenAIEmbeddingProvider, dimension: usize) -> Self {
529 Self {
530 backend: EmbeddingBackend::OpenAI(std::sync::Arc::new(provider)),
531 dimension,
532 }
533 }
534
535 const MAX_EMBEDDING_TEXT_LEN: usize = 20_000;
538
539 fn truncate_for_embedding(text: &str) -> std::borrow::Cow<'_, str> {
541 if text.len() <= Self::MAX_EMBEDDING_TEXT_LEN {
542 std::borrow::Cow::Borrowed(text)
543 } else {
544 let truncated = &text[..Self::MAX_EMBEDDING_TEXT_LEN];
546 let end = truncated
547 .char_indices()
548 .rev()
549 .next()
550 .map(|(i, c)| i + c.len_utf8())
551 .unwrap_or(Self::MAX_EMBEDDING_TEXT_LEN);
552 tracing::info!(
553 "Truncated embedding text from {} to {} chars to fit OpenAI token limit",
554 text.len(),
555 end
556 );
557 std::borrow::Cow::Owned(text[..end].to_string())
558 }
559 }
560
561 pub fn embed(&self, text: &str) -> Result<Vec<f32>> {
562 let text = match &self.backend {
564 EmbeddingBackend::OpenAI(_) => Self::truncate_for_embedding(text),
565 _ => std::borrow::Cow::Borrowed(text),
566 };
567
568 match &self.backend {
569 EmbeddingBackend::FastEmbed(model) => {
570 let mut guard = model
571 .lock()
572 .map_err(|_| anyhow!("fastembed runtime poisoned"))?;
573 let outputs = guard
574 .embed(vec![text.into_owned()], None)
575 .map_err(|err| anyhow!("failed to compute embedding with fastembed: {err}"))?;
576 outputs
577 .into_iter()
578 .next()
579 .ok_or_else(|| anyhow!("fastembed returned no embedding output"))
580 }
581 EmbeddingBackend::OpenAI(provider) => {
582 use memvid_core::EmbeddingProvider;
583 provider
584 .embed_text(&text)
585 .map_err(|err| anyhow!("failed to compute embedding with OpenAI: {err}"))
586 }
587 }
588 }
589
590 pub fn dimension(&self) -> usize {
591 self.dimension
592 }
593}
594
595impl memvid_core::VecEmbedder for EmbeddingRuntime {
596 fn embed_query(&self, text: &str) -> memvid_core::Result<Vec<f32>> {
597 self.embed(text)
598 .map_err(|err| memvid_core::MemvidError::EmbeddingFailed {
599 reason: err.to_string().into_boxed_str(),
600 })
601 }
602
603 fn embedding_dimension(&self) -> usize {
604 self.dimension()
605 }
606}
607
608fn ensure_fastembed_cache(config: &CliConfig) -> Result<PathBuf> {
610 use std::fs;
611
612 let cache_dir = config.models_dir.clone();
613 fs::create_dir_all(&cache_dir)?;
614 Ok(cache_dir)
615}
616
617fn model_size_mb(model: EmbeddingModelChoice) -> usize {
619 match model {
620 EmbeddingModelChoice::BgeSmall => 33,
621 EmbeddingModelChoice::BgeBase => 110,
622 EmbeddingModelChoice::Nomic => 137,
623 EmbeddingModelChoice::GteLarge => 327,
624 EmbeddingModelChoice::OpenAILarge
626 | EmbeddingModelChoice::OpenAISmall
627 | EmbeddingModelChoice::OpenAIAda => 0,
628 }
629}
630
631fn instantiate_embedding_runtime(config: &CliConfig) -> Result<EmbeddingRuntime> {
633 use tracing::info;
634
635 let embedding_model = config.embedding_model;
636
637 info!(
638 "Loading embedding model: {} ({}D)",
639 embedding_model.name(),
640 embedding_model.dimensions()
641 );
642
643 if embedding_model.is_openai() {
645 return instantiate_openai_runtime(embedding_model);
646 }
647
648 instantiate_fastembed_runtime(config, embedding_model)
650}
651
652fn instantiate_openai_runtime(embedding_model: EmbeddingModelChoice) -> Result<EmbeddingRuntime> {
654 use anyhow::bail;
655 use memvid_core::EmbeddingConfig;
656 use tracing::info;
657
658 let api_key = std::env::var("OPENAI_API_KEY").map_err(|_| {
659 anyhow!("OPENAI_API_KEY environment variable is required for OpenAI embeddings")
660 })?;
661
662 if api_key.is_empty() {
663 bail!("OPENAI_API_KEY cannot be empty");
664 }
665
666 let config = match embedding_model {
667 EmbeddingModelChoice::OpenAILarge => EmbeddingConfig::openai_large(),
668 EmbeddingModelChoice::OpenAISmall => EmbeddingConfig::openai_small(),
669 EmbeddingModelChoice::OpenAIAda => EmbeddingConfig::openai_ada(),
670 _ => unreachable!("is_openai() should have been false"),
671 };
672
673 let provider = OpenAIEmbeddingProvider::new(api_key, config.clone())
674 .map_err(|err| anyhow!("failed to create OpenAI embedding provider: {err}"))?;
675
676 info!(
677 "OpenAI embedding provider ready: model={}, dimension={}",
678 config.model, config.dimension
679 );
680
681 Ok(EmbeddingRuntime::new_openai(provider, config.dimension))
682}
683
684fn instantiate_fastembed_runtime(
686 config: &CliConfig,
687 embedding_model: EmbeddingModelChoice,
688) -> Result<EmbeddingRuntime> {
689 use anyhow::bail;
690 use fastembed::{InitOptions, TextEmbedding};
691 use std::fs;
692
693 let cache_dir = ensure_fastembed_cache(config)?;
694
695 if config.offline {
696 let mut entries = fs::read_dir(&cache_dir)?;
697 if entries.next().is_none() {
698 bail!(
699 "semantic embeddings unavailable while offline; allow one connected run so fastembed can cache model weights"
700 );
701 }
702 }
703
704 let options = InitOptions::new(embedding_model.to_fastembed_model())
705 .with_cache_dir(cache_dir)
706 .with_show_download_progress(true);
707 let mut model = TextEmbedding::try_new(options).map_err(|err| {
708 let platform_hint = if cfg!(target_os = "windows") {
710 "\n\nWindows users: If model downloads fail, try:\n\
711 1. Run as Administrator\n\
712 2. Check your antivirus isn't blocking downloads\n\
713 3. Use OpenAI embeddings instead: set OPENAI_API_KEY and use --embedding-model openai"
714 } else if cfg!(target_os = "linux") {
715 "\n\nLinux users: If model downloads fail, try:\n\
716 1. Check disk space in ~/.memvid/models\n\
717 2. Ensure you have network access to huggingface.co\n\
718 3. Use OpenAI embeddings instead: export OPENAI_API_KEY=... and use --embedding-model openai"
719 } else {
720 "\n\nIf model downloads fail, try using OpenAI embeddings:\n\
721 export OPENAI_API_KEY=your-key && memvid ... --embedding-model openai"
722 };
723
724 anyhow!(
725 "Failed to initialize embedding model '{}': {err}\n\n\
726 This typically means the model couldn't be downloaded or loaded.\n\
727 Model size: ~{} MB{}\n\n\
728 See: https://docs.memvid.com/embedding-models",
729 embedding_model.name(),
730 model_size_mb(embedding_model),
731 platform_hint
732 )
733 })?;
734
735 let probe = model
736 .embed(vec!["memvid probe".to_string()], None)
737 .map_err(|err| anyhow!("failed to determine embedding dimension: {err}"))?;
738 let dimension = probe.first().map(|vec| vec.len()).unwrap_or(0);
739
740 if dimension == 0 {
741 bail!("fastembed reported zero-length embeddings");
742 }
743
744 if dimension != embedding_model.dimensions() {
746 tracing::warn!(
747 "Embedding dimension mismatch: expected {}, got {}",
748 embedding_model.dimensions(),
749 dimension
750 );
751 }
752
753 Ok(EmbeddingRuntime::new_fastembed(model, dimension))
754}
755
756pub fn load_embedding_runtime(config: &CliConfig) -> Result<EmbeddingRuntime> {
758 use anyhow::bail;
759
760 match instantiate_embedding_runtime(config) {
761 Ok(runtime) => Ok(runtime),
762 Err(err) => {
763 if config.offline {
764 bail!(
765 "semantic embeddings unavailable while offline; allow one connected run so fastembed can cache model weights ({err})"
766 );
767 }
768 Err(err)
769 }
770 }
771}
772
773pub fn try_load_embedding_runtime(config: &CliConfig) -> Option<EmbeddingRuntime> {
775 use tracing::warn;
776
777 match instantiate_embedding_runtime(config) {
778 Ok(runtime) => Some(runtime),
779 Err(err) => {
780 warn!("semantic embeddings unavailable: {err}");
781 None
782 }
783 }
784}
785
786pub fn load_embedding_runtime_with_model(
789 config: &CliConfig,
790 model_override: Option<&str>,
791) -> Result<EmbeddingRuntime> {
792 use tracing::info;
793
794 let embedding_model = match model_override {
795 Some(model_str) => {
796 let parsed = model_str.parse::<EmbeddingModelChoice>()?;
797 info!(
798 "Using embedding model override: {} ({}D)",
799 parsed.name(),
800 parsed.dimensions()
801 );
802 parsed
803 }
804 None => config.embedding_model,
805 };
806
807 info!(
808 "Loading embedding model: {} ({}D)",
809 embedding_model.name(),
810 embedding_model.dimensions()
811 );
812
813 if embedding_model.is_openai() {
814 return instantiate_openai_runtime(embedding_model);
815 }
816
817 instantiate_fastembed_runtime(config, embedding_model)
818}
819
820pub fn try_load_embedding_runtime_with_model(
822 config: &CliConfig,
823 model_override: Option<&str>,
824) -> Option<EmbeddingRuntime> {
825 use tracing::warn;
826
827 match load_embedding_runtime_with_model(config, model_override) {
828 Ok(runtime) => Some(runtime),
829 Err(err) => {
830 warn!("semantic embeddings unavailable: {err}");
831 None
832 }
833 }
834}
835
836pub fn load_embedding_runtime_for_mv2(
846 config: &CliConfig,
847 model_override: Option<&str>,
848 mv2_dimension: Option<u32>,
849) -> Result<EmbeddingRuntime> {
850 use tracing::info;
851
852 if let Some(model_str) = model_override {
854 return load_embedding_runtime_with_model(config, Some(model_str));
855 }
856
857 if let Some(dim) = mv2_dimension {
859 if let Some(detected_model) = EmbeddingModelChoice::from_dimension(dim) {
860 info!(
861 "Auto-detected embedding model from MV2: {} ({}D)",
862 detected_model.name(),
863 dim
864 );
865
866 if detected_model.is_openai() {
868 if std::env::var("OPENAI_API_KEY").is_ok() {
869 return load_embedding_runtime_with_model(config, Some(detected_model.name()));
870 } else {
871 return Err(anyhow!(
873 "MV2 file uses OpenAI embeddings ({}D) but OPENAI_API_KEY is not set.\n\n\
874 Options:\n\
875 1. Set OPENAI_API_KEY environment variable\n\
876 2. Use --query-embedding-model to specify a different model\n\
877 3. Use lexical-only search with --mode lex\n\n\
878 See: https://docs.memvid.com/embedding-models",
879 dim
880 ));
881 }
882 }
883
884 return load_embedding_runtime_with_model(config, Some(detected_model.name()));
885 }
886 }
887
888 load_embedding_runtime(config)
890}
891
892pub fn try_load_embedding_runtime_for_mv2(
894 config: &CliConfig,
895 model_override: Option<&str>,
896 mv2_dimension: Option<u32>,
897) -> Option<EmbeddingRuntime> {
898 use tracing::warn;
899
900 match load_embedding_runtime_for_mv2(config, model_override, mv2_dimension) {
901 Ok(runtime) => Some(runtime),
902 Err(err) => {
903 warn!("semantic embeddings unavailable: {err}");
904 None
905 }
906 }
907}