pineapple_neural/
preprocess.rs1use candle_core::{DType, Device, Result, Tensor};
5
6use pineapple_core::im::PineappleImage;
7
8fn to_tensor_rgb(image: &PineappleImage, device: &Device) -> Result<Tensor> {
14 let w = image.width() as usize;
15 let h = image.height() as usize;
16 let c = image.channels() as usize;
17
18 let tensor = Tensor::from_vec(image.to_f32(), (h, w, c), device)?.permute((2, 0, 1))?;
19
20 if c == 3 {
21 return Ok(tensor);
22 }
23
24 if c == 1 {
25 return Tensor::cat(&[&tensor; 3], 0);
26 }
27
28 let averaged = tensor.mean_keepdim(0).unwrap();
29
30 Tensor::cat(&[&averaged; 3], 0)
31}
32
33pub fn preprocess_imagenet(image: &PineappleImage, device: &Device) -> Result<Tensor> {
35 pub const IMAGENET_MEAN: [f32; 3] = [0.485f32, 0.456, 0.406];
36 pub const IMAGENET_STD: [f32; 3] = [0.229f32, 0.224, 0.225];
37
38 let tensor = if image.width() == 224 && image.height() == 224 {
39 to_tensor_rgb(image, device)?
40 } else {
41 to_tensor_rgb(&image.resize(224, 224).unwrap(), device)?
42 };
43
44 let mean = Tensor::new(&IMAGENET_MEAN, device)?.reshape((3, 1, 1))?;
45 let std = Tensor::new(&IMAGENET_STD, device)?.reshape((3, 1, 1))?;
46
47 (tensor.to_dtype(DType::F32)? / 255.)?
48 .broadcast_sub(&mean)?
49 .broadcast_div(&std)
50}
51
52pub fn preprocess_subcell(image: &PineappleImage, device: &Device) -> Result<Tensor> {
57 let eps: Tensor = Tensor::new(1e-6f32, device)?;
58
59 let tensor = if image.width() == 448 && image.height() == 448 {
60 to_tensor_rgb(image, device)?
61 } else {
62 to_tensor_rgb(&image.resize(448, 448).unwrap(), device)?
63 };
64
65 let min_val = tensor.min(0)?.min(0)?.min(0)?;
68 let max_val = tensor.max(0)?.max(0)?.max(0)?;
69
70 tensor
71 .broadcast_sub(&min_val)?
72 .broadcast_div(&(max_val - min_val + eps)?)
73}
74
75#[cfg(test)]
76mod test {
77 use super::*;
78
79 use pineapple_core::im::PineappleBuffer;
80
81 #[test]
82 fn test_to_tensor_rgb_1channel() {
83 let buffer: Vec<u8> = vec![0, 1, 2, 3];
84 let image = PineappleImage::U8(PineappleBuffer::new(2, 2, 1, buffer).unwrap());
85 let tensor = to_tensor_rgb(&image, &Device::Cpu);
86
87 let shape = tensor.unwrap().shape().clone().into_dims();
88 assert_eq!(shape[0], 3);
89 assert_eq!(shape[1], 2);
90 assert_eq!(shape[2], 2);
91 }
92
93 #[test]
94 fn test_to_tensor_rgb_2channel() {
95 let buffer: Vec<u8> = vec![0, 1, 2, 3, 4, 5, 6, 7];
96 let image = PineappleImage::U8(PineappleBuffer::new(2, 2, 2, buffer).unwrap());
97 let tensor = to_tensor_rgb(&image, &Device::Cpu);
98
99 let shape = tensor.unwrap().shape().clone().into_dims();
100 assert_eq!(shape[0], 3);
101 assert_eq!(shape[1], 2);
102 assert_eq!(shape[2], 2);
103 }
104
105 #[test]
106 fn test_to_tensor_rgb_3channel() {
107 let buffer: Vec<u8> = vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11];
108 let image = PineappleImage::U8(PineappleBuffer::new(2, 2, 3, buffer).unwrap());
109 let tensor = to_tensor_rgb(&image, &Device::Cpu);
110
111 let shape = tensor.unwrap().shape().clone().into_dims();
112 assert_eq!(shape[0], 3);
113 assert_eq!(shape[1], 2);
114 assert_eq!(shape[2], 2);
115 }
116
117 #[test]
118 fn test_to_tensor_rgb_nchannel() {
119 let buffer: Vec<u8> = (0..20).collect();
120 let image = PineappleImage::U8(PineappleBuffer::new(2, 2, 5, buffer).unwrap());
121 let tensor = to_tensor_rgb(&image, &Device::Cpu);
122
123 let shape = tensor.unwrap().shape().clone().into_dims();
124 assert_eq!(shape[0], 3);
125 assert_eq!(shape[1], 2);
126 assert_eq!(shape[2], 2);
127 }
128}