1pub mod ollama;
9pub mod openai;
10
11use candle_core::{DType, Device, Tensor};
12use candle_nn::VarBuilder;
13use candle_transformers::models::bert::{BertModel, Config as BertConfig};
14use codemem_core::CodememError;
15use lru::LruCache;
16use std::num::NonZeroUsize;
17use std::path::{Path, PathBuf};
18use std::sync::Mutex;
19use tokenizers::{PaddingParams, PaddingStrategy};
20
21pub const MODEL_NAME: &str = "bge-base-en-v1.5";
23
24const HF_MODEL_REPO: &str = "BAAI/bge-base-en-v1.5";
26
27pub const DIMENSIONS: usize = 768;
29
30const MAX_SEQ_LENGTH: usize = 512;
32
33pub const CACHE_CAPACITY: usize = 10_000;
35
36pub use codemem_core::EmbeddingProvider;
38
39const BATCH_SIZE: usize = 32;
44
45fn select_device() -> Device {
50 #[cfg(feature = "metal")]
51 {
52 if let Ok(device) = Device::new_metal(0) {
53 tracing::info!("Using Metal GPU for embeddings");
54 return device;
55 }
56 tracing::warn!("Metal feature enabled but device creation failed, falling back");
57 }
58 #[cfg(feature = "cuda")]
59 {
60 if let Ok(device) = Device::new_cuda(0) {
61 tracing::info!("Using CUDA GPU for embeddings");
62 return device;
63 }
64 tracing::warn!("CUDA feature enabled but device creation failed, falling back");
65 }
66 tracing::info!("Using CPU for embeddings");
67 Device::Cpu
68}
69
70pub struct EmbeddingService {
72 model: Mutex<BertModel>,
73 tokenizer: tokenizers::Tokenizer,
76 device: Device,
77 cache: Mutex<LruCache<String, Vec<f32>>>,
78}
79
80impl EmbeddingService {
81 pub fn new(model_dir: &Path) -> Result<Self, CodememError> {
84 let model_path = model_dir.join("model.safetensors");
85 let config_path = model_dir.join("config.json");
86 let tokenizer_path = model_dir.join("tokenizer.json");
87
88 if !model_path.exists() {
89 return Err(CodememError::Embedding(format!(
90 "Model not found at {}. Run `codemem init` to download it.",
91 model_path.display()
92 )));
93 }
94
95 let device = select_device();
96
97 let config_str = std::fs::read_to_string(&config_path)
99 .map_err(|e| CodememError::Embedding(format!("Failed to read config: {e}")))?;
100 let config: BertConfig = serde_json::from_str(&config_str)
101 .map_err(|e| CodememError::Embedding(format!("Failed to parse config: {e}")))?;
102
103 let model = {
107 let vb = unsafe {
108 VarBuilder::from_mmaped_safetensors(&[&model_path], DType::F32, &device)
109 .map_err(|e| CodememError::Embedding(format!("Failed to load weights: {e}")))?
110 };
111 BertModel::load(vb.pp("bert"), &config)
112 };
113 let model = match model {
115 Ok(m) => m,
116 Err(_) => {
117 let vb2 = unsafe {
118 VarBuilder::from_mmaped_safetensors(&[&model_path], DType::F32, &device)
119 .map_err(|e| {
120 CodememError::Embedding(format!("Failed to load weights: {e}"))
121 })?
122 };
123 BertModel::load(vb2, &config).map_err(|e| {
124 CodememError::Embedding(format!("Failed to load BERT model: {e}"))
125 })?
126 }
127 };
128
129 let mut tokenizer = tokenizers::Tokenizer::from_file(&tokenizer_path)
130 .map_err(|e| CodememError::Embedding(e.to_string()))?;
131
132 tokenizer
134 .with_truncation(Some(tokenizers::TruncationParams {
135 max_length: MAX_SEQ_LENGTH,
136 ..Default::default()
137 }))
138 .map_err(|e| CodememError::Embedding(format!("Truncation error: {e}")))?;
139
140 let cache = Mutex::new(LruCache::new(
141 NonZeroUsize::new(CACHE_CAPACITY).expect("CACHE_CAPACITY is non-zero"),
142 ));
143
144 Ok(Self {
145 model: Mutex::new(model),
146 tokenizer,
147 device,
148 cache,
149 })
150 }
151
152 pub fn default_model_dir() -> PathBuf {
154 dirs::home_dir()
155 .unwrap_or_else(|| PathBuf::from("."))
156 .join(".codemem")
157 .join("models")
158 .join(MODEL_NAME)
159 }
160
161 pub fn download_model(dest_dir: &Path) -> Result<PathBuf, CodememError> {
164 let model_dest = dest_dir.join("model.safetensors");
165 let config_dest = dest_dir.join("config.json");
166 let tokenizer_dest = dest_dir.join("tokenizer.json");
167
168 if model_dest.exists() && config_dest.exists() && tokenizer_dest.exists() {
169 tracing::info!("Model already downloaded at {}", dest_dir.display());
170 return Ok(dest_dir.to_path_buf());
171 }
172
173 std::fs::create_dir_all(dest_dir)
174 .map_err(|e| CodememError::Embedding(format!("Failed to create dir: {e}")))?;
175
176 tracing::info!("Downloading {} from HuggingFace...", HF_MODEL_REPO);
177
178 let api = hf_hub::api::sync::Api::new()
179 .map_err(|e| CodememError::Embedding(format!("HuggingFace API error: {e}")))?;
180 let repo = api.model(HF_MODEL_REPO.to_string());
181
182 let cached_model = repo
183 .get("model.safetensors")
184 .map_err(|e| CodememError::Embedding(format!("Failed to download model: {e}")))?;
185
186 let cached_config = repo
187 .get("config.json")
188 .map_err(|e| CodememError::Embedding(format!("Failed to download config: {e}")))?;
189
190 let cached_tokenizer = repo
191 .get("tokenizer.json")
192 .map_err(|e| CodememError::Embedding(format!("Failed to download tokenizer: {e}")))?;
193
194 std::fs::copy(&cached_model, &model_dest)
195 .map_err(|e| CodememError::Embedding(format!("Failed to copy model: {e}")))?;
196 std::fs::copy(&cached_config, &config_dest)
197 .map_err(|e| CodememError::Embedding(format!("Failed to copy config: {e}")))?;
198 std::fs::copy(&cached_tokenizer, &tokenizer_dest)
199 .map_err(|e| CodememError::Embedding(format!("Failed to copy tokenizer: {e}")))?;
200
201 tracing::info!("Model downloaded to {}", dest_dir.display());
202 Ok(dest_dir.to_path_buf())
203 }
204
205 pub fn embed(&self, text: &str) -> Result<Vec<f32>, CodememError> {
208 {
210 let mut cache = self
211 .cache
212 .lock()
213 .map_err(|e| CodememError::LockPoisoned(format!("embedding cache: {e}")))?;
214 if let Some(cached) = cache.get(text) {
215 return Ok(cached.clone());
216 }
217 }
218
219 let embedding = self.embed_uncached(text)?;
220
221 {
223 let mut cache = self
224 .cache
225 .lock()
226 .map_err(|e| CodememError::LockPoisoned(format!("embedding cache: {e}")))?;
227 cache.put(text.to_string(), embedding.clone());
228 }
229
230 Ok(embedding)
231 }
232
233 fn embed_uncached(&self, text: &str) -> Result<Vec<f32>, CodememError> {
235 let encoding = self
237 .tokenizer
238 .encode(text, true)
239 .map_err(|e| CodememError::Embedding(e.to_string()))?;
240
241 let input_ids: Vec<u32> = encoding.get_ids().to_vec();
242 let attention_mask: Vec<u32> = encoding.get_attention_mask().to_vec();
243
244 let input_ids_tensor = Tensor::new(&input_ids[..], &self.device)
246 .and_then(|t| t.unsqueeze(0))
247 .map_err(|e| CodememError::Embedding(format!("Tensor error: {e}")))?;
248
249 let token_type_ids = input_ids_tensor
250 .zeros_like()
251 .map_err(|e| CodememError::Embedding(format!("Tensor error: {e}")))?;
252
253 let attention_mask_tensor = Tensor::new(&attention_mask[..], &self.device)
254 .and_then(|t| t.unsqueeze(0))
255 .map_err(|e| CodememError::Embedding(format!("Tensor error: {e}")))?;
256
257 let model = self
259 .model
260 .lock()
261 .map_err(|e| CodememError::LockPoisoned(format!("embedding model: {e}")))?;
262 let hidden_states = model
263 .forward(
264 &input_ids_tensor,
265 &token_type_ids,
266 Some(&attention_mask_tensor),
267 )
268 .map_err(|e| CodememError::Embedding(format!("Model forward error: {e}")))?;
269 drop(model);
270
271 drop(input_ids_tensor);
273 drop(token_type_ids);
274
275 let mask = attention_mask_tensor
278 .to_dtype(DType::F32)
279 .and_then(|t| t.unsqueeze(2))
280 .map_err(|e| CodememError::Embedding(format!("Mask error: {e}")))?;
281
282 let sum_mask = mask
283 .sum(1)
284 .map_err(|e| CodememError::Embedding(format!("Sum error: {e}")))?;
285
286 let pooled = hidden_states
287 .broadcast_mul(&mask)
288 .and_then(|t| t.sum(1))
289 .and_then(|t| t.broadcast_div(&sum_mask))
290 .map_err(|e| CodememError::Embedding(format!("Pooling error: {e}")))?;
291
292 let normalized = pooled
294 .sqr()
295 .and_then(|t| t.sum_keepdim(1))
296 .and_then(|t| t.sqrt())
297 .and_then(|norm| pooled.broadcast_div(&norm))
298 .map_err(|e| CodememError::Embedding(format!("Normalize error: {e}")))?;
299
300 let embedding: Vec<f32> = normalized
302 .squeeze(0)
303 .and_then(|t| t.to_vec1())
304 .map_err(|e| CodememError::Embedding(format!("Extract error: {e}")))?;
305
306 Ok(embedding)
307 }
308
309 pub fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, CodememError> {
314 if texts.is_empty() {
315 return Ok(vec![]);
316 }
317
318 let mut results: Vec<Option<Vec<f32>>> = vec![None; texts.len()];
320 let mut uncached_indices = Vec::new();
321 let mut uncached_texts = Vec::new();
322
323 {
324 let mut cache = self
325 .cache
326 .lock()
327 .map_err(|e| CodememError::LockPoisoned(format!("embedding cache: {e}")))?;
328 for (i, text) in texts.iter().enumerate() {
329 if let Some(cached) = cache.get(*text) {
330 results[i] = Some(cached.clone());
331 } else {
332 uncached_indices.push(i);
333 uncached_texts.push(*text);
334 }
335 }
336 }
337
338 if !uncached_texts.is_empty() {
339 let new_embeddings = self.embed_batch_uncached(&uncached_texts)?;
340
341 let mut cache = self
342 .cache
343 .lock()
344 .map_err(|e| CodememError::LockPoisoned(format!("embedding cache: {e}")))?;
345 for (idx, embedding) in uncached_indices.into_iter().zip(new_embeddings) {
346 cache.put(texts[idx].to_string(), embedding.clone());
347 results[idx] = Some(embedding);
348 }
349 }
350
351 let expected = texts.len();
353 let output: Vec<Vec<f32>> = results
354 .into_iter()
355 .enumerate()
356 .map(|(i, opt)| {
357 opt.ok_or_else(|| {
358 CodememError::Embedding(format!(
359 "Missing embedding for text at index {i} (batch size {expected})"
360 ))
361 })
362 })
363 .collect::<Result<Vec<_>, _>>()?;
364 Ok(output)
365 }
366
367 fn embed_batch_uncached(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, CodememError> {
373 if texts.is_empty() {
374 return Ok(vec![]);
375 }
376
377 let mut all_embeddings = Vec::with_capacity(texts.len());
378
379 for chunk in texts.chunks(BATCH_SIZE) {
380 let mut tokenizer = self.tokenizer.clone();
383 tokenizer.with_padding(Some(PaddingParams {
384 strategy: PaddingStrategy::BatchLongest,
385 ..Default::default()
386 }));
387
388 let encodings = tokenizer
389 .encode_batch(chunk.to_vec(), true)
390 .map_err(|e| CodememError::Embedding(format!("Batch encode error: {e}")))?;
391
392 let batch_len = encodings.len();
393 let seq_len = encodings[0].get_ids().len();
394
395 let all_ids: Vec<u32> = encodings
397 .iter()
398 .flat_map(|e| e.get_ids())
399 .copied()
400 .collect();
401 let all_masks: Vec<u32> = encodings
402 .iter()
403 .flat_map(|e| e.get_attention_mask())
404 .copied()
405 .collect();
406
407 let input_ids = Tensor::new(all_ids.as_slice(), &self.device)
409 .and_then(|t| t.reshape((batch_len, seq_len)))
410 .map_err(|e| CodememError::Embedding(format!("Tensor error: {e}")))?;
411
412 let token_type_ids = input_ids
413 .zeros_like()
414 .map_err(|e| CodememError::Embedding(format!("Tensor error: {e}")))?;
415
416 let attention_mask = Tensor::new(all_masks.as_slice(), &self.device)
417 .and_then(|t| t.reshape((batch_len, seq_len)))
418 .map_err(|e| CodememError::Embedding(format!("Tensor error: {e}")))?;
419
420 let model = self
422 .model
423 .lock()
424 .map_err(|e| CodememError::LockPoisoned(format!("embedding model: {e}")))?;
425 let hidden_states = model
426 .forward(&input_ids, &token_type_ids, Some(&attention_mask))
427 .map_err(|e| CodememError::Embedding(format!("Forward error: {e}")))?;
428 drop(model);
429
430 drop(input_ids);
432 drop(token_type_ids);
433
434 let mask = attention_mask
436 .to_dtype(DType::F32)
437 .and_then(|t| t.unsqueeze(2))
438 .map_err(|e| CodememError::Embedding(format!("Mask error: {e}")))?;
439
440 let sum_mask = mask
441 .sum(1)
442 .map_err(|e| CodememError::Embedding(format!("Sum error: {e}")))?;
443
444 let pooled = hidden_states
445 .broadcast_mul(&mask)
446 .and_then(|t| t.sum(1))
447 .and_then(|t| t.broadcast_div(&sum_mask))
448 .map_err(|e| CodememError::Embedding(format!("Pooling error: {e}")))?;
449
450 let norm = pooled
452 .sqr()
453 .and_then(|t| t.sum_keepdim(1))
454 .and_then(|t| t.sqrt())
455 .map_err(|e| CodememError::Embedding(format!("Norm error: {e}")))?;
456
457 let normalized = pooled
458 .broadcast_div(&norm)
459 .map_err(|e| CodememError::Embedding(format!("Normalize error: {e}")))?;
460
461 let flat: Vec<f32> = normalized
464 .flatten_all()
465 .and_then(|t| t.to_vec1())
466 .map_err(|e| CodememError::Embedding(format!("Extract error: {e}")))?;
467 for i in 0..batch_len {
468 let start = i * DIMENSIONS;
469 all_embeddings.push(flat[start..start + DIMENSIONS].to_vec());
470 }
471 }
472
473 Ok(all_embeddings)
474 }
475
476 pub fn cache_stats(&self) -> (usize, usize) {
478 match self.cache.lock() {
479 Ok(cache) => (cache.len(), CACHE_CAPACITY),
480 Err(_) => (0, CACHE_CAPACITY),
481 }
482 }
483}
484
485impl EmbeddingProvider for EmbeddingService {
486 fn dimensions(&self) -> usize {
487 DIMENSIONS
488 }
489
490 fn embed(&self, text: &str) -> Result<Vec<f32>, CodememError> {
491 self.embed(text)
492 }
493
494 fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, CodememError> {
495 self.embed_batch(texts)
496 }
497
498 fn name(&self) -> &str {
499 "candle"
500 }
501
502 fn cache_stats(&self) -> (usize, usize) {
503 self.cache_stats()
504 }
505}
506
507pub struct CachedProvider {
511 inner: Box<dyn EmbeddingProvider>,
512 cache: Mutex<LruCache<String, Vec<f32>>>,
513}
514
515impl CachedProvider {
516 pub fn new(inner: Box<dyn EmbeddingProvider>, capacity: usize) -> Self {
517 let cap =
519 NonZeroUsize::new(capacity).unwrap_or(NonZeroUsize::new(1).expect("1 is non-zero"));
520 Self {
521 inner,
522 cache: Mutex::new(LruCache::new(cap)),
523 }
524 }
525}
526
527impl EmbeddingProvider for CachedProvider {
528 fn dimensions(&self) -> usize {
529 self.inner.dimensions()
530 }
531
532 fn embed(&self, text: &str) -> Result<Vec<f32>, CodememError> {
533 {
534 let mut cache = self
535 .cache
536 .lock()
537 .map_err(|e| CodememError::LockPoisoned(format!("cached provider: {e}")))?;
538 if let Some(cached) = cache.get(text) {
539 return Ok(cached.clone());
540 }
541 }
542 let embedding = self.inner.embed(text)?;
543 {
544 let mut cache = self
545 .cache
546 .lock()
547 .map_err(|e| CodememError::LockPoisoned(format!("cached provider: {e}")))?;
548 cache.put(text.to_string(), embedding.clone());
549 }
550 Ok(embedding)
551 }
552
553 fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, CodememError> {
554 let mut results: Vec<Option<Vec<f32>>> = vec![None; texts.len()];
556 let mut uncached = Vec::new();
557 let mut uncached_idx = Vec::new();
558
559 {
560 let mut cache = self
561 .cache
562 .lock()
563 .map_err(|e| CodememError::LockPoisoned(format!("cached provider: {e}")))?;
564 for (i, text) in texts.iter().enumerate() {
565 if let Some(cached) = cache.get(*text) {
566 results[i] = Some(cached.clone());
567 } else {
568 uncached_idx.push(i);
569 uncached.push(*text);
570 }
571 }
572 }
573
574 if !uncached.is_empty() {
575 let new_embeddings = self.inner.embed_batch(&uncached)?;
576 let mut cache = self
577 .cache
578 .lock()
579 .map_err(|e| CodememError::LockPoisoned(format!("cached provider: {e}")))?;
580 for (idx, embedding) in uncached_idx.into_iter().zip(new_embeddings) {
581 cache.put(texts[idx].to_string(), embedding.clone());
582 results[idx] = Some(embedding);
583 }
584 }
585
586 let expected = texts.len();
588 let output: Vec<Vec<f32>> = results
589 .into_iter()
590 .enumerate()
591 .map(|(i, opt)| {
592 opt.ok_or_else(|| {
593 CodememError::Embedding(format!(
594 "Missing embedding for text at index {i} (batch size {expected})"
595 ))
596 })
597 })
598 .collect::<Result<Vec<_>, _>>()?;
599 Ok(output)
600 }
601
602 fn name(&self) -> &str {
603 self.inner.name()
604 }
605
606 fn cache_stats(&self) -> (usize, usize) {
607 match self.cache.lock() {
608 Ok(cache) => (cache.len(), cache.cap().into()),
609 Err(_) => (0, 0),
610 }
611 }
612}
613
614pub fn from_env() -> Result<Box<dyn EmbeddingProvider>, CodememError> {
626 let provider = std::env::var("CODEMEM_EMBED_PROVIDER")
627 .unwrap_or_else(|_| "candle".to_string())
628 .to_lowercase();
629 let dimensions: usize = std::env::var("CODEMEM_EMBED_DIMENSIONS")
630 .ok()
631 .and_then(|s| s.parse().ok())
632 .unwrap_or(DIMENSIONS);
633
634 match provider.as_str() {
635 "ollama" => {
636 let base_url = std::env::var("CODEMEM_EMBED_URL")
637 .unwrap_or_else(|_| ollama::DEFAULT_BASE_URL.to_string());
638 let model = std::env::var("CODEMEM_EMBED_MODEL")
639 .unwrap_or_else(|_| ollama::DEFAULT_MODEL.to_string());
640 let inner = Box::new(ollama::OllamaProvider::new(&base_url, &model, dimensions));
641 Ok(Box::new(CachedProvider::new(inner, CACHE_CAPACITY)))
642 }
643 "openai" => {
644 let api_key = std::env::var("CODEMEM_EMBED_API_KEY")
645 .or_else(|_| std::env::var("OPENAI_API_KEY"))
646 .map_err(|_| {
647 CodememError::Embedding(
648 "CODEMEM_EMBED_API_KEY or OPENAI_API_KEY required for OpenAI embeddings"
649 .into(),
650 )
651 })?;
652 let model = std::env::var("CODEMEM_EMBED_MODEL")
653 .unwrap_or_else(|_| openai::DEFAULT_MODEL.to_string());
654 let base_url = std::env::var("CODEMEM_EMBED_URL").ok();
655 let inner = Box::new(openai::OpenAIProvider::new(
656 &api_key,
657 &model,
658 dimensions,
659 base_url.as_deref(),
660 ));
661 Ok(Box::new(CachedProvider::new(inner, CACHE_CAPACITY)))
662 }
663 "candle" | "" => {
664 let model_dir = EmbeddingService::default_model_dir();
665 let service = EmbeddingService::new(&model_dir)?;
666 Ok(Box::new(service))
667 }
668 other => Err(CodememError::Embedding(format!(
669 "Unknown embedding provider: '{}'. Use 'candle', 'ollama', or 'openai'.",
670 other
671 ))),
672 }
673}
674
675#[cfg(test)]
676#[path = "tests/lib_tests.rs"]
677mod tests;