Skip to main content

codemem_embeddings/
lib.rs

1//! codemem-embeddings: Pluggable embedding providers for Codemem.
2//!
3//! Supports multiple backends:
4//! - **Candle** (default): Local BAAI/bge-base-en-v1.5 via pure Rust ML
5//! - **Ollama**: Local Ollama server with any embedding model
6//! - **OpenAI**: OpenAI API or any compatible endpoint (Together, Azure, etc.)
7
8pub mod ollama;
9pub mod openai;
10
11use candle_core::{DType, Device, Tensor};
12use candle_nn::VarBuilder;
13use candle_transformers::models::bert::{BertModel, Config as BertConfig};
14use codemem_core::CodememError;
15use lru::LruCache;
16use std::num::NonZeroUsize;
17use std::path::{Path, PathBuf};
18use std::sync::Mutex;
19use tokenizers::{PaddingParams, PaddingStrategy, TruncationParams};
20
21/// Default model name.
22pub const MODEL_NAME: &str = "bge-base-en-v1.5";
23
24/// HuggingFace model repo ID.
25const HF_MODEL_REPO: &str = "BAAI/bge-base-en-v1.5";
26
27/// Default embedding dimensions.
28pub const DIMENSIONS: usize = 768;
29
30/// Max sequence length for bge-base-en-v1.5.
31const MAX_SEQ_LENGTH: usize = 512;
32
33/// Default LRU cache capacity.
34pub const CACHE_CAPACITY: usize = 10_000;
35
36// ── Embedding Provider Trait ────────────────────────────────────────────────
37
38/// Trait for pluggable embedding providers.
39pub trait EmbeddingProvider: Send + Sync {
40    /// Embedding vector dimensions.
41    fn dimensions(&self) -> usize;
42
43    /// Embed a single text string.
44    fn embed(&self, text: &str) -> Result<Vec<f32>, CodememError>;
45
46    /// Embed a batch of texts (default: sequential).
47    fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, CodememError> {
48        texts.iter().map(|t| self.embed(t)).collect()
49    }
50
51    /// Provider name for display.
52    fn name(&self) -> &str;
53
54    /// Cache statistics: (current_size, capacity). Returns (0, 0) if no cache.
55    fn cache_stats(&self) -> (usize, usize) {
56        (0, 0)
57    }
58}
59
60// ── Candle Embedding Service ────────────────────────────────────────────────
61
62/// Maximum batch size for batched embedding forward passes.
63const BATCH_SIZE: usize = 32;
64
65/// Select the best available compute device.
66///
67/// Tries Metal (macOS GPU) first, then CUDA (NVIDIA GPU), then falls back to CPU.
68/// GPU backends are only available when the corresponding feature flag is enabled.
69fn select_device() -> Device {
70    #[cfg(feature = "metal")]
71    {
72        if let Ok(device) = Device::new_metal(0) {
73            tracing::info!("Using Metal GPU for embeddings");
74            return device;
75        }
76        tracing::warn!("Metal feature enabled but device creation failed, falling back");
77    }
78    #[cfg(feature = "cuda")]
79    {
80        if let Ok(device) = Device::new_cuda(0) {
81            tracing::info!("Using CUDA GPU for embeddings");
82            return device;
83        }
84        tracing::warn!("CUDA feature enabled but device creation failed, falling back");
85    }
86    tracing::info!("Using CPU for embeddings");
87    Device::Cpu
88}
89
90/// Embedding service with Candle inference and LRU caching.
91pub struct EmbeddingService {
92    model: Mutex<BertModel>,
93    tokenizer: tokenizers::Tokenizer,
94    device: Device,
95    cache: Mutex<LruCache<String, Vec<f32>>>,
96}
97
98impl EmbeddingService {
99    /// Create a new embedding service, loading model from the given directory.
100    /// Expects `model.safetensors`, `config.json`, and `tokenizer.json` in the directory.
101    pub fn new(model_dir: &Path) -> Result<Self, CodememError> {
102        let model_path = model_dir.join("model.safetensors");
103        let config_path = model_dir.join("config.json");
104        let tokenizer_path = model_dir.join("tokenizer.json");
105
106        if !model_path.exists() {
107            return Err(CodememError::Embedding(format!(
108                "Model not found at {}. Run `codemem init` to download it.",
109                model_path.display()
110            )));
111        }
112
113        let device = select_device();
114
115        // Load BERT config
116        let config_str = std::fs::read_to_string(&config_path)
117            .map_err(|e| CodememError::Embedding(format!("Failed to read config: {e}")))?;
118        let config: BertConfig = serde_json::from_str(&config_str)
119            .map_err(|e| CodememError::Embedding(format!("Failed to parse config: {e}")))?;
120
121        // Load model weights from safetensors via memory-mapped IO
122        let vb = unsafe {
123            VarBuilder::from_mmaped_safetensors(&[&model_path], DType::F32, &device)
124                .map_err(|e| CodememError::Embedding(format!("Failed to load weights: {e}")))?
125        };
126
127        // Try with "bert." prefix first (standard HF BERT models), then without
128        let model = BertModel::load(vb.pp("bert"), &config)
129            .or_else(|_| {
130                let vb2 = unsafe {
131                    VarBuilder::from_mmaped_safetensors(&[&model_path], DType::F32, &device)
132                        .map_err(|e| {
133                            candle_core::Error::Msg(format!("Failed to load weights: {e}"))
134                        })
135                }?;
136                BertModel::load(vb2, &config)
137            })
138            .map_err(|e| CodememError::Embedding(format!("Failed to load BERT model: {e}")))?;
139
140        let tokenizer = tokenizers::Tokenizer::from_file(&tokenizer_path)
141            .map_err(|e| CodememError::Embedding(e.to_string()))?;
142
143        let cache = Mutex::new(LruCache::new(NonZeroUsize::new(CACHE_CAPACITY).unwrap()));
144
145        Ok(Self {
146            model: Mutex::new(model),
147            tokenizer,
148            device,
149            cache,
150        })
151    }
152
153    /// Get the default model directory path (~/.codemem/models/{MODEL_NAME}).
154    pub fn default_model_dir() -> PathBuf {
155        dirs::home_dir()
156            .unwrap_or_else(|| PathBuf::from("."))
157            .join(".codemem")
158            .join("models")
159            .join(MODEL_NAME)
160    }
161
162    /// Download the model from HuggingFace Hub to the given directory.
163    /// Returns the directory path. No-ops if model already exists.
164    pub fn download_model(dest_dir: &Path) -> Result<PathBuf, CodememError> {
165        let model_dest = dest_dir.join("model.safetensors");
166        let config_dest = dest_dir.join("config.json");
167        let tokenizer_dest = dest_dir.join("tokenizer.json");
168
169        if model_dest.exists() && config_dest.exists() && tokenizer_dest.exists() {
170            tracing::info!("Model already downloaded at {}", dest_dir.display());
171            return Ok(dest_dir.to_path_buf());
172        }
173
174        std::fs::create_dir_all(dest_dir)
175            .map_err(|e| CodememError::Embedding(format!("Failed to create dir: {e}")))?;
176
177        tracing::info!("Downloading {} from HuggingFace...", HF_MODEL_REPO);
178
179        let api = hf_hub::api::sync::Api::new()
180            .map_err(|e| CodememError::Embedding(format!("HuggingFace API error: {e}")))?;
181        let repo = api.model(HF_MODEL_REPO.to_string());
182
183        let cached_model = repo
184            .get("model.safetensors")
185            .map_err(|e| CodememError::Embedding(format!("Failed to download model: {e}")))?;
186
187        let cached_config = repo
188            .get("config.json")
189            .map_err(|e| CodememError::Embedding(format!("Failed to download config: {e}")))?;
190
191        let cached_tokenizer = repo
192            .get("tokenizer.json")
193            .map_err(|e| CodememError::Embedding(format!("Failed to download tokenizer: {e}")))?;
194
195        std::fs::copy(&cached_model, &model_dest)
196            .map_err(|e| CodememError::Embedding(format!("Failed to copy model: {e}")))?;
197        std::fs::copy(&cached_config, &config_dest)
198            .map_err(|e| CodememError::Embedding(format!("Failed to copy config: {e}")))?;
199        std::fs::copy(&cached_tokenizer, &tokenizer_dest)
200            .map_err(|e| CodememError::Embedding(format!("Failed to copy tokenizer: {e}")))?;
201
202        tracing::info!("Model downloaded to {}", dest_dir.display());
203        Ok(dest_dir.to_path_buf())
204    }
205
206    /// Embed a single text string. Returns a 768-dim L2-normalized vector.
207    /// Uses LRU cache for repeated queries.
208    pub fn embed(&self, text: &str) -> Result<Vec<f32>, CodememError> {
209        // Check cache
210        {
211            let mut cache = self.cache.lock().unwrap();
212            if let Some(cached) = cache.get(text) {
213                return Ok(cached.clone());
214            }
215        }
216
217        let embedding = self.embed_uncached(text)?;
218
219        // Store in cache
220        {
221            let mut cache = self.cache.lock().unwrap();
222            cache.put(text.to_string(), embedding.clone());
223        }
224
225        Ok(embedding)
226    }
227
228    /// Embed without caching. Uses mean pooling with attention mask.
229    fn embed_uncached(&self, text: &str) -> Result<Vec<f32>, CodememError> {
230        // Tokenize with truncation
231        let mut tokenizer = self.tokenizer.clone();
232
233        tokenizer
234            .with_truncation(Some(tokenizers::TruncationParams {
235                max_length: MAX_SEQ_LENGTH,
236                ..Default::default()
237            }))
238            .map_err(|e| CodememError::Embedding(format!("Truncation error: {e}")))?;
239
240        let encoding = tokenizer
241            .encode(text, true)
242            .map_err(|e| CodememError::Embedding(e.to_string()))?;
243
244        let input_ids: Vec<u32> = encoding.get_ids().to_vec();
245        let attention_mask: Vec<u32> = encoding.get_attention_mask().to_vec();
246
247        // Build candle tensors with shape [1, seq_len]
248        let input_ids_tensor = Tensor::new(&input_ids[..], &self.device)
249            .and_then(|t| t.unsqueeze(0))
250            .map_err(|e| CodememError::Embedding(format!("Tensor error: {e}")))?;
251
252        let token_type_ids = input_ids_tensor
253            .zeros_like()
254            .map_err(|e| CodememError::Embedding(format!("Tensor error: {e}")))?;
255
256        let attention_mask_tensor = Tensor::new(&attention_mask[..], &self.device)
257            .and_then(|t| t.unsqueeze(0))
258            .map_err(|e| CodememError::Embedding(format!("Tensor error: {e}")))?;
259
260        // Forward pass -> [1, seq_len, hidden_size]
261        let model = self.model.lock().unwrap();
262        let hidden_states = model
263            .forward(
264                &input_ids_tensor,
265                &token_type_ids,
266                Some(&attention_mask_tensor),
267            )
268            .map_err(|e| CodememError::Embedding(format!("Model forward error: {e}")))?;
269        drop(model);
270
271        // Mean pooling weighted by attention mask
272        // attention_mask: [1, seq_len] -> [1, seq_len, 1] for broadcasting
273        let mask = attention_mask_tensor
274            .to_dtype(DType::F32)
275            .and_then(|t| t.unsqueeze(2))
276            .map_err(|e| CodememError::Embedding(format!("Mask error: {e}")))?;
277
278        let sum_mask = mask
279            .sum(1)
280            .map_err(|e| CodememError::Embedding(format!("Sum error: {e}")))?;
281
282        let pooled = hidden_states
283            .broadcast_mul(&mask)
284            .and_then(|t| t.sum(1))
285            .and_then(|t| t.broadcast_div(&sum_mask))
286            .map_err(|e| CodememError::Embedding(format!("Pooling error: {e}")))?;
287
288        // L2 normalize
289        let normalized = pooled
290            .sqr()
291            .and_then(|t| t.sum_keepdim(1))
292            .and_then(|t| t.sqrt())
293            .and_then(|norm| pooled.broadcast_div(&norm))
294            .map_err(|e| CodememError::Embedding(format!("Normalize error: {e}")))?;
295
296        // Extract as Vec<f32> — shape is [1, hidden_size], squeeze to [hidden_size]
297        let embedding: Vec<f32> = normalized
298            .squeeze(0)
299            .and_then(|t| t.to_vec1())
300            .map_err(|e| CodememError::Embedding(format!("Extract error: {e}")))?;
301
302        Ok(embedding)
303    }
304
305    /// Embed a batch of texts with cache-aware batching.
306    ///
307    /// Checks the LRU cache first and only runs the model on uncached texts,
308    /// using a true batched forward pass (single GPU/CPU kernel launch per chunk).
309    pub fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, CodememError> {
310        if texts.is_empty() {
311            return Ok(vec![]);
312        }
313
314        // Partition into cached and uncached
315        let mut results: Vec<Option<Vec<f32>>> = vec![None; texts.len()];
316        let mut uncached_indices = Vec::new();
317        let mut uncached_texts = Vec::new();
318
319        {
320            let mut cache = self.cache.lock().unwrap();
321            for (i, text) in texts.iter().enumerate() {
322                if let Some(cached) = cache.get(*text) {
323                    results[i] = Some(cached.clone());
324                } else {
325                    uncached_indices.push(i);
326                    uncached_texts.push(*text);
327                }
328            }
329        }
330
331        if !uncached_texts.is_empty() {
332            let new_embeddings = self.embed_batch_uncached(&uncached_texts)?;
333
334            let mut cache = self.cache.lock().unwrap();
335            for (idx, embedding) in uncached_indices.into_iter().zip(new_embeddings) {
336                cache.put(texts[idx].to_string(), embedding.clone());
337                results[idx] = Some(embedding);
338            }
339        }
340
341        Ok(results.into_iter().map(|r| r.unwrap()).collect())
342    }
343
344    /// Embed a batch of texts without caching, using a true batched forward pass.
345    ///
346    /// Tokenizes all texts, pads to the longest sequence in each chunk, runs a
347    /// single forward pass per chunk of up to `BATCH_SIZE` texts, then performs
348    /// mean pooling and L2 normalization on the batched output.
349    fn embed_batch_uncached(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, CodememError> {
350        if texts.is_empty() {
351            return Ok(vec![]);
352        }
353
354        let mut all_embeddings = Vec::with_capacity(texts.len());
355
356        for chunk in texts.chunks(BATCH_SIZE) {
357            let mut tokenizer = self.tokenizer.clone();
358
359            tokenizer
360                .with_truncation(Some(TruncationParams {
361                    max_length: MAX_SEQ_LENGTH,
362                    ..Default::default()
363                }))
364                .map_err(|e| CodememError::Embedding(format!("Truncation error: {e}")))?;
365
366            // Pad all sequences in this chunk to the length of the longest
367            tokenizer.with_padding(Some(PaddingParams {
368                strategy: PaddingStrategy::BatchLongest,
369                ..Default::default()
370            }));
371
372            let encodings = tokenizer
373                .encode_batch(chunk.to_vec(), true)
374                .map_err(|e| CodememError::Embedding(format!("Batch encode error: {e}")))?;
375
376            let batch_len = encodings.len();
377            let seq_len = encodings[0].get_ids().len();
378
379            // Flatten token IDs and attention masks into contiguous arrays
380            let all_ids: Vec<u32> = encodings
381                .iter()
382                .flat_map(|e| e.get_ids())
383                .copied()
384                .collect();
385            let all_masks: Vec<u32> = encodings
386                .iter()
387                .flat_map(|e| e.get_attention_mask())
388                .copied()
389                .collect();
390
391            // Build tensors with shape [batch_size, seq_len]
392            let input_ids = Tensor::new(all_ids.as_slice(), &self.device)
393                .and_then(|t| t.reshape((batch_len, seq_len)))
394                .map_err(|e| CodememError::Embedding(format!("Tensor error: {e}")))?;
395
396            let token_type_ids = input_ids
397                .zeros_like()
398                .map_err(|e| CodememError::Embedding(format!("Tensor error: {e}")))?;
399
400            let attention_mask = Tensor::new(all_masks.as_slice(), &self.device)
401                .and_then(|t| t.reshape((batch_len, seq_len)))
402                .map_err(|e| CodememError::Embedding(format!("Tensor error: {e}")))?;
403
404            // Single forward pass -> [batch_size, seq_len, hidden_size]
405            let model = self.model.lock().unwrap();
406            let hidden_states = model
407                .forward(&input_ids, &token_type_ids, Some(&attention_mask))
408                .map_err(|e| CodememError::Embedding(format!("Forward error: {e}")))?;
409            drop(model);
410
411            // Mean pooling: mask [batch, seq] -> [batch, seq, 1] for broadcast
412            let mask = attention_mask
413                .to_dtype(DType::F32)
414                .and_then(|t| t.unsqueeze(2))
415                .map_err(|e| CodememError::Embedding(format!("Mask error: {e}")))?;
416
417            let sum_mask = mask
418                .sum(1)
419                .map_err(|e| CodememError::Embedding(format!("Sum error: {e}")))?;
420
421            let pooled = hidden_states
422                .broadcast_mul(&mask)
423                .and_then(|t| t.sum(1))
424                .and_then(|t| t.broadcast_div(&sum_mask))
425                .map_err(|e| CodememError::Embedding(format!("Pooling error: {e}")))?;
426
427            // L2 normalize: [batch, hidden]
428            let norm = pooled
429                .sqr()
430                .and_then(|t| t.sum_keepdim(1))
431                .and_then(|t| t.sqrt())
432                .map_err(|e| CodememError::Embedding(format!("Norm error: {e}")))?;
433
434            let normalized = pooled
435                .broadcast_div(&norm)
436                .map_err(|e| CodememError::Embedding(format!("Normalize error: {e}")))?;
437
438            // Extract each row as Vec<f32>
439            for i in 0..batch_len {
440                let row: Vec<f32> = normalized
441                    .get(i)
442                    .and_then(|t| t.to_vec1())
443                    .map_err(|e| CodememError::Embedding(format!("Extract error: {e}")))?;
444                all_embeddings.push(row);
445            }
446        }
447
448        Ok(all_embeddings)
449    }
450
451    /// Get cache statistics: (current_size, capacity).
452    pub fn cache_stats(&self) -> (usize, usize) {
453        let cache = self.cache.lock().unwrap();
454        (cache.len(), CACHE_CAPACITY)
455    }
456}
457
458impl EmbeddingProvider for EmbeddingService {
459    fn dimensions(&self) -> usize {
460        DIMENSIONS
461    }
462
463    fn embed(&self, text: &str) -> Result<Vec<f32>, CodememError> {
464        self.embed(text)
465    }
466
467    fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, CodememError> {
468        self.embed_batch(texts)
469    }
470
471    fn name(&self) -> &str {
472        "candle"
473    }
474
475    fn cache_stats(&self) -> (usize, usize) {
476        self.cache_stats()
477    }
478}
479
480// ── Cached Provider Wrapper ───────────────────────────────────────────────
481
482/// Wraps any `EmbeddingProvider` with an LRU cache.
483pub struct CachedProvider {
484    inner: Box<dyn EmbeddingProvider>,
485    cache: Mutex<LruCache<String, Vec<f32>>>,
486}
487
488impl CachedProvider {
489    pub fn new(inner: Box<dyn EmbeddingProvider>, capacity: usize) -> Self {
490        Self {
491            inner,
492            cache: Mutex::new(LruCache::new(
493                NonZeroUsize::new(capacity).unwrap_or(NonZeroUsize::new(1).unwrap()),
494            )),
495        }
496    }
497}
498
499impl EmbeddingProvider for CachedProvider {
500    fn dimensions(&self) -> usize {
501        self.inner.dimensions()
502    }
503
504    fn embed(&self, text: &str) -> Result<Vec<f32>, CodememError> {
505        {
506            let mut cache = self.cache.lock().unwrap();
507            if let Some(cached) = cache.get(text) {
508                return Ok(cached.clone());
509            }
510        }
511        let embedding = self.inner.embed(text)?;
512        {
513            let mut cache = self.cache.lock().unwrap();
514            cache.put(text.to_string(), embedding.clone());
515        }
516        Ok(embedding)
517    }
518
519    fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, CodememError> {
520        // Check cache, only forward uncached texts
521        let mut results: Vec<Option<Vec<f32>>> = vec![None; texts.len()];
522        let mut uncached = Vec::new();
523        let mut uncached_idx = Vec::new();
524
525        {
526            let mut cache = self.cache.lock().unwrap();
527            for (i, text) in texts.iter().enumerate() {
528                if let Some(cached) = cache.get(*text) {
529                    results[i] = Some(cached.clone());
530                } else {
531                    uncached_idx.push(i);
532                    uncached.push(*text);
533                }
534            }
535        }
536
537        if !uncached.is_empty() {
538            let new_embeddings = self.inner.embed_batch(&uncached)?;
539            let mut cache = self.cache.lock().unwrap();
540            for (idx, embedding) in uncached_idx.into_iter().zip(new_embeddings) {
541                cache.put(texts[idx].to_string(), embedding.clone());
542                results[idx] = Some(embedding);
543            }
544        }
545
546        Ok(results.into_iter().map(|r| r.unwrap()).collect())
547    }
548
549    fn name(&self) -> &str {
550        self.inner.name()
551    }
552
553    fn cache_stats(&self) -> (usize, usize) {
554        let cache = self.cache.lock().unwrap();
555        (cache.len(), cache.cap().into())
556    }
557}
558
559// ── Factory ───────────────────────────────────────────────────────────────
560
561/// Create an embedding provider from environment variables.
562///
563/// | Variable | Values | Default |
564/// |----------|--------|---------|
565/// | `CODEMEM_EMBED_PROVIDER` | `candle`, `ollama`, `openai` | `candle` |
566/// | `CODEMEM_EMBED_MODEL` | model name | provider default |
567/// | `CODEMEM_EMBED_URL` | base URL | provider default |
568/// | `CODEMEM_EMBED_API_KEY` | API key | also reads `OPENAI_API_KEY` |
569/// | `CODEMEM_EMBED_DIMENSIONS` | integer | `768` |
570pub fn from_env() -> Result<Box<dyn EmbeddingProvider>, CodememError> {
571    let provider = std::env::var("CODEMEM_EMBED_PROVIDER")
572        .unwrap_or_else(|_| "candle".to_string())
573        .to_lowercase();
574    let dimensions: usize = std::env::var("CODEMEM_EMBED_DIMENSIONS")
575        .ok()
576        .and_then(|s| s.parse().ok())
577        .unwrap_or(DIMENSIONS);
578
579    match provider.as_str() {
580        "ollama" => {
581            let base_url = std::env::var("CODEMEM_EMBED_URL")
582                .unwrap_or_else(|_| ollama::DEFAULT_BASE_URL.to_string());
583            let model = std::env::var("CODEMEM_EMBED_MODEL")
584                .unwrap_or_else(|_| ollama::DEFAULT_MODEL.to_string());
585            let inner = Box::new(ollama::OllamaProvider::new(&base_url, &model, dimensions));
586            Ok(Box::new(CachedProvider::new(inner, CACHE_CAPACITY)))
587        }
588        "openai" => {
589            let api_key = std::env::var("CODEMEM_EMBED_API_KEY")
590                .or_else(|_| std::env::var("OPENAI_API_KEY"))
591                .map_err(|_| {
592                    CodememError::Embedding(
593                        "CODEMEM_EMBED_API_KEY or OPENAI_API_KEY required for OpenAI embeddings"
594                            .into(),
595                    )
596                })?;
597            let model = std::env::var("CODEMEM_EMBED_MODEL")
598                .unwrap_or_else(|_| openai::DEFAULT_MODEL.to_string());
599            let base_url = std::env::var("CODEMEM_EMBED_URL").ok();
600            let inner = Box::new(openai::OpenAIProvider::new(
601                &api_key,
602                &model,
603                dimensions,
604                base_url.as_deref(),
605            ));
606            Ok(Box::new(CachedProvider::new(inner, CACHE_CAPACITY)))
607        }
608        "candle" | "" => {
609            let model_dir = EmbeddingService::default_model_dir();
610            let service = EmbeddingService::new(&model_dir)?;
611            Ok(Box::new(service))
612        }
613        other => Err(CodememError::Embedding(format!(
614            "Unknown embedding provider: '{}'. Use 'candle', 'ollama', or 'openai'.",
615            other
616        ))),
617    }
618}