seekr_code/embedder/
onnx.rs1use std::path::{Path, PathBuf};
7use std::sync::Mutex;
8
9use ort::session::Session;
10use ort::session::builder::GraphOptimizationLevel;
11use ort::value::TensorRef;
12
13use crate::embedder::traits::Embedder;
14use crate::error::EmbedderError;
15
16const MODEL_URL: &str = "https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2/resolve/main/onnx/model_quantized.onnx";
18
19const MODEL_FILENAME: &str = "all-MiniLM-L6-v2-quantized.onnx";
21
22const TOKENIZER_URL: &str =
24 "https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2/resolve/main/tokenizer.json";
25
26const TOKENIZER_FILENAME: &str = "tokenizer.json";
28
29const EMBEDDING_DIM: usize = 384;
31
32const MAX_SEQ_LENGTH: usize = 256;
34
35pub struct OnnxEmbedder {
37 session: Mutex<Session>,
39 tokenizer: tokenizers::Tokenizer,
41 model_dir: PathBuf,
42}
43
44impl OnnxEmbedder {
45 pub fn new(model_dir: &Path) -> Result<Self, EmbedderError> {
50 std::fs::create_dir_all(model_dir).map_err(EmbedderError::Io)?;
51
52 let model_path = model_dir.join(MODEL_FILENAME);
53
54 if !model_path.exists() {
56 tracing::info!("Downloading ONNX model to {}...", model_path.display());
57 download_file(MODEL_URL, &model_path)?;
58 tracing::info!("Model downloaded successfully.");
59 }
60
61 let tokenizer_path = model_dir.join(TOKENIZER_FILENAME);
63 if !tokenizer_path.exists() {
64 tracing::info!("Downloading tokenizer...");
65 download_file(TOKENIZER_URL, &tokenizer_path)?;
66 tracing::info!("Tokenizer downloaded successfully.");
67 }
68
69 let session = Session::builder()
71 .map_err(|e| EmbedderError::OnnxError(e.to_string()))?
72 .with_optimization_level(GraphOptimizationLevel::Level3)
73 .unwrap_or_else(|e| e.recover())
74 .with_intra_threads(4)
75 .unwrap_or_else(|e| e.recover())
76 .commit_from_file(&model_path)
77 .map_err(|e| EmbedderError::OnnxError(format!("Failed to load model: {}", e)))?;
78
79 let tokenizer = tokenizers::Tokenizer::from_file(&tokenizer_path)
81 .map_err(|e| EmbedderError::OnnxError(format!("Failed to load tokenizer: {}", e)))?;
82
83 Ok(Self {
84 session: Mutex::new(session),
85 tokenizer,
86 model_dir: model_dir.to_path_buf(),
87 })
88 }
89
90 pub fn model_dir(&self) -> &Path {
92 &self.model_dir
93 }
94
95 fn tokenize(&self, text: &str) -> (Vec<i64>, Vec<i64>) {
101 let encoding = self.tokenizer.encode(text, true).unwrap_or_else(|_| {
103 self.tokenizer.encode("", true).unwrap()
105 });
106
107 let mut input_ids: Vec<i64> = encoding.get_ids().iter().map(|&id| id as i64).collect();
108 let mut attention_mask: Vec<i64> = encoding
109 .get_attention_mask()
110 .iter()
111 .map(|&m| m as i64)
112 .collect();
113
114 if input_ids.len() > MAX_SEQ_LENGTH {
116 input_ids.truncate(MAX_SEQ_LENGTH);
117 attention_mask.truncate(MAX_SEQ_LENGTH);
118 if let Some(last) = input_ids.last_mut() {
120 *last = 102;
121 }
122 }
123
124 while input_ids.len() < MAX_SEQ_LENGTH {
126 input_ids.push(0);
127 attention_mask.push(0);
128 }
129
130 (input_ids, attention_mask)
131 }
132
133 fn run_inference(
135 &self,
136 input_ids: &[i64],
137 attention_mask: &[i64],
138 ) -> Result<Vec<f32>, EmbedderError> {
139 let seq_len = input_ids.len();
140
141 let input_ids_array = ndarray::Array2::from_shape_vec((1, seq_len), input_ids.to_vec())
142 .map_err(|e| EmbedderError::OnnxError(format!("Shape error: {}", e)))?;
143 let attention_mask_array =
144 ndarray::Array2::from_shape_vec((1, seq_len), attention_mask.to_vec())
145 .map_err(|e| EmbedderError::OnnxError(format!("Shape error: {}", e)))?;
146 let token_type_ids_array = ndarray::Array2::<i64>::zeros((1, seq_len));
147
148 let input_ids_tensor = TensorRef::from_array_view(&input_ids_array)
150 .map_err(|e| EmbedderError::OnnxError(format!("Tensor creation error: {}", e)))?;
151 let attention_mask_tensor = TensorRef::from_array_view(&attention_mask_array)
152 .map_err(|e| EmbedderError::OnnxError(format!("Tensor creation error: {}", e)))?;
153 let token_type_ids_tensor = TensorRef::from_array_view(&token_type_ids_array)
154 .map_err(|e| EmbedderError::OnnxError(format!("Tensor creation error: {}", e)))?;
155
156 let mut session = self
157 .session
158 .lock()
159 .map_err(|e| EmbedderError::OnnxError(format!("Session lock poisoned: {}", e)))?;
160
161 let outputs = session
162 .run(ort::inputs![
163 "input_ids" => input_ids_tensor,
164 "attention_mask" => attention_mask_tensor,
165 "token_type_ids" => token_type_ids_tensor
166 ])
167 .map_err(|e| EmbedderError::OnnxError(format!("Inference error: {}", e)))?;
168
169 let output = if outputs.contains_key("last_hidden_state") {
171 &outputs["last_hidden_state"]
172 } else if outputs.contains_key("token_embeddings") {
173 &outputs["token_embeddings"]
174 } else {
175 &outputs[0]
176 };
177
178 let tensor = output
179 .try_extract_array::<f32>()
180 .map_err(|e| EmbedderError::OnnxError(format!("Extract error: {}", e)))?;
181
182 let shape = tensor.shape();
185 if shape.len() != 3 {
186 return Err(EmbedderError::OnnxError(format!(
187 "Unexpected output shape: {:?}",
188 shape
189 )));
190 }
191
192 let hidden_size = shape[2];
193 let mut pooled = vec![0.0f32; hidden_size];
194 let active_tokens: f32 = attention_mask.iter().map(|&m| m as f32).sum();
195
196 if active_tokens > 0.0 {
197 for seq_idx in 0..shape[1] {
198 let mask = attention_mask.get(seq_idx).copied().unwrap_or(0) as f32;
199 if mask > 0.0 {
200 for dim in 0..hidden_size {
201 pooled[dim] += tensor[[0, seq_idx, dim]];
202 }
203 }
204 }
205 for val in &mut pooled {
206 *val /= active_tokens;
207 }
208 }
209
210 let norm: f32 = pooled.iter().map(|x| x * x).sum::<f32>().sqrt();
212 if norm > 0.0 {
213 for x in &mut pooled {
214 *x /= norm;
215 }
216 }
217
218 Ok(pooled)
219 }
220}
221
222impl Embedder for OnnxEmbedder {
223 fn embed(&self, text: &str) -> Result<Vec<f32>, EmbedderError> {
224 let (input_ids, attention_mask) = self.tokenize(text);
225 self.run_inference(&input_ids, &attention_mask)
226 }
227
228 fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, EmbedderError> {
229 texts.iter().map(|text| self.embed(text)).collect()
232 }
233
234 fn dimension(&self) -> usize {
235 EMBEDDING_DIM
236 }
237}
238
239fn download_file(url: &str, dest: &Path) -> Result<(), EmbedderError> {
241 let response = reqwest::blocking::get(url)
242 .map_err(|e| EmbedderError::DownloadFailed(format!("HTTP request failed: {}", e)))?;
243
244 if !response.status().is_success() {
245 return Err(EmbedderError::DownloadFailed(format!(
246 "HTTP {} for {}",
247 response.status(),
248 url
249 )));
250 }
251
252 let bytes = response
253 .bytes()
254 .map_err(|e| EmbedderError::DownloadFailed(format!("Failed to read response: {}", e)))?;
255
256 if bytes.is_empty() {
258 return Err(EmbedderError::DownloadFailed(
259 "Downloaded file is empty".to_string(),
260 ));
261 }
262
263 let tmp_path = dest.with_extension("tmp");
265 std::fs::write(&tmp_path, &bytes).map_err(EmbedderError::Io)?;
266 std::fs::rename(&tmp_path, dest).map_err(EmbedderError::Io)?;
267
268 tracing::info!("Downloaded {} bytes to {}", bytes.len(), dest.display());
269
270 Ok(())
271}
272
273#[cfg(test)]
274mod tests {
275 use super::*;
276
277 #[test]
278 fn test_tokenize_output_length() {
279 let model_dir = std::env::temp_dir().join("seekr_test_tokenizer");
283 if let Ok(embedder) = OnnxEmbedder::new(&model_dir) {
284 let (ids, mask) = embedder.tokenize("hello world");
285 assert_eq!(ids.len(), MAX_SEQ_LENGTH);
286 assert_eq!(mask.len(), MAX_SEQ_LENGTH);
287
288 assert_eq!(ids[0], 101);
290
291 let active: i64 = mask.iter().sum();
293 assert!(active > 0, "Should have at least some active tokens");
294 }
295 }
296
297 #[test]
298 fn test_embedding_dimension() {
299 assert_eq!(EMBEDDING_DIM, 384);
300 }
301
302 #[test]
303 fn test_max_seq_length() {
304 assert_eq!(MAX_SEQ_LENGTH, 256);
305 }
306}