use std::cell::RefCell;
use fast_image_resize::{
images::Image as FirImage, IntoImageView, ResizeAlg, ResizeOptions, Resizer,
};
use image::{imageops::FilterType, DynamicImage, GenericImageView, Rgb, RgbImage};
use ndarray::{s, Array3, Array4};
use thiserror::Error;
#[derive(Error, Debug)]
pub enum TransformError {
#[error("Invalid tensor shape: expected {expected}, got {actual:?}")]
InvalidShape {
expected: String,
actual: Vec<usize>,
},
#[error("Image operation failed: {0}")]
ImageError(#[from] image::ImageError),
#[error("Empty batch: cannot stack zero tensors")]
EmptyBatch,
#[error("Inconsistent tensor shapes in batch")]
InconsistentShapes,
#[error("Shape error: {0}")]
ShapeError(String),
}
pub type Result<T> = std::result::Result<T, TransformError>;
pub fn rgb_bytes(image: &DynamicImage) -> (usize, usize, std::borrow::Cow<'_, [u8]>) {
match image {
DynamicImage::ImageRgb8(rgb) => (
rgb.width() as usize,
rgb.height() as usize,
std::borrow::Cow::Borrowed(rgb.as_raw()),
),
_ => {
let rgb = image.to_rgb8();
let w = rgb.width() as usize;
let h = rgb.height() as usize;
(w, h, std::borrow::Cow::Owned(rgb.into_raw()))
}
}
}
pub fn deinterleave_rgb_to_planes(
rgb: &[u8],
r_plane: &mut [f32],
g_plane: &mut [f32],
b_plane: &mut [f32],
scale: [f32; 3],
bias: [f32; 3],
) {
let pixels = r_plane.len();
debug_assert_eq!(pixels, g_plane.len());
debug_assert_eq!(pixels, b_plane.len());
debug_assert!(rgb.len() >= pixels * 3);
let full_blocks = pixels / 8;
let remainder = pixels % 8;
for block in 0..full_blocks {
let dst = block * 8;
let src_base = dst * 3;
let src = &rgb[src_base..src_base + 24];
let rd = &mut r_plane[dst..dst + 8];
let gd = &mut g_plane[dst..dst + 8];
let bd = &mut b_plane[dst..dst + 8];
for i in 0..8 {
let s = i * 3;
rd[i] = src[s] as f32 * scale[0] + bias[0];
gd[i] = src[s + 1] as f32 * scale[1] + bias[1];
bd[i] = src[s + 2] as f32 * scale[2] + bias[2];
}
}
let tail_dst = full_blocks * 8;
let tail_src = tail_dst * 3;
for i in 0..remainder {
let s = tail_src + i * 3;
r_plane[tail_dst + i] = rgb[s] as f32 * scale[0] + bias[0];
g_plane[tail_dst + i] = rgb[s + 1] as f32 * scale[1] + bias[1];
b_plane[tail_dst + i] = rgb[s + 2] as f32 * scale[2] + bias[2];
}
}
fn build_planar_tensor(
raw: &[u8],
w: usize,
h: usize,
scale: [f32; 3],
bias: [f32; 3],
) -> Array3<f32> {
let pixels = h * w;
let mut data = vec![0.0f32; 3 * pixels];
let (r_plane, rest) = data.split_at_mut(pixels);
let (g_plane, b_plane) = rest.split_at_mut(pixels);
deinterleave_rgb_to_planes(raw, r_plane, g_plane, b_plane, scale, bias);
#[expect(
clippy::expect_used,
reason = "data has exactly 3*h*w elements by construction"
)]
Array3::from_shape_vec((3, h, w), data).expect("shape matches pre-allocated buffer")
}
pub fn to_tensor(image: &DynamicImage) -> Array3<f32> {
let (w, h, raw) = rgb_bytes(image);
let s = 1.0 / 255.0;
build_planar_tensor(&raw, w, h, [s, s, s], [0.0, 0.0, 0.0])
}
#[cfg(test)]
pub fn to_tensor_no_norm(image: &DynamicImage) -> Array3<f32> {
let (w, h, raw) = rgb_bytes(image);
build_planar_tensor(&raw, w, h, [1.0, 1.0, 1.0], [0.0, 0.0, 0.0])
}
pub fn normalize(tensor: &mut Array3<f32>, mean: &[f64; 3], std: &[f64; 3]) {
let [h, w] = [tensor.shape()[1], tensor.shape()[2]];
let pixels = h * w;
if let Some(flat) = tensor.as_slice_mut() {
for c in 0..3 {
let mean_c = mean[c] as f32;
let inv_std_c = 1.0 / std[c] as f32;
let plane = &mut flat[c * pixels..(c + 1) * pixels];
for v in plane.iter_mut() {
*v = (*v - mean_c) * inv_std_c;
}
}
} else {
for c in 0..3 {
let mean_c = mean[c] as f32;
let std_c = std[c] as f32;
tensor
.slice_mut(s![c, .., ..])
.mapv_inplace(|v| (v - mean_c) / std_c);
}
}
}
pub fn to_tensor_and_normalize(
image: &DynamicImage,
mean: &[f64; 3],
std: &[f64; 3],
) -> Array3<f32> {
let (w, h, raw) = rgb_bytes(image);
let scale: [f32; 3] = std::array::from_fn(|c| 1.0 / (255.0 * std[c] as f32));
let bias: [f32; 3] = std::array::from_fn(|c| -(mean[c] as f32) / (std[c] as f32));
build_planar_tensor(&raw, w, h, scale, bias)
}
pub fn rescale(tensor: &mut Array3<f32>, factor: f64) {
let factor = factor as f32;
tensor.mapv_inplace(|v| v * factor);
}
fn to_fir_algorithm(filter: FilterType) -> ResizeAlg {
use fast_image_resize::FilterType as FirFilter;
match filter {
FilterType::Nearest => ResizeAlg::Nearest,
FilterType::Triangle => ResizeAlg::Convolution(FirFilter::Bilinear),
FilterType::CatmullRom => ResizeAlg::Convolution(FirFilter::CatmullRom),
FilterType::Gaussian => ResizeAlg::Convolution(FirFilter::Gaussian),
FilterType::Lanczos3 => ResizeAlg::Convolution(FirFilter::Lanczos3),
}
}
thread_local! {
static RESIZER: RefCell<Resizer> = RefCell::new(Resizer::new());
}
pub fn resize(image: &DynamicImage, width: u32, height: u32, filter: FilterType) -> DynamicImage {
let pixel_type = match image.pixel_type() {
Some(pt) => pt,
None => return image.resize_exact(width, height, filter),
};
let mut dst = FirImage::new(width, height, pixel_type);
let options = ResizeOptions::new().resize_alg(to_fir_algorithm(filter));
let ok = RESIZER.with(|r| r.borrow_mut().resize(image, &mut dst, &options).is_ok());
if !ok {
return image.resize_exact(width, height, filter);
}
fir_image_to_dynamic(dst, width, height, image, filter)
}
fn fir_image_to_dynamic(
img: FirImage<'_>,
width: u32,
height: u32,
source: &DynamicImage,
filter: FilterType,
) -> DynamicImage {
let buf = img.into_vec();
match source {
DynamicImage::ImageRgb8(_) => {
RgbImage::from_raw(width, height, buf).map(DynamicImage::ImageRgb8)
}
DynamicImage::ImageRgba8(_) => {
image::RgbaImage::from_raw(width, height, buf).map(DynamicImage::ImageRgba8)
}
DynamicImage::ImageLuma8(_) => {
image::GrayImage::from_raw(width, height, buf).map(DynamicImage::ImageLuma8)
}
_ => None,
}
.unwrap_or_else(|| source.resize_exact(width, height, filter))
}
pub fn resize_to_fit(
image: &DynamicImage,
max_width: u32,
max_height: u32,
filter: FilterType,
) -> DynamicImage {
let (w, h) = image.dimensions();
let ratio = (max_width as f64 / w as f64).min(max_height as f64 / h as f64);
if ratio >= 1.0 {
return image.clone();
}
let new_w = ((w as f64 * ratio).round() as u32).max(1);
let new_h = ((h as f64 * ratio).round() as u32).max(1);
resize(image, new_w, new_h, filter)
}
pub fn center_crop(image: &DynamicImage, crop_w: u32, crop_h: u32) -> DynamicImage {
let (w, h) = image.dimensions();
if crop_w >= w && crop_h >= h {
return image.clone();
}
let left = (w.saturating_sub(crop_w)) / 2;
let top = (h.saturating_sub(crop_h)) / 2;
let actual_w = crop_w.min(w);
let actual_h = crop_h.min(h);
image.crop_imm(left, top, actual_w, actual_h)
}
pub fn expand_to_square(image: &DynamicImage, background: Rgb<u8>) -> DynamicImage {
let (w, h) = image.dimensions();
match w.cmp(&h) {
std::cmp::Ordering::Equal => image.clone(),
std::cmp::Ordering::Less => {
let mut new_image = DynamicImage::from(RgbImage::from_pixel(h, h, background));
image::imageops::overlay(&mut new_image, image, ((h - w) / 2) as i64, 0);
new_image
}
std::cmp::Ordering::Greater => {
let mut new_image = DynamicImage::from(RgbImage::from_pixel(w, w, background));
image::imageops::overlay(&mut new_image, image, 0, ((w - h) / 2) as i64);
new_image
}
}
}
pub fn pad_to_size(
image: &DynamicImage,
target_w: u32,
target_h: u32,
background: Rgb<u8>,
) -> DynamicImage {
let (w, h) = image.dimensions();
if w >= target_w && h >= target_h {
return image.clone();
}
let new_w = w.max(target_w);
let new_h = h.max(target_h);
let mut new_image = DynamicImage::from(RgbImage::from_pixel(new_w, new_h, background));
image::imageops::overlay(&mut new_image, image, 0, 0);
new_image
}
pub fn stack_batch(tensors: &[Array3<f32>]) -> Result<Array4<f32>> {
if tensors.is_empty() {
return Err(TransformError::EmptyBatch);
}
let shape = tensors[0].shape();
let (c, h, w) = (shape[0], shape[1], shape[2]);
for tensor in tensors.iter().skip(1) {
if tensor.shape() != shape {
return Err(TransformError::InvalidShape {
expected: format!("[{c}, {h}, {w}]"),
actual: tensor.shape().to_vec(),
});
}
}
let mut batch = Array4::<f32>::zeros((tensors.len(), c, h, w));
for (i, tensor) in tensors.iter().enumerate() {
batch.slice_mut(s![i, .., .., ..]).assign(tensor);
}
Ok(batch)
}
pub fn pil_to_filter(resampling: Option<usize>) -> FilterType {
match resampling {
Some(0) => FilterType::Nearest,
Some(1) => FilterType::Lanczos3,
Some(2) | None => FilterType::Triangle, Some(3) => FilterType::CatmullRom, Some(4) | Some(5) => FilterType::Triangle,
_ => FilterType::Triangle,
}
}
pub fn calculate_mean_color(image: &DynamicImage) -> Rgb<u8> {
let rgb = image.to_rgb8();
let (w, h) = (rgb.width() as u64, rgb.height() as u64);
let total_pixels = w * h;
if total_pixels == 0 {
return Rgb([128, 128, 128]);
}
let (mut r_sum, mut g_sum, mut b_sum) = (0u64, 0u64, 0u64);
for pixel in rgb.pixels() {
r_sum += pixel[0] as u64;
g_sum += pixel[1] as u64;
b_sum += pixel[2] as u64;
}
Rgb([
(r_sum / total_pixels) as u8,
(g_sum / total_pixels) as u8,
(b_sum / total_pixels) as u8,
])
}
pub fn mean_to_rgb(mean: &[f64; 3]) -> Rgb<u8> {
Rgb([
(mean[0] * 255.0).round() as u8,
(mean[1] * 255.0).round() as u8,
(mean[2] * 255.0).round() as u8,
])
}
#[inline]
pub fn cubic_weight(x: f32) -> f32 {
let x = x.abs();
if x < 1.0 {
(1.5 * x - 2.5) * x * x + 1.0
} else if x < 2.0 {
((-0.5 * x + 2.5) * x - 4.0) * x + 2.0
} else {
0.0
}
}
pub fn bicubic_interpolate(
tensor: &Array3<f32>,
c: usize,
src_y: f32,
src_x: f32,
h: usize,
w: usize,
) -> f32 {
let y_int = src_y.floor() as i32;
let x_int = src_x.floor() as i32;
let y_frac = src_y - y_int as f32;
let x_frac = src_x - x_int as f32;
let mut result = 0.0f32;
for dy in -1..=2 {
let y_idx = (y_int + dy).clamp(0, h as i32 - 1) as usize;
let y_weight = cubic_weight(y_frac - dy as f32);
for dx in -1..=2 {
let x_idx = (x_int + dx).clamp(0, w as i32 - 1) as usize;
let x_weight = cubic_weight(x_frac - dx as f32);
result += tensor[[c, y_idx, x_idx]] * y_weight * x_weight;
}
}
result
}
pub fn bicubic_resize(tensor: &Array3<f32>, target_h: usize, target_w: usize) -> Array3<f32> {
let (c, h, w) = (tensor.shape()[0], tensor.shape()[1], tensor.shape()[2]);
if h == target_h && w == target_w {
return tensor.clone();
}
let mut result = Array3::<f32>::zeros((c, target_h, target_w));
let scale_h = h as f32 / target_h as f32;
let scale_w = w as f32 / target_w as f32;
for ch in 0..c {
for y in 0..target_h {
for x in 0..target_w {
let src_y = (y as f32 + 0.5) * scale_h - 0.5;
let src_x = (x as f32 + 0.5) * scale_w - 0.5;
result[[ch, y, x]] = bicubic_interpolate(tensor, ch, src_y, src_x, h, w);
}
}
}
result
}
#[cfg(test)]
mod tests {
use super::*;
fn create_test_image(width: u32, height: u32, color: Rgb<u8>) -> DynamicImage {
DynamicImage::from(RgbImage::from_pixel(width, height, color))
}
#[test]
fn test_to_tensor_shape() {
let img = create_test_image(10, 20, Rgb([255, 128, 0]));
let tensor = to_tensor(&img);
assert_eq!(tensor.shape(), &[3, 20, 10]); }
#[test]
fn test_to_tensor_values() {
let img = create_test_image(2, 2, Rgb([255, 128, 0]));
let tensor = to_tensor(&img);
assert!((tensor[[0, 0, 0]] - 1.0).abs() < 1e-6); assert!((tensor[[1, 0, 0]] - 0.502).abs() < 0.01); assert!((tensor[[2, 0, 0]] - 0.0).abs() < 1e-6); }
#[test]
fn test_to_tensor_no_norm() {
let img = create_test_image(2, 2, Rgb([255, 128, 64]));
let tensor = to_tensor_no_norm(&img);
assert!((tensor[[0, 0, 0]] - 255.0).abs() < 1e-6);
assert!((tensor[[1, 0, 0]] - 128.0).abs() < 1e-6);
assert!((tensor[[2, 0, 0]] - 64.0).abs() < 1e-6);
}
#[test]
fn test_normalize() {
let mut tensor = Array3::<f32>::from_elem((3, 2, 2), 0.5);
let mean = [0.5, 0.5, 0.5];
let std = [0.5, 0.5, 0.5];
normalize(&mut tensor, &mean, &std);
for val in &tensor {
assert!(val.abs() < 1e-6);
}
}
#[test]
fn test_rescale() {
let mut tensor = Array3::<f32>::from_elem((3, 2, 2), 255.0);
rescale(&mut tensor, 1.0 / 255.0);
for val in &tensor {
assert!((val - 1.0).abs() < 1e-6);
}
}
#[test]
fn test_resize() {
let img = create_test_image(100, 50, Rgb([128, 128, 128]));
let resized = resize(&img, 50, 25, FilterType::Triangle);
assert_eq!(resized.width(), 50);
assert_eq!(resized.height(), 25);
}
#[test]
fn test_center_crop() {
let img = create_test_image(100, 100, Rgb([128, 128, 128]));
let cropped = center_crop(&img, 50, 50);
assert_eq!(cropped.width(), 50);
assert_eq!(cropped.height(), 50);
}
#[test]
fn test_expand_to_square_horizontal() {
let img = create_test_image(100, 50, Rgb([255, 0, 0]));
let background = Rgb([0, 0, 0]);
let squared = expand_to_square(&img, background);
assert_eq!(squared.width(), 100);
assert_eq!(squared.height(), 100);
}
#[test]
fn test_expand_to_square_vertical() {
let img = create_test_image(50, 100, Rgb([255, 0, 0]));
let background = Rgb([0, 0, 0]);
let squared = expand_to_square(&img, background);
assert_eq!(squared.width(), 100);
assert_eq!(squared.height(), 100);
}
#[test]
fn test_expand_to_square_already_square() {
let img = create_test_image(100, 100, Rgb([255, 0, 0]));
let background = Rgb([0, 0, 0]);
let squared = expand_to_square(&img, background);
assert_eq!(squared.width(), 100);
assert_eq!(squared.height(), 100);
}
#[test]
fn test_stack_batch() {
let t1 = Array3::<f32>::zeros((3, 10, 10));
let t2 = Array3::<f32>::ones((3, 10, 10));
let batch = stack_batch(&[t1, t2]).unwrap();
assert_eq!(batch.shape(), &[2, 3, 10, 10]);
}
#[test]
fn test_stack_batch_empty() {
let result = stack_batch(&[]);
assert!(matches!(result, Err(TransformError::EmptyBatch)));
}
#[test]
fn test_pil_to_filter() {
assert!(matches!(pil_to_filter(Some(0)), FilterType::Nearest));
assert!(matches!(pil_to_filter(Some(1)), FilterType::Lanczos3));
assert!(matches!(pil_to_filter(Some(2)), FilterType::Triangle));
assert!(matches!(pil_to_filter(Some(3)), FilterType::CatmullRom));
assert!(matches!(pil_to_filter(None), FilterType::Triangle));
}
#[test]
fn test_mean_to_rgb() {
let mean = [0.5, 0.25, 1.0];
let rgb = mean_to_rgb(&mean);
assert_eq!(rgb[0], 128);
assert_eq!(rgb[1], 64);
assert_eq!(rgb[2], 255);
}
}