Skip to main content

codemem_embeddings/
lib.rs

1//! codemem-embeddings: Pluggable embedding providers for Codemem.
2//!
3//! Supports multiple backends:
4//! - **Candle** (default): Local BAAI/bge-base-en-v1.5 via pure Rust ML
5//! - **Ollama**: Local Ollama server with any embedding model
6//! - **OpenAI**: OpenAI API or any compatible endpoint (Together, Azure, etc.)
7
8pub mod ollama;
9pub mod openai;
10
11use candle_core::{DType, Device, Tensor};
12use candle_nn::VarBuilder;
13use candle_transformers::models::bert::{BertModel, Config as BertConfig};
14use codemem_core::CodememError;
15use lru::LruCache;
16use std::num::NonZeroUsize;
17use std::path::{Path, PathBuf};
18use std::sync::Mutex;
19use tokenizers::{PaddingParams, PaddingStrategy};
20
21/// Default model name.
22pub const MODEL_NAME: &str = "bge-base-en-v1.5";
23
24/// HuggingFace model repo ID.
25const HF_MODEL_REPO: &str = "BAAI/bge-base-en-v1.5";
26
27/// Default embedding dimensions.
28pub const DIMENSIONS: usize = 768;
29
30/// Max sequence length for bge-base-en-v1.5.
31const MAX_SEQ_LENGTH: usize = 512;
32
33/// Default LRU cache capacity.
34pub const CACHE_CAPACITY: usize = 10_000;
35
36// Re-export EmbeddingProvider trait from core
37pub use codemem_core::EmbeddingProvider;
38
39// ── Candle Embedding Service ────────────────────────────────────────────────
40
41/// Maximum batch size for batched embedding forward passes.
42/// bge-base-en-v1.5 at 768-dim: ~48 MB peak GPU memory per batch of 32.
43const BATCH_SIZE: usize = 32;
44
45/// Select the best available compute device.
46///
47/// Tries Metal (macOS GPU) first, then CUDA (NVIDIA GPU), then falls back to CPU.
48/// GPU backends are only available when the corresponding feature flag is enabled.
49fn select_device() -> Device {
50    #[cfg(feature = "metal")]
51    {
52        if let Ok(device) = Device::new_metal(0) {
53            tracing::info!("Using Metal GPU for embeddings");
54            return device;
55        }
56        tracing::warn!("Metal feature enabled but device creation failed, falling back");
57    }
58    #[cfg(feature = "cuda")]
59    {
60        if let Ok(device) = Device::new_cuda(0) {
61            tracing::info!("Using CUDA GPU for embeddings");
62            return device;
63        }
64        tracing::warn!("CUDA feature enabled but device creation failed, falling back");
65    }
66    tracing::info!("Using CPU for embeddings");
67    Device::Cpu
68}
69
70/// Embedding service with Candle inference and LRU caching.
71pub struct EmbeddingService {
72    model: Mutex<BertModel>,
73    /// Tokenizer pre-configured with truncation (no padding).
74    /// Used directly for single embeds; cloned and augmented with padding for batch.
75    tokenizer: tokenizers::Tokenizer,
76    device: Device,
77    cache: Mutex<LruCache<String, Vec<f32>>>,
78}
79
80impl EmbeddingService {
81    /// Create a new embedding service, loading model from the given directory.
82    /// Expects `model.safetensors`, `config.json`, and `tokenizer.json` in the directory.
83    pub fn new(model_dir: &Path) -> Result<Self, CodememError> {
84        let model_path = model_dir.join("model.safetensors");
85        let config_path = model_dir.join("config.json");
86        let tokenizer_path = model_dir.join("tokenizer.json");
87
88        if !model_path.exists() {
89            return Err(CodememError::Embedding(format!(
90                "Model not found at {}. Run `codemem init` to download it.",
91                model_path.display()
92            )));
93        }
94
95        let device = select_device();
96
97        // Load BERT config
98        let config_str = std::fs::read_to_string(&config_path)
99            .map_err(|e| CodememError::Embedding(format!("Failed to read config: {e}")))?;
100        let config: BertConfig = serde_json::from_str(&config_str)
101            .map_err(|e| CodememError::Embedding(format!("Failed to parse config: {e}")))?;
102
103        // Load model weights from safetensors via memory-mapped IO.
104        // Scope vb so it drops before a potential retry, avoiding two VarBuilders
105        // holding materialized Metal tensors simultaneously.
106        let model = {
107            let vb = unsafe {
108                VarBuilder::from_mmaped_safetensors(&[&model_path], DType::F32, &device)
109                    .map_err(|e| CodememError::Embedding(format!("Failed to load weights: {e}")))?
110            };
111            BertModel::load(vb.pp("bert"), &config)
112        };
113        // Try with "bert." prefix first (standard HF BERT models), then without
114        let model = match model {
115            Ok(m) => m,
116            Err(_) => {
117                let vb2 = unsafe {
118                    VarBuilder::from_mmaped_safetensors(&[&model_path], DType::F32, &device)
119                        .map_err(|e| {
120                            CodememError::Embedding(format!("Failed to load weights: {e}"))
121                        })?
122                };
123                BertModel::load(vb2, &config).map_err(|e| {
124                    CodememError::Embedding(format!("Failed to load BERT model: {e}"))
125                })?
126            }
127        };
128
129        let mut tokenizer = tokenizers::Tokenizer::from_file(&tokenizer_path)
130            .map_err(|e| CodememError::Embedding(e.to_string()))?;
131
132        // Pre-configure truncation once so we don't need to clone on every embed call.
133        tokenizer
134            .with_truncation(Some(tokenizers::TruncationParams {
135                max_length: MAX_SEQ_LENGTH,
136                ..Default::default()
137            }))
138            .map_err(|e| CodememError::Embedding(format!("Truncation error: {e}")))?;
139
140        let cache = Mutex::new(LruCache::new(
141            NonZeroUsize::new(CACHE_CAPACITY).expect("CACHE_CAPACITY is non-zero"),
142        ));
143
144        Ok(Self {
145            model: Mutex::new(model),
146            tokenizer,
147            device,
148            cache,
149        })
150    }
151
152    /// Get the default model directory path (~/.codemem/models/{MODEL_NAME}).
153    pub fn default_model_dir() -> PathBuf {
154        dirs::home_dir()
155            .unwrap_or_else(|| PathBuf::from("."))
156            .join(".codemem")
157            .join("models")
158            .join(MODEL_NAME)
159    }
160
161    /// Download the model from HuggingFace Hub to the given directory.
162    /// Returns the directory path. No-ops if model already exists.
163    pub fn download_model(dest_dir: &Path) -> Result<PathBuf, CodememError> {
164        let model_dest = dest_dir.join("model.safetensors");
165        let config_dest = dest_dir.join("config.json");
166        let tokenizer_dest = dest_dir.join("tokenizer.json");
167
168        if model_dest.exists() && config_dest.exists() && tokenizer_dest.exists() {
169            tracing::info!("Model already downloaded at {}", dest_dir.display());
170            return Ok(dest_dir.to_path_buf());
171        }
172
173        std::fs::create_dir_all(dest_dir)
174            .map_err(|e| CodememError::Embedding(format!("Failed to create dir: {e}")))?;
175
176        tracing::info!("Downloading {} from HuggingFace...", HF_MODEL_REPO);
177
178        let api = hf_hub::api::sync::Api::new()
179            .map_err(|e| CodememError::Embedding(format!("HuggingFace API error: {e}")))?;
180        let repo = api.model(HF_MODEL_REPO.to_string());
181
182        let cached_model = repo
183            .get("model.safetensors")
184            .map_err(|e| CodememError::Embedding(format!("Failed to download model: {e}")))?;
185
186        let cached_config = repo
187            .get("config.json")
188            .map_err(|e| CodememError::Embedding(format!("Failed to download config: {e}")))?;
189
190        let cached_tokenizer = repo
191            .get("tokenizer.json")
192            .map_err(|e| CodememError::Embedding(format!("Failed to download tokenizer: {e}")))?;
193
194        std::fs::copy(&cached_model, &model_dest)
195            .map_err(|e| CodememError::Embedding(format!("Failed to copy model: {e}")))?;
196        std::fs::copy(&cached_config, &config_dest)
197            .map_err(|e| CodememError::Embedding(format!("Failed to copy config: {e}")))?;
198        std::fs::copy(&cached_tokenizer, &tokenizer_dest)
199            .map_err(|e| CodememError::Embedding(format!("Failed to copy tokenizer: {e}")))?;
200
201        tracing::info!("Model downloaded to {}", dest_dir.display());
202        Ok(dest_dir.to_path_buf())
203    }
204
205    /// Embed a single text string. Returns a 768-dim L2-normalized vector.
206    /// Uses LRU cache for repeated queries.
207    pub fn embed(&self, text: &str) -> Result<Vec<f32>, CodememError> {
208        // Check cache
209        {
210            let mut cache = self
211                .cache
212                .lock()
213                .map_err(|e| CodememError::LockPoisoned(format!("embedding cache: {e}")))?;
214            if let Some(cached) = cache.get(text) {
215                return Ok(cached.clone());
216            }
217        }
218
219        let embedding = self.embed_uncached(text)?;
220
221        // Store in cache
222        {
223            let mut cache = self
224                .cache
225                .lock()
226                .map_err(|e| CodememError::LockPoisoned(format!("embedding cache: {e}")))?;
227            cache.put(text.to_string(), embedding.clone());
228        }
229
230        Ok(embedding)
231    }
232
233    /// Embed without caching. Uses mean pooling with attention mask.
234    fn embed_uncached(&self, text: &str) -> Result<Vec<f32>, CodememError> {
235        // Tokenize using pre-configured tokenizer (truncation already set in constructor)
236        let encoding = self
237            .tokenizer
238            .encode(text, true)
239            .map_err(|e| CodememError::Embedding(e.to_string()))?;
240
241        let input_ids: Vec<u32> = encoding.get_ids().to_vec();
242        let attention_mask: Vec<u32> = encoding.get_attention_mask().to_vec();
243
244        // Build candle tensors with shape [1, seq_len]
245        let input_ids_tensor = Tensor::new(&input_ids[..], &self.device)
246            .and_then(|t| t.unsqueeze(0))
247            .map_err(|e| CodememError::Embedding(format!("Tensor error: {e}")))?;
248
249        let token_type_ids = input_ids_tensor
250            .zeros_like()
251            .map_err(|e| CodememError::Embedding(format!("Tensor error: {e}")))?;
252
253        let attention_mask_tensor = Tensor::new(&attention_mask[..], &self.device)
254            .and_then(|t| t.unsqueeze(0))
255            .map_err(|e| CodememError::Embedding(format!("Tensor error: {e}")))?;
256
257        // Forward pass -> [1, seq_len, hidden_size]
258        let model = self
259            .model
260            .lock()
261            .map_err(|e| CodememError::LockPoisoned(format!("embedding model: {e}")))?;
262        let hidden_states = model
263            .forward(
264                &input_ids_tensor,
265                &token_type_ids,
266                Some(&attention_mask_tensor),
267            )
268            .map_err(|e| CodememError::Embedding(format!("Model forward error: {e}")))?;
269        drop(model);
270
271        // Drop intermediate GPU tensors before pooling
272        drop(input_ids_tensor);
273        drop(token_type_ids);
274
275        // Mean pooling weighted by attention mask
276        // attention_mask: [1, seq_len] -> [1, seq_len, 1] for broadcasting
277        let mask = attention_mask_tensor
278            .to_dtype(DType::F32)
279            .and_then(|t| t.unsqueeze(2))
280            .map_err(|e| CodememError::Embedding(format!("Mask error: {e}")))?;
281
282        let sum_mask = mask
283            .sum(1)
284            .map_err(|e| CodememError::Embedding(format!("Sum error: {e}")))?;
285
286        let pooled = hidden_states
287            .broadcast_mul(&mask)
288            .and_then(|t| t.sum(1))
289            .and_then(|t| t.broadcast_div(&sum_mask))
290            .map_err(|e| CodememError::Embedding(format!("Pooling error: {e}")))?;
291
292        // L2 normalize
293        let normalized = pooled
294            .sqr()
295            .and_then(|t| t.sum_keepdim(1))
296            .and_then(|t| t.sqrt())
297            .and_then(|norm| pooled.broadcast_div(&norm))
298            .map_err(|e| CodememError::Embedding(format!("Normalize error: {e}")))?;
299
300        // Extract as Vec<f32> — shape is [1, hidden_size], squeeze to [hidden_size]
301        let embedding: Vec<f32> = normalized
302            .squeeze(0)
303            .and_then(|t| t.to_vec1())
304            .map_err(|e| CodememError::Embedding(format!("Extract error: {e}")))?;
305
306        Ok(embedding)
307    }
308
309    /// Embed a batch of texts with cache-aware batching.
310    ///
311    /// Checks the LRU cache first and only runs the model on uncached texts,
312    /// using a true batched forward pass (single GPU/CPU kernel launch per chunk).
313    pub fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, CodememError> {
314        if texts.is_empty() {
315            return Ok(vec![]);
316        }
317
318        // Partition into cached and uncached
319        let mut results: Vec<Option<Vec<f32>>> = vec![None; texts.len()];
320        let mut uncached_indices = Vec::new();
321        let mut uncached_texts = Vec::new();
322
323        {
324            let mut cache = self
325                .cache
326                .lock()
327                .map_err(|e| CodememError::LockPoisoned(format!("embedding cache: {e}")))?;
328            for (i, text) in texts.iter().enumerate() {
329                if let Some(cached) = cache.get(*text) {
330                    results[i] = Some(cached.clone());
331                } else {
332                    uncached_indices.push(i);
333                    uncached_texts.push(*text);
334                }
335            }
336        }
337
338        if !uncached_texts.is_empty() {
339            let new_embeddings = self.embed_batch_uncached(&uncached_texts)?;
340
341            let mut cache = self
342                .cache
343                .lock()
344                .map_err(|e| CodememError::LockPoisoned(format!("embedding cache: {e}")))?;
345            for (idx, embedding) in uncached_indices.into_iter().zip(new_embeddings) {
346                cache.put(texts[idx].to_string(), embedding.clone());
347                results[idx] = Some(embedding);
348            }
349        }
350
351        // Verify all texts got embeddings — flatten() would silently drop Nones
352        let expected = texts.len();
353        let output: Vec<Vec<f32>> = results
354            .into_iter()
355            .enumerate()
356            .map(|(i, opt)| {
357                opt.ok_or_else(|| {
358                    CodememError::Embedding(format!(
359                        "Missing embedding for text at index {i} (batch size {expected})"
360                    ))
361                })
362            })
363            .collect::<Result<Vec<_>, _>>()?;
364        Ok(output)
365    }
366
367    /// Embed a batch of texts without caching, using a true batched forward pass.
368    ///
369    /// Tokenizes all texts, pads to the longest sequence in each chunk, runs a
370    /// single forward pass per chunk of up to `BATCH_SIZE` texts, then performs
371    /// mean pooling and L2 normalization on the batched output.
372    fn embed_batch_uncached(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, CodememError> {
373        if texts.is_empty() {
374            return Ok(vec![]);
375        }
376
377        let mut all_embeddings = Vec::with_capacity(texts.len());
378
379        for chunk in texts.chunks(BATCH_SIZE) {
380            // Clone tokenizer only for batch path — needs per-chunk padding config.
381            // Truncation is already configured on self.tokenizer.
382            let mut tokenizer = self.tokenizer.clone();
383            tokenizer.with_padding(Some(PaddingParams {
384                strategy: PaddingStrategy::BatchLongest,
385                ..Default::default()
386            }));
387
388            let encodings = tokenizer
389                .encode_batch(chunk.to_vec(), true)
390                .map_err(|e| CodememError::Embedding(format!("Batch encode error: {e}")))?;
391
392            let batch_len = encodings.len();
393            let seq_len = encodings[0].get_ids().len();
394
395            // Flatten token IDs and attention masks into contiguous arrays
396            let all_ids: Vec<u32> = encodings
397                .iter()
398                .flat_map(|e| e.get_ids())
399                .copied()
400                .collect();
401            let all_masks: Vec<u32> = encodings
402                .iter()
403                .flat_map(|e| e.get_attention_mask())
404                .copied()
405                .collect();
406
407            // Build tensors with shape [batch_size, seq_len]
408            let input_ids = Tensor::new(all_ids.as_slice(), &self.device)
409                .and_then(|t| t.reshape((batch_len, seq_len)))
410                .map_err(|e| CodememError::Embedding(format!("Tensor error: {e}")))?;
411
412            let token_type_ids = input_ids
413                .zeros_like()
414                .map_err(|e| CodememError::Embedding(format!("Tensor error: {e}")))?;
415
416            let attention_mask = Tensor::new(all_masks.as_slice(), &self.device)
417                .and_then(|t| t.reshape((batch_len, seq_len)))
418                .map_err(|e| CodememError::Embedding(format!("Tensor error: {e}")))?;
419
420            // Single forward pass -> [batch_size, seq_len, hidden_size]
421            let model = self
422                .model
423                .lock()
424                .map_err(|e| CodememError::LockPoisoned(format!("embedding model: {e}")))?;
425            let hidden_states = model
426                .forward(&input_ids, &token_type_ids, Some(&attention_mask))
427                .map_err(|e| CodememError::Embedding(format!("Forward error: {e}")))?;
428            drop(model);
429
430            // Drop intermediate GPU tensors before pooling
431            drop(input_ids);
432            drop(token_type_ids);
433
434            // Mean pooling: mask [batch, seq] -> [batch, seq, 1] for broadcast
435            let mask = attention_mask
436                .to_dtype(DType::F32)
437                .and_then(|t| t.unsqueeze(2))
438                .map_err(|e| CodememError::Embedding(format!("Mask error: {e}")))?;
439
440            let sum_mask = mask
441                .sum(1)
442                .map_err(|e| CodememError::Embedding(format!("Sum error: {e}")))?;
443
444            let pooled = hidden_states
445                .broadcast_mul(&mask)
446                .and_then(|t| t.sum(1))
447                .and_then(|t| t.broadcast_div(&sum_mask))
448                .map_err(|e| CodememError::Embedding(format!("Pooling error: {e}")))?;
449
450            // L2 normalize: [batch, hidden]
451            let norm = pooled
452                .sqr()
453                .and_then(|t| t.sum_keepdim(1))
454                .and_then(|t| t.sqrt())
455                .map_err(|e| CodememError::Embedding(format!("Norm error: {e}")))?;
456
457            let normalized = pooled
458                .broadcast_div(&norm)
459                .map_err(|e| CodememError::Embedding(format!("Normalize error: {e}")))?;
460
461            // Single GPU→CPU blit: flatten all rows, then slice on CPU.
462            // to_vec1() implicitly syncs the GPU pipeline (data must be ready to read).
463            let flat: Vec<f32> = normalized
464                .flatten_all()
465                .and_then(|t| t.to_vec1())
466                .map_err(|e| CodememError::Embedding(format!("Extract error: {e}")))?;
467            for i in 0..batch_len {
468                let start = i * DIMENSIONS;
469                all_embeddings.push(flat[start..start + DIMENSIONS].to_vec());
470            }
471        }
472
473        Ok(all_embeddings)
474    }
475
476    /// Get cache statistics: (current_size, capacity).
477    pub fn cache_stats(&self) -> (usize, usize) {
478        match self.cache.lock() {
479            Ok(cache) => (cache.len(), CACHE_CAPACITY),
480            Err(_) => (0, CACHE_CAPACITY),
481        }
482    }
483}
484
485impl EmbeddingProvider for EmbeddingService {
486    fn dimensions(&self) -> usize {
487        DIMENSIONS
488    }
489
490    fn embed(&self, text: &str) -> Result<Vec<f32>, CodememError> {
491        self.embed(text)
492    }
493
494    fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, CodememError> {
495        self.embed_batch(texts)
496    }
497
498    fn name(&self) -> &str {
499        "candle"
500    }
501
502    fn cache_stats(&self) -> (usize, usize) {
503        self.cache_stats()
504    }
505}
506
507// ── Cached Provider Wrapper ───────────────────────────────────────────────
508
509/// Wraps any `EmbeddingProvider` with an LRU cache.
510pub struct CachedProvider {
511    inner: Box<dyn EmbeddingProvider>,
512    cache: Mutex<LruCache<String, Vec<f32>>>,
513}
514
515impl CachedProvider {
516    pub fn new(inner: Box<dyn EmbeddingProvider>, capacity: usize) -> Self {
517        // SAFETY: 1 is non-zero, so the inner expect is infallible
518        let cap =
519            NonZeroUsize::new(capacity).unwrap_or(NonZeroUsize::new(1).expect("1 is non-zero"));
520        Self {
521            inner,
522            cache: Mutex::new(LruCache::new(cap)),
523        }
524    }
525}
526
527impl EmbeddingProvider for CachedProvider {
528    fn dimensions(&self) -> usize {
529        self.inner.dimensions()
530    }
531
532    fn embed(&self, text: &str) -> Result<Vec<f32>, CodememError> {
533        {
534            let mut cache = self
535                .cache
536                .lock()
537                .map_err(|e| CodememError::LockPoisoned(format!("cached provider: {e}")))?;
538            if let Some(cached) = cache.get(text) {
539                return Ok(cached.clone());
540            }
541        }
542        let embedding = self.inner.embed(text)?;
543        {
544            let mut cache = self
545                .cache
546                .lock()
547                .map_err(|e| CodememError::LockPoisoned(format!("cached provider: {e}")))?;
548            cache.put(text.to_string(), embedding.clone());
549        }
550        Ok(embedding)
551    }
552
553    fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, CodememError> {
554        // Check cache, only forward uncached texts
555        let mut results: Vec<Option<Vec<f32>>> = vec![None; texts.len()];
556        let mut uncached = Vec::new();
557        let mut uncached_idx = Vec::new();
558
559        {
560            let mut cache = self
561                .cache
562                .lock()
563                .map_err(|e| CodememError::LockPoisoned(format!("cached provider: {e}")))?;
564            for (i, text) in texts.iter().enumerate() {
565                if let Some(cached) = cache.get(*text) {
566                    results[i] = Some(cached.clone());
567                } else {
568                    uncached_idx.push(i);
569                    uncached.push(*text);
570                }
571            }
572        }
573
574        if !uncached.is_empty() {
575            let new_embeddings = self.inner.embed_batch(&uncached)?;
576            let mut cache = self
577                .cache
578                .lock()
579                .map_err(|e| CodememError::LockPoisoned(format!("cached provider: {e}")))?;
580            for (idx, embedding) in uncached_idx.into_iter().zip(new_embeddings) {
581                cache.put(texts[idx].to_string(), embedding.clone());
582                results[idx] = Some(embedding);
583            }
584        }
585
586        // Verify all texts got embeddings — flatten() would silently drop Nones
587        let expected = texts.len();
588        let output: Vec<Vec<f32>> = results
589            .into_iter()
590            .enumerate()
591            .map(|(i, opt)| {
592                opt.ok_or_else(|| {
593                    CodememError::Embedding(format!(
594                        "Missing embedding for text at index {i} (batch size {expected})"
595                    ))
596                })
597            })
598            .collect::<Result<Vec<_>, _>>()?;
599        Ok(output)
600    }
601
602    fn name(&self) -> &str {
603        self.inner.name()
604    }
605
606    fn cache_stats(&self) -> (usize, usize) {
607        match self.cache.lock() {
608            Ok(cache) => (cache.len(), cache.cap().into()),
609            Err(_) => (0, 0),
610        }
611    }
612}
613
614// ── Factory ───────────────────────────────────────────────────────────────
615
616/// Create an embedding provider from environment variables.
617///
618/// | Variable | Values | Default |
619/// |----------|--------|---------|
620/// | `CODEMEM_EMBED_PROVIDER` | `candle`, `ollama`, `openai` | `candle` |
621/// | `CODEMEM_EMBED_MODEL` | model name | provider default |
622/// | `CODEMEM_EMBED_URL` | base URL | provider default |
623/// | `CODEMEM_EMBED_API_KEY` | API key | also reads `OPENAI_API_KEY` |
624/// | `CODEMEM_EMBED_DIMENSIONS` | integer | `768` |
625pub fn from_env() -> Result<Box<dyn EmbeddingProvider>, CodememError> {
626    let provider = std::env::var("CODEMEM_EMBED_PROVIDER")
627        .unwrap_or_else(|_| "candle".to_string())
628        .to_lowercase();
629    let dimensions: usize = std::env::var("CODEMEM_EMBED_DIMENSIONS")
630        .ok()
631        .and_then(|s| s.parse().ok())
632        .unwrap_or(DIMENSIONS);
633
634    match provider.as_str() {
635        "ollama" => {
636            let base_url = std::env::var("CODEMEM_EMBED_URL")
637                .unwrap_or_else(|_| ollama::DEFAULT_BASE_URL.to_string());
638            let model = std::env::var("CODEMEM_EMBED_MODEL")
639                .unwrap_or_else(|_| ollama::DEFAULT_MODEL.to_string());
640            let inner = Box::new(ollama::OllamaProvider::new(&base_url, &model, dimensions));
641            Ok(Box::new(CachedProvider::new(inner, CACHE_CAPACITY)))
642        }
643        "openai" => {
644            let api_key = std::env::var("CODEMEM_EMBED_API_KEY")
645                .or_else(|_| std::env::var("OPENAI_API_KEY"))
646                .map_err(|_| {
647                    CodememError::Embedding(
648                        "CODEMEM_EMBED_API_KEY or OPENAI_API_KEY required for OpenAI embeddings"
649                            .into(),
650                    )
651                })?;
652            let model = std::env::var("CODEMEM_EMBED_MODEL")
653                .unwrap_or_else(|_| openai::DEFAULT_MODEL.to_string());
654            let base_url = std::env::var("CODEMEM_EMBED_URL").ok();
655            let inner = Box::new(openai::OpenAIProvider::new(
656                &api_key,
657                &model,
658                dimensions,
659                base_url.as_deref(),
660            ));
661            Ok(Box::new(CachedProvider::new(inner, CACHE_CAPACITY)))
662        }
663        "candle" | "" => {
664            let model_dir = EmbeddingService::default_model_dir();
665            let service = EmbeddingService::new(&model_dir)?;
666            Ok(Box::new(service))
667        }
668        other => Err(CodememError::Embedding(format!(
669            "Unknown embedding provider: '{}'. Use 'candle', 'ollama', or 'openai'.",
670            other
671        ))),
672    }
673}
674
675#[cfg(test)]
676#[path = "tests/lib_tests.rs"]
677mod tests;