Skip to main content

llm_multimodal/vision/
transforms.rs

1//! Image transformation functions for vision preprocessing.
2//!
3//! This module provides composable transforms that match HuggingFace image processor
4//! behavior, enabling pure Rust preprocessing without Python dependencies.
5
6use std::cell::RefCell;
7
8use fast_image_resize::{
9    images::Image as FirImage, IntoImageView, ResizeAlg, ResizeOptions, Resizer,
10};
11use image::{imageops::FilterType, DynamicImage, GenericImageView, Rgb, RgbImage};
12use ndarray::{s, Array3, Array4};
13use thiserror::Error;
14
15/// Errors that can occur during image transformations.
16#[derive(Error, Debug)]
17pub enum TransformError {
18    #[error("Invalid tensor shape: expected {expected}, got {actual:?}")]
19    InvalidShape {
20        expected: String,
21        actual: Vec<usize>,
22    },
23
24    #[error("Image operation failed: {0}")]
25    ImageError(#[from] image::ImageError),
26
27    #[error("Empty batch: cannot stack zero tensors")]
28    EmptyBatch,
29
30    #[error("Inconsistent tensor shapes in batch")]
31    InconsistentShapes,
32
33    #[error("Shape error: {0}")]
34    ShapeError(String),
35}
36
37pub type Result<T> = std::result::Result<T, TransformError>;
38
39/// Extract RGB pixel data from a DynamicImage, avoiding a copy when already RGB8.
40/// Returns (width, height, raw_bytes) where raw_bytes is interleaved R,G,B,R,G,B,...
41pub fn rgb_bytes(image: &DynamicImage) -> (usize, usize, std::borrow::Cow<'_, [u8]>) {
42    match image {
43        DynamicImage::ImageRgb8(rgb) => (
44            rgb.width() as usize,
45            rgb.height() as usize,
46            std::borrow::Cow::Borrowed(rgb.as_raw()),
47        ),
48        _ => {
49            let rgb = image.to_rgb8();
50            let w = rgb.width() as usize;
51            let h = rgb.height() as usize;
52            (w, h, std::borrow::Cow::Owned(rgb.into_raw()))
53        }
54    }
55}
56
57/// Deinterleave interleaved RGB bytes into separate R, G, B f32 planes with
58/// per-channel `scale` and `bias`: `plane[c][i] = rgb[i*3 + c] * scale[c] + bias[c]`.
59///
60/// Processes 8 pixels at a time so the compiler can unroll and auto-vectorize
61/// the stride-3 gather pattern.
62pub fn deinterleave_rgb_to_planes(
63    rgb: &[u8],
64    r_plane: &mut [f32],
65    g_plane: &mut [f32],
66    b_plane: &mut [f32],
67    scale: [f32; 3],
68    bias: [f32; 3],
69) {
70    let pixels = r_plane.len();
71    debug_assert_eq!(pixels, g_plane.len());
72    debug_assert_eq!(pixels, b_plane.len());
73    debug_assert!(rgb.len() >= pixels * 3);
74
75    let full_blocks = pixels / 8;
76    let remainder = pixels % 8;
77
78    for block in 0..full_blocks {
79        let dst = block * 8;
80        let src_base = dst * 3;
81        let src = &rgb[src_base..src_base + 24];
82        let rd = &mut r_plane[dst..dst + 8];
83        let gd = &mut g_plane[dst..dst + 8];
84        let bd = &mut b_plane[dst..dst + 8];
85
86        for i in 0..8 {
87            let s = i * 3;
88            rd[i] = src[s] as f32 * scale[0] + bias[0];
89            gd[i] = src[s + 1] as f32 * scale[1] + bias[1];
90            bd[i] = src[s + 2] as f32 * scale[2] + bias[2];
91        }
92    }
93
94    let tail_dst = full_blocks * 8;
95    let tail_src = tail_dst * 3;
96    for i in 0..remainder {
97        let s = tail_src + i * 3;
98        r_plane[tail_dst + i] = rgb[s] as f32 * scale[0] + bias[0];
99        g_plane[tail_dst + i] = rgb[s + 1] as f32 * scale[1] + bias[1];
100        b_plane[tail_dst + i] = rgb[s + 2] as f32 * scale[2] + bias[2];
101    }
102}
103
104/// Build a [C, H, W] f32 tensor from interleaved RGB bytes with per-channel
105/// `scale` and `bias`: `output[c][i] = raw[i*3 + c] * scale[c] + bias[c]`.
106fn build_planar_tensor(
107    raw: &[u8],
108    w: usize,
109    h: usize,
110    scale: [f32; 3],
111    bias: [f32; 3],
112) -> Array3<f32> {
113    let pixels = h * w;
114    let mut data = vec![0.0f32; 3 * pixels];
115    let (r_plane, rest) = data.split_at_mut(pixels);
116    let (g_plane, b_plane) = rest.split_at_mut(pixels);
117
118    deinterleave_rgb_to_planes(raw, r_plane, g_plane, b_plane, scale, bias);
119
120    #[expect(
121        clippy::expect_used,
122        reason = "data has exactly 3*h*w elements by construction"
123    )]
124    Array3::from_shape_vec((3, h, w), data).expect("shape matches pre-allocated buffer")
125}
126
127/// Convert image to tensor [C, H, W] normalized to [0, 1].
128///
129/// This matches the default behavior of `torchvision.transforms.ToTensor()`.
130pub fn to_tensor(image: &DynamicImage) -> Array3<f32> {
131    let (w, h, raw) = rgb_bytes(image);
132    let s = 1.0 / 255.0;
133    build_planar_tensor(&raw, w, h, [s, s, s], [0.0, 0.0, 0.0])
134}
135
136/// Convert image to tensor [C, H, W] without normalization (keeps [0, 255]).
137#[cfg(test)]
138pub fn to_tensor_no_norm(image: &DynamicImage) -> Array3<f32> {
139    let (w, h, raw) = rgb_bytes(image);
140    build_planar_tensor(&raw, w, h, [1.0, 1.0, 1.0], [0.0, 0.0, 0.0])
141}
142
143/// Normalize tensor per channel: (x - mean) / std.
144///
145/// This matches `torchvision.transforms.Normalize(mean, std)`.
146///
147/// # Arguments
148/// * `tensor` - Input tensor of shape [C, H, W]
149/// * `mean` - Per-channel mean values
150/// * `std` - Per-channel standard deviation values
151pub fn normalize(tensor: &mut Array3<f32>, mean: &[f64; 3], std: &[f64; 3]) {
152    let [h, w] = [tensor.shape()[1], tensor.shape()[2]];
153    let pixels = h * w;
154
155    if let Some(flat) = tensor.as_slice_mut() {
156        // Fast path: contiguous memory, process channel planes directly
157        for c in 0..3 {
158            let mean_c = mean[c] as f32;
159            let inv_std_c = 1.0 / std[c] as f32;
160            let plane = &mut flat[c * pixels..(c + 1) * pixels];
161            for v in plane.iter_mut() {
162                *v = (*v - mean_c) * inv_std_c;
163            }
164        }
165    } else {
166        for c in 0..3 {
167            let mean_c = mean[c] as f32;
168            let std_c = std[c] as f32;
169            tensor
170                .slice_mut(s![c, .., ..])
171                .mapv_inplace(|v| (v - mean_c) / std_c);
172        }
173    }
174}
175
176/// Convert image to tensor and normalize in a single pass.
177///
178/// Fuses `to_tensor` (u8→f32 with /255) and `normalize` ((x-mean)/std)
179/// into one loop to avoid an extra pass over the data.
180pub fn to_tensor_and_normalize(
181    image: &DynamicImage,
182    mean: &[f64; 3],
183    std: &[f64; 3],
184) -> Array3<f32> {
185    let (w, h, raw) = rgb_bytes(image);
186    // Fused: (pixel/255 - mean) / std = pixel * (1/(255*std)) - mean/std
187    let scale: [f32; 3] = std::array::from_fn(|c| 1.0 / (255.0 * std[c] as f32));
188    let bias: [f32; 3] = std::array::from_fn(|c| -(mean[c] as f32) / (std[c] as f32));
189    build_planar_tensor(&raw, w, h, scale, bias)
190}
191
192/// Rescale tensor by a constant factor.
193///
194/// Used when `do_rescale=True` in HuggingFace configs (typically 1/255).
195pub fn rescale(tensor: &mut Array3<f32>, factor: f64) {
196    let factor = factor as f32;
197    tensor.mapv_inplace(|v| v * factor);
198}
199
200/// Map `image` crate filter types to `fast_image_resize` algorithm.
201fn to_fir_algorithm(filter: FilterType) -> ResizeAlg {
202    use fast_image_resize::FilterType as FirFilter;
203    match filter {
204        FilterType::Nearest => ResizeAlg::Nearest,
205        FilterType::Triangle => ResizeAlg::Convolution(FirFilter::Bilinear),
206        FilterType::CatmullRom => ResizeAlg::Convolution(FirFilter::CatmullRom),
207        FilterType::Gaussian => ResizeAlg::Convolution(FirFilter::Gaussian),
208        FilterType::Lanczos3 => ResizeAlg::Convolution(FirFilter::Lanczos3),
209    }
210}
211
212thread_local! {
213    static RESIZER: RefCell<Resizer> = RefCell::new(Resizer::new());
214}
215
216/// Resize image to exact dimensions using SIMD-accelerated resizer.
217///
218/// # Arguments
219/// * `image` - Input image
220/// * `width` - Target width
221/// * `height` - Target height
222/// * `filter` - Interpolation filter (Nearest, Triangle/Bilinear, CatmullRom/Bicubic, Lanczos3)
223pub fn resize(image: &DynamicImage, width: u32, height: u32, filter: FilterType) -> DynamicImage {
224    let pixel_type = match image.pixel_type() {
225        Some(pt) => pt,
226        None => return image.resize_exact(width, height, filter),
227    };
228    let mut dst = FirImage::new(width, height, pixel_type);
229    let options = ResizeOptions::new().resize_alg(to_fir_algorithm(filter));
230    let ok = RESIZER.with(|r| r.borrow_mut().resize(image, &mut dst, &options).is_ok());
231    if !ok {
232        return image.resize_exact(width, height, filter);
233    }
234    fir_image_to_dynamic(dst, width, height, image, filter)
235}
236
237/// Convert a `fast_image_resize::Image` back to a `DynamicImage`.
238///
239/// Falls back to the `image` crate resize for unhandled pixel formats.
240fn fir_image_to_dynamic(
241    img: FirImage<'_>,
242    width: u32,
243    height: u32,
244    source: &DynamicImage,
245    filter: FilterType,
246) -> DynamicImage {
247    let buf = img.into_vec();
248    match source {
249        DynamicImage::ImageRgb8(_) => {
250            RgbImage::from_raw(width, height, buf).map(DynamicImage::ImageRgb8)
251        }
252        DynamicImage::ImageRgba8(_) => {
253            image::RgbaImage::from_raw(width, height, buf).map(DynamicImage::ImageRgba8)
254        }
255        DynamicImage::ImageLuma8(_) => {
256            image::GrayImage::from_raw(width, height, buf).map(DynamicImage::ImageLuma8)
257        }
258        _ => None,
259    }
260    .unwrap_or_else(|| source.resize_exact(width, height, filter))
261}
262
263/// Resize image preserving aspect ratio, fitting within max dimensions.
264pub fn resize_to_fit(
265    image: &DynamicImage,
266    max_width: u32,
267    max_height: u32,
268    filter: FilterType,
269) -> DynamicImage {
270    let (w, h) = image.dimensions();
271    let ratio = (max_width as f64 / w as f64).min(max_height as f64 / h as f64);
272    if ratio >= 1.0 {
273        return image.clone();
274    }
275    let new_w = ((w as f64 * ratio).round() as u32).max(1);
276    let new_h = ((h as f64 * ratio).round() as u32).max(1);
277    resize(image, new_w, new_h, filter)
278}
279
280/// Center crop image to specified dimensions.
281///
282/// If the crop size is larger than the image, the image is returned unchanged.
283pub fn center_crop(image: &DynamicImage, crop_w: u32, crop_h: u32) -> DynamicImage {
284    let (w, h) = image.dimensions();
285    if crop_w >= w && crop_h >= h {
286        return image.clone();
287    }
288    let left = (w.saturating_sub(crop_w)) / 2;
289    let top = (h.saturating_sub(crop_h)) / 2;
290    let actual_w = crop_w.min(w);
291    let actual_h = crop_h.min(h);
292    image.crop_imm(left, top, actual_w, actual_h)
293}
294
295/// Expand image to square by padding with background color.
296///
297/// This is used by LLaVA models which expect square inputs. The image is
298/// centered and padded with the mean color on the shorter dimension.
299pub fn expand_to_square(image: &DynamicImage, background: Rgb<u8>) -> DynamicImage {
300    let (w, h) = image.dimensions();
301    match w.cmp(&h) {
302        std::cmp::Ordering::Equal => image.clone(),
303        std::cmp::Ordering::Less => {
304            // Height > Width: pad horizontally
305            let mut new_image = DynamicImage::from(RgbImage::from_pixel(h, h, background));
306            image::imageops::overlay(&mut new_image, image, ((h - w) / 2) as i64, 0);
307            new_image
308        }
309        std::cmp::Ordering::Greater => {
310            // Width > Height: pad vertically
311            let mut new_image = DynamicImage::from(RgbImage::from_pixel(w, w, background));
312            image::imageops::overlay(&mut new_image, image, 0, ((w - h) / 2) as i64);
313            new_image
314        }
315    }
316}
317
318/// Pad image to specified dimensions with background color.
319///
320/// Image is placed at top-left corner.
321pub fn pad_to_size(
322    image: &DynamicImage,
323    target_w: u32,
324    target_h: u32,
325    background: Rgb<u8>,
326) -> DynamicImage {
327    let (w, h) = image.dimensions();
328    if w >= target_w && h >= target_h {
329        return image.clone();
330    }
331    let new_w = w.max(target_w);
332    let new_h = h.max(target_h);
333    let mut new_image = DynamicImage::from(RgbImage::from_pixel(new_w, new_h, background));
334    image::imageops::overlay(&mut new_image, image, 0, 0);
335    new_image
336}
337
338/// Stack multiple [C, H, W] tensors into [B, C, H, W].
339///
340/// All tensors must have the same shape.
341pub fn stack_batch(tensors: &[Array3<f32>]) -> Result<Array4<f32>> {
342    if tensors.is_empty() {
343        return Err(TransformError::EmptyBatch);
344    }
345
346    let shape = tensors[0].shape();
347    let (c, h, w) = (shape[0], shape[1], shape[2]);
348
349    // Verify all tensors have the same shape
350    for tensor in tensors.iter().skip(1) {
351        if tensor.shape() != shape {
352            return Err(TransformError::InvalidShape {
353                expected: format!("[{c}, {h}, {w}]"),
354                actual: tensor.shape().to_vec(),
355            });
356        }
357    }
358
359    let mut batch = Array4::<f32>::zeros((tensors.len(), c, h, w));
360    for (i, tensor) in tensors.iter().enumerate() {
361        batch.slice_mut(s![i, .., .., ..]).assign(tensor);
362    }
363
364    Ok(batch)
365}
366
367/// Convert PIL/HuggingFace resampling enum to image crate filter.
368///
369/// PIL resampling constants:
370/// - 0: NEAREST
371/// - 1: LANCZOS (also ANTIALIAS)
372/// - 2: BILINEAR
373/// - 3: BICUBIC
374/// - 4: BOX
375/// - 5: HAMMING
376pub fn pil_to_filter(resampling: Option<usize>) -> FilterType {
377    match resampling {
378        Some(0) => FilterType::Nearest,
379        Some(1) => FilterType::Lanczos3,
380        Some(2) | None => FilterType::Triangle, // Bilinear (default)
381        Some(3) => FilterType::CatmullRom,      // Bicubic
382        // Box and Hamming don't have direct equivalents, use Triangle
383        Some(4) | Some(5) => FilterType::Triangle,
384        _ => FilterType::Triangle,
385    }
386}
387
388/// Calculate mean color of an image as RGB.
389pub fn calculate_mean_color(image: &DynamicImage) -> Rgb<u8> {
390    let rgb = image.to_rgb8();
391    let (w, h) = (rgb.width() as u64, rgb.height() as u64);
392    let total_pixels = w * h;
393
394    if total_pixels == 0 {
395        return Rgb([128, 128, 128]);
396    }
397
398    let (mut r_sum, mut g_sum, mut b_sum) = (0u64, 0u64, 0u64);
399    for pixel in rgb.pixels() {
400        r_sum += pixel[0] as u64;
401        g_sum += pixel[1] as u64;
402        b_sum += pixel[2] as u64;
403    }
404
405    Rgb([
406        (r_sum / total_pixels) as u8,
407        (g_sum / total_pixels) as u8,
408        (b_sum / total_pixels) as u8,
409    ])
410}
411
412/// Convert normalized mean values [0, 1] to RGB bytes.
413pub fn mean_to_rgb(mean: &[f64; 3]) -> Rgb<u8> {
414    Rgb([
415        (mean[0] * 255.0).round() as u8,
416        (mean[1] * 255.0).round() as u8,
417        (mean[2] * 255.0).round() as u8,
418    ])
419}
420
421/// Cubic interpolation weight function (Keys bicubic kernel with a=-0.5).
422///
423/// This matches PyTorch's bicubic interpolation used in
424/// `torch.nn.functional.interpolate(mode='bicubic')`.
425#[inline]
426pub fn cubic_weight(x: f32) -> f32 {
427    let x = x.abs();
428    if x < 1.0 {
429        (1.5 * x - 2.5) * x * x + 1.0
430    } else if x < 2.0 {
431        ((-0.5 * x + 2.5) * x - 4.0) * x + 2.0
432    } else {
433        0.0
434    }
435}
436
437/// Perform bicubic interpolation at a single point in a tensor.
438///
439/// Uses a 4x4 kernel with Keys bicubic weights (a=-0.5) to match PyTorch's
440/// `torch.nn.functional.interpolate(mode='bicubic')`.
441///
442/// # Arguments
443/// * `tensor` - Input tensor of shape [C, H, W]
444/// * `c` - Channel index
445/// * `src_y` - Source Y coordinate (can be fractional)
446/// * `src_x` - Source X coordinate (can be fractional)
447/// * `h` - Height of the tensor
448/// * `w` - Width of the tensor
449///
450/// # Returns
451/// The interpolated value at the specified position.
452pub fn bicubic_interpolate(
453    tensor: &Array3<f32>,
454    c: usize,
455    src_y: f32,
456    src_x: f32,
457    h: usize,
458    w: usize,
459) -> f32 {
460    let y_int = src_y.floor() as i32;
461    let x_int = src_x.floor() as i32;
462    let y_frac = src_y - y_int as f32;
463    let x_frac = src_x - x_int as f32;
464
465    let mut result = 0.0f32;
466
467    // Sample 4x4 neighborhood
468    for dy in -1..=2 {
469        let y_idx = (y_int + dy).clamp(0, h as i32 - 1) as usize;
470        let y_weight = cubic_weight(y_frac - dy as f32);
471
472        for dx in -1..=2 {
473            let x_idx = (x_int + dx).clamp(0, w as i32 - 1) as usize;
474            let x_weight = cubic_weight(x_frac - dx as f32);
475
476            result += tensor[[c, y_idx, x_idx]] * y_weight * x_weight;
477        }
478    }
479
480    result
481}
482
483/// Resize a tensor using bicubic interpolation.
484///
485/// This matches PyTorch's `torch.nn.functional.interpolate(mode='bicubic', align_corners=False)`.
486///
487/// # Arguments
488/// * `tensor` - Input tensor of shape [C, H, W]
489/// * `target_h` - Target height
490/// * `target_w` - Target width
491///
492/// # Returns
493/// Resized tensor of shape [C, target_h, target_w].
494pub fn bicubic_resize(tensor: &Array3<f32>, target_h: usize, target_w: usize) -> Array3<f32> {
495    let (c, h, w) = (tensor.shape()[0], tensor.shape()[1], tensor.shape()[2]);
496
497    if h == target_h && w == target_w {
498        return tensor.clone();
499    }
500
501    let mut result = Array3::<f32>::zeros((c, target_h, target_w));
502
503    // PyTorch align_corners=False coordinate mapping
504    let scale_h = h as f32 / target_h as f32;
505    let scale_w = w as f32 / target_w as f32;
506
507    for ch in 0..c {
508        for y in 0..target_h {
509            for x in 0..target_w {
510                // PyTorch align_corners=False: src = (dst + 0.5) * scale - 0.5
511                let src_y = (y as f32 + 0.5) * scale_h - 0.5;
512                let src_x = (x as f32 + 0.5) * scale_w - 0.5;
513
514                result[[ch, y, x]] = bicubic_interpolate(tensor, ch, src_y, src_x, h, w);
515            }
516        }
517    }
518
519    result
520}
521
522#[cfg(test)]
523mod tests {
524    use super::*;
525
526    fn create_test_image(width: u32, height: u32, color: Rgb<u8>) -> DynamicImage {
527        DynamicImage::from(RgbImage::from_pixel(width, height, color))
528    }
529
530    #[test]
531    fn test_to_tensor_shape() {
532        let img = create_test_image(10, 20, Rgb([255, 128, 0]));
533        let tensor = to_tensor(&img);
534        assert_eq!(tensor.shape(), &[3, 20, 10]); // [C, H, W]
535    }
536
537    #[test]
538    fn test_to_tensor_values() {
539        let img = create_test_image(2, 2, Rgb([255, 128, 0]));
540        let tensor = to_tensor(&img);
541
542        // Check normalization to [0, 1]
543        assert!((tensor[[0, 0, 0]] - 1.0).abs() < 1e-6); // R=255 -> 1.0
544        assert!((tensor[[1, 0, 0]] - 0.502).abs() < 0.01); // G=128 -> ~0.5
545        assert!((tensor[[2, 0, 0]] - 0.0).abs() < 1e-6); // B=0 -> 0.0
546    }
547
548    #[test]
549    fn test_to_tensor_no_norm() {
550        let img = create_test_image(2, 2, Rgb([255, 128, 64]));
551        let tensor = to_tensor_no_norm(&img);
552
553        assert!((tensor[[0, 0, 0]] - 255.0).abs() < 1e-6);
554        assert!((tensor[[1, 0, 0]] - 128.0).abs() < 1e-6);
555        assert!((tensor[[2, 0, 0]] - 64.0).abs() < 1e-6);
556    }
557
558    #[test]
559    fn test_normalize() {
560        let mut tensor = Array3::<f32>::from_elem((3, 2, 2), 0.5);
561        let mean = [0.5, 0.5, 0.5];
562        let std = [0.5, 0.5, 0.5];
563
564        normalize(&mut tensor, &mean, &std);
565
566        // (0.5 - 0.5) / 0.5 = 0.0
567        for val in &tensor {
568            assert!(val.abs() < 1e-6);
569        }
570    }
571
572    #[test]
573    fn test_rescale() {
574        let mut tensor = Array3::<f32>::from_elem((3, 2, 2), 255.0);
575        rescale(&mut tensor, 1.0 / 255.0);
576
577        for val in &tensor {
578            assert!((val - 1.0).abs() < 1e-6);
579        }
580    }
581
582    #[test]
583    fn test_resize() {
584        let img = create_test_image(100, 50, Rgb([128, 128, 128]));
585        let resized = resize(&img, 50, 25, FilterType::Triangle);
586
587        assert_eq!(resized.width(), 50);
588        assert_eq!(resized.height(), 25);
589    }
590
591    #[test]
592    fn test_center_crop() {
593        let img = create_test_image(100, 100, Rgb([128, 128, 128]));
594        let cropped = center_crop(&img, 50, 50);
595
596        assert_eq!(cropped.width(), 50);
597        assert_eq!(cropped.height(), 50);
598    }
599
600    #[test]
601    fn test_expand_to_square_horizontal() {
602        let img = create_test_image(100, 50, Rgb([255, 0, 0]));
603        let background = Rgb([0, 0, 0]);
604        let squared = expand_to_square(&img, background);
605
606        assert_eq!(squared.width(), 100);
607        assert_eq!(squared.height(), 100);
608    }
609
610    #[test]
611    fn test_expand_to_square_vertical() {
612        let img = create_test_image(50, 100, Rgb([255, 0, 0]));
613        let background = Rgb([0, 0, 0]);
614        let squared = expand_to_square(&img, background);
615
616        assert_eq!(squared.width(), 100);
617        assert_eq!(squared.height(), 100);
618    }
619
620    #[test]
621    fn test_expand_to_square_already_square() {
622        let img = create_test_image(100, 100, Rgb([255, 0, 0]));
623        let background = Rgb([0, 0, 0]);
624        let squared = expand_to_square(&img, background);
625
626        assert_eq!(squared.width(), 100);
627        assert_eq!(squared.height(), 100);
628    }
629
630    #[test]
631    fn test_stack_batch() {
632        let t1 = Array3::<f32>::zeros((3, 10, 10));
633        let t2 = Array3::<f32>::ones((3, 10, 10));
634
635        let batch = stack_batch(&[t1, t2]).unwrap();
636
637        assert_eq!(batch.shape(), &[2, 3, 10, 10]);
638    }
639
640    #[test]
641    fn test_stack_batch_empty() {
642        let result = stack_batch(&[]);
643        assert!(matches!(result, Err(TransformError::EmptyBatch)));
644    }
645
646    #[test]
647    fn test_pil_to_filter() {
648        assert!(matches!(pil_to_filter(Some(0)), FilterType::Nearest));
649        assert!(matches!(pil_to_filter(Some(1)), FilterType::Lanczos3));
650        assert!(matches!(pil_to_filter(Some(2)), FilterType::Triangle));
651        assert!(matches!(pil_to_filter(Some(3)), FilterType::CatmullRom));
652        assert!(matches!(pil_to_filter(None), FilterType::Triangle));
653    }
654
655    #[test]
656    fn test_mean_to_rgb() {
657        let mean = [0.5, 0.25, 1.0];
658        let rgb = mean_to_rgb(&mean);
659
660        assert_eq!(rgb[0], 128);
661        assert_eq!(rgb[1], 64);
662        assert_eq!(rgb[2], 255);
663    }
664}