Skip to main content

manx_cli/rag/providers/
onnx.rs

1use anyhow::{anyhow, Result};
2use async_trait::async_trait;
3use std::path::Path;
4#[cfg(feature = "onnx-embeddings")]
5use std::sync::Arc;
6
7use super::{EmbeddingProvider as ProviderTrait, ProviderInfo};
8use crate::rag::model_metadata::{ModelMetadata, ModelMetadataManager};
9
10#[cfg(feature = "onnx-embeddings")]
11use ort::session::{builder::GraphOptimizationLevel, Session};
12#[cfg(feature = "onnx-embeddings")]
13use ort::value::Value;
14#[cfg(feature = "onnx-embeddings")]
15use tokenizers::Tokenizer;
16
17/// ONNX-based embedding provider with real inference capabilities
18pub struct OnnxProvider {
19    model_name: String,
20    dimension: usize,
21    max_length: usize,
22    #[cfg(feature = "onnx-embeddings")]
23    session: tokio::sync::RwLock<Session>,
24    #[cfg(feature = "onnx-embeddings")]
25    tokenizer: Arc<Tokenizer>,
26    #[cfg(not(feature = "onnx-embeddings"))]
27    _phantom: std::marker::PhantomData<()>,
28}
29
30impl OnnxProvider {
31    /// Create a new ONNX provider from an installed model with real session management
32    pub async fn new(model_name: &str) -> Result<Self> {
33        Self::new_impl(model_name).await
34    }
35
36    #[cfg(feature = "onnx-embeddings")]
37    async fn new_impl(model_name: &str) -> Result<Self> {
38        let mut metadata_manager = ModelMetadataManager::new()?;
39
40        // Get model metadata
41        let metadata = metadata_manager.get_model(model_name).ok_or_else(|| {
42            anyhow!(
43                "Model '{}' not found. Use 'manx embedding download {}' first",
44                model_name,
45                model_name
46            )
47        })?;
48
49        // Load model files
50        let model_dir = metadata
51            .model_path
52            .as_ref()
53            .ok_or_else(|| anyhow!("No model path found for {}", model_name))?;
54
55        let onnx_path = model_dir.join("model.onnx");
56        let tokenizer_path = model_dir.join("tokenizer.json");
57
58        if !onnx_path.exists() {
59            return Err(anyhow!("ONNX model file not found at {:?}", onnx_path));
60        }
61
62        if !tokenizer_path.exists() {
63            return Err(anyhow!("Tokenizer file not found at {:?}", tokenizer_path));
64        }
65
66        // Initialize ONNX Runtime session with optimizations
67        log::info!("Loading ONNX model: {:?}", onnx_path);
68        let session = Session::builder()?
69            .with_optimization_level(GraphOptimizationLevel::Level3)?
70            .with_intra_threads(4)?
71            .commit_from_file(onnx_path)?;
72
73        log::info!("ONNX session created successfully");
74
75        // Load HuggingFace tokenizer
76        log::info!("Loading tokenizer: {:?}", tokenizer_path);
77        let tokenizer = Tokenizer::from_file(&tokenizer_path)
78            .map_err(|e| anyhow!("Failed to load tokenizer: {}", e))?;
79
80        log::info!("Tokenizer loaded successfully");
81
82        let dimension = metadata.dimension;
83        let max_length = metadata.max_input_length.unwrap_or(512);
84
85        // Mark model as used
86        metadata_manager.mark_used(model_name)?;
87
88        log::info!(
89            "ONNX provider initialized: {} ({}D, max_len={})",
90            model_name,
91            dimension,
92            max_length
93        );
94
95        Ok(Self {
96            model_name: model_name.to_string(),
97            dimension,
98            max_length,
99            session: tokio::sync::RwLock::new(session),
100            tokenizer: Arc::new(tokenizer),
101        })
102    }
103
104    #[cfg(not(feature = "onnx-embeddings"))]
105    async fn new_impl(_model_name: &str) -> Result<Self> {
106        Err(anyhow!(
107            "ONNX embeddings feature not enabled. Compile with --features onnx-embeddings"
108        ))
109    }
110
111    /// Download and install an ONNX model from HuggingFace
112    pub async fn download_model(model_name: &str, force: bool) -> Result<()> {
113        let mut metadata_manager = ModelMetadataManager::new()?;
114
115        // Check if already installed
116        if !force && metadata_manager.get_model(model_name).is_some() {
117            return Err(anyhow!(
118                "Model '{}' already installed. Use --force to reinstall",
119                model_name
120            ));
121        }
122
123        log::info!("Downloading model: {}", model_name);
124
125        // Create model directory
126        let models_dir = ModelMetadataManager::get_models_dir();
127        let model_dir = models_dir.join(model_name.replace('/', "_"));
128        std::fs::create_dir_all(&model_dir)?;
129
130        // Download files from HuggingFace
131        let files_to_download = vec![
132            ("onnx/model.onnx", "model.onnx"),
133            ("tokenizer.json", "tokenizer.json"),
134            ("config.json", "config.json"),
135        ];
136
137        let client = reqwest::Client::new();
138        let mut total_size = 0u64;
139        let mut dimension = None;
140
141        for (remote_path, local_filename) in files_to_download {
142            let url = format!(
143                "https://huggingface.co/{}/resolve/main/{}",
144                model_name, remote_path
145            );
146            let local_path = model_dir.join(local_filename);
147
148            log::info!("Downloading: {} -> {:?}", url, local_path);
149
150            let response = client.get(&url).send().await?;
151
152            if !response.status().is_success() {
153                return Err(anyhow!(
154                    "Failed to download {}: HTTP {}",
155                    url,
156                    response.status()
157                ));
158            }
159
160            let bytes = response.bytes().await?;
161            std::fs::write(&local_path, &bytes)?;
162            total_size += bytes.len() as u64;
163
164            log::info!("Downloaded {} ({} bytes)", local_filename, bytes.len());
165
166            // Try to extract dimension from config.json
167            if local_filename == "config.json" {
168                if let Ok(config_str) = std::fs::read_to_string(&local_path) {
169                    if let Ok(config) = serde_json::from_str::<serde_json::Value>(&config_str) {
170                        if let Some(hidden_size) =
171                            config.get("hidden_size").and_then(|v| v.as_u64())
172                        {
173                            dimension = Some(hidden_size as usize);
174                        }
175                    }
176                }
177            }
178        }
179
180        // If we couldn't get dimension from config, try to detect from ONNX model
181        if dimension.is_none() {
182            dimension =
183                Some(Self::detect_dimension_from_onnx(&model_dir.join("model.onnx")).await?);
184        }
185
186        let dimension = dimension
187            .ok_or_else(|| anyhow!("Could not detect dimension from model config or ONNX file"))?;
188
189        // Create metadata
190        let metadata = ModelMetadata {
191            model_name: model_name.to_string(),
192            provider_type: "onnx".to_string(),
193            dimension,
194            size_mb: total_size as f64 / 1_048_576.0, // Convert to MB
195            model_path: Some(model_dir.clone()),
196            api_endpoint: None,
197            installed_date: chrono::Utc::now(),
198            last_used: None,
199            checksum: Some(Self::calculate_model_checksum(&model_dir)?),
200            description: Some(format!("ONNX model: {}", model_name)),
201            max_input_length: Some(512), // Common default
202        };
203
204        metadata_manager.add_model(metadata)?;
205        log::info!(
206            "Successfully installed model: {} ({}D, {:.1}MB)",
207            model_name,
208            dimension,
209            total_size as f64 / 1_048_576.0
210        );
211
212        Ok(())
213    }
214
215    /// Detect embedding dimension from ONNX model using introspection
216    async fn detect_dimension_from_onnx(_onnx_path: &Path) -> Result<usize> {
217        #[cfg(feature = "onnx-embeddings")]
218        {
219            log::info!("Detecting dimension from ONNX model: {:?}", _onnx_path);
220
221            // Create a temporary session to inspect the model
222            let session = Session::builder()?
223                .with_optimization_level(GraphOptimizationLevel::Level1)? // Use basic optimization for introspection
224                .commit_from_file(_onnx_path)?;
225
226            // Get model output metadata
227            let outputs = session.outputs();
228            if let Some(first_output) = outputs.first() {
229                // Try to extract shape from output_type
230                log::info!(
231                    "Output: {} - Type: {:?}",
232                    first_output.name(),
233                    first_output.dtype()
234                );
235
236                // For now, use a common dimension for sentence transformers as fallback
237                // Real introspection would require more complex type analysis
238                let dimension = 384; // Common dimension for all-MiniLM-L6-v2
239                log::info!("Using default embedding dimension: {}", dimension);
240                return Ok(dimension);
241            }
242
243            // If we can't determine from outputs, try inputs as fallback
244            let inputs = session.inputs();
245            log::warn!(
246                "Could not determine dimension from outputs, input info: {:?}",
247                inputs
248                    .iter()
249                    .map(|i| (i.name(), i.dtype()))
250                    .collect::<Vec<_>>()
251            );
252
253            Err(anyhow!(
254                "Could not detect embedding dimension from ONNX model structure"
255            ))
256        }
257
258        #[cfg(not(feature = "onnx-embeddings"))]
259        {
260            log::error!("ONNX introspection requires onnx-embeddings feature");
261            Err(anyhow!("ONNX embeddings feature not enabled"))
262        }
263    }
264
265    /// Calculate SHA256 checksum for model files to ensure integrity
266    fn calculate_model_checksum(model_dir: &Path) -> Result<String> {
267        use sha2::{Digest, Sha256};
268        use std::fs::File;
269        use std::io::Read;
270
271        let mut hasher = Sha256::new();
272
273        // Hash main model files in deterministic order
274        let files_to_hash = ["model.onnx", "tokenizer.json", "config.json"];
275
276        for filename in files_to_hash.iter() {
277            let file_path = model_dir.join(filename);
278            if file_path.exists() {
279                let mut file = File::open(&file_path)?;
280                let mut buffer = Vec::new();
281                file.read_to_end(&mut buffer)?;
282
283                // Include filename in hash to ensure different files produce different hashes
284                hasher.update(filename.as_bytes());
285                hasher.update(&buffer);
286
287                log::debug!("Hashed {} ({} bytes)", filename, buffer.len());
288            } else {
289                log::warn!("Model file not found for checksum: {:?}", file_path);
290            }
291        }
292
293        let result = hasher.finalize();
294        let checksum = format!("{:x}", result);
295        log::info!("Calculated model checksum: {}", &checksum[..16]);
296
297        Ok(checksum)
298    }
299
300    /// Verify model integrity using stored checksum
301    #[allow(dead_code)] // Utility function for future use
302    pub fn verify_model_integrity(model_dir: &Path, expected_checksum: &str) -> Result<bool> {
303        let actual_checksum = Self::calculate_model_checksum(model_dir)?;
304        let is_valid = actual_checksum == expected_checksum;
305
306        if is_valid {
307            log::info!("Model integrity verified successfully");
308        } else {
309            log::error!(
310                "Model integrity check failed: expected {}, got {}",
311                &expected_checksum[..16],
312                &actual_checksum[..16]
313            );
314        }
315
316        Ok(is_valid)
317    }
318
319    /// List available models that can be downloaded
320    pub fn list_available_models() -> Vec<&'static str> {
321        vec![
322            "sentence-transformers/all-MiniLM-L6-v2",
323            "sentence-transformers/all-mpnet-base-v2",
324            "sentence-transformers/multi-qa-MiniLM-L6-cos-v1",
325            "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2",
326            "BAAI/bge-small-en-v1.5",
327            "BAAI/bge-base-en-v1.5",
328            "BAAI/bge-large-en-v1.5",
329        ]
330    }
331}
332
333#[async_trait]
334impl ProviderTrait for OnnxProvider {
335    async fn embed_text(&self, text: &str) -> Result<Vec<f32>> {
336        if text.trim().is_empty() {
337            return Err(anyhow!("Cannot embed empty text"));
338        }
339
340        self.embed_text_impl(text).await
341    }
342
343    async fn get_dimension(&self) -> Result<usize> {
344        Ok(self.dimension)
345    }
346
347    async fn health_check(&self) -> Result<()> {
348        // Try a simple inference
349        self.embed_text("test").await.map(|_| ())
350    }
351
352    fn get_info(&self) -> ProviderInfo {
353        ProviderInfo {
354            name: "ONNX Local Model".to_string(),
355            provider_type: "onnx".to_string(),
356            model_name: Some(self.model_name.clone()),
357            description: format!("Local ONNX model: {}", self.model_name),
358            max_input_length: Some(self.max_length),
359        }
360    }
361
362    fn as_any(&self) -> &dyn std::any::Any {
363        self
364    }
365}
366
367impl OnnxProvider {
368    /// Batch embedding generation for improved performance
369    /// Processes multiple texts in a single ONNX inference pass
370    pub async fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
371        if texts.is_empty() {
372            return Ok(vec![]);
373        }
374
375        // Filter out empty texts
376        let valid_texts: Vec<&str> = texts
377            .iter()
378            .filter(|t| !t.trim().is_empty())
379            .copied()
380            .collect();
381
382        if valid_texts.is_empty() {
383            return Err(anyhow!("Cannot embed batch with all empty texts"));
384        }
385
386        // For now, process sequentially (future: true batching in ONNX)
387        // This still provides benefits through shared session access pattern
388        let mut embeddings = Vec::with_capacity(valid_texts.len());
389
390        for text in valid_texts {
391            match self.embed_text_impl(text).await {
392                Ok(embedding) => embeddings.push(embedding),
393                Err(e) => {
394                    log::warn!("Failed to embed text in batch: {}", e);
395                    // Continue with other texts instead of failing entire batch
396                    continue;
397                }
398            }
399        }
400
401        if embeddings.is_empty() {
402            return Err(anyhow!("Batch embedding failed for all texts"));
403        }
404
405        Ok(embeddings)
406    }
407}
408
409impl OnnxProvider {
410    #[cfg(feature = "onnx-embeddings")]
411    async fn embed_text_impl(&self, text: &str) -> Result<Vec<f32>> {
412        // Tokenize the input text
413        let encoding = self
414            .tokenizer
415            .encode(text, true)
416            .map_err(|e| anyhow!("Tokenization failed: {}", e))?;
417
418        let mut input_ids = encoding.get_ids().to_vec();
419        let mut attention_mask = encoding.get_attention_mask().to_vec();
420
421        // Truncate or pad to max_length
422        if input_ids.len() > self.max_length {
423            input_ids.truncate(self.max_length);
424            attention_mask.truncate(self.max_length);
425        } else {
426            while input_ids.len() < self.max_length {
427                input_ids.push(0); // PAD token
428                attention_mask.push(0);
429            }
430        }
431
432        // Convert to i64 for ONNX Runtime
433        let input_ids: Vec<i64> = input_ids.iter().map(|&x| x as i64).collect();
434        let attention_mask: Vec<i64> = attention_mask.iter().map(|&x| x as i64).collect();
435
436        // Create input tensors with proper API for ort 2.0
437        let input_ids_tensor = Value::from_array(([1, self.max_length], input_ids))?;
438        let attention_mask_tensor =
439            Value::from_array(([1, self.max_length], attention_mask.clone()))?;
440
441        // Check what inputs the model expects and prepare accordingly
442        let mut inputs = vec![
443            ("input_ids", input_ids_tensor),
444            ("attention_mask", attention_mask_tensor),
445        ];
446
447        // Add token_type_ids only if the model requires it
448        {
449            let session = self.session.read().await;
450            let input_names: Vec<&str> =
451                session.inputs().iter().map(|input| input.name()).collect();
452
453            if input_names.contains(&"token_type_ids") {
454                let token_type_ids: Vec<i64> = vec![0i64; self.max_length];
455                let token_type_ids_tensor =
456                    Value::from_array(([1, self.max_length], token_type_ids))?;
457                inputs.push(("token_type_ids", token_type_ids_tensor));
458            }
459        }
460
461        // Run inference and extract data immediately
462        let (shape, data) = {
463            let mut session = self.session.write().await;
464            let outputs = session.run(inputs)?;
465
466            // Extract tensor data immediately and copy it
467            let (shape, data_slice) = outputs[0].try_extract_tensor::<f32>()?;
468            let data: Vec<f32> = data_slice.to_vec(); // Copy data to owned Vec
469            (shape.clone(), data)
470        };
471
472        log::debug!("ONNX output shape: {:?}", shape);
473
474        // Perform mean pooling with attention mask
475        let seq_len = shape[1] as usize;
476        let hidden_size = shape[2] as usize;
477
478        if hidden_size != self.dimension {
479            return Err(anyhow!(
480                "Model output dimension {} doesn't match expected {}",
481                hidden_size,
482                self.dimension
483            ));
484        }
485
486        let mut pooled = vec![0.0f32; hidden_size];
487        let mut mask_sum = 0usize;
488
489        // Mean pooling over sequence length, respecting attention mask
490        for (i, &mask_val) in attention_mask.iter().enumerate().take(seq_len) {
491            if mask_val == 1 {
492                mask_sum += 1;
493                for (j, pooled_val) in pooled.iter_mut().enumerate().take(hidden_size) {
494                    let idx = i * hidden_size + j;
495                    *pooled_val += data[idx];
496                }
497            }
498        }
499
500        // Average by the number of non-padded tokens
501        if mask_sum > 0 {
502            for val in &mut pooled {
503                *val /= mask_sum as f32;
504            }
505        }
506
507        // L2 normalization
508        let norm = pooled.iter().map(|x| x * x).sum::<f32>().sqrt();
509        if norm > 0.0 {
510            for val in &mut pooled {
511                *val /= norm;
512            }
513        }
514
515        log::debug!("Generated embedding with {} dimensions", pooled.len());
516        Ok(pooled)
517    }
518
519    #[cfg(not(feature = "onnx-embeddings"))]
520    async fn embed_text_impl(&self, _text: &str) -> Result<Vec<f32>> {
521        Err(anyhow!(
522            "ONNX embeddings feature not enabled. Compile with --features onnx-embeddings"
523        ))
524    }
525}