1pub mod ollama;
9pub mod openai;
10
11use candle_core::{DType, Device, Tensor};
12use candle_nn::VarBuilder;
13use candle_transformers::models::bert::{BertModel, Config as BertConfig};
14use codemem_core::CodememError;
15use lru::LruCache;
16use std::num::NonZeroUsize;
17use std::path::{Path, PathBuf};
18use std::sync::Mutex;
19use tokenizers::{PaddingParams, PaddingStrategy, TruncationParams};
20
21pub const MODEL_NAME: &str = "bge-base-en-v1.5";
23
24const HF_MODEL_REPO: &str = "BAAI/bge-base-en-v1.5";
26
27pub const DIMENSIONS: usize = 768;
29
30const MAX_SEQ_LENGTH: usize = 512;
32
33pub const CACHE_CAPACITY: usize = 10_000;
35
36pub trait EmbeddingProvider: Send + Sync {
40 fn dimensions(&self) -> usize;
42
43 fn embed(&self, text: &str) -> Result<Vec<f32>, CodememError>;
45
46 fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, CodememError> {
48 texts.iter().map(|t| self.embed(t)).collect()
49 }
50
51 fn name(&self) -> &str;
53
54 fn cache_stats(&self) -> (usize, usize) {
56 (0, 0)
57 }
58}
59
60const BATCH_SIZE: usize = 32;
64
65fn select_device() -> Device {
70 #[cfg(feature = "metal")]
71 {
72 if let Ok(device) = Device::new_metal(0) {
73 tracing::info!("Using Metal GPU for embeddings");
74 return device;
75 }
76 tracing::warn!("Metal feature enabled but device creation failed, falling back");
77 }
78 #[cfg(feature = "cuda")]
79 {
80 if let Ok(device) = Device::new_cuda(0) {
81 tracing::info!("Using CUDA GPU for embeddings");
82 return device;
83 }
84 tracing::warn!("CUDA feature enabled but device creation failed, falling back");
85 }
86 tracing::info!("Using CPU for embeddings");
87 Device::Cpu
88}
89
90pub struct EmbeddingService {
92 model: Mutex<BertModel>,
93 tokenizer: tokenizers::Tokenizer,
94 device: Device,
95 cache: Mutex<LruCache<String, Vec<f32>>>,
96}
97
98impl EmbeddingService {
99 pub fn new(model_dir: &Path) -> Result<Self, CodememError> {
102 let model_path = model_dir.join("model.safetensors");
103 let config_path = model_dir.join("config.json");
104 let tokenizer_path = model_dir.join("tokenizer.json");
105
106 if !model_path.exists() {
107 return Err(CodememError::Embedding(format!(
108 "Model not found at {}. Run `codemem init` to download it.",
109 model_path.display()
110 )));
111 }
112
113 let device = select_device();
114
115 let config_str = std::fs::read_to_string(&config_path)
117 .map_err(|e| CodememError::Embedding(format!("Failed to read config: {e}")))?;
118 let config: BertConfig = serde_json::from_str(&config_str)
119 .map_err(|e| CodememError::Embedding(format!("Failed to parse config: {e}")))?;
120
121 let vb = unsafe {
123 VarBuilder::from_mmaped_safetensors(&[&model_path], DType::F32, &device)
124 .map_err(|e| CodememError::Embedding(format!("Failed to load weights: {e}")))?
125 };
126
127 let model = BertModel::load(vb.pp("bert"), &config)
129 .or_else(|_| {
130 let vb2 = unsafe {
131 VarBuilder::from_mmaped_safetensors(&[&model_path], DType::F32, &device)
132 .map_err(|e| {
133 candle_core::Error::Msg(format!("Failed to load weights: {e}"))
134 })
135 }?;
136 BertModel::load(vb2, &config)
137 })
138 .map_err(|e| CodememError::Embedding(format!("Failed to load BERT model: {e}")))?;
139
140 let tokenizer = tokenizers::Tokenizer::from_file(&tokenizer_path)
141 .map_err(|e| CodememError::Embedding(e.to_string()))?;
142
143 let cache = Mutex::new(LruCache::new(NonZeroUsize::new(CACHE_CAPACITY).unwrap()));
144
145 Ok(Self {
146 model: Mutex::new(model),
147 tokenizer,
148 device,
149 cache,
150 })
151 }
152
153 pub fn default_model_dir() -> PathBuf {
155 dirs::home_dir()
156 .unwrap_or_else(|| PathBuf::from("."))
157 .join(".codemem")
158 .join("models")
159 .join(MODEL_NAME)
160 }
161
162 pub fn download_model(dest_dir: &Path) -> Result<PathBuf, CodememError> {
165 let model_dest = dest_dir.join("model.safetensors");
166 let config_dest = dest_dir.join("config.json");
167 let tokenizer_dest = dest_dir.join("tokenizer.json");
168
169 if model_dest.exists() && config_dest.exists() && tokenizer_dest.exists() {
170 tracing::info!("Model already downloaded at {}", dest_dir.display());
171 return Ok(dest_dir.to_path_buf());
172 }
173
174 std::fs::create_dir_all(dest_dir)
175 .map_err(|e| CodememError::Embedding(format!("Failed to create dir: {e}")))?;
176
177 tracing::info!("Downloading {} from HuggingFace...", HF_MODEL_REPO);
178
179 let api = hf_hub::api::sync::Api::new()
180 .map_err(|e| CodememError::Embedding(format!("HuggingFace API error: {e}")))?;
181 let repo = api.model(HF_MODEL_REPO.to_string());
182
183 let cached_model = repo
184 .get("model.safetensors")
185 .map_err(|e| CodememError::Embedding(format!("Failed to download model: {e}")))?;
186
187 let cached_config = repo
188 .get("config.json")
189 .map_err(|e| CodememError::Embedding(format!("Failed to download config: {e}")))?;
190
191 let cached_tokenizer = repo
192 .get("tokenizer.json")
193 .map_err(|e| CodememError::Embedding(format!("Failed to download tokenizer: {e}")))?;
194
195 std::fs::copy(&cached_model, &model_dest)
196 .map_err(|e| CodememError::Embedding(format!("Failed to copy model: {e}")))?;
197 std::fs::copy(&cached_config, &config_dest)
198 .map_err(|e| CodememError::Embedding(format!("Failed to copy config: {e}")))?;
199 std::fs::copy(&cached_tokenizer, &tokenizer_dest)
200 .map_err(|e| CodememError::Embedding(format!("Failed to copy tokenizer: {e}")))?;
201
202 tracing::info!("Model downloaded to {}", dest_dir.display());
203 Ok(dest_dir.to_path_buf())
204 }
205
206 pub fn embed(&self, text: &str) -> Result<Vec<f32>, CodememError> {
209 {
211 let mut cache = self.cache.lock().unwrap();
212 if let Some(cached) = cache.get(text) {
213 return Ok(cached.clone());
214 }
215 }
216
217 let embedding = self.embed_uncached(text)?;
218
219 {
221 let mut cache = self.cache.lock().unwrap();
222 cache.put(text.to_string(), embedding.clone());
223 }
224
225 Ok(embedding)
226 }
227
228 fn embed_uncached(&self, text: &str) -> Result<Vec<f32>, CodememError> {
230 let mut tokenizer = self.tokenizer.clone();
232
233 tokenizer
234 .with_truncation(Some(tokenizers::TruncationParams {
235 max_length: MAX_SEQ_LENGTH,
236 ..Default::default()
237 }))
238 .map_err(|e| CodememError::Embedding(format!("Truncation error: {e}")))?;
239
240 let encoding = tokenizer
241 .encode(text, true)
242 .map_err(|e| CodememError::Embedding(e.to_string()))?;
243
244 let input_ids: Vec<u32> = encoding.get_ids().to_vec();
245 let attention_mask: Vec<u32> = encoding.get_attention_mask().to_vec();
246
247 let input_ids_tensor = Tensor::new(&input_ids[..], &self.device)
249 .and_then(|t| t.unsqueeze(0))
250 .map_err(|e| CodememError::Embedding(format!("Tensor error: {e}")))?;
251
252 let token_type_ids = input_ids_tensor
253 .zeros_like()
254 .map_err(|e| CodememError::Embedding(format!("Tensor error: {e}")))?;
255
256 let attention_mask_tensor = Tensor::new(&attention_mask[..], &self.device)
257 .and_then(|t| t.unsqueeze(0))
258 .map_err(|e| CodememError::Embedding(format!("Tensor error: {e}")))?;
259
260 let model = self.model.lock().unwrap();
262 let hidden_states = model
263 .forward(
264 &input_ids_tensor,
265 &token_type_ids,
266 Some(&attention_mask_tensor),
267 )
268 .map_err(|e| CodememError::Embedding(format!("Model forward error: {e}")))?;
269 drop(model);
270
271 let mask = attention_mask_tensor
274 .to_dtype(DType::F32)
275 .and_then(|t| t.unsqueeze(2))
276 .map_err(|e| CodememError::Embedding(format!("Mask error: {e}")))?;
277
278 let sum_mask = mask
279 .sum(1)
280 .map_err(|e| CodememError::Embedding(format!("Sum error: {e}")))?;
281
282 let pooled = hidden_states
283 .broadcast_mul(&mask)
284 .and_then(|t| t.sum(1))
285 .and_then(|t| t.broadcast_div(&sum_mask))
286 .map_err(|e| CodememError::Embedding(format!("Pooling error: {e}")))?;
287
288 let normalized = pooled
290 .sqr()
291 .and_then(|t| t.sum_keepdim(1))
292 .and_then(|t| t.sqrt())
293 .and_then(|norm| pooled.broadcast_div(&norm))
294 .map_err(|e| CodememError::Embedding(format!("Normalize error: {e}")))?;
295
296 let embedding: Vec<f32> = normalized
298 .squeeze(0)
299 .and_then(|t| t.to_vec1())
300 .map_err(|e| CodememError::Embedding(format!("Extract error: {e}")))?;
301
302 Ok(embedding)
303 }
304
305 pub fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, CodememError> {
310 if texts.is_empty() {
311 return Ok(vec![]);
312 }
313
314 let mut results: Vec<Option<Vec<f32>>> = vec![None; texts.len()];
316 let mut uncached_indices = Vec::new();
317 let mut uncached_texts = Vec::new();
318
319 {
320 let mut cache = self.cache.lock().unwrap();
321 for (i, text) in texts.iter().enumerate() {
322 if let Some(cached) = cache.get(*text) {
323 results[i] = Some(cached.clone());
324 } else {
325 uncached_indices.push(i);
326 uncached_texts.push(*text);
327 }
328 }
329 }
330
331 if !uncached_texts.is_empty() {
332 let new_embeddings = self.embed_batch_uncached(&uncached_texts)?;
333
334 let mut cache = self.cache.lock().unwrap();
335 for (idx, embedding) in uncached_indices.into_iter().zip(new_embeddings) {
336 cache.put(texts[idx].to_string(), embedding.clone());
337 results[idx] = Some(embedding);
338 }
339 }
340
341 Ok(results.into_iter().map(|r| r.unwrap()).collect())
342 }
343
344 fn embed_batch_uncached(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, CodememError> {
350 if texts.is_empty() {
351 return Ok(vec![]);
352 }
353
354 let mut all_embeddings = Vec::with_capacity(texts.len());
355
356 for chunk in texts.chunks(BATCH_SIZE) {
357 let mut tokenizer = self.tokenizer.clone();
358
359 tokenizer
360 .with_truncation(Some(TruncationParams {
361 max_length: MAX_SEQ_LENGTH,
362 ..Default::default()
363 }))
364 .map_err(|e| CodememError::Embedding(format!("Truncation error: {e}")))?;
365
366 tokenizer.with_padding(Some(PaddingParams {
368 strategy: PaddingStrategy::BatchLongest,
369 ..Default::default()
370 }));
371
372 let encodings = tokenizer
373 .encode_batch(chunk.to_vec(), true)
374 .map_err(|e| CodememError::Embedding(format!("Batch encode error: {e}")))?;
375
376 let batch_len = encodings.len();
377 let seq_len = encodings[0].get_ids().len();
378
379 let all_ids: Vec<u32> = encodings
381 .iter()
382 .flat_map(|e| e.get_ids())
383 .copied()
384 .collect();
385 let all_masks: Vec<u32> = encodings
386 .iter()
387 .flat_map(|e| e.get_attention_mask())
388 .copied()
389 .collect();
390
391 let input_ids = Tensor::new(all_ids.as_slice(), &self.device)
393 .and_then(|t| t.reshape((batch_len, seq_len)))
394 .map_err(|e| CodememError::Embedding(format!("Tensor error: {e}")))?;
395
396 let token_type_ids = input_ids
397 .zeros_like()
398 .map_err(|e| CodememError::Embedding(format!("Tensor error: {e}")))?;
399
400 let attention_mask = Tensor::new(all_masks.as_slice(), &self.device)
401 .and_then(|t| t.reshape((batch_len, seq_len)))
402 .map_err(|e| CodememError::Embedding(format!("Tensor error: {e}")))?;
403
404 let model = self.model.lock().unwrap();
406 let hidden_states = model
407 .forward(&input_ids, &token_type_ids, Some(&attention_mask))
408 .map_err(|e| CodememError::Embedding(format!("Forward error: {e}")))?;
409 drop(model);
410
411 let mask = attention_mask
413 .to_dtype(DType::F32)
414 .and_then(|t| t.unsqueeze(2))
415 .map_err(|e| CodememError::Embedding(format!("Mask error: {e}")))?;
416
417 let sum_mask = mask
418 .sum(1)
419 .map_err(|e| CodememError::Embedding(format!("Sum error: {e}")))?;
420
421 let pooled = hidden_states
422 .broadcast_mul(&mask)
423 .and_then(|t| t.sum(1))
424 .and_then(|t| t.broadcast_div(&sum_mask))
425 .map_err(|e| CodememError::Embedding(format!("Pooling error: {e}")))?;
426
427 let norm = pooled
429 .sqr()
430 .and_then(|t| t.sum_keepdim(1))
431 .and_then(|t| t.sqrt())
432 .map_err(|e| CodememError::Embedding(format!("Norm error: {e}")))?;
433
434 let normalized = pooled
435 .broadcast_div(&norm)
436 .map_err(|e| CodememError::Embedding(format!("Normalize error: {e}")))?;
437
438 for i in 0..batch_len {
440 let row: Vec<f32> = normalized
441 .get(i)
442 .and_then(|t| t.to_vec1())
443 .map_err(|e| CodememError::Embedding(format!("Extract error: {e}")))?;
444 all_embeddings.push(row);
445 }
446 }
447
448 Ok(all_embeddings)
449 }
450
451 pub fn cache_stats(&self) -> (usize, usize) {
453 let cache = self.cache.lock().unwrap();
454 (cache.len(), CACHE_CAPACITY)
455 }
456}
457
458impl EmbeddingProvider for EmbeddingService {
459 fn dimensions(&self) -> usize {
460 DIMENSIONS
461 }
462
463 fn embed(&self, text: &str) -> Result<Vec<f32>, CodememError> {
464 self.embed(text)
465 }
466
467 fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, CodememError> {
468 self.embed_batch(texts)
469 }
470
471 fn name(&self) -> &str {
472 "candle"
473 }
474
475 fn cache_stats(&self) -> (usize, usize) {
476 self.cache_stats()
477 }
478}
479
480pub struct CachedProvider {
484 inner: Box<dyn EmbeddingProvider>,
485 cache: Mutex<LruCache<String, Vec<f32>>>,
486}
487
488impl CachedProvider {
489 pub fn new(inner: Box<dyn EmbeddingProvider>, capacity: usize) -> Self {
490 Self {
491 inner,
492 cache: Mutex::new(LruCache::new(
493 NonZeroUsize::new(capacity).unwrap_or(NonZeroUsize::new(1).unwrap()),
494 )),
495 }
496 }
497}
498
499impl EmbeddingProvider for CachedProvider {
500 fn dimensions(&self) -> usize {
501 self.inner.dimensions()
502 }
503
504 fn embed(&self, text: &str) -> Result<Vec<f32>, CodememError> {
505 {
506 let mut cache = self.cache.lock().unwrap();
507 if let Some(cached) = cache.get(text) {
508 return Ok(cached.clone());
509 }
510 }
511 let embedding = self.inner.embed(text)?;
512 {
513 let mut cache = self.cache.lock().unwrap();
514 cache.put(text.to_string(), embedding.clone());
515 }
516 Ok(embedding)
517 }
518
519 fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, CodememError> {
520 let mut results: Vec<Option<Vec<f32>>> = vec![None; texts.len()];
522 let mut uncached = Vec::new();
523 let mut uncached_idx = Vec::new();
524
525 {
526 let mut cache = self.cache.lock().unwrap();
527 for (i, text) in texts.iter().enumerate() {
528 if let Some(cached) = cache.get(*text) {
529 results[i] = Some(cached.clone());
530 } else {
531 uncached_idx.push(i);
532 uncached.push(*text);
533 }
534 }
535 }
536
537 if !uncached.is_empty() {
538 let new_embeddings = self.inner.embed_batch(&uncached)?;
539 let mut cache = self.cache.lock().unwrap();
540 for (idx, embedding) in uncached_idx.into_iter().zip(new_embeddings) {
541 cache.put(texts[idx].to_string(), embedding.clone());
542 results[idx] = Some(embedding);
543 }
544 }
545
546 Ok(results.into_iter().map(|r| r.unwrap()).collect())
547 }
548
549 fn name(&self) -> &str {
550 self.inner.name()
551 }
552
553 fn cache_stats(&self) -> (usize, usize) {
554 let cache = self.cache.lock().unwrap();
555 (cache.len(), cache.cap().into())
556 }
557}
558
559pub fn from_env() -> Result<Box<dyn EmbeddingProvider>, CodememError> {
571 let provider = std::env::var("CODEMEM_EMBED_PROVIDER")
572 .unwrap_or_else(|_| "candle".to_string())
573 .to_lowercase();
574 let dimensions: usize = std::env::var("CODEMEM_EMBED_DIMENSIONS")
575 .ok()
576 .and_then(|s| s.parse().ok())
577 .unwrap_or(DIMENSIONS);
578
579 match provider.as_str() {
580 "ollama" => {
581 let base_url = std::env::var("CODEMEM_EMBED_URL")
582 .unwrap_or_else(|_| ollama::DEFAULT_BASE_URL.to_string());
583 let model = std::env::var("CODEMEM_EMBED_MODEL")
584 .unwrap_or_else(|_| ollama::DEFAULT_MODEL.to_string());
585 let inner = Box::new(ollama::OllamaProvider::new(&base_url, &model, dimensions));
586 Ok(Box::new(CachedProvider::new(inner, CACHE_CAPACITY)))
587 }
588 "openai" => {
589 let api_key = std::env::var("CODEMEM_EMBED_API_KEY")
590 .or_else(|_| std::env::var("OPENAI_API_KEY"))
591 .map_err(|_| {
592 CodememError::Embedding(
593 "CODEMEM_EMBED_API_KEY or OPENAI_API_KEY required for OpenAI embeddings"
594 .into(),
595 )
596 })?;
597 let model = std::env::var("CODEMEM_EMBED_MODEL")
598 .unwrap_or_else(|_| openai::DEFAULT_MODEL.to_string());
599 let base_url = std::env::var("CODEMEM_EMBED_URL").ok();
600 let inner = Box::new(openai::OpenAIProvider::new(
601 &api_key,
602 &model,
603 dimensions,
604 base_url.as_deref(),
605 ));
606 Ok(Box::new(CachedProvider::new(inner, CACHE_CAPACITY)))
607 }
608 "candle" | "" => {
609 let model_dir = EmbeddingService::default_model_dir();
610 let service = EmbeddingService::new(&model_dir)?;
611 Ok(Box::new(service))
612 }
613 other => Err(CodememError::Embedding(format!(
614 "Unknown embedding provider: '{}'. Use 'candle', 'ollama', or 'openai'.",
615 other
616 ))),
617 }
618}