codemem-embeddings 0.16.0

Candle-based embedding service for Codemem using BAAI/bge-base-en-v1.5
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
//! codemem-embeddings: Pluggable embedding providers for Codemem.
//!
//! Supports multiple backends:
//! - **Candle** (default): Local BERT models via pure Rust ML (any HF BERT model)
//! - **Ollama**: Local Ollama server with any embedding model
//! - **OpenAI**: OpenAI API or any compatible endpoint (Together, Azure, etc.)
//! - **Gemini**: Google Generative Language API (text-embedding-004)

pub mod gemini;
pub mod ollama;
pub mod openai;

use candle_core::{DType, Device, Tensor};
use candle_nn::{Module, VarBuilder};
use candle_transformers::models::bert::{BertModel, Config as BertConfig};
use candle_transformers::models::jina_bert::{
    BertModel as JinaBertModel, Config as JinaBertConfig,
};
use codemem_core::CodememError;
use lru::LruCache;
use std::num::NonZeroUsize;
use std::path::{Path, PathBuf};
use std::sync::Mutex;
use tokenizers::{PaddingParams, PaddingStrategy};

/// Default model name (short form used for directory naming).
pub const MODEL_NAME: &str = "bge-base-en-v1.5";

/// Default HuggingFace model repo ID.
/// Used internally and by `commands_init` for the default model download.
pub const DEFAULT_HF_REPO: &str = "BAAI/bge-base-en-v1.5";

/// Default embedding dimensions for remote providers (Ollama/OpenAI).
/// Candle reads `hidden_size` from the model's config.json instead.
pub const DEFAULT_REMOTE_DIMENSIONS: usize = 768;

/// Default max sequence length for standard BERT models (used when config doesn't specify).
const DEFAULT_MAX_SEQ_LENGTH: usize = 512;

/// Default LRU cache capacity.
pub const CACHE_CAPACITY: usize = 10_000;

// Re-export EmbeddingProvider trait from core
pub use codemem_core::EmbeddingProvider;

// ── Candle Embedding Service ────────────────────────────────────────────────

/// Default batch size for batched embedding forward passes.
/// Configurable via `EmbeddingConfig.batch_size` or `CODEMEM_EMBED_BATCH_SIZE`.
pub const DEFAULT_BATCH_SIZE: usize = 16;

/// Select the best available compute device.
///
/// Tries Metal (macOS GPU) first, then CUDA (NVIDIA GPU), then falls back to CPU.
/// GPU backends are only available when the corresponding feature flag is enabled.
fn select_device() -> Device {
    #[cfg(feature = "metal")]
    {
        // Use catch_unwind to handle SIGBUS/panics on CI runners without GPU access.
        match std::panic::catch_unwind(|| Device::new_metal(0)) {
            Ok(Ok(device)) => {
                tracing::info!("Using Metal GPU for embeddings");
                return device;
            }
            Ok(Err(e)) => {
                tracing::warn!("Metal device creation failed: {e}, falling back to CPU");
            }
            Err(_) => {
                tracing::warn!("Metal device creation panicked, falling back to CPU");
            }
        }
    }
    #[cfg(feature = "cuda")]
    {
        match std::panic::catch_unwind(|| Device::new_cuda(0)) {
            Ok(Ok(device)) => {
                tracing::info!("Using CUDA GPU for embeddings");
                return device;
            }
            Ok(Err(e)) => {
                tracing::warn!("CUDA device creation failed: {e}, falling back to CPU");
            }
            Err(_) => {
                tracing::warn!("CUDA device creation panicked, falling back to CPU");
            }
        }
    }
    tracing::info!("Using CPU for embeddings");
    Device::Cpu
}

/// Model backend enum — dispatches forward passes to the correct architecture.
enum ModelBackend {
    /// Standard BERT (absolute positional embeddings). Used by BGE, MiniLM, etc.
    Bert(BertModel),
    /// JinaBERT (ALiBi positional embeddings). Used by Jina embeddings v2.
    JinaBert(JinaBertModel),
}

/// Embedding service with Candle inference (no internal cache — use `CachedProvider` wrapper).
pub struct EmbeddingService {
    model: Mutex<ModelBackend>,
    /// Tokenizer pre-configured with truncation (no padding).
    /// Used directly for single embeds; cloned and augmented with padding for batch.
    tokenizer: tokenizers::Tokenizer,
    device: Device,
    /// Maximum texts per forward pass (GPU memory trade-off).
    batch_size: usize,
    /// Hidden size read from model config (e.g. 768 for bge-base, 384 for bge-small).
    hidden_size: usize,
    /// Max sequence length (512 for BERT, up to 8192 for JinaBERT).
    max_seq_length: usize,
}

/// Minimal struct for sniffing model architecture from config.json before full parsing.
#[derive(serde::Deserialize)]
struct ConfigProbe {
    #[serde(default)]
    position_embedding_type: Option<String>,
    hidden_size: usize,
    #[serde(default = "default_max_position_embeddings")]
    max_position_embeddings: usize,
}

fn default_max_position_embeddings() -> usize {
    DEFAULT_MAX_SEQ_LENGTH
}

impl EmbeddingService {
    /// Create a new embedding service, loading model from the given directory.
    /// Expects `model.safetensors`, `config.json`, and `tokenizer.json` in the directory.
    ///
    /// Auto-detects model architecture (BERT vs JinaBERT) from config.json.
    /// `dtype` controls precision: `DType::F32` (default) or `DType::F16` (half memory, faster on Metal).
    pub fn new(model_dir: &Path, batch_size: usize, dtype: DType) -> Result<Self, CodememError> {
        let model_path = model_dir.join("model.safetensors");
        let config_path = model_dir.join("config.json");
        let tokenizer_path = model_dir.join("tokenizer.json");

        if !model_path.exists() {
            return Err(CodememError::Embedding(format!(
                "Model not found at {}. Run `codemem init` to download it.",
                model_path.display()
            )));
        }

        let device = select_device();

        tracing::info!(
            "Loading model from {} (dtype: {:?}, device: {:?})",
            model_dir.display(),
            dtype,
            device
        );

        let config_str = std::fs::read_to_string(&config_path)
            .map_err(|e| CodememError::Embedding(format!("Failed to read config: {e}")))?;

        // Probe config to detect architecture before full parsing
        let probe: ConfigProbe = serde_json::from_str(&config_str)
            .map_err(|e| CodememError::Embedding(format!("Failed to probe config: {e}")))?;
        let hidden_size = probe.hidden_size;
        let is_alibi = probe
            .position_embedding_type
            .as_deref()
            .is_some_and(|t| t == "alibi");
        // Cap at 8192 to avoid excessive memory usage even if model claims more
        let max_seq_length = probe.max_position_embeddings.min(8192);

        let (model, arch_name) = if is_alibi {
            // JinaBERT (ALiBi positional embeddings)
            let config: JinaBertConfig = serde_json::from_str(&config_str).map_err(|e| {
                CodememError::Embedding(format!("Failed to parse JinaBERT config: {e}"))
            })?;
            let vb = unsafe {
                VarBuilder::from_mmaped_safetensors(&[&model_path], dtype, &device)
                    .map_err(|e| CodememError::Embedding(format!("Failed to load weights: {e}")))?
            };
            // JinaBERT weights use "bert." prefix
            let jina_model = JinaBertModel::new(vb.pp("bert"), &config).map_err(|e| {
                CodememError::Embedding(format!("Failed to load JinaBERT model: {e}"))
            })?;
            (ModelBackend::JinaBert(jina_model), "JinaBERT (ALiBi)")
        } else {
            // Standard BERT (absolute positional embeddings)
            let config: BertConfig = serde_json::from_str(&config_str).map_err(|e| {
                CodememError::Embedding(format!("Failed to parse BERT config: {e}"))
            })?;
            // Load model weights from safetensors via memory-mapped IO.
            // Scope vb so it drops before a potential retry, avoiding two VarBuilders
            // holding materialized Metal tensors simultaneously.
            let bert_model = {
                let vb = unsafe {
                    VarBuilder::from_mmaped_safetensors(&[&model_path], dtype, &device).map_err(
                        |e| CodememError::Embedding(format!("Failed to load weights: {e}")),
                    )?
                };
                BertModel::load(vb.pp("bert"), &config)
            };
            // Try with "bert." prefix first (standard HF BERT models), then without
            let bert_model = match bert_model {
                Ok(m) => m,
                Err(_) => {
                    let vb2 = unsafe {
                        VarBuilder::from_mmaped_safetensors(&[&model_path], dtype, &device)
                            .map_err(|e| {
                                CodememError::Embedding(format!("Failed to load weights: {e}"))
                            })?
                    };
                    BertModel::load(vb2, &config).map_err(|e| {
                        CodememError::Embedding(format!("Failed to load BERT model: {e}"))
                    })?
                }
            };
            (ModelBackend::Bert(bert_model), "BERT (absolute)")
        };

        tracing::info!(
            "Loaded {} model (hidden_size={}, max_seq_length={})",
            arch_name,
            hidden_size,
            max_seq_length
        );

        let mut tokenizer = tokenizers::Tokenizer::from_file(&tokenizer_path)
            .map_err(|e| CodememError::Embedding(e.to_string()))?;

        // Pre-configure truncation once so we don't need to clone on every embed call.
        tokenizer
            .with_truncation(Some(tokenizers::TruncationParams {
                max_length: max_seq_length,
                ..Default::default()
            }))
            .map_err(|e| CodememError::Embedding(format!("Truncation error: {e}")))?;

        Ok(Self {
            model: Mutex::new(model),
            tokenizer,
            device,
            batch_size,
            hidden_size,
            max_seq_length,
        })
    }

    /// Maximum sequence length this model supports.
    pub fn max_seq_length(&self) -> usize {
        self.max_seq_length
    }

    /// Get the model directory path for a given model name.
    /// Falls back to `~/.codemem/models/{model_name}`.
    pub fn model_dir_for(model_name: &str) -> PathBuf {
        dirs::home_dir()
            .unwrap_or_else(|| PathBuf::from("."))
            .join(".codemem")
            .join("models")
            .join(model_name)
    }

    /// Get the default model directory path (~/.codemem/models/{MODEL_NAME}).
    pub fn default_model_dir() -> PathBuf {
        Self::model_dir_for(MODEL_NAME)
    }

    /// Download a model from HuggingFace Hub to the given directory.
    /// `hf_repo` is the full repo ID (e.g. "BAAI/bge-base-en-v1.5").
    /// Returns the directory path. No-ops if model already exists.
    pub fn download_model(dest_dir: &Path, hf_repo: &str) -> Result<PathBuf, CodememError> {
        let model_dest = dest_dir.join("model.safetensors");
        let config_dest = dest_dir.join("config.json");
        let tokenizer_dest = dest_dir.join("tokenizer.json");

        if model_dest.exists() && config_dest.exists() && tokenizer_dest.exists() {
            tracing::info!("Model already downloaded at {}", dest_dir.display());
            return Ok(dest_dir.to_path_buf());
        }

        std::fs::create_dir_all(dest_dir)
            .map_err(|e| CodememError::Embedding(format!("Failed to create dir: {e}")))?;

        tracing::info!("Downloading {} from HuggingFace...", hf_repo);

        let api = hf_hub::api::sync::Api::new()
            .map_err(|e| CodememError::Embedding(format!("HuggingFace API error: {e}")))?;
        let repo = api.model(hf_repo.to_string());

        let cached_model = repo
            .get("model.safetensors")
            .map_err(|e| CodememError::Embedding(format!("Failed to download model: {e}")))?;

        let cached_config = repo
            .get("config.json")
            .map_err(|e| CodememError::Embedding(format!("Failed to download config: {e}")))?;

        let cached_tokenizer = repo
            .get("tokenizer.json")
            .map_err(|e| CodememError::Embedding(format!("Failed to download tokenizer: {e}")))?;

        std::fs::copy(&cached_model, &model_dest)
            .map_err(|e| CodememError::Embedding(format!("Failed to copy model: {e}")))?;
        std::fs::copy(&cached_config, &config_dest)
            .map_err(|e| CodememError::Embedding(format!("Failed to copy config: {e}")))?;
        std::fs::copy(&cached_tokenizer, &tokenizer_dest)
            .map_err(|e| CodememError::Embedding(format!("Failed to copy tokenizer: {e}")))?;

        tracing::info!("Model downloaded to {}", dest_dir.display());
        Ok(dest_dir.to_path_buf())
    }

    /// Download the default model (BAAI/bge-base-en-v1.5) to the default directory.
    /// Convenience wrapper for `download_model(&default_model_dir(), DEFAULT_HF_REPO)`.
    pub fn download_default_model() -> Result<PathBuf, CodememError> {
        Self::download_model(&Self::default_model_dir(), DEFAULT_HF_REPO)
    }

    /// Embed a single text string. Returns an L2-normalized vector (dimension = model's hidden_size).
    pub fn embed(&self, text: &str) -> Result<Vec<f32>, CodememError> {
        // Tokenize using pre-configured tokenizer (truncation already set in constructor)
        let encoding = self
            .tokenizer
            .encode(text, true)
            .map_err(|e| CodememError::Embedding(e.to_string()))?;

        let input_ids: Vec<u32> = encoding.get_ids().to_vec();
        let attention_mask: Vec<u32> = encoding.get_attention_mask().to_vec();

        // Build candle tensors with shape [1, seq_len]
        let input_ids_tensor = Tensor::new(&input_ids[..], &self.device)
            .and_then(|t| t.unsqueeze(0))
            .map_err(|e| CodememError::Embedding(format!("Tensor error: {e}")))?;

        let attention_mask_tensor = Tensor::new(&attention_mask[..], &self.device)
            .and_then(|t| t.unsqueeze(0))
            .map_err(|e| CodememError::Embedding(format!("Tensor error: {e}")))?;

        // Forward pass -> [1, seq_len, hidden_size]
        let model = self
            .model
            .lock()
            .map_err(|e| CodememError::LockPoisoned(format!("embedding model: {e}")))?;
        let hidden_states = match &*model {
            ModelBackend::Bert(bert) => {
                let token_type_ids = input_ids_tensor
                    .zeros_like()
                    .map_err(|e| CodememError::Embedding(format!("Tensor error: {e}")))?;
                let result = bert
                    .forward(
                        &input_ids_tensor,
                        &token_type_ids,
                        Some(&attention_mask_tensor),
                    )
                    .map_err(|e| CodememError::Embedding(format!("Model forward error: {e}")))?;
                drop(token_type_ids);
                result
            }
            ModelBackend::JinaBert(jina) => jina
                .forward(&input_ids_tensor)
                .map_err(|e| CodememError::Embedding(format!("Model forward error: {e}")))?,
        };
        drop(model);
        drop(input_ids_tensor);

        // Cast hidden states to F32 for pooling math (model may output F16/BF16)
        let hidden_states = hidden_states
            .to_dtype(DType::F32)
            .map_err(|e| CodememError::Embedding(format!("Cast error: {e}")))?;

        // Mean pooling weighted by attention mask
        // attention_mask: [1, seq_len] -> [1, seq_len, 1] for broadcasting
        let mask = attention_mask_tensor
            .to_dtype(DType::F32)
            .and_then(|t| t.unsqueeze(2))
            .map_err(|e| CodememError::Embedding(format!("Mask error: {e}")))?;

        let sum_mask = mask
            .sum(1)
            .map_err(|e| CodememError::Embedding(format!("Sum error: {e}")))?;

        let pooled = hidden_states
            .broadcast_mul(&mask)
            .and_then(|t| t.sum(1))
            .and_then(|t| t.broadcast_div(&sum_mask))
            .map_err(|e| CodememError::Embedding(format!("Pooling error: {e}")))?;

        // L2 normalize
        let normalized = pooled
            .sqr()
            .and_then(|t| t.sum_keepdim(1))
            .and_then(|t| t.sqrt())
            .and_then(|norm| pooled.broadcast_div(&norm))
            .map_err(|e| CodememError::Embedding(format!("Normalize error: {e}")))?;

        // Extract as Vec<f32> — shape is [1, hidden_size], squeeze to [hidden_size]
        let embedding: Vec<f32> = normalized
            .squeeze(0)
            .and_then(|t| t.to_vec1())
            .map_err(|e| CodememError::Embedding(format!("Extract error: {e}")))?;

        Ok(embedding)
    }

    /// Embed a batch of texts using a true batched forward pass.
    ///
    /// Tokenizes all texts, pads to the longest sequence in each chunk, runs a
    /// single forward pass per chunk of up to `batch_size` texts, then performs
    /// mean pooling and L2 normalization on the batched output.
    pub fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, CodememError> {
        if texts.is_empty() {
            return Ok(vec![]);
        }

        let mut all_embeddings = Vec::with_capacity(texts.len());

        for chunk in texts.chunks(self.batch_size) {
            // Clone tokenizer only for batch path — needs per-chunk padding config.
            // Truncation is already configured on self.tokenizer.
            let mut tokenizer = self.tokenizer.clone();
            tokenizer.with_padding(Some(PaddingParams {
                strategy: PaddingStrategy::BatchLongest,
                ..Default::default()
            }));

            let encodings = tokenizer
                .encode_batch(chunk.to_vec(), true)
                .map_err(|e| CodememError::Embedding(format!("Batch encode error: {e}")))?;

            let batch_len = encodings.len();
            let seq_len = encodings[0].get_ids().len();

            // Flatten token IDs and attention masks into contiguous arrays
            let all_ids: Vec<u32> = encodings
                .iter()
                .flat_map(|e| e.get_ids())
                .copied()
                .collect();
            let all_masks: Vec<u32> = encodings
                .iter()
                .flat_map(|e| e.get_attention_mask())
                .copied()
                .collect();

            // Build tensors with shape [batch_size, seq_len]
            let input_ids = Tensor::new(all_ids.as_slice(), &self.device)
                .and_then(|t| t.reshape((batch_len, seq_len)))
                .map_err(|e| CodememError::Embedding(format!("Tensor error: {e}")))?;

            let attention_mask = Tensor::new(all_masks.as_slice(), &self.device)
                .and_then(|t| t.reshape((batch_len, seq_len)))
                .map_err(|e| CodememError::Embedding(format!("Tensor error: {e}")))?;

            // Single forward pass -> [batch_size, seq_len, hidden_size]
            let model = self
                .model
                .lock()
                .map_err(|e| CodememError::LockPoisoned(format!("embedding model: {e}")))?;
            let hidden_states = match &*model {
                ModelBackend::Bert(bert) => {
                    let token_type_ids = input_ids
                        .zeros_like()
                        .map_err(|e| CodememError::Embedding(format!("Tensor error: {e}")))?;
                    let result = bert
                        .forward(&input_ids, &token_type_ids, Some(&attention_mask))
                        .map_err(|e| CodememError::Embedding(format!("Forward error: {e}")))?;
                    drop(token_type_ids);
                    result
                }
                ModelBackend::JinaBert(jina) => jina
                    .forward(&input_ids)
                    .map_err(|e| CodememError::Embedding(format!("Forward error: {e}")))?,
            };
            drop(model);
            drop(input_ids);

            // Cast hidden states to F32 for pooling math (model may output F16/BF16)
            let hidden_states = hidden_states
                .to_dtype(DType::F32)
                .map_err(|e| CodememError::Embedding(format!("Cast error: {e}")))?;

            // Mean pooling: mask [batch, seq] -> [batch, seq, 1] for broadcast
            let mask = attention_mask
                .to_dtype(DType::F32)
                .and_then(|t| t.unsqueeze(2))
                .map_err(|e| CodememError::Embedding(format!("Mask error: {e}")))?;

            let sum_mask = mask
                .sum(1)
                .map_err(|e| CodememError::Embedding(format!("Sum error: {e}")))?;

            let pooled = hidden_states
                .broadcast_mul(&mask)
                .and_then(|t| t.sum(1))
                .and_then(|t| t.broadcast_div(&sum_mask))
                .map_err(|e| CodememError::Embedding(format!("Pooling error: {e}")))?;

            // L2 normalize: [batch, hidden]
            let norm = pooled
                .sqr()
                .and_then(|t| t.sum_keepdim(1))
                .and_then(|t| t.sqrt())
                .map_err(|e| CodememError::Embedding(format!("Norm error: {e}")))?;

            let normalized = pooled
                .broadcast_div(&norm)
                .map_err(|e| CodememError::Embedding(format!("Normalize error: {e}")))?;

            // Single GPU→CPU blit: flatten all rows, then slice on CPU.
            // to_vec1() implicitly syncs the GPU pipeline (data must be ready to read).
            let flat: Vec<f32> = normalized
                .flatten_all()
                .and_then(|t| t.to_vec1())
                .map_err(|e| CodememError::Embedding(format!("Extract error: {e}")))?;
            for i in 0..batch_len {
                let start = i * self.hidden_size;
                all_embeddings.push(flat[start..start + self.hidden_size].to_vec());
            }
        }

        Ok(all_embeddings)
    }
}

impl EmbeddingProvider for EmbeddingService {
    fn dimensions(&self) -> usize {
        self.hidden_size
    }

    fn embed(&self, text: &str) -> Result<Vec<f32>, CodememError> {
        self.embed(text)
    }

    fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, CodememError> {
        self.embed_batch(texts)
    }

    fn name(&self) -> &str {
        "candle"
    }
}

// ── Cached Provider Wrapper ───────────────────────────────────────────────

/// Wraps any `EmbeddingProvider` with an LRU cache.
pub struct CachedProvider {
    inner: Box<dyn EmbeddingProvider>,
    cache: Mutex<LruCache<String, Vec<f32>>>,
}

impl CachedProvider {
    pub fn new(inner: Box<dyn EmbeddingProvider>, capacity: usize) -> Self {
        // SAFETY: 1 is non-zero, so the inner expect is infallible
        let cap =
            NonZeroUsize::new(capacity).unwrap_or(NonZeroUsize::new(1).expect("1 is non-zero"));
        Self {
            inner,
            cache: Mutex::new(LruCache::new(cap)),
        }
    }
}

impl EmbeddingProvider for CachedProvider {
    fn dimensions(&self) -> usize {
        self.inner.dimensions()
    }

    fn embed(&self, text: &str) -> Result<Vec<f32>, CodememError> {
        {
            let mut cache = self
                .cache
                .lock()
                .map_err(|e| CodememError::LockPoisoned(format!("cached provider: {e}")))?;
            if let Some(cached) = cache.get(text) {
                return Ok(cached.clone());
            }
        }
        let embedding = self.inner.embed(text)?;
        {
            let mut cache = self
                .cache
                .lock()
                .map_err(|e| CodememError::LockPoisoned(format!("cached provider: {e}")))?;
            cache.put(text.to_string(), embedding.clone());
        }
        Ok(embedding)
    }

    fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, CodememError> {
        // Check cache, only forward uncached texts
        let mut results: Vec<Option<Vec<f32>>> = vec![None; texts.len()];
        let mut uncached = Vec::new();
        let mut uncached_idx = Vec::new();

        {
            let mut cache = self
                .cache
                .lock()
                .map_err(|e| CodememError::LockPoisoned(format!("cached provider: {e}")))?;
            for (i, text) in texts.iter().enumerate() {
                if let Some(cached) = cache.get(*text) {
                    results[i] = Some(cached.clone());
                } else {
                    uncached_idx.push(i);
                    uncached.push(*text);
                }
            }
        }

        if !uncached.is_empty() {
            let new_embeddings = self.inner.embed_batch(&uncached)?;
            let mut cache = self
                .cache
                .lock()
                .map_err(|e| CodememError::LockPoisoned(format!("cached provider: {e}")))?;
            for (idx, embedding) in uncached_idx.into_iter().zip(new_embeddings) {
                cache.put(texts[idx].to_string(), embedding.clone());
                results[idx] = Some(embedding);
            }
        }

        // Verify all texts got embeddings — flatten() would silently drop Nones
        let expected = texts.len();
        let output: Vec<Vec<f32>> = results
            .into_iter()
            .enumerate()
            .map(|(i, opt)| {
                opt.ok_or_else(|| {
                    CodememError::Embedding(format!(
                        "Missing embedding for text at index {i} (batch size {expected})"
                    ))
                })
            })
            .collect::<Result<Vec<_>, _>>()?;
        Ok(output)
    }

    fn name(&self) -> &str {
        self.inner.name()
    }

    fn cache_stats(&self) -> (usize, usize) {
        match self.cache.lock() {
            Ok(cache) => (cache.len(), cache.cap().into()),
            Err(_) => (0, 0),
        }
    }
}

// ── Factory ───────────────────────────────────────────────────────────────

/// Parse a dtype string into a Candle DType.
///
/// Supported values: "f16" (default, half precision — less memory, faster on Metal), "f32", "bf16".
pub fn parse_dtype(s: &str) -> Result<DType, CodememError> {
    match s.to_lowercase().as_str() {
        "f16" | "float16" | "half" | "" => Ok(DType::F16),
        "f32" | "float32" => Ok(DType::F32),
        "bf16" | "bfloat16" => Ok(DType::BF16),
        other => Err(CodememError::Embedding(format!(
            "Unknown dtype: '{}'. Use 'f16', 'f32', or 'bf16'.",
            other
        ))),
    }
}

/// Resolve the HuggingFace repo ID and local directory name from a model identifier.
///
/// Accepts:
/// - Full HF repo: `"BAAI/bge-base-en-v1.5"` → repo=`"BAAI/bge-base-en-v1.5"`, dir=`"bge-base-en-v1.5"`
/// - Short name: `"bge-small-en-v1.5"` → repo=`"BAAI/bge-small-en-v1.5"`, dir=`"bge-small-en-v1.5"`
///
/// Returns `Err` if the model identifier is a bare name without an org prefix and isn't
/// a recognized `bge-*` shorthand — HuggingFace requires `org/repo` format.
pub fn resolve_model_id(model: &str) -> Result<(String, String), CodememError> {
    if model.contains('/') {
        // Full repo ID — directory name is the part after the slash
        let dir_name = model.rsplit('/').next().unwrap_or(model);
        Ok((model.to_string(), dir_name.to_string()))
    } else if model.starts_with("bge-") {
        // Short name — assume BAAI namespace for bge-* models
        Ok((format!("BAAI/{model}"), model.to_string()))
    } else {
        Err(CodememError::Embedding(format!(
            "Model identifier '{}' must be a full HuggingFace repo ID (e.g., 'BAAI/bge-base-en-v1.5' \
             or 'sentence-transformers/all-MiniLM-L6-v2'). Short names are only supported for 'bge-*' models.",
            model
        )))
    }
}

/// Create an embedding provider from environment variables.
///
/// When `config` is provided, its fields serve as defaults; env vars override them.
///
/// | Variable | Values | Default |
/// |----------|--------|---------|
/// | `CODEMEM_EMBED_PROVIDER` | `candle`, `ollama`, `openai`, `gemini` | `candle` |
/// | `CODEMEM_EMBED_MODEL` | model name or HF repo | `BAAI/bge-base-en-v1.5` |
/// | `CODEMEM_EMBED_URL` | base URL | provider default |
/// | `CODEMEM_EMBED_API_KEY` | API key | also reads `OPENAI_API_KEY` / `GEMINI_API_KEY` / `GOOGLE_API_KEY` |
/// | `CODEMEM_EMBED_DIMENSIONS` | integer | read from model config |
/// | `CODEMEM_EMBED_BATCH_SIZE` | integer | `16` |
/// | `CODEMEM_EMBED_DTYPE` | `f16`, `f32`, `bf16` | `f16` |
pub fn from_env(
    config: Option<&codemem_core::EmbeddingConfig>,
) -> Result<Box<dyn EmbeddingProvider>, CodememError> {
    let provider = std::env::var("CODEMEM_EMBED_PROVIDER")
        .unwrap_or_else(|_| {
            config
                .map(|c| c.provider.clone())
                .unwrap_or_else(|| "candle".to_string())
        })
        .to_lowercase();
    // For Ollama/OpenAI, dimensions must be specified explicitly (remote APIs need it).
    // For Candle, this value is ignored — hidden_size is read from the model's config.json.
    let dimensions: usize = std::env::var("CODEMEM_EMBED_DIMENSIONS")
        .ok()
        .and_then(|s| s.parse().ok())
        .unwrap_or_else(|| config.map_or(DEFAULT_REMOTE_DIMENSIONS, |c| c.dimensions));
    let cache_capacity = config.map_or(CACHE_CAPACITY, |c| c.cache_capacity);
    let batch_size: usize = std::env::var("CODEMEM_EMBED_BATCH_SIZE")
        .ok()
        .and_then(|s| s.parse().ok())
        .unwrap_or_else(|| config.map_or(DEFAULT_BATCH_SIZE, |c| c.batch_size));

    match provider.as_str() {
        "ollama" => {
            let base_url = std::env::var("CODEMEM_EMBED_URL").unwrap_or_else(|_| {
                config
                    .filter(|c| !c.url.is_empty())
                    .map(|c| c.url.clone())
                    .unwrap_or_else(|| ollama::DEFAULT_BASE_URL.to_string())
            });
            let model = std::env::var("CODEMEM_EMBED_MODEL").unwrap_or_else(|_| {
                config
                    .filter(|c| !c.model.is_empty())
                    .map(|c| c.model.clone())
                    .unwrap_or_else(|| ollama::DEFAULT_MODEL.to_string())
            });
            let inner = Box::new(ollama::OllamaProvider::new(&base_url, &model, dimensions));
            Ok(Box::new(CachedProvider::new(inner, cache_capacity)))
        }
        "openai" => {
            let api_key = std::env::var("CODEMEM_EMBED_API_KEY")
                .or_else(|_| std::env::var("OPENAI_API_KEY"))
                .map_err(|_| {
                    CodememError::Embedding(
                        "CODEMEM_EMBED_API_KEY or OPENAI_API_KEY required for OpenAI embeddings"
                            .into(),
                    )
                })?;
            let model = std::env::var("CODEMEM_EMBED_MODEL").unwrap_or_else(|_| {
                config
                    .filter(|c| !c.model.is_empty())
                    .map(|c| c.model.clone())
                    .unwrap_or_else(|| openai::DEFAULT_MODEL.to_string())
            });
            let base_url = std::env::var("CODEMEM_EMBED_URL")
                .ok()
                .or_else(|| config.filter(|c| !c.url.is_empty()).map(|c| c.url.clone()));
            let inner = Box::new(openai::OpenAIProvider::new(
                &api_key,
                &model,
                dimensions,
                base_url.as_deref(),
            ));
            Ok(Box::new(CachedProvider::new(inner, cache_capacity)))
        }
        "gemini" | "google" => {
            let api_key = std::env::var("CODEMEM_EMBED_API_KEY")
                .or_else(|_| std::env::var("GEMINI_API_KEY"))
                .or_else(|_| std::env::var("GOOGLE_API_KEY"))
                .map_err(|_| {
                    CodememError::Embedding(
                        "CODEMEM_EMBED_API_KEY, GEMINI_API_KEY, or GOOGLE_API_KEY required for Gemini embeddings"
                            .into(),
                    )
                })?;
            let model = std::env::var("CODEMEM_EMBED_MODEL").unwrap_or_else(|_| {
                config
                    .filter(|c| !c.model.is_empty())
                    .map(|c| c.model.clone())
                    .unwrap_or_else(|| gemini::DEFAULT_MODEL.to_string())
            });
            let base_url = std::env::var("CODEMEM_EMBED_URL")
                .ok()
                .or_else(|| config.filter(|c| !c.url.is_empty()).map(|c| c.url.clone()));
            let inner = Box::new(gemini::GeminiProvider::new(
                &api_key,
                &model,
                dimensions,
                base_url.as_deref(),
            ));
            Ok(Box::new(CachedProvider::new(inner, cache_capacity)))
        }
        "candle" | "" => {
            let model_id = std::env::var("CODEMEM_EMBED_MODEL").unwrap_or_else(|_| {
                config
                    .filter(|c| !c.model.is_empty())
                    .map(|c| c.model.clone())
                    .unwrap_or_else(|| DEFAULT_HF_REPO.to_string())
            });
            let (hf_repo, dir_name) = resolve_model_id(&model_id)?;
            let model_dir = EmbeddingService::model_dir_for(&dir_name);

            let dtype_str = std::env::var("CODEMEM_EMBED_DTYPE").unwrap_or_else(|_| {
                config
                    .filter(|c| !c.dtype.is_empty())
                    .map(|c| c.dtype.clone())
                    .unwrap_or_else(|| "f16".to_string())
            });
            let dtype = parse_dtype(&dtype_str)?;

            let service = EmbeddingService::new(&model_dir, batch_size, dtype).map_err(|e| {
                // Enhance error message with download hint for non-default models
                if e.to_string().contains("Model not found") && hf_repo != DEFAULT_HF_REPO {
                    CodememError::Embedding(format!(
                        "Model '{}' not found at {}. Download it with:\n  \
                         CODEMEM_EMBED_MODEL={} codemem init",
                        hf_repo,
                        model_dir.display(),
                        hf_repo
                    ))
                } else {
                    e
                }
            })?;
            Ok(Box::new(CachedProvider::new(
                Box::new(service),
                cache_capacity,
            )))
        }
        other => Err(CodememError::Embedding(format!(
            "Unknown embedding provider: '{}'. Use 'candle', 'ollama', 'openai', or 'gemini'.",
            other
        ))),
    }
}

#[cfg(test)]
#[path = "tests/lib_tests.rs"]
mod tests;