Skip to main content

pineapple_neural/
preprocess.rs

1// Copyright (c) 2025, Tom Ouellette
2// Licensed under the BSD 3-Clause License
3
4use candle_core::{DType, Device, Result, Tensor};
5
6use pineapple_core::im::PineappleImage;
7
8/// Convert a PineappleImage to a 3-channel "RGB" tensor
9///
10/// Any 1-channel image is simply repeated three times to generate
11/// a 3-channel image. Multi-channel images that aren't 1 or 3 channels
12/// will be averaged then stacked to 3 channels.
13fn to_tensor_rgb(image: &PineappleImage, device: &Device) -> Result<Tensor> {
14    let w = image.width() as usize;
15    let h = image.height() as usize;
16    let c = image.channels() as usize;
17
18    let tensor = Tensor::from_vec(image.to_f32(), (h, w, c), device)?.permute((2, 0, 1))?;
19
20    if c == 3 {
21        return Ok(tensor);
22    }
23
24    if c == 1 {
25        return Tensor::cat(&[&tensor; 3], 0);
26    }
27
28    let averaged = tensor.mean_keepdim(0).unwrap();
29
30    Tensor::cat(&[&averaged; 3], 0)
31}
32
33/// Perform imagenet standardization on an input PineappleImage
34pub fn preprocess_imagenet(image: &PineappleImage, device: &Device) -> Result<Tensor> {
35    pub const IMAGENET_MEAN: [f32; 3] = [0.485f32, 0.456, 0.406];
36    pub const IMAGENET_STD: [f32; 3] = [0.229f32, 0.224, 0.225];
37
38    let tensor = if image.width() == 224 && image.height() == 224 {
39        to_tensor_rgb(image, device)?
40    } else {
41        to_tensor_rgb(&image.resize(224, 224).unwrap(), device)?
42    };
43
44    let mean = Tensor::new(&IMAGENET_MEAN, device)?.reshape((3, 1, 1))?;
45    let std = Tensor::new(&IMAGENET_STD, device)?.reshape((3, 1, 1))?;
46
47    (tensor.to_dtype(DType::F32)? / 255.)?
48        .broadcast_sub(&mean)?
49        .broadcast_div(&std)
50}
51
52/// Perform subcell standardization on an input PineappleImage
53///
54/// Note that subcell used min-max normalization for some reason
55/// https://github.com/CellProfiling/SubCellPortable/blob/main/inference.py#L76C1-L81C14
56pub fn preprocess_subcell(image: &PineappleImage, device: &Device) -> Result<Tensor> {
57    let eps: Tensor = Tensor::new(1e-6f32, device)?;
58
59    let tensor = if image.width() == 448 && image.height() == 448 {
60        to_tensor_rgb(image, device)?
61    } else {
62        to_tensor_rgb(&image.resize(448, 448).unwrap(), device)?
63    };
64
65    // Not sure if there's an implementation to take min over
66    // multiple dimensions in candle - need to re-check docs
67    let min_val = tensor.min(0)?.min(0)?.min(0)?;
68    let max_val = tensor.max(0)?.max(0)?.max(0)?;
69
70    tensor
71        .broadcast_sub(&min_val)?
72        .broadcast_div(&(max_val - min_val + eps)?)
73}
74
75#[cfg(test)]
76mod test {
77    use super::*;
78
79    use pineapple_core::im::PineappleBuffer;
80
81    #[test]
82    fn test_to_tensor_rgb_1channel() {
83        let buffer: Vec<u8> = vec![0, 1, 2, 3];
84        let image = PineappleImage::U8(PineappleBuffer::new(2, 2, 1, buffer).unwrap());
85        let tensor = to_tensor_rgb(&image, &Device::Cpu);
86
87        let shape = tensor.unwrap().shape().clone().into_dims();
88        assert_eq!(shape[0], 3);
89        assert_eq!(shape[1], 2);
90        assert_eq!(shape[2], 2);
91    }
92
93    #[test]
94    fn test_to_tensor_rgb_2channel() {
95        let buffer: Vec<u8> = vec![0, 1, 2, 3, 4, 5, 6, 7];
96        let image = PineappleImage::U8(PineappleBuffer::new(2, 2, 2, buffer).unwrap());
97        let tensor = to_tensor_rgb(&image, &Device::Cpu);
98
99        let shape = tensor.unwrap().shape().clone().into_dims();
100        assert_eq!(shape[0], 3);
101        assert_eq!(shape[1], 2);
102        assert_eq!(shape[2], 2);
103    }
104
105    #[test]
106    fn test_to_tensor_rgb_3channel() {
107        let buffer: Vec<u8> = vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11];
108        let image = PineappleImage::U8(PineappleBuffer::new(2, 2, 3, buffer).unwrap());
109        let tensor = to_tensor_rgb(&image, &Device::Cpu);
110
111        let shape = tensor.unwrap().shape().clone().into_dims();
112        assert_eq!(shape[0], 3);
113        assert_eq!(shape[1], 2);
114        assert_eq!(shape[2], 2);
115    }
116
117    #[test]
118    fn test_to_tensor_rgb_nchannel() {
119        let buffer: Vec<u8> = (0..20).collect();
120        let image = PineappleImage::U8(PineappleBuffer::new(2, 2, 5, buffer).unwrap());
121        let tensor = to_tensor_rgb(&image, &Device::Cpu);
122
123        let shape = tensor.unwrap().shape().clone().into_dims();
124        assert_eq!(shape[0], 3);
125        assert_eq!(shape[1], 2);
126        assert_eq!(shape[2], 2);
127    }
128}