1use std::path::PathBuf;
2
3use candle_core::{DType, Device, Tensor};
4use candle_nn::VarBuilder;
5use candle_transformers::models::bert::{BertModel, Config, DTYPE};
6use hf_hub::{Repo, RepoType, api::sync::ApiBuilder};
7use tokenizers::{PaddingParams, Tokenizer, TruncationParams};
8use tracing::info;
9
10use crate::error::{Result, SedimentError};
11
12pub const DEFAULT_MODEL_ID: &str = "sentence-transformers/all-MiniLM-L6-v2";
14
15pub const EMBEDDING_DIM: usize = 384;
17
18pub struct Embedder {
20 model: BertModel,
21 tokenizer: Tokenizer,
22 device: Device,
23 normalize: bool,
24}
25
26impl Embedder {
27 pub fn new() -> Result<Self> {
29 Self::with_model(DEFAULT_MODEL_ID)
30 }
31
32 pub fn with_model(model_id: &str) -> Result<Self> {
34 info!("Loading embedding model: {}", model_id);
35
36 let device = Device::Cpu;
37 let (model_path, tokenizer_path, config_path) = download_model(model_id)?;
38
39 let config_str = std::fs::read_to_string(&config_path)
41 .map_err(|e| SedimentError::ModelLoading(format!("Failed to read config: {}", e)))?;
42 let config: Config = serde_json::from_str(&config_str)
43 .map_err(|e| SedimentError::ModelLoading(format!("Failed to parse config: {}", e)))?;
44
45 let mut tokenizer = Tokenizer::from_file(&tokenizer_path)
47 .map_err(|e| SedimentError::Tokenizer(format!("Failed to load tokenizer: {}", e)))?;
48
49 let padding = PaddingParams {
51 strategy: tokenizers::PaddingStrategy::BatchLongest,
52 ..Default::default()
53 };
54 let truncation = TruncationParams {
55 max_length: 512,
56 ..Default::default()
57 };
58 tokenizer.with_padding(Some(padding));
59 tokenizer
60 .with_truncation(Some(truncation))
61 .map_err(|e| SedimentError::Tokenizer(format!("Failed to set truncation: {}", e)))?;
62
63 let vb = unsafe {
65 VarBuilder::from_mmaped_safetensors(&[model_path], DTYPE, &device).map_err(|e| {
66 SedimentError::ModelLoading(format!("Failed to load weights: {}", e))
67 })?
68 };
69
70 let model = BertModel::load(vb, &config)
71 .map_err(|e| SedimentError::ModelLoading(format!("Failed to load model: {}", e)))?;
72
73 info!("Embedding model loaded successfully");
74
75 Ok(Self {
76 model,
77 tokenizer,
78 device,
79 normalize: true,
80 })
81 }
82
83 pub fn embed(&self, text: &str) -> Result<Vec<f32>> {
85 let embeddings = self.embed_batch(&[text])?;
86 Ok(embeddings.into_iter().next().unwrap())
87 }
88
89 pub fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
91 if texts.is_empty() {
92 return Ok(Vec::new());
93 }
94
95 let encodings = self
97 .tokenizer
98 .encode_batch(texts.to_vec(), true)
99 .map_err(|e| SedimentError::Tokenizer(format!("Tokenization failed: {}", e)))?;
100
101 let token_ids: Vec<Vec<u32>> = encodings.iter().map(|e| e.get_ids().to_vec()).collect();
102
103 let attention_masks: Vec<Vec<u32>> = encodings
104 .iter()
105 .map(|e| e.get_attention_mask().to_vec())
106 .collect();
107
108 let token_type_ids: Vec<Vec<u32>> = encodings
109 .iter()
110 .map(|e| e.get_type_ids().to_vec())
111 .collect();
112
113 let batch_size = texts.len();
115 let seq_len = token_ids[0].len();
116
117 let token_ids_flat: Vec<u32> = token_ids.into_iter().flatten().collect();
118 let attention_mask_flat: Vec<u32> = attention_masks.into_iter().flatten().collect();
119 let token_type_ids_flat: Vec<u32> = token_type_ids.into_iter().flatten().collect();
120
121 let token_ids_tensor =
122 Tensor::from_vec(token_ids_flat, (batch_size, seq_len), &self.device).map_err(|e| {
123 SedimentError::Embedding(format!("Failed to create token tensor: {}", e))
124 })?;
125
126 let attention_mask_tensor =
127 Tensor::from_vec(attention_mask_flat, (batch_size, seq_len), &self.device).map_err(
128 |e| SedimentError::Embedding(format!("Failed to create mask tensor: {}", e)),
129 )?;
130
131 let token_type_ids_tensor =
132 Tensor::from_vec(token_type_ids_flat, (batch_size, seq_len), &self.device).map_err(
133 |e| SedimentError::Embedding(format!("Failed to create type tensor: {}", e)),
134 )?;
135
136 let embeddings = self
138 .model
139 .forward(
140 &token_ids_tensor,
141 &token_type_ids_tensor,
142 Some(&attention_mask_tensor),
143 )
144 .map_err(|e| SedimentError::Embedding(format!("Model forward failed: {}", e)))?;
145
146 let attention_mask_f32 = attention_mask_tensor
148 .to_dtype(DType::F32)
149 .map_err(|e| SedimentError::Embedding(format!("Mask conversion failed: {}", e)))?
150 .unsqueeze(2)
151 .map_err(|e| SedimentError::Embedding(format!("Unsqueeze failed: {}", e)))?;
152
153 let masked_embeddings = embeddings
154 .broadcast_mul(&attention_mask_f32)
155 .map_err(|e| SedimentError::Embedding(format!("Broadcast mul failed: {}", e)))?;
156
157 let sum_embeddings = masked_embeddings
158 .sum(1)
159 .map_err(|e| SedimentError::Embedding(format!("Sum failed: {}", e)))?;
160
161 let sum_mask = attention_mask_f32
162 .sum(1)
163 .map_err(|e| SedimentError::Embedding(format!("Mask sum failed: {}", e)))?;
164
165 let mean_embeddings = sum_embeddings
166 .broadcast_div(&sum_mask)
167 .map_err(|e| SedimentError::Embedding(format!("Division failed: {}", e)))?;
168
169 let final_embeddings = if self.normalize {
171 normalize_l2(&mean_embeddings)?
172 } else {
173 mean_embeddings
174 };
175
176 let embeddings_vec: Vec<Vec<f32>> = final_embeddings
178 .to_vec2()
179 .map_err(|e| SedimentError::Embedding(format!("Tensor to vec failed: {}", e)))?;
180
181 Ok(embeddings_vec)
182 }
183
184 pub fn dimension(&self) -> usize {
186 EMBEDDING_DIM
187 }
188}
189
190fn download_model(model_id: &str) -> Result<(PathBuf, PathBuf, PathBuf)> {
192 let api = ApiBuilder::from_env()
193 .with_progress(true)
194 .build()
195 .map_err(|e| SedimentError::ModelLoading(format!("Failed to create HF API: {}", e)))?;
196
197 let repo = api.repo(Repo::new(model_id.to_string(), RepoType::Model));
198
199 let model_path = repo
200 .get("model.safetensors")
201 .map_err(|e| SedimentError::ModelLoading(format!("Failed to download model: {}", e)))?;
202
203 let tokenizer_path = repo
204 .get("tokenizer.json")
205 .map_err(|e| SedimentError::ModelLoading(format!("Failed to download tokenizer: {}", e)))?;
206
207 let config_path = repo
208 .get("config.json")
209 .map_err(|e| SedimentError::ModelLoading(format!("Failed to download config: {}", e)))?;
210
211 Ok((model_path, tokenizer_path, config_path))
212}
213
214fn normalize_l2(tensor: &Tensor) -> Result<Tensor> {
216 let norm = tensor
217 .sqr()
218 .map_err(|e| SedimentError::Embedding(format!("Sqr failed: {}", e)))?
219 .sum_keepdim(1)
220 .map_err(|e| SedimentError::Embedding(format!("Sum keepdim failed: {}", e)))?
221 .sqrt()
222 .map_err(|e| SedimentError::Embedding(format!("Sqrt failed: {}", e)))?;
223
224 tensor
225 .broadcast_div(&norm)
226 .map_err(|e| SedimentError::Embedding(format!("Normalize div failed: {}", e)))
227}
228
229#[cfg(test)]
230mod tests {
231 use super::*;
232
233 #[test]
234 #[ignore] fn test_embedder() -> Result<()> {
236 let embedder = Embedder::new()?;
237
238 let text = "Hello, world!";
239 let embedding = embedder.embed(text)?;
240
241 assert_eq!(embedding.len(), EMBEDDING_DIM);
242
243 let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
245 assert!((norm - 1.0).abs() < 0.01);
246
247 Ok(())
248 }
249
250 #[test]
251 #[ignore] fn test_batch_embedding() -> Result<()> {
253 let embedder = Embedder::new()?;
254
255 let texts = vec!["Hello", "World", "Test sentence"];
256 let embeddings = embedder.embed_batch(&texts)?;
257
258 assert_eq!(embeddings.len(), 3);
259 for emb in &embeddings {
260 assert_eq!(emb.len(), EMBEDDING_DIM);
261 }
262
263 Ok(())
264 }
265}