seekr_code/embedder/
onnx.rs1use std::path::{Path, PathBuf};
7use std::sync::Mutex;
8
9use ort::session::builder::GraphOptimizationLevel;
10use ort::session::Session;
11use ort::value::TensorRef;
12
13use crate::embedder::traits::Embedder;
14use crate::error::EmbedderError;
15
16const MODEL_URL: &str = "https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2/resolve/main/onnx/model_quantized.onnx";
18
19const MODEL_FILENAME: &str = "all-MiniLM-L6-v2-quantized.onnx";
21
22const TOKENIZER_URL: &str = "https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2/resolve/main/tokenizer.json";
24
25const TOKENIZER_FILENAME: &str = "tokenizer.json";
27
28const EMBEDDING_DIM: usize = 384;
30
31const MAX_SEQ_LENGTH: usize = 256;
33
34pub struct OnnxEmbedder {
36 session: Mutex<Session>,
38 tokenizer: tokenizers::Tokenizer,
40 model_dir: PathBuf,
41}
42
43impl OnnxEmbedder {
44 pub fn new(model_dir: &Path) -> Result<Self, EmbedderError> {
49 std::fs::create_dir_all(model_dir).map_err(EmbedderError::Io)?;
50
51 let model_path = model_dir.join(MODEL_FILENAME);
52
53 if !model_path.exists() {
55 tracing::info!("Downloading ONNX model to {}...", model_path.display());
56 download_file(MODEL_URL, &model_path)?;
57 tracing::info!("Model downloaded successfully.");
58 }
59
60 let tokenizer_path = model_dir.join(TOKENIZER_FILENAME);
62 if !tokenizer_path.exists() {
63 tracing::info!("Downloading tokenizer...");
64 download_file(TOKENIZER_URL, &tokenizer_path)?;
65 tracing::info!("Tokenizer downloaded successfully.");
66 }
67
68 let session = Session::builder()
70 .map_err(|e| EmbedderError::OnnxError(e.to_string()))?
71 .with_optimization_level(GraphOptimizationLevel::Level3)
72 .unwrap_or_else(|e| e.recover())
73 .with_intra_threads(4)
74 .unwrap_or_else(|e| e.recover())
75 .commit_from_file(&model_path)
76 .map_err(|e| EmbedderError::OnnxError(format!("Failed to load model: {}", e)))?;
77
78 let tokenizer = tokenizers::Tokenizer::from_file(&tokenizer_path)
80 .map_err(|e| EmbedderError::OnnxError(format!("Failed to load tokenizer: {}", e)))?;
81
82 Ok(Self {
83 session: Mutex::new(session),
84 tokenizer,
85 model_dir: model_dir.to_path_buf(),
86 })
87 }
88
89 pub fn model_dir(&self) -> &Path {
91 &self.model_dir
92 }
93
94 fn tokenize(&self, text: &str) -> (Vec<i64>, Vec<i64>) {
100 let encoding = self
102 .tokenizer
103 .encode(text, true)
104 .unwrap_or_else(|_| {
105 self.tokenizer.encode("", true).unwrap()
107 });
108
109 let mut input_ids: Vec<i64> = encoding.get_ids().iter().map(|&id| id as i64).collect();
110 let mut attention_mask: Vec<i64> =
111 encoding.get_attention_mask().iter().map(|&m| m as i64).collect();
112
113 if input_ids.len() > MAX_SEQ_LENGTH {
115 input_ids.truncate(MAX_SEQ_LENGTH);
116 attention_mask.truncate(MAX_SEQ_LENGTH);
117 if let Some(last) = input_ids.last_mut() {
119 *last = 102;
120 }
121 }
122
123 while input_ids.len() < MAX_SEQ_LENGTH {
125 input_ids.push(0);
126 attention_mask.push(0);
127 }
128
129 (input_ids, attention_mask)
130 }
131
132 fn run_inference(
134 &self,
135 input_ids: &[i64],
136 attention_mask: &[i64],
137 ) -> Result<Vec<f32>, EmbedderError> {
138 let seq_len = input_ids.len();
139
140 let input_ids_array =
141 ndarray::Array2::from_shape_vec((1, seq_len), input_ids.to_vec())
142 .map_err(|e| EmbedderError::OnnxError(format!("Shape error: {}", e)))?;
143 let attention_mask_array =
144 ndarray::Array2::from_shape_vec((1, seq_len), attention_mask.to_vec())
145 .map_err(|e| EmbedderError::OnnxError(format!("Shape error: {}", e)))?;
146 let token_type_ids_array = ndarray::Array2::<i64>::zeros((1, seq_len));
147
148 let input_ids_tensor = TensorRef::from_array_view(&input_ids_array)
150 .map_err(|e| EmbedderError::OnnxError(format!("Tensor creation error: {}", e)))?;
151 let attention_mask_tensor = TensorRef::from_array_view(&attention_mask_array)
152 .map_err(|e| EmbedderError::OnnxError(format!("Tensor creation error: {}", e)))?;
153 let token_type_ids_tensor = TensorRef::from_array_view(&token_type_ids_array)
154 .map_err(|e| EmbedderError::OnnxError(format!("Tensor creation error: {}", e)))?;
155
156 let mut session = self.session.lock().map_err(|e| {
157 EmbedderError::OnnxError(format!("Session lock poisoned: {}", e))
158 })?;
159
160 let outputs = session
161 .run(ort::inputs![
162 "input_ids" => input_ids_tensor,
163 "attention_mask" => attention_mask_tensor,
164 "token_type_ids" => token_type_ids_tensor
165 ])
166 .map_err(|e| EmbedderError::OnnxError(format!("Inference error: {}", e)))?;
167
168 let output = if outputs.contains_key("last_hidden_state") {
170 &outputs["last_hidden_state"]
171 } else if outputs.contains_key("token_embeddings") {
172 &outputs["token_embeddings"]
173 } else {
174 &outputs[0]
175 };
176
177 let tensor = output
178 .try_extract_array::<f32>()
179 .map_err(|e| EmbedderError::OnnxError(format!("Extract error: {}", e)))?;
180
181 let shape = tensor.shape();
184 if shape.len() != 3 {
185 return Err(EmbedderError::OnnxError(format!(
186 "Unexpected output shape: {:?}",
187 shape
188 )));
189 }
190
191 let hidden_size = shape[2];
192 let mut pooled = vec![0.0f32; hidden_size];
193 let active_tokens: f32 = attention_mask.iter().map(|&m| m as f32).sum();
194
195 if active_tokens > 0.0 {
196 for seq_idx in 0..shape[1] {
197 let mask = attention_mask.get(seq_idx).copied().unwrap_or(0) as f32;
198 if mask > 0.0 {
199 for dim in 0..hidden_size {
200 pooled[dim] += tensor[[0, seq_idx, dim]];
201 }
202 }
203 }
204 for val in &mut pooled {
205 *val /= active_tokens;
206 }
207 }
208
209 let norm: f32 = pooled.iter().map(|x| x * x).sum::<f32>().sqrt();
211 if norm > 0.0 {
212 for x in &mut pooled {
213 *x /= norm;
214 }
215 }
216
217 Ok(pooled)
218 }
219}
220
221impl Embedder for OnnxEmbedder {
222 fn embed(&self, text: &str) -> Result<Vec<f32>, EmbedderError> {
223 let (input_ids, attention_mask) = self.tokenize(text);
224 self.run_inference(&input_ids, &attention_mask)
225 }
226
227 fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, EmbedderError> {
228 texts.iter().map(|text| self.embed(text)).collect()
231 }
232
233 fn dimension(&self) -> usize {
234 EMBEDDING_DIM
235 }
236}
237
238fn download_file(url: &str, dest: &Path) -> Result<(), EmbedderError> {
240 let response = reqwest::blocking::get(url)
241 .map_err(|e| EmbedderError::DownloadFailed(format!("HTTP request failed: {}", e)))?;
242
243 if !response.status().is_success() {
244 return Err(EmbedderError::DownloadFailed(format!(
245 "HTTP {} for {}",
246 response.status(),
247 url
248 )));
249 }
250
251 let bytes = response
252 .bytes()
253 .map_err(|e| EmbedderError::DownloadFailed(format!("Failed to read response: {}", e)))?;
254
255 if bytes.is_empty() {
257 return Err(EmbedderError::DownloadFailed(
258 "Downloaded file is empty".to_string(),
259 ));
260 }
261
262 let tmp_path = dest.with_extension("tmp");
264 std::fs::write(&tmp_path, &bytes).map_err(EmbedderError::Io)?;
265 std::fs::rename(&tmp_path, dest).map_err(EmbedderError::Io)?;
266
267 tracing::info!(
268 "Downloaded {} bytes to {}",
269 bytes.len(),
270 dest.display()
271 );
272
273 Ok(())
274}
275
276#[cfg(test)]
277mod tests {
278 use super::*;
279
280 #[test]
281 fn test_tokenize_output_length() {
282 let model_dir = std::env::temp_dir().join("seekr_test_tokenizer");
286 if let Ok(embedder) = OnnxEmbedder::new(&model_dir) {
287 let (ids, mask) = embedder.tokenize("hello world");
288 assert_eq!(ids.len(), MAX_SEQ_LENGTH);
289 assert_eq!(mask.len(), MAX_SEQ_LENGTH);
290
291 assert_eq!(ids[0], 101);
293
294 let active: i64 = mask.iter().sum();
296 assert!(active > 0, "Should have at least some active tokens");
297 }
298 }
299
300 #[test]
301 fn test_embedding_dimension() {
302 assert_eq!(EMBEDDING_DIM, 384);
303 }
304
305 #[test]
306 fn test_max_seq_length() {
307 assert_eq!(MAX_SEQ_LENGTH, 256);
308 }
309}