1pub mod ollama;
9pub mod openai;
10
11use candle_core::{DType, Device, Tensor};
12use candle_nn::VarBuilder;
13use candle_transformers::models::bert::{BertModel, Config as BertConfig};
14use codemem_core::CodememError;
15use lru::LruCache;
16use std::num::NonZeroUsize;
17use std::path::{Path, PathBuf};
18use std::sync::Mutex;
19use tokenizers::{PaddingParams, PaddingStrategy};
20
21pub const MODEL_NAME: &str = "bge-base-en-v1.5";
23
24const HF_MODEL_REPO: &str = "BAAI/bge-base-en-v1.5";
26
27pub const DIMENSIONS: usize = 768;
29
30const MAX_SEQ_LENGTH: usize = 512;
32
33pub const CACHE_CAPACITY: usize = 10_000;
35
36pub use codemem_core::EmbeddingProvider;
38
39pub const DEFAULT_BATCH_SIZE: usize = 16;
44
45fn select_device() -> Device {
50 #[cfg(feature = "metal")]
51 {
52 if let Ok(device) = Device::new_metal(0) {
53 tracing::info!("Using Metal GPU for embeddings");
54 return device;
55 }
56 tracing::warn!("Metal feature enabled but device creation failed, falling back");
57 }
58 #[cfg(feature = "cuda")]
59 {
60 if let Ok(device) = Device::new_cuda(0) {
61 tracing::info!("Using CUDA GPU for embeddings");
62 return device;
63 }
64 tracing::warn!("CUDA feature enabled but device creation failed, falling back");
65 }
66 tracing::info!("Using CPU for embeddings");
67 Device::Cpu
68}
69
70pub struct EmbeddingService {
72 model: Mutex<BertModel>,
73 tokenizer: tokenizers::Tokenizer,
76 device: Device,
77 batch_size: usize,
79}
80
81impl EmbeddingService {
82 pub fn new(model_dir: &Path, batch_size: usize) -> Result<Self, CodememError> {
85 let model_path = model_dir.join("model.safetensors");
86 let config_path = model_dir.join("config.json");
87 let tokenizer_path = model_dir.join("tokenizer.json");
88
89 if !model_path.exists() {
90 return Err(CodememError::Embedding(format!(
91 "Model not found at {}. Run `codemem init` to download it.",
92 model_path.display()
93 )));
94 }
95
96 let device = select_device();
97
98 let config_str = std::fs::read_to_string(&config_path)
100 .map_err(|e| CodememError::Embedding(format!("Failed to read config: {e}")))?;
101 let config: BertConfig = serde_json::from_str(&config_str)
102 .map_err(|e| CodememError::Embedding(format!("Failed to parse config: {e}")))?;
103
104 let model = {
108 let vb = unsafe {
109 VarBuilder::from_mmaped_safetensors(&[&model_path], DType::F32, &device)
110 .map_err(|e| CodememError::Embedding(format!("Failed to load weights: {e}")))?
111 };
112 BertModel::load(vb.pp("bert"), &config)
113 };
114 let model = match model {
116 Ok(m) => m,
117 Err(_) => {
118 let vb2 = unsafe {
119 VarBuilder::from_mmaped_safetensors(&[&model_path], DType::F32, &device)
120 .map_err(|e| {
121 CodememError::Embedding(format!("Failed to load weights: {e}"))
122 })?
123 };
124 BertModel::load(vb2, &config).map_err(|e| {
125 CodememError::Embedding(format!("Failed to load BERT model: {e}"))
126 })?
127 }
128 };
129
130 let mut tokenizer = tokenizers::Tokenizer::from_file(&tokenizer_path)
131 .map_err(|e| CodememError::Embedding(e.to_string()))?;
132
133 tokenizer
135 .with_truncation(Some(tokenizers::TruncationParams {
136 max_length: MAX_SEQ_LENGTH,
137 ..Default::default()
138 }))
139 .map_err(|e| CodememError::Embedding(format!("Truncation error: {e}")))?;
140
141 Ok(Self {
142 model: Mutex::new(model),
143 tokenizer,
144 device,
145 batch_size,
146 })
147 }
148
149 pub fn default_model_dir() -> PathBuf {
151 dirs::home_dir()
152 .unwrap_or_else(|| PathBuf::from("."))
153 .join(".codemem")
154 .join("models")
155 .join(MODEL_NAME)
156 }
157
158 pub fn download_model(dest_dir: &Path) -> Result<PathBuf, CodememError> {
161 let model_dest = dest_dir.join("model.safetensors");
162 let config_dest = dest_dir.join("config.json");
163 let tokenizer_dest = dest_dir.join("tokenizer.json");
164
165 if model_dest.exists() && config_dest.exists() && tokenizer_dest.exists() {
166 tracing::info!("Model already downloaded at {}", dest_dir.display());
167 return Ok(dest_dir.to_path_buf());
168 }
169
170 std::fs::create_dir_all(dest_dir)
171 .map_err(|e| CodememError::Embedding(format!("Failed to create dir: {e}")))?;
172
173 tracing::info!("Downloading {} from HuggingFace...", HF_MODEL_REPO);
174
175 let api = hf_hub::api::sync::Api::new()
176 .map_err(|e| CodememError::Embedding(format!("HuggingFace API error: {e}")))?;
177 let repo = api.model(HF_MODEL_REPO.to_string());
178
179 let cached_model = repo
180 .get("model.safetensors")
181 .map_err(|e| CodememError::Embedding(format!("Failed to download model: {e}")))?;
182
183 let cached_config = repo
184 .get("config.json")
185 .map_err(|e| CodememError::Embedding(format!("Failed to download config: {e}")))?;
186
187 let cached_tokenizer = repo
188 .get("tokenizer.json")
189 .map_err(|e| CodememError::Embedding(format!("Failed to download tokenizer: {e}")))?;
190
191 std::fs::copy(&cached_model, &model_dest)
192 .map_err(|e| CodememError::Embedding(format!("Failed to copy model: {e}")))?;
193 std::fs::copy(&cached_config, &config_dest)
194 .map_err(|e| CodememError::Embedding(format!("Failed to copy config: {e}")))?;
195 std::fs::copy(&cached_tokenizer, &tokenizer_dest)
196 .map_err(|e| CodememError::Embedding(format!("Failed to copy tokenizer: {e}")))?;
197
198 tracing::info!("Model downloaded to {}", dest_dir.display());
199 Ok(dest_dir.to_path_buf())
200 }
201
202 pub fn embed(&self, text: &str) -> Result<Vec<f32>, CodememError> {
204 let encoding = self
206 .tokenizer
207 .encode(text, true)
208 .map_err(|e| CodememError::Embedding(e.to_string()))?;
209
210 let input_ids: Vec<u32> = encoding.get_ids().to_vec();
211 let attention_mask: Vec<u32> = encoding.get_attention_mask().to_vec();
212
213 let input_ids_tensor = Tensor::new(&input_ids[..], &self.device)
215 .and_then(|t| t.unsqueeze(0))
216 .map_err(|e| CodememError::Embedding(format!("Tensor error: {e}")))?;
217
218 let token_type_ids = input_ids_tensor
219 .zeros_like()
220 .map_err(|e| CodememError::Embedding(format!("Tensor error: {e}")))?;
221
222 let attention_mask_tensor = Tensor::new(&attention_mask[..], &self.device)
223 .and_then(|t| t.unsqueeze(0))
224 .map_err(|e| CodememError::Embedding(format!("Tensor error: {e}")))?;
225
226 let model = self
228 .model
229 .lock()
230 .map_err(|e| CodememError::LockPoisoned(format!("embedding model: {e}")))?;
231 let hidden_states = model
232 .forward(
233 &input_ids_tensor,
234 &token_type_ids,
235 Some(&attention_mask_tensor),
236 )
237 .map_err(|e| CodememError::Embedding(format!("Model forward error: {e}")))?;
238 drop(model);
239
240 drop(input_ids_tensor);
242 drop(token_type_ids);
243
244 let mask = attention_mask_tensor
247 .to_dtype(DType::F32)
248 .and_then(|t| t.unsqueeze(2))
249 .map_err(|e| CodememError::Embedding(format!("Mask error: {e}")))?;
250
251 let sum_mask = mask
252 .sum(1)
253 .map_err(|e| CodememError::Embedding(format!("Sum error: {e}")))?;
254
255 let pooled = hidden_states
256 .broadcast_mul(&mask)
257 .and_then(|t| t.sum(1))
258 .and_then(|t| t.broadcast_div(&sum_mask))
259 .map_err(|e| CodememError::Embedding(format!("Pooling error: {e}")))?;
260
261 let normalized = pooled
263 .sqr()
264 .and_then(|t| t.sum_keepdim(1))
265 .and_then(|t| t.sqrt())
266 .and_then(|norm| pooled.broadcast_div(&norm))
267 .map_err(|e| CodememError::Embedding(format!("Normalize error: {e}")))?;
268
269 let embedding: Vec<f32> = normalized
271 .squeeze(0)
272 .and_then(|t| t.to_vec1())
273 .map_err(|e| CodememError::Embedding(format!("Extract error: {e}")))?;
274
275 Ok(embedding)
276 }
277
278 pub fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, CodememError> {
284 if texts.is_empty() {
285 return Ok(vec![]);
286 }
287
288 let mut all_embeddings = Vec::with_capacity(texts.len());
289
290 for chunk in texts.chunks(self.batch_size) {
291 let mut tokenizer = self.tokenizer.clone();
294 tokenizer.with_padding(Some(PaddingParams {
295 strategy: PaddingStrategy::BatchLongest,
296 ..Default::default()
297 }));
298
299 let encodings = tokenizer
300 .encode_batch(chunk.to_vec(), true)
301 .map_err(|e| CodememError::Embedding(format!("Batch encode error: {e}")))?;
302
303 let batch_len = encodings.len();
304 let seq_len = encodings[0].get_ids().len();
305
306 let all_ids: Vec<u32> = encodings
308 .iter()
309 .flat_map(|e| e.get_ids())
310 .copied()
311 .collect();
312 let all_masks: Vec<u32> = encodings
313 .iter()
314 .flat_map(|e| e.get_attention_mask())
315 .copied()
316 .collect();
317
318 let input_ids = Tensor::new(all_ids.as_slice(), &self.device)
320 .and_then(|t| t.reshape((batch_len, seq_len)))
321 .map_err(|e| CodememError::Embedding(format!("Tensor error: {e}")))?;
322
323 let token_type_ids = input_ids
324 .zeros_like()
325 .map_err(|e| CodememError::Embedding(format!("Tensor error: {e}")))?;
326
327 let attention_mask = Tensor::new(all_masks.as_slice(), &self.device)
328 .and_then(|t| t.reshape((batch_len, seq_len)))
329 .map_err(|e| CodememError::Embedding(format!("Tensor error: {e}")))?;
330
331 let model = self
333 .model
334 .lock()
335 .map_err(|e| CodememError::LockPoisoned(format!("embedding model: {e}")))?;
336 let hidden_states = model
337 .forward(&input_ids, &token_type_ids, Some(&attention_mask))
338 .map_err(|e| CodememError::Embedding(format!("Forward error: {e}")))?;
339 drop(model);
340
341 drop(input_ids);
343 drop(token_type_ids);
344
345 let mask = attention_mask
347 .to_dtype(DType::F32)
348 .and_then(|t| t.unsqueeze(2))
349 .map_err(|e| CodememError::Embedding(format!("Mask error: {e}")))?;
350
351 let sum_mask = mask
352 .sum(1)
353 .map_err(|e| CodememError::Embedding(format!("Sum error: {e}")))?;
354
355 let pooled = hidden_states
356 .broadcast_mul(&mask)
357 .and_then(|t| t.sum(1))
358 .and_then(|t| t.broadcast_div(&sum_mask))
359 .map_err(|e| CodememError::Embedding(format!("Pooling error: {e}")))?;
360
361 let norm = pooled
363 .sqr()
364 .and_then(|t| t.sum_keepdim(1))
365 .and_then(|t| t.sqrt())
366 .map_err(|e| CodememError::Embedding(format!("Norm error: {e}")))?;
367
368 let normalized = pooled
369 .broadcast_div(&norm)
370 .map_err(|e| CodememError::Embedding(format!("Normalize error: {e}")))?;
371
372 let flat: Vec<f32> = normalized
375 .flatten_all()
376 .and_then(|t| t.to_vec1())
377 .map_err(|e| CodememError::Embedding(format!("Extract error: {e}")))?;
378 for i in 0..batch_len {
379 let start = i * DIMENSIONS;
380 all_embeddings.push(flat[start..start + DIMENSIONS].to_vec());
381 }
382 }
383
384 Ok(all_embeddings)
385 }
386}
387
388impl EmbeddingProvider for EmbeddingService {
389 fn dimensions(&self) -> usize {
390 DIMENSIONS
391 }
392
393 fn embed(&self, text: &str) -> Result<Vec<f32>, CodememError> {
394 self.embed(text)
395 }
396
397 fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, CodememError> {
398 self.embed_batch(texts)
399 }
400
401 fn name(&self) -> &str {
402 "candle"
403 }
404}
405
406pub struct CachedProvider {
410 inner: Box<dyn EmbeddingProvider>,
411 cache: Mutex<LruCache<String, Vec<f32>>>,
412}
413
414impl CachedProvider {
415 pub fn new(inner: Box<dyn EmbeddingProvider>, capacity: usize) -> Self {
416 let cap =
418 NonZeroUsize::new(capacity).unwrap_or(NonZeroUsize::new(1).expect("1 is non-zero"));
419 Self {
420 inner,
421 cache: Mutex::new(LruCache::new(cap)),
422 }
423 }
424}
425
426impl EmbeddingProvider for CachedProvider {
427 fn dimensions(&self) -> usize {
428 self.inner.dimensions()
429 }
430
431 fn embed(&self, text: &str) -> Result<Vec<f32>, CodememError> {
432 {
433 let mut cache = self
434 .cache
435 .lock()
436 .map_err(|e| CodememError::LockPoisoned(format!("cached provider: {e}")))?;
437 if let Some(cached) = cache.get(text) {
438 return Ok(cached.clone());
439 }
440 }
441 let embedding = self.inner.embed(text)?;
442 {
443 let mut cache = self
444 .cache
445 .lock()
446 .map_err(|e| CodememError::LockPoisoned(format!("cached provider: {e}")))?;
447 cache.put(text.to_string(), embedding.clone());
448 }
449 Ok(embedding)
450 }
451
452 fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, CodememError> {
453 let mut results: Vec<Option<Vec<f32>>> = vec![None; texts.len()];
455 let mut uncached = Vec::new();
456 let mut uncached_idx = Vec::new();
457
458 {
459 let mut cache = self
460 .cache
461 .lock()
462 .map_err(|e| CodememError::LockPoisoned(format!("cached provider: {e}")))?;
463 for (i, text) in texts.iter().enumerate() {
464 if let Some(cached) = cache.get(*text) {
465 results[i] = Some(cached.clone());
466 } else {
467 uncached_idx.push(i);
468 uncached.push(*text);
469 }
470 }
471 }
472
473 if !uncached.is_empty() {
474 let new_embeddings = self.inner.embed_batch(&uncached)?;
475 let mut cache = self
476 .cache
477 .lock()
478 .map_err(|e| CodememError::LockPoisoned(format!("cached provider: {e}")))?;
479 for (idx, embedding) in uncached_idx.into_iter().zip(new_embeddings) {
480 cache.put(texts[idx].to_string(), embedding.clone());
481 results[idx] = Some(embedding);
482 }
483 }
484
485 let expected = texts.len();
487 let output: Vec<Vec<f32>> = results
488 .into_iter()
489 .enumerate()
490 .map(|(i, opt)| {
491 opt.ok_or_else(|| {
492 CodememError::Embedding(format!(
493 "Missing embedding for text at index {i} (batch size {expected})"
494 ))
495 })
496 })
497 .collect::<Result<Vec<_>, _>>()?;
498 Ok(output)
499 }
500
501 fn name(&self) -> &str {
502 self.inner.name()
503 }
504
505 fn cache_stats(&self) -> (usize, usize) {
506 match self.cache.lock() {
507 Ok(cache) => (cache.len(), cache.cap().into()),
508 Err(_) => (0, 0),
509 }
510 }
511}
512
513pub fn from_env(
527 config: Option<&codemem_core::EmbeddingConfig>,
528) -> Result<Box<dyn EmbeddingProvider>, CodememError> {
529 let provider = std::env::var("CODEMEM_EMBED_PROVIDER")
530 .unwrap_or_else(|_| {
531 config
532 .map(|c| c.provider.clone())
533 .unwrap_or_else(|| "candle".to_string())
534 })
535 .to_lowercase();
536 let dimensions: usize = std::env::var("CODEMEM_EMBED_DIMENSIONS")
537 .ok()
538 .and_then(|s| s.parse().ok())
539 .unwrap_or_else(|| config.map_or(DIMENSIONS, |c| c.dimensions));
540 let cache_capacity = config.map_or(CACHE_CAPACITY, |c| c.cache_capacity);
541 let batch_size = config.map_or(DEFAULT_BATCH_SIZE, |c| c.batch_size);
542
543 match provider.as_str() {
544 "ollama" => {
545 let base_url = std::env::var("CODEMEM_EMBED_URL").unwrap_or_else(|_| {
546 config
547 .filter(|c| !c.url.is_empty())
548 .map(|c| c.url.clone())
549 .unwrap_or_else(|| ollama::DEFAULT_BASE_URL.to_string())
550 });
551 let model = std::env::var("CODEMEM_EMBED_MODEL").unwrap_or_else(|_| {
552 config
553 .filter(|c| !c.model.is_empty())
554 .map(|c| c.model.clone())
555 .unwrap_or_else(|| ollama::DEFAULT_MODEL.to_string())
556 });
557 let inner = Box::new(ollama::OllamaProvider::new(&base_url, &model, dimensions));
558 Ok(Box::new(CachedProvider::new(inner, cache_capacity)))
559 }
560 "openai" => {
561 let api_key = std::env::var("CODEMEM_EMBED_API_KEY")
562 .or_else(|_| std::env::var("OPENAI_API_KEY"))
563 .map_err(|_| {
564 CodememError::Embedding(
565 "CODEMEM_EMBED_API_KEY or OPENAI_API_KEY required for OpenAI embeddings"
566 .into(),
567 )
568 })?;
569 let model = std::env::var("CODEMEM_EMBED_MODEL").unwrap_or_else(|_| {
570 config
571 .filter(|c| !c.model.is_empty())
572 .map(|c| c.model.clone())
573 .unwrap_or_else(|| openai::DEFAULT_MODEL.to_string())
574 });
575 let base_url = std::env::var("CODEMEM_EMBED_URL")
576 .ok()
577 .or_else(|| config.filter(|c| !c.url.is_empty()).map(|c| c.url.clone()));
578 let inner = Box::new(openai::OpenAIProvider::new(
579 &api_key,
580 &model,
581 dimensions,
582 base_url.as_deref(),
583 ));
584 Ok(Box::new(CachedProvider::new(inner, cache_capacity)))
585 }
586 "candle" | "" => {
587 let model_dir = EmbeddingService::default_model_dir();
588 let service = EmbeddingService::new(&model_dir, batch_size)?;
589 Ok(Box::new(CachedProvider::new(
590 Box::new(service),
591 cache_capacity,
592 )))
593 }
594 other => Err(CodememError::Embedding(format!(
595 "Unknown embedding provider: '{}'. Use 'candle', 'ollama', or 'openai'.",
596 other
597 ))),
598 }
599}
600
601#[cfg(test)]
602#[path = "tests/lib_tests.rs"]
603mod tests;