burn_depth/
inference.rs

1use burn::prelude::*;
2
3use crate::model::depth_pro::{DepthPro, DepthProInference};
4
5/// Converts packed RGB bytes into a normalized tensor suitable for `DepthPro::infer`.
6///
7/// The input slice must contain `width * height * 3` bytes in row-major order.
8/// The output tensor is channel-first (`NCHW`) with values scaled to `[-1, 1]`.
9pub fn rgb_to_input_tensor<B: Backend>(
10    rgb: &[u8],
11    width: usize,
12    height: usize,
13    device: &B::Device,
14) -> Result<Tensor<B, 4>, String> {
15    let expected_len = width
16        .checked_mul(height)
17        .and_then(|pixels| pixels.checked_mul(3))
18        .ok_or_else(|| "image dimensions overflowed while preparing input".to_string())?;
19
20    if rgb.len() != expected_len {
21        return Err(format!(
22            "expected {expected_len} RGB bytes for {width}x{height}, got {}",
23            rgb.len()
24        ));
25    }
26
27    let hw = width * height;
28    let mut data = vec![0.0f32; 3 * hw];
29
30    for (idx, pixel) in rgb.chunks_exact(3).enumerate() {
31        for channel in 0..3 {
32            let value = pixel[channel] as f32 / 255.0;
33            data[channel * hw + idx] = value * 2.0 - 1.0;
34        }
35    }
36
37    Ok(
38        Tensor::<B, 1>::from_floats(data.as_slice(), device).reshape([
39            1,
40            3,
41            height as i32,
42            width as i32,
43        ]),
44    )
45}
46
47/// Runs the DepthPro model directly from packed RGB bytes.
48///
49/// This helper combines [`rgb_to_input_tensor`] and [`DepthPro::infer`], making it
50/// convenient to integrate inference in external applications without reimplementing
51/// the preprocessing pipeline.
52pub fn infer_from_rgb<B: Backend>(
53    model: &DepthPro<B>,
54    rgb: &[u8],
55    width: usize,
56    height: usize,
57    device: &B::Device,
58    focal_length_px: Option<Tensor<B, 1>>,
59) -> Result<DepthProInference<B>, String> {
60    let input = rgb_to_input_tensor::<B>(rgb, width, height, device)?;
61    Ok(model.infer(input, focal_length_px))
62}
63
64#[cfg(test)]
65mod tests {
66    use super::*;
67
68    type TestBackend = crate::InferenceBackend;
69
70    #[test]
71    fn rgb_to_input_tensor_normalizes_channels() {
72        let device = <TestBackend as Backend>::Device::default();
73        let rgb = vec![
74            0u8, 255, 128, //
75            255, 0, 128,
76        ];
77        let tensor = rgb_to_input_tensor::<TestBackend>(&rgb, 1, 2, &device).unwrap();
78        let data = tensor.into_data().convert::<f32>();
79        assert_eq!(data.shape.as_slice(), &[1, 3, 2, 1]);
80        let values = data.to_vec::<f32>().unwrap();
81
82        let expected = [-1.0f32, 1.0f32, 1.0f32, -1.0f32, 0.0039215689, 0.0039215689];
83        assert_eq!(values.len(), expected.len());
84        for (value, expected) in values.iter().zip(expected.iter()) {
85            assert!((value - expected).abs() < 1e-6);
86        }
87    }
88
89    #[test]
90    fn rgb_to_input_tensor_rejects_invalid_length() {
91        let device = <TestBackend as Backend>::Device::default();
92        let rgb = vec![0u8; 5];
93        let result = rgb_to_input_tensor::<TestBackend>(&rgb, 1, 2, &device);
94        assert!(result.is_err());
95    }
96}