1#[cfg(feature = "onnx")]
2use crate::error::{SpatialError, SpatialResult};
3#[cfg(feature = "onnx")]
4use image::DynamicImage;
5#[cfg(feature = "onnx")]
6use ndarray::Array2;
7#[cfg(feature = "onnx")]
8use ort::session::{builder::GraphOptimizationLevel, Session};
9
10#[cfg(feature = "onnx")]
11const INPUT_SIZE: u32 = 518;
12#[cfg(feature = "onnx")]
13const IMAGENET_MEAN: [f32; 3] = [0.485, 0.456, 0.406];
14#[cfg(feature = "onnx")]
15const IMAGENET_STD: [f32; 3] = [0.229, 0.224, 0.225];
16
17#[cfg(feature = "onnx")]
18pub struct OnnxDepthEstimator {
19 session: Session,
20}
21
22#[cfg(feature = "onnx")]
23impl OnnxDepthEstimator {
24 pub fn new(model_path: &str) -> SpatialResult<Self> {
25 let session = Session::builder()
26 .map_err(|e| SpatialError::ModelError(format!("Failed to create session: {}", e)))?
27 .with_optimization_level(GraphOptimizationLevel::Level3)
28 .map_err(|e| SpatialError::ModelError(format!("Failed to set opt level: {}", e)))?
29 .with_intra_threads(4)
30 .map_err(|e| SpatialError::ModelError(format!("Failed to set threads: {}", e)))?
31 .commit_from_file(model_path)
32 .map_err(|e| SpatialError::ModelError(format!("Failed to load ONNX model: {}", e)))?;
33
34 Ok(Self { session })
35 }
36
37 pub fn estimate(&mut self, image: &DynamicImage) -> SpatialResult<Array2<f32>> {
38 let (orig_width, orig_height) = (image.width(), image.height());
39 let size = INPUT_SIZE as usize;
40
41 let resized = image.resize_exact(
42 INPUT_SIZE,
43 INPUT_SIZE,
44 image::imageops::FilterType::Lanczos3,
45 );
46
47 let rgb = resized.to_rgb8();
48 let mut input_data = vec![0.0f32; 1 * 3 * size * size];
49
50 for (i, pixel) in rgb.pixels().enumerate() {
51 for c in 0..3 {
52 let normalized = (pixel[c] as f32 / 255.0 - IMAGENET_MEAN[c]) / IMAGENET_STD[c];
53 input_data[c * size * size + i] = normalized;
54 }
55 }
56
57 let input_value = ort::value::Value::from_array(([1usize, 3, size, size], input_data))
58 .map_err(|e| SpatialError::TensorError(format!("Failed to create input: {}", e)))?;
59
60 let outputs = self.session.run(ort::inputs![input_value])
61 .map_err(|e| SpatialError::ModelError(format!("Inference failed: {}", e)))?;
62
63 let (shape, data) = outputs[0].try_extract_tensor::<f32>()
64 .map_err(|e| SpatialError::TensorError(format!("Failed to extract output: {}", e)))?;
65
66 let dims: Vec<usize> = shape.iter().map(|&d| d as usize).collect();
67 let h = dims[1];
68 let w = dims[2];
69
70 let depth_data: Vec<f32> = data.to_vec();
71
72 let min_val = depth_data.iter().copied().fold(f32::INFINITY, f32::min);
73 let max_val = depth_data.iter().copied().fold(f32::NEG_INFINITY, f32::max);
74 let range = max_val - min_val;
75
76 let normalized: Vec<f32> = if range > 1e-6 {
77 depth_data.iter().map(|&v| (v - min_val) / range).collect()
78 } else {
79 vec![0.5; depth_data.len()]
80 };
81
82 let depth_image = image::ImageBuffer::from_fn(w as u32, h as u32, |x, y| {
83 image::Luma([normalized[y as usize * w + x as usize]])
84 });
85
86 let resized_depth = image::imageops::resize(
87 &depth_image,
88 orig_width,
89 orig_height,
90 image::imageops::FilterType::Lanczos3,
91 );
92
93 let data: Vec<f32> = resized_depth.pixels().map(|p| p[0]).collect();
94 Array2::from_shape_vec((orig_height as usize, orig_width as usize), data)
95 .map_err(|e| SpatialError::TensorError(format!("Failed to reshape depth: {}", e)))
96 }
97}