Skip to main content

oxigdal_ml/models/
onnx.rs

1//! ONNX Runtime integration for OxiGDAL
2//!
3//! This module provides integration with ONNX Runtime for running ML models
4//! on geospatial data.
5
6use std::path::Path;
7
8use ndarray::{Array, ArrayD, ArrayView, IxDyn};
9use ort::session::Session;
10use ort::session::builder::GraphOptimizationLevel;
11use ort::value::TensorRef;
12use oxigdal_core::buffer::RasterBuffer;
13use oxigdal_core::types::RasterDataType;
14use serde::{Deserialize, Serialize};
15use tracing::{debug, info};
16
17use crate::error::{InferenceError, ModelError, Result};
18use crate::models::Model;
19
20/// ONNX model with ONNX Runtime backend
21pub struct OnnxModel {
22    session: Session,
23    metadata: ModelMetadata,
24    config: SessionConfig,
25}
26
27/// Model metadata
28#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct ModelMetadata {
30    /// Model name
31    pub name: String,
32    /// Model version
33    pub version: String,
34    /// Model description
35    pub description: String,
36    /// Input tensor names
37    pub input_names: Vec<String>,
38    /// Output tensor names
39    pub output_names: Vec<String>,
40    /// Input shape (channels, height, width)
41    pub input_shape: (usize, usize, usize),
42    /// Output shape (channels, height, width)
43    pub output_shape: (usize, usize, usize),
44    /// Class labels (if classification model)
45    pub class_labels: Option<Vec<String>>,
46}
47
48/// Session configuration for ONNX Runtime
49#[derive(Debug, Clone)]
50pub struct SessionConfig {
51    /// Execution provider
52    pub execution_provider: ExecutionProvider,
53    /// Number of threads for CPU inference
54    pub num_threads: usize,
55    /// Enable graph optimization
56    pub graph_optimization: bool,
57    /// Batch size
58    pub batch_size: usize,
59}
60
61/// Execution provider for ONNX Runtime
62#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
63pub enum ExecutionProvider {
64    /// CPU execution
65    Cpu,
66    /// CUDA GPU execution (requires 'gpu' feature)
67    #[cfg(feature = "gpu")]
68    Cuda,
69    /// TensorRT execution (requires 'gpu' feature)
70    #[cfg(feature = "gpu")]
71    TensorRt,
72    /// DirectML execution (requires 'directml' feature, Windows only)
73    #[cfg(feature = "directml")]
74    DirectMl,
75    /// CoreML execution (requires 'coreml' feature, macOS/iOS only)
76    #[cfg(feature = "coreml")]
77    CoreMl,
78}
79
80impl Default for SessionConfig {
81    fn default() -> Self {
82        Self {
83            execution_provider: ExecutionProvider::Cpu,
84            num_threads: num_cpus(),
85            graph_optimization: true,
86            batch_size: 1,
87        }
88    }
89}
90
91impl OnnxModel {
92    /// Loads an ONNX model from a file
93    ///
94    /// # Errors
95    /// Returns an error if the model cannot be loaded
96    pub fn from_file<P: AsRef<Path>>(path: P) -> Result<Self> {
97        Self::from_file_with_config(path, SessionConfig::default())
98    }
99
100    /// Loads an ONNX model from a file with custom configuration
101    ///
102    /// # Errors
103    /// Returns an error if the model cannot be loaded
104    pub fn from_file_with_config<P: AsRef<Path>>(path: P, config: SessionConfig) -> Result<Self> {
105        let path = path.as_ref();
106        info!("Loading ONNX model from: {}", path.display());
107
108        if !path.exists() {
109            return Err(ModelError::NotFound {
110                path: path.display().to_string(),
111            }
112            .into());
113        }
114
115        // Create SessionBuilder with configuration
116        let mut builder = Session::builder().map_err(|e| ModelError::LoadFailed {
117            reason: format!("Failed to create session builder: {}", e),
118        })?;
119
120        // Configure number of threads
121        builder = builder
122            .with_intra_threads(config.num_threads)
123            .map_err(|e| ModelError::LoadFailed {
124                reason: format!("Failed to set intra threads: {}", e),
125            })?;
126
127        // Configure graph optimization
128        if config.graph_optimization {
129            builder = builder
130                .with_optimization_level(GraphOptimizationLevel::Level3)
131                .map_err(|e| ModelError::LoadFailed {
132                    reason: format!("Failed to set optimization level: {}", e),
133                })?;
134        }
135
136        // Configure execution provider
137        #[cfg(feature = "gpu")]
138        {
139            use ort::execution_providers::CUDAExecutionProvider;
140            if matches!(config.execution_provider, ExecutionProvider::Cuda) {
141                builder = builder
142                    .with_execution_providers([CUDAExecutionProvider::default().build()])
143                    .map_err(|e| ModelError::LoadFailed {
144                        reason: format!("Failed to set CUDA execution provider: {}", e),
145                    })?;
146            }
147        }
148
149        #[cfg(feature = "directml")]
150        {
151            use ort::execution_providers::DirectMLExecutionProvider;
152            if matches!(config.execution_provider, ExecutionProvider::DirectMl) {
153                builder = builder
154                    .with_execution_providers([DirectMLExecutionProvider::default().build()])
155                    .map_err(|e| ModelError::LoadFailed {
156                        reason: format!("Failed to set DirectML execution provider: {}", e),
157                    })?;
158            }
159        }
160
161        #[cfg(feature = "coreml")]
162        {
163            use ort::execution_providers::CoreMLExecutionProvider;
164            if matches!(config.execution_provider, ExecutionProvider::CoreMl) {
165                builder = builder
166                    .with_execution_providers([CoreMLExecutionProvider::default().build()])
167                    .map_err(|e| ModelError::LoadFailed {
168                        reason: format!("Failed to set CoreML execution provider: {}", e),
169                    })?;
170            }
171        }
172
173        // Load the model
174        let session = builder
175            .commit_from_file(path)
176            .map_err(|e| ModelError::LoadFailed {
177                reason: format!("Failed to load ONNX model: {}", e),
178            })?;
179
180        info!("ONNX model loaded successfully");
181
182        // Extract metadata from the loaded session
183        let metadata = Self::extract_metadata(&session)?;
184
185        Ok(Self {
186            session,
187            metadata,
188            config,
189        })
190    }
191
192    /// Extracts metadata from an ONNX session
193    fn extract_metadata(session: &Session) -> Result<ModelMetadata> {
194        // Get input metadata using accessor methods
195        let inputs = session.inputs();
196        let outputs = session.outputs();
197
198        debug!(
199            "Extracting metadata: {} inputs, {} outputs",
200            inputs.len(),
201            outputs.len()
202        );
203
204        // Extract input names and shape
205        let input_names: Vec<String> = inputs.iter().map(|i| i.name().to_string()).collect();
206
207        // Get first input shape (assuming batch, channels, height, width)
208        let input_shape = if let Some(first_input) = inputs.first() {
209            if let Some(shape) = first_input.dtype().tensor_shape() {
210                // Assume NCHW format: [batch, channels, height, width]
211                // Extract C, H, W (skip batch dimension)
212                // shape derefs to &[i64]
213                if shape.len() >= 4 {
214                    let c = if shape[1] < 0 { 3 } else { shape[1] as usize };
215                    let h = if shape[2] < 0 { 256 } else { shape[2] as usize };
216                    let w = if shape[3] < 0 { 256 } else { shape[3] as usize };
217                    (c, h, w)
218                } else if shape.len() == 3 {
219                    let c = if shape[0] < 0 { 3 } else { shape[0] as usize };
220                    let h = if shape[1] < 0 { 256 } else { shape[1] as usize };
221                    let w = if shape[2] < 0 { 256 } else { shape[2] as usize };
222                    (c, h, w)
223                } else {
224                    (3, 256, 256) // Default fallback
225                }
226            } else {
227                (3, 256, 256) // Default fallback
228            }
229        } else {
230            return Err(ModelError::LoadFailed {
231                reason: "No input tensors found in model".to_string(),
232            }
233            .into());
234        };
235
236        // Extract output names and shape
237        let output_names: Vec<String> = outputs.iter().map(|o| o.name().to_string()).collect();
238
239        let output_shape = if let Some(first_output) = outputs.first() {
240            if let Some(shape) = first_output.dtype().tensor_shape() {
241                // Assume NCHW format: [batch, channels, height, width]
242                if shape.len() >= 4 {
243                    let c = if shape[1] < 0 { 1 } else { shape[1] as usize };
244                    let h = if shape[2] < 0 { 256 } else { shape[2] as usize };
245                    let w = if shape[3] < 0 { 256 } else { shape[3] as usize };
246                    (c, h, w)
247                } else if shape.len() == 3 {
248                    let c = if shape[0] < 0 { 1 } else { shape[0] as usize };
249                    let h = if shape[1] < 0 { 256 } else { shape[1] as usize };
250                    let w = if shape[2] < 0 { 256 } else { shape[2] as usize };
251                    (c, h, w)
252                } else {
253                    (1, 256, 256) // Default fallback
254                }
255            } else {
256                (1, 256, 256) // Default fallback
257            }
258        } else {
259            return Err(ModelError::LoadFailed {
260                reason: "No output tensors found in model".to_string(),
261            }
262            .into());
263        };
264
265        Ok(ModelMetadata {
266            name: "onnx_model".to_string(),
267            version: "1.0.0".to_string(),
268            description: "ONNX Runtime model".to_string(),
269            input_names,
270            output_names,
271            input_shape,
272            output_shape,
273            class_labels: None,
274        })
275    }
276
277    /// Runs inference on a raster buffer
278    ///
279    /// # Errors
280    /// Returns an error if inference fails
281    pub fn infer(&mut self, input: &RasterBuffer) -> Result<RasterBuffer> {
282        debug!(
283            "Running inference on {}x{} buffer",
284            input.width(),
285            input.height()
286        );
287
288        // Convert RasterBuffer to ndarray
289        let input_array = self.buffer_to_ndarray(input)?;
290
291        // Get input name
292        let input_name =
293            self.metadata
294                .input_names
295                .first()
296                .ok_or_else(|| InferenceError::Failed {
297                    reason: "No input tensor name available".to_string(),
298                })?;
299
300        // Create TensorRef from ndarray view
301        let input_tensor =
302            TensorRef::from_array_view(input_array.view()).map_err(|e| InferenceError::Failed {
303                reason: format!("Failed to create input tensor: {}", e),
304            })?;
305
306        // Run inference using ort 2.0 API with inputs! macro
307        let outputs = self
308            .session
309            .run(ort::inputs![input_name.as_str() => input_tensor])
310            .map_err(|e| InferenceError::Failed {
311                reason: format!("ONNX inference failed: {}", e),
312            })?;
313
314        // Get output name
315        let output_name =
316            self.metadata
317                .output_names
318                .first()
319                .ok_or_else(|| InferenceError::Failed {
320                    reason: "No output tensor name available".to_string(),
321                })?;
322
323        // Extract output tensor
324        let output_tensor = outputs.get(output_name.as_str()).ok_or_else(|| {
325            InferenceError::OutputParsingFailed {
326                reason: format!("Output tensor '{}' not found", output_name),
327            }
328        })?;
329
330        // Extract array from tensor (ort 2.0 API)
331        // try_extract_array directly returns ArrayViewD
332        let output_array = output_tensor.try_extract_array::<f32>().map_err(|e| {
333            InferenceError::OutputParsingFailed {
334                reason: format!("Failed to extract output tensor: {}", e),
335            }
336        })?;
337
338        // Convert to owned array to avoid borrow checker issues
339        let output_owned = output_array.to_owned();
340
341        // Drop outputs to release the borrow of self.session
342        drop(outputs);
343
344        // Convert back to RasterBuffer
345        let output_view = output_owned.view().into_dyn();
346        self.ndarray_to_buffer(&output_view)
347    }
348
349    /// Runs batch inference
350    ///
351    /// # Errors
352    /// Returns an error if inference fails
353    pub fn infer_batch(&mut self, inputs: &[RasterBuffer]) -> Result<Vec<RasterBuffer>> {
354        if inputs.is_empty() {
355            return Ok(Vec::new());
356        }
357
358        debug!("Running batch inference on {} inputs", inputs.len());
359
360        // Process each input individually (ONNX Runtime handles batching internally)
361        let mut results = Vec::with_capacity(inputs.len());
362        for input in inputs {
363            let output = self.infer(input)?;
364            results.push(output);
365        }
366
367        Ok(results)
368    }
369
370    /// Converts RasterBuffer to ndarray
371    fn buffer_to_ndarray(&self, buffer: &RasterBuffer) -> Result<ArrayD<f32>> {
372        let width = buffer.width() as usize;
373        let height = buffer.height() as usize;
374
375        // Get expected input shape from metadata
376        let (channels, expected_height, expected_width) = self.metadata.input_shape;
377
378        // Validate dimensions
379        if width != expected_width || height != expected_height {
380            return Err(InferenceError::InvalidInputShape {
381                expected: vec![channels, expected_height, expected_width],
382                actual: vec![channels, height, width],
383            }
384            .into());
385        }
386
387        // Convert buffer data to f32
388        let data = match buffer.data_type() {
389            RasterDataType::Float32 => {
390                let slice = buffer
391                    .as_slice::<f32>()
392                    .map_err(crate::error::MlError::OxiGdal)?;
393                slice.to_vec()
394            }
395            RasterDataType::UInt8 => {
396                let slice = buffer
397                    .as_slice::<u8>()
398                    .map_err(crate::error::MlError::OxiGdal)?;
399                slice.iter().map(|&v| f32::from(v) / 255.0).collect()
400            }
401            RasterDataType::Int16 => {
402                let slice = buffer
403                    .as_slice::<i16>()
404                    .map_err(crate::error::MlError::OxiGdal)?;
405                slice.iter().map(|&v| v as f32).collect()
406            }
407            RasterDataType::UInt16 => {
408                let slice = buffer
409                    .as_slice::<u16>()
410                    .map_err(crate::error::MlError::OxiGdal)?;
411                slice.iter().map(|&v| f32::from(v) / 65535.0).collect()
412            }
413            RasterDataType::Float64 => {
414                let slice = buffer
415                    .as_slice::<f64>()
416                    .map_err(crate::error::MlError::OxiGdal)?;
417                slice.iter().map(|&v| v as f32).collect()
418            }
419            _ => {
420                return Err(InferenceError::Failed {
421                    reason: format!("Unsupported data type: {:?}", buffer.data_type()),
422                }
423                .into());
424            }
425        };
426
427        // Calculate expected total size
428        let total_pixels = height * width;
429        let num_bands = data.len() / total_pixels;
430
431        // Create array with shape [batch=1, channels, height, width]
432        let shape = IxDyn(&[1, num_bands, height, width]);
433
434        Array::from_shape_vec(shape, data).map_err(|e| {
435            InferenceError::Failed {
436                reason: format!("Failed to create ndarray from buffer: {}", e),
437            }
438            .into()
439        })
440    }
441
442    /// Converts ndarray to RasterBuffer
443    fn ndarray_to_buffer(&self, array: &ArrayView<f32, IxDyn>) -> Result<RasterBuffer> {
444        let shape = array.shape();
445        debug!("Converting ndarray with shape {:?} to RasterBuffer", shape);
446
447        // Expect shape [batch, channels, height, width] or [channels, height, width]
448        let (height, width) = if shape.len() == 4 {
449            // Shape: [batch, channels, height, width]
450            (shape[2], shape[3])
451        } else if shape.len() == 3 {
452            // Shape: [channels, height, width]
453            (shape[1], shape[2])
454        } else if shape.len() == 2 {
455            // Shape: [height, width]
456            (shape[0], shape[1])
457        } else {
458            return Err(InferenceError::OutputParsingFailed {
459                reason: format!("Unexpected output shape: {:?}", shape),
460            }
461            .into());
462        };
463
464        // Convert to contiguous vec
465        let data: Vec<f32> = array.iter().copied().collect();
466
467        // Convert to bytes
468        let bytes: Vec<u8> = data.iter().flat_map(|&f: &f32| f.to_le_bytes()).collect();
469
470        // Create RasterBuffer
471        RasterBuffer::new(
472            bytes,
473            width as u64,
474            height as u64,
475            RasterDataType::Float32,
476            oxigdal_core::types::NoDataValue::None,
477        )
478        .map_err(crate::error::MlError::OxiGdal)
479    }
480}
481
482impl Model for OnnxModel {
483    fn metadata(&self) -> &ModelMetadata {
484        &self.metadata
485    }
486
487    fn predict(&mut self, input: &RasterBuffer) -> Result<RasterBuffer> {
488        self.infer(input)
489    }
490
491    fn predict_batch(&mut self, inputs: &[RasterBuffer]) -> Result<Vec<RasterBuffer>> {
492        self.infer_batch(inputs)
493    }
494
495    fn input_shape(&self) -> (usize, usize, usize) {
496        self.metadata.input_shape
497    }
498
499    fn output_shape(&self) -> (usize, usize, usize) {
500        self.metadata.output_shape
501    }
502}
503
504/// Returns the number of CPUs
505fn num_cpus() -> usize {
506    std::thread::available_parallelism()
507        .map(|n| n.get())
508        .unwrap_or(4)
509}
510
511#[cfg(test)]
512mod tests {
513    use super::*;
514
515    #[test]
516    fn test_session_config_default() {
517        let config = SessionConfig::default();
518        assert_eq!(config.execution_provider, ExecutionProvider::Cpu);
519        assert!(config.graph_optimization);
520        assert_eq!(config.batch_size, 1);
521    }
522
523    #[test]
524    fn test_metadata_serialization() {
525        let metadata = ModelMetadata {
526            name: "test_model".to_string(),
527            version: "1.0.0".to_string(),
528            description: "Test model".to_string(),
529            input_names: vec!["input".to_string()],
530            output_names: vec!["output".to_string()],
531            input_shape: (3, 256, 256),
532            output_shape: (1, 256, 256),
533            class_labels: None,
534        };
535
536        let json = serde_json::to_string(&metadata);
537        assert!(json.is_ok());
538    }
539
540    #[test]
541    fn test_num_cpus() {
542        let cpus = num_cpus();
543        assert!(cpus > 0);
544        assert!(cpus <= 256); // Reasonable upper bound
545    }
546}