1pub mod ollama;
9pub mod openai;
10
11use candle_core::{DType, Device, Tensor};
12use candle_nn::VarBuilder;
13use candle_transformers::models::bert::{BertModel, Config as BertConfig};
14use codemem_core::CodememError;
15use lru::LruCache;
16use std::num::NonZeroUsize;
17use std::path::{Path, PathBuf};
18use std::sync::Mutex;
19use tokenizers::{PaddingParams, PaddingStrategy};
20
21pub const MODEL_NAME: &str = "bge-base-en-v1.5";
23
24pub const DEFAULT_HF_REPO: &str = "BAAI/bge-base-en-v1.5";
27
28pub const DEFAULT_REMOTE_DIMENSIONS: usize = 768;
31
32const MAX_SEQ_LENGTH: usize = 512;
34
35pub const CACHE_CAPACITY: usize = 10_000;
37
38pub use codemem_core::EmbeddingProvider;
40
41pub const DEFAULT_BATCH_SIZE: usize = 16;
46
47fn select_device() -> Device {
52 #[cfg(feature = "metal")]
53 {
54 match std::panic::catch_unwind(|| Device::new_metal(0)) {
56 Ok(Ok(device)) => {
57 tracing::info!("Using Metal GPU for embeddings");
58 return device;
59 }
60 Ok(Err(e)) => {
61 tracing::warn!("Metal device creation failed: {e}, falling back to CPU");
62 }
63 Err(_) => {
64 tracing::warn!("Metal device creation panicked, falling back to CPU");
65 }
66 }
67 }
68 #[cfg(feature = "cuda")]
69 {
70 match std::panic::catch_unwind(|| Device::new_cuda(0)) {
71 Ok(Ok(device)) => {
72 tracing::info!("Using CUDA GPU for embeddings");
73 return device;
74 }
75 Ok(Err(e)) => {
76 tracing::warn!("CUDA device creation failed: {e}, falling back to CPU");
77 }
78 Err(_) => {
79 tracing::warn!("CUDA device creation panicked, falling back to CPU");
80 }
81 }
82 }
83 tracing::info!("Using CPU for embeddings");
84 Device::Cpu
85}
86
87pub struct EmbeddingService {
89 model: Mutex<BertModel>,
90 tokenizer: tokenizers::Tokenizer,
93 device: Device,
94 batch_size: usize,
96 hidden_size: usize,
98}
99
100impl EmbeddingService {
101 pub fn new(model_dir: &Path, batch_size: usize, dtype: DType) -> Result<Self, CodememError> {
106 let model_path = model_dir.join("model.safetensors");
107 let config_path = model_dir.join("config.json");
108 let tokenizer_path = model_dir.join("tokenizer.json");
109
110 if !model_path.exists() {
111 return Err(CodememError::Embedding(format!(
112 "Model not found at {}. Run `codemem init` to download it.",
113 model_path.display()
114 )));
115 }
116
117 let device = select_device();
118
119 tracing::info!(
120 "Loading model from {} (dtype: {:?}, device: {:?})",
121 model_dir.display(),
122 dtype,
123 device
124 );
125
126 let config_str = std::fs::read_to_string(&config_path)
128 .map_err(|e| CodememError::Embedding(format!("Failed to read config: {e}")))?;
129 let config: BertConfig = serde_json::from_str(&config_str)
130 .map_err(|e| CodememError::Embedding(format!("Failed to parse config: {e}")))?;
131
132 let hidden_size = config.hidden_size;
133
134 let model = {
138 let vb = unsafe {
139 VarBuilder::from_mmaped_safetensors(&[&model_path], dtype, &device)
140 .map_err(|e| CodememError::Embedding(format!("Failed to load weights: {e}")))?
141 };
142 BertModel::load(vb.pp("bert"), &config)
143 };
144 let model = match model {
146 Ok(m) => m,
147 Err(_) => {
148 let vb2 = unsafe {
149 VarBuilder::from_mmaped_safetensors(&[&model_path], dtype, &device).map_err(
150 |e| CodememError::Embedding(format!("Failed to load weights: {e}")),
151 )?
152 };
153 BertModel::load(vb2, &config).map_err(|e| {
154 CodememError::Embedding(format!("Failed to load BERT model: {e}"))
155 })?
156 }
157 };
158
159 let mut tokenizer = tokenizers::Tokenizer::from_file(&tokenizer_path)
160 .map_err(|e| CodememError::Embedding(e.to_string()))?;
161
162 tokenizer
164 .with_truncation(Some(tokenizers::TruncationParams {
165 max_length: MAX_SEQ_LENGTH,
166 ..Default::default()
167 }))
168 .map_err(|e| CodememError::Embedding(format!("Truncation error: {e}")))?;
169
170 Ok(Self {
171 model: Mutex::new(model),
172 tokenizer,
173 device,
174 batch_size,
175 hidden_size,
176 })
177 }
178
179 pub fn model_dir_for(model_name: &str) -> PathBuf {
182 dirs::home_dir()
183 .unwrap_or_else(|| PathBuf::from("."))
184 .join(".codemem")
185 .join("models")
186 .join(model_name)
187 }
188
189 pub fn default_model_dir() -> PathBuf {
191 Self::model_dir_for(MODEL_NAME)
192 }
193
194 pub fn download_model(dest_dir: &Path, hf_repo: &str) -> Result<PathBuf, CodememError> {
198 let model_dest = dest_dir.join("model.safetensors");
199 let config_dest = dest_dir.join("config.json");
200 let tokenizer_dest = dest_dir.join("tokenizer.json");
201
202 if model_dest.exists() && config_dest.exists() && tokenizer_dest.exists() {
203 tracing::info!("Model already downloaded at {}", dest_dir.display());
204 return Ok(dest_dir.to_path_buf());
205 }
206
207 std::fs::create_dir_all(dest_dir)
208 .map_err(|e| CodememError::Embedding(format!("Failed to create dir: {e}")))?;
209
210 tracing::info!("Downloading {} from HuggingFace...", hf_repo);
211
212 let api = hf_hub::api::sync::Api::new()
213 .map_err(|e| CodememError::Embedding(format!("HuggingFace API error: {e}")))?;
214 let repo = api.model(hf_repo.to_string());
215
216 let cached_model = repo
217 .get("model.safetensors")
218 .map_err(|e| CodememError::Embedding(format!("Failed to download model: {e}")))?;
219
220 let cached_config = repo
221 .get("config.json")
222 .map_err(|e| CodememError::Embedding(format!("Failed to download config: {e}")))?;
223
224 let cached_tokenizer = repo
225 .get("tokenizer.json")
226 .map_err(|e| CodememError::Embedding(format!("Failed to download tokenizer: {e}")))?;
227
228 std::fs::copy(&cached_model, &model_dest)
229 .map_err(|e| CodememError::Embedding(format!("Failed to copy model: {e}")))?;
230 std::fs::copy(&cached_config, &config_dest)
231 .map_err(|e| CodememError::Embedding(format!("Failed to copy config: {e}")))?;
232 std::fs::copy(&cached_tokenizer, &tokenizer_dest)
233 .map_err(|e| CodememError::Embedding(format!("Failed to copy tokenizer: {e}")))?;
234
235 tracing::info!("Model downloaded to {}", dest_dir.display());
236 Ok(dest_dir.to_path_buf())
237 }
238
239 pub fn download_default_model() -> Result<PathBuf, CodememError> {
242 Self::download_model(&Self::default_model_dir(), DEFAULT_HF_REPO)
243 }
244
245 pub fn embed(&self, text: &str) -> Result<Vec<f32>, CodememError> {
247 let encoding = self
249 .tokenizer
250 .encode(text, true)
251 .map_err(|e| CodememError::Embedding(e.to_string()))?;
252
253 let input_ids: Vec<u32> = encoding.get_ids().to_vec();
254 let attention_mask: Vec<u32> = encoding.get_attention_mask().to_vec();
255
256 let input_ids_tensor = Tensor::new(&input_ids[..], &self.device)
258 .and_then(|t| t.unsqueeze(0))
259 .map_err(|e| CodememError::Embedding(format!("Tensor error: {e}")))?;
260
261 let token_type_ids = input_ids_tensor
262 .zeros_like()
263 .map_err(|e| CodememError::Embedding(format!("Tensor error: {e}")))?;
264
265 let attention_mask_tensor = Tensor::new(&attention_mask[..], &self.device)
266 .and_then(|t| t.unsqueeze(0))
267 .map_err(|e| CodememError::Embedding(format!("Tensor error: {e}")))?;
268
269 let model = self
271 .model
272 .lock()
273 .map_err(|e| CodememError::LockPoisoned(format!("embedding model: {e}")))?;
274 let hidden_states = model
275 .forward(
276 &input_ids_tensor,
277 &token_type_ids,
278 Some(&attention_mask_tensor),
279 )
280 .map_err(|e| CodememError::Embedding(format!("Model forward error: {e}")))?;
281 drop(model);
282
283 drop(input_ids_tensor);
285 drop(token_type_ids);
286
287 let hidden_states = hidden_states
289 .to_dtype(DType::F32)
290 .map_err(|e| CodememError::Embedding(format!("Cast error: {e}")))?;
291
292 let mask = attention_mask_tensor
295 .to_dtype(DType::F32)
296 .and_then(|t| t.unsqueeze(2))
297 .map_err(|e| CodememError::Embedding(format!("Mask error: {e}")))?;
298
299 let sum_mask = mask
300 .sum(1)
301 .map_err(|e| CodememError::Embedding(format!("Sum error: {e}")))?;
302
303 let pooled = hidden_states
304 .broadcast_mul(&mask)
305 .and_then(|t| t.sum(1))
306 .and_then(|t| t.broadcast_div(&sum_mask))
307 .map_err(|e| CodememError::Embedding(format!("Pooling error: {e}")))?;
308
309 let normalized = pooled
311 .sqr()
312 .and_then(|t| t.sum_keepdim(1))
313 .and_then(|t| t.sqrt())
314 .and_then(|norm| pooled.broadcast_div(&norm))
315 .map_err(|e| CodememError::Embedding(format!("Normalize error: {e}")))?;
316
317 let embedding: Vec<f32> = normalized
319 .squeeze(0)
320 .and_then(|t| t.to_vec1())
321 .map_err(|e| CodememError::Embedding(format!("Extract error: {e}")))?;
322
323 Ok(embedding)
324 }
325
326 pub fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, CodememError> {
332 if texts.is_empty() {
333 return Ok(vec![]);
334 }
335
336 let mut all_embeddings = Vec::with_capacity(texts.len());
337
338 for chunk in texts.chunks(self.batch_size) {
339 let mut tokenizer = self.tokenizer.clone();
342 tokenizer.with_padding(Some(PaddingParams {
343 strategy: PaddingStrategy::BatchLongest,
344 ..Default::default()
345 }));
346
347 let encodings = tokenizer
348 .encode_batch(chunk.to_vec(), true)
349 .map_err(|e| CodememError::Embedding(format!("Batch encode error: {e}")))?;
350
351 let batch_len = encodings.len();
352 let seq_len = encodings[0].get_ids().len();
353
354 let all_ids: Vec<u32> = encodings
356 .iter()
357 .flat_map(|e| e.get_ids())
358 .copied()
359 .collect();
360 let all_masks: Vec<u32> = encodings
361 .iter()
362 .flat_map(|e| e.get_attention_mask())
363 .copied()
364 .collect();
365
366 let input_ids = Tensor::new(all_ids.as_slice(), &self.device)
368 .and_then(|t| t.reshape((batch_len, seq_len)))
369 .map_err(|e| CodememError::Embedding(format!("Tensor error: {e}")))?;
370
371 let token_type_ids = input_ids
372 .zeros_like()
373 .map_err(|e| CodememError::Embedding(format!("Tensor error: {e}")))?;
374
375 let attention_mask = Tensor::new(all_masks.as_slice(), &self.device)
376 .and_then(|t| t.reshape((batch_len, seq_len)))
377 .map_err(|e| CodememError::Embedding(format!("Tensor error: {e}")))?;
378
379 let model = self
381 .model
382 .lock()
383 .map_err(|e| CodememError::LockPoisoned(format!("embedding model: {e}")))?;
384 let hidden_states = model
385 .forward(&input_ids, &token_type_ids, Some(&attention_mask))
386 .map_err(|e| CodememError::Embedding(format!("Forward error: {e}")))?;
387 drop(model);
388
389 drop(input_ids);
391 drop(token_type_ids);
392
393 let hidden_states = hidden_states
395 .to_dtype(DType::F32)
396 .map_err(|e| CodememError::Embedding(format!("Cast error: {e}")))?;
397
398 let mask = attention_mask
400 .to_dtype(DType::F32)
401 .and_then(|t| t.unsqueeze(2))
402 .map_err(|e| CodememError::Embedding(format!("Mask error: {e}")))?;
403
404 let sum_mask = mask
405 .sum(1)
406 .map_err(|e| CodememError::Embedding(format!("Sum error: {e}")))?;
407
408 let pooled = hidden_states
409 .broadcast_mul(&mask)
410 .and_then(|t| t.sum(1))
411 .and_then(|t| t.broadcast_div(&sum_mask))
412 .map_err(|e| CodememError::Embedding(format!("Pooling error: {e}")))?;
413
414 let norm = pooled
416 .sqr()
417 .and_then(|t| t.sum_keepdim(1))
418 .and_then(|t| t.sqrt())
419 .map_err(|e| CodememError::Embedding(format!("Norm error: {e}")))?;
420
421 let normalized = pooled
422 .broadcast_div(&norm)
423 .map_err(|e| CodememError::Embedding(format!("Normalize error: {e}")))?;
424
425 let flat: Vec<f32> = normalized
428 .flatten_all()
429 .and_then(|t| t.to_vec1())
430 .map_err(|e| CodememError::Embedding(format!("Extract error: {e}")))?;
431 for i in 0..batch_len {
432 let start = i * self.hidden_size;
433 all_embeddings.push(flat[start..start + self.hidden_size].to_vec());
434 }
435 }
436
437 Ok(all_embeddings)
438 }
439}
440
441impl EmbeddingProvider for EmbeddingService {
442 fn dimensions(&self) -> usize {
443 self.hidden_size
444 }
445
446 fn embed(&self, text: &str) -> Result<Vec<f32>, CodememError> {
447 self.embed(text)
448 }
449
450 fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, CodememError> {
451 self.embed_batch(texts)
452 }
453
454 fn name(&self) -> &str {
455 "candle"
456 }
457}
458
459pub struct CachedProvider {
463 inner: Box<dyn EmbeddingProvider>,
464 cache: Mutex<LruCache<String, Vec<f32>>>,
465}
466
467impl CachedProvider {
468 pub fn new(inner: Box<dyn EmbeddingProvider>, capacity: usize) -> Self {
469 let cap =
471 NonZeroUsize::new(capacity).unwrap_or(NonZeroUsize::new(1).expect("1 is non-zero"));
472 Self {
473 inner,
474 cache: Mutex::new(LruCache::new(cap)),
475 }
476 }
477}
478
479impl EmbeddingProvider for CachedProvider {
480 fn dimensions(&self) -> usize {
481 self.inner.dimensions()
482 }
483
484 fn embed(&self, text: &str) -> Result<Vec<f32>, CodememError> {
485 {
486 let mut cache = self
487 .cache
488 .lock()
489 .map_err(|e| CodememError::LockPoisoned(format!("cached provider: {e}")))?;
490 if let Some(cached) = cache.get(text) {
491 return Ok(cached.clone());
492 }
493 }
494 let embedding = self.inner.embed(text)?;
495 {
496 let mut cache = self
497 .cache
498 .lock()
499 .map_err(|e| CodememError::LockPoisoned(format!("cached provider: {e}")))?;
500 cache.put(text.to_string(), embedding.clone());
501 }
502 Ok(embedding)
503 }
504
505 fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, CodememError> {
506 let mut results: Vec<Option<Vec<f32>>> = vec![None; texts.len()];
508 let mut uncached = Vec::new();
509 let mut uncached_idx = Vec::new();
510
511 {
512 let mut cache = self
513 .cache
514 .lock()
515 .map_err(|e| CodememError::LockPoisoned(format!("cached provider: {e}")))?;
516 for (i, text) in texts.iter().enumerate() {
517 if let Some(cached) = cache.get(*text) {
518 results[i] = Some(cached.clone());
519 } else {
520 uncached_idx.push(i);
521 uncached.push(*text);
522 }
523 }
524 }
525
526 if !uncached.is_empty() {
527 let new_embeddings = self.inner.embed_batch(&uncached)?;
528 let mut cache = self
529 .cache
530 .lock()
531 .map_err(|e| CodememError::LockPoisoned(format!("cached provider: {e}")))?;
532 for (idx, embedding) in uncached_idx.into_iter().zip(new_embeddings) {
533 cache.put(texts[idx].to_string(), embedding.clone());
534 results[idx] = Some(embedding);
535 }
536 }
537
538 let expected = texts.len();
540 let output: Vec<Vec<f32>> = results
541 .into_iter()
542 .enumerate()
543 .map(|(i, opt)| {
544 opt.ok_or_else(|| {
545 CodememError::Embedding(format!(
546 "Missing embedding for text at index {i} (batch size {expected})"
547 ))
548 })
549 })
550 .collect::<Result<Vec<_>, _>>()?;
551 Ok(output)
552 }
553
554 fn name(&self) -> &str {
555 self.inner.name()
556 }
557
558 fn cache_stats(&self) -> (usize, usize) {
559 match self.cache.lock() {
560 Ok(cache) => (cache.len(), cache.cap().into()),
561 Err(_) => (0, 0),
562 }
563 }
564}
565
566pub fn parse_dtype(s: &str) -> Result<DType, CodememError> {
572 match s.to_lowercase().as_str() {
573 "f32" | "float32" | "" => Ok(DType::F32),
574 "f16" | "float16" | "half" => Ok(DType::F16),
575 "bf16" | "bfloat16" => Ok(DType::BF16),
576 other => Err(CodememError::Embedding(format!(
577 "Unknown dtype: '{}'. Use 'f32', 'f16', or 'bf16'.",
578 other
579 ))),
580 }
581}
582
583fn resolve_model_id(model: &str) -> Result<(String, String), CodememError> {
592 if model.contains('/') {
593 let dir_name = model.rsplit('/').next().unwrap_or(model);
595 Ok((model.to_string(), dir_name.to_string()))
596 } else if model.starts_with("bge-") {
597 Ok((format!("BAAI/{model}"), model.to_string()))
599 } else {
600 Err(CodememError::Embedding(format!(
601 "Model identifier '{}' must be a full HuggingFace repo ID (e.g., 'BAAI/bge-base-en-v1.5' \
602 or 'sentence-transformers/all-MiniLM-L6-v2'). Short names are only supported for 'bge-*' models.",
603 model
604 )))
605 }
606}
607
608pub fn from_env(
622 config: Option<&codemem_core::EmbeddingConfig>,
623) -> Result<Box<dyn EmbeddingProvider>, CodememError> {
624 let provider = std::env::var("CODEMEM_EMBED_PROVIDER")
625 .unwrap_or_else(|_| {
626 config
627 .map(|c| c.provider.clone())
628 .unwrap_or_else(|| "candle".to_string())
629 })
630 .to_lowercase();
631 let dimensions: usize = std::env::var("CODEMEM_EMBED_DIMENSIONS")
634 .ok()
635 .and_then(|s| s.parse().ok())
636 .unwrap_or_else(|| config.map_or(DEFAULT_REMOTE_DIMENSIONS, |c| c.dimensions));
637 let cache_capacity = config.map_or(CACHE_CAPACITY, |c| c.cache_capacity);
638 let batch_size: usize = std::env::var("CODEMEM_EMBED_BATCH_SIZE")
639 .ok()
640 .and_then(|s| s.parse().ok())
641 .unwrap_or_else(|| config.map_or(DEFAULT_BATCH_SIZE, |c| c.batch_size));
642
643 match provider.as_str() {
644 "ollama" => {
645 let base_url = std::env::var("CODEMEM_EMBED_URL").unwrap_or_else(|_| {
646 config
647 .filter(|c| !c.url.is_empty())
648 .map(|c| c.url.clone())
649 .unwrap_or_else(|| ollama::DEFAULT_BASE_URL.to_string())
650 });
651 let model = std::env::var("CODEMEM_EMBED_MODEL").unwrap_or_else(|_| {
652 config
653 .filter(|c| !c.model.is_empty())
654 .map(|c| c.model.clone())
655 .unwrap_or_else(|| ollama::DEFAULT_MODEL.to_string())
656 });
657 let inner = Box::new(ollama::OllamaProvider::new(&base_url, &model, dimensions));
658 Ok(Box::new(CachedProvider::new(inner, cache_capacity)))
659 }
660 "openai" => {
661 let api_key = std::env::var("CODEMEM_EMBED_API_KEY")
662 .or_else(|_| std::env::var("OPENAI_API_KEY"))
663 .map_err(|_| {
664 CodememError::Embedding(
665 "CODEMEM_EMBED_API_KEY or OPENAI_API_KEY required for OpenAI embeddings"
666 .into(),
667 )
668 })?;
669 let model = std::env::var("CODEMEM_EMBED_MODEL").unwrap_or_else(|_| {
670 config
671 .filter(|c| !c.model.is_empty())
672 .map(|c| c.model.clone())
673 .unwrap_or_else(|| openai::DEFAULT_MODEL.to_string())
674 });
675 let base_url = std::env::var("CODEMEM_EMBED_URL")
676 .ok()
677 .or_else(|| config.filter(|c| !c.url.is_empty()).map(|c| c.url.clone()));
678 let inner = Box::new(openai::OpenAIProvider::new(
679 &api_key,
680 &model,
681 dimensions,
682 base_url.as_deref(),
683 ));
684 Ok(Box::new(CachedProvider::new(inner, cache_capacity)))
685 }
686 "candle" | "" => {
687 let model_id = std::env::var("CODEMEM_EMBED_MODEL").unwrap_or_else(|_| {
688 config
689 .filter(|c| !c.model.is_empty())
690 .map(|c| c.model.clone())
691 .unwrap_or_else(|| DEFAULT_HF_REPO.to_string())
692 });
693 let (hf_repo, dir_name) = resolve_model_id(&model_id)?;
694 let model_dir = EmbeddingService::model_dir_for(&dir_name);
695
696 let dtype_str = std::env::var("CODEMEM_EMBED_DTYPE").unwrap_or_else(|_| {
697 config
698 .filter(|c| !c.dtype.is_empty())
699 .map(|c| c.dtype.clone())
700 .unwrap_or_else(|| "f32".to_string())
701 });
702 let dtype = parse_dtype(&dtype_str)?;
703
704 let service = EmbeddingService::new(&model_dir, batch_size, dtype).map_err(|e| {
705 if e.to_string().contains("Model not found") && hf_repo != DEFAULT_HF_REPO {
707 CodememError::Embedding(format!(
708 "Model '{}' not found at {}. Download it with:\n \
709 CODEMEM_EMBED_MODEL={} codemem init",
710 hf_repo,
711 model_dir.display(),
712 hf_repo
713 ))
714 } else {
715 e
716 }
717 })?;
718 Ok(Box::new(CachedProvider::new(
719 Box::new(service),
720 cache_capacity,
721 )))
722 }
723 other => Err(CodememError::Embedding(format!(
724 "Unknown embedding provider: '{}'. Use 'candle', 'ollama', or 'openai'.",
725 other
726 ))),
727 }
728}
729
730#[cfg(test)]
731#[path = "tests/lib_tests.rs"]
732mod tests;