Skip to main content

codemem_embeddings/
lib.rs

1//! codemem-embeddings: Pluggable embedding providers for Codemem.
2//!
3//! Supports multiple backends:
4//! - **Candle** (default): Local BERT models via pure Rust ML (any HF BERT model)
5//! - **Ollama**: Local Ollama server with any embedding model
6//! - **OpenAI**: OpenAI API or any compatible endpoint (Together, Azure, etc.)
7
8pub mod ollama;
9pub mod openai;
10
11use candle_core::{DType, Device, Tensor};
12use candle_nn::VarBuilder;
13use candle_transformers::models::bert::{BertModel, Config as BertConfig};
14use codemem_core::CodememError;
15use lru::LruCache;
16use std::num::NonZeroUsize;
17use std::path::{Path, PathBuf};
18use std::sync::Mutex;
19use tokenizers::{PaddingParams, PaddingStrategy};
20
21/// Default model name (short form used for directory naming).
22pub const MODEL_NAME: &str = "bge-base-en-v1.5";
23
24/// Default HuggingFace model repo ID.
25/// Used internally and by `commands_init` for the default model download.
26pub const DEFAULT_HF_REPO: &str = "BAAI/bge-base-en-v1.5";
27
28/// Default embedding dimensions for remote providers (Ollama/OpenAI).
29/// Candle reads `hidden_size` from the model's config.json instead.
30pub const DEFAULT_REMOTE_DIMENSIONS: usize = 768;
31
32/// Max sequence length for BERT models.
33const MAX_SEQ_LENGTH: usize = 512;
34
35/// Default LRU cache capacity.
36pub const CACHE_CAPACITY: usize = 10_000;
37
38// Re-export EmbeddingProvider trait from core
39pub use codemem_core::EmbeddingProvider;
40
41// ── Candle Embedding Service ────────────────────────────────────────────────
42
43/// Default batch size for batched embedding forward passes.
44/// Configurable via `EmbeddingConfig.batch_size` or `CODEMEM_EMBED_BATCH_SIZE`.
45pub const DEFAULT_BATCH_SIZE: usize = 16;
46
47/// Select the best available compute device.
48///
49/// Tries Metal (macOS GPU) first, then CUDA (NVIDIA GPU), then falls back to CPU.
50/// GPU backends are only available when the corresponding feature flag is enabled.
51fn select_device() -> Device {
52    #[cfg(feature = "metal")]
53    {
54        // Use catch_unwind to handle SIGBUS/panics on CI runners without GPU access.
55        match std::panic::catch_unwind(|| Device::new_metal(0)) {
56            Ok(Ok(device)) => {
57                tracing::info!("Using Metal GPU for embeddings");
58                return device;
59            }
60            Ok(Err(e)) => {
61                tracing::warn!("Metal device creation failed: {e}, falling back to CPU");
62            }
63            Err(_) => {
64                tracing::warn!("Metal device creation panicked, falling back to CPU");
65            }
66        }
67    }
68    #[cfg(feature = "cuda")]
69    {
70        match std::panic::catch_unwind(|| Device::new_cuda(0)) {
71            Ok(Ok(device)) => {
72                tracing::info!("Using CUDA GPU for embeddings");
73                return device;
74            }
75            Ok(Err(e)) => {
76                tracing::warn!("CUDA device creation failed: {e}, falling back to CPU");
77            }
78            Err(_) => {
79                tracing::warn!("CUDA device creation panicked, falling back to CPU");
80            }
81        }
82    }
83    tracing::info!("Using CPU for embeddings");
84    Device::Cpu
85}
86
87/// Embedding service with Candle inference (no internal cache — use `CachedProvider` wrapper).
88pub struct EmbeddingService {
89    model: Mutex<BertModel>,
90    /// Tokenizer pre-configured with truncation (no padding).
91    /// Used directly for single embeds; cloned and augmented with padding for batch.
92    tokenizer: tokenizers::Tokenizer,
93    device: Device,
94    /// Maximum texts per forward pass (GPU memory trade-off).
95    batch_size: usize,
96    /// Hidden size read from model config (e.g. 768 for bge-base, 384 for bge-small).
97    hidden_size: usize,
98}
99
100impl EmbeddingService {
101    /// Create a new embedding service, loading model from the given directory.
102    /// Expects `model.safetensors`, `config.json`, and `tokenizer.json` in the directory.
103    ///
104    /// `dtype` controls precision: `DType::F32` (default) or `DType::F16` (half memory, faster on Metal).
105    pub fn new(model_dir: &Path, batch_size: usize, dtype: DType) -> Result<Self, CodememError> {
106        let model_path = model_dir.join("model.safetensors");
107        let config_path = model_dir.join("config.json");
108        let tokenizer_path = model_dir.join("tokenizer.json");
109
110        if !model_path.exists() {
111            return Err(CodememError::Embedding(format!(
112                "Model not found at {}. Run `codemem init` to download it.",
113                model_path.display()
114            )));
115        }
116
117        let device = select_device();
118
119        tracing::info!(
120            "Loading model from {} (dtype: {:?}, device: {:?})",
121            model_dir.display(),
122            dtype,
123            device
124        );
125
126        // Load BERT config
127        let config_str = std::fs::read_to_string(&config_path)
128            .map_err(|e| CodememError::Embedding(format!("Failed to read config: {e}")))?;
129        let config: BertConfig = serde_json::from_str(&config_str)
130            .map_err(|e| CodememError::Embedding(format!("Failed to parse config: {e}")))?;
131
132        let hidden_size = config.hidden_size;
133
134        // Load model weights from safetensors via memory-mapped IO.
135        // Scope vb so it drops before a potential retry, avoiding two VarBuilders
136        // holding materialized Metal tensors simultaneously.
137        let model = {
138            let vb = unsafe {
139                VarBuilder::from_mmaped_safetensors(&[&model_path], dtype, &device)
140                    .map_err(|e| CodememError::Embedding(format!("Failed to load weights: {e}")))?
141            };
142            BertModel::load(vb.pp("bert"), &config)
143        };
144        // Try with "bert." prefix first (standard HF BERT models), then without
145        let model = match model {
146            Ok(m) => m,
147            Err(_) => {
148                let vb2 = unsafe {
149                    VarBuilder::from_mmaped_safetensors(&[&model_path], dtype, &device).map_err(
150                        |e| CodememError::Embedding(format!("Failed to load weights: {e}")),
151                    )?
152                };
153                BertModel::load(vb2, &config).map_err(|e| {
154                    CodememError::Embedding(format!("Failed to load BERT model: {e}"))
155                })?
156            }
157        };
158
159        let mut tokenizer = tokenizers::Tokenizer::from_file(&tokenizer_path)
160            .map_err(|e| CodememError::Embedding(e.to_string()))?;
161
162        // Pre-configure truncation once so we don't need to clone on every embed call.
163        tokenizer
164            .with_truncation(Some(tokenizers::TruncationParams {
165                max_length: MAX_SEQ_LENGTH,
166                ..Default::default()
167            }))
168            .map_err(|e| CodememError::Embedding(format!("Truncation error: {e}")))?;
169
170        Ok(Self {
171            model: Mutex::new(model),
172            tokenizer,
173            device,
174            batch_size,
175            hidden_size,
176        })
177    }
178
179    /// Get the model directory path for a given model name.
180    /// Falls back to `~/.codemem/models/{model_name}`.
181    pub fn model_dir_for(model_name: &str) -> PathBuf {
182        dirs::home_dir()
183            .unwrap_or_else(|| PathBuf::from("."))
184            .join(".codemem")
185            .join("models")
186            .join(model_name)
187    }
188
189    /// Get the default model directory path (~/.codemem/models/{MODEL_NAME}).
190    pub fn default_model_dir() -> PathBuf {
191        Self::model_dir_for(MODEL_NAME)
192    }
193
194    /// Download a model from HuggingFace Hub to the given directory.
195    /// `hf_repo` is the full repo ID (e.g. "BAAI/bge-base-en-v1.5").
196    /// Returns the directory path. No-ops if model already exists.
197    pub fn download_model(dest_dir: &Path, hf_repo: &str) -> Result<PathBuf, CodememError> {
198        let model_dest = dest_dir.join("model.safetensors");
199        let config_dest = dest_dir.join("config.json");
200        let tokenizer_dest = dest_dir.join("tokenizer.json");
201
202        if model_dest.exists() && config_dest.exists() && tokenizer_dest.exists() {
203            tracing::info!("Model already downloaded at {}", dest_dir.display());
204            return Ok(dest_dir.to_path_buf());
205        }
206
207        std::fs::create_dir_all(dest_dir)
208            .map_err(|e| CodememError::Embedding(format!("Failed to create dir: {e}")))?;
209
210        tracing::info!("Downloading {} from HuggingFace...", hf_repo);
211
212        let api = hf_hub::api::sync::Api::new()
213            .map_err(|e| CodememError::Embedding(format!("HuggingFace API error: {e}")))?;
214        let repo = api.model(hf_repo.to_string());
215
216        let cached_model = repo
217            .get("model.safetensors")
218            .map_err(|e| CodememError::Embedding(format!("Failed to download model: {e}")))?;
219
220        let cached_config = repo
221            .get("config.json")
222            .map_err(|e| CodememError::Embedding(format!("Failed to download config: {e}")))?;
223
224        let cached_tokenizer = repo
225            .get("tokenizer.json")
226            .map_err(|e| CodememError::Embedding(format!("Failed to download tokenizer: {e}")))?;
227
228        std::fs::copy(&cached_model, &model_dest)
229            .map_err(|e| CodememError::Embedding(format!("Failed to copy model: {e}")))?;
230        std::fs::copy(&cached_config, &config_dest)
231            .map_err(|e| CodememError::Embedding(format!("Failed to copy config: {e}")))?;
232        std::fs::copy(&cached_tokenizer, &tokenizer_dest)
233            .map_err(|e| CodememError::Embedding(format!("Failed to copy tokenizer: {e}")))?;
234
235        tracing::info!("Model downloaded to {}", dest_dir.display());
236        Ok(dest_dir.to_path_buf())
237    }
238
239    /// Download the default model (BAAI/bge-base-en-v1.5) to the default directory.
240    /// Convenience wrapper for `download_model(&default_model_dir(), DEFAULT_HF_REPO)`.
241    pub fn download_default_model() -> Result<PathBuf, CodememError> {
242        Self::download_model(&Self::default_model_dir(), DEFAULT_HF_REPO)
243    }
244
245    /// Embed a single text string. Returns an L2-normalized vector (dimension = model's hidden_size).
246    pub fn embed(&self, text: &str) -> Result<Vec<f32>, CodememError> {
247        // Tokenize using pre-configured tokenizer (truncation already set in constructor)
248        let encoding = self
249            .tokenizer
250            .encode(text, true)
251            .map_err(|e| CodememError::Embedding(e.to_string()))?;
252
253        let input_ids: Vec<u32> = encoding.get_ids().to_vec();
254        let attention_mask: Vec<u32> = encoding.get_attention_mask().to_vec();
255
256        // Build candle tensors with shape [1, seq_len]
257        let input_ids_tensor = Tensor::new(&input_ids[..], &self.device)
258            .and_then(|t| t.unsqueeze(0))
259            .map_err(|e| CodememError::Embedding(format!("Tensor error: {e}")))?;
260
261        let token_type_ids = input_ids_tensor
262            .zeros_like()
263            .map_err(|e| CodememError::Embedding(format!("Tensor error: {e}")))?;
264
265        let attention_mask_tensor = Tensor::new(&attention_mask[..], &self.device)
266            .and_then(|t| t.unsqueeze(0))
267            .map_err(|e| CodememError::Embedding(format!("Tensor error: {e}")))?;
268
269        // Forward pass -> [1, seq_len, hidden_size]
270        let model = self
271            .model
272            .lock()
273            .map_err(|e| CodememError::LockPoisoned(format!("embedding model: {e}")))?;
274        let hidden_states = model
275            .forward(
276                &input_ids_tensor,
277                &token_type_ids,
278                Some(&attention_mask_tensor),
279            )
280            .map_err(|e| CodememError::Embedding(format!("Model forward error: {e}")))?;
281        drop(model);
282
283        // Drop intermediate GPU tensors before pooling
284        drop(input_ids_tensor);
285        drop(token_type_ids);
286
287        // Cast hidden states to F32 for pooling math (model may output F16/BF16)
288        let hidden_states = hidden_states
289            .to_dtype(DType::F32)
290            .map_err(|e| CodememError::Embedding(format!("Cast error: {e}")))?;
291
292        // Mean pooling weighted by attention mask
293        // attention_mask: [1, seq_len] -> [1, seq_len, 1] for broadcasting
294        let mask = attention_mask_tensor
295            .to_dtype(DType::F32)
296            .and_then(|t| t.unsqueeze(2))
297            .map_err(|e| CodememError::Embedding(format!("Mask error: {e}")))?;
298
299        let sum_mask = mask
300            .sum(1)
301            .map_err(|e| CodememError::Embedding(format!("Sum error: {e}")))?;
302
303        let pooled = hidden_states
304            .broadcast_mul(&mask)
305            .and_then(|t| t.sum(1))
306            .and_then(|t| t.broadcast_div(&sum_mask))
307            .map_err(|e| CodememError::Embedding(format!("Pooling error: {e}")))?;
308
309        // L2 normalize
310        let normalized = pooled
311            .sqr()
312            .and_then(|t| t.sum_keepdim(1))
313            .and_then(|t| t.sqrt())
314            .and_then(|norm| pooled.broadcast_div(&norm))
315            .map_err(|e| CodememError::Embedding(format!("Normalize error: {e}")))?;
316
317        // Extract as Vec<f32> — shape is [1, hidden_size], squeeze to [hidden_size]
318        let embedding: Vec<f32> = normalized
319            .squeeze(0)
320            .and_then(|t| t.to_vec1())
321            .map_err(|e| CodememError::Embedding(format!("Extract error: {e}")))?;
322
323        Ok(embedding)
324    }
325
326    /// Embed a batch of texts using a true batched forward pass.
327    ///
328    /// Tokenizes all texts, pads to the longest sequence in each chunk, runs a
329    /// single forward pass per chunk of up to `batch_size` texts, then performs
330    /// mean pooling and L2 normalization on the batched output.
331    pub fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, CodememError> {
332        if texts.is_empty() {
333            return Ok(vec![]);
334        }
335
336        let mut all_embeddings = Vec::with_capacity(texts.len());
337
338        for chunk in texts.chunks(self.batch_size) {
339            // Clone tokenizer only for batch path — needs per-chunk padding config.
340            // Truncation is already configured on self.tokenizer.
341            let mut tokenizer = self.tokenizer.clone();
342            tokenizer.with_padding(Some(PaddingParams {
343                strategy: PaddingStrategy::BatchLongest,
344                ..Default::default()
345            }));
346
347            let encodings = tokenizer
348                .encode_batch(chunk.to_vec(), true)
349                .map_err(|e| CodememError::Embedding(format!("Batch encode error: {e}")))?;
350
351            let batch_len = encodings.len();
352            let seq_len = encodings[0].get_ids().len();
353
354            // Flatten token IDs and attention masks into contiguous arrays
355            let all_ids: Vec<u32> = encodings
356                .iter()
357                .flat_map(|e| e.get_ids())
358                .copied()
359                .collect();
360            let all_masks: Vec<u32> = encodings
361                .iter()
362                .flat_map(|e| e.get_attention_mask())
363                .copied()
364                .collect();
365
366            // Build tensors with shape [batch_size, seq_len]
367            let input_ids = Tensor::new(all_ids.as_slice(), &self.device)
368                .and_then(|t| t.reshape((batch_len, seq_len)))
369                .map_err(|e| CodememError::Embedding(format!("Tensor error: {e}")))?;
370
371            let token_type_ids = input_ids
372                .zeros_like()
373                .map_err(|e| CodememError::Embedding(format!("Tensor error: {e}")))?;
374
375            let attention_mask = Tensor::new(all_masks.as_slice(), &self.device)
376                .and_then(|t| t.reshape((batch_len, seq_len)))
377                .map_err(|e| CodememError::Embedding(format!("Tensor error: {e}")))?;
378
379            // Single forward pass -> [batch_size, seq_len, hidden_size]
380            let model = self
381                .model
382                .lock()
383                .map_err(|e| CodememError::LockPoisoned(format!("embedding model: {e}")))?;
384            let hidden_states = model
385                .forward(&input_ids, &token_type_ids, Some(&attention_mask))
386                .map_err(|e| CodememError::Embedding(format!("Forward error: {e}")))?;
387            drop(model);
388
389            // Drop intermediate GPU tensors before pooling
390            drop(input_ids);
391            drop(token_type_ids);
392
393            // Cast hidden states to F32 for pooling math (model may output F16/BF16)
394            let hidden_states = hidden_states
395                .to_dtype(DType::F32)
396                .map_err(|e| CodememError::Embedding(format!("Cast error: {e}")))?;
397
398            // Mean pooling: mask [batch, seq] -> [batch, seq, 1] for broadcast
399            let mask = attention_mask
400                .to_dtype(DType::F32)
401                .and_then(|t| t.unsqueeze(2))
402                .map_err(|e| CodememError::Embedding(format!("Mask error: {e}")))?;
403
404            let sum_mask = mask
405                .sum(1)
406                .map_err(|e| CodememError::Embedding(format!("Sum error: {e}")))?;
407
408            let pooled = hidden_states
409                .broadcast_mul(&mask)
410                .and_then(|t| t.sum(1))
411                .and_then(|t| t.broadcast_div(&sum_mask))
412                .map_err(|e| CodememError::Embedding(format!("Pooling error: {e}")))?;
413
414            // L2 normalize: [batch, hidden]
415            let norm = pooled
416                .sqr()
417                .and_then(|t| t.sum_keepdim(1))
418                .and_then(|t| t.sqrt())
419                .map_err(|e| CodememError::Embedding(format!("Norm error: {e}")))?;
420
421            let normalized = pooled
422                .broadcast_div(&norm)
423                .map_err(|e| CodememError::Embedding(format!("Normalize error: {e}")))?;
424
425            // Single GPU→CPU blit: flatten all rows, then slice on CPU.
426            // to_vec1() implicitly syncs the GPU pipeline (data must be ready to read).
427            let flat: Vec<f32> = normalized
428                .flatten_all()
429                .and_then(|t| t.to_vec1())
430                .map_err(|e| CodememError::Embedding(format!("Extract error: {e}")))?;
431            for i in 0..batch_len {
432                let start = i * self.hidden_size;
433                all_embeddings.push(flat[start..start + self.hidden_size].to_vec());
434            }
435        }
436
437        Ok(all_embeddings)
438    }
439}
440
441impl EmbeddingProvider for EmbeddingService {
442    fn dimensions(&self) -> usize {
443        self.hidden_size
444    }
445
446    fn embed(&self, text: &str) -> Result<Vec<f32>, CodememError> {
447        self.embed(text)
448    }
449
450    fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, CodememError> {
451        self.embed_batch(texts)
452    }
453
454    fn name(&self) -> &str {
455        "candle"
456    }
457}
458
459// ── Cached Provider Wrapper ───────────────────────────────────────────────
460
461/// Wraps any `EmbeddingProvider` with an LRU cache.
462pub struct CachedProvider {
463    inner: Box<dyn EmbeddingProvider>,
464    cache: Mutex<LruCache<String, Vec<f32>>>,
465}
466
467impl CachedProvider {
468    pub fn new(inner: Box<dyn EmbeddingProvider>, capacity: usize) -> Self {
469        // SAFETY: 1 is non-zero, so the inner expect is infallible
470        let cap =
471            NonZeroUsize::new(capacity).unwrap_or(NonZeroUsize::new(1).expect("1 is non-zero"));
472        Self {
473            inner,
474            cache: Mutex::new(LruCache::new(cap)),
475        }
476    }
477}
478
479impl EmbeddingProvider for CachedProvider {
480    fn dimensions(&self) -> usize {
481        self.inner.dimensions()
482    }
483
484    fn embed(&self, text: &str) -> Result<Vec<f32>, CodememError> {
485        {
486            let mut cache = self
487                .cache
488                .lock()
489                .map_err(|e| CodememError::LockPoisoned(format!("cached provider: {e}")))?;
490            if let Some(cached) = cache.get(text) {
491                return Ok(cached.clone());
492            }
493        }
494        let embedding = self.inner.embed(text)?;
495        {
496            let mut cache = self
497                .cache
498                .lock()
499                .map_err(|e| CodememError::LockPoisoned(format!("cached provider: {e}")))?;
500            cache.put(text.to_string(), embedding.clone());
501        }
502        Ok(embedding)
503    }
504
505    fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, CodememError> {
506        // Check cache, only forward uncached texts
507        let mut results: Vec<Option<Vec<f32>>> = vec![None; texts.len()];
508        let mut uncached = Vec::new();
509        let mut uncached_idx = Vec::new();
510
511        {
512            let mut cache = self
513                .cache
514                .lock()
515                .map_err(|e| CodememError::LockPoisoned(format!("cached provider: {e}")))?;
516            for (i, text) in texts.iter().enumerate() {
517                if let Some(cached) = cache.get(*text) {
518                    results[i] = Some(cached.clone());
519                } else {
520                    uncached_idx.push(i);
521                    uncached.push(*text);
522                }
523            }
524        }
525
526        if !uncached.is_empty() {
527            let new_embeddings = self.inner.embed_batch(&uncached)?;
528            let mut cache = self
529                .cache
530                .lock()
531                .map_err(|e| CodememError::LockPoisoned(format!("cached provider: {e}")))?;
532            for (idx, embedding) in uncached_idx.into_iter().zip(new_embeddings) {
533                cache.put(texts[idx].to_string(), embedding.clone());
534                results[idx] = Some(embedding);
535            }
536        }
537
538        // Verify all texts got embeddings — flatten() would silently drop Nones
539        let expected = texts.len();
540        let output: Vec<Vec<f32>> = results
541            .into_iter()
542            .enumerate()
543            .map(|(i, opt)| {
544                opt.ok_or_else(|| {
545                    CodememError::Embedding(format!(
546                        "Missing embedding for text at index {i} (batch size {expected})"
547                    ))
548                })
549            })
550            .collect::<Result<Vec<_>, _>>()?;
551        Ok(output)
552    }
553
554    fn name(&self) -> &str {
555        self.inner.name()
556    }
557
558    fn cache_stats(&self) -> (usize, usize) {
559        match self.cache.lock() {
560            Ok(cache) => (cache.len(), cache.cap().into()),
561            Err(_) => (0, 0),
562        }
563    }
564}
565
566// ── Factory ───────────────────────────────────────────────────────────────
567
568/// Parse a dtype string into a Candle DType.
569///
570/// Supported values: "f32" (default), "f16" (half precision — less memory, faster on Metal).
571pub fn parse_dtype(s: &str) -> Result<DType, CodememError> {
572    match s.to_lowercase().as_str() {
573        "f32" | "float32" | "" => Ok(DType::F32),
574        "f16" | "float16" | "half" => Ok(DType::F16),
575        "bf16" | "bfloat16" => Ok(DType::BF16),
576        other => Err(CodememError::Embedding(format!(
577            "Unknown dtype: '{}'. Use 'f32', 'f16', or 'bf16'.",
578            other
579        ))),
580    }
581}
582
583/// Resolve the HuggingFace repo ID and local directory name from a model identifier.
584///
585/// Accepts:
586/// - Full HF repo: `"BAAI/bge-base-en-v1.5"` → repo=`"BAAI/bge-base-en-v1.5"`, dir=`"bge-base-en-v1.5"`
587/// - Short name: `"bge-small-en-v1.5"` → repo=`"BAAI/bge-small-en-v1.5"`, dir=`"bge-small-en-v1.5"`
588///
589/// Returns `Err` if the model identifier is a bare name without an org prefix and isn't
590/// a recognized `bge-*` shorthand — HuggingFace requires `org/repo` format.
591fn resolve_model_id(model: &str) -> Result<(String, String), CodememError> {
592    if model.contains('/') {
593        // Full repo ID — directory name is the part after the slash
594        let dir_name = model.rsplit('/').next().unwrap_or(model);
595        Ok((model.to_string(), dir_name.to_string()))
596    } else if model.starts_with("bge-") {
597        // Short name — assume BAAI namespace for bge-* models
598        Ok((format!("BAAI/{model}"), model.to_string()))
599    } else {
600        Err(CodememError::Embedding(format!(
601            "Model identifier '{}' must be a full HuggingFace repo ID (e.g., 'BAAI/bge-base-en-v1.5' \
602             or 'sentence-transformers/all-MiniLM-L6-v2'). Short names are only supported for 'bge-*' models.",
603            model
604        )))
605    }
606}
607
608/// Create an embedding provider from environment variables.
609///
610/// When `config` is provided, its fields serve as defaults; env vars override them.
611///
612/// | Variable | Values | Default |
613/// |----------|--------|---------|
614/// | `CODEMEM_EMBED_PROVIDER` | `candle`, `ollama`, `openai` | `candle` |
615/// | `CODEMEM_EMBED_MODEL` | model name or HF repo | `BAAI/bge-base-en-v1.5` |
616/// | `CODEMEM_EMBED_URL` | base URL | provider default |
617/// | `CODEMEM_EMBED_API_KEY` | API key | also reads `OPENAI_API_KEY` |
618/// | `CODEMEM_EMBED_DIMENSIONS` | integer | read from model config |
619/// | `CODEMEM_EMBED_BATCH_SIZE` | integer | `16` |
620/// | `CODEMEM_EMBED_DTYPE` | `f32`, `f16`, `bf16` | `f32` |
621pub fn from_env(
622    config: Option<&codemem_core::EmbeddingConfig>,
623) -> Result<Box<dyn EmbeddingProvider>, CodememError> {
624    let provider = std::env::var("CODEMEM_EMBED_PROVIDER")
625        .unwrap_or_else(|_| {
626            config
627                .map(|c| c.provider.clone())
628                .unwrap_or_else(|| "candle".to_string())
629        })
630        .to_lowercase();
631    // For Ollama/OpenAI, dimensions must be specified explicitly (remote APIs need it).
632    // For Candle, this value is ignored — hidden_size is read from the model's config.json.
633    let dimensions: usize = std::env::var("CODEMEM_EMBED_DIMENSIONS")
634        .ok()
635        .and_then(|s| s.parse().ok())
636        .unwrap_or_else(|| config.map_or(DEFAULT_REMOTE_DIMENSIONS, |c| c.dimensions));
637    let cache_capacity = config.map_or(CACHE_CAPACITY, |c| c.cache_capacity);
638    let batch_size: usize = std::env::var("CODEMEM_EMBED_BATCH_SIZE")
639        .ok()
640        .and_then(|s| s.parse().ok())
641        .unwrap_or_else(|| config.map_or(DEFAULT_BATCH_SIZE, |c| c.batch_size));
642
643    match provider.as_str() {
644        "ollama" => {
645            let base_url = std::env::var("CODEMEM_EMBED_URL").unwrap_or_else(|_| {
646                config
647                    .filter(|c| !c.url.is_empty())
648                    .map(|c| c.url.clone())
649                    .unwrap_or_else(|| ollama::DEFAULT_BASE_URL.to_string())
650            });
651            let model = std::env::var("CODEMEM_EMBED_MODEL").unwrap_or_else(|_| {
652                config
653                    .filter(|c| !c.model.is_empty())
654                    .map(|c| c.model.clone())
655                    .unwrap_or_else(|| ollama::DEFAULT_MODEL.to_string())
656            });
657            let inner = Box::new(ollama::OllamaProvider::new(&base_url, &model, dimensions));
658            Ok(Box::new(CachedProvider::new(inner, cache_capacity)))
659        }
660        "openai" => {
661            let api_key = std::env::var("CODEMEM_EMBED_API_KEY")
662                .or_else(|_| std::env::var("OPENAI_API_KEY"))
663                .map_err(|_| {
664                    CodememError::Embedding(
665                        "CODEMEM_EMBED_API_KEY or OPENAI_API_KEY required for OpenAI embeddings"
666                            .into(),
667                    )
668                })?;
669            let model = std::env::var("CODEMEM_EMBED_MODEL").unwrap_or_else(|_| {
670                config
671                    .filter(|c| !c.model.is_empty())
672                    .map(|c| c.model.clone())
673                    .unwrap_or_else(|| openai::DEFAULT_MODEL.to_string())
674            });
675            let base_url = std::env::var("CODEMEM_EMBED_URL")
676                .ok()
677                .or_else(|| config.filter(|c| !c.url.is_empty()).map(|c| c.url.clone()));
678            let inner = Box::new(openai::OpenAIProvider::new(
679                &api_key,
680                &model,
681                dimensions,
682                base_url.as_deref(),
683            ));
684            Ok(Box::new(CachedProvider::new(inner, cache_capacity)))
685        }
686        "candle" | "" => {
687            let model_id = std::env::var("CODEMEM_EMBED_MODEL").unwrap_or_else(|_| {
688                config
689                    .filter(|c| !c.model.is_empty())
690                    .map(|c| c.model.clone())
691                    .unwrap_or_else(|| DEFAULT_HF_REPO.to_string())
692            });
693            let (hf_repo, dir_name) = resolve_model_id(&model_id)?;
694            let model_dir = EmbeddingService::model_dir_for(&dir_name);
695
696            let dtype_str = std::env::var("CODEMEM_EMBED_DTYPE").unwrap_or_else(|_| {
697                config
698                    .filter(|c| !c.dtype.is_empty())
699                    .map(|c| c.dtype.clone())
700                    .unwrap_or_else(|| "f32".to_string())
701            });
702            let dtype = parse_dtype(&dtype_str)?;
703
704            let service = EmbeddingService::new(&model_dir, batch_size, dtype).map_err(|e| {
705                // Enhance error message with download hint for non-default models
706                if e.to_string().contains("Model not found") && hf_repo != DEFAULT_HF_REPO {
707                    CodememError::Embedding(format!(
708                        "Model '{}' not found at {}. Download it with:\n  \
709                         CODEMEM_EMBED_MODEL={} codemem init",
710                        hf_repo,
711                        model_dir.display(),
712                        hf_repo
713                    ))
714                } else {
715                    e
716                }
717            })?;
718            Ok(Box::new(CachedProvider::new(
719                Box::new(service),
720                cache_capacity,
721            )))
722        }
723        other => Err(CodememError::Embedding(format!(
724            "Unknown embedding provider: '{}'. Use 'candle', 'ollama', or 'openai'.",
725            other
726        ))),
727    }
728}
729
730#[cfg(test)]
731#[path = "tests/lib_tests.rs"]
732mod tests;