1use anyhow::{anyhow, Result};
2use async_trait::async_trait;
3use std::path::Path;
4#[cfg(feature = "onnx-embeddings")]
5use std::sync::Arc;
6
7use super::{EmbeddingProvider as ProviderTrait, ProviderInfo};
8use crate::rag::model_metadata::{ModelMetadata, ModelMetadataManager};
9
10#[cfg(feature = "onnx-embeddings")]
11use ort::session::{builder::GraphOptimizationLevel, Session};
12#[cfg(feature = "onnx-embeddings")]
13use ort::value::Value;
14#[cfg(feature = "onnx-embeddings")]
15use tokenizers::Tokenizer;
16
17pub struct OnnxProvider {
19 model_name: String,
20 dimension: usize,
21 max_length: usize,
22 #[cfg(feature = "onnx-embeddings")]
23 session: tokio::sync::RwLock<Session>,
24 #[cfg(feature = "onnx-embeddings")]
25 tokenizer: Arc<Tokenizer>,
26 #[cfg(not(feature = "onnx-embeddings"))]
27 _phantom: std::marker::PhantomData<()>,
28}
29
30impl OnnxProvider {
31 pub async fn new(model_name: &str) -> Result<Self> {
33 Self::new_impl(model_name).await
34 }
35
36 #[cfg(feature = "onnx-embeddings")]
37 async fn new_impl(model_name: &str) -> Result<Self> {
38 let mut metadata_manager = ModelMetadataManager::new()?;
39
40 let metadata = metadata_manager.get_model(model_name).ok_or_else(|| {
42 anyhow!(
43 "Model '{}' not found. Use 'manx embedding download {}' first",
44 model_name,
45 model_name
46 )
47 })?;
48
49 let model_dir = metadata
51 .model_path
52 .as_ref()
53 .ok_or_else(|| anyhow!("No model path found for {}", model_name))?;
54
55 let onnx_path = model_dir.join("model.onnx");
56 let tokenizer_path = model_dir.join("tokenizer.json");
57
58 if !onnx_path.exists() {
59 return Err(anyhow!("ONNX model file not found at {:?}", onnx_path));
60 }
61
62 if !tokenizer_path.exists() {
63 return Err(anyhow!("Tokenizer file not found at {:?}", tokenizer_path));
64 }
65
66 log::info!("Loading ONNX model: {:?}", onnx_path);
68 let session = Session::builder()?
69 .with_optimization_level(GraphOptimizationLevel::Level3)?
70 .with_intra_threads(4)?
71 .commit_from_file(onnx_path)?;
72
73 log::info!("ONNX session created successfully");
74
75 log::info!("Loading tokenizer: {:?}", tokenizer_path);
77 let tokenizer = Tokenizer::from_file(&tokenizer_path)
78 .map_err(|e| anyhow!("Failed to load tokenizer: {}", e))?;
79
80 log::info!("Tokenizer loaded successfully");
81
82 let dimension = metadata.dimension;
83 let max_length = metadata.max_input_length.unwrap_or(512);
84
85 metadata_manager.mark_used(model_name)?;
87
88 log::info!(
89 "ONNX provider initialized: {} ({}D, max_len={})",
90 model_name,
91 dimension,
92 max_length
93 );
94
95 Ok(Self {
96 model_name: model_name.to_string(),
97 dimension,
98 max_length,
99 session: tokio::sync::RwLock::new(session),
100 tokenizer: Arc::new(tokenizer),
101 })
102 }
103
104 #[cfg(not(feature = "onnx-embeddings"))]
105 async fn new_impl(_model_name: &str) -> Result<Self> {
106 Err(anyhow!(
107 "ONNX embeddings feature not enabled. Compile with --features onnx-embeddings"
108 ))
109 }
110
111 pub async fn download_model(model_name: &str, force: bool) -> Result<()> {
113 let mut metadata_manager = ModelMetadataManager::new()?;
114
115 if !force && metadata_manager.get_model(model_name).is_some() {
117 return Err(anyhow!(
118 "Model '{}' already installed. Use --force to reinstall",
119 model_name
120 ));
121 }
122
123 log::info!("Downloading model: {}", model_name);
124
125 let models_dir = ModelMetadataManager::get_models_dir();
127 let model_dir = models_dir.join(model_name.replace('/', "_"));
128 std::fs::create_dir_all(&model_dir)?;
129
130 let files_to_download = vec![
132 ("onnx/model.onnx", "model.onnx"),
133 ("tokenizer.json", "tokenizer.json"),
134 ("config.json", "config.json"),
135 ];
136
137 let client = reqwest::Client::new();
138 let mut total_size = 0u64;
139 let mut dimension = None;
140
141 for (remote_path, local_filename) in files_to_download {
142 let url = format!(
143 "https://huggingface.co/{}/resolve/main/{}",
144 model_name, remote_path
145 );
146 let local_path = model_dir.join(local_filename);
147
148 log::info!("Downloading: {} -> {:?}", url, local_path);
149
150 let response = client.get(&url).send().await?;
151
152 if !response.status().is_success() {
153 return Err(anyhow!(
154 "Failed to download {}: HTTP {}",
155 url,
156 response.status()
157 ));
158 }
159
160 let bytes = response.bytes().await?;
161 std::fs::write(&local_path, &bytes)?;
162 total_size += bytes.len() as u64;
163
164 log::info!("Downloaded {} ({} bytes)", local_filename, bytes.len());
165
166 if local_filename == "config.json" {
168 if let Ok(config_str) = std::fs::read_to_string(&local_path) {
169 if let Ok(config) = serde_json::from_str::<serde_json::Value>(&config_str) {
170 if let Some(hidden_size) =
171 config.get("hidden_size").and_then(|v| v.as_u64())
172 {
173 dimension = Some(hidden_size as usize);
174 }
175 }
176 }
177 }
178 }
179
180 if dimension.is_none() {
182 dimension =
183 Some(Self::detect_dimension_from_onnx(&model_dir.join("model.onnx")).await?);
184 }
185
186 let dimension = dimension
187 .ok_or_else(|| anyhow!("Could not detect dimension from model config or ONNX file"))?;
188
189 let metadata = ModelMetadata {
191 model_name: model_name.to_string(),
192 provider_type: "onnx".to_string(),
193 dimension,
194 size_mb: total_size as f64 / 1_048_576.0, model_path: Some(model_dir.clone()),
196 api_endpoint: None,
197 installed_date: chrono::Utc::now(),
198 last_used: None,
199 checksum: Some(Self::calculate_model_checksum(&model_dir)?),
200 description: Some(format!("ONNX model: {}", model_name)),
201 max_input_length: Some(512), };
203
204 metadata_manager.add_model(metadata)?;
205 log::info!(
206 "Successfully installed model: {} ({}D, {:.1}MB)",
207 model_name,
208 dimension,
209 total_size as f64 / 1_048_576.0
210 );
211
212 Ok(())
213 }
214
215 async fn detect_dimension_from_onnx(_onnx_path: &Path) -> Result<usize> {
217 #[cfg(feature = "onnx-embeddings")]
218 {
219 log::info!("Detecting dimension from ONNX model: {:?}", _onnx_path);
220
221 let session = Session::builder()?
223 .with_optimization_level(GraphOptimizationLevel::Level1)? .commit_from_file(_onnx_path)?;
225
226 let outputs = session.outputs();
228 if let Some(first_output) = outputs.first() {
229 log::info!(
231 "Output: {} - Type: {:?}",
232 first_output.name(),
233 first_output.dtype()
234 );
235
236 let dimension = 384; log::info!("Using default embedding dimension: {}", dimension);
240 return Ok(dimension);
241 }
242
243 let inputs = session.inputs();
245 log::warn!(
246 "Could not determine dimension from outputs, input info: {:?}",
247 inputs
248 .iter()
249 .map(|i| (i.name(), i.dtype()))
250 .collect::<Vec<_>>()
251 );
252
253 Err(anyhow!(
254 "Could not detect embedding dimension from ONNX model structure"
255 ))
256 }
257
258 #[cfg(not(feature = "onnx-embeddings"))]
259 {
260 log::error!("ONNX introspection requires onnx-embeddings feature");
261 Err(anyhow!("ONNX embeddings feature not enabled"))
262 }
263 }
264
265 fn calculate_model_checksum(model_dir: &Path) -> Result<String> {
267 use sha2::{Digest, Sha256};
268 use std::fs::File;
269 use std::io::Read;
270
271 let mut hasher = Sha256::new();
272
273 let files_to_hash = ["model.onnx", "tokenizer.json", "config.json"];
275
276 for filename in files_to_hash.iter() {
277 let file_path = model_dir.join(filename);
278 if file_path.exists() {
279 let mut file = File::open(&file_path)?;
280 let mut buffer = Vec::new();
281 file.read_to_end(&mut buffer)?;
282
283 hasher.update(filename.as_bytes());
285 hasher.update(&buffer);
286
287 log::debug!("Hashed {} ({} bytes)", filename, buffer.len());
288 } else {
289 log::warn!("Model file not found for checksum: {:?}", file_path);
290 }
291 }
292
293 let result = hasher.finalize();
294 let checksum = format!("{:x}", result);
295 log::info!("Calculated model checksum: {}", &checksum[..16]);
296
297 Ok(checksum)
298 }
299
300 #[allow(dead_code)] pub fn verify_model_integrity(model_dir: &Path, expected_checksum: &str) -> Result<bool> {
303 let actual_checksum = Self::calculate_model_checksum(model_dir)?;
304 let is_valid = actual_checksum == expected_checksum;
305
306 if is_valid {
307 log::info!("Model integrity verified successfully");
308 } else {
309 log::error!(
310 "Model integrity check failed: expected {}, got {}",
311 &expected_checksum[..16],
312 &actual_checksum[..16]
313 );
314 }
315
316 Ok(is_valid)
317 }
318
319 pub fn list_available_models() -> Vec<&'static str> {
321 vec![
322 "sentence-transformers/all-MiniLM-L6-v2",
323 "sentence-transformers/all-mpnet-base-v2",
324 "sentence-transformers/multi-qa-MiniLM-L6-cos-v1",
325 "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2",
326 "BAAI/bge-small-en-v1.5",
327 "BAAI/bge-base-en-v1.5",
328 "BAAI/bge-large-en-v1.5",
329 ]
330 }
331}
332
333#[async_trait]
334impl ProviderTrait for OnnxProvider {
335 async fn embed_text(&self, text: &str) -> Result<Vec<f32>> {
336 if text.trim().is_empty() {
337 return Err(anyhow!("Cannot embed empty text"));
338 }
339
340 self.embed_text_impl(text).await
341 }
342
343 async fn get_dimension(&self) -> Result<usize> {
344 Ok(self.dimension)
345 }
346
347 async fn health_check(&self) -> Result<()> {
348 self.embed_text("test").await.map(|_| ())
350 }
351
352 fn get_info(&self) -> ProviderInfo {
353 ProviderInfo {
354 name: "ONNX Local Model".to_string(),
355 provider_type: "onnx".to_string(),
356 model_name: Some(self.model_name.clone()),
357 description: format!("Local ONNX model: {}", self.model_name),
358 max_input_length: Some(self.max_length),
359 }
360 }
361
362 fn as_any(&self) -> &dyn std::any::Any {
363 self
364 }
365}
366
367impl OnnxProvider {
368 pub async fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
371 if texts.is_empty() {
372 return Ok(vec![]);
373 }
374
375 let valid_texts: Vec<&str> = texts
377 .iter()
378 .filter(|t| !t.trim().is_empty())
379 .copied()
380 .collect();
381
382 if valid_texts.is_empty() {
383 return Err(anyhow!("Cannot embed batch with all empty texts"));
384 }
385
386 let mut embeddings = Vec::with_capacity(valid_texts.len());
389
390 for text in valid_texts {
391 match self.embed_text_impl(text).await {
392 Ok(embedding) => embeddings.push(embedding),
393 Err(e) => {
394 log::warn!("Failed to embed text in batch: {}", e);
395 continue;
397 }
398 }
399 }
400
401 if embeddings.is_empty() {
402 return Err(anyhow!("Batch embedding failed for all texts"));
403 }
404
405 Ok(embeddings)
406 }
407}
408
409impl OnnxProvider {
410 #[cfg(feature = "onnx-embeddings")]
411 async fn embed_text_impl(&self, text: &str) -> Result<Vec<f32>> {
412 let encoding = self
414 .tokenizer
415 .encode(text, true)
416 .map_err(|e| anyhow!("Tokenization failed: {}", e))?;
417
418 let mut input_ids = encoding.get_ids().to_vec();
419 let mut attention_mask = encoding.get_attention_mask().to_vec();
420
421 if input_ids.len() > self.max_length {
423 input_ids.truncate(self.max_length);
424 attention_mask.truncate(self.max_length);
425 } else {
426 while input_ids.len() < self.max_length {
427 input_ids.push(0); attention_mask.push(0);
429 }
430 }
431
432 let input_ids: Vec<i64> = input_ids.iter().map(|&x| x as i64).collect();
434 let attention_mask: Vec<i64> = attention_mask.iter().map(|&x| x as i64).collect();
435
436 let input_ids_tensor = Value::from_array(([1, self.max_length], input_ids))?;
438 let attention_mask_tensor =
439 Value::from_array(([1, self.max_length], attention_mask.clone()))?;
440
441 let mut inputs = vec![
443 ("input_ids", input_ids_tensor),
444 ("attention_mask", attention_mask_tensor),
445 ];
446
447 {
449 let session = self.session.read().await;
450 let input_names: Vec<&str> =
451 session.inputs().iter().map(|input| input.name()).collect();
452
453 if input_names.contains(&"token_type_ids") {
454 let token_type_ids: Vec<i64> = vec![0i64; self.max_length];
455 let token_type_ids_tensor =
456 Value::from_array(([1, self.max_length], token_type_ids))?;
457 inputs.push(("token_type_ids", token_type_ids_tensor));
458 }
459 }
460
461 let (shape, data) = {
463 let mut session = self.session.write().await;
464 let outputs = session.run(inputs)?;
465
466 let (shape, data_slice) = outputs[0].try_extract_tensor::<f32>()?;
468 let data: Vec<f32> = data_slice.to_vec(); (shape.clone(), data)
470 };
471
472 log::debug!("ONNX output shape: {:?}", shape);
473
474 let seq_len = shape[1] as usize;
476 let hidden_size = shape[2] as usize;
477
478 if hidden_size != self.dimension {
479 return Err(anyhow!(
480 "Model output dimension {} doesn't match expected {}",
481 hidden_size,
482 self.dimension
483 ));
484 }
485
486 let mut pooled = vec![0.0f32; hidden_size];
487 let mut mask_sum = 0usize;
488
489 for (i, &mask_val) in attention_mask.iter().enumerate().take(seq_len) {
491 if mask_val == 1 {
492 mask_sum += 1;
493 for (j, pooled_val) in pooled.iter_mut().enumerate().take(hidden_size) {
494 let idx = i * hidden_size + j;
495 *pooled_val += data[idx];
496 }
497 }
498 }
499
500 if mask_sum > 0 {
502 for val in &mut pooled {
503 *val /= mask_sum as f32;
504 }
505 }
506
507 let norm = pooled.iter().map(|x| x * x).sum::<f32>().sqrt();
509 if norm > 0.0 {
510 for val in &mut pooled {
511 *val /= norm;
512 }
513 }
514
515 log::debug!("Generated embedding with {} dimensions", pooled.len());
516 Ok(pooled)
517 }
518
519 #[cfg(not(feature = "onnx-embeddings"))]
520 async fn embed_text_impl(&self, _text: &str) -> Result<Vec<f32>> {
521 Err(anyhow!(
522 "ONNX embeddings feature not enabled. Compile with --features onnx-embeddings"
523 ))
524 }
525}