Skip to main content

sediment/
embedder.rs

1use std::path::PathBuf;
2
3use candle_core::{DType, Device, Tensor};
4use candle_nn::VarBuilder;
5use candle_transformers::models::bert::{BertModel, Config, DTYPE};
6use hf_hub::{Repo, RepoType, api::sync::ApiBuilder};
7use tokenizers::{PaddingParams, Tokenizer, TruncationParams};
8use tracing::info;
9
10use crate::error::{Result, SedimentError};
11
12/// Default embedding dimension (384-dim for small models).
13/// Kept as a pub const for backward compatibility; prefer `Embedder::dimension()`.
14pub const EMBEDDING_DIM: usize = 384;
15
16/// Supported embedding models.
17///
18/// Each variant carries model metadata: HF repo ID, pinned revision,
19/// SHA-256 hashes for integrity verification, and prefix functions
20/// for asymmetric query/document embedding.
21#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
22pub enum EmbeddingModel {
23    /// sentence-transformers/all-MiniLM-L6-v2 (default, no prefixes, 384-dim)
24    #[default]
25    AllMiniLmL6V2,
26    /// intfloat/e5-small-v2 (query: "query: {text}", document: "passage: {text}", 384-dim)
27    E5SmallV2,
28    /// BAAI/bge-small-en-v1.5 (query prefix, no doc prefix, 384-dim)
29    BgeSmallEnV15,
30    /// BAAI/bge-base-en-v1.5 (query prefix, no doc prefix, 768-dim)
31    BgeBaseEnV15,
32}
33
34impl EmbeddingModel {
35    /// Embedding dimension for this model
36    pub fn embedding_dim(&self) -> usize {
37        match self {
38            Self::AllMiniLmL6V2 | Self::E5SmallV2 | Self::BgeSmallEnV15 => 384,
39            Self::BgeBaseEnV15 => 768,
40        }
41    }
42
43    /// Hugging Face model repository ID
44    pub fn model_id(&self) -> &'static str {
45        match self {
46            Self::AllMiniLmL6V2 => "sentence-transformers/all-MiniLM-L6-v2",
47            Self::E5SmallV2 => "intfloat/e5-small-v2",
48            Self::BgeSmallEnV15 => "BAAI/bge-small-en-v1.5",
49            Self::BgeBaseEnV15 => "BAAI/bge-base-en-v1.5",
50        }
51    }
52
53    /// Pinned git revision for reproducible downloads
54    pub fn revision(&self) -> &'static str {
55        match self {
56            Self::AllMiniLmL6V2 => "e4ce9877abf3edfe10b0d82785e83bdcb973e22e",
57            Self::E5SmallV2 => "ffb93f3bd4047442299a41ebb6fa998a38507c52",
58            Self::BgeSmallEnV15 => "5c38ec7c405ec4b44b94cc5a9bb96e735b38267a",
59            Self::BgeBaseEnV15 => "a5beb1e3e68b9ab74eb54cfd186867f64f240e1a",
60        }
61    }
62
63    /// Expected SHA-256 hash of model.safetensors
64    pub fn model_sha256(&self) -> &'static str {
65        match self {
66            Self::AllMiniLmL6V2 => {
67                "53aa51172d142c89d9012cce15ae4d6cc0ca6895895114379cacb4fab128d9db"
68            }
69            Self::E5SmallV2 => "45bfa60070649aae2244fbc9d508537779b93b6f353c17b0f95ceccb1c5116c1",
70            Self::BgeSmallEnV15 => {
71                "3c9f31665447c8911517620762200d2245a2518d6e7208acc78cd9db317e21ad"
72            }
73            Self::BgeBaseEnV15 => {
74                "c7c1988aae201f80cf91a5dbbd5866409503b89dcaba877ca6dba7dd0a5167d7"
75            }
76        }
77    }
78
79    /// Expected SHA-256 hash of tokenizer.json
80    pub fn tokenizer_sha256(&self) -> &'static str {
81        match self {
82            Self::AllMiniLmL6V2 => {
83                "be50c3628f2bf5bb5e3a7f17b1f74611b2561a3a27eeab05e5aa30f411572037"
84            }
85            Self::E5SmallV2 => "d241a60d5e8f04cc1b2b3e9ef7a4921b27bf526d9f6050ab90f9267a1f9e5c66",
86            Self::BgeSmallEnV15 => {
87                "d241a60d5e8f04cc1b2b3e9ef7a4921b27bf526d9f6050ab90f9267a1f9e5c66"
88            }
89            Self::BgeBaseEnV15 => {
90                "d241a60d5e8f04cc1b2b3e9ef7a4921b27bf526d9f6050ab90f9267a1f9e5c66"
91            }
92        }
93    }
94
95    /// Expected SHA-256 hash of config.json
96    pub fn config_sha256(&self) -> &'static str {
97        match self {
98            Self::AllMiniLmL6V2 => {
99                "953f9c0d463486b10a6871cc2fd59f223b2c70184f49815e7efbcab5d8908b41"
100            }
101            Self::E5SmallV2 => "5dfb0363cd0243be179c03bcaafd1542d0fbb95e8cbcf575fff3e229342adc2f",
102            Self::BgeSmallEnV15 => {
103                "094f8e891b932f2000c92cfc663bac4c62069f5d8af5b5278c4306aef3084750"
104            }
105            Self::BgeBaseEnV15 => {
106                "bc00af31a4a31b74040d73370aa83b62da34c90b75eb77bfa7db039d90abd591"
107            }
108        }
109    }
110
111    /// Apply query prefix for asymmetric search
112    pub fn prefix_query<'a>(&self, text: &'a str) -> std::borrow::Cow<'a, str> {
113        match self {
114            Self::AllMiniLmL6V2 => std::borrow::Cow::Borrowed(text),
115            Self::E5SmallV2 => std::borrow::Cow::Owned(format!("query: {text}")),
116            Self::BgeSmallEnV15 | Self::BgeBaseEnV15 => std::borrow::Cow::Owned(format!(
117                "Represent this sentence for searching relevant passages: {text}"
118            )),
119        }
120    }
121
122    /// Apply document prefix for asymmetric search
123    pub fn prefix_document<'a>(&self, text: &'a str) -> std::borrow::Cow<'a, str> {
124        match self {
125            Self::AllMiniLmL6V2 => std::borrow::Cow::Borrowed(text),
126            Self::E5SmallV2 => std::borrow::Cow::Owned(format!("passage: {text}")),
127            Self::BgeSmallEnV15 | Self::BgeBaseEnV15 => std::borrow::Cow::Borrowed(text),
128        }
129    }
130
131    /// Parse from env var value (e.g. "e5-small-v2", "bge-small-en-v1.5")
132    pub fn from_env_str(s: &str) -> Option<Self> {
133        match s {
134            "all-MiniLM-L6-v2" | "all-minilm-l6-v2" => Some(Self::AllMiniLmL6V2),
135            "e5-small-v2" => Some(Self::E5SmallV2),
136            "bge-small-en-v1.5" => Some(Self::BgeSmallEnV15),
137            "bge-base-en-v1.5" => Some(Self::BgeBaseEnV15),
138            _ => None,
139        }
140    }
141}
142
143/// Embedder for converting text to vectors.
144///
145/// # Thread Safety
146/// `Embedder` wraps `BertModel` and `Tokenizer` which are `Send + Sync`.
147/// It is shared via `Arc<Embedder>` across the server. All inference runs
148/// synchronously on the calling thread (via `rt.block_on`), so there are
149/// no cross-thread mutation concerns.
150pub struct Embedder {
151    model: BertModel,
152    tokenizer: Tokenizer,
153    device: Device,
154    embedding_model: EmbeddingModel,
155}
156
157impl Embedder {
158    /// Create a new embedder, downloading the model if necessary.
159    ///
160    /// Reads `SEDIMENT_EMBEDDING_MODEL` env var to select the model.
161    /// Falls back to AllMiniLmL6V2 if unset or unrecognized.
162    pub fn new() -> Result<Self> {
163        let embedding_model = std::env::var("SEDIMENT_EMBEDDING_MODEL")
164            .ok()
165            .and_then(|s| EmbeddingModel::from_env_str(&s))
166            .unwrap_or_default();
167        Self::with_embedding_model(embedding_model)
168    }
169
170    /// Create an embedder with a specific embedding model variant
171    pub fn with_embedding_model(embedding_model: EmbeddingModel) -> Result<Self> {
172        let model_id = embedding_model.model_id();
173        info!("Loading embedding model: {}", model_id);
174
175        let device = Device::Cpu;
176        let (model_path, tokenizer_path, config_path) =
177            download_model(model_id, embedding_model.revision())?;
178
179        // Load config
180        let config_str = std::fs::read_to_string(&config_path)
181            .map_err(|e| SedimentError::ModelLoading(format!("Failed to read config: {}", e)))?;
182        let config: Config = serde_json::from_str(&config_str)
183            .map_err(|e| SedimentError::ModelLoading(format!("Failed to parse config: {}", e)))?;
184
185        // Load tokenizer
186        let mut tokenizer = Tokenizer::from_file(&tokenizer_path)
187            .map_err(|e| SedimentError::Tokenizer(format!("Failed to load tokenizer: {}", e)))?;
188
189        // Configure tokenizer for batch processing
190        let padding = PaddingParams {
191            strategy: tokenizers::PaddingStrategy::BatchLongest,
192            ..Default::default()
193        };
194        let truncation = TruncationParams {
195            max_length: 512,
196            ..Default::default()
197        };
198        tokenizer.with_padding(Some(padding));
199        tokenizer
200            .with_truncation(Some(truncation))
201            .map_err(|e| SedimentError::Tokenizer(format!("Failed to set truncation: {}", e)))?;
202
203        // Verify integrity of tokenizer and config files using hardcoded SHA-256 hashes.
204        // Skip verification for models with empty (placeholder) hashes.
205        let tokenizer_hash = embedding_model.tokenizer_sha256();
206        if !tokenizer_hash.is_empty() {
207            verify_file_hash(&tokenizer_path, tokenizer_hash, "tokenizer.json")?;
208        }
209        let config_hash = embedding_model.config_sha256();
210        if !config_hash.is_empty() {
211            verify_file_hash(&config_path, config_hash, "config.json")?;
212        }
213        if !tokenizer_hash.is_empty() || !config_hash.is_empty() {
214            info!("Tokenizer and config integrity verified (SHA-256)");
215        }
216
217        // Load model weights into memory and verify integrity.
218        // Uses from_buffered_safetensors instead of unsafe from_mmaped_safetensors
219        // to eliminate the TOCTOU window between hash verification and file use.
220        // The same bytes that pass SHA-256 verification are the ones parsed.
221        let model_bytes = std::fs::read(&model_path).map_err(|e| {
222            SedimentError::ModelLoading(format!("Failed to read model weights: {}", e))
223        })?;
224        let model_hash = embedding_model.model_sha256();
225        if !model_hash.is_empty() {
226            verify_bytes_hash(&model_bytes, model_hash, "model.safetensors")?;
227        }
228        let vb = VarBuilder::from_buffered_safetensors(model_bytes, DTYPE, &device)
229            .map_err(|e| SedimentError::ModelLoading(format!("Failed to load weights: {}", e)))?;
230
231        let model = BertModel::load(vb, &config)
232            .map_err(|e| SedimentError::ModelLoading(format!("Failed to load model: {}", e)))?;
233
234        info!("Embedding model loaded successfully");
235
236        Ok(Self {
237            model,
238            tokenizer,
239            device,
240            embedding_model,
241        })
242    }
243
244    /// Embed a single text (raw, no prefix applied)
245    pub fn embed(&self, text: &str) -> Result<Vec<f32>> {
246        let embeddings = self.embed_batch(&[text])?;
247        embeddings.into_iter().next().ok_or_else(|| {
248            SedimentError::Embedding("embed_batch returned empty result for non-empty input".into())
249        })
250    }
251
252    /// Embed multiple texts at once (raw, no prefix applied)
253    pub fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
254        if texts.is_empty() {
255            return Ok(Vec::new());
256        }
257
258        // Tokenize
259        let encodings = self
260            .tokenizer
261            .encode_batch(texts.to_vec(), true)
262            .map_err(|e| SedimentError::Tokenizer(format!("Tokenization failed: {}", e)))?;
263
264        let token_ids: Vec<Vec<u32>> = encodings.iter().map(|e| e.get_ids().to_vec()).collect();
265
266        let attention_masks: Vec<Vec<u32>> = encodings
267            .iter()
268            .map(|e| e.get_attention_mask().to_vec())
269            .collect();
270
271        let token_type_ids: Vec<Vec<u32>> = encodings
272            .iter()
273            .map(|e| e.get_type_ids().to_vec())
274            .collect();
275
276        // Convert to tensors
277        let batch_size = texts.len();
278        let seq_len = token_ids[0].len();
279
280        let token_ids_flat: Vec<u32> = token_ids.into_iter().flatten().collect();
281        let attention_mask_flat: Vec<u32> = attention_masks.into_iter().flatten().collect();
282        let token_type_ids_flat: Vec<u32> = token_type_ids.into_iter().flatten().collect();
283
284        let token_ids_tensor =
285            Tensor::from_vec(token_ids_flat, (batch_size, seq_len), &self.device).map_err(|e| {
286                SedimentError::Embedding(format!("Failed to create token tensor: {}", e))
287            })?;
288
289        let attention_mask_tensor =
290            Tensor::from_vec(attention_mask_flat, (batch_size, seq_len), &self.device).map_err(
291                |e| SedimentError::Embedding(format!("Failed to create mask tensor: {}", e)),
292            )?;
293
294        let token_type_ids_tensor =
295            Tensor::from_vec(token_type_ids_flat, (batch_size, seq_len), &self.device).map_err(
296                |e| SedimentError::Embedding(format!("Failed to create type tensor: {}", e)),
297            )?;
298
299        // Run model
300        let embeddings = self
301            .model
302            .forward(
303                &token_ids_tensor,
304                &token_type_ids_tensor,
305                Some(&attention_mask_tensor),
306            )
307            .map_err(|e| SedimentError::Embedding(format!("Model forward failed: {}", e)))?;
308
309        // Mean pooling with attention mask
310        let attention_mask_f32 = attention_mask_tensor
311            .to_dtype(DType::F32)
312            .map_err(|e| SedimentError::Embedding(format!("Mask conversion failed: {}", e)))?
313            .unsqueeze(2)
314            .map_err(|e| SedimentError::Embedding(format!("Unsqueeze failed: {}", e)))?;
315
316        let masked_embeddings = embeddings
317            .broadcast_mul(&attention_mask_f32)
318            .map_err(|e| SedimentError::Embedding(format!("Broadcast mul failed: {}", e)))?;
319
320        let sum_embeddings = masked_embeddings
321            .sum(1)
322            .map_err(|e| SedimentError::Embedding(format!("Sum failed: {}", e)))?;
323
324        let sum_mask = attention_mask_f32
325            .sum(1)
326            .map_err(|e| SedimentError::Embedding(format!("Mask sum failed: {}", e)))?;
327
328        let mean_embeddings = sum_embeddings
329            .broadcast_div(&sum_mask)
330            .map_err(|e| SedimentError::Embedding(format!("Division failed: {}", e)))?;
331
332        // L2 normalize embeddings
333        let final_embeddings = normalize_l2(&mean_embeddings)?;
334
335        // Convert to Vec<Vec<f32>>
336        let embeddings_vec: Vec<Vec<f32>> = final_embeddings
337            .to_vec2()
338            .map_err(|e| SedimentError::Embedding(format!("Tensor to vec failed: {}", e)))?;
339
340        Ok(embeddings_vec)
341    }
342
343    /// Embed a single query text with model-specific query prefix
344    pub fn embed_query(&self, text: &str) -> Result<Vec<f32>> {
345        let prefixed = self.embedding_model.prefix_query(text);
346        self.embed(&prefixed)
347    }
348
349    /// Embed a single document text with model-specific document prefix
350    pub fn embed_document(&self, text: &str) -> Result<Vec<f32>> {
351        let prefixed = self.embedding_model.prefix_document(text);
352        self.embed(&prefixed)
353    }
354
355    /// Embed multiple document texts with model-specific document prefix
356    pub fn embed_document_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
357        let prefixed: Vec<String> = texts
358            .iter()
359            .map(|t| self.embedding_model.prefix_document(t).into_owned())
360            .collect();
361        let refs: Vec<&str> = prefixed.iter().map(|s| s.as_str()).collect();
362        self.embed_batch(&refs)
363    }
364
365    /// Get the embedding dimension for the active model
366    pub fn dimension(&self) -> usize {
367        self.embedding_model.embedding_dim()
368    }
369
370    /// Get the active embedding model
371    pub fn embedding_model(&self) -> EmbeddingModel {
372        self.embedding_model
373    }
374}
375
376/// Download model files from Hugging Face Hub
377fn download_model(model_id: &str, revision: &str) -> Result<(PathBuf, PathBuf, PathBuf)> {
378    let api = ApiBuilder::from_env()
379        .with_progress(true)
380        .build()
381        .map_err(|e| SedimentError::ModelLoading(format!("Failed to create HF API: {}", e)))?;
382
383    let repo = api.repo(Repo::with_revision(
384        model_id.to_string(),
385        RepoType::Model,
386        revision.to_string(),
387    ));
388
389    let model_path = repo
390        .get("model.safetensors")
391        .map_err(|e| SedimentError::ModelLoading(format!("Failed to download model: {}", e)))?;
392
393    let tokenizer_path = repo
394        .get("tokenizer.json")
395        .map_err(|e| SedimentError::ModelLoading(format!("Failed to download tokenizer: {}", e)))?;
396
397    let config_path = repo
398        .get("config.json")
399        .map_err(|e| SedimentError::ModelLoading(format!("Failed to download config: {}", e)))?;
400
401    Ok((model_path, tokenizer_path, config_path))
402}
403
404/// Verify the SHA-256 hash of a file against an expected value.
405fn verify_file_hash(path: &std::path::Path, expected: &str, file_label: &str) -> Result<()> {
406    use sha2::{Digest, Sha256};
407
408    let file_bytes = std::fs::read(path).map_err(|e| {
409        SedimentError::ModelLoading(format!(
410            "Failed to read {} for hash verification: {}",
411            file_label, e
412        ))
413    })?;
414
415    let hash = Sha256::digest(&file_bytes);
416    let hex_hash = format!("{:x}", hash);
417
418    if hex_hash != expected {
419        return Err(SedimentError::ModelLoading(format!(
420            "{} integrity check failed: expected SHA-256 {}, got {}",
421            file_label, expected, hex_hash
422        )));
423    }
424
425    Ok(())
426}
427
428/// Verify the SHA-256 hash of in-memory bytes against an expected value.
429///
430/// This is used for model weights to eliminate the TOCTOU window: the same bytes
431/// that are hash-verified are the ones passed to the safetensors parser.
432fn verify_bytes_hash(data: &[u8], expected: &str, file_label: &str) -> Result<()> {
433    use sha2::{Digest, Sha256};
434
435    let hash = Sha256::digest(data);
436    let hex_hash = format!("{:x}", hash);
437
438    if hex_hash != expected {
439        return Err(SedimentError::ModelLoading(format!(
440            "{} integrity check failed: expected SHA-256 {}, got {}",
441            file_label, expected, hex_hash
442        )));
443    }
444
445    Ok(())
446}
447
448/// L2 normalize a tensor
449fn normalize_l2(tensor: &Tensor) -> Result<Tensor> {
450    let norm = tensor
451        .sqr()
452        .map_err(|e| SedimentError::Embedding(format!("Sqr failed: {}", e)))?
453        .sum_keepdim(1)
454        .map_err(|e| SedimentError::Embedding(format!("Sum keepdim failed: {}", e)))?
455        .sqrt()
456        .map_err(|e| SedimentError::Embedding(format!("Sqrt failed: {}", e)))?;
457
458    tensor
459        .broadcast_div(&norm)
460        .map_err(|e| SedimentError::Embedding(format!("Normalize div failed: {}", e)))
461}
462
463#[cfg(test)]
464mod tests {
465    use super::*;
466
467    #[test]
468    #[ignore] // Requires model download
469    fn test_embedder() -> Result<()> {
470        let embedder = Embedder::new()?;
471
472        let text = "Hello, world!";
473        let embedding = embedder.embed(text)?;
474
475        assert_eq!(embedding.len(), EMBEDDING_DIM);
476
477        // Check normalization (L2 norm should be ~1.0)
478        let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
479        assert!((norm - 1.0).abs() < 0.01);
480
481        Ok(())
482    }
483
484    #[test]
485    #[ignore] // Requires model download
486    fn test_batch_embedding() -> Result<()> {
487        let embedder = Embedder::new()?;
488
489        let texts = vec!["Hello", "World", "Test sentence"];
490        let embeddings = embedder.embed_batch(&texts)?;
491
492        assert_eq!(embeddings.len(), 3);
493        for emb in &embeddings {
494            assert_eq!(emb.len(), EMBEDDING_DIM);
495        }
496
497        Ok(())
498    }
499
500    #[test]
501    #[ignore] // Requires model download
502    fn test_embed_query_and_document() -> Result<()> {
503        let embedder = Embedder::new()?;
504
505        let query_emb = embedder.embed_query("What database do we use?")?;
506        let doc_emb = embedder.embed_document("We use Postgres for the main database")?;
507
508        assert_eq!(query_emb.len(), EMBEDDING_DIM);
509        assert_eq!(doc_emb.len(), EMBEDDING_DIM);
510
511        // For AllMiniLmL6V2 (no prefixes), embed_query and embed should be identical
512        let raw_emb = embedder.embed("What database do we use?")?;
513        assert_eq!(query_emb, raw_emb);
514
515        Ok(())
516    }
517
518    #[test]
519    #[ignore] // Requires model download
520    fn test_e5_small_v2_embedder() -> Result<()> {
521        let embedder = Embedder::with_embedding_model(EmbeddingModel::E5SmallV2)?;
522
523        // Verify dimension
524        let emb = embedder.embed("test")?;
525        assert_eq!(emb.len(), EMBEDDING_DIM);
526
527        // Verify L2 normalization
528        let norm: f32 = emb.iter().map(|x| x * x).sum::<f32>().sqrt();
529        assert!((norm - 1.0).abs() < 0.01);
530
531        // Query and document prefixes should produce different vectors
532        let query_emb = embedder.embed_query("What is the capital of France?")?;
533        let doc_emb = embedder.embed_document("What is the capital of France?")?;
534        assert_ne!(
535            query_emb, doc_emb,
536            "E5 query and document embeddings should differ due to prefixes"
537        );
538
539        Ok(())
540    }
541
542    #[test]
543    #[ignore] // Requires model download
544    fn test_bge_small_en_v15_embedder() -> Result<()> {
545        let embedder = Embedder::with_embedding_model(EmbeddingModel::BgeSmallEnV15)?;
546
547        // Verify dimension
548        let emb = embedder.embed("test")?;
549        assert_eq!(emb.len(), EMBEDDING_DIM);
550
551        // Verify L2 normalization
552        let norm: f32 = emb.iter().map(|x| x * x).sum::<f32>().sqrt();
553        assert!((norm - 1.0).abs() < 0.01);
554
555        // Query prefix should produce different vector; document has no prefix
556        let query_emb = embedder.embed_query("What is the capital of France?")?;
557        let doc_emb = embedder.embed_document("What is the capital of France?")?;
558        let raw_emb = embedder.embed("What is the capital of France?")?;
559
560        assert_ne!(
561            query_emb, doc_emb,
562            "BGE query and document embeddings should differ due to query prefix"
563        );
564        // Document embedding should be identical to raw embedding (no prefix)
565        assert_eq!(
566            doc_emb, raw_emb,
567            "BGE document embedding should equal raw embedding (no prefix)"
568        );
569
570        Ok(())
571    }
572
573    #[test]
574    #[ignore] // Requires model download
575    fn test_bge_base_en_v15_embedder() -> Result<()> {
576        let embedder = Embedder::with_embedding_model(EmbeddingModel::BgeBaseEnV15)?;
577
578        // Verify 768-dim output
579        let emb = embedder.embed("test")?;
580        assert_eq!(emb.len(), 768);
581        assert_eq!(embedder.dimension(), 768);
582
583        // Verify L2 normalization
584        let norm: f32 = emb.iter().map(|x| x * x).sum::<f32>().sqrt();
585        assert!((norm - 1.0).abs() < 0.01);
586
587        // Query prefix should produce different vector; document has no prefix
588        let query_emb = embedder.embed_query("What is the capital of France?")?;
589        let doc_emb = embedder.embed_document("What is the capital of France?")?;
590        let raw_emb = embedder.embed("What is the capital of France?")?;
591
592        assert_eq!(query_emb.len(), 768);
593        assert_eq!(doc_emb.len(), 768);
594
595        assert_ne!(
596            query_emb, doc_emb,
597            "BGE-base query and document embeddings should differ due to query prefix"
598        );
599        assert_eq!(
600            doc_emb, raw_emb,
601            "BGE-base document embedding should equal raw embedding (no prefix)"
602        );
603
604        Ok(())
605    }
606
607    #[test]
608    fn test_embedding_model_from_env_str() {
609        assert_eq!(
610            EmbeddingModel::from_env_str("e5-small-v2"),
611            Some(EmbeddingModel::E5SmallV2)
612        );
613        assert_eq!(
614            EmbeddingModel::from_env_str("bge-small-en-v1.5"),
615            Some(EmbeddingModel::BgeSmallEnV15)
616        );
617        assert_eq!(
618            EmbeddingModel::from_env_str("bge-base-en-v1.5"),
619            Some(EmbeddingModel::BgeBaseEnV15)
620        );
621        assert_eq!(
622            EmbeddingModel::from_env_str("all-MiniLM-L6-v2"),
623            Some(EmbeddingModel::AllMiniLmL6V2)
624        );
625        assert_eq!(EmbeddingModel::from_env_str("unknown-model"), None);
626    }
627
628    #[test]
629    fn test_embedding_model_dimensions() {
630        assert_eq!(EmbeddingModel::AllMiniLmL6V2.embedding_dim(), 384);
631        assert_eq!(EmbeddingModel::E5SmallV2.embedding_dim(), 384);
632        assert_eq!(EmbeddingModel::BgeSmallEnV15.embedding_dim(), 384);
633        assert_eq!(EmbeddingModel::BgeBaseEnV15.embedding_dim(), 768);
634    }
635
636    #[test]
637    fn test_embedding_model_prefixes() {
638        let text = "hello world";
639
640        // AllMiniLmL6V2: no prefixes
641        let m = EmbeddingModel::AllMiniLmL6V2;
642        assert_eq!(m.prefix_query(text).as_ref(), "hello world");
643        assert_eq!(m.prefix_document(text).as_ref(), "hello world");
644
645        // E5SmallV2: query and document prefixes
646        let m = EmbeddingModel::E5SmallV2;
647        assert_eq!(m.prefix_query(text).as_ref(), "query: hello world");
648        assert_eq!(m.prefix_document(text).as_ref(), "passage: hello world");
649
650        // BgeSmallEnV15: query prefix only
651        let m = EmbeddingModel::BgeSmallEnV15;
652        assert_eq!(
653            m.prefix_query(text).as_ref(),
654            "Represent this sentence for searching relevant passages: hello world"
655        );
656        assert_eq!(m.prefix_document(text).as_ref(), "hello world");
657
658        // BgeBaseEnV15: same prefixes as BgeSmallEnV15
659        let m = EmbeddingModel::BgeBaseEnV15;
660        assert_eq!(
661            m.prefix_query(text).as_ref(),
662            "Represent this sentence for searching relevant passages: hello world"
663        );
664        assert_eq!(m.prefix_document(text).as_ref(), "hello world");
665    }
666
667    #[test]
668    fn test_verify_bytes_hash_correct() {
669        let data = b"hello world";
670        let expected = "b94d27b9934d3e08a52e52d7da7dabfac484efe37a5380ee9088f7ace2efcde9";
671        assert!(verify_bytes_hash(data, expected, "test").is_ok());
672    }
673
674    #[test]
675    fn test_verify_bytes_hash_incorrect() {
676        let data = b"hello world";
677        let wrong = "0000000000000000000000000000000000000000000000000000000000000000";
678        let err = verify_bytes_hash(data, wrong, "test").unwrap_err();
679        assert!(err.to_string().contains("integrity check failed"));
680    }
681
682    #[test]
683    fn test_verify_bytes_hash_empty() {
684        let data = b"";
685        // SHA-256 of empty input
686        let expected = "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855";
687        assert!(verify_bytes_hash(data, expected, "empty").is_ok());
688    }
689}