1use ix_config::{EmbeddingConfig, load_shared_config};
8use std::sync::Mutex;
9use thiserror::Error;
10
11#[cfg(feature = "fastembed")]
12use fastembed::{EmbeddingModel, InitOptions, TextEmbedding};
13
14#[cfg(feature = "candle")]
15use candle_core::{Device, Tensor};
16#[cfg(feature = "candle")]
17use candle_nn::VarBuilder;
18#[cfg(feature = "candle")]
19use candle_transformers::models::bert::{BertModel, Config as BertConfig, DTYPE};
20#[cfg(feature = "candle")]
21use hf_hub::{Repo, RepoType, api::sync::Api};
22#[cfg(feature = "candle")]
23use tokenizers::{PaddingParams, PaddingStrategy, Tokenizer, TruncationParams, TruncationStrategy};
24
25#[derive(Debug, Error)]
26pub enum EmbeddingError {
27 #[error("Failed to initialize embedding model: {0}")]
28 InitError(String),
29
30 #[error("Failed to generate embedding: {0}")]
31 EmbedError(String),
32
33 #[error("Embedding provider unavailable: {0}")]
34 ProviderUnavailable(String),
35
36 #[error("No embedding returned for input")]
37 EmptyResult,
38
39 #[error("Unknown provider: {0}")]
40 UnknownProvider(String),
41
42 #[error("Unknown model: {0}")]
43 UnknownModel(String),
44
45 #[error(
46 "Embedding dimension mismatch for model {model}: expected {expected}, configured {configured}"
47 )]
48 DimensionMismatch {
49 model: String,
50 expected: usize,
51 configured: usize,
52 },
53
54 #[error("Provider not available: {provider} (enable the '{feature}' feature)")]
55 ProviderNotCompiled { provider: String, feature: String },
56}
57
58pub type Result<T> = std::result::Result<T, EmbeddingError>;
59
60pub trait EmbeddingProvider: Send + Sync {
61 fn embed(&self, text: &str) -> Result<Vec<f32>>;
62 fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>>;
63 fn dimension(&self) -> usize;
64 fn model_name(&self) -> &str;
65 fn provider_name(&self) -> &'static str;
66 fn batch_size(&self) -> usize {
67 1
68 }
69}
70
71pub struct Embedder {
72 provider: Box<dyn EmbeddingProvider>,
73}
74
75impl Embedder {
76 pub fn new() -> Result<Self> {
77 let config = load_shared_config()
78 .map(|c| c.embedding)
79 .unwrap_or_default();
80 Self::with_config(&config)
81 }
82
83 pub fn with_config(config: &EmbeddingConfig) -> Result<Self> {
84 let provider = provider_from_config(config)?;
85 Ok(Self { provider })
86 }
87
88 pub fn from_provider(provider: Box<dyn EmbeddingProvider>) -> Self {
89 Self { provider }
90 }
91
92 pub fn embed(&self, text: &str) -> Result<Vec<f32>> {
93 self.provider.embed(text)
94 }
95
96 pub fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
97 self.provider.embed_batch(texts)
98 }
99
100 #[must_use]
101 pub fn dimension(&self) -> usize {
102 self.provider.dimension()
103 }
104
105 #[must_use]
106 pub fn batch_size(&self) -> usize {
107 self.provider.batch_size()
108 }
109
110 #[must_use]
111 pub fn model_name(&self) -> &str {
112 self.provider.model_name()
113 }
114
115 #[must_use]
116 pub fn provider_name(&self) -> &'static str {
117 self.provider.provider_name()
118 }
119}
120
121#[cfg(feature = "fastembed")]
126struct FastEmbedProvider {
127 model: Mutex<TextEmbedding>,
128 model_name: String,
129 dimension: usize,
130 batch_size: usize,
131}
132
133#[cfg(feature = "fastembed")]
134impl FastEmbedProvider {
135 fn new(config: &EmbeddingConfig) -> Result<Self> {
136 let embedding_model = fastembed_model_from_string(&config.model)?;
137 let (model_name, dimension) = {
138 let model_info = TextEmbedding::get_model_info(&embedding_model)
139 .map_err(|e| EmbeddingError::UnknownModel(format!("{}: {e}", config.model)))?;
140
141 if let Some(configured_dim) = config.dimension
142 && configured_dim != model_info.dim
143 {
144 return Err(EmbeddingError::DimensionMismatch {
145 model: config.model.clone(),
146 expected: model_info.dim,
147 configured: configured_dim,
148 });
149 }
150
151 (model_info.model_code.clone(), model_info.dim)
152 };
153
154 let model = TextEmbedding::try_new(InitOptions::new(embedding_model))
155 .map_err(|e| EmbeddingError::InitError(e.to_string()))?;
156
157 Ok(Self {
158 model: Mutex::new(model),
159 model_name,
160 dimension,
161 batch_size: config.batch_size.max(1),
162 })
163 }
164}
165
166#[cfg(feature = "fastembed")]
167impl EmbeddingProvider for FastEmbedProvider {
168 fn embed(&self, text: &str) -> Result<Vec<f32>> {
169 let embeddings = {
170 let model = self.model.lock().map_err(|_| {
171 EmbeddingError::ProviderUnavailable("model lock poisoned".to_string())
172 })?;
173
174 model
175 .embed(vec![text], None)
176 .map_err(|e| EmbeddingError::EmbedError(e.to_string()))?
177 };
178
179 let mut embedding = embeddings
180 .into_iter()
181 .next()
182 .ok_or(EmbeddingError::EmptyResult)?;
183 l2_normalize(&mut embedding);
184 Ok(embedding)
185 }
186
187 fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
188 if texts.is_empty() {
189 return Ok(Vec::new());
190 }
191
192 let mut all_embeddings = Vec::with_capacity(texts.len());
193
194 for chunk in texts.chunks(self.batch_size) {
195 let embeddings = {
196 let model = self.model.lock().map_err(|_| {
197 EmbeddingError::ProviderUnavailable("model lock poisoned".to_string())
198 })?;
199
200 model
201 .embed(chunk.to_vec(), None)
202 .map_err(|e| EmbeddingError::EmbedError(e.to_string()))?
203 };
204 all_embeddings.extend(embeddings);
205 }
206
207 for embedding in &mut all_embeddings {
208 l2_normalize(embedding);
209 }
210
211 Ok(all_embeddings)
212 }
213
214 fn dimension(&self) -> usize {
215 self.dimension
216 }
217
218 fn model_name(&self) -> &str {
219 &self.model_name
220 }
221
222 fn provider_name(&self) -> &'static str {
223 "fastembed"
224 }
225
226 fn batch_size(&self) -> usize {
227 self.batch_size
228 }
229}
230
231#[cfg(feature = "fastembed")]
232fn fastembed_model_from_string(model_name: &str) -> Result<EmbeddingModel> {
233 let trimmed = model_name.trim();
234 if trimmed.is_empty() {
235 return Err(EmbeddingError::UnknownModel(model_name.to_string()));
236 }
237
238 if let Ok(model) = trimmed.parse() {
239 return Ok(model);
240 }
241
242 let needle = normalize_model_token(trimmed);
243 let needle_suffix = normalize_model_token(trimmed.rsplit('/').next().unwrap_or(trimmed));
244
245 for info in TextEmbedding::list_supported_models() {
246 for candidate in model_identifiers(&info.model_code) {
247 if candidate == needle || candidate == needle_suffix {
248 return Ok(info.model);
249 }
250 }
251 }
252
253 Err(EmbeddingError::UnknownModel(model_name.to_string()))
254}
255
256#[cfg(feature = "fastembed")]
257fn model_identifiers(model_code: &str) -> Vec<String> {
258 let normalized = normalize_model_token(model_code);
259 let suffix = model_code.rsplit('/').next().unwrap_or(model_code);
260 let suffix_normalized = normalize_model_token(suffix);
261
262 let mut identifiers = vec![normalized, suffix_normalized];
263
264 for value in [suffix.strip_suffix("-onnx"), suffix.strip_suffix("-onnx-q")]
265 .into_iter()
266 .flatten()
267 {
268 identifiers.push(normalize_model_token(value));
269 }
270
271 identifiers
272}
273
274fn normalize_model_token(value: &str) -> String {
275 value
276 .chars()
277 .filter(char::is_ascii_alphanumeric)
278 .map(|c| c.to_ascii_lowercase())
279 .collect()
280}
281
282fn l2_normalize(embedding: &mut [f32]) {
283 let norm = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
284 if norm <= 0.0 {
285 return;
286 }
287
288 for x in embedding {
289 *x /= norm;
290 }
291}
292
293#[cfg(feature = "candle")]
298struct CandleProvider {
299 model: Mutex<BertModel>,
300 tokenizer: Mutex<Tokenizer>,
301 device: Device,
302 model_name: String,
303 dimension: usize,
304 batch_size: usize,
305}
306
307#[cfg(feature = "candle")]
308impl CandleProvider {
309 fn new(config: &EmbeddingConfig) -> Result<Self> {
310 let device = Self::select_device();
311 let model_id = if config.model.is_empty() {
312 "sentence-transformers/all-MiniLM-L6-v2"
313 } else {
314 &config.model
315 };
316
317 let (model, tokenizer, dimension) = Self::load_model(model_id, &device)?;
318
319 if let Some(configured_dim) = config.dimension
320 && configured_dim != dimension
321 {
322 return Err(EmbeddingError::DimensionMismatch {
323 model: model_id.to_string(),
324 expected: dimension,
325 configured: configured_dim,
326 });
327 }
328
329 Ok(Self {
330 model: Mutex::new(model),
331 tokenizer: Mutex::new(tokenizer),
332 device,
333 model_name: model_id.to_string(),
334 dimension,
335 batch_size: config.batch_size.max(1),
336 })
337 }
338
339 #[allow(clippy::missing_const_for_fn)] fn select_device() -> Device {
341 #[cfg(feature = "metal")]
342 {
343 Device::new_metal(0).unwrap_or(Device::Cpu)
344 }
345 #[cfg(all(feature = "cuda", not(feature = "metal")))]
346 {
347 Device::new_cuda(0).unwrap_or(Device::Cpu)
348 }
349 #[cfg(not(any(feature = "metal", feature = "cuda")))]
350 {
351 Device::Cpu
352 }
353 }
354
355 fn load_model(model_id: &str, device: &Device) -> Result<(BertModel, Tokenizer, usize)> {
356 let api = Api::new().map_err(|e| EmbeddingError::InitError(e.to_string()))?;
357 let repo = api.repo(Repo::new(model_id.to_string(), RepoType::Model));
358
359 let config_path = repo
361 .get("config.json")
362 .map_err(|e| EmbeddingError::InitError(format!("Failed to get config: {e}")))?;
363 let tokenizer_path = repo
364 .get("tokenizer.json")
365 .map_err(|e| EmbeddingError::InitError(format!("Failed to get tokenizer: {e}")))?;
366 let weights_path = repo
367 .get("model.safetensors")
368 .or_else(|_| repo.get("pytorch_model.bin"))
369 .map_err(|e| EmbeddingError::InitError(format!("Failed to get weights: {e}")))?;
370
371 let config_str = std::fs::read_to_string(&config_path)
373 .map_err(|e| EmbeddingError::InitError(format!("Failed to read config: {e}")))?;
374 let bert_config: BertConfig = serde_json::from_str(&config_str)
375 .map_err(|e| EmbeddingError::InitError(format!("Failed to parse config: {e}")))?;
376 let dimension = bert_config.hidden_size;
377
378 let mut tokenizer = Tokenizer::from_file(&tokenizer_path)
380 .map_err(|e| EmbeddingError::InitError(format!("Failed to load tokenizer: {e}")))?;
381 tokenizer
382 .with_truncation(Some(TruncationParams {
383 max_length: bert_config.max_position_embeddings,
384 strategy: TruncationStrategy::LongestFirst,
385 ..Default::default()
386 }))
387 .map_err(|e| {
388 EmbeddingError::InitError(format!("Failed to configure tokenizer truncation: {e}"))
389 })?;
390 tokenizer.with_padding(Some(PaddingParams {
391 strategy: PaddingStrategy::BatchLongest,
392 ..Default::default()
393 }));
394
395 #[allow(unsafe_code)]
399 let vb = if weights_path
400 .extension()
401 .is_some_and(|ext| ext == "safetensors")
402 {
403 unsafe {
404 VarBuilder::from_mmaped_safetensors(&[weights_path], DTYPE, device).map_err(
405 |e| EmbeddingError::InitError(format!("Failed to load weights: {e}")),
406 )?
407 }
408 } else {
409 VarBuilder::from_pth(&weights_path, DTYPE, device)
410 .map_err(|e| EmbeddingError::InitError(format!("Failed to load weights: {e}")))?
411 };
412
413 let model = BertModel::load(vb, &bert_config)
414 .map_err(|e| EmbeddingError::InitError(format!("Failed to build model: {e}")))?;
415
416 Ok((model, tokenizer, dimension))
417 }
418
419 fn embed_tokens(&self, token_ids: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
420 let token_type_ids = token_ids
421 .zeros_like()
422 .map_err(|e| EmbeddingError::EmbedError(e.to_string()))?;
423
424 let embeddings = self
426 .model
427 .lock()
428 .map_err(|_| EmbeddingError::ProviderUnavailable("model lock poisoned".to_string()))?
429 .forward(token_ids, &token_type_ids, Some(attention_mask))
430 .map_err(|e| EmbeddingError::EmbedError(e.to_string()))?;
431
432 let mask_expanded = attention_mask
434 .unsqueeze(2)
435 .map_err(|e| EmbeddingError::EmbedError(e.to_string()))?
436 .broadcast_as(embeddings.shape())
437 .map_err(|e| EmbeddingError::EmbedError(e.to_string()))?
438 .to_dtype(embeddings.dtype())
439 .map_err(|e| EmbeddingError::EmbedError(e.to_string()))?;
440
441 let masked = embeddings
442 .mul(&mask_expanded)
443 .map_err(|e| EmbeddingError::EmbedError(e.to_string()))?;
444
445 let summed = masked
446 .sum(1)
447 .map_err(|e| EmbeddingError::EmbedError(e.to_string()))?;
448
449 let mask_sum = mask_expanded
450 .sum(1)
451 .map_err(|e| EmbeddingError::EmbedError(e.to_string()))?
452 .clamp(1e-9, f64::MAX)
453 .map_err(|e| EmbeddingError::EmbedError(e.to_string()))?;
454
455 let pooled = summed
456 .div(&mask_sum)
457 .map_err(|e| EmbeddingError::EmbedError(e.to_string()))?;
458
459 let norm = pooled
461 .sqr()
462 .map_err(|e| EmbeddingError::EmbedError(e.to_string()))?
463 .sum(1)
464 .map_err(|e| EmbeddingError::EmbedError(e.to_string()))?
465 .sqrt()
466 .map_err(|e| EmbeddingError::EmbedError(e.to_string()))?
467 .unsqueeze(1)
468 .map_err(|e| EmbeddingError::EmbedError(e.to_string()))?
469 .clamp(1e-9, f64::MAX)
470 .map_err(|e| EmbeddingError::EmbedError(e.to_string()))?
471 .broadcast_as(pooled.shape())
472 .map_err(|e| EmbeddingError::EmbedError(e.to_string()))?;
473
474 pooled
475 .div(&norm)
476 .map_err(|e| EmbeddingError::EmbedError(e.to_string()))
477 }
478}
479
480#[cfg(feature = "candle")]
481impl EmbeddingProvider for CandleProvider {
482 fn embed(&self, text: &str) -> Result<Vec<f32>> {
483 let encoding = self
484 .tokenizer
485 .lock()
486 .map_err(|_| {
487 EmbeddingError::ProviderUnavailable("tokenizer lock poisoned".to_string())
488 })?
489 .encode(text, true)
490 .map_err(|e| EmbeddingError::EmbedError(e.to_string()))?;
491
492 let ids: Vec<u32> = encoding.get_ids().to_vec();
493 let mask: Vec<u32> = encoding.get_attention_mask().to_vec();
494
495 let token_ids = Tensor::new(&ids[..], &self.device)
496 .map_err(|e| EmbeddingError::EmbedError(e.to_string()))?
497 .unsqueeze(0)
498 .map_err(|e| EmbeddingError::EmbedError(e.to_string()))?;
499
500 let attention_mask = Tensor::new(&mask[..], &self.device)
501 .map_err(|e| EmbeddingError::EmbedError(e.to_string()))?
502 .unsqueeze(0)
503 .map_err(|e| EmbeddingError::EmbedError(e.to_string()))?;
504
505 let embeddings = self.embed_tokens(&token_ids, &attention_mask)?;
506
507 embeddings
508 .squeeze(0)
509 .map_err(|e| EmbeddingError::EmbedError(e.to_string()))?
510 .to_vec1()
511 .map_err(|e| EmbeddingError::EmbedError(e.to_string()))
512 }
513
514 fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
515 if texts.is_empty() {
516 return Ok(Vec::new());
517 }
518
519 let mut all_embeddings = Vec::with_capacity(texts.len());
520
521 for chunk in texts.chunks(self.batch_size) {
522 let encodings = self
523 .tokenizer
524 .lock()
525 .map_err(|_| {
526 EmbeddingError::ProviderUnavailable("tokenizer lock poisoned".to_string())
527 })?
528 .encode_batch(chunk.to_vec(), true)
529 .map_err(|e| EmbeddingError::EmbedError(e.to_string()))?;
530
531 let batch_len = encodings.len();
532 let seq_len = encodings
533 .iter()
534 .map(tokenizers::Encoding::len)
535 .max()
536 .unwrap_or(0);
537
538 let mut ids_flat: Vec<u32> = Vec::with_capacity(batch_len * seq_len);
539 let mut mask_flat: Vec<u32> = Vec::with_capacity(batch_len * seq_len);
540
541 for enc in &encodings {
542 ids_flat.extend(enc.get_ids());
543 mask_flat.extend(enc.get_attention_mask());
544 }
545
546 let token_ids = Tensor::new(&ids_flat[..], &self.device)
547 .map_err(|e| EmbeddingError::EmbedError(e.to_string()))?
548 .reshape((batch_len, seq_len))
549 .map_err(|e| EmbeddingError::EmbedError(e.to_string()))?;
550
551 let attention_mask = Tensor::new(&mask_flat[..], &self.device)
552 .map_err(|e| EmbeddingError::EmbedError(e.to_string()))?
553 .reshape((batch_len, seq_len))
554 .map_err(|e| EmbeddingError::EmbedError(e.to_string()))?;
555
556 let embeddings = self.embed_tokens(&token_ids, &attention_mask)?;
557
558 let batch_embeddings: Vec<Vec<f32>> = embeddings
559 .to_vec2()
560 .map_err(|e| EmbeddingError::EmbedError(e.to_string()))?;
561
562 all_embeddings.extend(batch_embeddings);
563 }
564
565 Ok(all_embeddings)
566 }
567
568 fn dimension(&self) -> usize {
569 self.dimension
570 }
571
572 fn model_name(&self) -> &str {
573 &self.model_name
574 }
575
576 fn provider_name(&self) -> &'static str {
577 #[cfg(feature = "metal")]
578 {
579 "candle-metal"
580 }
581 #[cfg(all(feature = "cuda", not(feature = "metal")))]
582 {
583 "candle-cuda"
584 }
585 #[cfg(not(any(feature = "metal", feature = "cuda")))]
586 {
587 "candle-cpu"
588 }
589 }
590
591 fn batch_size(&self) -> usize {
592 self.batch_size
593 }
594}
595
596fn provider_from_config(config: &EmbeddingConfig) -> Result<Box<dyn EmbeddingProvider>> {
601 let provider = config.provider.trim().to_lowercase();
602 match provider.as_str() {
603 #[cfg(feature = "fastembed")]
604 "fastembed" | "fastembed-rs" => Ok(Box::new(FastEmbedProvider::new(config)?)),
605
606 #[cfg(not(feature = "fastembed"))]
607 "fastembed" | "fastembed-rs" => Err(EmbeddingError::ProviderNotCompiled {
608 provider: "fastembed".to_string(),
609 feature: "fastembed".to_string(),
610 }),
611
612 #[cfg(feature = "candle")]
613 "candle" | "candle-rs" => Ok(Box::new(CandleProvider::new(config)?)),
614
615 #[cfg(not(feature = "candle"))]
616 "candle" | "candle-rs" => Err(EmbeddingError::ProviderNotCompiled {
617 provider: "candle".to_string(),
618 feature: "candle".to_string(),
619 }),
620
621 _ => Err(EmbeddingError::UnknownProvider(config.provider.clone())),
622 }
623}
624
625#[cfg(test)]
630mod tests {
631 use super::*;
632
633 #[test]
634 #[cfg(feature = "fastembed")]
635 fn test_fastembed_model_from_string() {
636 assert!(fastembed_model_from_string("BAAI/bge-small-en-v1.5").is_ok());
637 assert!(fastembed_model_from_string("bge-small-en-v1.5").is_ok());
638 assert!(fastembed_model_from_string("all-MiniLM-L6-v2").is_ok());
639 assert!(fastembed_model_from_string("AllMiniLML6V2").is_ok());
640 assert!(fastembed_model_from_string("unknown-model").is_err());
641 }
642
643 #[test]
644 #[ignore = "Requires downloading model (~30MB)"]
645 fn test_embed_text() {
646 let embedder = Embedder::new().unwrap();
647 let embedding = embedder.embed("Hello, world!").unwrap();
648 assert_eq!(embedding.len(), embedder.dimension());
649 }
650
651 #[test]
652 #[ignore = "Requires downloading model (~30MB)"]
653 fn test_embed_batch() {
654 let embedder = Embedder::new().unwrap();
655 let embeddings = embedder
656 .embed_batch(&["First text", "Second text", "Third text"])
657 .unwrap();
658 assert_eq!(embeddings.len(), 3);
659 assert!(embeddings.iter().all(|e| e.len() == embedder.dimension()));
660 }
661
662 #[test]
663 fn test_embed_batch_empty() {
664 let config = EmbeddingConfig::default();
665 if let Ok(embedder) = Embedder::with_config(&config) {
666 let result = embedder.embed_batch(&[]).unwrap();
667 assert!(result.is_empty());
668 }
669 }
670
671 #[test]
672 #[cfg(feature = "candle")]
673 #[ignore = "Requires downloading model (~90MB)"]
674 fn test_candle_embed_text() {
675 let config = EmbeddingConfig {
676 provider: "candle".to_string(),
677 model: "sentence-transformers/all-MiniLM-L6-v2".to_string(),
678 ..Default::default()
679 };
680 let embedder = Embedder::with_config(&config).unwrap();
681 let embedding = embedder.embed("Hello, world!").unwrap();
682 assert_eq!(embedding.len(), 384);
683 assert!(embedder.provider_name().starts_with("candle"));
684 }
685
686 #[test]
687 #[cfg(feature = "candle")]
688 #[ignore = "Requires downloading model (~90MB)"]
689 fn test_candle_embed_batch() {
690 let config = EmbeddingConfig {
691 provider: "candle".to_string(),
692 model: "sentence-transformers/all-MiniLM-L6-v2".to_string(),
693 batch_size: 2,
694 ..Default::default()
695 };
696 let embedder = Embedder::with_config(&config).unwrap();
697 let embeddings = embedder
698 .embed_batch(&["First text", "Second text", "Third text"])
699 .unwrap();
700 assert_eq!(embeddings.len(), 3);
701 assert!(embeddings.iter().all(|e| e.len() == 384));
702 }
703
704 #[test]
705 #[cfg(feature = "candle")]
706 #[ignore = "Requires downloading model (~1.3GB)"]
707 fn test_candle_bge_large() {
708 let config = EmbeddingConfig {
709 provider: "candle".to_string(),
710 model: "BAAI/bge-large-en-v1.5".to_string(),
711 batch_size: 8,
712 ..Default::default()
713 };
714 let embedder = Embedder::with_config(&config).unwrap();
715
716 assert_eq!(embedder.dimension(), 1024);
718 assert_eq!(embedder.model_name(), "BAAI/bge-large-en-v1.5");
719
720 let embedding = embedder.embed("Hello, world!").unwrap();
722 assert_eq!(embedding.len(), 1024);
723
724 let embeddings = embedder
726 .embed_batch(&["First text", "Second text"])
727 .unwrap();
728 assert_eq!(embeddings.len(), 2);
729 assert!(embeddings.iter().all(|e| e.len() == 1024));
730 }
731}