use candle_core::{DType, Device, Result, Tensor};
use pineapple_core::im::PineappleImage;
fn to_tensor_rgb(image: &PineappleImage, device: &Device) -> Result<Tensor> {
let w = image.width() as usize;
let h = image.height() as usize;
let c = image.channels() as usize;
let tensor = Tensor::from_vec(image.to_f32(), (h, w, c), device)?.permute((2, 0, 1))?;
if c == 3 {
return Ok(tensor);
}
if c == 1 {
return Tensor::cat(&[&tensor; 3], 0);
}
let averaged = tensor.mean_keepdim(0).unwrap();
Tensor::cat(&[&averaged; 3], 0)
}
pub fn preprocess_imagenet(image: &PineappleImage, device: &Device) -> Result<Tensor> {
pub const IMAGENET_MEAN: [f32; 3] = [0.485f32, 0.456, 0.406];
pub const IMAGENET_STD: [f32; 3] = [0.229f32, 0.224, 0.225];
let tensor = if image.width() == 224 && image.height() == 224 {
to_tensor_rgb(image, device)?
} else {
to_tensor_rgb(&image.resize(224, 224).unwrap(), device)?
};
let mean = Tensor::new(&IMAGENET_MEAN, device)?.reshape((3, 1, 1))?;
let std = Tensor::new(&IMAGENET_STD, device)?.reshape((3, 1, 1))?;
(tensor.to_dtype(DType::F32)? / 255.)?
.broadcast_sub(&mean)?
.broadcast_div(&std)
}
pub fn preprocess_subcell(image: &PineappleImage, device: &Device) -> Result<Tensor> {
let eps: Tensor = Tensor::new(1e-6f32, device)?;
let tensor = if image.width() == 448 && image.height() == 448 {
to_tensor_rgb(image, device)?
} else {
to_tensor_rgb(&image.resize(448, 448).unwrap(), device)?
};
let min_val = tensor.min(0)?.min(0)?.min(0)?;
let max_val = tensor.max(0)?.max(0)?.max(0)?;
tensor
.broadcast_sub(&min_val)?
.broadcast_div(&(max_val - min_val + eps)?)
}
#[cfg(test)]
mod test {
use super::*;
use pineapple_core::im::PineappleBuffer;
#[test]
fn test_to_tensor_rgb_1channel() {
let buffer: Vec<u8> = vec![0, 1, 2, 3];
let image = PineappleImage::U8(PineappleBuffer::new(2, 2, 1, buffer).unwrap());
let tensor = to_tensor_rgb(&image, &Device::Cpu);
let shape = tensor.unwrap().shape().clone().into_dims();
assert_eq!(shape[0], 3);
assert_eq!(shape[1], 2);
assert_eq!(shape[2], 2);
}
#[test]
fn test_to_tensor_rgb_2channel() {
let buffer: Vec<u8> = vec![0, 1, 2, 3, 4, 5, 6, 7];
let image = PineappleImage::U8(PineappleBuffer::new(2, 2, 2, buffer).unwrap());
let tensor = to_tensor_rgb(&image, &Device::Cpu);
let shape = tensor.unwrap().shape().clone().into_dims();
assert_eq!(shape[0], 3);
assert_eq!(shape[1], 2);
assert_eq!(shape[2], 2);
}
#[test]
fn test_to_tensor_rgb_3channel() {
let buffer: Vec<u8> = vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11];
let image = PineappleImage::U8(PineappleBuffer::new(2, 2, 3, buffer).unwrap());
let tensor = to_tensor_rgb(&image, &Device::Cpu);
let shape = tensor.unwrap().shape().clone().into_dims();
assert_eq!(shape[0], 3);
assert_eq!(shape[1], 2);
assert_eq!(shape[2], 2);
}
#[test]
fn test_to_tensor_rgb_nchannel() {
let buffer: Vec<u8> = (0..20).collect();
let image = PineappleImage::U8(PineappleBuffer::new(2, 2, 5, buffer).unwrap());
let tensor = to_tensor_rgb(&image, &Device::Cpu);
let shape = tensor.unwrap().shape().clone().into_dims();
assert_eq!(shape[0], 3);
assert_eq!(shape[1], 2);
assert_eq!(shape[2], 2);
}
}