Skip to main content

ix_embeddings/
lib.rs

1//! Pluggable embedding infrastructure for Ixchel.
2//!
3//! Supports multiple backends via feature flags:
4//! - `fastembed` (default): ONNX-based, CPU-only
5//! - `candle`: Hugging Face Candle, supports Metal/CUDA
6
7use ix_config::{EmbeddingConfig, load_shared_config};
8use std::sync::Mutex;
9use thiserror::Error;
10
11#[cfg(feature = "fastembed")]
12use fastembed::{EmbeddingModel, InitOptions, TextEmbedding};
13
14#[cfg(feature = "candle")]
15use candle_core::{Device, Tensor};
16#[cfg(feature = "candle")]
17use candle_nn::VarBuilder;
18#[cfg(feature = "candle")]
19use candle_transformers::models::bert::{BertModel, Config as BertConfig, DTYPE};
20#[cfg(feature = "candle")]
21use hf_hub::{Repo, RepoType, api::sync::Api};
22#[cfg(feature = "candle")]
23use tokenizers::{PaddingParams, PaddingStrategy, Tokenizer, TruncationParams, TruncationStrategy};
24
25#[derive(Debug, Error)]
26pub enum EmbeddingError {
27    #[error("Failed to initialize embedding model: {0}")]
28    InitError(String),
29
30    #[error("Failed to generate embedding: {0}")]
31    EmbedError(String),
32
33    #[error("Embedding provider unavailable: {0}")]
34    ProviderUnavailable(String),
35
36    #[error("No embedding returned for input")]
37    EmptyResult,
38
39    #[error("Unknown provider: {0}")]
40    UnknownProvider(String),
41
42    #[error("Unknown model: {0}")]
43    UnknownModel(String),
44
45    #[error(
46        "Embedding dimension mismatch for model {model}: expected {expected}, configured {configured}"
47    )]
48    DimensionMismatch {
49        model: String,
50        expected: usize,
51        configured: usize,
52    },
53
54    #[error("Provider not available: {provider} (enable the '{feature}' feature)")]
55    ProviderNotCompiled { provider: String, feature: String },
56}
57
58pub type Result<T> = std::result::Result<T, EmbeddingError>;
59
60pub trait EmbeddingProvider: Send + Sync {
61    fn embed(&self, text: &str) -> Result<Vec<f32>>;
62    fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>>;
63    fn dimension(&self) -> usize;
64    fn model_name(&self) -> &str;
65    fn provider_name(&self) -> &'static str;
66    fn batch_size(&self) -> usize {
67        1
68    }
69}
70
71pub struct Embedder {
72    provider: Box<dyn EmbeddingProvider>,
73}
74
75impl Embedder {
76    pub fn new() -> Result<Self> {
77        let config = load_shared_config()
78            .map(|c| c.embedding)
79            .unwrap_or_default();
80        Self::with_config(&config)
81    }
82
83    pub fn with_config(config: &EmbeddingConfig) -> Result<Self> {
84        let provider = provider_from_config(config)?;
85        Ok(Self { provider })
86    }
87
88    pub fn from_provider(provider: Box<dyn EmbeddingProvider>) -> Self {
89        Self { provider }
90    }
91
92    pub fn embed(&self, text: &str) -> Result<Vec<f32>> {
93        self.provider.embed(text)
94    }
95
96    pub fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
97        self.provider.embed_batch(texts)
98    }
99
100    #[must_use]
101    pub fn dimension(&self) -> usize {
102        self.provider.dimension()
103    }
104
105    #[must_use]
106    pub fn batch_size(&self) -> usize {
107        self.provider.batch_size()
108    }
109
110    #[must_use]
111    pub fn model_name(&self) -> &str {
112        self.provider.model_name()
113    }
114
115    #[must_use]
116    pub fn provider_name(&self) -> &'static str {
117        self.provider.provider_name()
118    }
119}
120
121// =============================================================================
122// FastEmbed Provider
123// =============================================================================
124
125#[cfg(feature = "fastembed")]
126struct FastEmbedProvider {
127    model: Mutex<TextEmbedding>,
128    model_name: String,
129    dimension: usize,
130    batch_size: usize,
131}
132
133#[cfg(feature = "fastembed")]
134impl FastEmbedProvider {
135    fn new(config: &EmbeddingConfig) -> Result<Self> {
136        let embedding_model = fastembed_model_from_string(&config.model)?;
137        let (model_name, dimension) = {
138            let model_info = TextEmbedding::get_model_info(&embedding_model)
139                .map_err(|e| EmbeddingError::UnknownModel(format!("{}: {e}", config.model)))?;
140
141            if let Some(configured_dim) = config.dimension
142                && configured_dim != model_info.dim
143            {
144                return Err(EmbeddingError::DimensionMismatch {
145                    model: config.model.clone(),
146                    expected: model_info.dim,
147                    configured: configured_dim,
148                });
149            }
150
151            (model_info.model_code.clone(), model_info.dim)
152        };
153
154        let model = TextEmbedding::try_new(InitOptions::new(embedding_model))
155            .map_err(|e| EmbeddingError::InitError(e.to_string()))?;
156
157        Ok(Self {
158            model: Mutex::new(model),
159            model_name,
160            dimension,
161            batch_size: config.batch_size.max(1),
162        })
163    }
164}
165
166#[cfg(feature = "fastembed")]
167impl EmbeddingProvider for FastEmbedProvider {
168    fn embed(&self, text: &str) -> Result<Vec<f32>> {
169        let embeddings = {
170            let model = self.model.lock().map_err(|_| {
171                EmbeddingError::ProviderUnavailable("model lock poisoned".to_string())
172            })?;
173
174            model
175                .embed(vec![text], None)
176                .map_err(|e| EmbeddingError::EmbedError(e.to_string()))?
177        };
178
179        let mut embedding = embeddings
180            .into_iter()
181            .next()
182            .ok_or(EmbeddingError::EmptyResult)?;
183        l2_normalize(&mut embedding);
184        Ok(embedding)
185    }
186
187    fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
188        if texts.is_empty() {
189            return Ok(Vec::new());
190        }
191
192        let mut all_embeddings = Vec::with_capacity(texts.len());
193
194        for chunk in texts.chunks(self.batch_size) {
195            let embeddings = {
196                let model = self.model.lock().map_err(|_| {
197                    EmbeddingError::ProviderUnavailable("model lock poisoned".to_string())
198                })?;
199
200                model
201                    .embed(chunk.to_vec(), None)
202                    .map_err(|e| EmbeddingError::EmbedError(e.to_string()))?
203            };
204            all_embeddings.extend(embeddings);
205        }
206
207        for embedding in &mut all_embeddings {
208            l2_normalize(embedding);
209        }
210
211        Ok(all_embeddings)
212    }
213
214    fn dimension(&self) -> usize {
215        self.dimension
216    }
217
218    fn model_name(&self) -> &str {
219        &self.model_name
220    }
221
222    fn provider_name(&self) -> &'static str {
223        "fastembed"
224    }
225
226    fn batch_size(&self) -> usize {
227        self.batch_size
228    }
229}
230
231#[cfg(feature = "fastembed")]
232fn fastembed_model_from_string(model_name: &str) -> Result<EmbeddingModel> {
233    let trimmed = model_name.trim();
234    if trimmed.is_empty() {
235        return Err(EmbeddingError::UnknownModel(model_name.to_string()));
236    }
237
238    if let Ok(model) = trimmed.parse() {
239        return Ok(model);
240    }
241
242    let needle = normalize_model_token(trimmed);
243    let needle_suffix = normalize_model_token(trimmed.rsplit('/').next().unwrap_or(trimmed));
244
245    for info in TextEmbedding::list_supported_models() {
246        for candidate in model_identifiers(&info.model_code) {
247            if candidate == needle || candidate == needle_suffix {
248                return Ok(info.model);
249            }
250        }
251    }
252
253    Err(EmbeddingError::UnknownModel(model_name.to_string()))
254}
255
256#[cfg(feature = "fastembed")]
257fn model_identifiers(model_code: &str) -> Vec<String> {
258    let normalized = normalize_model_token(model_code);
259    let suffix = model_code.rsplit('/').next().unwrap_or(model_code);
260    let suffix_normalized = normalize_model_token(suffix);
261
262    let mut identifiers = vec![normalized, suffix_normalized];
263
264    for value in [suffix.strip_suffix("-onnx"), suffix.strip_suffix("-onnx-q")]
265        .into_iter()
266        .flatten()
267    {
268        identifiers.push(normalize_model_token(value));
269    }
270
271    identifiers
272}
273
274fn normalize_model_token(value: &str) -> String {
275    value
276        .chars()
277        .filter(char::is_ascii_alphanumeric)
278        .map(|c| c.to_ascii_lowercase())
279        .collect()
280}
281
282fn l2_normalize(embedding: &mut [f32]) {
283    let norm = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
284    if norm <= 0.0 {
285        return;
286    }
287
288    for x in embedding {
289        *x /= norm;
290    }
291}
292
293// =============================================================================
294// Candle Provider
295// =============================================================================
296
297#[cfg(feature = "candle")]
298struct CandleProvider {
299    model: Mutex<BertModel>,
300    tokenizer: Mutex<Tokenizer>,
301    device: Device,
302    model_name: String,
303    dimension: usize,
304    batch_size: usize,
305}
306
307#[cfg(feature = "candle")]
308impl CandleProvider {
309    fn new(config: &EmbeddingConfig) -> Result<Self> {
310        let device = Self::select_device();
311        let model_id = if config.model.is_empty() {
312            "sentence-transformers/all-MiniLM-L6-v2"
313        } else {
314            &config.model
315        };
316
317        let (model, tokenizer, dimension) = Self::load_model(model_id, &device)?;
318
319        if let Some(configured_dim) = config.dimension
320            && configured_dim != dimension
321        {
322            return Err(EmbeddingError::DimensionMismatch {
323                model: model_id.to_string(),
324                expected: dimension,
325                configured: configured_dim,
326            });
327        }
328
329        Ok(Self {
330            model: Mutex::new(model),
331            tokenizer: Mutex::new(tokenizer),
332            device,
333            model_name: model_id.to_string(),
334            dimension,
335            batch_size: config.batch_size.max(1),
336        })
337    }
338
339    #[allow(clippy::missing_const_for_fn)] // Can't be const when metal/cuda features call non-const fns
340    fn select_device() -> Device {
341        #[cfg(feature = "metal")]
342        {
343            Device::new_metal(0).unwrap_or(Device::Cpu)
344        }
345        #[cfg(all(feature = "cuda", not(feature = "metal")))]
346        {
347            Device::new_cuda(0).unwrap_or(Device::Cpu)
348        }
349        #[cfg(not(any(feature = "metal", feature = "cuda")))]
350        {
351            Device::Cpu
352        }
353    }
354
355    fn load_model(model_id: &str, device: &Device) -> Result<(BertModel, Tokenizer, usize)> {
356        let api = Api::new().map_err(|e| EmbeddingError::InitError(e.to_string()))?;
357        let repo = api.repo(Repo::new(model_id.to_string(), RepoType::Model));
358
359        // Download model files
360        let config_path = repo
361            .get("config.json")
362            .map_err(|e| EmbeddingError::InitError(format!("Failed to get config: {e}")))?;
363        let tokenizer_path = repo
364            .get("tokenizer.json")
365            .map_err(|e| EmbeddingError::InitError(format!("Failed to get tokenizer: {e}")))?;
366        let weights_path = repo
367            .get("model.safetensors")
368            .or_else(|_| repo.get("pytorch_model.bin"))
369            .map_err(|e| EmbeddingError::InitError(format!("Failed to get weights: {e}")))?;
370
371        // Load config
372        let config_str = std::fs::read_to_string(&config_path)
373            .map_err(|e| EmbeddingError::InitError(format!("Failed to read config: {e}")))?;
374        let bert_config: BertConfig = serde_json::from_str(&config_str)
375            .map_err(|e| EmbeddingError::InitError(format!("Failed to parse config: {e}")))?;
376        let dimension = bert_config.hidden_size;
377
378        // Load tokenizer with padding + truncation
379        let mut tokenizer = Tokenizer::from_file(&tokenizer_path)
380            .map_err(|e| EmbeddingError::InitError(format!("Failed to load tokenizer: {e}")))?;
381        tokenizer
382            .with_truncation(Some(TruncationParams {
383                max_length: bert_config.max_position_embeddings,
384                strategy: TruncationStrategy::LongestFirst,
385                ..Default::default()
386            }))
387            .map_err(|e| {
388                EmbeddingError::InitError(format!("Failed to configure tokenizer truncation: {e}"))
389            })?;
390        tokenizer.with_padding(Some(PaddingParams {
391            strategy: PaddingStrategy::BatchLongest,
392            ..Default::default()
393        }));
394
395        // Load model weights
396        // SAFETY: We just downloaded this file from HuggingFace Hub and trust its contents.
397        // Memory-mapping provides significant performance benefits for large model files.
398        #[allow(unsafe_code)]
399        let vb = if weights_path
400            .extension()
401            .is_some_and(|ext| ext == "safetensors")
402        {
403            unsafe {
404                VarBuilder::from_mmaped_safetensors(&[weights_path], DTYPE, device).map_err(
405                    |e| EmbeddingError::InitError(format!("Failed to load weights: {e}")),
406                )?
407            }
408        } else {
409            VarBuilder::from_pth(&weights_path, DTYPE, device)
410                .map_err(|e| EmbeddingError::InitError(format!("Failed to load weights: {e}")))?
411        };
412
413        let model = BertModel::load(vb, &bert_config)
414            .map_err(|e| EmbeddingError::InitError(format!("Failed to build model: {e}")))?;
415
416        Ok((model, tokenizer, dimension))
417    }
418
419    fn embed_tokens(&self, token_ids: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
420        let token_type_ids = token_ids
421            .zeros_like()
422            .map_err(|e| EmbeddingError::EmbedError(e.to_string()))?;
423
424        // Hold the lock only for the forward pass
425        let embeddings = self
426            .model
427            .lock()
428            .map_err(|_| EmbeddingError::ProviderUnavailable("model lock poisoned".to_string()))?
429            .forward(token_ids, &token_type_ids, Some(attention_mask))
430            .map_err(|e| EmbeddingError::EmbedError(e.to_string()))?;
431
432        // Mean pooling with attention mask
433        let mask_expanded = attention_mask
434            .unsqueeze(2)
435            .map_err(|e| EmbeddingError::EmbedError(e.to_string()))?
436            .broadcast_as(embeddings.shape())
437            .map_err(|e| EmbeddingError::EmbedError(e.to_string()))?
438            .to_dtype(embeddings.dtype())
439            .map_err(|e| EmbeddingError::EmbedError(e.to_string()))?;
440
441        let masked = embeddings
442            .mul(&mask_expanded)
443            .map_err(|e| EmbeddingError::EmbedError(e.to_string()))?;
444
445        let summed = masked
446            .sum(1)
447            .map_err(|e| EmbeddingError::EmbedError(e.to_string()))?;
448
449        let mask_sum = mask_expanded
450            .sum(1)
451            .map_err(|e| EmbeddingError::EmbedError(e.to_string()))?
452            .clamp(1e-9, f64::MAX)
453            .map_err(|e| EmbeddingError::EmbedError(e.to_string()))?;
454
455        let pooled = summed
456            .div(&mask_sum)
457            .map_err(|e| EmbeddingError::EmbedError(e.to_string()))?;
458
459        // L2 normalize
460        let norm = pooled
461            .sqr()
462            .map_err(|e| EmbeddingError::EmbedError(e.to_string()))?
463            .sum(1)
464            .map_err(|e| EmbeddingError::EmbedError(e.to_string()))?
465            .sqrt()
466            .map_err(|e| EmbeddingError::EmbedError(e.to_string()))?
467            .unsqueeze(1)
468            .map_err(|e| EmbeddingError::EmbedError(e.to_string()))?
469            .clamp(1e-9, f64::MAX)
470            .map_err(|e| EmbeddingError::EmbedError(e.to_string()))?
471            .broadcast_as(pooled.shape())
472            .map_err(|e| EmbeddingError::EmbedError(e.to_string()))?;
473
474        pooled
475            .div(&norm)
476            .map_err(|e| EmbeddingError::EmbedError(e.to_string()))
477    }
478}
479
480#[cfg(feature = "candle")]
481impl EmbeddingProvider for CandleProvider {
482    fn embed(&self, text: &str) -> Result<Vec<f32>> {
483        let encoding = self
484            .tokenizer
485            .lock()
486            .map_err(|_| {
487                EmbeddingError::ProviderUnavailable("tokenizer lock poisoned".to_string())
488            })?
489            .encode(text, true)
490            .map_err(|e| EmbeddingError::EmbedError(e.to_string()))?;
491
492        let ids: Vec<u32> = encoding.get_ids().to_vec();
493        let mask: Vec<u32> = encoding.get_attention_mask().to_vec();
494
495        let token_ids = Tensor::new(&ids[..], &self.device)
496            .map_err(|e| EmbeddingError::EmbedError(e.to_string()))?
497            .unsqueeze(0)
498            .map_err(|e| EmbeddingError::EmbedError(e.to_string()))?;
499
500        let attention_mask = Tensor::new(&mask[..], &self.device)
501            .map_err(|e| EmbeddingError::EmbedError(e.to_string()))?
502            .unsqueeze(0)
503            .map_err(|e| EmbeddingError::EmbedError(e.to_string()))?;
504
505        let embeddings = self.embed_tokens(&token_ids, &attention_mask)?;
506
507        embeddings
508            .squeeze(0)
509            .map_err(|e| EmbeddingError::EmbedError(e.to_string()))?
510            .to_vec1()
511            .map_err(|e| EmbeddingError::EmbedError(e.to_string()))
512    }
513
514    fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
515        if texts.is_empty() {
516            return Ok(Vec::new());
517        }
518
519        let mut all_embeddings = Vec::with_capacity(texts.len());
520
521        for chunk in texts.chunks(self.batch_size) {
522            let encodings = self
523                .tokenizer
524                .lock()
525                .map_err(|_| {
526                    EmbeddingError::ProviderUnavailable("tokenizer lock poisoned".to_string())
527                })?
528                .encode_batch(chunk.to_vec(), true)
529                .map_err(|e| EmbeddingError::EmbedError(e.to_string()))?;
530
531            let batch_len = encodings.len();
532            let seq_len = encodings
533                .iter()
534                .map(tokenizers::Encoding::len)
535                .max()
536                .unwrap_or(0);
537
538            let mut ids_flat: Vec<u32> = Vec::with_capacity(batch_len * seq_len);
539            let mut mask_flat: Vec<u32> = Vec::with_capacity(batch_len * seq_len);
540
541            for enc in &encodings {
542                ids_flat.extend(enc.get_ids());
543                mask_flat.extend(enc.get_attention_mask());
544            }
545
546            let token_ids = Tensor::new(&ids_flat[..], &self.device)
547                .map_err(|e| EmbeddingError::EmbedError(e.to_string()))?
548                .reshape((batch_len, seq_len))
549                .map_err(|e| EmbeddingError::EmbedError(e.to_string()))?;
550
551            let attention_mask = Tensor::new(&mask_flat[..], &self.device)
552                .map_err(|e| EmbeddingError::EmbedError(e.to_string()))?
553                .reshape((batch_len, seq_len))
554                .map_err(|e| EmbeddingError::EmbedError(e.to_string()))?;
555
556            let embeddings = self.embed_tokens(&token_ids, &attention_mask)?;
557
558            let batch_embeddings: Vec<Vec<f32>> = embeddings
559                .to_vec2()
560                .map_err(|e| EmbeddingError::EmbedError(e.to_string()))?;
561
562            all_embeddings.extend(batch_embeddings);
563        }
564
565        Ok(all_embeddings)
566    }
567
568    fn dimension(&self) -> usize {
569        self.dimension
570    }
571
572    fn model_name(&self) -> &str {
573        &self.model_name
574    }
575
576    fn provider_name(&self) -> &'static str {
577        #[cfg(feature = "metal")]
578        {
579            "candle-metal"
580        }
581        #[cfg(all(feature = "cuda", not(feature = "metal")))]
582        {
583            "candle-cuda"
584        }
585        #[cfg(not(any(feature = "metal", feature = "cuda")))]
586        {
587            "candle-cpu"
588        }
589    }
590
591    fn batch_size(&self) -> usize {
592        self.batch_size
593    }
594}
595
596// =============================================================================
597// Provider Factory
598// =============================================================================
599
600fn provider_from_config(config: &EmbeddingConfig) -> Result<Box<dyn EmbeddingProvider>> {
601    let provider = config.provider.trim().to_lowercase();
602    match provider.as_str() {
603        #[cfg(feature = "fastembed")]
604        "fastembed" | "fastembed-rs" => Ok(Box::new(FastEmbedProvider::new(config)?)),
605
606        #[cfg(not(feature = "fastembed"))]
607        "fastembed" | "fastembed-rs" => Err(EmbeddingError::ProviderNotCompiled {
608            provider: "fastembed".to_string(),
609            feature: "fastembed".to_string(),
610        }),
611
612        #[cfg(feature = "candle")]
613        "candle" | "candle-rs" => Ok(Box::new(CandleProvider::new(config)?)),
614
615        #[cfg(not(feature = "candle"))]
616        "candle" | "candle-rs" => Err(EmbeddingError::ProviderNotCompiled {
617            provider: "candle".to_string(),
618            feature: "candle".to_string(),
619        }),
620
621        _ => Err(EmbeddingError::UnknownProvider(config.provider.clone())),
622    }
623}
624
625// =============================================================================
626// Tests
627// =============================================================================
628
629#[cfg(test)]
630mod tests {
631    use super::*;
632
633    #[test]
634    #[cfg(feature = "fastembed")]
635    fn test_fastembed_model_from_string() {
636        assert!(fastembed_model_from_string("BAAI/bge-small-en-v1.5").is_ok());
637        assert!(fastembed_model_from_string("bge-small-en-v1.5").is_ok());
638        assert!(fastembed_model_from_string("all-MiniLM-L6-v2").is_ok());
639        assert!(fastembed_model_from_string("AllMiniLML6V2").is_ok());
640        assert!(fastembed_model_from_string("unknown-model").is_err());
641    }
642
643    #[test]
644    #[ignore = "Requires downloading model (~30MB)"]
645    fn test_embed_text() {
646        let embedder = Embedder::new().unwrap();
647        let embedding = embedder.embed("Hello, world!").unwrap();
648        assert_eq!(embedding.len(), embedder.dimension());
649    }
650
651    #[test]
652    #[ignore = "Requires downloading model (~30MB)"]
653    fn test_embed_batch() {
654        let embedder = Embedder::new().unwrap();
655        let embeddings = embedder
656            .embed_batch(&["First text", "Second text", "Third text"])
657            .unwrap();
658        assert_eq!(embeddings.len(), 3);
659        assert!(embeddings.iter().all(|e| e.len() == embedder.dimension()));
660    }
661
662    #[test]
663    fn test_embed_batch_empty() {
664        let config = EmbeddingConfig::default();
665        if let Ok(embedder) = Embedder::with_config(&config) {
666            let result = embedder.embed_batch(&[]).unwrap();
667            assert!(result.is_empty());
668        }
669    }
670
671    #[test]
672    #[cfg(feature = "candle")]
673    #[ignore = "Requires downloading model (~90MB)"]
674    fn test_candle_embed_text() {
675        let config = EmbeddingConfig {
676            provider: "candle".to_string(),
677            model: "sentence-transformers/all-MiniLM-L6-v2".to_string(),
678            ..Default::default()
679        };
680        let embedder = Embedder::with_config(&config).unwrap();
681        let embedding = embedder.embed("Hello, world!").unwrap();
682        assert_eq!(embedding.len(), 384);
683        assert!(embedder.provider_name().starts_with("candle"));
684    }
685
686    #[test]
687    #[cfg(feature = "candle")]
688    #[ignore = "Requires downloading model (~90MB)"]
689    fn test_candle_embed_batch() {
690        let config = EmbeddingConfig {
691            provider: "candle".to_string(),
692            model: "sentence-transformers/all-MiniLM-L6-v2".to_string(),
693            batch_size: 2,
694            ..Default::default()
695        };
696        let embedder = Embedder::with_config(&config).unwrap();
697        let embeddings = embedder
698            .embed_batch(&["First text", "Second text", "Third text"])
699            .unwrap();
700        assert_eq!(embeddings.len(), 3);
701        assert!(embeddings.iter().all(|e| e.len() == 384));
702    }
703
704    #[test]
705    #[cfg(feature = "candle")]
706    #[ignore = "Requires downloading model (~1.3GB)"]
707    fn test_candle_bge_large() {
708        let config = EmbeddingConfig {
709            provider: "candle".to_string(),
710            model: "BAAI/bge-large-en-v1.5".to_string(),
711            batch_size: 8,
712            ..Default::default()
713        };
714        let embedder = Embedder::with_config(&config).unwrap();
715
716        // Verify model loaded correctly
717        assert_eq!(embedder.dimension(), 1024);
718        assert_eq!(embedder.model_name(), "BAAI/bge-large-en-v1.5");
719
720        // Test single embedding
721        let embedding = embedder.embed("Hello, world!").unwrap();
722        assert_eq!(embedding.len(), 1024);
723
724        // Test batch embedding
725        let embeddings = embedder
726            .embed_batch(&["First text", "Second text"])
727            .unwrap();
728        assert_eq!(embeddings.len(), 2);
729        assert!(embeddings.iter().all(|e| e.len() == 1024));
730    }
731}