1use std::path::PathBuf;
2
3use candle_core::{DType, Device, Tensor};
4use candle_nn::VarBuilder;
5use candle_transformers::models::bert::{BertModel, Config, DTYPE};
6use hf_hub::{Repo, RepoType, api::sync::ApiBuilder};
7use tokenizers::{PaddingParams, Tokenizer, TruncationParams};
8use tracing::info;
9
10use crate::error::{Result, SedimentError};
11
12pub const EMBEDDING_DIM: usize = 384;
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
22pub enum EmbeddingModel {
23 #[default]
25 AllMiniLmL6V2,
26 E5SmallV2,
28 BgeSmallEnV15,
30 BgeBaseEnV15,
32}
33
34impl EmbeddingModel {
35 pub fn embedding_dim(&self) -> usize {
37 match self {
38 Self::AllMiniLmL6V2 | Self::E5SmallV2 | Self::BgeSmallEnV15 => 384,
39 Self::BgeBaseEnV15 => 768,
40 }
41 }
42
43 pub fn model_id(&self) -> &'static str {
45 match self {
46 Self::AllMiniLmL6V2 => "sentence-transformers/all-MiniLM-L6-v2",
47 Self::E5SmallV2 => "intfloat/e5-small-v2",
48 Self::BgeSmallEnV15 => "BAAI/bge-small-en-v1.5",
49 Self::BgeBaseEnV15 => "BAAI/bge-base-en-v1.5",
50 }
51 }
52
53 pub fn revision(&self) -> &'static str {
55 match self {
56 Self::AllMiniLmL6V2 => "e4ce9877abf3edfe10b0d82785e83bdcb973e22e",
57 Self::E5SmallV2 => "ffb93f3bd4047442299a41ebb6fa998a38507c52",
58 Self::BgeSmallEnV15 => "5c38ec7c405ec4b44b94cc5a9bb96e735b38267a",
59 Self::BgeBaseEnV15 => "a5beb1e3e68b9ab74eb54cfd186867f64f240e1a",
60 }
61 }
62
63 pub fn model_sha256(&self) -> &'static str {
65 match self {
66 Self::AllMiniLmL6V2 => {
67 "53aa51172d142c89d9012cce15ae4d6cc0ca6895895114379cacb4fab128d9db"
68 }
69 Self::E5SmallV2 => "45bfa60070649aae2244fbc9d508537779b93b6f353c17b0f95ceccb1c5116c1",
70 Self::BgeSmallEnV15 => {
71 "3c9f31665447c8911517620762200d2245a2518d6e7208acc78cd9db317e21ad"
72 }
73 Self::BgeBaseEnV15 => {
74 "c7c1988aae201f80cf91a5dbbd5866409503b89dcaba877ca6dba7dd0a5167d7"
75 }
76 }
77 }
78
79 pub fn tokenizer_sha256(&self) -> &'static str {
81 match self {
82 Self::AllMiniLmL6V2 => {
83 "be50c3628f2bf5bb5e3a7f17b1f74611b2561a3a27eeab05e5aa30f411572037"
84 }
85 Self::E5SmallV2 => "d241a60d5e8f04cc1b2b3e9ef7a4921b27bf526d9f6050ab90f9267a1f9e5c66",
86 Self::BgeSmallEnV15 => {
87 "d241a60d5e8f04cc1b2b3e9ef7a4921b27bf526d9f6050ab90f9267a1f9e5c66"
88 }
89 Self::BgeBaseEnV15 => {
90 "d241a60d5e8f04cc1b2b3e9ef7a4921b27bf526d9f6050ab90f9267a1f9e5c66"
91 }
92 }
93 }
94
95 pub fn config_sha256(&self) -> &'static str {
97 match self {
98 Self::AllMiniLmL6V2 => {
99 "953f9c0d463486b10a6871cc2fd59f223b2c70184f49815e7efbcab5d8908b41"
100 }
101 Self::E5SmallV2 => "5dfb0363cd0243be179c03bcaafd1542d0fbb95e8cbcf575fff3e229342adc2f",
102 Self::BgeSmallEnV15 => {
103 "094f8e891b932f2000c92cfc663bac4c62069f5d8af5b5278c4306aef3084750"
104 }
105 Self::BgeBaseEnV15 => {
106 "bc00af31a4a31b74040d73370aa83b62da34c90b75eb77bfa7db039d90abd591"
107 }
108 }
109 }
110
111 pub fn prefix_query<'a>(&self, text: &'a str) -> std::borrow::Cow<'a, str> {
113 match self {
114 Self::AllMiniLmL6V2 => std::borrow::Cow::Borrowed(text),
115 Self::E5SmallV2 => std::borrow::Cow::Owned(format!("query: {text}")),
116 Self::BgeSmallEnV15 | Self::BgeBaseEnV15 => std::borrow::Cow::Owned(format!(
117 "Represent this sentence for searching relevant passages: {text}"
118 )),
119 }
120 }
121
122 pub fn prefix_document<'a>(&self, text: &'a str) -> std::borrow::Cow<'a, str> {
124 match self {
125 Self::AllMiniLmL6V2 => std::borrow::Cow::Borrowed(text),
126 Self::E5SmallV2 => std::borrow::Cow::Owned(format!("passage: {text}")),
127 Self::BgeSmallEnV15 | Self::BgeBaseEnV15 => std::borrow::Cow::Borrowed(text),
128 }
129 }
130
131 pub fn from_env_str(s: &str) -> Option<Self> {
133 match s {
134 "all-MiniLM-L6-v2" | "all-minilm-l6-v2" => Some(Self::AllMiniLmL6V2),
135 "e5-small-v2" => Some(Self::E5SmallV2),
136 "bge-small-en-v1.5" => Some(Self::BgeSmallEnV15),
137 "bge-base-en-v1.5" => Some(Self::BgeBaseEnV15),
138 _ => None,
139 }
140 }
141}
142
143pub struct Embedder {
151 model: BertModel,
152 tokenizer: Tokenizer,
153 device: Device,
154 embedding_model: EmbeddingModel,
155}
156
157impl Embedder {
158 pub fn new() -> Result<Self> {
163 let embedding_model = std::env::var("SEDIMENT_EMBEDDING_MODEL")
164 .ok()
165 .and_then(|s| EmbeddingModel::from_env_str(&s))
166 .unwrap_or_default();
167 Self::with_embedding_model(embedding_model)
168 }
169
170 pub fn with_embedding_model(embedding_model: EmbeddingModel) -> Result<Self> {
172 let model_id = embedding_model.model_id();
173 info!("Loading embedding model: {}", model_id);
174
175 let device = Device::Cpu;
176 let (model_path, tokenizer_path, config_path) =
177 download_model(model_id, embedding_model.revision())?;
178
179 let config_str = std::fs::read_to_string(&config_path)
181 .map_err(|e| SedimentError::ModelLoading(format!("Failed to read config: {}", e)))?;
182 let config: Config = serde_json::from_str(&config_str)
183 .map_err(|e| SedimentError::ModelLoading(format!("Failed to parse config: {}", e)))?;
184
185 let mut tokenizer = Tokenizer::from_file(&tokenizer_path)
187 .map_err(|e| SedimentError::Tokenizer(format!("Failed to load tokenizer: {}", e)))?;
188
189 let padding = PaddingParams {
191 strategy: tokenizers::PaddingStrategy::BatchLongest,
192 ..Default::default()
193 };
194 let truncation = TruncationParams {
195 max_length: 512,
196 ..Default::default()
197 };
198 tokenizer.with_padding(Some(padding));
199 tokenizer
200 .with_truncation(Some(truncation))
201 .map_err(|e| SedimentError::Tokenizer(format!("Failed to set truncation: {}", e)))?;
202
203 let tokenizer_hash = embedding_model.tokenizer_sha256();
206 if !tokenizer_hash.is_empty() {
207 verify_file_hash(&tokenizer_path, tokenizer_hash, "tokenizer.json")?;
208 }
209 let config_hash = embedding_model.config_sha256();
210 if !config_hash.is_empty() {
211 verify_file_hash(&config_path, config_hash, "config.json")?;
212 }
213 if !tokenizer_hash.is_empty() || !config_hash.is_empty() {
214 info!("Tokenizer and config integrity verified (SHA-256)");
215 }
216
217 let model_bytes = std::fs::read(&model_path).map_err(|e| {
222 SedimentError::ModelLoading(format!("Failed to read model weights: {}", e))
223 })?;
224 let model_hash = embedding_model.model_sha256();
225 if !model_hash.is_empty() {
226 verify_bytes_hash(&model_bytes, model_hash, "model.safetensors")?;
227 }
228 let vb = VarBuilder::from_buffered_safetensors(model_bytes, DTYPE, &device)
229 .map_err(|e| SedimentError::ModelLoading(format!("Failed to load weights: {}", e)))?;
230
231 let model = BertModel::load(vb, &config)
232 .map_err(|e| SedimentError::ModelLoading(format!("Failed to load model: {}", e)))?;
233
234 info!("Embedding model loaded successfully");
235
236 Ok(Self {
237 model,
238 tokenizer,
239 device,
240 embedding_model,
241 })
242 }
243
244 pub fn embed(&self, text: &str) -> Result<Vec<f32>> {
246 let embeddings = self.embed_batch(&[text])?;
247 embeddings.into_iter().next().ok_or_else(|| {
248 SedimentError::Embedding("embed_batch returned empty result for non-empty input".into())
249 })
250 }
251
252 pub fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
254 if texts.is_empty() {
255 return Ok(Vec::new());
256 }
257
258 let encodings = self
260 .tokenizer
261 .encode_batch(texts.to_vec(), true)
262 .map_err(|e| SedimentError::Tokenizer(format!("Tokenization failed: {}", e)))?;
263
264 let token_ids: Vec<Vec<u32>> = encodings.iter().map(|e| e.get_ids().to_vec()).collect();
265
266 let attention_masks: Vec<Vec<u32>> = encodings
267 .iter()
268 .map(|e| e.get_attention_mask().to_vec())
269 .collect();
270
271 let token_type_ids: Vec<Vec<u32>> = encodings
272 .iter()
273 .map(|e| e.get_type_ids().to_vec())
274 .collect();
275
276 let batch_size = texts.len();
278 let seq_len = token_ids[0].len();
279
280 let token_ids_flat: Vec<u32> = token_ids.into_iter().flatten().collect();
281 let attention_mask_flat: Vec<u32> = attention_masks.into_iter().flatten().collect();
282 let token_type_ids_flat: Vec<u32> = token_type_ids.into_iter().flatten().collect();
283
284 let token_ids_tensor =
285 Tensor::from_vec(token_ids_flat, (batch_size, seq_len), &self.device).map_err(|e| {
286 SedimentError::Embedding(format!("Failed to create token tensor: {}", e))
287 })?;
288
289 let attention_mask_tensor =
290 Tensor::from_vec(attention_mask_flat, (batch_size, seq_len), &self.device).map_err(
291 |e| SedimentError::Embedding(format!("Failed to create mask tensor: {}", e)),
292 )?;
293
294 let token_type_ids_tensor =
295 Tensor::from_vec(token_type_ids_flat, (batch_size, seq_len), &self.device).map_err(
296 |e| SedimentError::Embedding(format!("Failed to create type tensor: {}", e)),
297 )?;
298
299 let embeddings = self
301 .model
302 .forward(
303 &token_ids_tensor,
304 &token_type_ids_tensor,
305 Some(&attention_mask_tensor),
306 )
307 .map_err(|e| SedimentError::Embedding(format!("Model forward failed: {}", e)))?;
308
309 let attention_mask_f32 = attention_mask_tensor
311 .to_dtype(DType::F32)
312 .map_err(|e| SedimentError::Embedding(format!("Mask conversion failed: {}", e)))?
313 .unsqueeze(2)
314 .map_err(|e| SedimentError::Embedding(format!("Unsqueeze failed: {}", e)))?;
315
316 let masked_embeddings = embeddings
317 .broadcast_mul(&attention_mask_f32)
318 .map_err(|e| SedimentError::Embedding(format!("Broadcast mul failed: {}", e)))?;
319
320 let sum_embeddings = masked_embeddings
321 .sum(1)
322 .map_err(|e| SedimentError::Embedding(format!("Sum failed: {}", e)))?;
323
324 let sum_mask = attention_mask_f32
325 .sum(1)
326 .map_err(|e| SedimentError::Embedding(format!("Mask sum failed: {}", e)))?;
327
328 let mean_embeddings = sum_embeddings
329 .broadcast_div(&sum_mask)
330 .map_err(|e| SedimentError::Embedding(format!("Division failed: {}", e)))?;
331
332 let final_embeddings = normalize_l2(&mean_embeddings)?;
334
335 let embeddings_vec: Vec<Vec<f32>> = final_embeddings
337 .to_vec2()
338 .map_err(|e| SedimentError::Embedding(format!("Tensor to vec failed: {}", e)))?;
339
340 Ok(embeddings_vec)
341 }
342
343 pub fn embed_query(&self, text: &str) -> Result<Vec<f32>> {
345 let prefixed = self.embedding_model.prefix_query(text);
346 self.embed(&prefixed)
347 }
348
349 pub fn embed_document(&self, text: &str) -> Result<Vec<f32>> {
351 let prefixed = self.embedding_model.prefix_document(text);
352 self.embed(&prefixed)
353 }
354
355 pub fn embed_document_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
357 let prefixed: Vec<String> = texts
358 .iter()
359 .map(|t| self.embedding_model.prefix_document(t).into_owned())
360 .collect();
361 let refs: Vec<&str> = prefixed.iter().map(|s| s.as_str()).collect();
362 self.embed_batch(&refs)
363 }
364
365 pub fn dimension(&self) -> usize {
367 self.embedding_model.embedding_dim()
368 }
369
370 pub fn embedding_model(&self) -> EmbeddingModel {
372 self.embedding_model
373 }
374}
375
376fn download_model(model_id: &str, revision: &str) -> Result<(PathBuf, PathBuf, PathBuf)> {
378 let api = ApiBuilder::from_env()
379 .with_progress(true)
380 .build()
381 .map_err(|e| SedimentError::ModelLoading(format!("Failed to create HF API: {}", e)))?;
382
383 let repo = api.repo(Repo::with_revision(
384 model_id.to_string(),
385 RepoType::Model,
386 revision.to_string(),
387 ));
388
389 let model_path = repo
390 .get("model.safetensors")
391 .map_err(|e| SedimentError::ModelLoading(format!("Failed to download model: {}", e)))?;
392
393 let tokenizer_path = repo
394 .get("tokenizer.json")
395 .map_err(|e| SedimentError::ModelLoading(format!("Failed to download tokenizer: {}", e)))?;
396
397 let config_path = repo
398 .get("config.json")
399 .map_err(|e| SedimentError::ModelLoading(format!("Failed to download config: {}", e)))?;
400
401 Ok((model_path, tokenizer_path, config_path))
402}
403
404fn verify_file_hash(path: &std::path::Path, expected: &str, file_label: &str) -> Result<()> {
406 use sha2::{Digest, Sha256};
407
408 let file_bytes = std::fs::read(path).map_err(|e| {
409 SedimentError::ModelLoading(format!(
410 "Failed to read {} for hash verification: {}",
411 file_label, e
412 ))
413 })?;
414
415 let hash = Sha256::digest(&file_bytes);
416 let hex_hash = format!("{:x}", hash);
417
418 if hex_hash != expected {
419 return Err(SedimentError::ModelLoading(format!(
420 "{} integrity check failed: expected SHA-256 {}, got {}",
421 file_label, expected, hex_hash
422 )));
423 }
424
425 Ok(())
426}
427
428fn verify_bytes_hash(data: &[u8], expected: &str, file_label: &str) -> Result<()> {
433 use sha2::{Digest, Sha256};
434
435 let hash = Sha256::digest(data);
436 let hex_hash = format!("{:x}", hash);
437
438 if hex_hash != expected {
439 return Err(SedimentError::ModelLoading(format!(
440 "{} integrity check failed: expected SHA-256 {}, got {}",
441 file_label, expected, hex_hash
442 )));
443 }
444
445 Ok(())
446}
447
448fn normalize_l2(tensor: &Tensor) -> Result<Tensor> {
450 let norm = tensor
451 .sqr()
452 .map_err(|e| SedimentError::Embedding(format!("Sqr failed: {}", e)))?
453 .sum_keepdim(1)
454 .map_err(|e| SedimentError::Embedding(format!("Sum keepdim failed: {}", e)))?
455 .sqrt()
456 .map_err(|e| SedimentError::Embedding(format!("Sqrt failed: {}", e)))?;
457
458 tensor
459 .broadcast_div(&norm)
460 .map_err(|e| SedimentError::Embedding(format!("Normalize div failed: {}", e)))
461}
462
463#[cfg(test)]
464mod tests {
465 use super::*;
466
467 #[test]
468 #[ignore] fn test_embedder() -> Result<()> {
470 let embedder = Embedder::new()?;
471
472 let text = "Hello, world!";
473 let embedding = embedder.embed(text)?;
474
475 assert_eq!(embedding.len(), EMBEDDING_DIM);
476
477 let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
479 assert!((norm - 1.0).abs() < 0.01);
480
481 Ok(())
482 }
483
484 #[test]
485 #[ignore] fn test_batch_embedding() -> Result<()> {
487 let embedder = Embedder::new()?;
488
489 let texts = vec!["Hello", "World", "Test sentence"];
490 let embeddings = embedder.embed_batch(&texts)?;
491
492 assert_eq!(embeddings.len(), 3);
493 for emb in &embeddings {
494 assert_eq!(emb.len(), EMBEDDING_DIM);
495 }
496
497 Ok(())
498 }
499
500 #[test]
501 #[ignore] fn test_embed_query_and_document() -> Result<()> {
503 let embedder = Embedder::new()?;
504
505 let query_emb = embedder.embed_query("What database do we use?")?;
506 let doc_emb = embedder.embed_document("We use Postgres for the main database")?;
507
508 assert_eq!(query_emb.len(), EMBEDDING_DIM);
509 assert_eq!(doc_emb.len(), EMBEDDING_DIM);
510
511 let raw_emb = embedder.embed("What database do we use?")?;
513 assert_eq!(query_emb, raw_emb);
514
515 Ok(())
516 }
517
518 #[test]
519 #[ignore] fn test_e5_small_v2_embedder() -> Result<()> {
521 let embedder = Embedder::with_embedding_model(EmbeddingModel::E5SmallV2)?;
522
523 let emb = embedder.embed("test")?;
525 assert_eq!(emb.len(), EMBEDDING_DIM);
526
527 let norm: f32 = emb.iter().map(|x| x * x).sum::<f32>().sqrt();
529 assert!((norm - 1.0).abs() < 0.01);
530
531 let query_emb = embedder.embed_query("What is the capital of France?")?;
533 let doc_emb = embedder.embed_document("What is the capital of France?")?;
534 assert_ne!(
535 query_emb, doc_emb,
536 "E5 query and document embeddings should differ due to prefixes"
537 );
538
539 Ok(())
540 }
541
542 #[test]
543 #[ignore] fn test_bge_small_en_v15_embedder() -> Result<()> {
545 let embedder = Embedder::with_embedding_model(EmbeddingModel::BgeSmallEnV15)?;
546
547 let emb = embedder.embed("test")?;
549 assert_eq!(emb.len(), EMBEDDING_DIM);
550
551 let norm: f32 = emb.iter().map(|x| x * x).sum::<f32>().sqrt();
553 assert!((norm - 1.0).abs() < 0.01);
554
555 let query_emb = embedder.embed_query("What is the capital of France?")?;
557 let doc_emb = embedder.embed_document("What is the capital of France?")?;
558 let raw_emb = embedder.embed("What is the capital of France?")?;
559
560 assert_ne!(
561 query_emb, doc_emb,
562 "BGE query and document embeddings should differ due to query prefix"
563 );
564 assert_eq!(
566 doc_emb, raw_emb,
567 "BGE document embedding should equal raw embedding (no prefix)"
568 );
569
570 Ok(())
571 }
572
573 #[test]
574 #[ignore] fn test_bge_base_en_v15_embedder() -> Result<()> {
576 let embedder = Embedder::with_embedding_model(EmbeddingModel::BgeBaseEnV15)?;
577
578 let emb = embedder.embed("test")?;
580 assert_eq!(emb.len(), 768);
581 assert_eq!(embedder.dimension(), 768);
582
583 let norm: f32 = emb.iter().map(|x| x * x).sum::<f32>().sqrt();
585 assert!((norm - 1.0).abs() < 0.01);
586
587 let query_emb = embedder.embed_query("What is the capital of France?")?;
589 let doc_emb = embedder.embed_document("What is the capital of France?")?;
590 let raw_emb = embedder.embed("What is the capital of France?")?;
591
592 assert_eq!(query_emb.len(), 768);
593 assert_eq!(doc_emb.len(), 768);
594
595 assert_ne!(
596 query_emb, doc_emb,
597 "BGE-base query and document embeddings should differ due to query prefix"
598 );
599 assert_eq!(
600 doc_emb, raw_emb,
601 "BGE-base document embedding should equal raw embedding (no prefix)"
602 );
603
604 Ok(())
605 }
606
607 #[test]
608 fn test_embedding_model_from_env_str() {
609 assert_eq!(
610 EmbeddingModel::from_env_str("e5-small-v2"),
611 Some(EmbeddingModel::E5SmallV2)
612 );
613 assert_eq!(
614 EmbeddingModel::from_env_str("bge-small-en-v1.5"),
615 Some(EmbeddingModel::BgeSmallEnV15)
616 );
617 assert_eq!(
618 EmbeddingModel::from_env_str("bge-base-en-v1.5"),
619 Some(EmbeddingModel::BgeBaseEnV15)
620 );
621 assert_eq!(
622 EmbeddingModel::from_env_str("all-MiniLM-L6-v2"),
623 Some(EmbeddingModel::AllMiniLmL6V2)
624 );
625 assert_eq!(EmbeddingModel::from_env_str("unknown-model"), None);
626 }
627
628 #[test]
629 fn test_embedding_model_dimensions() {
630 assert_eq!(EmbeddingModel::AllMiniLmL6V2.embedding_dim(), 384);
631 assert_eq!(EmbeddingModel::E5SmallV2.embedding_dim(), 384);
632 assert_eq!(EmbeddingModel::BgeSmallEnV15.embedding_dim(), 384);
633 assert_eq!(EmbeddingModel::BgeBaseEnV15.embedding_dim(), 768);
634 }
635
636 #[test]
637 fn test_embedding_model_prefixes() {
638 let text = "hello world";
639
640 let m = EmbeddingModel::AllMiniLmL6V2;
642 assert_eq!(m.prefix_query(text).as_ref(), "hello world");
643 assert_eq!(m.prefix_document(text).as_ref(), "hello world");
644
645 let m = EmbeddingModel::E5SmallV2;
647 assert_eq!(m.prefix_query(text).as_ref(), "query: hello world");
648 assert_eq!(m.prefix_document(text).as_ref(), "passage: hello world");
649
650 let m = EmbeddingModel::BgeSmallEnV15;
652 assert_eq!(
653 m.prefix_query(text).as_ref(),
654 "Represent this sentence for searching relevant passages: hello world"
655 );
656 assert_eq!(m.prefix_document(text).as_ref(), "hello world");
657
658 let m = EmbeddingModel::BgeBaseEnV15;
660 assert_eq!(
661 m.prefix_query(text).as_ref(),
662 "Represent this sentence for searching relevant passages: hello world"
663 );
664 assert_eq!(m.prefix_document(text).as_ref(), "hello world");
665 }
666
667 #[test]
668 fn test_verify_bytes_hash_correct() {
669 let data = b"hello world";
670 let expected = "b94d27b9934d3e08a52e52d7da7dabfac484efe37a5380ee9088f7ace2efcde9";
671 assert!(verify_bytes_hash(data, expected, "test").is_ok());
672 }
673
674 #[test]
675 fn test_verify_bytes_hash_incorrect() {
676 let data = b"hello world";
677 let wrong = "0000000000000000000000000000000000000000000000000000000000000000";
678 let err = verify_bytes_hash(data, wrong, "test").unwrap_err();
679 assert!(err.to_string().contains("integrity check failed"));
680 }
681
682 #[test]
683 fn test_verify_bytes_hash_empty() {
684 let data = b"";
685 let expected = "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855";
687 assert!(verify_bytes_hash(data, expected, "empty").is_ok());
688 }
689}