Skip to main content

codemem_embeddings/
lib.rs

1//! codemem-embeddings: Pluggable embedding providers for Codemem.
2//!
3//! Supports multiple backends:
4//! - **Candle** (default): Local BAAI/bge-base-en-v1.5 via pure Rust ML
5//! - **Ollama**: Local Ollama server with any embedding model
6//! - **OpenAI**: OpenAI API or any compatible endpoint (Together, Azure, etc.)
7
8pub mod ollama;
9pub mod openai;
10
11use candle_core::{DType, Device, Tensor};
12use candle_nn::VarBuilder;
13use candle_transformers::models::bert::{BertModel, Config as BertConfig};
14use codemem_core::CodememError;
15use lru::LruCache;
16use std::num::NonZeroUsize;
17use std::path::{Path, PathBuf};
18use std::sync::Mutex;
19use tokenizers::{PaddingParams, PaddingStrategy};
20
21/// Default model name.
22pub const MODEL_NAME: &str = "bge-base-en-v1.5";
23
24/// HuggingFace model repo ID.
25const HF_MODEL_REPO: &str = "BAAI/bge-base-en-v1.5";
26
27/// Default embedding dimensions.
28pub const DIMENSIONS: usize = 768;
29
30/// Max sequence length for bge-base-en-v1.5.
31const MAX_SEQ_LENGTH: usize = 512;
32
33/// Default LRU cache capacity.
34pub const CACHE_CAPACITY: usize = 10_000;
35
36// Re-export EmbeddingProvider trait from core
37pub use codemem_core::EmbeddingProvider;
38
39// ── Candle Embedding Service ────────────────────────────────────────────────
40
41/// Default batch size for batched embedding forward passes.
42/// Configurable via `EmbeddingConfig.batch_size`.
43pub const DEFAULT_BATCH_SIZE: usize = 16;
44
45/// Select the best available compute device.
46///
47/// Tries Metal (macOS GPU) first, then CUDA (NVIDIA GPU), then falls back to CPU.
48/// GPU backends are only available when the corresponding feature flag is enabled.
49fn select_device() -> Device {
50    #[cfg(feature = "metal")]
51    {
52        if let Ok(device) = Device::new_metal(0) {
53            tracing::info!("Using Metal GPU for embeddings");
54            return device;
55        }
56        tracing::warn!("Metal feature enabled but device creation failed, falling back");
57    }
58    #[cfg(feature = "cuda")]
59    {
60        if let Ok(device) = Device::new_cuda(0) {
61            tracing::info!("Using CUDA GPU for embeddings");
62            return device;
63        }
64        tracing::warn!("CUDA feature enabled but device creation failed, falling back");
65    }
66    tracing::info!("Using CPU for embeddings");
67    Device::Cpu
68}
69
70/// Embedding service with Candle inference (no internal cache — use `CachedProvider` wrapper).
71pub struct EmbeddingService {
72    model: Mutex<BertModel>,
73    /// Tokenizer pre-configured with truncation (no padding).
74    /// Used directly for single embeds; cloned and augmented with padding for batch.
75    tokenizer: tokenizers::Tokenizer,
76    device: Device,
77    /// Maximum texts per forward pass (GPU memory trade-off).
78    batch_size: usize,
79}
80
81impl EmbeddingService {
82    /// Create a new embedding service, loading model from the given directory.
83    /// Expects `model.safetensors`, `config.json`, and `tokenizer.json` in the directory.
84    pub fn new(model_dir: &Path, batch_size: usize) -> Result<Self, CodememError> {
85        let model_path = model_dir.join("model.safetensors");
86        let config_path = model_dir.join("config.json");
87        let tokenizer_path = model_dir.join("tokenizer.json");
88
89        if !model_path.exists() {
90            return Err(CodememError::Embedding(format!(
91                "Model not found at {}. Run `codemem init` to download it.",
92                model_path.display()
93            )));
94        }
95
96        let device = select_device();
97
98        // Load BERT config
99        let config_str = std::fs::read_to_string(&config_path)
100            .map_err(|e| CodememError::Embedding(format!("Failed to read config: {e}")))?;
101        let config: BertConfig = serde_json::from_str(&config_str)
102            .map_err(|e| CodememError::Embedding(format!("Failed to parse config: {e}")))?;
103
104        // Load model weights from safetensors via memory-mapped IO.
105        // Scope vb so it drops before a potential retry, avoiding two VarBuilders
106        // holding materialized Metal tensors simultaneously.
107        let model = {
108            let vb = unsafe {
109                VarBuilder::from_mmaped_safetensors(&[&model_path], DType::F32, &device)
110                    .map_err(|e| CodememError::Embedding(format!("Failed to load weights: {e}")))?
111            };
112            BertModel::load(vb.pp("bert"), &config)
113        };
114        // Try with "bert." prefix first (standard HF BERT models), then without
115        let model = match model {
116            Ok(m) => m,
117            Err(_) => {
118                let vb2 = unsafe {
119                    VarBuilder::from_mmaped_safetensors(&[&model_path], DType::F32, &device)
120                        .map_err(|e| {
121                            CodememError::Embedding(format!("Failed to load weights: {e}"))
122                        })?
123                };
124                BertModel::load(vb2, &config).map_err(|e| {
125                    CodememError::Embedding(format!("Failed to load BERT model: {e}"))
126                })?
127            }
128        };
129
130        let mut tokenizer = tokenizers::Tokenizer::from_file(&tokenizer_path)
131            .map_err(|e| CodememError::Embedding(e.to_string()))?;
132
133        // Pre-configure truncation once so we don't need to clone on every embed call.
134        tokenizer
135            .with_truncation(Some(tokenizers::TruncationParams {
136                max_length: MAX_SEQ_LENGTH,
137                ..Default::default()
138            }))
139            .map_err(|e| CodememError::Embedding(format!("Truncation error: {e}")))?;
140
141        Ok(Self {
142            model: Mutex::new(model),
143            tokenizer,
144            device,
145            batch_size,
146        })
147    }
148
149    /// Get the default model directory path (~/.codemem/models/{MODEL_NAME}).
150    pub fn default_model_dir() -> PathBuf {
151        dirs::home_dir()
152            .unwrap_or_else(|| PathBuf::from("."))
153            .join(".codemem")
154            .join("models")
155            .join(MODEL_NAME)
156    }
157
158    /// Download the model from HuggingFace Hub to the given directory.
159    /// Returns the directory path. No-ops if model already exists.
160    pub fn download_model(dest_dir: &Path) -> Result<PathBuf, CodememError> {
161        let model_dest = dest_dir.join("model.safetensors");
162        let config_dest = dest_dir.join("config.json");
163        let tokenizer_dest = dest_dir.join("tokenizer.json");
164
165        if model_dest.exists() && config_dest.exists() && tokenizer_dest.exists() {
166            tracing::info!("Model already downloaded at {}", dest_dir.display());
167            return Ok(dest_dir.to_path_buf());
168        }
169
170        std::fs::create_dir_all(dest_dir)
171            .map_err(|e| CodememError::Embedding(format!("Failed to create dir: {e}")))?;
172
173        tracing::info!("Downloading {} from HuggingFace...", HF_MODEL_REPO);
174
175        let api = hf_hub::api::sync::Api::new()
176            .map_err(|e| CodememError::Embedding(format!("HuggingFace API error: {e}")))?;
177        let repo = api.model(HF_MODEL_REPO.to_string());
178
179        let cached_model = repo
180            .get("model.safetensors")
181            .map_err(|e| CodememError::Embedding(format!("Failed to download model: {e}")))?;
182
183        let cached_config = repo
184            .get("config.json")
185            .map_err(|e| CodememError::Embedding(format!("Failed to download config: {e}")))?;
186
187        let cached_tokenizer = repo
188            .get("tokenizer.json")
189            .map_err(|e| CodememError::Embedding(format!("Failed to download tokenizer: {e}")))?;
190
191        std::fs::copy(&cached_model, &model_dest)
192            .map_err(|e| CodememError::Embedding(format!("Failed to copy model: {e}")))?;
193        std::fs::copy(&cached_config, &config_dest)
194            .map_err(|e| CodememError::Embedding(format!("Failed to copy config: {e}")))?;
195        std::fs::copy(&cached_tokenizer, &tokenizer_dest)
196            .map_err(|e| CodememError::Embedding(format!("Failed to copy tokenizer: {e}")))?;
197
198        tracing::info!("Model downloaded to {}", dest_dir.display());
199        Ok(dest_dir.to_path_buf())
200    }
201
202    /// Embed a single text string. Returns a 768-dim L2-normalized vector.
203    pub fn embed(&self, text: &str) -> Result<Vec<f32>, CodememError> {
204        // Tokenize using pre-configured tokenizer (truncation already set in constructor)
205        let encoding = self
206            .tokenizer
207            .encode(text, true)
208            .map_err(|e| CodememError::Embedding(e.to_string()))?;
209
210        let input_ids: Vec<u32> = encoding.get_ids().to_vec();
211        let attention_mask: Vec<u32> = encoding.get_attention_mask().to_vec();
212
213        // Build candle tensors with shape [1, seq_len]
214        let input_ids_tensor = Tensor::new(&input_ids[..], &self.device)
215            .and_then(|t| t.unsqueeze(0))
216            .map_err(|e| CodememError::Embedding(format!("Tensor error: {e}")))?;
217
218        let token_type_ids = input_ids_tensor
219            .zeros_like()
220            .map_err(|e| CodememError::Embedding(format!("Tensor error: {e}")))?;
221
222        let attention_mask_tensor = Tensor::new(&attention_mask[..], &self.device)
223            .and_then(|t| t.unsqueeze(0))
224            .map_err(|e| CodememError::Embedding(format!("Tensor error: {e}")))?;
225
226        // Forward pass -> [1, seq_len, hidden_size]
227        let model = self
228            .model
229            .lock()
230            .map_err(|e| CodememError::LockPoisoned(format!("embedding model: {e}")))?;
231        let hidden_states = model
232            .forward(
233                &input_ids_tensor,
234                &token_type_ids,
235                Some(&attention_mask_tensor),
236            )
237            .map_err(|e| CodememError::Embedding(format!("Model forward error: {e}")))?;
238        drop(model);
239
240        // Drop intermediate GPU tensors before pooling
241        drop(input_ids_tensor);
242        drop(token_type_ids);
243
244        // Mean pooling weighted by attention mask
245        // attention_mask: [1, seq_len] -> [1, seq_len, 1] for broadcasting
246        let mask = attention_mask_tensor
247            .to_dtype(DType::F32)
248            .and_then(|t| t.unsqueeze(2))
249            .map_err(|e| CodememError::Embedding(format!("Mask error: {e}")))?;
250
251        let sum_mask = mask
252            .sum(1)
253            .map_err(|e| CodememError::Embedding(format!("Sum error: {e}")))?;
254
255        let pooled = hidden_states
256            .broadcast_mul(&mask)
257            .and_then(|t| t.sum(1))
258            .and_then(|t| t.broadcast_div(&sum_mask))
259            .map_err(|e| CodememError::Embedding(format!("Pooling error: {e}")))?;
260
261        // L2 normalize
262        let normalized = pooled
263            .sqr()
264            .and_then(|t| t.sum_keepdim(1))
265            .and_then(|t| t.sqrt())
266            .and_then(|norm| pooled.broadcast_div(&norm))
267            .map_err(|e| CodememError::Embedding(format!("Normalize error: {e}")))?;
268
269        // Extract as Vec<f32> — shape is [1, hidden_size], squeeze to [hidden_size]
270        let embedding: Vec<f32> = normalized
271            .squeeze(0)
272            .and_then(|t| t.to_vec1())
273            .map_err(|e| CodememError::Embedding(format!("Extract error: {e}")))?;
274
275        Ok(embedding)
276    }
277
278    /// Embed a batch of texts using a true batched forward pass.
279    ///
280    /// Tokenizes all texts, pads to the longest sequence in each chunk, runs a
281    /// single forward pass per chunk of up to `batch_size` texts, then performs
282    /// mean pooling and L2 normalization on the batched output.
283    pub fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, CodememError> {
284        if texts.is_empty() {
285            return Ok(vec![]);
286        }
287
288        let mut all_embeddings = Vec::with_capacity(texts.len());
289
290        for chunk in texts.chunks(self.batch_size) {
291            // Clone tokenizer only for batch path — needs per-chunk padding config.
292            // Truncation is already configured on self.tokenizer.
293            let mut tokenizer = self.tokenizer.clone();
294            tokenizer.with_padding(Some(PaddingParams {
295                strategy: PaddingStrategy::BatchLongest,
296                ..Default::default()
297            }));
298
299            let encodings = tokenizer
300                .encode_batch(chunk.to_vec(), true)
301                .map_err(|e| CodememError::Embedding(format!("Batch encode error: {e}")))?;
302
303            let batch_len = encodings.len();
304            let seq_len = encodings[0].get_ids().len();
305
306            // Flatten token IDs and attention masks into contiguous arrays
307            let all_ids: Vec<u32> = encodings
308                .iter()
309                .flat_map(|e| e.get_ids())
310                .copied()
311                .collect();
312            let all_masks: Vec<u32> = encodings
313                .iter()
314                .flat_map(|e| e.get_attention_mask())
315                .copied()
316                .collect();
317
318            // Build tensors with shape [batch_size, seq_len]
319            let input_ids = Tensor::new(all_ids.as_slice(), &self.device)
320                .and_then(|t| t.reshape((batch_len, seq_len)))
321                .map_err(|e| CodememError::Embedding(format!("Tensor error: {e}")))?;
322
323            let token_type_ids = input_ids
324                .zeros_like()
325                .map_err(|e| CodememError::Embedding(format!("Tensor error: {e}")))?;
326
327            let attention_mask = Tensor::new(all_masks.as_slice(), &self.device)
328                .and_then(|t| t.reshape((batch_len, seq_len)))
329                .map_err(|e| CodememError::Embedding(format!("Tensor error: {e}")))?;
330
331            // Single forward pass -> [batch_size, seq_len, hidden_size]
332            let model = self
333                .model
334                .lock()
335                .map_err(|e| CodememError::LockPoisoned(format!("embedding model: {e}")))?;
336            let hidden_states = model
337                .forward(&input_ids, &token_type_ids, Some(&attention_mask))
338                .map_err(|e| CodememError::Embedding(format!("Forward error: {e}")))?;
339            drop(model);
340
341            // Drop intermediate GPU tensors before pooling
342            drop(input_ids);
343            drop(token_type_ids);
344
345            // Mean pooling: mask [batch, seq] -> [batch, seq, 1] for broadcast
346            let mask = attention_mask
347                .to_dtype(DType::F32)
348                .and_then(|t| t.unsqueeze(2))
349                .map_err(|e| CodememError::Embedding(format!("Mask error: {e}")))?;
350
351            let sum_mask = mask
352                .sum(1)
353                .map_err(|e| CodememError::Embedding(format!("Sum error: {e}")))?;
354
355            let pooled = hidden_states
356                .broadcast_mul(&mask)
357                .and_then(|t| t.sum(1))
358                .and_then(|t| t.broadcast_div(&sum_mask))
359                .map_err(|e| CodememError::Embedding(format!("Pooling error: {e}")))?;
360
361            // L2 normalize: [batch, hidden]
362            let norm = pooled
363                .sqr()
364                .and_then(|t| t.sum_keepdim(1))
365                .and_then(|t| t.sqrt())
366                .map_err(|e| CodememError::Embedding(format!("Norm error: {e}")))?;
367
368            let normalized = pooled
369                .broadcast_div(&norm)
370                .map_err(|e| CodememError::Embedding(format!("Normalize error: {e}")))?;
371
372            // Single GPU→CPU blit: flatten all rows, then slice on CPU.
373            // to_vec1() implicitly syncs the GPU pipeline (data must be ready to read).
374            let flat: Vec<f32> = normalized
375                .flatten_all()
376                .and_then(|t| t.to_vec1())
377                .map_err(|e| CodememError::Embedding(format!("Extract error: {e}")))?;
378            for i in 0..batch_len {
379                let start = i * DIMENSIONS;
380                all_embeddings.push(flat[start..start + DIMENSIONS].to_vec());
381            }
382        }
383
384        Ok(all_embeddings)
385    }
386}
387
388impl EmbeddingProvider for EmbeddingService {
389    fn dimensions(&self) -> usize {
390        DIMENSIONS
391    }
392
393    fn embed(&self, text: &str) -> Result<Vec<f32>, CodememError> {
394        self.embed(text)
395    }
396
397    fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, CodememError> {
398        self.embed_batch(texts)
399    }
400
401    fn name(&self) -> &str {
402        "candle"
403    }
404}
405
406// ── Cached Provider Wrapper ───────────────────────────────────────────────
407
408/// Wraps any `EmbeddingProvider` with an LRU cache.
409pub struct CachedProvider {
410    inner: Box<dyn EmbeddingProvider>,
411    cache: Mutex<LruCache<String, Vec<f32>>>,
412}
413
414impl CachedProvider {
415    pub fn new(inner: Box<dyn EmbeddingProvider>, capacity: usize) -> Self {
416        // SAFETY: 1 is non-zero, so the inner expect is infallible
417        let cap =
418            NonZeroUsize::new(capacity).unwrap_or(NonZeroUsize::new(1).expect("1 is non-zero"));
419        Self {
420            inner,
421            cache: Mutex::new(LruCache::new(cap)),
422        }
423    }
424}
425
426impl EmbeddingProvider for CachedProvider {
427    fn dimensions(&self) -> usize {
428        self.inner.dimensions()
429    }
430
431    fn embed(&self, text: &str) -> Result<Vec<f32>, CodememError> {
432        {
433            let mut cache = self
434                .cache
435                .lock()
436                .map_err(|e| CodememError::LockPoisoned(format!("cached provider: {e}")))?;
437            if let Some(cached) = cache.get(text) {
438                return Ok(cached.clone());
439            }
440        }
441        let embedding = self.inner.embed(text)?;
442        {
443            let mut cache = self
444                .cache
445                .lock()
446                .map_err(|e| CodememError::LockPoisoned(format!("cached provider: {e}")))?;
447            cache.put(text.to_string(), embedding.clone());
448        }
449        Ok(embedding)
450    }
451
452    fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, CodememError> {
453        // Check cache, only forward uncached texts
454        let mut results: Vec<Option<Vec<f32>>> = vec![None; texts.len()];
455        let mut uncached = Vec::new();
456        let mut uncached_idx = Vec::new();
457
458        {
459            let mut cache = self
460                .cache
461                .lock()
462                .map_err(|e| CodememError::LockPoisoned(format!("cached provider: {e}")))?;
463            for (i, text) in texts.iter().enumerate() {
464                if let Some(cached) = cache.get(*text) {
465                    results[i] = Some(cached.clone());
466                } else {
467                    uncached_idx.push(i);
468                    uncached.push(*text);
469                }
470            }
471        }
472
473        if !uncached.is_empty() {
474            let new_embeddings = self.inner.embed_batch(&uncached)?;
475            let mut cache = self
476                .cache
477                .lock()
478                .map_err(|e| CodememError::LockPoisoned(format!("cached provider: {e}")))?;
479            for (idx, embedding) in uncached_idx.into_iter().zip(new_embeddings) {
480                cache.put(texts[idx].to_string(), embedding.clone());
481                results[idx] = Some(embedding);
482            }
483        }
484
485        // Verify all texts got embeddings — flatten() would silently drop Nones
486        let expected = texts.len();
487        let output: Vec<Vec<f32>> = results
488            .into_iter()
489            .enumerate()
490            .map(|(i, opt)| {
491                opt.ok_or_else(|| {
492                    CodememError::Embedding(format!(
493                        "Missing embedding for text at index {i} (batch size {expected})"
494                    ))
495                })
496            })
497            .collect::<Result<Vec<_>, _>>()?;
498        Ok(output)
499    }
500
501    fn name(&self) -> &str {
502        self.inner.name()
503    }
504
505    fn cache_stats(&self) -> (usize, usize) {
506        match self.cache.lock() {
507            Ok(cache) => (cache.len(), cache.cap().into()),
508            Err(_) => (0, 0),
509        }
510    }
511}
512
513// ── Factory ───────────────────────────────────────────────────────────────
514
515/// Create an embedding provider from environment variables.
516///
517/// When `config` is provided, its fields serve as defaults; env vars override them.
518///
519/// | Variable | Values | Default |
520/// |----------|--------|---------|
521/// | `CODEMEM_EMBED_PROVIDER` | `candle`, `ollama`, `openai` | `candle` |
522/// | `CODEMEM_EMBED_MODEL` | model name | provider default |
523/// | `CODEMEM_EMBED_URL` | base URL | provider default |
524/// | `CODEMEM_EMBED_API_KEY` | API key | also reads `OPENAI_API_KEY` |
525/// | `CODEMEM_EMBED_DIMENSIONS` | integer | `768` |
526pub fn from_env(
527    config: Option<&codemem_core::EmbeddingConfig>,
528) -> Result<Box<dyn EmbeddingProvider>, CodememError> {
529    let provider = std::env::var("CODEMEM_EMBED_PROVIDER")
530        .unwrap_or_else(|_| {
531            config
532                .map(|c| c.provider.clone())
533                .unwrap_or_else(|| "candle".to_string())
534        })
535        .to_lowercase();
536    let dimensions: usize = std::env::var("CODEMEM_EMBED_DIMENSIONS")
537        .ok()
538        .and_then(|s| s.parse().ok())
539        .unwrap_or_else(|| config.map_or(DIMENSIONS, |c| c.dimensions));
540    let cache_capacity = config.map_or(CACHE_CAPACITY, |c| c.cache_capacity);
541    let batch_size = config.map_or(DEFAULT_BATCH_SIZE, |c| c.batch_size);
542
543    match provider.as_str() {
544        "ollama" => {
545            let base_url = std::env::var("CODEMEM_EMBED_URL").unwrap_or_else(|_| {
546                config
547                    .filter(|c| !c.url.is_empty())
548                    .map(|c| c.url.clone())
549                    .unwrap_or_else(|| ollama::DEFAULT_BASE_URL.to_string())
550            });
551            let model = std::env::var("CODEMEM_EMBED_MODEL").unwrap_or_else(|_| {
552                config
553                    .filter(|c| !c.model.is_empty())
554                    .map(|c| c.model.clone())
555                    .unwrap_or_else(|| ollama::DEFAULT_MODEL.to_string())
556            });
557            let inner = Box::new(ollama::OllamaProvider::new(&base_url, &model, dimensions));
558            Ok(Box::new(CachedProvider::new(inner, cache_capacity)))
559        }
560        "openai" => {
561            let api_key = std::env::var("CODEMEM_EMBED_API_KEY")
562                .or_else(|_| std::env::var("OPENAI_API_KEY"))
563                .map_err(|_| {
564                    CodememError::Embedding(
565                        "CODEMEM_EMBED_API_KEY or OPENAI_API_KEY required for OpenAI embeddings"
566                            .into(),
567                    )
568                })?;
569            let model = std::env::var("CODEMEM_EMBED_MODEL").unwrap_or_else(|_| {
570                config
571                    .filter(|c| !c.model.is_empty())
572                    .map(|c| c.model.clone())
573                    .unwrap_or_else(|| openai::DEFAULT_MODEL.to_string())
574            });
575            let base_url = std::env::var("CODEMEM_EMBED_URL")
576                .ok()
577                .or_else(|| config.filter(|c| !c.url.is_empty()).map(|c| c.url.clone()));
578            let inner = Box::new(openai::OpenAIProvider::new(
579                &api_key,
580                &model,
581                dimensions,
582                base_url.as_deref(),
583            ));
584            Ok(Box::new(CachedProvider::new(inner, cache_capacity)))
585        }
586        "candle" | "" => {
587            let model_dir = EmbeddingService::default_model_dir();
588            let service = EmbeddingService::new(&model_dir, batch_size)?;
589            Ok(Box::new(CachedProvider::new(
590                Box::new(service),
591                cache_capacity,
592            )))
593        }
594        other => Err(CodememError::Embedding(format!(
595            "Unknown embedding provider: '{}'. Use 'candle', 'ollama', or 'openai'.",
596            other
597        ))),
598    }
599}
600
601#[cfg(test)]
602#[path = "tests/lib_tests.rs"]
603mod tests;