Skip to main content

codemem_embeddings/
lib.rs

1//! codemem-embeddings: Pluggable embedding providers for Codemem.
2//!
3//! Supports multiple backends:
4//! - **Candle** (default): Local BERT models via pure Rust ML (any HF BERT model)
5//! - **Ollama**: Local Ollama server with any embedding model
6//! - **OpenAI**: OpenAI API or any compatible endpoint (Together, Azure, etc.)
7//! - **Gemini**: Google Generative Language API (text-embedding-004)
8
9pub mod gemini;
10pub mod ollama;
11pub mod openai;
12
13use candle_core::{DType, Device, Tensor};
14use candle_nn::VarBuilder;
15use candle_transformers::models::bert::{BertModel, Config as BertConfig};
16use codemem_core::CodememError;
17use lru::LruCache;
18use std::num::NonZeroUsize;
19use std::path::{Path, PathBuf};
20use std::sync::Mutex;
21use tokenizers::{PaddingParams, PaddingStrategy};
22
23/// Default model name (short form used for directory naming).
24pub const MODEL_NAME: &str = "bge-base-en-v1.5";
25
26/// Default HuggingFace model repo ID.
27/// Used internally and by `commands_init` for the default model download.
28pub const DEFAULT_HF_REPO: &str = "BAAI/bge-base-en-v1.5";
29
30/// Default embedding dimensions for remote providers (Ollama/OpenAI).
31/// Candle reads `hidden_size` from the model's config.json instead.
32pub const DEFAULT_REMOTE_DIMENSIONS: usize = 768;
33
34/// Max sequence length for BERT models.
35const MAX_SEQ_LENGTH: usize = 512;
36
37/// Default LRU cache capacity.
38pub const CACHE_CAPACITY: usize = 10_000;
39
40// Re-export EmbeddingProvider trait from core
41pub use codemem_core::EmbeddingProvider;
42
43// ── Candle Embedding Service ────────────────────────────────────────────────
44
45/// Default batch size for batched embedding forward passes.
46/// Configurable via `EmbeddingConfig.batch_size` or `CODEMEM_EMBED_BATCH_SIZE`.
47pub const DEFAULT_BATCH_SIZE: usize = 16;
48
49/// Select the best available compute device.
50///
51/// Tries Metal (macOS GPU) first, then CUDA (NVIDIA GPU), then falls back to CPU.
52/// GPU backends are only available when the corresponding feature flag is enabled.
53fn select_device() -> Device {
54    #[cfg(feature = "metal")]
55    {
56        // Use catch_unwind to handle SIGBUS/panics on CI runners without GPU access.
57        match std::panic::catch_unwind(|| Device::new_metal(0)) {
58            Ok(Ok(device)) => {
59                tracing::info!("Using Metal GPU for embeddings");
60                return device;
61            }
62            Ok(Err(e)) => {
63                tracing::warn!("Metal device creation failed: {e}, falling back to CPU");
64            }
65            Err(_) => {
66                tracing::warn!("Metal device creation panicked, falling back to CPU");
67            }
68        }
69    }
70    #[cfg(feature = "cuda")]
71    {
72        match std::panic::catch_unwind(|| Device::new_cuda(0)) {
73            Ok(Ok(device)) => {
74                tracing::info!("Using CUDA GPU for embeddings");
75                return device;
76            }
77            Ok(Err(e)) => {
78                tracing::warn!("CUDA device creation failed: {e}, falling back to CPU");
79            }
80            Err(_) => {
81                tracing::warn!("CUDA device creation panicked, falling back to CPU");
82            }
83        }
84    }
85    tracing::info!("Using CPU for embeddings");
86    Device::Cpu
87}
88
89/// Embedding service with Candle inference (no internal cache — use `CachedProvider` wrapper).
90pub struct EmbeddingService {
91    model: Mutex<BertModel>,
92    /// Tokenizer pre-configured with truncation (no padding).
93    /// Used directly for single embeds; cloned and augmented with padding for batch.
94    tokenizer: tokenizers::Tokenizer,
95    device: Device,
96    /// Maximum texts per forward pass (GPU memory trade-off).
97    batch_size: usize,
98    /// Hidden size read from model config (e.g. 768 for bge-base, 384 for bge-small).
99    hidden_size: usize,
100}
101
102impl EmbeddingService {
103    /// Create a new embedding service, loading model from the given directory.
104    /// Expects `model.safetensors`, `config.json`, and `tokenizer.json` in the directory.
105    ///
106    /// `dtype` controls precision: `DType::F32` (default) or `DType::F16` (half memory, faster on Metal).
107    pub fn new(model_dir: &Path, batch_size: usize, dtype: DType) -> Result<Self, CodememError> {
108        let model_path = model_dir.join("model.safetensors");
109        let config_path = model_dir.join("config.json");
110        let tokenizer_path = model_dir.join("tokenizer.json");
111
112        if !model_path.exists() {
113            return Err(CodememError::Embedding(format!(
114                "Model not found at {}. Run `codemem init` to download it.",
115                model_path.display()
116            )));
117        }
118
119        let device = select_device();
120
121        tracing::info!(
122            "Loading model from {} (dtype: {:?}, device: {:?})",
123            model_dir.display(),
124            dtype,
125            device
126        );
127
128        // Load BERT config
129        let config_str = std::fs::read_to_string(&config_path)
130            .map_err(|e| CodememError::Embedding(format!("Failed to read config: {e}")))?;
131        let config: BertConfig = serde_json::from_str(&config_str)
132            .map_err(|e| CodememError::Embedding(format!("Failed to parse config: {e}")))?;
133
134        let hidden_size = config.hidden_size;
135
136        // Load model weights from safetensors via memory-mapped IO.
137        // Scope vb so it drops before a potential retry, avoiding two VarBuilders
138        // holding materialized Metal tensors simultaneously.
139        let model = {
140            let vb = unsafe {
141                VarBuilder::from_mmaped_safetensors(&[&model_path], dtype, &device)
142                    .map_err(|e| CodememError::Embedding(format!("Failed to load weights: {e}")))?
143            };
144            BertModel::load(vb.pp("bert"), &config)
145        };
146        // Try with "bert." prefix first (standard HF BERT models), then without
147        let model = match model {
148            Ok(m) => m,
149            Err(_) => {
150                let vb2 = unsafe {
151                    VarBuilder::from_mmaped_safetensors(&[&model_path], dtype, &device).map_err(
152                        |e| CodememError::Embedding(format!("Failed to load weights: {e}")),
153                    )?
154                };
155                BertModel::load(vb2, &config).map_err(|e| {
156                    CodememError::Embedding(format!("Failed to load BERT model: {e}"))
157                })?
158            }
159        };
160
161        let mut tokenizer = tokenizers::Tokenizer::from_file(&tokenizer_path)
162            .map_err(|e| CodememError::Embedding(e.to_string()))?;
163
164        // Pre-configure truncation once so we don't need to clone on every embed call.
165        tokenizer
166            .with_truncation(Some(tokenizers::TruncationParams {
167                max_length: MAX_SEQ_LENGTH,
168                ..Default::default()
169            }))
170            .map_err(|e| CodememError::Embedding(format!("Truncation error: {e}")))?;
171
172        Ok(Self {
173            model: Mutex::new(model),
174            tokenizer,
175            device,
176            batch_size,
177            hidden_size,
178        })
179    }
180
181    /// Get the model directory path for a given model name.
182    /// Falls back to `~/.codemem/models/{model_name}`.
183    pub fn model_dir_for(model_name: &str) -> PathBuf {
184        dirs::home_dir()
185            .unwrap_or_else(|| PathBuf::from("."))
186            .join(".codemem")
187            .join("models")
188            .join(model_name)
189    }
190
191    /// Get the default model directory path (~/.codemem/models/{MODEL_NAME}).
192    pub fn default_model_dir() -> PathBuf {
193        Self::model_dir_for(MODEL_NAME)
194    }
195
196    /// Download a model from HuggingFace Hub to the given directory.
197    /// `hf_repo` is the full repo ID (e.g. "BAAI/bge-base-en-v1.5").
198    /// Returns the directory path. No-ops if model already exists.
199    pub fn download_model(dest_dir: &Path, hf_repo: &str) -> Result<PathBuf, CodememError> {
200        let model_dest = dest_dir.join("model.safetensors");
201        let config_dest = dest_dir.join("config.json");
202        let tokenizer_dest = dest_dir.join("tokenizer.json");
203
204        if model_dest.exists() && config_dest.exists() && tokenizer_dest.exists() {
205            tracing::info!("Model already downloaded at {}", dest_dir.display());
206            return Ok(dest_dir.to_path_buf());
207        }
208
209        std::fs::create_dir_all(dest_dir)
210            .map_err(|e| CodememError::Embedding(format!("Failed to create dir: {e}")))?;
211
212        tracing::info!("Downloading {} from HuggingFace...", hf_repo);
213
214        let api = hf_hub::api::sync::Api::new()
215            .map_err(|e| CodememError::Embedding(format!("HuggingFace API error: {e}")))?;
216        let repo = api.model(hf_repo.to_string());
217
218        let cached_model = repo
219            .get("model.safetensors")
220            .map_err(|e| CodememError::Embedding(format!("Failed to download model: {e}")))?;
221
222        let cached_config = repo
223            .get("config.json")
224            .map_err(|e| CodememError::Embedding(format!("Failed to download config: {e}")))?;
225
226        let cached_tokenizer = repo
227            .get("tokenizer.json")
228            .map_err(|e| CodememError::Embedding(format!("Failed to download tokenizer: {e}")))?;
229
230        std::fs::copy(&cached_model, &model_dest)
231            .map_err(|e| CodememError::Embedding(format!("Failed to copy model: {e}")))?;
232        std::fs::copy(&cached_config, &config_dest)
233            .map_err(|e| CodememError::Embedding(format!("Failed to copy config: {e}")))?;
234        std::fs::copy(&cached_tokenizer, &tokenizer_dest)
235            .map_err(|e| CodememError::Embedding(format!("Failed to copy tokenizer: {e}")))?;
236
237        tracing::info!("Model downloaded to {}", dest_dir.display());
238        Ok(dest_dir.to_path_buf())
239    }
240
241    /// Download the default model (BAAI/bge-base-en-v1.5) to the default directory.
242    /// Convenience wrapper for `download_model(&default_model_dir(), DEFAULT_HF_REPO)`.
243    pub fn download_default_model() -> Result<PathBuf, CodememError> {
244        Self::download_model(&Self::default_model_dir(), DEFAULT_HF_REPO)
245    }
246
247    /// Embed a single text string. Returns an L2-normalized vector (dimension = model's hidden_size).
248    pub fn embed(&self, text: &str) -> Result<Vec<f32>, CodememError> {
249        // Tokenize using pre-configured tokenizer (truncation already set in constructor)
250        let encoding = self
251            .tokenizer
252            .encode(text, true)
253            .map_err(|e| CodememError::Embedding(e.to_string()))?;
254
255        let input_ids: Vec<u32> = encoding.get_ids().to_vec();
256        let attention_mask: Vec<u32> = encoding.get_attention_mask().to_vec();
257
258        // Build candle tensors with shape [1, seq_len]
259        let input_ids_tensor = Tensor::new(&input_ids[..], &self.device)
260            .and_then(|t| t.unsqueeze(0))
261            .map_err(|e| CodememError::Embedding(format!("Tensor error: {e}")))?;
262
263        let token_type_ids = input_ids_tensor
264            .zeros_like()
265            .map_err(|e| CodememError::Embedding(format!("Tensor error: {e}")))?;
266
267        let attention_mask_tensor = Tensor::new(&attention_mask[..], &self.device)
268            .and_then(|t| t.unsqueeze(0))
269            .map_err(|e| CodememError::Embedding(format!("Tensor error: {e}")))?;
270
271        // Forward pass -> [1, seq_len, hidden_size]
272        let model = self
273            .model
274            .lock()
275            .map_err(|e| CodememError::LockPoisoned(format!("embedding model: {e}")))?;
276        let hidden_states = model
277            .forward(
278                &input_ids_tensor,
279                &token_type_ids,
280                Some(&attention_mask_tensor),
281            )
282            .map_err(|e| CodememError::Embedding(format!("Model forward error: {e}")))?;
283        drop(model);
284
285        // Drop intermediate GPU tensors before pooling
286        drop(input_ids_tensor);
287        drop(token_type_ids);
288
289        // Cast hidden states to F32 for pooling math (model may output F16/BF16)
290        let hidden_states = hidden_states
291            .to_dtype(DType::F32)
292            .map_err(|e| CodememError::Embedding(format!("Cast error: {e}")))?;
293
294        // Mean pooling weighted by attention mask
295        // attention_mask: [1, seq_len] -> [1, seq_len, 1] for broadcasting
296        let mask = attention_mask_tensor
297            .to_dtype(DType::F32)
298            .and_then(|t| t.unsqueeze(2))
299            .map_err(|e| CodememError::Embedding(format!("Mask error: {e}")))?;
300
301        let sum_mask = mask
302            .sum(1)
303            .map_err(|e| CodememError::Embedding(format!("Sum error: {e}")))?;
304
305        let pooled = hidden_states
306            .broadcast_mul(&mask)
307            .and_then(|t| t.sum(1))
308            .and_then(|t| t.broadcast_div(&sum_mask))
309            .map_err(|e| CodememError::Embedding(format!("Pooling error: {e}")))?;
310
311        // L2 normalize
312        let normalized = pooled
313            .sqr()
314            .and_then(|t| t.sum_keepdim(1))
315            .and_then(|t| t.sqrt())
316            .and_then(|norm| pooled.broadcast_div(&norm))
317            .map_err(|e| CodememError::Embedding(format!("Normalize error: {e}")))?;
318
319        // Extract as Vec<f32> — shape is [1, hidden_size], squeeze to [hidden_size]
320        let embedding: Vec<f32> = normalized
321            .squeeze(0)
322            .and_then(|t| t.to_vec1())
323            .map_err(|e| CodememError::Embedding(format!("Extract error: {e}")))?;
324
325        Ok(embedding)
326    }
327
328    /// Embed a batch of texts using a true batched forward pass.
329    ///
330    /// Tokenizes all texts, pads to the longest sequence in each chunk, runs a
331    /// single forward pass per chunk of up to `batch_size` texts, then performs
332    /// mean pooling and L2 normalization on the batched output.
333    pub fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, CodememError> {
334        if texts.is_empty() {
335            return Ok(vec![]);
336        }
337
338        let mut all_embeddings = Vec::with_capacity(texts.len());
339
340        for chunk in texts.chunks(self.batch_size) {
341            // Clone tokenizer only for batch path — needs per-chunk padding config.
342            // Truncation is already configured on self.tokenizer.
343            let mut tokenizer = self.tokenizer.clone();
344            tokenizer.with_padding(Some(PaddingParams {
345                strategy: PaddingStrategy::BatchLongest,
346                ..Default::default()
347            }));
348
349            let encodings = tokenizer
350                .encode_batch(chunk.to_vec(), true)
351                .map_err(|e| CodememError::Embedding(format!("Batch encode error: {e}")))?;
352
353            let batch_len = encodings.len();
354            let seq_len = encodings[0].get_ids().len();
355
356            // Flatten token IDs and attention masks into contiguous arrays
357            let all_ids: Vec<u32> = encodings
358                .iter()
359                .flat_map(|e| e.get_ids())
360                .copied()
361                .collect();
362            let all_masks: Vec<u32> = encodings
363                .iter()
364                .flat_map(|e| e.get_attention_mask())
365                .copied()
366                .collect();
367
368            // Build tensors with shape [batch_size, seq_len]
369            let input_ids = Tensor::new(all_ids.as_slice(), &self.device)
370                .and_then(|t| t.reshape((batch_len, seq_len)))
371                .map_err(|e| CodememError::Embedding(format!("Tensor error: {e}")))?;
372
373            let token_type_ids = input_ids
374                .zeros_like()
375                .map_err(|e| CodememError::Embedding(format!("Tensor error: {e}")))?;
376
377            let attention_mask = Tensor::new(all_masks.as_slice(), &self.device)
378                .and_then(|t| t.reshape((batch_len, seq_len)))
379                .map_err(|e| CodememError::Embedding(format!("Tensor error: {e}")))?;
380
381            // Single forward pass -> [batch_size, seq_len, hidden_size]
382            let model = self
383                .model
384                .lock()
385                .map_err(|e| CodememError::LockPoisoned(format!("embedding model: {e}")))?;
386            let hidden_states = model
387                .forward(&input_ids, &token_type_ids, Some(&attention_mask))
388                .map_err(|e| CodememError::Embedding(format!("Forward error: {e}")))?;
389            drop(model);
390
391            // Drop intermediate GPU tensors before pooling
392            drop(input_ids);
393            drop(token_type_ids);
394
395            // Cast hidden states to F32 for pooling math (model may output F16/BF16)
396            let hidden_states = hidden_states
397                .to_dtype(DType::F32)
398                .map_err(|e| CodememError::Embedding(format!("Cast error: {e}")))?;
399
400            // Mean pooling: mask [batch, seq] -> [batch, seq, 1] for broadcast
401            let mask = attention_mask
402                .to_dtype(DType::F32)
403                .and_then(|t| t.unsqueeze(2))
404                .map_err(|e| CodememError::Embedding(format!("Mask error: {e}")))?;
405
406            let sum_mask = mask
407                .sum(1)
408                .map_err(|e| CodememError::Embedding(format!("Sum error: {e}")))?;
409
410            let pooled = hidden_states
411                .broadcast_mul(&mask)
412                .and_then(|t| t.sum(1))
413                .and_then(|t| t.broadcast_div(&sum_mask))
414                .map_err(|e| CodememError::Embedding(format!("Pooling error: {e}")))?;
415
416            // L2 normalize: [batch, hidden]
417            let norm = pooled
418                .sqr()
419                .and_then(|t| t.sum_keepdim(1))
420                .and_then(|t| t.sqrt())
421                .map_err(|e| CodememError::Embedding(format!("Norm error: {e}")))?;
422
423            let normalized = pooled
424                .broadcast_div(&norm)
425                .map_err(|e| CodememError::Embedding(format!("Normalize error: {e}")))?;
426
427            // Single GPU→CPU blit: flatten all rows, then slice on CPU.
428            // to_vec1() implicitly syncs the GPU pipeline (data must be ready to read).
429            let flat: Vec<f32> = normalized
430                .flatten_all()
431                .and_then(|t| t.to_vec1())
432                .map_err(|e| CodememError::Embedding(format!("Extract error: {e}")))?;
433            for i in 0..batch_len {
434                let start = i * self.hidden_size;
435                all_embeddings.push(flat[start..start + self.hidden_size].to_vec());
436            }
437        }
438
439        Ok(all_embeddings)
440    }
441}
442
443impl EmbeddingProvider for EmbeddingService {
444    fn dimensions(&self) -> usize {
445        self.hidden_size
446    }
447
448    fn embed(&self, text: &str) -> Result<Vec<f32>, CodememError> {
449        self.embed(text)
450    }
451
452    fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, CodememError> {
453        self.embed_batch(texts)
454    }
455
456    fn name(&self) -> &str {
457        "candle"
458    }
459}
460
461// ── Cached Provider Wrapper ───────────────────────────────────────────────
462
463/// Wraps any `EmbeddingProvider` with an LRU cache.
464pub struct CachedProvider {
465    inner: Box<dyn EmbeddingProvider>,
466    cache: Mutex<LruCache<String, Vec<f32>>>,
467}
468
469impl CachedProvider {
470    pub fn new(inner: Box<dyn EmbeddingProvider>, capacity: usize) -> Self {
471        // SAFETY: 1 is non-zero, so the inner expect is infallible
472        let cap =
473            NonZeroUsize::new(capacity).unwrap_or(NonZeroUsize::new(1).expect("1 is non-zero"));
474        Self {
475            inner,
476            cache: Mutex::new(LruCache::new(cap)),
477        }
478    }
479}
480
481impl EmbeddingProvider for CachedProvider {
482    fn dimensions(&self) -> usize {
483        self.inner.dimensions()
484    }
485
486    fn embed(&self, text: &str) -> Result<Vec<f32>, CodememError> {
487        {
488            let mut cache = self
489                .cache
490                .lock()
491                .map_err(|e| CodememError::LockPoisoned(format!("cached provider: {e}")))?;
492            if let Some(cached) = cache.get(text) {
493                return Ok(cached.clone());
494            }
495        }
496        let embedding = self.inner.embed(text)?;
497        {
498            let mut cache = self
499                .cache
500                .lock()
501                .map_err(|e| CodememError::LockPoisoned(format!("cached provider: {e}")))?;
502            cache.put(text.to_string(), embedding.clone());
503        }
504        Ok(embedding)
505    }
506
507    fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, CodememError> {
508        // Check cache, only forward uncached texts
509        let mut results: Vec<Option<Vec<f32>>> = vec![None; texts.len()];
510        let mut uncached = Vec::new();
511        let mut uncached_idx = Vec::new();
512
513        {
514            let mut cache = self
515                .cache
516                .lock()
517                .map_err(|e| CodememError::LockPoisoned(format!("cached provider: {e}")))?;
518            for (i, text) in texts.iter().enumerate() {
519                if let Some(cached) = cache.get(*text) {
520                    results[i] = Some(cached.clone());
521                } else {
522                    uncached_idx.push(i);
523                    uncached.push(*text);
524                }
525            }
526        }
527
528        if !uncached.is_empty() {
529            let new_embeddings = self.inner.embed_batch(&uncached)?;
530            let mut cache = self
531                .cache
532                .lock()
533                .map_err(|e| CodememError::LockPoisoned(format!("cached provider: {e}")))?;
534            for (idx, embedding) in uncached_idx.into_iter().zip(new_embeddings) {
535                cache.put(texts[idx].to_string(), embedding.clone());
536                results[idx] = Some(embedding);
537            }
538        }
539
540        // Verify all texts got embeddings — flatten() would silently drop Nones
541        let expected = texts.len();
542        let output: Vec<Vec<f32>> = results
543            .into_iter()
544            .enumerate()
545            .map(|(i, opt)| {
546                opt.ok_or_else(|| {
547                    CodememError::Embedding(format!(
548                        "Missing embedding for text at index {i} (batch size {expected})"
549                    ))
550                })
551            })
552            .collect::<Result<Vec<_>, _>>()?;
553        Ok(output)
554    }
555
556    fn name(&self) -> &str {
557        self.inner.name()
558    }
559
560    fn cache_stats(&self) -> (usize, usize) {
561        match self.cache.lock() {
562            Ok(cache) => (cache.len(), cache.cap().into()),
563            Err(_) => (0, 0),
564        }
565    }
566}
567
568// ── Factory ───────────────────────────────────────────────────────────────
569
570/// Parse a dtype string into a Candle DType.
571///
572/// Supported values: "f32" (default), "f16" (half precision — less memory, faster on Metal).
573pub fn parse_dtype(s: &str) -> Result<DType, CodememError> {
574    match s.to_lowercase().as_str() {
575        "f32" | "float32" | "" => Ok(DType::F32),
576        "f16" | "float16" | "half" => Ok(DType::F16),
577        "bf16" | "bfloat16" => Ok(DType::BF16),
578        other => Err(CodememError::Embedding(format!(
579            "Unknown dtype: '{}'. Use 'f32', 'f16', or 'bf16'.",
580            other
581        ))),
582    }
583}
584
585/// Resolve the HuggingFace repo ID and local directory name from a model identifier.
586///
587/// Accepts:
588/// - Full HF repo: `"BAAI/bge-base-en-v1.5"` → repo=`"BAAI/bge-base-en-v1.5"`, dir=`"bge-base-en-v1.5"`
589/// - Short name: `"bge-small-en-v1.5"` → repo=`"BAAI/bge-small-en-v1.5"`, dir=`"bge-small-en-v1.5"`
590///
591/// Returns `Err` if the model identifier is a bare name without an org prefix and isn't
592/// a recognized `bge-*` shorthand — HuggingFace requires `org/repo` format.
593fn resolve_model_id(model: &str) -> Result<(String, String), CodememError> {
594    if model.contains('/') {
595        // Full repo ID — directory name is the part after the slash
596        let dir_name = model.rsplit('/').next().unwrap_or(model);
597        Ok((model.to_string(), dir_name.to_string()))
598    } else if model.starts_with("bge-") {
599        // Short name — assume BAAI namespace for bge-* models
600        Ok((format!("BAAI/{model}"), model.to_string()))
601    } else {
602        Err(CodememError::Embedding(format!(
603            "Model identifier '{}' must be a full HuggingFace repo ID (e.g., 'BAAI/bge-base-en-v1.5' \
604             or 'sentence-transformers/all-MiniLM-L6-v2'). Short names are only supported for 'bge-*' models.",
605            model
606        )))
607    }
608}
609
610/// Create an embedding provider from environment variables.
611///
612/// When `config` is provided, its fields serve as defaults; env vars override them.
613///
614/// | Variable | Values | Default |
615/// |----------|--------|---------|
616/// | `CODEMEM_EMBED_PROVIDER` | `candle`, `ollama`, `openai`, `gemini` | `candle` |
617/// | `CODEMEM_EMBED_MODEL` | model name or HF repo | `BAAI/bge-base-en-v1.5` |
618/// | `CODEMEM_EMBED_URL` | base URL | provider default |
619/// | `CODEMEM_EMBED_API_KEY` | API key | also reads `OPENAI_API_KEY` / `GEMINI_API_KEY` / `GOOGLE_API_KEY` |
620/// | `CODEMEM_EMBED_DIMENSIONS` | integer | read from model config |
621/// | `CODEMEM_EMBED_BATCH_SIZE` | integer | `16` |
622/// | `CODEMEM_EMBED_DTYPE` | `f32`, `f16`, `bf16` | `f32` |
623pub fn from_env(
624    config: Option<&codemem_core::EmbeddingConfig>,
625) -> Result<Box<dyn EmbeddingProvider>, CodememError> {
626    let provider = std::env::var("CODEMEM_EMBED_PROVIDER")
627        .unwrap_or_else(|_| {
628            config
629                .map(|c| c.provider.clone())
630                .unwrap_or_else(|| "candle".to_string())
631        })
632        .to_lowercase();
633    // For Ollama/OpenAI, dimensions must be specified explicitly (remote APIs need it).
634    // For Candle, this value is ignored — hidden_size is read from the model's config.json.
635    let dimensions: usize = std::env::var("CODEMEM_EMBED_DIMENSIONS")
636        .ok()
637        .and_then(|s| s.parse().ok())
638        .unwrap_or_else(|| config.map_or(DEFAULT_REMOTE_DIMENSIONS, |c| c.dimensions));
639    let cache_capacity = config.map_or(CACHE_CAPACITY, |c| c.cache_capacity);
640    let batch_size: usize = std::env::var("CODEMEM_EMBED_BATCH_SIZE")
641        .ok()
642        .and_then(|s| s.parse().ok())
643        .unwrap_or_else(|| config.map_or(DEFAULT_BATCH_SIZE, |c| c.batch_size));
644
645    match provider.as_str() {
646        "ollama" => {
647            let base_url = std::env::var("CODEMEM_EMBED_URL").unwrap_or_else(|_| {
648                config
649                    .filter(|c| !c.url.is_empty())
650                    .map(|c| c.url.clone())
651                    .unwrap_or_else(|| ollama::DEFAULT_BASE_URL.to_string())
652            });
653            let model = std::env::var("CODEMEM_EMBED_MODEL").unwrap_or_else(|_| {
654                config
655                    .filter(|c| !c.model.is_empty())
656                    .map(|c| c.model.clone())
657                    .unwrap_or_else(|| ollama::DEFAULT_MODEL.to_string())
658            });
659            let inner = Box::new(ollama::OllamaProvider::new(&base_url, &model, dimensions));
660            Ok(Box::new(CachedProvider::new(inner, cache_capacity)))
661        }
662        "openai" => {
663            let api_key = std::env::var("CODEMEM_EMBED_API_KEY")
664                .or_else(|_| std::env::var("OPENAI_API_KEY"))
665                .map_err(|_| {
666                    CodememError::Embedding(
667                        "CODEMEM_EMBED_API_KEY or OPENAI_API_KEY required for OpenAI embeddings"
668                            .into(),
669                    )
670                })?;
671            let model = std::env::var("CODEMEM_EMBED_MODEL").unwrap_or_else(|_| {
672                config
673                    .filter(|c| !c.model.is_empty())
674                    .map(|c| c.model.clone())
675                    .unwrap_or_else(|| openai::DEFAULT_MODEL.to_string())
676            });
677            let base_url = std::env::var("CODEMEM_EMBED_URL")
678                .ok()
679                .or_else(|| config.filter(|c| !c.url.is_empty()).map(|c| c.url.clone()));
680            let inner = Box::new(openai::OpenAIProvider::new(
681                &api_key,
682                &model,
683                dimensions,
684                base_url.as_deref(),
685            ));
686            Ok(Box::new(CachedProvider::new(inner, cache_capacity)))
687        }
688        "gemini" | "google" => {
689            let api_key = std::env::var("CODEMEM_EMBED_API_KEY")
690                .or_else(|_| std::env::var("GEMINI_API_KEY"))
691                .or_else(|_| std::env::var("GOOGLE_API_KEY"))
692                .map_err(|_| {
693                    CodememError::Embedding(
694                        "CODEMEM_EMBED_API_KEY, GEMINI_API_KEY, or GOOGLE_API_KEY required for Gemini embeddings"
695                            .into(),
696                    )
697                })?;
698            let model = std::env::var("CODEMEM_EMBED_MODEL").unwrap_or_else(|_| {
699                config
700                    .filter(|c| !c.model.is_empty())
701                    .map(|c| c.model.clone())
702                    .unwrap_or_else(|| gemini::DEFAULT_MODEL.to_string())
703            });
704            let base_url = std::env::var("CODEMEM_EMBED_URL")
705                .ok()
706                .or_else(|| config.filter(|c| !c.url.is_empty()).map(|c| c.url.clone()));
707            let inner = Box::new(gemini::GeminiProvider::new(
708                &api_key,
709                &model,
710                dimensions,
711                base_url.as_deref(),
712            ));
713            Ok(Box::new(CachedProvider::new(inner, cache_capacity)))
714        }
715        "candle" | "" => {
716            let model_id = std::env::var("CODEMEM_EMBED_MODEL").unwrap_or_else(|_| {
717                config
718                    .filter(|c| !c.model.is_empty())
719                    .map(|c| c.model.clone())
720                    .unwrap_or_else(|| DEFAULT_HF_REPO.to_string())
721            });
722            let (hf_repo, dir_name) = resolve_model_id(&model_id)?;
723            let model_dir = EmbeddingService::model_dir_for(&dir_name);
724
725            let dtype_str = std::env::var("CODEMEM_EMBED_DTYPE").unwrap_or_else(|_| {
726                config
727                    .filter(|c| !c.dtype.is_empty())
728                    .map(|c| c.dtype.clone())
729                    .unwrap_or_else(|| "f32".to_string())
730            });
731            let dtype = parse_dtype(&dtype_str)?;
732
733            let service = EmbeddingService::new(&model_dir, batch_size, dtype).map_err(|e| {
734                // Enhance error message with download hint for non-default models
735                if e.to_string().contains("Model not found") && hf_repo != DEFAULT_HF_REPO {
736                    CodememError::Embedding(format!(
737                        "Model '{}' not found at {}. Download it with:\n  \
738                         CODEMEM_EMBED_MODEL={} codemem init",
739                        hf_repo,
740                        model_dir.display(),
741                        hf_repo
742                    ))
743                } else {
744                    e
745                }
746            })?;
747            Ok(Box::new(CachedProvider::new(
748                Box::new(service),
749                cache_capacity,
750            )))
751        }
752        other => Err(CodememError::Embedding(format!(
753            "Unknown embedding provider: '{}'. Use 'candle', 'ollama', 'openai', or 'gemini'.",
754            other
755        ))),
756    }
757}
758
759#[cfg(test)]
760#[path = "tests/lib_tests.rs"]
761mod tests;