1use std::path::PathBuf;
2
3use candle_core::{DType, Device, Tensor};
4use candle_nn::VarBuilder;
5use candle_transformers::models::bert::{BertModel, Config, DTYPE};
6use hf_hub::{Repo, RepoType, api::sync::ApiBuilder};
7use tokenizers::{PaddingParams, Tokenizer, TruncationParams};
8use tracing::info;
9
10use crate::error::{Result, SedimentError};
11
12pub const DEFAULT_MODEL_ID: &str = "sentence-transformers/all-MiniLM-L6-v2";
14
15pub const EMBEDDING_DIM: usize = 384;
17
18pub struct Embedder {
26 model: BertModel,
27 tokenizer: Tokenizer,
28 device: Device,
29}
30
31impl Embedder {
32 pub fn new() -> Result<Self> {
34 Self::with_model(DEFAULT_MODEL_ID)
35 }
36
37 pub fn with_model(model_id: &str) -> Result<Self> {
39 info!("Loading embedding model: {}", model_id);
40
41 let device = Device::Cpu;
42 let (model_path, tokenizer_path, config_path) = download_model(model_id)?;
43
44 let config_str = std::fs::read_to_string(&config_path)
46 .map_err(|e| SedimentError::ModelLoading(format!("Failed to read config: {}", e)))?;
47 let config: Config = serde_json::from_str(&config_str)
48 .map_err(|e| SedimentError::ModelLoading(format!("Failed to parse config: {}", e)))?;
49
50 let mut tokenizer = Tokenizer::from_file(&tokenizer_path)
52 .map_err(|e| SedimentError::Tokenizer(format!("Failed to load tokenizer: {}", e)))?;
53
54 let padding = PaddingParams {
56 strategy: tokenizers::PaddingStrategy::BatchLongest,
57 ..Default::default()
58 };
59 let truncation = TruncationParams {
60 max_length: 512,
61 ..Default::default()
62 };
63 tokenizer.with_padding(Some(padding));
64 tokenizer
65 .with_truncation(Some(truncation))
66 .map_err(|e| SedimentError::Tokenizer(format!("Failed to set truncation: {}", e)))?;
67
68 let model_bytes = std::fs::read(&model_path).map_err(|e| {
73 SedimentError::ModelLoading(format!("Failed to read model weights: {}", e))
74 })?;
75 verify_bytes_hash(&model_bytes, MODEL_SHA256, "model.safetensors")?;
76 let vb = VarBuilder::from_buffered_safetensors(model_bytes, DTYPE, &device)
77 .map_err(|e| SedimentError::ModelLoading(format!("Failed to load weights: {}", e)))?;
78
79 let model = BertModel::load(vb, &config)
80 .map_err(|e| SedimentError::ModelLoading(format!("Failed to load model: {}", e)))?;
81
82 info!("Embedding model loaded successfully");
83
84 Ok(Self {
85 model,
86 tokenizer,
87 device,
88 })
89 }
90
91 pub fn embed(&self, text: &str) -> Result<Vec<f32>> {
93 let embeddings = self.embed_batch(&[text])?;
94 embeddings.into_iter().next().ok_or_else(|| {
95 SedimentError::Embedding("embed_batch returned empty result for non-empty input".into())
96 })
97 }
98
99 pub fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
101 if texts.is_empty() {
102 return Ok(Vec::new());
103 }
104
105 let encodings = self
107 .tokenizer
108 .encode_batch(texts.to_vec(), true)
109 .map_err(|e| SedimentError::Tokenizer(format!("Tokenization failed: {}", e)))?;
110
111 let token_ids: Vec<Vec<u32>> = encodings.iter().map(|e| e.get_ids().to_vec()).collect();
112
113 let attention_masks: Vec<Vec<u32>> = encodings
114 .iter()
115 .map(|e| e.get_attention_mask().to_vec())
116 .collect();
117
118 let token_type_ids: Vec<Vec<u32>> = encodings
119 .iter()
120 .map(|e| e.get_type_ids().to_vec())
121 .collect();
122
123 let batch_size = texts.len();
125 let seq_len = token_ids[0].len();
126
127 let token_ids_flat: Vec<u32> = token_ids.into_iter().flatten().collect();
128 let attention_mask_flat: Vec<u32> = attention_masks.into_iter().flatten().collect();
129 let token_type_ids_flat: Vec<u32> = token_type_ids.into_iter().flatten().collect();
130
131 let token_ids_tensor =
132 Tensor::from_vec(token_ids_flat, (batch_size, seq_len), &self.device).map_err(|e| {
133 SedimentError::Embedding(format!("Failed to create token tensor: {}", e))
134 })?;
135
136 let attention_mask_tensor =
137 Tensor::from_vec(attention_mask_flat, (batch_size, seq_len), &self.device).map_err(
138 |e| SedimentError::Embedding(format!("Failed to create mask tensor: {}", e)),
139 )?;
140
141 let token_type_ids_tensor =
142 Tensor::from_vec(token_type_ids_flat, (batch_size, seq_len), &self.device).map_err(
143 |e| SedimentError::Embedding(format!("Failed to create type tensor: {}", e)),
144 )?;
145
146 let embeddings = self
148 .model
149 .forward(
150 &token_ids_tensor,
151 &token_type_ids_tensor,
152 Some(&attention_mask_tensor),
153 )
154 .map_err(|e| SedimentError::Embedding(format!("Model forward failed: {}", e)))?;
155
156 let attention_mask_f32 = attention_mask_tensor
158 .to_dtype(DType::F32)
159 .map_err(|e| SedimentError::Embedding(format!("Mask conversion failed: {}", e)))?
160 .unsqueeze(2)
161 .map_err(|e| SedimentError::Embedding(format!("Unsqueeze failed: {}", e)))?;
162
163 let masked_embeddings = embeddings
164 .broadcast_mul(&attention_mask_f32)
165 .map_err(|e| SedimentError::Embedding(format!("Broadcast mul failed: {}", e)))?;
166
167 let sum_embeddings = masked_embeddings
168 .sum(1)
169 .map_err(|e| SedimentError::Embedding(format!("Sum failed: {}", e)))?;
170
171 let sum_mask = attention_mask_f32
172 .sum(1)
173 .map_err(|e| SedimentError::Embedding(format!("Mask sum failed: {}", e)))?;
174
175 let mean_embeddings = sum_embeddings
176 .broadcast_div(&sum_mask)
177 .map_err(|e| SedimentError::Embedding(format!("Division failed: {}", e)))?;
178
179 let final_embeddings = normalize_l2(&mean_embeddings)?;
181
182 let embeddings_vec: Vec<Vec<f32>> = final_embeddings
184 .to_vec2()
185 .map_err(|e| SedimentError::Embedding(format!("Tensor to vec failed: {}", e)))?;
186
187 Ok(embeddings_vec)
188 }
189
190 pub fn dimension(&self) -> usize {
192 EMBEDDING_DIM
193 }
194}
195
196fn download_model(model_id: &str) -> Result<(PathBuf, PathBuf, PathBuf)> {
198 let api = ApiBuilder::from_env()
199 .with_progress(true)
200 .build()
201 .map_err(|e| SedimentError::ModelLoading(format!("Failed to create HF API: {}", e)))?;
202
203 let repo = api.repo(Repo::with_revision(
204 model_id.to_string(),
205 RepoType::Model,
206 "e4ce9877abf3edfe10b0d82785e83bdcb973e22e".to_string(),
207 ));
208
209 let model_path = repo
210 .get("model.safetensors")
211 .map_err(|e| SedimentError::ModelLoading(format!("Failed to download model: {}", e)))?;
212
213 let tokenizer_path = repo
214 .get("tokenizer.json")
215 .map_err(|e| SedimentError::ModelLoading(format!("Failed to download tokenizer: {}", e)))?;
216
217 let config_path = repo
218 .get("config.json")
219 .map_err(|e| SedimentError::ModelLoading(format!("Failed to download config: {}", e)))?;
220
221 verify_file_hash(&tokenizer_path, TOKENIZER_SHA256, "tokenizer.json")?;
225 verify_file_hash(&config_path, CONFIG_SHA256, "config.json")?;
226 info!("Tokenizer and config integrity verified (SHA-256)");
227
228 Ok((model_path, tokenizer_path, config_path))
229}
230
231const MODEL_SHA256: &str = "53aa51172d142c89d9012cce15ae4d6cc0ca6895895114379cacb4fab128d9db";
233const TOKENIZER_SHA256: &str = "be50c3628f2bf5bb5e3a7f17b1f74611b2561a3a27eeab05e5aa30f411572037";
234const CONFIG_SHA256: &str = "953f9c0d463486b10a6871cc2fd59f223b2c70184f49815e7efbcab5d8908b41";
235
236fn verify_file_hash(path: &std::path::Path, expected: &str, file_label: &str) -> Result<()> {
238 use sha2::{Digest, Sha256};
239
240 let file_bytes = std::fs::read(path).map_err(|e| {
241 SedimentError::ModelLoading(format!(
242 "Failed to read {} for hash verification: {}",
243 file_label, e
244 ))
245 })?;
246
247 let hash = Sha256::digest(&file_bytes);
248 let hex_hash = format!("{:x}", hash);
249
250 if hex_hash != expected {
251 return Err(SedimentError::ModelLoading(format!(
252 "{} integrity check failed: expected SHA-256 {}, got {}",
253 file_label, expected, hex_hash
254 )));
255 }
256
257 Ok(())
258}
259
260fn verify_bytes_hash(data: &[u8], expected: &str, file_label: &str) -> Result<()> {
265 use sha2::{Digest, Sha256};
266
267 let hash = Sha256::digest(data);
268 let hex_hash = format!("{:x}", hash);
269
270 if hex_hash != expected {
271 return Err(SedimentError::ModelLoading(format!(
272 "{} integrity check failed: expected SHA-256 {}, got {}",
273 file_label, expected, hex_hash
274 )));
275 }
276
277 Ok(())
278}
279
280fn normalize_l2(tensor: &Tensor) -> Result<Tensor> {
282 let norm = tensor
283 .sqr()
284 .map_err(|e| SedimentError::Embedding(format!("Sqr failed: {}", e)))?
285 .sum_keepdim(1)
286 .map_err(|e| SedimentError::Embedding(format!("Sum keepdim failed: {}", e)))?
287 .sqrt()
288 .map_err(|e| SedimentError::Embedding(format!("Sqrt failed: {}", e)))?;
289
290 tensor
291 .broadcast_div(&norm)
292 .map_err(|e| SedimentError::Embedding(format!("Normalize div failed: {}", e)))
293}
294
295#[cfg(test)]
296mod tests {
297 use super::*;
298
299 #[test]
300 #[ignore] fn test_embedder() -> Result<()> {
302 let embedder = Embedder::new()?;
303
304 let text = "Hello, world!";
305 let embedding = embedder.embed(text)?;
306
307 assert_eq!(embedding.len(), EMBEDDING_DIM);
308
309 let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
311 assert!((norm - 1.0).abs() < 0.01);
312
313 Ok(())
314 }
315
316 #[test]
317 #[ignore] fn test_batch_embedding() -> Result<()> {
319 let embedder = Embedder::new()?;
320
321 let texts = vec!["Hello", "World", "Test sentence"];
322 let embeddings = embedder.embed_batch(&texts)?;
323
324 assert_eq!(embeddings.len(), 3);
325 for emb in &embeddings {
326 assert_eq!(emb.len(), EMBEDDING_DIM);
327 }
328
329 Ok(())
330 }
331
332 #[test]
333 fn test_verify_bytes_hash_correct() {
334 let data = b"hello world";
335 let expected = "b94d27b9934d3e08a52e52d7da7dabfac484efe37a5380ee9088f7ace2efcde9";
336 assert!(verify_bytes_hash(data, expected, "test").is_ok());
337 }
338
339 #[test]
340 fn test_verify_bytes_hash_incorrect() {
341 let data = b"hello world";
342 let wrong = "0000000000000000000000000000000000000000000000000000000000000000";
343 let err = verify_bytes_hash(data, wrong, "test").unwrap_err();
344 assert!(err.to_string().contains("integrity check failed"));
345 }
346
347 #[test]
348 fn test_verify_bytes_hash_empty() {
349 let data = b"";
350 let expected = "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855";
352 assert!(verify_bytes_hash(data, expected, "empty").is_ok());
353 }
354}