1use std::path::Path;
7
8use ndarray::{Array, ArrayD, ArrayView, IxDyn};
9use ort::session::Session;
10use ort::session::builder::GraphOptimizationLevel;
11use ort::value::TensorRef;
12use oxigdal_core::buffer::RasterBuffer;
13use oxigdal_core::types::RasterDataType;
14use serde::{Deserialize, Serialize};
15use tracing::{debug, info};
16
17use crate::error::{InferenceError, ModelError, Result};
18use crate::models::Model;
19
20pub struct OnnxModel {
22 session: Session,
23 metadata: ModelMetadata,
24 config: SessionConfig,
25}
26
27#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct ModelMetadata {
30 pub name: String,
32 pub version: String,
34 pub description: String,
36 pub input_names: Vec<String>,
38 pub output_names: Vec<String>,
40 pub input_shape: (usize, usize, usize),
42 pub output_shape: (usize, usize, usize),
44 pub class_labels: Option<Vec<String>>,
46}
47
48#[derive(Debug, Clone)]
50pub struct SessionConfig {
51 pub execution_provider: ExecutionProvider,
53 pub num_threads: usize,
55 pub graph_optimization: bool,
57 pub batch_size: usize,
59}
60
61#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
63pub enum ExecutionProvider {
64 Cpu,
66 #[cfg(feature = "gpu")]
68 Cuda,
69 #[cfg(feature = "gpu")]
71 TensorRt,
72 #[cfg(feature = "directml")]
74 DirectMl,
75 #[cfg(feature = "coreml")]
77 CoreMl,
78}
79
80impl Default for SessionConfig {
81 fn default() -> Self {
82 Self {
83 execution_provider: ExecutionProvider::Cpu,
84 num_threads: num_cpus(),
85 graph_optimization: true,
86 batch_size: 1,
87 }
88 }
89}
90
91impl OnnxModel {
92 pub fn from_file<P: AsRef<Path>>(path: P) -> Result<Self> {
97 Self::from_file_with_config(path, SessionConfig::default())
98 }
99
100 pub fn from_file_with_config<P: AsRef<Path>>(path: P, config: SessionConfig) -> Result<Self> {
105 let path = path.as_ref();
106 info!("Loading ONNX model from: {}", path.display());
107
108 if !path.exists() {
109 return Err(ModelError::NotFound {
110 path: path.display().to_string(),
111 }
112 .into());
113 }
114
115 let mut builder = Session::builder().map_err(|e| ModelError::LoadFailed {
117 reason: format!("Failed to create session builder: {}", e),
118 })?;
119
120 builder = builder
122 .with_intra_threads(config.num_threads)
123 .map_err(|e| ModelError::LoadFailed {
124 reason: format!("Failed to set intra threads: {}", e),
125 })?;
126
127 if config.graph_optimization {
129 builder = builder
130 .with_optimization_level(GraphOptimizationLevel::Level3)
131 .map_err(|e| ModelError::LoadFailed {
132 reason: format!("Failed to set optimization level: {}", e),
133 })?;
134 }
135
136 #[cfg(feature = "gpu")]
138 {
139 use ort::execution_providers::CUDAExecutionProvider;
140 if matches!(config.execution_provider, ExecutionProvider::Cuda) {
141 builder = builder
142 .with_execution_providers([CUDAExecutionProvider::default().build()])
143 .map_err(|e| ModelError::LoadFailed {
144 reason: format!("Failed to set CUDA execution provider: {}", e),
145 })?;
146 }
147 }
148
149 #[cfg(feature = "directml")]
150 {
151 use ort::execution_providers::DirectMLExecutionProvider;
152 if matches!(config.execution_provider, ExecutionProvider::DirectMl) {
153 builder = builder
154 .with_execution_providers([DirectMLExecutionProvider::default().build()])
155 .map_err(|e| ModelError::LoadFailed {
156 reason: format!("Failed to set DirectML execution provider: {}", e),
157 })?;
158 }
159 }
160
161 #[cfg(feature = "coreml")]
162 {
163 use ort::execution_providers::CoreMLExecutionProvider;
164 if matches!(config.execution_provider, ExecutionProvider::CoreMl) {
165 builder = builder
166 .with_execution_providers([CoreMLExecutionProvider::default().build()])
167 .map_err(|e| ModelError::LoadFailed {
168 reason: format!("Failed to set CoreML execution provider: {}", e),
169 })?;
170 }
171 }
172
173 let session = builder
175 .commit_from_file(path)
176 .map_err(|e| ModelError::LoadFailed {
177 reason: format!("Failed to load ONNX model: {}", e),
178 })?;
179
180 info!("ONNX model loaded successfully");
181
182 let metadata = Self::extract_metadata(&session)?;
184
185 Ok(Self {
186 session,
187 metadata,
188 config,
189 })
190 }
191
192 fn extract_metadata(session: &Session) -> Result<ModelMetadata> {
194 let inputs = session.inputs();
196 let outputs = session.outputs();
197
198 debug!(
199 "Extracting metadata: {} inputs, {} outputs",
200 inputs.len(),
201 outputs.len()
202 );
203
204 let input_names: Vec<String> = inputs.iter().map(|i| i.name().to_string()).collect();
206
207 let input_shape = if let Some(first_input) = inputs.first() {
209 if let Some(shape) = first_input.dtype().tensor_shape() {
210 if shape.len() >= 4 {
214 let c = if shape[1] < 0 { 3 } else { shape[1] as usize };
215 let h = if shape[2] < 0 { 256 } else { shape[2] as usize };
216 let w = if shape[3] < 0 { 256 } else { shape[3] as usize };
217 (c, h, w)
218 } else if shape.len() == 3 {
219 let c = if shape[0] < 0 { 3 } else { shape[0] as usize };
220 let h = if shape[1] < 0 { 256 } else { shape[1] as usize };
221 let w = if shape[2] < 0 { 256 } else { shape[2] as usize };
222 (c, h, w)
223 } else {
224 (3, 256, 256) }
226 } else {
227 (3, 256, 256) }
229 } else {
230 return Err(ModelError::LoadFailed {
231 reason: "No input tensors found in model".to_string(),
232 }
233 .into());
234 };
235
236 let output_names: Vec<String> = outputs.iter().map(|o| o.name().to_string()).collect();
238
239 let output_shape = if let Some(first_output) = outputs.first() {
240 if let Some(shape) = first_output.dtype().tensor_shape() {
241 if shape.len() >= 4 {
243 let c = if shape[1] < 0 { 1 } else { shape[1] as usize };
244 let h = if shape[2] < 0 { 256 } else { shape[2] as usize };
245 let w = if shape[3] < 0 { 256 } else { shape[3] as usize };
246 (c, h, w)
247 } else if shape.len() == 3 {
248 let c = if shape[0] < 0 { 1 } else { shape[0] as usize };
249 let h = if shape[1] < 0 { 256 } else { shape[1] as usize };
250 let w = if shape[2] < 0 { 256 } else { shape[2] as usize };
251 (c, h, w)
252 } else {
253 (1, 256, 256) }
255 } else {
256 (1, 256, 256) }
258 } else {
259 return Err(ModelError::LoadFailed {
260 reason: "No output tensors found in model".to_string(),
261 }
262 .into());
263 };
264
265 Ok(ModelMetadata {
266 name: "onnx_model".to_string(),
267 version: "1.0.0".to_string(),
268 description: "ONNX Runtime model".to_string(),
269 input_names,
270 output_names,
271 input_shape,
272 output_shape,
273 class_labels: None,
274 })
275 }
276
277 pub fn infer(&mut self, input: &RasterBuffer) -> Result<RasterBuffer> {
282 debug!(
283 "Running inference on {}x{} buffer",
284 input.width(),
285 input.height()
286 );
287
288 let input_array = self.buffer_to_ndarray(input)?;
290
291 let input_name =
293 self.metadata
294 .input_names
295 .first()
296 .ok_or_else(|| InferenceError::Failed {
297 reason: "No input tensor name available".to_string(),
298 })?;
299
300 let input_tensor =
302 TensorRef::from_array_view(input_array.view()).map_err(|e| InferenceError::Failed {
303 reason: format!("Failed to create input tensor: {}", e),
304 })?;
305
306 let outputs = self
308 .session
309 .run(ort::inputs![input_name.as_str() => input_tensor])
310 .map_err(|e| InferenceError::Failed {
311 reason: format!("ONNX inference failed: {}", e),
312 })?;
313
314 let output_name =
316 self.metadata
317 .output_names
318 .first()
319 .ok_or_else(|| InferenceError::Failed {
320 reason: "No output tensor name available".to_string(),
321 })?;
322
323 let output_tensor = outputs.get(output_name.as_str()).ok_or_else(|| {
325 InferenceError::OutputParsingFailed {
326 reason: format!("Output tensor '{}' not found", output_name),
327 }
328 })?;
329
330 let output_array = output_tensor.try_extract_array::<f32>().map_err(|e| {
333 InferenceError::OutputParsingFailed {
334 reason: format!("Failed to extract output tensor: {}", e),
335 }
336 })?;
337
338 let output_owned = output_array.to_owned();
340
341 drop(outputs);
343
344 let output_view = output_owned.view().into_dyn();
346 self.ndarray_to_buffer(&output_view)
347 }
348
349 pub fn infer_batch(&mut self, inputs: &[RasterBuffer]) -> Result<Vec<RasterBuffer>> {
354 if inputs.is_empty() {
355 return Ok(Vec::new());
356 }
357
358 debug!("Running batch inference on {} inputs", inputs.len());
359
360 let mut results = Vec::with_capacity(inputs.len());
362 for input in inputs {
363 let output = self.infer(input)?;
364 results.push(output);
365 }
366
367 Ok(results)
368 }
369
370 fn buffer_to_ndarray(&self, buffer: &RasterBuffer) -> Result<ArrayD<f32>> {
372 let width = buffer.width() as usize;
373 let height = buffer.height() as usize;
374
375 let (channels, expected_height, expected_width) = self.metadata.input_shape;
377
378 if width != expected_width || height != expected_height {
380 return Err(InferenceError::InvalidInputShape {
381 expected: vec![channels, expected_height, expected_width],
382 actual: vec![channels, height, width],
383 }
384 .into());
385 }
386
387 let data = match buffer.data_type() {
389 RasterDataType::Float32 => {
390 let slice = buffer
391 .as_slice::<f32>()
392 .map_err(crate::error::MlError::OxiGdal)?;
393 slice.to_vec()
394 }
395 RasterDataType::UInt8 => {
396 let slice = buffer
397 .as_slice::<u8>()
398 .map_err(crate::error::MlError::OxiGdal)?;
399 slice.iter().map(|&v| f32::from(v) / 255.0).collect()
400 }
401 RasterDataType::Int16 => {
402 let slice = buffer
403 .as_slice::<i16>()
404 .map_err(crate::error::MlError::OxiGdal)?;
405 slice.iter().map(|&v| v as f32).collect()
406 }
407 RasterDataType::UInt16 => {
408 let slice = buffer
409 .as_slice::<u16>()
410 .map_err(crate::error::MlError::OxiGdal)?;
411 slice.iter().map(|&v| f32::from(v) / 65535.0).collect()
412 }
413 RasterDataType::Float64 => {
414 let slice = buffer
415 .as_slice::<f64>()
416 .map_err(crate::error::MlError::OxiGdal)?;
417 slice.iter().map(|&v| v as f32).collect()
418 }
419 _ => {
420 return Err(InferenceError::Failed {
421 reason: format!("Unsupported data type: {:?}", buffer.data_type()),
422 }
423 .into());
424 }
425 };
426
427 let total_pixels = height * width;
429 let num_bands = data.len() / total_pixels;
430
431 let shape = IxDyn(&[1, num_bands, height, width]);
433
434 Array::from_shape_vec(shape, data).map_err(|e| {
435 InferenceError::Failed {
436 reason: format!("Failed to create ndarray from buffer: {}", e),
437 }
438 .into()
439 })
440 }
441
442 fn ndarray_to_buffer(&self, array: &ArrayView<f32, IxDyn>) -> Result<RasterBuffer> {
444 let shape = array.shape();
445 debug!("Converting ndarray with shape {:?} to RasterBuffer", shape);
446
447 let (height, width) = if shape.len() == 4 {
449 (shape[2], shape[3])
451 } else if shape.len() == 3 {
452 (shape[1], shape[2])
454 } else if shape.len() == 2 {
455 (shape[0], shape[1])
457 } else {
458 return Err(InferenceError::OutputParsingFailed {
459 reason: format!("Unexpected output shape: {:?}", shape),
460 }
461 .into());
462 };
463
464 let data: Vec<f32> = array.iter().copied().collect();
466
467 let bytes: Vec<u8> = data.iter().flat_map(|&f: &f32| f.to_le_bytes()).collect();
469
470 RasterBuffer::new(
472 bytes,
473 width as u64,
474 height as u64,
475 RasterDataType::Float32,
476 oxigdal_core::types::NoDataValue::None,
477 )
478 .map_err(crate::error::MlError::OxiGdal)
479 }
480}
481
482impl Model for OnnxModel {
483 fn metadata(&self) -> &ModelMetadata {
484 &self.metadata
485 }
486
487 fn predict(&mut self, input: &RasterBuffer) -> Result<RasterBuffer> {
488 self.infer(input)
489 }
490
491 fn predict_batch(&mut self, inputs: &[RasterBuffer]) -> Result<Vec<RasterBuffer>> {
492 self.infer_batch(inputs)
493 }
494
495 fn input_shape(&self) -> (usize, usize, usize) {
496 self.metadata.input_shape
497 }
498
499 fn output_shape(&self) -> (usize, usize, usize) {
500 self.metadata.output_shape
501 }
502}
503
504fn num_cpus() -> usize {
506 std::thread::available_parallelism()
507 .map(|n| n.get())
508 .unwrap_or(4)
509}
510
511#[cfg(test)]
512mod tests {
513 use super::*;
514
515 #[test]
516 fn test_session_config_default() {
517 let config = SessionConfig::default();
518 assert_eq!(config.execution_provider, ExecutionProvider::Cpu);
519 assert!(config.graph_optimization);
520 assert_eq!(config.batch_size, 1);
521 }
522
523 #[test]
524 fn test_metadata_serialization() {
525 let metadata = ModelMetadata {
526 name: "test_model".to_string(),
527 version: "1.0.0".to_string(),
528 description: "Test model".to_string(),
529 input_names: vec!["input".to_string()],
530 output_names: vec!["output".to_string()],
531 input_shape: (3, 256, 256),
532 output_shape: (1, 256, 256),
533 class_labels: None,
534 };
535
536 let json = serde_json::to_string(&metadata);
537 assert!(json.is_ok());
538 }
539
540 #[test]
541 fn test_num_cpus() {
542 let cpus = num_cpus();
543 assert!(cpus > 0);
544 assert!(cpus <= 256); }
546}