1use burn::prelude::*;
2
3use crate::model::depth_pro::{DepthPro, DepthProInference};
4
5pub fn rgb_to_input_tensor<B: Backend>(
10 rgb: &[u8],
11 width: usize,
12 height: usize,
13 device: &B::Device,
14) -> Result<Tensor<B, 4>, String> {
15 let expected_len = width
16 .checked_mul(height)
17 .and_then(|pixels| pixels.checked_mul(3))
18 .ok_or_else(|| "image dimensions overflowed while preparing input".to_string())?;
19
20 if rgb.len() != expected_len {
21 return Err(format!(
22 "expected {expected_len} RGB bytes for {width}x{height}, got {}",
23 rgb.len()
24 ));
25 }
26
27 let hw = width * height;
28 let mut data = vec![0.0f32; 3 * hw];
29
30 for (idx, pixel) in rgb.chunks_exact(3).enumerate() {
31 for channel in 0..3 {
32 let value = pixel[channel] as f32 / 255.0;
33 data[channel * hw + idx] = value * 2.0 - 1.0;
34 }
35 }
36
37 Ok(
38 Tensor::<B, 1>::from_floats(data.as_slice(), device).reshape([
39 1,
40 3,
41 height as i32,
42 width as i32,
43 ]),
44 )
45}
46
47pub fn infer_from_rgb<B: Backend>(
53 model: &DepthPro<B>,
54 rgb: &[u8],
55 width: usize,
56 height: usize,
57 device: &B::Device,
58 focal_length_px: Option<Tensor<B, 1>>,
59) -> Result<DepthProInference<B>, String> {
60 let input = rgb_to_input_tensor::<B>(rgb, width, height, device)?;
61 Ok(model.infer(input, focal_length_px))
62}
63
64#[cfg(test)]
65mod tests {
66 use super::*;
67
68 type TestBackend = crate::InferenceBackend;
69
70 #[test]
71 fn rgb_to_input_tensor_normalizes_channels() {
72 let device = <TestBackend as Backend>::Device::default();
73 let rgb = vec![
74 0u8, 255, 128, 255, 0, 128,
76 ];
77 let tensor = rgb_to_input_tensor::<TestBackend>(&rgb, 1, 2, &device).unwrap();
78 let data = tensor.into_data().convert::<f32>();
79 assert_eq!(data.shape.as_slice(), &[1, 3, 2, 1]);
80 let values = data.to_vec::<f32>().unwrap();
81
82 let expected = [-1.0f32, 1.0f32, 1.0f32, -1.0f32, 0.0039215689, 0.0039215689];
83 assert_eq!(values.len(), expected.len());
84 for (value, expected) in values.iter().zip(expected.iter()) {
85 assert!((value - expected).abs() < 1e-6);
86 }
87 }
88
89 #[test]
90 fn rgb_to_input_tensor_rejects_invalid_length() {
91 let device = <TestBackend as Backend>::Device::default();
92 let rgb = vec![0u8; 5];
93 let result = rgb_to_input_tensor::<TestBackend>(&rgb, 1, 2, &device);
94 assert!(result.is_err());
95 }
96}