1pub mod ollama;
9pub mod openai;
10
11use candle_core::{DType, Device, Tensor};
12use candle_nn::VarBuilder;
13use candle_transformers::models::bert::{BertModel, Config as BertConfig};
14use codemem_core::CodememError;
15use lru::LruCache;
16use std::num::NonZeroUsize;
17use std::path::{Path, PathBuf};
18use std::sync::Mutex;
19use tokenizers::{PaddingParams, PaddingStrategy, TruncationParams};
20
21pub const MODEL_NAME: &str = "bge-base-en-v1.5";
23
24const HF_MODEL_REPO: &str = "BAAI/bge-base-en-v1.5";
26
27pub const DIMENSIONS: usize = 768;
29
30const MAX_SEQ_LENGTH: usize = 512;
32
33pub const CACHE_CAPACITY: usize = 10_000;
35
36pub trait EmbeddingProvider: Send + Sync {
40 fn dimensions(&self) -> usize;
42
43 fn embed(&self, text: &str) -> Result<Vec<f32>, CodememError>;
45
46 fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, CodememError> {
48 texts.iter().map(|t| self.embed(t)).collect()
49 }
50
51 fn name(&self) -> &str;
53
54 fn cache_stats(&self) -> (usize, usize) {
56 (0, 0)
57 }
58}
59
60const BATCH_SIZE: usize = 32;
64
65fn select_device() -> Device {
70 #[cfg(feature = "metal")]
71 {
72 if let Ok(device) = Device::new_metal(0) {
73 tracing::info!("Using Metal GPU for embeddings");
74 return device;
75 }
76 tracing::warn!("Metal feature enabled but device creation failed, falling back");
77 }
78 #[cfg(feature = "cuda")]
79 {
80 if let Ok(device) = Device::new_cuda(0) {
81 tracing::info!("Using CUDA GPU for embeddings");
82 return device;
83 }
84 tracing::warn!("CUDA feature enabled but device creation failed, falling back");
85 }
86 tracing::info!("Using CPU for embeddings");
87 Device::Cpu
88}
89
90pub struct EmbeddingService {
92 model: Mutex<BertModel>,
93 tokenizer: tokenizers::Tokenizer,
94 device: Device,
95 cache: Mutex<LruCache<String, Vec<f32>>>,
96}
97
98impl EmbeddingService {
99 pub fn new(model_dir: &Path) -> Result<Self, CodememError> {
102 let model_path = model_dir.join("model.safetensors");
103 let config_path = model_dir.join("config.json");
104 let tokenizer_path = model_dir.join("tokenizer.json");
105
106 if !model_path.exists() {
107 return Err(CodememError::Embedding(format!(
108 "Model not found at {}. Run `codemem init` to download it.",
109 model_path.display()
110 )));
111 }
112
113 let device = select_device();
114
115 let config_str = std::fs::read_to_string(&config_path)
117 .map_err(|e| CodememError::Embedding(format!("Failed to read config: {e}")))?;
118 let config: BertConfig = serde_json::from_str(&config_str)
119 .map_err(|e| CodememError::Embedding(format!("Failed to parse config: {e}")))?;
120
121 let vb = unsafe {
123 VarBuilder::from_mmaped_safetensors(&[&model_path], DType::F32, &device)
124 .map_err(|e| CodememError::Embedding(format!("Failed to load weights: {e}")))?
125 };
126
127 let model = BertModel::load(vb.pp("bert"), &config)
129 .or_else(|_| {
130 let vb2 = unsafe {
131 VarBuilder::from_mmaped_safetensors(&[&model_path], DType::F32, &device)
132 .map_err(|e| {
133 candle_core::Error::Msg(format!("Failed to load weights: {e}"))
134 })
135 }?;
136 BertModel::load(vb2, &config)
137 })
138 .map_err(|e| CodememError::Embedding(format!("Failed to load BERT model: {e}")))?;
139
140 let tokenizer = tokenizers::Tokenizer::from_file(&tokenizer_path)
141 .map_err(|e| CodememError::Embedding(e.to_string()))?;
142
143 let cache = Mutex::new(LruCache::new(
144 NonZeroUsize::new(CACHE_CAPACITY).expect("CACHE_CAPACITY is non-zero"),
145 ));
146
147 Ok(Self {
148 model: Mutex::new(model),
149 tokenizer,
150 device,
151 cache,
152 })
153 }
154
155 pub fn default_model_dir() -> PathBuf {
157 dirs::home_dir()
158 .unwrap_or_else(|| PathBuf::from("."))
159 .join(".codemem")
160 .join("models")
161 .join(MODEL_NAME)
162 }
163
164 pub fn download_model(dest_dir: &Path) -> Result<PathBuf, CodememError> {
167 let model_dest = dest_dir.join("model.safetensors");
168 let config_dest = dest_dir.join("config.json");
169 let tokenizer_dest = dest_dir.join("tokenizer.json");
170
171 if model_dest.exists() && config_dest.exists() && tokenizer_dest.exists() {
172 tracing::info!("Model already downloaded at {}", dest_dir.display());
173 return Ok(dest_dir.to_path_buf());
174 }
175
176 std::fs::create_dir_all(dest_dir)
177 .map_err(|e| CodememError::Embedding(format!("Failed to create dir: {e}")))?;
178
179 tracing::info!("Downloading {} from HuggingFace...", HF_MODEL_REPO);
180
181 let api = hf_hub::api::sync::Api::new()
182 .map_err(|e| CodememError::Embedding(format!("HuggingFace API error: {e}")))?;
183 let repo = api.model(HF_MODEL_REPO.to_string());
184
185 let cached_model = repo
186 .get("model.safetensors")
187 .map_err(|e| CodememError::Embedding(format!("Failed to download model: {e}")))?;
188
189 let cached_config = repo
190 .get("config.json")
191 .map_err(|e| CodememError::Embedding(format!("Failed to download config: {e}")))?;
192
193 let cached_tokenizer = repo
194 .get("tokenizer.json")
195 .map_err(|e| CodememError::Embedding(format!("Failed to download tokenizer: {e}")))?;
196
197 std::fs::copy(&cached_model, &model_dest)
198 .map_err(|e| CodememError::Embedding(format!("Failed to copy model: {e}")))?;
199 std::fs::copy(&cached_config, &config_dest)
200 .map_err(|e| CodememError::Embedding(format!("Failed to copy config: {e}")))?;
201 std::fs::copy(&cached_tokenizer, &tokenizer_dest)
202 .map_err(|e| CodememError::Embedding(format!("Failed to copy tokenizer: {e}")))?;
203
204 tracing::info!("Model downloaded to {}", dest_dir.display());
205 Ok(dest_dir.to_path_buf())
206 }
207
208 pub fn embed(&self, text: &str) -> Result<Vec<f32>, CodememError> {
211 {
213 let mut cache = self
214 .cache
215 .lock()
216 .map_err(|e| CodememError::LockPoisoned(format!("embedding cache: {e}")))?;
217 if let Some(cached) = cache.get(text) {
218 return Ok(cached.clone());
219 }
220 }
221
222 let embedding = self.embed_uncached(text)?;
223
224 {
226 let mut cache = self
227 .cache
228 .lock()
229 .map_err(|e| CodememError::LockPoisoned(format!("embedding cache: {e}")))?;
230 cache.put(text.to_string(), embedding.clone());
231 }
232
233 Ok(embedding)
234 }
235
236 fn embed_uncached(&self, text: &str) -> Result<Vec<f32>, CodememError> {
238 let mut tokenizer = self.tokenizer.clone();
240
241 tokenizer
242 .with_truncation(Some(tokenizers::TruncationParams {
243 max_length: MAX_SEQ_LENGTH,
244 ..Default::default()
245 }))
246 .map_err(|e| CodememError::Embedding(format!("Truncation error: {e}")))?;
247
248 let encoding = tokenizer
249 .encode(text, true)
250 .map_err(|e| CodememError::Embedding(e.to_string()))?;
251
252 let input_ids: Vec<u32> = encoding.get_ids().to_vec();
253 let attention_mask: Vec<u32> = encoding.get_attention_mask().to_vec();
254
255 let input_ids_tensor = Tensor::new(&input_ids[..], &self.device)
257 .and_then(|t| t.unsqueeze(0))
258 .map_err(|e| CodememError::Embedding(format!("Tensor error: {e}")))?;
259
260 let token_type_ids = input_ids_tensor
261 .zeros_like()
262 .map_err(|e| CodememError::Embedding(format!("Tensor error: {e}")))?;
263
264 let attention_mask_tensor = Tensor::new(&attention_mask[..], &self.device)
265 .and_then(|t| t.unsqueeze(0))
266 .map_err(|e| CodememError::Embedding(format!("Tensor error: {e}")))?;
267
268 let model = self
270 .model
271 .lock()
272 .map_err(|e| CodememError::LockPoisoned(format!("embedding model: {e}")))?;
273 let hidden_states = model
274 .forward(
275 &input_ids_tensor,
276 &token_type_ids,
277 Some(&attention_mask_tensor),
278 )
279 .map_err(|e| CodememError::Embedding(format!("Model forward error: {e}")))?;
280 drop(model);
281
282 let mask = attention_mask_tensor
285 .to_dtype(DType::F32)
286 .and_then(|t| t.unsqueeze(2))
287 .map_err(|e| CodememError::Embedding(format!("Mask error: {e}")))?;
288
289 let sum_mask = mask
290 .sum(1)
291 .map_err(|e| CodememError::Embedding(format!("Sum error: {e}")))?;
292
293 let pooled = hidden_states
294 .broadcast_mul(&mask)
295 .and_then(|t| t.sum(1))
296 .and_then(|t| t.broadcast_div(&sum_mask))
297 .map_err(|e| CodememError::Embedding(format!("Pooling error: {e}")))?;
298
299 let normalized = pooled
301 .sqr()
302 .and_then(|t| t.sum_keepdim(1))
303 .and_then(|t| t.sqrt())
304 .and_then(|norm| pooled.broadcast_div(&norm))
305 .map_err(|e| CodememError::Embedding(format!("Normalize error: {e}")))?;
306
307 let embedding: Vec<f32> = normalized
309 .squeeze(0)
310 .and_then(|t| t.to_vec1())
311 .map_err(|e| CodememError::Embedding(format!("Extract error: {e}")))?;
312
313 Ok(embedding)
314 }
315
316 pub fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, CodememError> {
321 if texts.is_empty() {
322 return Ok(vec![]);
323 }
324
325 let mut results: Vec<Option<Vec<f32>>> = vec![None; texts.len()];
327 let mut uncached_indices = Vec::new();
328 let mut uncached_texts = Vec::new();
329
330 {
331 let mut cache = self
332 .cache
333 .lock()
334 .map_err(|e| CodememError::LockPoisoned(format!("embedding cache: {e}")))?;
335 for (i, text) in texts.iter().enumerate() {
336 if let Some(cached) = cache.get(*text) {
337 results[i] = Some(cached.clone());
338 } else {
339 uncached_indices.push(i);
340 uncached_texts.push(*text);
341 }
342 }
343 }
344
345 if !uncached_texts.is_empty() {
346 let new_embeddings = self.embed_batch_uncached(&uncached_texts)?;
347
348 let mut cache = self
349 .cache
350 .lock()
351 .map_err(|e| CodememError::LockPoisoned(format!("embedding cache: {e}")))?;
352 for (idx, embedding) in uncached_indices.into_iter().zip(new_embeddings) {
353 cache.put(texts[idx].to_string(), embedding.clone());
354 results[idx] = Some(embedding);
355 }
356 }
357
358 Ok(results.into_iter().flatten().collect())
359 }
360
361 fn embed_batch_uncached(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, CodememError> {
367 if texts.is_empty() {
368 return Ok(vec![]);
369 }
370
371 let mut all_embeddings = Vec::with_capacity(texts.len());
372
373 for chunk in texts.chunks(BATCH_SIZE) {
374 let mut tokenizer = self.tokenizer.clone();
375
376 tokenizer
377 .with_truncation(Some(TruncationParams {
378 max_length: MAX_SEQ_LENGTH,
379 ..Default::default()
380 }))
381 .map_err(|e| CodememError::Embedding(format!("Truncation error: {e}")))?;
382
383 tokenizer.with_padding(Some(PaddingParams {
385 strategy: PaddingStrategy::BatchLongest,
386 ..Default::default()
387 }));
388
389 let encodings = tokenizer
390 .encode_batch(chunk.to_vec(), true)
391 .map_err(|e| CodememError::Embedding(format!("Batch encode error: {e}")))?;
392
393 let batch_len = encodings.len();
394 let seq_len = encodings[0].get_ids().len();
395
396 let all_ids: Vec<u32> = encodings
398 .iter()
399 .flat_map(|e| e.get_ids())
400 .copied()
401 .collect();
402 let all_masks: Vec<u32> = encodings
403 .iter()
404 .flat_map(|e| e.get_attention_mask())
405 .copied()
406 .collect();
407
408 let input_ids = Tensor::new(all_ids.as_slice(), &self.device)
410 .and_then(|t| t.reshape((batch_len, seq_len)))
411 .map_err(|e| CodememError::Embedding(format!("Tensor error: {e}")))?;
412
413 let token_type_ids = input_ids
414 .zeros_like()
415 .map_err(|e| CodememError::Embedding(format!("Tensor error: {e}")))?;
416
417 let attention_mask = Tensor::new(all_masks.as_slice(), &self.device)
418 .and_then(|t| t.reshape((batch_len, seq_len)))
419 .map_err(|e| CodememError::Embedding(format!("Tensor error: {e}")))?;
420
421 let model = self
423 .model
424 .lock()
425 .map_err(|e| CodememError::LockPoisoned(format!("embedding model: {e}")))?;
426 let hidden_states = model
427 .forward(&input_ids, &token_type_ids, Some(&attention_mask))
428 .map_err(|e| CodememError::Embedding(format!("Forward error: {e}")))?;
429 drop(model);
430
431 let mask = attention_mask
433 .to_dtype(DType::F32)
434 .and_then(|t| t.unsqueeze(2))
435 .map_err(|e| CodememError::Embedding(format!("Mask error: {e}")))?;
436
437 let sum_mask = mask
438 .sum(1)
439 .map_err(|e| CodememError::Embedding(format!("Sum error: {e}")))?;
440
441 let pooled = hidden_states
442 .broadcast_mul(&mask)
443 .and_then(|t| t.sum(1))
444 .and_then(|t| t.broadcast_div(&sum_mask))
445 .map_err(|e| CodememError::Embedding(format!("Pooling error: {e}")))?;
446
447 let norm = pooled
449 .sqr()
450 .and_then(|t| t.sum_keepdim(1))
451 .and_then(|t| t.sqrt())
452 .map_err(|e| CodememError::Embedding(format!("Norm error: {e}")))?;
453
454 let normalized = pooled
455 .broadcast_div(&norm)
456 .map_err(|e| CodememError::Embedding(format!("Normalize error: {e}")))?;
457
458 for i in 0..batch_len {
460 let row: Vec<f32> = normalized
461 .get(i)
462 .and_then(|t| t.to_vec1())
463 .map_err(|e| CodememError::Embedding(format!("Extract error: {e}")))?;
464 all_embeddings.push(row);
465 }
466 }
467
468 Ok(all_embeddings)
469 }
470
471 pub fn cache_stats(&self) -> (usize, usize) {
473 match self.cache.lock() {
474 Ok(cache) => (cache.len(), CACHE_CAPACITY),
475 Err(_) => (0, CACHE_CAPACITY),
476 }
477 }
478}
479
480impl EmbeddingProvider for EmbeddingService {
481 fn dimensions(&self) -> usize {
482 DIMENSIONS
483 }
484
485 fn embed(&self, text: &str) -> Result<Vec<f32>, CodememError> {
486 self.embed(text)
487 }
488
489 fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, CodememError> {
490 self.embed_batch(texts)
491 }
492
493 fn name(&self) -> &str {
494 "candle"
495 }
496
497 fn cache_stats(&self) -> (usize, usize) {
498 self.cache_stats()
499 }
500}
501
502pub struct CachedProvider {
506 inner: Box<dyn EmbeddingProvider>,
507 cache: Mutex<LruCache<String, Vec<f32>>>,
508}
509
510impl CachedProvider {
511 pub fn new(inner: Box<dyn EmbeddingProvider>, capacity: usize) -> Self {
512 let cap =
514 NonZeroUsize::new(capacity).unwrap_or(NonZeroUsize::new(1).expect("1 is non-zero"));
515 Self {
516 inner,
517 cache: Mutex::new(LruCache::new(cap)),
518 }
519 }
520}
521
522impl EmbeddingProvider for CachedProvider {
523 fn dimensions(&self) -> usize {
524 self.inner.dimensions()
525 }
526
527 fn embed(&self, text: &str) -> Result<Vec<f32>, CodememError> {
528 {
529 let mut cache = self
530 .cache
531 .lock()
532 .map_err(|e| CodememError::LockPoisoned(format!("cached provider: {e}")))?;
533 if let Some(cached) = cache.get(text) {
534 return Ok(cached.clone());
535 }
536 }
537 let embedding = self.inner.embed(text)?;
538 {
539 let mut cache = self
540 .cache
541 .lock()
542 .map_err(|e| CodememError::LockPoisoned(format!("cached provider: {e}")))?;
543 cache.put(text.to_string(), embedding.clone());
544 }
545 Ok(embedding)
546 }
547
548 fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, CodememError> {
549 let mut results: Vec<Option<Vec<f32>>> = vec![None; texts.len()];
551 let mut uncached = Vec::new();
552 let mut uncached_idx = Vec::new();
553
554 {
555 let mut cache = self
556 .cache
557 .lock()
558 .map_err(|e| CodememError::LockPoisoned(format!("cached provider: {e}")))?;
559 for (i, text) in texts.iter().enumerate() {
560 if let Some(cached) = cache.get(*text) {
561 results[i] = Some(cached.clone());
562 } else {
563 uncached_idx.push(i);
564 uncached.push(*text);
565 }
566 }
567 }
568
569 if !uncached.is_empty() {
570 let new_embeddings = self.inner.embed_batch(&uncached)?;
571 let mut cache = self
572 .cache
573 .lock()
574 .map_err(|e| CodememError::LockPoisoned(format!("cached provider: {e}")))?;
575 for (idx, embedding) in uncached_idx.into_iter().zip(new_embeddings) {
576 cache.put(texts[idx].to_string(), embedding.clone());
577 results[idx] = Some(embedding);
578 }
579 }
580
581 Ok(results.into_iter().flatten().collect())
582 }
583
584 fn name(&self) -> &str {
585 self.inner.name()
586 }
587
588 fn cache_stats(&self) -> (usize, usize) {
589 match self.cache.lock() {
590 Ok(cache) => (cache.len(), cache.cap().into()),
591 Err(_) => (0, 0),
592 }
593 }
594}
595
596pub fn from_env() -> Result<Box<dyn EmbeddingProvider>, CodememError> {
608 let provider = std::env::var("CODEMEM_EMBED_PROVIDER")
609 .unwrap_or_else(|_| "candle".to_string())
610 .to_lowercase();
611 let dimensions: usize = std::env::var("CODEMEM_EMBED_DIMENSIONS")
612 .ok()
613 .and_then(|s| s.parse().ok())
614 .unwrap_or(DIMENSIONS);
615
616 match provider.as_str() {
617 "ollama" => {
618 let base_url = std::env::var("CODEMEM_EMBED_URL")
619 .unwrap_or_else(|_| ollama::DEFAULT_BASE_URL.to_string());
620 let model = std::env::var("CODEMEM_EMBED_MODEL")
621 .unwrap_or_else(|_| ollama::DEFAULT_MODEL.to_string());
622 let inner = Box::new(ollama::OllamaProvider::new(&base_url, &model, dimensions));
623 Ok(Box::new(CachedProvider::new(inner, CACHE_CAPACITY)))
624 }
625 "openai" => {
626 let api_key = std::env::var("CODEMEM_EMBED_API_KEY")
627 .or_else(|_| std::env::var("OPENAI_API_KEY"))
628 .map_err(|_| {
629 CodememError::Embedding(
630 "CODEMEM_EMBED_API_KEY or OPENAI_API_KEY required for OpenAI embeddings"
631 .into(),
632 )
633 })?;
634 let model = std::env::var("CODEMEM_EMBED_MODEL")
635 .unwrap_or_else(|_| openai::DEFAULT_MODEL.to_string());
636 let base_url = std::env::var("CODEMEM_EMBED_URL").ok();
637 let inner = Box::new(openai::OpenAIProvider::new(
638 &api_key,
639 &model,
640 dimensions,
641 base_url.as_deref(),
642 ));
643 Ok(Box::new(CachedProvider::new(inner, CACHE_CAPACITY)))
644 }
645 "candle" | "" => {
646 let model_dir = EmbeddingService::default_model_dir();
647 let service = EmbeddingService::new(&model_dir)?;
648 Ok(Box::new(service))
649 }
650 other => Err(CodememError::Embedding(format!(
651 "Unknown embedding provider: '{}'. Use 'candle', 'ollama', or 'openai'.",
652 other
653 ))),
654 }
655}
656
657#[cfg(test)]
658mod tests {
659 use super::*;
660 use std::sync::atomic::{AtomicUsize, Ordering};
661
662 static ENV_MUTEX: std::sync::Mutex<()> = std::sync::Mutex::new(());
664
665 struct MockProvider {
667 dims: usize,
668 call_count: AtomicUsize,
669 }
670
671 impl MockProvider {
672 fn new(dims: usize) -> Self {
673 Self {
674 dims,
675 call_count: AtomicUsize::new(0),
676 }
677 }
678 }
679
680 impl EmbeddingProvider for MockProvider {
681 fn dimensions(&self) -> usize {
682 self.dims
683 }
684
685 fn embed(&self, _text: &str) -> Result<Vec<f32>, CodememError> {
686 self.call_count.fetch_add(1, Ordering::SeqCst);
687 Ok(vec![0.1; self.dims])
688 }
689
690 fn name(&self) -> &str {
691 "mock"
692 }
693 }
694
695 #[test]
696 fn cached_provider_cache_hit() {
697 let mock = MockProvider::new(4);
698 let provider = CachedProvider::new(Box::new(mock), 100);
699
700 let v1 = provider.embed("hello").unwrap();
702 assert_eq!(v1.len(), 4);
703
704 let v2 = provider.embed("hello").unwrap();
706 assert_eq!(v1, v2);
707
708 let (size, cap) = provider.cache_stats();
711 assert_eq!(size, 1);
712 assert_eq!(cap, 100);
713 }
714
715 #[test]
716 fn cached_provider_cache_miss() {
717 let mock = MockProvider::new(4);
718 let provider = CachedProvider::new(Box::new(mock), 100);
719
720 provider.embed("hello").unwrap();
721 provider.embed("world").unwrap();
722
723 let (size, _) = provider.cache_stats();
724 assert_eq!(size, 2);
725 }
726
727 #[test]
728 fn cached_provider_batch_empty() {
729 let mock = MockProvider::new(4);
730 let provider = CachedProvider::new(Box::new(mock), 100);
731
732 let result = provider.embed_batch(&[]).unwrap();
733 assert!(result.is_empty());
734 }
735
736 #[test]
737 fn cached_provider_batch_single() {
738 let mock = MockProvider::new(4);
739 let provider = CachedProvider::new(Box::new(mock), 100);
740
741 let result = provider.embed_batch(&["hello"]).unwrap();
742 assert_eq!(result.len(), 1);
743 assert_eq!(result[0].len(), 4);
744
745 let (size, _) = provider.cache_stats();
746 assert_eq!(size, 1);
747 }
748
749 #[test]
750 fn cached_provider_batch_mixed_cache() {
751 let mock = MockProvider::new(4);
752 let provider = CachedProvider::new(Box::new(mock), 100);
753
754 provider.embed("hello").unwrap();
756
757 let result = provider.embed_batch(&["hello", "world"]).unwrap();
759 assert_eq!(result.len(), 2);
760
761 let (size, _) = provider.cache_stats();
762 assert_eq!(size, 2);
763 }
764
765 #[test]
766 fn cached_provider_zero_capacity() {
767 let mock = MockProvider::new(4);
769 let provider = CachedProvider::new(Box::new(mock), 0);
770
771 provider.embed("a").unwrap();
772 provider.embed("b").unwrap();
773
774 let (size, cap) = provider.cache_stats();
775 assert_eq!(cap, 1);
777 assert_eq!(size, 1);
778 }
779
780 #[test]
781 fn cached_provider_name_delegates() {
782 let mock = MockProvider::new(4);
783 let provider = CachedProvider::new(Box::new(mock), 100);
784 assert_eq!(provider.name(), "mock");
785 }
786
787 #[test]
788 fn cached_provider_dimensions_delegates() {
789 let mock = MockProvider::new(768);
790 let provider = CachedProvider::new(Box::new(mock), 100);
791 assert_eq!(provider.dimensions(), 768);
792 }
793
794 #[test]
795 fn from_env_unknown_provider() {
796 let _lock = ENV_MUTEX.lock().unwrap_or_else(|e| e.into_inner());
797 std::env::set_var("CODEMEM_EMBED_PROVIDER", "nonexistent_provider_xyz");
799 let result = from_env();
800 std::env::remove_var("CODEMEM_EMBED_PROVIDER");
801
802 match result {
803 Err(e) => {
804 let err = e.to_string();
805 assert!(
806 err.contains("Unknown embedding provider"),
807 "Error should mention unknown provider: {err}"
808 );
809 }
810 Ok(_) => panic!("Expected error for unknown provider"),
811 }
812 }
813
814 #[test]
815 fn embedding_service_missing_model() {
816 match EmbeddingService::new(Path::new("/nonexistent/path")) {
817 Err(e) => {
818 let err = e.to_string();
819 assert!(
820 err.contains("Model not found"),
821 "Error should mention missing model: {err}"
822 );
823 }
824 Ok(_) => panic!("Expected error for missing model"),
825 }
826 }
827
828 #[test]
829 fn default_model_dir_path() {
830 let dir = EmbeddingService::default_model_dir();
831 assert!(dir.to_string_lossy().contains(MODEL_NAME));
832 assert!(dir.to_string_lossy().contains(".codemem"));
833 }
834
835 #[test]
836 fn constants_are_sensible() {
837 assert_eq!(DIMENSIONS, 768);
838 assert_eq!(CACHE_CAPACITY, 10_000);
839 assert_eq!(MODEL_NAME, "bge-base-en-v1.5");
840 }
841
842 #[test]
843 fn from_env_ollama_provider() {
844 let _lock = ENV_MUTEX.lock().unwrap_or_else(|e| e.into_inner());
845 std::env::set_var("CODEMEM_EMBED_PROVIDER", "ollama");
846 let result = from_env();
847 std::env::remove_var("CODEMEM_EMBED_PROVIDER");
848
849 let provider = result.expect("from_env should succeed for ollama");
850 assert_eq!(provider.name(), "ollama");
851 }
852
853 #[test]
854 fn from_env_openai_provider() {
855 let _lock = ENV_MUTEX.lock().unwrap_or_else(|e| e.into_inner());
856 std::env::set_var("CODEMEM_EMBED_PROVIDER", "openai");
857 std::env::set_var("OPENAI_API_KEY", "test-key-123");
858 let result = from_env();
859 std::env::remove_var("CODEMEM_EMBED_PROVIDER");
860 std::env::remove_var("OPENAI_API_KEY");
861
862 let provider = result.expect("from_env should succeed for openai");
863 assert_eq!(provider.name(), "openai");
864 }
865}