Skip to main content

codemem_embeddings/
lib.rs

1//! codemem-embeddings: Pluggable embedding providers for Codemem.
2//!
3//! Supports multiple backends:
4//! - **Candle** (default): Local BERT models via pure Rust ML (any HF BERT model)
5//! - **Ollama**: Local Ollama server with any embedding model
6//! - **OpenAI**: OpenAI API or any compatible endpoint (Together, Azure, etc.)
7//! - **Gemini**: Google Generative Language API (text-embedding-004)
8
9pub mod gemini;
10pub mod ollama;
11pub mod openai;
12
13use candle_core::{DType, Device, Tensor};
14use candle_nn::{Module, VarBuilder};
15use candle_transformers::models::bert::{BertModel, Config as BertConfig};
16use candle_transformers::models::jina_bert::{
17    BertModel as JinaBertModel, Config as JinaBertConfig,
18};
19use codemem_core::CodememError;
20use lru::LruCache;
21use std::num::NonZeroUsize;
22use std::path::{Path, PathBuf};
23use std::sync::Mutex;
24use tokenizers::{PaddingParams, PaddingStrategy};
25
26/// Default model name (short form used for directory naming).
27pub const MODEL_NAME: &str = "bge-base-en-v1.5";
28
29/// Default HuggingFace model repo ID.
30/// Used internally and by `commands_init` for the default model download.
31pub const DEFAULT_HF_REPO: &str = "BAAI/bge-base-en-v1.5";
32
33/// Default embedding dimensions for remote providers (Ollama/OpenAI).
34/// Candle reads `hidden_size` from the model's config.json instead.
35pub const DEFAULT_REMOTE_DIMENSIONS: usize = 768;
36
37/// Default max sequence length for standard BERT models (used when config doesn't specify).
38const DEFAULT_MAX_SEQ_LENGTH: usize = 512;
39
40/// Default LRU cache capacity.
41pub const CACHE_CAPACITY: usize = 10_000;
42
43// Re-export EmbeddingProvider trait from core
44pub use codemem_core::EmbeddingProvider;
45
46// ── Candle Embedding Service ────────────────────────────────────────────────
47
48/// Default batch size for batched embedding forward passes.
49/// Configurable via `EmbeddingConfig.batch_size` or `CODEMEM_EMBED_BATCH_SIZE`.
50pub const DEFAULT_BATCH_SIZE: usize = 16;
51
52/// Select the best available compute device.
53///
54/// Tries Metal (macOS GPU) first, then CUDA (NVIDIA GPU), then falls back to CPU.
55/// GPU backends are only available when the corresponding feature flag is enabled.
56fn select_device() -> Device {
57    #[cfg(feature = "metal")]
58    {
59        // Use catch_unwind to handle SIGBUS/panics on CI runners without GPU access.
60        match std::panic::catch_unwind(|| Device::new_metal(0)) {
61            Ok(Ok(device)) => {
62                tracing::info!("Using Metal GPU for embeddings");
63                return device;
64            }
65            Ok(Err(e)) => {
66                tracing::warn!("Metal device creation failed: {e}, falling back to CPU");
67            }
68            Err(_) => {
69                tracing::warn!("Metal device creation panicked, falling back to CPU");
70            }
71        }
72    }
73    #[cfg(feature = "cuda")]
74    {
75        match std::panic::catch_unwind(|| Device::new_cuda(0)) {
76            Ok(Ok(device)) => {
77                tracing::info!("Using CUDA GPU for embeddings");
78                return device;
79            }
80            Ok(Err(e)) => {
81                tracing::warn!("CUDA device creation failed: {e}, falling back to CPU");
82            }
83            Err(_) => {
84                tracing::warn!("CUDA device creation panicked, falling back to CPU");
85            }
86        }
87    }
88    tracing::info!("Using CPU for embeddings");
89    Device::Cpu
90}
91
92/// Model backend enum — dispatches forward passes to the correct architecture.
93enum ModelBackend {
94    /// Standard BERT (absolute positional embeddings). Used by BGE, MiniLM, etc.
95    Bert(BertModel),
96    /// JinaBERT (ALiBi positional embeddings). Used by Jina embeddings v2.
97    JinaBert(JinaBertModel),
98}
99
100/// Embedding service with Candle inference (no internal cache — use `CachedProvider` wrapper).
101pub struct EmbeddingService {
102    model: Mutex<ModelBackend>,
103    /// Tokenizer pre-configured with truncation (no padding).
104    /// Used directly for single embeds; cloned and augmented with padding for batch.
105    tokenizer: tokenizers::Tokenizer,
106    device: Device,
107    /// Maximum texts per forward pass (GPU memory trade-off).
108    batch_size: usize,
109    /// Hidden size read from model config (e.g. 768 for bge-base, 384 for bge-small).
110    hidden_size: usize,
111    /// Max sequence length (512 for BERT, up to 8192 for JinaBERT).
112    max_seq_length: usize,
113}
114
115/// Minimal struct for sniffing model architecture from config.json before full parsing.
116#[derive(serde::Deserialize)]
117struct ConfigProbe {
118    #[serde(default)]
119    position_embedding_type: Option<String>,
120    hidden_size: usize,
121    #[serde(default = "default_max_position_embeddings")]
122    max_position_embeddings: usize,
123}
124
125fn default_max_position_embeddings() -> usize {
126    DEFAULT_MAX_SEQ_LENGTH
127}
128
129impl EmbeddingService {
130    /// Create a new embedding service, loading model from the given directory.
131    /// Expects `model.safetensors`, `config.json`, and `tokenizer.json` in the directory.
132    ///
133    /// Auto-detects model architecture (BERT vs JinaBERT) from config.json.
134    /// `dtype` controls precision: `DType::F32` (default) or `DType::F16` (half memory, faster on Metal).
135    pub fn new(model_dir: &Path, batch_size: usize, dtype: DType) -> Result<Self, CodememError> {
136        let model_path = model_dir.join("model.safetensors");
137        let config_path = model_dir.join("config.json");
138        let tokenizer_path = model_dir.join("tokenizer.json");
139
140        if !model_path.exists() {
141            return Err(CodememError::Embedding(format!(
142                "Model not found at {}. Run `codemem init` to download it.",
143                model_path.display()
144            )));
145        }
146
147        let device = select_device();
148
149        tracing::info!(
150            "Loading model from {} (dtype: {:?}, device: {:?})",
151            model_dir.display(),
152            dtype,
153            device
154        );
155
156        let config_str = std::fs::read_to_string(&config_path)
157            .map_err(|e| CodememError::Embedding(format!("Failed to read config: {e}")))?;
158
159        // Probe config to detect architecture before full parsing
160        let probe: ConfigProbe = serde_json::from_str(&config_str)
161            .map_err(|e| CodememError::Embedding(format!("Failed to probe config: {e}")))?;
162        let hidden_size = probe.hidden_size;
163        let is_alibi = probe
164            .position_embedding_type
165            .as_deref()
166            .is_some_and(|t| t == "alibi");
167        // Cap at 8192 to avoid excessive memory usage even if model claims more
168        let max_seq_length = probe.max_position_embeddings.min(8192);
169
170        let (model, arch_name) = if is_alibi {
171            // JinaBERT (ALiBi positional embeddings)
172            let config: JinaBertConfig = serde_json::from_str(&config_str).map_err(|e| {
173                CodememError::Embedding(format!("Failed to parse JinaBERT config: {e}"))
174            })?;
175            let vb = unsafe {
176                VarBuilder::from_mmaped_safetensors(&[&model_path], dtype, &device)
177                    .map_err(|e| CodememError::Embedding(format!("Failed to load weights: {e}")))?
178            };
179            // JinaBERT weights use "bert." prefix
180            let jina_model = JinaBertModel::new(vb.pp("bert"), &config).map_err(|e| {
181                CodememError::Embedding(format!("Failed to load JinaBERT model: {e}"))
182            })?;
183            (ModelBackend::JinaBert(jina_model), "JinaBERT (ALiBi)")
184        } else {
185            // Standard BERT (absolute positional embeddings)
186            let config: BertConfig = serde_json::from_str(&config_str).map_err(|e| {
187                CodememError::Embedding(format!("Failed to parse BERT config: {e}"))
188            })?;
189            // Load model weights from safetensors via memory-mapped IO.
190            // Scope vb so it drops before a potential retry, avoiding two VarBuilders
191            // holding materialized Metal tensors simultaneously.
192            let bert_model = {
193                let vb = unsafe {
194                    VarBuilder::from_mmaped_safetensors(&[&model_path], dtype, &device).map_err(
195                        |e| CodememError::Embedding(format!("Failed to load weights: {e}")),
196                    )?
197                };
198                BertModel::load(vb.pp("bert"), &config)
199            };
200            // Try with "bert." prefix first (standard HF BERT models), then without
201            let bert_model = match bert_model {
202                Ok(m) => m,
203                Err(_) => {
204                    let vb2 = unsafe {
205                        VarBuilder::from_mmaped_safetensors(&[&model_path], dtype, &device)
206                            .map_err(|e| {
207                                CodememError::Embedding(format!("Failed to load weights: {e}"))
208                            })?
209                    };
210                    BertModel::load(vb2, &config).map_err(|e| {
211                        CodememError::Embedding(format!("Failed to load BERT model: {e}"))
212                    })?
213                }
214            };
215            (ModelBackend::Bert(bert_model), "BERT (absolute)")
216        };
217
218        tracing::info!(
219            "Loaded {} model (hidden_size={}, max_seq_length={})",
220            arch_name,
221            hidden_size,
222            max_seq_length
223        );
224
225        let mut tokenizer = tokenizers::Tokenizer::from_file(&tokenizer_path)
226            .map_err(|e| CodememError::Embedding(e.to_string()))?;
227
228        // Pre-configure truncation once so we don't need to clone on every embed call.
229        tokenizer
230            .with_truncation(Some(tokenizers::TruncationParams {
231                max_length: max_seq_length,
232                ..Default::default()
233            }))
234            .map_err(|e| CodememError::Embedding(format!("Truncation error: {e}")))?;
235
236        Ok(Self {
237            model: Mutex::new(model),
238            tokenizer,
239            device,
240            batch_size,
241            hidden_size,
242            max_seq_length,
243        })
244    }
245
246    /// Maximum sequence length this model supports.
247    pub fn max_seq_length(&self) -> usize {
248        self.max_seq_length
249    }
250
251    /// Get the model directory path for a given model name.
252    /// Falls back to `~/.codemem/models/{model_name}`.
253    pub fn model_dir_for(model_name: &str) -> PathBuf {
254        dirs::home_dir()
255            .unwrap_or_else(|| PathBuf::from("."))
256            .join(".codemem")
257            .join("models")
258            .join(model_name)
259    }
260
261    /// Get the default model directory path (~/.codemem/models/{MODEL_NAME}).
262    pub fn default_model_dir() -> PathBuf {
263        Self::model_dir_for(MODEL_NAME)
264    }
265
266    /// Download a model from HuggingFace Hub to the given directory.
267    /// `hf_repo` is the full repo ID (e.g. "BAAI/bge-base-en-v1.5").
268    /// Returns the directory path. No-ops if model already exists.
269    pub fn download_model(dest_dir: &Path, hf_repo: &str) -> Result<PathBuf, CodememError> {
270        let model_dest = dest_dir.join("model.safetensors");
271        let config_dest = dest_dir.join("config.json");
272        let tokenizer_dest = dest_dir.join("tokenizer.json");
273
274        if model_dest.exists() && config_dest.exists() && tokenizer_dest.exists() {
275            tracing::info!("Model already downloaded at {}", dest_dir.display());
276            return Ok(dest_dir.to_path_buf());
277        }
278
279        std::fs::create_dir_all(dest_dir)
280            .map_err(|e| CodememError::Embedding(format!("Failed to create dir: {e}")))?;
281
282        tracing::info!("Downloading {} from HuggingFace...", hf_repo);
283
284        let api = hf_hub::api::sync::Api::new()
285            .map_err(|e| CodememError::Embedding(format!("HuggingFace API error: {e}")))?;
286        let repo = api.model(hf_repo.to_string());
287
288        let cached_model = repo
289            .get("model.safetensors")
290            .map_err(|e| CodememError::Embedding(format!("Failed to download model: {e}")))?;
291
292        let cached_config = repo
293            .get("config.json")
294            .map_err(|e| CodememError::Embedding(format!("Failed to download config: {e}")))?;
295
296        let cached_tokenizer = repo
297            .get("tokenizer.json")
298            .map_err(|e| CodememError::Embedding(format!("Failed to download tokenizer: {e}")))?;
299
300        std::fs::copy(&cached_model, &model_dest)
301            .map_err(|e| CodememError::Embedding(format!("Failed to copy model: {e}")))?;
302        std::fs::copy(&cached_config, &config_dest)
303            .map_err(|e| CodememError::Embedding(format!("Failed to copy config: {e}")))?;
304        std::fs::copy(&cached_tokenizer, &tokenizer_dest)
305            .map_err(|e| CodememError::Embedding(format!("Failed to copy tokenizer: {e}")))?;
306
307        tracing::info!("Model downloaded to {}", dest_dir.display());
308        Ok(dest_dir.to_path_buf())
309    }
310
311    /// Download the default model (BAAI/bge-base-en-v1.5) to the default directory.
312    /// Convenience wrapper for `download_model(&default_model_dir(), DEFAULT_HF_REPO)`.
313    pub fn download_default_model() -> Result<PathBuf, CodememError> {
314        Self::download_model(&Self::default_model_dir(), DEFAULT_HF_REPO)
315    }
316
317    /// Embed a single text string. Returns an L2-normalized vector (dimension = model's hidden_size).
318    pub fn embed(&self, text: &str) -> Result<Vec<f32>, CodememError> {
319        // Tokenize using pre-configured tokenizer (truncation already set in constructor)
320        let encoding = self
321            .tokenizer
322            .encode(text, true)
323            .map_err(|e| CodememError::Embedding(e.to_string()))?;
324
325        let input_ids: Vec<u32> = encoding.get_ids().to_vec();
326        let attention_mask: Vec<u32> = encoding.get_attention_mask().to_vec();
327
328        // Build candle tensors with shape [1, seq_len]
329        let input_ids_tensor = Tensor::new(&input_ids[..], &self.device)
330            .and_then(|t| t.unsqueeze(0))
331            .map_err(|e| CodememError::Embedding(format!("Tensor error: {e}")))?;
332
333        let attention_mask_tensor = Tensor::new(&attention_mask[..], &self.device)
334            .and_then(|t| t.unsqueeze(0))
335            .map_err(|e| CodememError::Embedding(format!("Tensor error: {e}")))?;
336
337        // Forward pass -> [1, seq_len, hidden_size]
338        let model = self
339            .model
340            .lock()
341            .map_err(|e| CodememError::LockPoisoned(format!("embedding model: {e}")))?;
342        let hidden_states = match &*model {
343            ModelBackend::Bert(bert) => {
344                let token_type_ids = input_ids_tensor
345                    .zeros_like()
346                    .map_err(|e| CodememError::Embedding(format!("Tensor error: {e}")))?;
347                let result = bert
348                    .forward(
349                        &input_ids_tensor,
350                        &token_type_ids,
351                        Some(&attention_mask_tensor),
352                    )
353                    .map_err(|e| CodememError::Embedding(format!("Model forward error: {e}")))?;
354                drop(token_type_ids);
355                result
356            }
357            ModelBackend::JinaBert(jina) => jina
358                .forward(&input_ids_tensor)
359                .map_err(|e| CodememError::Embedding(format!("Model forward error: {e}")))?,
360        };
361        drop(model);
362        drop(input_ids_tensor);
363
364        // Cast hidden states to F32 for pooling math (model may output F16/BF16)
365        let hidden_states = hidden_states
366            .to_dtype(DType::F32)
367            .map_err(|e| CodememError::Embedding(format!("Cast error: {e}")))?;
368
369        // Mean pooling weighted by attention mask
370        // attention_mask: [1, seq_len] -> [1, seq_len, 1] for broadcasting
371        let mask = attention_mask_tensor
372            .to_dtype(DType::F32)
373            .and_then(|t| t.unsqueeze(2))
374            .map_err(|e| CodememError::Embedding(format!("Mask error: {e}")))?;
375
376        let sum_mask = mask
377            .sum(1)
378            .map_err(|e| CodememError::Embedding(format!("Sum error: {e}")))?;
379
380        let pooled = hidden_states
381            .broadcast_mul(&mask)
382            .and_then(|t| t.sum(1))
383            .and_then(|t| t.broadcast_div(&sum_mask))
384            .map_err(|e| CodememError::Embedding(format!("Pooling error: {e}")))?;
385
386        // L2 normalize
387        let normalized = pooled
388            .sqr()
389            .and_then(|t| t.sum_keepdim(1))
390            .and_then(|t| t.sqrt())
391            .and_then(|norm| pooled.broadcast_div(&norm))
392            .map_err(|e| CodememError::Embedding(format!("Normalize error: {e}")))?;
393
394        // Extract as Vec<f32> — shape is [1, hidden_size], squeeze to [hidden_size]
395        let embedding: Vec<f32> = normalized
396            .squeeze(0)
397            .and_then(|t| t.to_vec1())
398            .map_err(|e| CodememError::Embedding(format!("Extract error: {e}")))?;
399
400        Ok(embedding)
401    }
402
403    /// Embed a batch of texts using a true batched forward pass.
404    ///
405    /// Tokenizes all texts, pads to the longest sequence in each chunk, runs a
406    /// single forward pass per chunk of up to `batch_size` texts, then performs
407    /// mean pooling and L2 normalization on the batched output.
408    pub fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, CodememError> {
409        if texts.is_empty() {
410            return Ok(vec![]);
411        }
412
413        let mut all_embeddings = Vec::with_capacity(texts.len());
414
415        for chunk in texts.chunks(self.batch_size) {
416            // Clone tokenizer only for batch path — needs per-chunk padding config.
417            // Truncation is already configured on self.tokenizer.
418            let mut tokenizer = self.tokenizer.clone();
419            tokenizer.with_padding(Some(PaddingParams {
420                strategy: PaddingStrategy::BatchLongest,
421                ..Default::default()
422            }));
423
424            let encodings = tokenizer
425                .encode_batch(chunk.to_vec(), true)
426                .map_err(|e| CodememError::Embedding(format!("Batch encode error: {e}")))?;
427
428            let batch_len = encodings.len();
429            let seq_len = encodings[0].get_ids().len();
430
431            // Flatten token IDs and attention masks into contiguous arrays
432            let all_ids: Vec<u32> = encodings
433                .iter()
434                .flat_map(|e| e.get_ids())
435                .copied()
436                .collect();
437            let all_masks: Vec<u32> = encodings
438                .iter()
439                .flat_map(|e| e.get_attention_mask())
440                .copied()
441                .collect();
442
443            // Build tensors with shape [batch_size, seq_len]
444            let input_ids = Tensor::new(all_ids.as_slice(), &self.device)
445                .and_then(|t| t.reshape((batch_len, seq_len)))
446                .map_err(|e| CodememError::Embedding(format!("Tensor error: {e}")))?;
447
448            let attention_mask = Tensor::new(all_masks.as_slice(), &self.device)
449                .and_then(|t| t.reshape((batch_len, seq_len)))
450                .map_err(|e| CodememError::Embedding(format!("Tensor error: {e}")))?;
451
452            // Single forward pass -> [batch_size, seq_len, hidden_size]
453            let model = self
454                .model
455                .lock()
456                .map_err(|e| CodememError::LockPoisoned(format!("embedding model: {e}")))?;
457            let hidden_states = match &*model {
458                ModelBackend::Bert(bert) => {
459                    let token_type_ids = input_ids
460                        .zeros_like()
461                        .map_err(|e| CodememError::Embedding(format!("Tensor error: {e}")))?;
462                    let result = bert
463                        .forward(&input_ids, &token_type_ids, Some(&attention_mask))
464                        .map_err(|e| CodememError::Embedding(format!("Forward error: {e}")))?;
465                    drop(token_type_ids);
466                    result
467                }
468                ModelBackend::JinaBert(jina) => jina
469                    .forward(&input_ids)
470                    .map_err(|e| CodememError::Embedding(format!("Forward error: {e}")))?,
471            };
472            drop(model);
473            drop(input_ids);
474
475            // Cast hidden states to F32 for pooling math (model may output F16/BF16)
476            let hidden_states = hidden_states
477                .to_dtype(DType::F32)
478                .map_err(|e| CodememError::Embedding(format!("Cast error: {e}")))?;
479
480            // Mean pooling: mask [batch, seq] -> [batch, seq, 1] for broadcast
481            let mask = attention_mask
482                .to_dtype(DType::F32)
483                .and_then(|t| t.unsqueeze(2))
484                .map_err(|e| CodememError::Embedding(format!("Mask error: {e}")))?;
485
486            let sum_mask = mask
487                .sum(1)
488                .map_err(|e| CodememError::Embedding(format!("Sum error: {e}")))?;
489
490            let pooled = hidden_states
491                .broadcast_mul(&mask)
492                .and_then(|t| t.sum(1))
493                .and_then(|t| t.broadcast_div(&sum_mask))
494                .map_err(|e| CodememError::Embedding(format!("Pooling error: {e}")))?;
495
496            // L2 normalize: [batch, hidden]
497            let norm = pooled
498                .sqr()
499                .and_then(|t| t.sum_keepdim(1))
500                .and_then(|t| t.sqrt())
501                .map_err(|e| CodememError::Embedding(format!("Norm error: {e}")))?;
502
503            let normalized = pooled
504                .broadcast_div(&norm)
505                .map_err(|e| CodememError::Embedding(format!("Normalize error: {e}")))?;
506
507            // Single GPU→CPU blit: flatten all rows, then slice on CPU.
508            // to_vec1() implicitly syncs the GPU pipeline (data must be ready to read).
509            let flat: Vec<f32> = normalized
510                .flatten_all()
511                .and_then(|t| t.to_vec1())
512                .map_err(|e| CodememError::Embedding(format!("Extract error: {e}")))?;
513            for i in 0..batch_len {
514                let start = i * self.hidden_size;
515                all_embeddings.push(flat[start..start + self.hidden_size].to_vec());
516            }
517        }
518
519        Ok(all_embeddings)
520    }
521}
522
523impl EmbeddingProvider for EmbeddingService {
524    fn dimensions(&self) -> usize {
525        self.hidden_size
526    }
527
528    fn embed(&self, text: &str) -> Result<Vec<f32>, CodememError> {
529        self.embed(text)
530    }
531
532    fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, CodememError> {
533        self.embed_batch(texts)
534    }
535
536    fn name(&self) -> &str {
537        "candle"
538    }
539}
540
541// ── Cached Provider Wrapper ───────────────────────────────────────────────
542
543/// Wraps any `EmbeddingProvider` with an LRU cache.
544pub struct CachedProvider {
545    inner: Box<dyn EmbeddingProvider>,
546    cache: Mutex<LruCache<String, Vec<f32>>>,
547}
548
549impl CachedProvider {
550    pub fn new(inner: Box<dyn EmbeddingProvider>, capacity: usize) -> Self {
551        // SAFETY: 1 is non-zero, so the inner expect is infallible
552        let cap =
553            NonZeroUsize::new(capacity).unwrap_or(NonZeroUsize::new(1).expect("1 is non-zero"));
554        Self {
555            inner,
556            cache: Mutex::new(LruCache::new(cap)),
557        }
558    }
559}
560
561impl EmbeddingProvider for CachedProvider {
562    fn dimensions(&self) -> usize {
563        self.inner.dimensions()
564    }
565
566    fn embed(&self, text: &str) -> Result<Vec<f32>, CodememError> {
567        {
568            let mut cache = self
569                .cache
570                .lock()
571                .map_err(|e| CodememError::LockPoisoned(format!("cached provider: {e}")))?;
572            if let Some(cached) = cache.get(text) {
573                return Ok(cached.clone());
574            }
575        }
576        let embedding = self.inner.embed(text)?;
577        {
578            let mut cache = self
579                .cache
580                .lock()
581                .map_err(|e| CodememError::LockPoisoned(format!("cached provider: {e}")))?;
582            cache.put(text.to_string(), embedding.clone());
583        }
584        Ok(embedding)
585    }
586
587    fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, CodememError> {
588        // Check cache, only forward uncached texts
589        let mut results: Vec<Option<Vec<f32>>> = vec![None; texts.len()];
590        let mut uncached = Vec::new();
591        let mut uncached_idx = Vec::new();
592
593        {
594            let mut cache = self
595                .cache
596                .lock()
597                .map_err(|e| CodememError::LockPoisoned(format!("cached provider: {e}")))?;
598            for (i, text) in texts.iter().enumerate() {
599                if let Some(cached) = cache.get(*text) {
600                    results[i] = Some(cached.clone());
601                } else {
602                    uncached_idx.push(i);
603                    uncached.push(*text);
604                }
605            }
606        }
607
608        if !uncached.is_empty() {
609            let new_embeddings = self.inner.embed_batch(&uncached)?;
610            let mut cache = self
611                .cache
612                .lock()
613                .map_err(|e| CodememError::LockPoisoned(format!("cached provider: {e}")))?;
614            for (idx, embedding) in uncached_idx.into_iter().zip(new_embeddings) {
615                cache.put(texts[idx].to_string(), embedding.clone());
616                results[idx] = Some(embedding);
617            }
618        }
619
620        // Verify all texts got embeddings — flatten() would silently drop Nones
621        let expected = texts.len();
622        let output: Vec<Vec<f32>> = results
623            .into_iter()
624            .enumerate()
625            .map(|(i, opt)| {
626                opt.ok_or_else(|| {
627                    CodememError::Embedding(format!(
628                        "Missing embedding for text at index {i} (batch size {expected})"
629                    ))
630                })
631            })
632            .collect::<Result<Vec<_>, _>>()?;
633        Ok(output)
634    }
635
636    fn name(&self) -> &str {
637        self.inner.name()
638    }
639
640    fn cache_stats(&self) -> (usize, usize) {
641        match self.cache.lock() {
642            Ok(cache) => (cache.len(), cache.cap().into()),
643            Err(_) => (0, 0),
644        }
645    }
646}
647
648// ── Factory ───────────────────────────────────────────────────────────────
649
650/// Parse a dtype string into a Candle DType.
651///
652/// Supported values: "f16" (default, half precision — less memory, faster on Metal), "f32", "bf16".
653pub fn parse_dtype(s: &str) -> Result<DType, CodememError> {
654    match s.to_lowercase().as_str() {
655        "f16" | "float16" | "half" | "" => Ok(DType::F16),
656        "f32" | "float32" => Ok(DType::F32),
657        "bf16" | "bfloat16" => Ok(DType::BF16),
658        other => Err(CodememError::Embedding(format!(
659            "Unknown dtype: '{}'. Use 'f16', 'f32', or 'bf16'.",
660            other
661        ))),
662    }
663}
664
665/// Resolve the HuggingFace repo ID and local directory name from a model identifier.
666///
667/// Accepts:
668/// - Full HF repo: `"BAAI/bge-base-en-v1.5"` → repo=`"BAAI/bge-base-en-v1.5"`, dir=`"bge-base-en-v1.5"`
669/// - Short name: `"bge-small-en-v1.5"` → repo=`"BAAI/bge-small-en-v1.5"`, dir=`"bge-small-en-v1.5"`
670///
671/// Returns `Err` if the model identifier is a bare name without an org prefix and isn't
672/// a recognized `bge-*` shorthand — HuggingFace requires `org/repo` format.
673pub fn resolve_model_id(model: &str) -> Result<(String, String), CodememError> {
674    if model.contains('/') {
675        // Full repo ID — directory name is the part after the slash
676        let dir_name = model.rsplit('/').next().unwrap_or(model);
677        Ok((model.to_string(), dir_name.to_string()))
678    } else if model.starts_with("bge-") {
679        // Short name — assume BAAI namespace for bge-* models
680        Ok((format!("BAAI/{model}"), model.to_string()))
681    } else {
682        Err(CodememError::Embedding(format!(
683            "Model identifier '{}' must be a full HuggingFace repo ID (e.g., 'BAAI/bge-base-en-v1.5' \
684             or 'sentence-transformers/all-MiniLM-L6-v2'). Short names are only supported for 'bge-*' models.",
685            model
686        )))
687    }
688}
689
690/// Create an embedding provider from environment variables.
691///
692/// When `config` is provided, its fields serve as defaults; env vars override them.
693///
694/// | Variable | Values | Default |
695/// |----------|--------|---------|
696/// | `CODEMEM_EMBED_PROVIDER` | `candle`, `ollama`, `openai`, `gemini` | `candle` |
697/// | `CODEMEM_EMBED_MODEL` | model name or HF repo | `BAAI/bge-base-en-v1.5` |
698/// | `CODEMEM_EMBED_URL` | base URL | provider default |
699/// | `CODEMEM_EMBED_API_KEY` | API key | also reads `OPENAI_API_KEY` / `GEMINI_API_KEY` / `GOOGLE_API_KEY` |
700/// | `CODEMEM_EMBED_DIMENSIONS` | integer | read from model config |
701/// | `CODEMEM_EMBED_BATCH_SIZE` | integer | `16` |
702/// | `CODEMEM_EMBED_DTYPE` | `f16`, `f32`, `bf16` | `f16` |
703pub fn from_env(
704    config: Option<&codemem_core::EmbeddingConfig>,
705) -> Result<Box<dyn EmbeddingProvider>, CodememError> {
706    let provider = std::env::var("CODEMEM_EMBED_PROVIDER")
707        .unwrap_or_else(|_| {
708            config
709                .map(|c| c.provider.clone())
710                .unwrap_or_else(|| "candle".to_string())
711        })
712        .to_lowercase();
713    // For Ollama/OpenAI, dimensions must be specified explicitly (remote APIs need it).
714    // For Candle, this value is ignored — hidden_size is read from the model's config.json.
715    let dimensions: usize = std::env::var("CODEMEM_EMBED_DIMENSIONS")
716        .ok()
717        .and_then(|s| s.parse().ok())
718        .unwrap_or_else(|| config.map_or(DEFAULT_REMOTE_DIMENSIONS, |c| c.dimensions));
719    let cache_capacity = config.map_or(CACHE_CAPACITY, |c| c.cache_capacity);
720    let batch_size: usize = std::env::var("CODEMEM_EMBED_BATCH_SIZE")
721        .ok()
722        .and_then(|s| s.parse().ok())
723        .unwrap_or_else(|| config.map_or(DEFAULT_BATCH_SIZE, |c| c.batch_size));
724
725    match provider.as_str() {
726        "ollama" => {
727            let base_url = std::env::var("CODEMEM_EMBED_URL").unwrap_or_else(|_| {
728                config
729                    .filter(|c| !c.url.is_empty())
730                    .map(|c| c.url.clone())
731                    .unwrap_or_else(|| ollama::DEFAULT_BASE_URL.to_string())
732            });
733            let model = std::env::var("CODEMEM_EMBED_MODEL").unwrap_or_else(|_| {
734                config
735                    .filter(|c| !c.model.is_empty())
736                    .map(|c| c.model.clone())
737                    .unwrap_or_else(|| ollama::DEFAULT_MODEL.to_string())
738            });
739            let inner = Box::new(ollama::OllamaProvider::new(&base_url, &model, dimensions));
740            Ok(Box::new(CachedProvider::new(inner, cache_capacity)))
741        }
742        "openai" => {
743            let api_key = std::env::var("CODEMEM_EMBED_API_KEY")
744                .or_else(|_| std::env::var("OPENAI_API_KEY"))
745                .map_err(|_| {
746                    CodememError::Embedding(
747                        "CODEMEM_EMBED_API_KEY or OPENAI_API_KEY required for OpenAI embeddings"
748                            .into(),
749                    )
750                })?;
751            let model = std::env::var("CODEMEM_EMBED_MODEL").unwrap_or_else(|_| {
752                config
753                    .filter(|c| !c.model.is_empty())
754                    .map(|c| c.model.clone())
755                    .unwrap_or_else(|| openai::DEFAULT_MODEL.to_string())
756            });
757            let base_url = std::env::var("CODEMEM_EMBED_URL")
758                .ok()
759                .or_else(|| config.filter(|c| !c.url.is_empty()).map(|c| c.url.clone()));
760            let inner = Box::new(openai::OpenAIProvider::new(
761                &api_key,
762                &model,
763                dimensions,
764                base_url.as_deref(),
765            ));
766            Ok(Box::new(CachedProvider::new(inner, cache_capacity)))
767        }
768        "gemini" | "google" => {
769            let api_key = std::env::var("CODEMEM_EMBED_API_KEY")
770                .or_else(|_| std::env::var("GEMINI_API_KEY"))
771                .or_else(|_| std::env::var("GOOGLE_API_KEY"))
772                .map_err(|_| {
773                    CodememError::Embedding(
774                        "CODEMEM_EMBED_API_KEY, GEMINI_API_KEY, or GOOGLE_API_KEY required for Gemini embeddings"
775                            .into(),
776                    )
777                })?;
778            let model = std::env::var("CODEMEM_EMBED_MODEL").unwrap_or_else(|_| {
779                config
780                    .filter(|c| !c.model.is_empty())
781                    .map(|c| c.model.clone())
782                    .unwrap_or_else(|| gemini::DEFAULT_MODEL.to_string())
783            });
784            let base_url = std::env::var("CODEMEM_EMBED_URL")
785                .ok()
786                .or_else(|| config.filter(|c| !c.url.is_empty()).map(|c| c.url.clone()));
787            let inner = Box::new(gemini::GeminiProvider::new(
788                &api_key,
789                &model,
790                dimensions,
791                base_url.as_deref(),
792            ));
793            Ok(Box::new(CachedProvider::new(inner, cache_capacity)))
794        }
795        "candle" | "" => {
796            let model_id = std::env::var("CODEMEM_EMBED_MODEL").unwrap_or_else(|_| {
797                config
798                    .filter(|c| !c.model.is_empty())
799                    .map(|c| c.model.clone())
800                    .unwrap_or_else(|| DEFAULT_HF_REPO.to_string())
801            });
802            let (hf_repo, dir_name) = resolve_model_id(&model_id)?;
803            let model_dir = EmbeddingService::model_dir_for(&dir_name);
804
805            let dtype_str = std::env::var("CODEMEM_EMBED_DTYPE").unwrap_or_else(|_| {
806                config
807                    .filter(|c| !c.dtype.is_empty())
808                    .map(|c| c.dtype.clone())
809                    .unwrap_or_else(|| "f16".to_string())
810            });
811            let dtype = parse_dtype(&dtype_str)?;
812
813            let service = EmbeddingService::new(&model_dir, batch_size, dtype).map_err(|e| {
814                // Enhance error message with download hint for non-default models
815                if e.to_string().contains("Model not found") && hf_repo != DEFAULT_HF_REPO {
816                    CodememError::Embedding(format!(
817                        "Model '{}' not found at {}. Download it with:\n  \
818                         CODEMEM_EMBED_MODEL={} codemem init",
819                        hf_repo,
820                        model_dir.display(),
821                        hf_repo
822                    ))
823                } else {
824                    e
825                }
826            })?;
827            Ok(Box::new(CachedProvider::new(
828                Box::new(service),
829                cache_capacity,
830            )))
831        }
832        other => Err(CodememError::Embedding(format!(
833            "Unknown embedding provider: '{}'. Use 'candle', 'ollama', 'openai', or 'gemini'.",
834            other
835        ))),
836    }
837}
838
839#[cfg(test)]
840#[path = "tests/lib_tests.rs"]
841mod tests;