1use std::path::PathBuf;
2
3use candle_core::{DType, Device, Tensor};
4use candle_nn::VarBuilder;
5use candle_transformers::models::bert::{BertModel, Config, DTYPE};
6use hf_hub::{Repo, RepoType, api::sync::ApiBuilder};
7use tokenizers::{PaddingParams, Tokenizer, TruncationParams};
8use tracing::info;
9
10use crate::error::{Result, SedimentError};
11
12pub const DEFAULT_MODEL_ID: &str = "sentence-transformers/all-MiniLM-L6-v2";
14
15pub const EMBEDDING_DIM: usize = 384;
17
18pub struct Embedder {
26 model: BertModel,
27 tokenizer: Tokenizer,
28 device: Device,
29 normalize: bool,
30}
31
32impl Embedder {
33 pub fn new() -> Result<Self> {
35 Self::with_model(DEFAULT_MODEL_ID)
36 }
37
38 pub fn with_model(model_id: &str) -> Result<Self> {
40 info!("Loading embedding model: {}", model_id);
41
42 let device = Device::Cpu;
43 let (model_path, tokenizer_path, config_path) = download_model(model_id)?;
44
45 let config_str = std::fs::read_to_string(&config_path)
47 .map_err(|e| SedimentError::ModelLoading(format!("Failed to read config: {}", e)))?;
48 let config: Config = serde_json::from_str(&config_str)
49 .map_err(|e| SedimentError::ModelLoading(format!("Failed to parse config: {}", e)))?;
50
51 let mut tokenizer = Tokenizer::from_file(&tokenizer_path)
53 .map_err(|e| SedimentError::Tokenizer(format!("Failed to load tokenizer: {}", e)))?;
54
55 let padding = PaddingParams {
57 strategy: tokenizers::PaddingStrategy::BatchLongest,
58 ..Default::default()
59 };
60 let truncation = TruncationParams {
61 max_length: 512,
62 ..Default::default()
63 };
64 tokenizer.with_padding(Some(padding));
65 tokenizer
66 .with_truncation(Some(truncation))
67 .map_err(|e| SedimentError::Tokenizer(format!("Failed to set truncation: {}", e)))?;
68
69 let model_bytes = std::fs::read(&model_path).map_err(|e| {
74 SedimentError::ModelLoading(format!("Failed to read model weights: {}", e))
75 })?;
76 verify_bytes_hash(&model_bytes, MODEL_SHA256, "model.safetensors")?;
77 let vb = VarBuilder::from_buffered_safetensors(model_bytes, DTYPE, &device)
78 .map_err(|e| SedimentError::ModelLoading(format!("Failed to load weights: {}", e)))?;
79
80 let model = BertModel::load(vb, &config)
81 .map_err(|e| SedimentError::ModelLoading(format!("Failed to load model: {}", e)))?;
82
83 info!("Embedding model loaded successfully");
84
85 Ok(Self {
86 model,
87 tokenizer,
88 device,
89 normalize: true,
90 })
91 }
92
93 pub fn embed(&self, text: &str) -> Result<Vec<f32>> {
95 let embeddings = self.embed_batch(&[text])?;
96 embeddings.into_iter().next().ok_or_else(|| {
97 SedimentError::Embedding("embed_batch returned empty result for non-empty input".into())
98 })
99 }
100
101 pub fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
103 if texts.is_empty() {
104 return Ok(Vec::new());
105 }
106
107 let encodings = self
109 .tokenizer
110 .encode_batch(texts.to_vec(), true)
111 .map_err(|e| SedimentError::Tokenizer(format!("Tokenization failed: {}", e)))?;
112
113 let token_ids: Vec<Vec<u32>> = encodings.iter().map(|e| e.get_ids().to_vec()).collect();
114
115 let attention_masks: Vec<Vec<u32>> = encodings
116 .iter()
117 .map(|e| e.get_attention_mask().to_vec())
118 .collect();
119
120 let token_type_ids: Vec<Vec<u32>> = encodings
121 .iter()
122 .map(|e| e.get_type_ids().to_vec())
123 .collect();
124
125 let batch_size = texts.len();
127 let seq_len = token_ids[0].len();
128
129 let token_ids_flat: Vec<u32> = token_ids.into_iter().flatten().collect();
130 let attention_mask_flat: Vec<u32> = attention_masks.into_iter().flatten().collect();
131 let token_type_ids_flat: Vec<u32> = token_type_ids.into_iter().flatten().collect();
132
133 let token_ids_tensor =
134 Tensor::from_vec(token_ids_flat, (batch_size, seq_len), &self.device).map_err(|e| {
135 SedimentError::Embedding(format!("Failed to create token tensor: {}", e))
136 })?;
137
138 let attention_mask_tensor =
139 Tensor::from_vec(attention_mask_flat, (batch_size, seq_len), &self.device).map_err(
140 |e| SedimentError::Embedding(format!("Failed to create mask tensor: {}", e)),
141 )?;
142
143 let token_type_ids_tensor =
144 Tensor::from_vec(token_type_ids_flat, (batch_size, seq_len), &self.device).map_err(
145 |e| SedimentError::Embedding(format!("Failed to create type tensor: {}", e)),
146 )?;
147
148 let embeddings = self
150 .model
151 .forward(
152 &token_ids_tensor,
153 &token_type_ids_tensor,
154 Some(&attention_mask_tensor),
155 )
156 .map_err(|e| SedimentError::Embedding(format!("Model forward failed: {}", e)))?;
157
158 let attention_mask_f32 = attention_mask_tensor
160 .to_dtype(DType::F32)
161 .map_err(|e| SedimentError::Embedding(format!("Mask conversion failed: {}", e)))?
162 .unsqueeze(2)
163 .map_err(|e| SedimentError::Embedding(format!("Unsqueeze failed: {}", e)))?;
164
165 let masked_embeddings = embeddings
166 .broadcast_mul(&attention_mask_f32)
167 .map_err(|e| SedimentError::Embedding(format!("Broadcast mul failed: {}", e)))?;
168
169 let sum_embeddings = masked_embeddings
170 .sum(1)
171 .map_err(|e| SedimentError::Embedding(format!("Sum failed: {}", e)))?;
172
173 let sum_mask = attention_mask_f32
174 .sum(1)
175 .map_err(|e| SedimentError::Embedding(format!("Mask sum failed: {}", e)))?;
176
177 let mean_embeddings = sum_embeddings
178 .broadcast_div(&sum_mask)
179 .map_err(|e| SedimentError::Embedding(format!("Division failed: {}", e)))?;
180
181 let final_embeddings = if self.normalize {
183 normalize_l2(&mean_embeddings)?
184 } else {
185 mean_embeddings
186 };
187
188 let embeddings_vec: Vec<Vec<f32>> = final_embeddings
190 .to_vec2()
191 .map_err(|e| SedimentError::Embedding(format!("Tensor to vec failed: {}", e)))?;
192
193 Ok(embeddings_vec)
194 }
195
196 pub fn dimension(&self) -> usize {
198 EMBEDDING_DIM
199 }
200}
201
202fn download_model(model_id: &str) -> Result<(PathBuf, PathBuf, PathBuf)> {
204 let api = ApiBuilder::from_env()
205 .with_progress(true)
206 .build()
207 .map_err(|e| SedimentError::ModelLoading(format!("Failed to create HF API: {}", e)))?;
208
209 let repo = api.repo(Repo::with_revision(
210 model_id.to_string(),
211 RepoType::Model,
212 "e4ce9877abf3edfe10b0d82785e83bdcb973e22e".to_string(),
213 ));
214
215 let model_path = repo
216 .get("model.safetensors")
217 .map_err(|e| SedimentError::ModelLoading(format!("Failed to download model: {}", e)))?;
218
219 let tokenizer_path = repo
220 .get("tokenizer.json")
221 .map_err(|e| SedimentError::ModelLoading(format!("Failed to download tokenizer: {}", e)))?;
222
223 let config_path = repo
224 .get("config.json")
225 .map_err(|e| SedimentError::ModelLoading(format!("Failed to download config: {}", e)))?;
226
227 verify_file_hash(&tokenizer_path, TOKENIZER_SHA256, "tokenizer.json")?;
231 verify_file_hash(&config_path, CONFIG_SHA256, "config.json")?;
232 info!("Tokenizer and config integrity verified (SHA-256)");
233
234 Ok((model_path, tokenizer_path, config_path))
235}
236
237const MODEL_SHA256: &str = "53aa51172d142c89d9012cce15ae4d6cc0ca6895895114379cacb4fab128d9db";
239const TOKENIZER_SHA256: &str = "be50c3628f2bf5bb5e3a7f17b1f74611b2561a3a27eeab05e5aa30f411572037";
240const CONFIG_SHA256: &str = "953f9c0d463486b10a6871cc2fd59f223b2c70184f49815e7efbcab5d8908b41";
241
242fn verify_file_hash(path: &std::path::Path, expected: &str, file_label: &str) -> Result<()> {
244 use sha2::{Digest, Sha256};
245
246 let file_bytes = std::fs::read(path).map_err(|e| {
247 SedimentError::ModelLoading(format!(
248 "Failed to read {} for hash verification: {}",
249 file_label, e
250 ))
251 })?;
252
253 let hash = Sha256::digest(&file_bytes);
254 let hex_hash = format!("{:x}", hash);
255
256 if hex_hash != expected {
257 return Err(SedimentError::ModelLoading(format!(
258 "{} integrity check failed: expected SHA-256 {}, got {}",
259 file_label, expected, hex_hash
260 )));
261 }
262
263 Ok(())
264}
265
266fn verify_bytes_hash(data: &[u8], expected: &str, file_label: &str) -> Result<()> {
271 use sha2::{Digest, Sha256};
272
273 let hash = Sha256::digest(data);
274 let hex_hash = format!("{:x}", hash);
275
276 if hex_hash != expected {
277 return Err(SedimentError::ModelLoading(format!(
278 "{} integrity check failed: expected SHA-256 {}, got {}",
279 file_label, expected, hex_hash
280 )));
281 }
282
283 Ok(())
284}
285
286fn normalize_l2(tensor: &Tensor) -> Result<Tensor> {
288 let norm = tensor
289 .sqr()
290 .map_err(|e| SedimentError::Embedding(format!("Sqr failed: {}", e)))?
291 .sum_keepdim(1)
292 .map_err(|e| SedimentError::Embedding(format!("Sum keepdim failed: {}", e)))?
293 .sqrt()
294 .map_err(|e| SedimentError::Embedding(format!("Sqrt failed: {}", e)))?;
295
296 tensor
297 .broadcast_div(&norm)
298 .map_err(|e| SedimentError::Embedding(format!("Normalize div failed: {}", e)))
299}
300
301#[cfg(test)]
302mod tests {
303 use super::*;
304
305 #[test]
306 #[ignore] fn test_embedder() -> Result<()> {
308 let embedder = Embedder::new()?;
309
310 let text = "Hello, world!";
311 let embedding = embedder.embed(text)?;
312
313 assert_eq!(embedding.len(), EMBEDDING_DIM);
314
315 let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
317 assert!((norm - 1.0).abs() < 0.01);
318
319 Ok(())
320 }
321
322 #[test]
323 #[ignore] fn test_batch_embedding() -> Result<()> {
325 let embedder = Embedder::new()?;
326
327 let texts = vec!["Hello", "World", "Test sentence"];
328 let embeddings = embedder.embed_batch(&texts)?;
329
330 assert_eq!(embeddings.len(), 3);
331 for emb in &embeddings {
332 assert_eq!(emb.len(), EMBEDDING_DIM);
333 }
334
335 Ok(())
336 }
337
338 #[test]
339 fn test_verify_bytes_hash_correct() {
340 let data = b"hello world";
341 let expected = "b94d27b9934d3e08a52e52d7da7dabfac484efe37a5380ee9088f7ace2efcde9";
342 assert!(verify_bytes_hash(data, expected, "test").is_ok());
343 }
344
345 #[test]
346 fn test_verify_bytes_hash_incorrect() {
347 let data = b"hello world";
348 let wrong = "0000000000000000000000000000000000000000000000000000000000000000";
349 let err = verify_bytes_hash(data, wrong, "test").unwrap_err();
350 assert!(err.to_string().contains("integrity check failed"));
351 }
352
353 #[test]
354 fn test_verify_bytes_hash_empty() {
355 let data = b"";
356 let expected = "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855";
358 assert!(verify_bytes_hash(data, expected, "empty").is_ok());
359 }
360}