Skip to main content

sediment/
embedder.rs

1use std::path::PathBuf;
2
3use candle_core::{DType, Device, Tensor};
4use candle_nn::VarBuilder;
5use candle_transformers::models::bert::{BertModel, Config, DTYPE};
6use hf_hub::{Repo, RepoType, api::sync::ApiBuilder};
7use tokenizers::{PaddingParams, Tokenizer, TruncationParams};
8use tracing::info;
9
10use crate::error::{Result, SedimentError};
11
12/// Default model to use for embeddings
13pub const DEFAULT_MODEL_ID: &str = "sentence-transformers/all-MiniLM-L6-v2";
14
15/// Embedding dimension for the default model
16pub const EMBEDDING_DIM: usize = 384;
17
18/// Embedder for converting text to vectors
19pub struct Embedder {
20    model: BertModel,
21    tokenizer: Tokenizer,
22    device: Device,
23    normalize: bool,
24}
25
26impl Embedder {
27    /// Create a new embedder, downloading the model if necessary
28    pub fn new() -> Result<Self> {
29        Self::with_model(DEFAULT_MODEL_ID)
30    }
31
32    /// Create an embedder with a specific model
33    pub fn with_model(model_id: &str) -> Result<Self> {
34        info!("Loading embedding model: {}", model_id);
35
36        let device = Device::Cpu;
37        let (model_path, tokenizer_path, config_path) = download_model(model_id)?;
38
39        // Load config
40        let config_str = std::fs::read_to_string(&config_path)
41            .map_err(|e| SedimentError::ModelLoading(format!("Failed to read config: {}", e)))?;
42        let config: Config = serde_json::from_str(&config_str)
43            .map_err(|e| SedimentError::ModelLoading(format!("Failed to parse config: {}", e)))?;
44
45        // Load tokenizer
46        let mut tokenizer = Tokenizer::from_file(&tokenizer_path)
47            .map_err(|e| SedimentError::Tokenizer(format!("Failed to load tokenizer: {}", e)))?;
48
49        // Configure tokenizer for batch processing
50        let padding = PaddingParams {
51            strategy: tokenizers::PaddingStrategy::BatchLongest,
52            ..Default::default()
53        };
54        let truncation = TruncationParams {
55            max_length: 512,
56            ..Default::default()
57        };
58        tokenizer.with_padding(Some(padding));
59        tokenizer
60            .with_truncation(Some(truncation))
61            .map_err(|e| SedimentError::Tokenizer(format!("Failed to set truncation: {}", e)))?;
62
63        // Load model weights
64        let vb = unsafe {
65            VarBuilder::from_mmaped_safetensors(&[model_path], DTYPE, &device).map_err(|e| {
66                SedimentError::ModelLoading(format!("Failed to load weights: {}", e))
67            })?
68        };
69
70        let model = BertModel::load(vb, &config)
71            .map_err(|e| SedimentError::ModelLoading(format!("Failed to load model: {}", e)))?;
72
73        info!("Embedding model loaded successfully");
74
75        Ok(Self {
76            model,
77            tokenizer,
78            device,
79            normalize: true,
80        })
81    }
82
83    /// Embed a single text
84    pub fn embed(&self, text: &str) -> Result<Vec<f32>> {
85        let embeddings = self.embed_batch(&[text])?;
86        Ok(embeddings.into_iter().next().unwrap())
87    }
88
89    /// Embed multiple texts at once
90    pub fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
91        if texts.is_empty() {
92            return Ok(Vec::new());
93        }
94
95        // Tokenize
96        let encodings = self
97            .tokenizer
98            .encode_batch(texts.to_vec(), true)
99            .map_err(|e| SedimentError::Tokenizer(format!("Tokenization failed: {}", e)))?;
100
101        let token_ids: Vec<Vec<u32>> = encodings.iter().map(|e| e.get_ids().to_vec()).collect();
102
103        let attention_masks: Vec<Vec<u32>> = encodings
104            .iter()
105            .map(|e| e.get_attention_mask().to_vec())
106            .collect();
107
108        let token_type_ids: Vec<Vec<u32>> = encodings
109            .iter()
110            .map(|e| e.get_type_ids().to_vec())
111            .collect();
112
113        // Convert to tensors
114        let batch_size = texts.len();
115        let seq_len = token_ids[0].len();
116
117        let token_ids_flat: Vec<u32> = token_ids.into_iter().flatten().collect();
118        let attention_mask_flat: Vec<u32> = attention_masks.into_iter().flatten().collect();
119        let token_type_ids_flat: Vec<u32> = token_type_ids.into_iter().flatten().collect();
120
121        let token_ids_tensor =
122            Tensor::from_vec(token_ids_flat, (batch_size, seq_len), &self.device).map_err(|e| {
123                SedimentError::Embedding(format!("Failed to create token tensor: {}", e))
124            })?;
125
126        let attention_mask_tensor =
127            Tensor::from_vec(attention_mask_flat, (batch_size, seq_len), &self.device).map_err(
128                |e| SedimentError::Embedding(format!("Failed to create mask tensor: {}", e)),
129            )?;
130
131        let token_type_ids_tensor =
132            Tensor::from_vec(token_type_ids_flat, (batch_size, seq_len), &self.device).map_err(
133                |e| SedimentError::Embedding(format!("Failed to create type tensor: {}", e)),
134            )?;
135
136        // Run model
137        let embeddings = self
138            .model
139            .forward(
140                &token_ids_tensor,
141                &token_type_ids_tensor,
142                Some(&attention_mask_tensor),
143            )
144            .map_err(|e| SedimentError::Embedding(format!("Model forward failed: {}", e)))?;
145
146        // Mean pooling with attention mask
147        let attention_mask_f32 = attention_mask_tensor
148            .to_dtype(DType::F32)
149            .map_err(|e| SedimentError::Embedding(format!("Mask conversion failed: {}", e)))?
150            .unsqueeze(2)
151            .map_err(|e| SedimentError::Embedding(format!("Unsqueeze failed: {}", e)))?;
152
153        let masked_embeddings = embeddings
154            .broadcast_mul(&attention_mask_f32)
155            .map_err(|e| SedimentError::Embedding(format!("Broadcast mul failed: {}", e)))?;
156
157        let sum_embeddings = masked_embeddings
158            .sum(1)
159            .map_err(|e| SedimentError::Embedding(format!("Sum failed: {}", e)))?;
160
161        let sum_mask = attention_mask_f32
162            .sum(1)
163            .map_err(|e| SedimentError::Embedding(format!("Mask sum failed: {}", e)))?;
164
165        let mean_embeddings = sum_embeddings
166            .broadcast_div(&sum_mask)
167            .map_err(|e| SedimentError::Embedding(format!("Division failed: {}", e)))?;
168
169        // Normalize if requested
170        let final_embeddings = if self.normalize {
171            normalize_l2(&mean_embeddings)?
172        } else {
173            mean_embeddings
174        };
175
176        // Convert to Vec<Vec<f32>>
177        let embeddings_vec: Vec<Vec<f32>> = final_embeddings
178            .to_vec2()
179            .map_err(|e| SedimentError::Embedding(format!("Tensor to vec failed: {}", e)))?;
180
181        Ok(embeddings_vec)
182    }
183
184    /// Get the embedding dimension
185    pub fn dimension(&self) -> usize {
186        EMBEDDING_DIM
187    }
188}
189
190/// Download model files from Hugging Face Hub
191fn download_model(model_id: &str) -> Result<(PathBuf, PathBuf, PathBuf)> {
192    let api = ApiBuilder::from_env()
193        .with_progress(true)
194        .build()
195        .map_err(|e| SedimentError::ModelLoading(format!("Failed to create HF API: {}", e)))?;
196
197    let repo = api.repo(Repo::new(model_id.to_string(), RepoType::Model));
198
199    let model_path = repo
200        .get("model.safetensors")
201        .map_err(|e| SedimentError::ModelLoading(format!("Failed to download model: {}", e)))?;
202
203    let tokenizer_path = repo
204        .get("tokenizer.json")
205        .map_err(|e| SedimentError::ModelLoading(format!("Failed to download tokenizer: {}", e)))?;
206
207    let config_path = repo
208        .get("config.json")
209        .map_err(|e| SedimentError::ModelLoading(format!("Failed to download config: {}", e)))?;
210
211    Ok((model_path, tokenizer_path, config_path))
212}
213
214/// L2 normalize a tensor
215fn normalize_l2(tensor: &Tensor) -> Result<Tensor> {
216    let norm = tensor
217        .sqr()
218        .map_err(|e| SedimentError::Embedding(format!("Sqr failed: {}", e)))?
219        .sum_keepdim(1)
220        .map_err(|e| SedimentError::Embedding(format!("Sum keepdim failed: {}", e)))?
221        .sqrt()
222        .map_err(|e| SedimentError::Embedding(format!("Sqrt failed: {}", e)))?;
223
224    tensor
225        .broadcast_div(&norm)
226        .map_err(|e| SedimentError::Embedding(format!("Normalize div failed: {}", e)))
227}
228
229#[cfg(test)]
230mod tests {
231    use super::*;
232
233    #[test]
234    #[ignore] // Requires model download
235    fn test_embedder() -> Result<()> {
236        let embedder = Embedder::new()?;
237
238        let text = "Hello, world!";
239        let embedding = embedder.embed(text)?;
240
241        assert_eq!(embedding.len(), EMBEDDING_DIM);
242
243        // Check normalization (L2 norm should be ~1.0)
244        let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
245        assert!((norm - 1.0).abs() < 0.01);
246
247        Ok(())
248    }
249
250    #[test]
251    #[ignore] // Requires model download
252    fn test_batch_embedding() -> Result<()> {
253        let embedder = Embedder::new()?;
254
255        let texts = vec!["Hello", "World", "Test sentence"];
256        let embeddings = embedder.embed_batch(&texts)?;
257
258        assert_eq!(embeddings.len(), 3);
259        for emb in &embeddings {
260            assert_eq!(emb.len(), EMBEDDING_DIM);
261        }
262
263        Ok(())
264    }
265}