Skip to main content

codemem_embeddings/
lib.rs

1//! codemem-embeddings: Pluggable embedding providers for Codemem.
2//!
3//! Supports multiple backends:
4//! - **Candle** (default): Local BAAI/bge-base-en-v1.5 via pure Rust ML
5//! - **Ollama**: Local Ollama server with any embedding model
6//! - **OpenAI**: OpenAI API or any compatible endpoint (Together, Azure, etc.)
7
8pub mod ollama;
9pub mod openai;
10
11use candle_core::{DType, Device, Tensor};
12use candle_nn::VarBuilder;
13use candle_transformers::models::bert::{BertModel, Config as BertConfig};
14use codemem_core::CodememError;
15use lru::LruCache;
16use std::num::NonZeroUsize;
17use std::path::{Path, PathBuf};
18use std::sync::Mutex;
19use tokenizers::{PaddingParams, PaddingStrategy, TruncationParams};
20
21/// Default model name.
22pub const MODEL_NAME: &str = "bge-base-en-v1.5";
23
24/// HuggingFace model repo ID.
25const HF_MODEL_REPO: &str = "BAAI/bge-base-en-v1.5";
26
27/// Default embedding dimensions.
28pub const DIMENSIONS: usize = 768;
29
30/// Max sequence length for bge-base-en-v1.5.
31const MAX_SEQ_LENGTH: usize = 512;
32
33/// Default LRU cache capacity.
34pub const CACHE_CAPACITY: usize = 10_000;
35
36// ── Embedding Provider Trait ────────────────────────────────────────────────
37
38/// Trait for pluggable embedding providers.
39pub trait EmbeddingProvider: Send + Sync {
40    /// Embedding vector dimensions.
41    fn dimensions(&self) -> usize;
42
43    /// Embed a single text string.
44    fn embed(&self, text: &str) -> Result<Vec<f32>, CodememError>;
45
46    /// Embed a batch of texts (default: sequential).
47    fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, CodememError> {
48        texts.iter().map(|t| self.embed(t)).collect()
49    }
50
51    /// Provider name for display.
52    fn name(&self) -> &str;
53
54    /// Cache statistics: (current_size, capacity). Returns (0, 0) if no cache.
55    fn cache_stats(&self) -> (usize, usize) {
56        (0, 0)
57    }
58}
59
60// ── Candle Embedding Service ────────────────────────────────────────────────
61
62/// Maximum batch size for batched embedding forward passes.
63const BATCH_SIZE: usize = 32;
64
65/// Select the best available compute device.
66///
67/// Tries Metal (macOS GPU) first, then CUDA (NVIDIA GPU), then falls back to CPU.
68/// GPU backends are only available when the corresponding feature flag is enabled.
69fn select_device() -> Device {
70    #[cfg(feature = "metal")]
71    {
72        if let Ok(device) = Device::new_metal(0) {
73            tracing::info!("Using Metal GPU for embeddings");
74            return device;
75        }
76        tracing::warn!("Metal feature enabled but device creation failed, falling back");
77    }
78    #[cfg(feature = "cuda")]
79    {
80        if let Ok(device) = Device::new_cuda(0) {
81            tracing::info!("Using CUDA GPU for embeddings");
82            return device;
83        }
84        tracing::warn!("CUDA feature enabled but device creation failed, falling back");
85    }
86    tracing::info!("Using CPU for embeddings");
87    Device::Cpu
88}
89
90/// Embedding service with Candle inference and LRU caching.
91pub struct EmbeddingService {
92    model: Mutex<BertModel>,
93    tokenizer: tokenizers::Tokenizer,
94    device: Device,
95    cache: Mutex<LruCache<String, Vec<f32>>>,
96}
97
98impl EmbeddingService {
99    /// Create a new embedding service, loading model from the given directory.
100    /// Expects `model.safetensors`, `config.json`, and `tokenizer.json` in the directory.
101    pub fn new(model_dir: &Path) -> Result<Self, CodememError> {
102        let model_path = model_dir.join("model.safetensors");
103        let config_path = model_dir.join("config.json");
104        let tokenizer_path = model_dir.join("tokenizer.json");
105
106        if !model_path.exists() {
107            return Err(CodememError::Embedding(format!(
108                "Model not found at {}. Run `codemem init` to download it.",
109                model_path.display()
110            )));
111        }
112
113        let device = select_device();
114
115        // Load BERT config
116        let config_str = std::fs::read_to_string(&config_path)
117            .map_err(|e| CodememError::Embedding(format!("Failed to read config: {e}")))?;
118        let config: BertConfig = serde_json::from_str(&config_str)
119            .map_err(|e| CodememError::Embedding(format!("Failed to parse config: {e}")))?;
120
121        // Load model weights from safetensors via memory-mapped IO
122        let vb = unsafe {
123            VarBuilder::from_mmaped_safetensors(&[&model_path], DType::F32, &device)
124                .map_err(|e| CodememError::Embedding(format!("Failed to load weights: {e}")))?
125        };
126
127        // Try with "bert." prefix first (standard HF BERT models), then without
128        let model = BertModel::load(vb.pp("bert"), &config)
129            .or_else(|_| {
130                let vb2 = unsafe {
131                    VarBuilder::from_mmaped_safetensors(&[&model_path], DType::F32, &device)
132                        .map_err(|e| {
133                            candle_core::Error::Msg(format!("Failed to load weights: {e}"))
134                        })
135                }?;
136                BertModel::load(vb2, &config)
137            })
138            .map_err(|e| CodememError::Embedding(format!("Failed to load BERT model: {e}")))?;
139
140        let tokenizer = tokenizers::Tokenizer::from_file(&tokenizer_path)
141            .map_err(|e| CodememError::Embedding(e.to_string()))?;
142
143        let cache = Mutex::new(LruCache::new(
144            NonZeroUsize::new(CACHE_CAPACITY).expect("CACHE_CAPACITY is non-zero"),
145        ));
146
147        Ok(Self {
148            model: Mutex::new(model),
149            tokenizer,
150            device,
151            cache,
152        })
153    }
154
155    /// Get the default model directory path (~/.codemem/models/{MODEL_NAME}).
156    pub fn default_model_dir() -> PathBuf {
157        dirs::home_dir()
158            .unwrap_or_else(|| PathBuf::from("."))
159            .join(".codemem")
160            .join("models")
161            .join(MODEL_NAME)
162    }
163
164    /// Download the model from HuggingFace Hub to the given directory.
165    /// Returns the directory path. No-ops if model already exists.
166    pub fn download_model(dest_dir: &Path) -> Result<PathBuf, CodememError> {
167        let model_dest = dest_dir.join("model.safetensors");
168        let config_dest = dest_dir.join("config.json");
169        let tokenizer_dest = dest_dir.join("tokenizer.json");
170
171        if model_dest.exists() && config_dest.exists() && tokenizer_dest.exists() {
172            tracing::info!("Model already downloaded at {}", dest_dir.display());
173            return Ok(dest_dir.to_path_buf());
174        }
175
176        std::fs::create_dir_all(dest_dir)
177            .map_err(|e| CodememError::Embedding(format!("Failed to create dir: {e}")))?;
178
179        tracing::info!("Downloading {} from HuggingFace...", HF_MODEL_REPO);
180
181        let api = hf_hub::api::sync::Api::new()
182            .map_err(|e| CodememError::Embedding(format!("HuggingFace API error: {e}")))?;
183        let repo = api.model(HF_MODEL_REPO.to_string());
184
185        let cached_model = repo
186            .get("model.safetensors")
187            .map_err(|e| CodememError::Embedding(format!("Failed to download model: {e}")))?;
188
189        let cached_config = repo
190            .get("config.json")
191            .map_err(|e| CodememError::Embedding(format!("Failed to download config: {e}")))?;
192
193        let cached_tokenizer = repo
194            .get("tokenizer.json")
195            .map_err(|e| CodememError::Embedding(format!("Failed to download tokenizer: {e}")))?;
196
197        std::fs::copy(&cached_model, &model_dest)
198            .map_err(|e| CodememError::Embedding(format!("Failed to copy model: {e}")))?;
199        std::fs::copy(&cached_config, &config_dest)
200            .map_err(|e| CodememError::Embedding(format!("Failed to copy config: {e}")))?;
201        std::fs::copy(&cached_tokenizer, &tokenizer_dest)
202            .map_err(|e| CodememError::Embedding(format!("Failed to copy tokenizer: {e}")))?;
203
204        tracing::info!("Model downloaded to {}", dest_dir.display());
205        Ok(dest_dir.to_path_buf())
206    }
207
208    /// Embed a single text string. Returns a 768-dim L2-normalized vector.
209    /// Uses LRU cache for repeated queries.
210    pub fn embed(&self, text: &str) -> Result<Vec<f32>, CodememError> {
211        // Check cache
212        {
213            let mut cache = self
214                .cache
215                .lock()
216                .map_err(|e| CodememError::LockPoisoned(format!("embedding cache: {e}")))?;
217            if let Some(cached) = cache.get(text) {
218                return Ok(cached.clone());
219            }
220        }
221
222        let embedding = self.embed_uncached(text)?;
223
224        // Store in cache
225        {
226            let mut cache = self
227                .cache
228                .lock()
229                .map_err(|e| CodememError::LockPoisoned(format!("embedding cache: {e}")))?;
230            cache.put(text.to_string(), embedding.clone());
231        }
232
233        Ok(embedding)
234    }
235
236    /// Embed without caching. Uses mean pooling with attention mask.
237    fn embed_uncached(&self, text: &str) -> Result<Vec<f32>, CodememError> {
238        // Tokenize with truncation
239        let mut tokenizer = self.tokenizer.clone();
240
241        tokenizer
242            .with_truncation(Some(tokenizers::TruncationParams {
243                max_length: MAX_SEQ_LENGTH,
244                ..Default::default()
245            }))
246            .map_err(|e| CodememError::Embedding(format!("Truncation error: {e}")))?;
247
248        let encoding = tokenizer
249            .encode(text, true)
250            .map_err(|e| CodememError::Embedding(e.to_string()))?;
251
252        let input_ids: Vec<u32> = encoding.get_ids().to_vec();
253        let attention_mask: Vec<u32> = encoding.get_attention_mask().to_vec();
254
255        // Build candle tensors with shape [1, seq_len]
256        let input_ids_tensor = Tensor::new(&input_ids[..], &self.device)
257            .and_then(|t| t.unsqueeze(0))
258            .map_err(|e| CodememError::Embedding(format!("Tensor error: {e}")))?;
259
260        let token_type_ids = input_ids_tensor
261            .zeros_like()
262            .map_err(|e| CodememError::Embedding(format!("Tensor error: {e}")))?;
263
264        let attention_mask_tensor = Tensor::new(&attention_mask[..], &self.device)
265            .and_then(|t| t.unsqueeze(0))
266            .map_err(|e| CodememError::Embedding(format!("Tensor error: {e}")))?;
267
268        // Forward pass -> [1, seq_len, hidden_size]
269        let model = self
270            .model
271            .lock()
272            .map_err(|e| CodememError::LockPoisoned(format!("embedding model: {e}")))?;
273        let hidden_states = model
274            .forward(
275                &input_ids_tensor,
276                &token_type_ids,
277                Some(&attention_mask_tensor),
278            )
279            .map_err(|e| CodememError::Embedding(format!("Model forward error: {e}")))?;
280        drop(model);
281
282        // Mean pooling weighted by attention mask
283        // attention_mask: [1, seq_len] -> [1, seq_len, 1] for broadcasting
284        let mask = attention_mask_tensor
285            .to_dtype(DType::F32)
286            .and_then(|t| t.unsqueeze(2))
287            .map_err(|e| CodememError::Embedding(format!("Mask error: {e}")))?;
288
289        let sum_mask = mask
290            .sum(1)
291            .map_err(|e| CodememError::Embedding(format!("Sum error: {e}")))?;
292
293        let pooled = hidden_states
294            .broadcast_mul(&mask)
295            .and_then(|t| t.sum(1))
296            .and_then(|t| t.broadcast_div(&sum_mask))
297            .map_err(|e| CodememError::Embedding(format!("Pooling error: {e}")))?;
298
299        // L2 normalize
300        let normalized = pooled
301            .sqr()
302            .and_then(|t| t.sum_keepdim(1))
303            .and_then(|t| t.sqrt())
304            .and_then(|norm| pooled.broadcast_div(&norm))
305            .map_err(|e| CodememError::Embedding(format!("Normalize error: {e}")))?;
306
307        // Extract as Vec<f32> — shape is [1, hidden_size], squeeze to [hidden_size]
308        let embedding: Vec<f32> = normalized
309            .squeeze(0)
310            .and_then(|t| t.to_vec1())
311            .map_err(|e| CodememError::Embedding(format!("Extract error: {e}")))?;
312
313        Ok(embedding)
314    }
315
316    /// Embed a batch of texts with cache-aware batching.
317    ///
318    /// Checks the LRU cache first and only runs the model on uncached texts,
319    /// using a true batched forward pass (single GPU/CPU kernel launch per chunk).
320    pub fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, CodememError> {
321        if texts.is_empty() {
322            return Ok(vec![]);
323        }
324
325        // Partition into cached and uncached
326        let mut results: Vec<Option<Vec<f32>>> = vec![None; texts.len()];
327        let mut uncached_indices = Vec::new();
328        let mut uncached_texts = Vec::new();
329
330        {
331            let mut cache = self
332                .cache
333                .lock()
334                .map_err(|e| CodememError::LockPoisoned(format!("embedding cache: {e}")))?;
335            for (i, text) in texts.iter().enumerate() {
336                if let Some(cached) = cache.get(*text) {
337                    results[i] = Some(cached.clone());
338                } else {
339                    uncached_indices.push(i);
340                    uncached_texts.push(*text);
341                }
342            }
343        }
344
345        if !uncached_texts.is_empty() {
346            let new_embeddings = self.embed_batch_uncached(&uncached_texts)?;
347
348            let mut cache = self
349                .cache
350                .lock()
351                .map_err(|e| CodememError::LockPoisoned(format!("embedding cache: {e}")))?;
352            for (idx, embedding) in uncached_indices.into_iter().zip(new_embeddings) {
353                cache.put(texts[idx].to_string(), embedding.clone());
354                results[idx] = Some(embedding);
355            }
356        }
357
358        Ok(results.into_iter().flatten().collect())
359    }
360
361    /// Embed a batch of texts without caching, using a true batched forward pass.
362    ///
363    /// Tokenizes all texts, pads to the longest sequence in each chunk, runs a
364    /// single forward pass per chunk of up to `BATCH_SIZE` texts, then performs
365    /// mean pooling and L2 normalization on the batched output.
366    fn embed_batch_uncached(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, CodememError> {
367        if texts.is_empty() {
368            return Ok(vec![]);
369        }
370
371        let mut all_embeddings = Vec::with_capacity(texts.len());
372
373        for chunk in texts.chunks(BATCH_SIZE) {
374            let mut tokenizer = self.tokenizer.clone();
375
376            tokenizer
377                .with_truncation(Some(TruncationParams {
378                    max_length: MAX_SEQ_LENGTH,
379                    ..Default::default()
380                }))
381                .map_err(|e| CodememError::Embedding(format!("Truncation error: {e}")))?;
382
383            // Pad all sequences in this chunk to the length of the longest
384            tokenizer.with_padding(Some(PaddingParams {
385                strategy: PaddingStrategy::BatchLongest,
386                ..Default::default()
387            }));
388
389            let encodings = tokenizer
390                .encode_batch(chunk.to_vec(), true)
391                .map_err(|e| CodememError::Embedding(format!("Batch encode error: {e}")))?;
392
393            let batch_len = encodings.len();
394            let seq_len = encodings[0].get_ids().len();
395
396            // Flatten token IDs and attention masks into contiguous arrays
397            let all_ids: Vec<u32> = encodings
398                .iter()
399                .flat_map(|e| e.get_ids())
400                .copied()
401                .collect();
402            let all_masks: Vec<u32> = encodings
403                .iter()
404                .flat_map(|e| e.get_attention_mask())
405                .copied()
406                .collect();
407
408            // Build tensors with shape [batch_size, seq_len]
409            let input_ids = Tensor::new(all_ids.as_slice(), &self.device)
410                .and_then(|t| t.reshape((batch_len, seq_len)))
411                .map_err(|e| CodememError::Embedding(format!("Tensor error: {e}")))?;
412
413            let token_type_ids = input_ids
414                .zeros_like()
415                .map_err(|e| CodememError::Embedding(format!("Tensor error: {e}")))?;
416
417            let attention_mask = Tensor::new(all_masks.as_slice(), &self.device)
418                .and_then(|t| t.reshape((batch_len, seq_len)))
419                .map_err(|e| CodememError::Embedding(format!("Tensor error: {e}")))?;
420
421            // Single forward pass -> [batch_size, seq_len, hidden_size]
422            let model = self
423                .model
424                .lock()
425                .map_err(|e| CodememError::LockPoisoned(format!("embedding model: {e}")))?;
426            let hidden_states = model
427                .forward(&input_ids, &token_type_ids, Some(&attention_mask))
428                .map_err(|e| CodememError::Embedding(format!("Forward error: {e}")))?;
429            drop(model);
430
431            // Mean pooling: mask [batch, seq] -> [batch, seq, 1] for broadcast
432            let mask = attention_mask
433                .to_dtype(DType::F32)
434                .and_then(|t| t.unsqueeze(2))
435                .map_err(|e| CodememError::Embedding(format!("Mask error: {e}")))?;
436
437            let sum_mask = mask
438                .sum(1)
439                .map_err(|e| CodememError::Embedding(format!("Sum error: {e}")))?;
440
441            let pooled = hidden_states
442                .broadcast_mul(&mask)
443                .and_then(|t| t.sum(1))
444                .and_then(|t| t.broadcast_div(&sum_mask))
445                .map_err(|e| CodememError::Embedding(format!("Pooling error: {e}")))?;
446
447            // L2 normalize: [batch, hidden]
448            let norm = pooled
449                .sqr()
450                .and_then(|t| t.sum_keepdim(1))
451                .and_then(|t| t.sqrt())
452                .map_err(|e| CodememError::Embedding(format!("Norm error: {e}")))?;
453
454            let normalized = pooled
455                .broadcast_div(&norm)
456                .map_err(|e| CodememError::Embedding(format!("Normalize error: {e}")))?;
457
458            // Extract each row as Vec<f32>
459            for i in 0..batch_len {
460                let row: Vec<f32> = normalized
461                    .get(i)
462                    .and_then(|t| t.to_vec1())
463                    .map_err(|e| CodememError::Embedding(format!("Extract error: {e}")))?;
464                all_embeddings.push(row);
465            }
466        }
467
468        Ok(all_embeddings)
469    }
470
471    /// Get cache statistics: (current_size, capacity).
472    pub fn cache_stats(&self) -> (usize, usize) {
473        match self.cache.lock() {
474            Ok(cache) => (cache.len(), CACHE_CAPACITY),
475            Err(_) => (0, CACHE_CAPACITY),
476        }
477    }
478}
479
480impl EmbeddingProvider for EmbeddingService {
481    fn dimensions(&self) -> usize {
482        DIMENSIONS
483    }
484
485    fn embed(&self, text: &str) -> Result<Vec<f32>, CodememError> {
486        self.embed(text)
487    }
488
489    fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, CodememError> {
490        self.embed_batch(texts)
491    }
492
493    fn name(&self) -> &str {
494        "candle"
495    }
496
497    fn cache_stats(&self) -> (usize, usize) {
498        self.cache_stats()
499    }
500}
501
502// ── Cached Provider Wrapper ───────────────────────────────────────────────
503
504/// Wraps any `EmbeddingProvider` with an LRU cache.
505pub struct CachedProvider {
506    inner: Box<dyn EmbeddingProvider>,
507    cache: Mutex<LruCache<String, Vec<f32>>>,
508}
509
510impl CachedProvider {
511    pub fn new(inner: Box<dyn EmbeddingProvider>, capacity: usize) -> Self {
512        // SAFETY: 1 is non-zero, so the inner expect is infallible
513        let cap =
514            NonZeroUsize::new(capacity).unwrap_or(NonZeroUsize::new(1).expect("1 is non-zero"));
515        Self {
516            inner,
517            cache: Mutex::new(LruCache::new(cap)),
518        }
519    }
520}
521
522impl EmbeddingProvider for CachedProvider {
523    fn dimensions(&self) -> usize {
524        self.inner.dimensions()
525    }
526
527    fn embed(&self, text: &str) -> Result<Vec<f32>, CodememError> {
528        {
529            let mut cache = self
530                .cache
531                .lock()
532                .map_err(|e| CodememError::LockPoisoned(format!("cached provider: {e}")))?;
533            if let Some(cached) = cache.get(text) {
534                return Ok(cached.clone());
535            }
536        }
537        let embedding = self.inner.embed(text)?;
538        {
539            let mut cache = self
540                .cache
541                .lock()
542                .map_err(|e| CodememError::LockPoisoned(format!("cached provider: {e}")))?;
543            cache.put(text.to_string(), embedding.clone());
544        }
545        Ok(embedding)
546    }
547
548    fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, CodememError> {
549        // Check cache, only forward uncached texts
550        let mut results: Vec<Option<Vec<f32>>> = vec![None; texts.len()];
551        let mut uncached = Vec::new();
552        let mut uncached_idx = Vec::new();
553
554        {
555            let mut cache = self
556                .cache
557                .lock()
558                .map_err(|e| CodememError::LockPoisoned(format!("cached provider: {e}")))?;
559            for (i, text) in texts.iter().enumerate() {
560                if let Some(cached) = cache.get(*text) {
561                    results[i] = Some(cached.clone());
562                } else {
563                    uncached_idx.push(i);
564                    uncached.push(*text);
565                }
566            }
567        }
568
569        if !uncached.is_empty() {
570            let new_embeddings = self.inner.embed_batch(&uncached)?;
571            let mut cache = self
572                .cache
573                .lock()
574                .map_err(|e| CodememError::LockPoisoned(format!("cached provider: {e}")))?;
575            for (idx, embedding) in uncached_idx.into_iter().zip(new_embeddings) {
576                cache.put(texts[idx].to_string(), embedding.clone());
577                results[idx] = Some(embedding);
578            }
579        }
580
581        Ok(results.into_iter().flatten().collect())
582    }
583
584    fn name(&self) -> &str {
585        self.inner.name()
586    }
587
588    fn cache_stats(&self) -> (usize, usize) {
589        match self.cache.lock() {
590            Ok(cache) => (cache.len(), cache.cap().into()),
591            Err(_) => (0, 0),
592        }
593    }
594}
595
596// ── Factory ───────────────────────────────────────────────────────────────
597
598/// Create an embedding provider from environment variables.
599///
600/// | Variable | Values | Default |
601/// |----------|--------|---------|
602/// | `CODEMEM_EMBED_PROVIDER` | `candle`, `ollama`, `openai` | `candle` |
603/// | `CODEMEM_EMBED_MODEL` | model name | provider default |
604/// | `CODEMEM_EMBED_URL` | base URL | provider default |
605/// | `CODEMEM_EMBED_API_KEY` | API key | also reads `OPENAI_API_KEY` |
606/// | `CODEMEM_EMBED_DIMENSIONS` | integer | `768` |
607pub fn from_env() -> Result<Box<dyn EmbeddingProvider>, CodememError> {
608    let provider = std::env::var("CODEMEM_EMBED_PROVIDER")
609        .unwrap_or_else(|_| "candle".to_string())
610        .to_lowercase();
611    let dimensions: usize = std::env::var("CODEMEM_EMBED_DIMENSIONS")
612        .ok()
613        .and_then(|s| s.parse().ok())
614        .unwrap_or(DIMENSIONS);
615
616    match provider.as_str() {
617        "ollama" => {
618            let base_url = std::env::var("CODEMEM_EMBED_URL")
619                .unwrap_or_else(|_| ollama::DEFAULT_BASE_URL.to_string());
620            let model = std::env::var("CODEMEM_EMBED_MODEL")
621                .unwrap_or_else(|_| ollama::DEFAULT_MODEL.to_string());
622            let inner = Box::new(ollama::OllamaProvider::new(&base_url, &model, dimensions));
623            Ok(Box::new(CachedProvider::new(inner, CACHE_CAPACITY)))
624        }
625        "openai" => {
626            let api_key = std::env::var("CODEMEM_EMBED_API_KEY")
627                .or_else(|_| std::env::var("OPENAI_API_KEY"))
628                .map_err(|_| {
629                    CodememError::Embedding(
630                        "CODEMEM_EMBED_API_KEY or OPENAI_API_KEY required for OpenAI embeddings"
631                            .into(),
632                    )
633                })?;
634            let model = std::env::var("CODEMEM_EMBED_MODEL")
635                .unwrap_or_else(|_| openai::DEFAULT_MODEL.to_string());
636            let base_url = std::env::var("CODEMEM_EMBED_URL").ok();
637            let inner = Box::new(openai::OpenAIProvider::new(
638                &api_key,
639                &model,
640                dimensions,
641                base_url.as_deref(),
642            ));
643            Ok(Box::new(CachedProvider::new(inner, CACHE_CAPACITY)))
644        }
645        "candle" | "" => {
646            let model_dir = EmbeddingService::default_model_dir();
647            let service = EmbeddingService::new(&model_dir)?;
648            Ok(Box::new(service))
649        }
650        other => Err(CodememError::Embedding(format!(
651            "Unknown embedding provider: '{}'. Use 'candle', 'ollama', or 'openai'.",
652            other
653        ))),
654    }
655}
656
657#[cfg(test)]
658mod tests {
659    use super::*;
660    use std::sync::atomic::{AtomicUsize, Ordering};
661
662    /// A mock embedding provider for testing CachedProvider behavior.
663    struct MockProvider {
664        dims: usize,
665        call_count: AtomicUsize,
666    }
667
668    impl MockProvider {
669        fn new(dims: usize) -> Self {
670            Self {
671                dims,
672                call_count: AtomicUsize::new(0),
673            }
674        }
675    }
676
677    impl EmbeddingProvider for MockProvider {
678        fn dimensions(&self) -> usize {
679            self.dims
680        }
681
682        fn embed(&self, _text: &str) -> Result<Vec<f32>, CodememError> {
683            self.call_count.fetch_add(1, Ordering::SeqCst);
684            Ok(vec![0.1; self.dims])
685        }
686
687        fn name(&self) -> &str {
688            "mock"
689        }
690    }
691
692    #[test]
693    fn cached_provider_cache_hit() {
694        let mock = MockProvider::new(4);
695        let provider = CachedProvider::new(Box::new(mock), 100);
696
697        // First call: cache miss
698        let v1 = provider.embed("hello").unwrap();
699        assert_eq!(v1.len(), 4);
700
701        // Second call: cache hit (inner should only be called once)
702        let v2 = provider.embed("hello").unwrap();
703        assert_eq!(v1, v2);
704
705        // Access inner mock through the provider trait -- call_count should be 1
706        // We can check cache_stats instead
707        let (size, cap) = provider.cache_stats();
708        assert_eq!(size, 1);
709        assert_eq!(cap, 100);
710    }
711
712    #[test]
713    fn cached_provider_cache_miss() {
714        let mock = MockProvider::new(4);
715        let provider = CachedProvider::new(Box::new(mock), 100);
716
717        provider.embed("hello").unwrap();
718        provider.embed("world").unwrap();
719
720        let (size, _) = provider.cache_stats();
721        assert_eq!(size, 2);
722    }
723
724    #[test]
725    fn cached_provider_batch_empty() {
726        let mock = MockProvider::new(4);
727        let provider = CachedProvider::new(Box::new(mock), 100);
728
729        let result = provider.embed_batch(&[]).unwrap();
730        assert!(result.is_empty());
731    }
732
733    #[test]
734    fn cached_provider_batch_single() {
735        let mock = MockProvider::new(4);
736        let provider = CachedProvider::new(Box::new(mock), 100);
737
738        let result = provider.embed_batch(&["hello"]).unwrap();
739        assert_eq!(result.len(), 1);
740        assert_eq!(result[0].len(), 4);
741
742        let (size, _) = provider.cache_stats();
743        assert_eq!(size, 1);
744    }
745
746    #[test]
747    fn cached_provider_batch_mixed_cache() {
748        let mock = MockProvider::new(4);
749        let provider = CachedProvider::new(Box::new(mock), 100);
750
751        // Pre-populate cache
752        provider.embed("hello").unwrap();
753
754        // Batch with one cached and one uncached
755        let result = provider.embed_batch(&["hello", "world"]).unwrap();
756        assert_eq!(result.len(), 2);
757
758        let (size, _) = provider.cache_stats();
759        assert_eq!(size, 2);
760    }
761
762    #[test]
763    fn cached_provider_zero_capacity() {
764        // Capacity of 0 should default to 1
765        let mock = MockProvider::new(4);
766        let provider = CachedProvider::new(Box::new(mock), 0);
767
768        provider.embed("a").unwrap();
769        provider.embed("b").unwrap();
770
771        let (size, cap) = provider.cache_stats();
772        // Cap should be 1 (the fallback), so only 1 entry retained
773        assert_eq!(cap, 1);
774        assert_eq!(size, 1);
775    }
776
777    #[test]
778    fn cached_provider_name_delegates() {
779        let mock = MockProvider::new(4);
780        let provider = CachedProvider::new(Box::new(mock), 100);
781        assert_eq!(provider.name(), "mock");
782    }
783
784    #[test]
785    fn cached_provider_dimensions_delegates() {
786        let mock = MockProvider::new(768);
787        let provider = CachedProvider::new(Box::new(mock), 100);
788        assert_eq!(provider.dimensions(), 768);
789    }
790
791    #[test]
792    fn from_env_unknown_provider() {
793        // Set env var to trigger the error path
794        std::env::set_var("CODEMEM_EMBED_PROVIDER", "nonexistent_provider_xyz");
795        let result = from_env();
796        std::env::remove_var("CODEMEM_EMBED_PROVIDER");
797
798        match result {
799            Err(e) => {
800                let err = e.to_string();
801                assert!(
802                    err.contains("Unknown embedding provider"),
803                    "Error should mention unknown provider: {err}"
804                );
805            }
806            Ok(_) => panic!("Expected error for unknown provider"),
807        }
808    }
809
810    #[test]
811    fn embedding_service_missing_model() {
812        match EmbeddingService::new(Path::new("/nonexistent/path")) {
813            Err(e) => {
814                let err = e.to_string();
815                assert!(
816                    err.contains("Model not found"),
817                    "Error should mention missing model: {err}"
818                );
819            }
820            Ok(_) => panic!("Expected error for missing model"),
821        }
822    }
823
824    #[test]
825    fn default_model_dir_path() {
826        let dir = EmbeddingService::default_model_dir();
827        assert!(dir.to_string_lossy().contains(MODEL_NAME));
828        assert!(dir.to_string_lossy().contains(".codemem"));
829    }
830
831    #[test]
832    fn constants_are_sensible() {
833        assert_eq!(DIMENSIONS, 768);
834        assert_eq!(CACHE_CAPACITY, 10_000);
835        assert_eq!(MODEL_NAME, "bge-base-en-v1.5");
836    }
837}