1pub mod gemini;
10pub mod ollama;
11pub mod openai;
12
13use candle_core::{DType, Device, Tensor};
14use candle_nn::{Module, VarBuilder};
15use candle_transformers::models::bert::{BertModel, Config as BertConfig};
16use candle_transformers::models::jina_bert::{
17 BertModel as JinaBertModel, Config as JinaBertConfig,
18};
19use codemem_core::CodememError;
20use lru::LruCache;
21use std::num::NonZeroUsize;
22use std::path::{Path, PathBuf};
23use std::sync::Mutex;
24use tokenizers::{PaddingParams, PaddingStrategy};
25
26pub const MODEL_NAME: &str = "bge-base-en-v1.5";
28
29pub const DEFAULT_HF_REPO: &str = "BAAI/bge-base-en-v1.5";
32
33pub const DEFAULT_REMOTE_DIMENSIONS: usize = 768;
36
37const DEFAULT_MAX_SEQ_LENGTH: usize = 512;
39
40pub const CACHE_CAPACITY: usize = 10_000;
42
43pub use codemem_core::EmbeddingProvider;
45
46pub const DEFAULT_BATCH_SIZE: usize = 16;
51
52fn select_device() -> Device {
57 #[cfg(feature = "metal")]
58 {
59 match std::panic::catch_unwind(|| Device::new_metal(0)) {
61 Ok(Ok(device)) => {
62 tracing::info!("Using Metal GPU for embeddings");
63 return device;
64 }
65 Ok(Err(e)) => {
66 tracing::warn!("Metal device creation failed: {e}, falling back to CPU");
67 }
68 Err(_) => {
69 tracing::warn!("Metal device creation panicked, falling back to CPU");
70 }
71 }
72 }
73 #[cfg(feature = "cuda")]
74 {
75 match std::panic::catch_unwind(|| Device::new_cuda(0)) {
76 Ok(Ok(device)) => {
77 tracing::info!("Using CUDA GPU for embeddings");
78 return device;
79 }
80 Ok(Err(e)) => {
81 tracing::warn!("CUDA device creation failed: {e}, falling back to CPU");
82 }
83 Err(_) => {
84 tracing::warn!("CUDA device creation panicked, falling back to CPU");
85 }
86 }
87 }
88 tracing::info!("Using CPU for embeddings");
89 Device::Cpu
90}
91
92enum ModelBackend {
94 Bert(BertModel),
96 JinaBert(JinaBertModel),
98}
99
100pub struct EmbeddingService {
102 model: Mutex<ModelBackend>,
103 tokenizer: tokenizers::Tokenizer,
106 device: Device,
107 batch_size: usize,
109 hidden_size: usize,
111 max_seq_length: usize,
113}
114
115#[derive(serde::Deserialize)]
117struct ConfigProbe {
118 #[serde(default)]
119 position_embedding_type: Option<String>,
120 hidden_size: usize,
121 #[serde(default = "default_max_position_embeddings")]
122 max_position_embeddings: usize,
123}
124
125fn default_max_position_embeddings() -> usize {
126 DEFAULT_MAX_SEQ_LENGTH
127}
128
129impl EmbeddingService {
130 pub fn new(model_dir: &Path, batch_size: usize, dtype: DType) -> Result<Self, CodememError> {
136 let model_path = model_dir.join("model.safetensors");
137 let config_path = model_dir.join("config.json");
138 let tokenizer_path = model_dir.join("tokenizer.json");
139
140 if !model_path.exists() {
141 return Err(CodememError::Embedding(format!(
142 "Model not found at {}. Run `codemem init` to download it.",
143 model_path.display()
144 )));
145 }
146
147 let device = select_device();
148
149 tracing::info!(
150 "Loading model from {} (dtype: {:?}, device: {:?})",
151 model_dir.display(),
152 dtype,
153 device
154 );
155
156 let config_str = std::fs::read_to_string(&config_path)
157 .map_err(|e| CodememError::Embedding(format!("Failed to read config: {e}")))?;
158
159 let probe: ConfigProbe = serde_json::from_str(&config_str)
161 .map_err(|e| CodememError::Embedding(format!("Failed to probe config: {e}")))?;
162 let hidden_size = probe.hidden_size;
163 let is_alibi = probe
164 .position_embedding_type
165 .as_deref()
166 .is_some_and(|t| t == "alibi");
167 let max_seq_length = probe.max_position_embeddings.min(8192);
169
170 let (model, arch_name) = if is_alibi {
171 let config: JinaBertConfig = serde_json::from_str(&config_str).map_err(|e| {
173 CodememError::Embedding(format!("Failed to parse JinaBERT config: {e}"))
174 })?;
175 let vb = unsafe {
176 VarBuilder::from_mmaped_safetensors(&[&model_path], dtype, &device)
177 .map_err(|e| CodememError::Embedding(format!("Failed to load weights: {e}")))?
178 };
179 let jina_model = JinaBertModel::new(vb.pp("bert"), &config).map_err(|e| {
181 CodememError::Embedding(format!("Failed to load JinaBERT model: {e}"))
182 })?;
183 (ModelBackend::JinaBert(jina_model), "JinaBERT (ALiBi)")
184 } else {
185 let config: BertConfig = serde_json::from_str(&config_str).map_err(|e| {
187 CodememError::Embedding(format!("Failed to parse BERT config: {e}"))
188 })?;
189 let bert_model = {
193 let vb = unsafe {
194 VarBuilder::from_mmaped_safetensors(&[&model_path], dtype, &device).map_err(
195 |e| CodememError::Embedding(format!("Failed to load weights: {e}")),
196 )?
197 };
198 BertModel::load(vb.pp("bert"), &config)
199 };
200 let bert_model = match bert_model {
202 Ok(m) => m,
203 Err(_) => {
204 let vb2 = unsafe {
205 VarBuilder::from_mmaped_safetensors(&[&model_path], dtype, &device)
206 .map_err(|e| {
207 CodememError::Embedding(format!("Failed to load weights: {e}"))
208 })?
209 };
210 BertModel::load(vb2, &config).map_err(|e| {
211 CodememError::Embedding(format!("Failed to load BERT model: {e}"))
212 })?
213 }
214 };
215 (ModelBackend::Bert(bert_model), "BERT (absolute)")
216 };
217
218 tracing::info!(
219 "Loaded {} model (hidden_size={}, max_seq_length={})",
220 arch_name,
221 hidden_size,
222 max_seq_length
223 );
224
225 let mut tokenizer = tokenizers::Tokenizer::from_file(&tokenizer_path)
226 .map_err(|e| CodememError::Embedding(e.to_string()))?;
227
228 tokenizer
230 .with_truncation(Some(tokenizers::TruncationParams {
231 max_length: max_seq_length,
232 ..Default::default()
233 }))
234 .map_err(|e| CodememError::Embedding(format!("Truncation error: {e}")))?;
235
236 Ok(Self {
237 model: Mutex::new(model),
238 tokenizer,
239 device,
240 batch_size,
241 hidden_size,
242 max_seq_length,
243 })
244 }
245
246 pub fn max_seq_length(&self) -> usize {
248 self.max_seq_length
249 }
250
251 pub fn model_dir_for(model_name: &str) -> PathBuf {
254 dirs::home_dir()
255 .unwrap_or_else(|| PathBuf::from("."))
256 .join(".codemem")
257 .join("models")
258 .join(model_name)
259 }
260
261 pub fn default_model_dir() -> PathBuf {
263 Self::model_dir_for(MODEL_NAME)
264 }
265
266 pub fn download_model(dest_dir: &Path, hf_repo: &str) -> Result<PathBuf, CodememError> {
270 let model_dest = dest_dir.join("model.safetensors");
271 let config_dest = dest_dir.join("config.json");
272 let tokenizer_dest = dest_dir.join("tokenizer.json");
273
274 if model_dest.exists() && config_dest.exists() && tokenizer_dest.exists() {
275 tracing::info!("Model already downloaded at {}", dest_dir.display());
276 return Ok(dest_dir.to_path_buf());
277 }
278
279 std::fs::create_dir_all(dest_dir)
280 .map_err(|e| CodememError::Embedding(format!("Failed to create dir: {e}")))?;
281
282 tracing::info!("Downloading {} from HuggingFace...", hf_repo);
283
284 let api = hf_hub::api::sync::Api::new()
285 .map_err(|e| CodememError::Embedding(format!("HuggingFace API error: {e}")))?;
286 let repo = api.model(hf_repo.to_string());
287
288 let cached_model = repo
289 .get("model.safetensors")
290 .map_err(|e| CodememError::Embedding(format!("Failed to download model: {e}")))?;
291
292 let cached_config = repo
293 .get("config.json")
294 .map_err(|e| CodememError::Embedding(format!("Failed to download config: {e}")))?;
295
296 let cached_tokenizer = repo
297 .get("tokenizer.json")
298 .map_err(|e| CodememError::Embedding(format!("Failed to download tokenizer: {e}")))?;
299
300 std::fs::copy(&cached_model, &model_dest)
301 .map_err(|e| CodememError::Embedding(format!("Failed to copy model: {e}")))?;
302 std::fs::copy(&cached_config, &config_dest)
303 .map_err(|e| CodememError::Embedding(format!("Failed to copy config: {e}")))?;
304 std::fs::copy(&cached_tokenizer, &tokenizer_dest)
305 .map_err(|e| CodememError::Embedding(format!("Failed to copy tokenizer: {e}")))?;
306
307 tracing::info!("Model downloaded to {}", dest_dir.display());
308 Ok(dest_dir.to_path_buf())
309 }
310
311 pub fn download_default_model() -> Result<PathBuf, CodememError> {
314 Self::download_model(&Self::default_model_dir(), DEFAULT_HF_REPO)
315 }
316
317 pub fn embed(&self, text: &str) -> Result<Vec<f32>, CodememError> {
319 let encoding = self
321 .tokenizer
322 .encode(text, true)
323 .map_err(|e| CodememError::Embedding(e.to_string()))?;
324
325 let input_ids: Vec<u32> = encoding.get_ids().to_vec();
326 let attention_mask: Vec<u32> = encoding.get_attention_mask().to_vec();
327
328 let input_ids_tensor = Tensor::new(&input_ids[..], &self.device)
330 .and_then(|t| t.unsqueeze(0))
331 .map_err(|e| CodememError::Embedding(format!("Tensor error: {e}")))?;
332
333 let attention_mask_tensor = Tensor::new(&attention_mask[..], &self.device)
334 .and_then(|t| t.unsqueeze(0))
335 .map_err(|e| CodememError::Embedding(format!("Tensor error: {e}")))?;
336
337 let model = self
339 .model
340 .lock()
341 .map_err(|e| CodememError::LockPoisoned(format!("embedding model: {e}")))?;
342 let hidden_states = match &*model {
343 ModelBackend::Bert(bert) => {
344 let token_type_ids = input_ids_tensor
345 .zeros_like()
346 .map_err(|e| CodememError::Embedding(format!("Tensor error: {e}")))?;
347 let result = bert
348 .forward(
349 &input_ids_tensor,
350 &token_type_ids,
351 Some(&attention_mask_tensor),
352 )
353 .map_err(|e| CodememError::Embedding(format!("Model forward error: {e}")))?;
354 drop(token_type_ids);
355 result
356 }
357 ModelBackend::JinaBert(jina) => jina
358 .forward(&input_ids_tensor)
359 .map_err(|e| CodememError::Embedding(format!("Model forward error: {e}")))?,
360 };
361 drop(model);
362 drop(input_ids_tensor);
363
364 let hidden_states = hidden_states
366 .to_dtype(DType::F32)
367 .map_err(|e| CodememError::Embedding(format!("Cast error: {e}")))?;
368
369 let mask = attention_mask_tensor
372 .to_dtype(DType::F32)
373 .and_then(|t| t.unsqueeze(2))
374 .map_err(|e| CodememError::Embedding(format!("Mask error: {e}")))?;
375
376 let sum_mask = mask
377 .sum(1)
378 .map_err(|e| CodememError::Embedding(format!("Sum error: {e}")))?;
379
380 let pooled = hidden_states
381 .broadcast_mul(&mask)
382 .and_then(|t| t.sum(1))
383 .and_then(|t| t.broadcast_div(&sum_mask))
384 .map_err(|e| CodememError::Embedding(format!("Pooling error: {e}")))?;
385
386 let normalized = pooled
388 .sqr()
389 .and_then(|t| t.sum_keepdim(1))
390 .and_then(|t| t.sqrt())
391 .and_then(|norm| pooled.broadcast_div(&norm))
392 .map_err(|e| CodememError::Embedding(format!("Normalize error: {e}")))?;
393
394 let embedding: Vec<f32> = normalized
396 .squeeze(0)
397 .and_then(|t| t.to_vec1())
398 .map_err(|e| CodememError::Embedding(format!("Extract error: {e}")))?;
399
400 Ok(embedding)
401 }
402
403 pub fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, CodememError> {
409 if texts.is_empty() {
410 return Ok(vec![]);
411 }
412
413 let mut all_embeddings = Vec::with_capacity(texts.len());
414
415 for chunk in texts.chunks(self.batch_size) {
416 let mut tokenizer = self.tokenizer.clone();
419 tokenizer.with_padding(Some(PaddingParams {
420 strategy: PaddingStrategy::BatchLongest,
421 ..Default::default()
422 }));
423
424 let encodings = tokenizer
425 .encode_batch(chunk.to_vec(), true)
426 .map_err(|e| CodememError::Embedding(format!("Batch encode error: {e}")))?;
427
428 let batch_len = encodings.len();
429 let seq_len = encodings[0].get_ids().len();
430
431 let all_ids: Vec<u32> = encodings
433 .iter()
434 .flat_map(|e| e.get_ids())
435 .copied()
436 .collect();
437 let all_masks: Vec<u32> = encodings
438 .iter()
439 .flat_map(|e| e.get_attention_mask())
440 .copied()
441 .collect();
442
443 let input_ids = Tensor::new(all_ids.as_slice(), &self.device)
445 .and_then(|t| t.reshape((batch_len, seq_len)))
446 .map_err(|e| CodememError::Embedding(format!("Tensor error: {e}")))?;
447
448 let attention_mask = Tensor::new(all_masks.as_slice(), &self.device)
449 .and_then(|t| t.reshape((batch_len, seq_len)))
450 .map_err(|e| CodememError::Embedding(format!("Tensor error: {e}")))?;
451
452 let model = self
454 .model
455 .lock()
456 .map_err(|e| CodememError::LockPoisoned(format!("embedding model: {e}")))?;
457 let hidden_states = match &*model {
458 ModelBackend::Bert(bert) => {
459 let token_type_ids = input_ids
460 .zeros_like()
461 .map_err(|e| CodememError::Embedding(format!("Tensor error: {e}")))?;
462 let result = bert
463 .forward(&input_ids, &token_type_ids, Some(&attention_mask))
464 .map_err(|e| CodememError::Embedding(format!("Forward error: {e}")))?;
465 drop(token_type_ids);
466 result
467 }
468 ModelBackend::JinaBert(jina) => jina
469 .forward(&input_ids)
470 .map_err(|e| CodememError::Embedding(format!("Forward error: {e}")))?,
471 };
472 drop(model);
473 drop(input_ids);
474
475 let hidden_states = hidden_states
477 .to_dtype(DType::F32)
478 .map_err(|e| CodememError::Embedding(format!("Cast error: {e}")))?;
479
480 let mask = attention_mask
482 .to_dtype(DType::F32)
483 .and_then(|t| t.unsqueeze(2))
484 .map_err(|e| CodememError::Embedding(format!("Mask error: {e}")))?;
485
486 let sum_mask = mask
487 .sum(1)
488 .map_err(|e| CodememError::Embedding(format!("Sum error: {e}")))?;
489
490 let pooled = hidden_states
491 .broadcast_mul(&mask)
492 .and_then(|t| t.sum(1))
493 .and_then(|t| t.broadcast_div(&sum_mask))
494 .map_err(|e| CodememError::Embedding(format!("Pooling error: {e}")))?;
495
496 let norm = pooled
498 .sqr()
499 .and_then(|t| t.sum_keepdim(1))
500 .and_then(|t| t.sqrt())
501 .map_err(|e| CodememError::Embedding(format!("Norm error: {e}")))?;
502
503 let normalized = pooled
504 .broadcast_div(&norm)
505 .map_err(|e| CodememError::Embedding(format!("Normalize error: {e}")))?;
506
507 let flat: Vec<f32> = normalized
510 .flatten_all()
511 .and_then(|t| t.to_vec1())
512 .map_err(|e| CodememError::Embedding(format!("Extract error: {e}")))?;
513 for i in 0..batch_len {
514 let start = i * self.hidden_size;
515 all_embeddings.push(flat[start..start + self.hidden_size].to_vec());
516 }
517 }
518
519 Ok(all_embeddings)
520 }
521}
522
523impl EmbeddingProvider for EmbeddingService {
524 fn dimensions(&self) -> usize {
525 self.hidden_size
526 }
527
528 fn embed(&self, text: &str) -> Result<Vec<f32>, CodememError> {
529 self.embed(text)
530 }
531
532 fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, CodememError> {
533 self.embed_batch(texts)
534 }
535
536 fn name(&self) -> &str {
537 "candle"
538 }
539}
540
541pub struct CachedProvider {
545 inner: Box<dyn EmbeddingProvider>,
546 cache: Mutex<LruCache<String, Vec<f32>>>,
547}
548
549impl CachedProvider {
550 pub fn new(inner: Box<dyn EmbeddingProvider>, capacity: usize) -> Self {
551 let cap =
553 NonZeroUsize::new(capacity).unwrap_or(NonZeroUsize::new(1).expect("1 is non-zero"));
554 Self {
555 inner,
556 cache: Mutex::new(LruCache::new(cap)),
557 }
558 }
559}
560
561impl EmbeddingProvider for CachedProvider {
562 fn dimensions(&self) -> usize {
563 self.inner.dimensions()
564 }
565
566 fn embed(&self, text: &str) -> Result<Vec<f32>, CodememError> {
567 {
568 let mut cache = self
569 .cache
570 .lock()
571 .map_err(|e| CodememError::LockPoisoned(format!("cached provider: {e}")))?;
572 if let Some(cached) = cache.get(text) {
573 return Ok(cached.clone());
574 }
575 }
576 let embedding = self.inner.embed(text)?;
577 {
578 let mut cache = self
579 .cache
580 .lock()
581 .map_err(|e| CodememError::LockPoisoned(format!("cached provider: {e}")))?;
582 cache.put(text.to_string(), embedding.clone());
583 }
584 Ok(embedding)
585 }
586
587 fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, CodememError> {
588 let mut results: Vec<Option<Vec<f32>>> = vec![None; texts.len()];
590 let mut uncached = Vec::new();
591 let mut uncached_idx = Vec::new();
592
593 {
594 let mut cache = self
595 .cache
596 .lock()
597 .map_err(|e| CodememError::LockPoisoned(format!("cached provider: {e}")))?;
598 for (i, text) in texts.iter().enumerate() {
599 if let Some(cached) = cache.get(*text) {
600 results[i] = Some(cached.clone());
601 } else {
602 uncached_idx.push(i);
603 uncached.push(*text);
604 }
605 }
606 }
607
608 if !uncached.is_empty() {
609 let new_embeddings = self.inner.embed_batch(&uncached)?;
610 let mut cache = self
611 .cache
612 .lock()
613 .map_err(|e| CodememError::LockPoisoned(format!("cached provider: {e}")))?;
614 for (idx, embedding) in uncached_idx.into_iter().zip(new_embeddings) {
615 cache.put(texts[idx].to_string(), embedding.clone());
616 results[idx] = Some(embedding);
617 }
618 }
619
620 let expected = texts.len();
622 let output: Vec<Vec<f32>> = results
623 .into_iter()
624 .enumerate()
625 .map(|(i, opt)| {
626 opt.ok_or_else(|| {
627 CodememError::Embedding(format!(
628 "Missing embedding for text at index {i} (batch size {expected})"
629 ))
630 })
631 })
632 .collect::<Result<Vec<_>, _>>()?;
633 Ok(output)
634 }
635
636 fn name(&self) -> &str {
637 self.inner.name()
638 }
639
640 fn cache_stats(&self) -> (usize, usize) {
641 match self.cache.lock() {
642 Ok(cache) => (cache.len(), cache.cap().into()),
643 Err(_) => (0, 0),
644 }
645 }
646}
647
648pub fn parse_dtype(s: &str) -> Result<DType, CodememError> {
654 match s.to_lowercase().as_str() {
655 "f16" | "float16" | "half" | "" => Ok(DType::F16),
656 "f32" | "float32" => Ok(DType::F32),
657 "bf16" | "bfloat16" => Ok(DType::BF16),
658 other => Err(CodememError::Embedding(format!(
659 "Unknown dtype: '{}'. Use 'f16', 'f32', or 'bf16'.",
660 other
661 ))),
662 }
663}
664
665pub fn resolve_model_id(model: &str) -> Result<(String, String), CodememError> {
674 if model.contains('/') {
675 let dir_name = model.rsplit('/').next().unwrap_or(model);
677 Ok((model.to_string(), dir_name.to_string()))
678 } else if model.starts_with("bge-") {
679 Ok((format!("BAAI/{model}"), model.to_string()))
681 } else {
682 Err(CodememError::Embedding(format!(
683 "Model identifier '{}' must be a full HuggingFace repo ID (e.g., 'BAAI/bge-base-en-v1.5' \
684 or 'sentence-transformers/all-MiniLM-L6-v2'). Short names are only supported for 'bge-*' models.",
685 model
686 )))
687 }
688}
689
690pub fn from_env(
704 config: Option<&codemem_core::EmbeddingConfig>,
705) -> Result<Box<dyn EmbeddingProvider>, CodememError> {
706 let provider = std::env::var("CODEMEM_EMBED_PROVIDER")
707 .unwrap_or_else(|_| {
708 config
709 .map(|c| c.provider.clone())
710 .unwrap_or_else(|| "candle".to_string())
711 })
712 .to_lowercase();
713 let dimensions: usize = std::env::var("CODEMEM_EMBED_DIMENSIONS")
716 .ok()
717 .and_then(|s| s.parse().ok())
718 .unwrap_or_else(|| config.map_or(DEFAULT_REMOTE_DIMENSIONS, |c| c.dimensions));
719 let cache_capacity = config.map_or(CACHE_CAPACITY, |c| c.cache_capacity);
720 let batch_size: usize = std::env::var("CODEMEM_EMBED_BATCH_SIZE")
721 .ok()
722 .and_then(|s| s.parse().ok())
723 .unwrap_or_else(|| config.map_or(DEFAULT_BATCH_SIZE, |c| c.batch_size));
724
725 match provider.as_str() {
726 "ollama" => {
727 let base_url = std::env::var("CODEMEM_EMBED_URL").unwrap_or_else(|_| {
728 config
729 .filter(|c| !c.url.is_empty())
730 .map(|c| c.url.clone())
731 .unwrap_or_else(|| ollama::DEFAULT_BASE_URL.to_string())
732 });
733 let model = std::env::var("CODEMEM_EMBED_MODEL").unwrap_or_else(|_| {
734 config
735 .filter(|c| !c.model.is_empty())
736 .map(|c| c.model.clone())
737 .unwrap_or_else(|| ollama::DEFAULT_MODEL.to_string())
738 });
739 let inner = Box::new(ollama::OllamaProvider::new(&base_url, &model, dimensions));
740 Ok(Box::new(CachedProvider::new(inner, cache_capacity)))
741 }
742 "openai" => {
743 let api_key = std::env::var("CODEMEM_EMBED_API_KEY")
744 .or_else(|_| std::env::var("OPENAI_API_KEY"))
745 .map_err(|_| {
746 CodememError::Embedding(
747 "CODEMEM_EMBED_API_KEY or OPENAI_API_KEY required for OpenAI embeddings"
748 .into(),
749 )
750 })?;
751 let model = std::env::var("CODEMEM_EMBED_MODEL").unwrap_or_else(|_| {
752 config
753 .filter(|c| !c.model.is_empty())
754 .map(|c| c.model.clone())
755 .unwrap_or_else(|| openai::DEFAULT_MODEL.to_string())
756 });
757 let base_url = std::env::var("CODEMEM_EMBED_URL")
758 .ok()
759 .or_else(|| config.filter(|c| !c.url.is_empty()).map(|c| c.url.clone()));
760 let inner = Box::new(openai::OpenAIProvider::new(
761 &api_key,
762 &model,
763 dimensions,
764 base_url.as_deref(),
765 ));
766 Ok(Box::new(CachedProvider::new(inner, cache_capacity)))
767 }
768 "gemini" | "google" => {
769 let api_key = std::env::var("CODEMEM_EMBED_API_KEY")
770 .or_else(|_| std::env::var("GEMINI_API_KEY"))
771 .or_else(|_| std::env::var("GOOGLE_API_KEY"))
772 .map_err(|_| {
773 CodememError::Embedding(
774 "CODEMEM_EMBED_API_KEY, GEMINI_API_KEY, or GOOGLE_API_KEY required for Gemini embeddings"
775 .into(),
776 )
777 })?;
778 let model = std::env::var("CODEMEM_EMBED_MODEL").unwrap_or_else(|_| {
779 config
780 .filter(|c| !c.model.is_empty())
781 .map(|c| c.model.clone())
782 .unwrap_or_else(|| gemini::DEFAULT_MODEL.to_string())
783 });
784 let base_url = std::env::var("CODEMEM_EMBED_URL")
785 .ok()
786 .or_else(|| config.filter(|c| !c.url.is_empty()).map(|c| c.url.clone()));
787 let inner = Box::new(gemini::GeminiProvider::new(
788 &api_key,
789 &model,
790 dimensions,
791 base_url.as_deref(),
792 ));
793 Ok(Box::new(CachedProvider::new(inner, cache_capacity)))
794 }
795 "candle" | "" => {
796 let model_id = std::env::var("CODEMEM_EMBED_MODEL").unwrap_or_else(|_| {
797 config
798 .filter(|c| !c.model.is_empty())
799 .map(|c| c.model.clone())
800 .unwrap_or_else(|| DEFAULT_HF_REPO.to_string())
801 });
802 let (hf_repo, dir_name) = resolve_model_id(&model_id)?;
803 let model_dir = EmbeddingService::model_dir_for(&dir_name);
804
805 let dtype_str = std::env::var("CODEMEM_EMBED_DTYPE").unwrap_or_else(|_| {
806 config
807 .filter(|c| !c.dtype.is_empty())
808 .map(|c| c.dtype.clone())
809 .unwrap_or_else(|| "f16".to_string())
810 });
811 let dtype = parse_dtype(&dtype_str)?;
812
813 let service = EmbeddingService::new(&model_dir, batch_size, dtype).map_err(|e| {
814 if e.to_string().contains("Model not found") && hf_repo != DEFAULT_HF_REPO {
816 CodememError::Embedding(format!(
817 "Model '{}' not found at {}. Download it with:\n \
818 CODEMEM_EMBED_MODEL={} codemem init",
819 hf_repo,
820 model_dir.display(),
821 hf_repo
822 ))
823 } else {
824 e
825 }
826 })?;
827 Ok(Box::new(CachedProvider::new(
828 Box::new(service),
829 cache_capacity,
830 )))
831 }
832 other => Err(CodememError::Embedding(format!(
833 "Unknown embedding provider: '{}'. Use 'candle', 'ollama', 'openai', or 'gemini'.",
834 other
835 ))),
836 }
837}
838
839#[cfg(test)]
840#[path = "tests/lib_tests.rs"]
841mod tests;