1use std::cell::RefCell;
7
8use fast_image_resize::{
9 images::Image as FirImage, IntoImageView, ResizeAlg, ResizeOptions, Resizer,
10};
11use image::{imageops::FilterType, DynamicImage, GenericImageView, Rgb, RgbImage};
12use ndarray::{s, Array3, Array4};
13use thiserror::Error;
14
15#[derive(Error, Debug)]
17pub enum TransformError {
18 #[error("Invalid tensor shape: expected {expected}, got {actual:?}")]
19 InvalidShape {
20 expected: String,
21 actual: Vec<usize>,
22 },
23
24 #[error("Image operation failed: {0}")]
25 ImageError(#[from] image::ImageError),
26
27 #[error("Empty batch: cannot stack zero tensors")]
28 EmptyBatch,
29
30 #[error("Inconsistent tensor shapes in batch")]
31 InconsistentShapes,
32
33 #[error("Shape error: {0}")]
34 ShapeError(String),
35}
36
37pub type Result<T> = std::result::Result<T, TransformError>;
38
39pub fn rgb_bytes(image: &DynamicImage) -> (usize, usize, std::borrow::Cow<'_, [u8]>) {
42 match image {
43 DynamicImage::ImageRgb8(rgb) => (
44 rgb.width() as usize,
45 rgb.height() as usize,
46 std::borrow::Cow::Borrowed(rgb.as_raw()),
47 ),
48 _ => {
49 let rgb = image.to_rgb8();
50 let w = rgb.width() as usize;
51 let h = rgb.height() as usize;
52 (w, h, std::borrow::Cow::Owned(rgb.into_raw()))
53 }
54 }
55}
56
57pub fn deinterleave_rgb_to_planes(
63 rgb: &[u8],
64 r_plane: &mut [f32],
65 g_plane: &mut [f32],
66 b_plane: &mut [f32],
67 scale: [f32; 3],
68 bias: [f32; 3],
69) {
70 let pixels = r_plane.len();
71 debug_assert_eq!(pixels, g_plane.len());
72 debug_assert_eq!(pixels, b_plane.len());
73 debug_assert!(rgb.len() >= pixels * 3);
74
75 let full_blocks = pixels / 8;
76 let remainder = pixels % 8;
77
78 for block in 0..full_blocks {
79 let dst = block * 8;
80 let src_base = dst * 3;
81 let src = &rgb[src_base..src_base + 24];
82 let rd = &mut r_plane[dst..dst + 8];
83 let gd = &mut g_plane[dst..dst + 8];
84 let bd = &mut b_plane[dst..dst + 8];
85
86 for i in 0..8 {
87 let s = i * 3;
88 rd[i] = src[s] as f32 * scale[0] + bias[0];
89 gd[i] = src[s + 1] as f32 * scale[1] + bias[1];
90 bd[i] = src[s + 2] as f32 * scale[2] + bias[2];
91 }
92 }
93
94 let tail_dst = full_blocks * 8;
95 let tail_src = tail_dst * 3;
96 for i in 0..remainder {
97 let s = tail_src + i * 3;
98 r_plane[tail_dst + i] = rgb[s] as f32 * scale[0] + bias[0];
99 g_plane[tail_dst + i] = rgb[s + 1] as f32 * scale[1] + bias[1];
100 b_plane[tail_dst + i] = rgb[s + 2] as f32 * scale[2] + bias[2];
101 }
102}
103
104fn build_planar_tensor(
107 raw: &[u8],
108 w: usize,
109 h: usize,
110 scale: [f32; 3],
111 bias: [f32; 3],
112) -> Array3<f32> {
113 let pixels = h * w;
114 let mut data = vec![0.0f32; 3 * pixels];
115 let (r_plane, rest) = data.split_at_mut(pixels);
116 let (g_plane, b_plane) = rest.split_at_mut(pixels);
117
118 deinterleave_rgb_to_planes(raw, r_plane, g_plane, b_plane, scale, bias);
119
120 #[expect(
121 clippy::expect_used,
122 reason = "data has exactly 3*h*w elements by construction"
123 )]
124 Array3::from_shape_vec((3, h, w), data).expect("shape matches pre-allocated buffer")
125}
126
127pub fn to_tensor(image: &DynamicImage) -> Array3<f32> {
131 let (w, h, raw) = rgb_bytes(image);
132 let s = 1.0 / 255.0;
133 build_planar_tensor(&raw, w, h, [s, s, s], [0.0, 0.0, 0.0])
134}
135
136#[cfg(test)]
138pub fn to_tensor_no_norm(image: &DynamicImage) -> Array3<f32> {
139 let (w, h, raw) = rgb_bytes(image);
140 build_planar_tensor(&raw, w, h, [1.0, 1.0, 1.0], [0.0, 0.0, 0.0])
141}
142
143pub fn normalize(tensor: &mut Array3<f32>, mean: &[f64; 3], std: &[f64; 3]) {
152 let [h, w] = [tensor.shape()[1], tensor.shape()[2]];
153 let pixels = h * w;
154
155 if let Some(flat) = tensor.as_slice_mut() {
156 for c in 0..3 {
158 let mean_c = mean[c] as f32;
159 let inv_std_c = 1.0 / std[c] as f32;
160 let plane = &mut flat[c * pixels..(c + 1) * pixels];
161 for v in plane.iter_mut() {
162 *v = (*v - mean_c) * inv_std_c;
163 }
164 }
165 } else {
166 for c in 0..3 {
167 let mean_c = mean[c] as f32;
168 let std_c = std[c] as f32;
169 tensor
170 .slice_mut(s![c, .., ..])
171 .mapv_inplace(|v| (v - mean_c) / std_c);
172 }
173 }
174}
175
176pub fn to_tensor_and_normalize(
181 image: &DynamicImage,
182 mean: &[f64; 3],
183 std: &[f64; 3],
184) -> Array3<f32> {
185 let (w, h, raw) = rgb_bytes(image);
186 let scale: [f32; 3] = std::array::from_fn(|c| 1.0 / (255.0 * std[c] as f32));
188 let bias: [f32; 3] = std::array::from_fn(|c| -(mean[c] as f32) / (std[c] as f32));
189 build_planar_tensor(&raw, w, h, scale, bias)
190}
191
192pub fn rescale(tensor: &mut Array3<f32>, factor: f64) {
196 let factor = factor as f32;
197 tensor.mapv_inplace(|v| v * factor);
198}
199
200fn to_fir_algorithm(filter: FilterType) -> ResizeAlg {
202 use fast_image_resize::FilterType as FirFilter;
203 match filter {
204 FilterType::Nearest => ResizeAlg::Nearest,
205 FilterType::Triangle => ResizeAlg::Convolution(FirFilter::Bilinear),
206 FilterType::CatmullRom => ResizeAlg::Convolution(FirFilter::CatmullRom),
207 FilterType::Gaussian => ResizeAlg::Convolution(FirFilter::Gaussian),
208 FilterType::Lanczos3 => ResizeAlg::Convolution(FirFilter::Lanczos3),
209 }
210}
211
212thread_local! {
213 static RESIZER: RefCell<Resizer> = RefCell::new(Resizer::new());
214}
215
216pub fn resize(image: &DynamicImage, width: u32, height: u32, filter: FilterType) -> DynamicImage {
224 let pixel_type = match image.pixel_type() {
225 Some(pt) => pt,
226 None => return image.resize_exact(width, height, filter),
227 };
228 let mut dst = FirImage::new(width, height, pixel_type);
229 let options = ResizeOptions::new().resize_alg(to_fir_algorithm(filter));
230 let ok = RESIZER.with(|r| r.borrow_mut().resize(image, &mut dst, &options).is_ok());
231 if !ok {
232 return image.resize_exact(width, height, filter);
233 }
234 fir_image_to_dynamic(dst, width, height, image, filter)
235}
236
237fn fir_image_to_dynamic(
241 img: FirImage<'_>,
242 width: u32,
243 height: u32,
244 source: &DynamicImage,
245 filter: FilterType,
246) -> DynamicImage {
247 let buf = img.into_vec();
248 match source {
249 DynamicImage::ImageRgb8(_) => {
250 RgbImage::from_raw(width, height, buf).map(DynamicImage::ImageRgb8)
251 }
252 DynamicImage::ImageRgba8(_) => {
253 image::RgbaImage::from_raw(width, height, buf).map(DynamicImage::ImageRgba8)
254 }
255 DynamicImage::ImageLuma8(_) => {
256 image::GrayImage::from_raw(width, height, buf).map(DynamicImage::ImageLuma8)
257 }
258 _ => None,
259 }
260 .unwrap_or_else(|| source.resize_exact(width, height, filter))
261}
262
263pub fn resize_to_fit(
265 image: &DynamicImage,
266 max_width: u32,
267 max_height: u32,
268 filter: FilterType,
269) -> DynamicImage {
270 let (w, h) = image.dimensions();
271 let ratio = (max_width as f64 / w as f64).min(max_height as f64 / h as f64);
272 if ratio >= 1.0 {
273 return image.clone();
274 }
275 let new_w = ((w as f64 * ratio).round() as u32).max(1);
276 let new_h = ((h as f64 * ratio).round() as u32).max(1);
277 resize(image, new_w, new_h, filter)
278}
279
280pub fn center_crop(image: &DynamicImage, crop_w: u32, crop_h: u32) -> DynamicImage {
284 let (w, h) = image.dimensions();
285 if crop_w >= w && crop_h >= h {
286 return image.clone();
287 }
288 let left = (w.saturating_sub(crop_w)) / 2;
289 let top = (h.saturating_sub(crop_h)) / 2;
290 let actual_w = crop_w.min(w);
291 let actual_h = crop_h.min(h);
292 image.crop_imm(left, top, actual_w, actual_h)
293}
294
295pub fn expand_to_square(image: &DynamicImage, background: Rgb<u8>) -> DynamicImage {
300 let (w, h) = image.dimensions();
301 match w.cmp(&h) {
302 std::cmp::Ordering::Equal => image.clone(),
303 std::cmp::Ordering::Less => {
304 let mut new_image = DynamicImage::from(RgbImage::from_pixel(h, h, background));
306 image::imageops::overlay(&mut new_image, image, ((h - w) / 2) as i64, 0);
307 new_image
308 }
309 std::cmp::Ordering::Greater => {
310 let mut new_image = DynamicImage::from(RgbImage::from_pixel(w, w, background));
312 image::imageops::overlay(&mut new_image, image, 0, ((w - h) / 2) as i64);
313 new_image
314 }
315 }
316}
317
318pub fn pad_to_size(
322 image: &DynamicImage,
323 target_w: u32,
324 target_h: u32,
325 background: Rgb<u8>,
326) -> DynamicImage {
327 let (w, h) = image.dimensions();
328 if w >= target_w && h >= target_h {
329 return image.clone();
330 }
331 let new_w = w.max(target_w);
332 let new_h = h.max(target_h);
333 let mut new_image = DynamicImage::from(RgbImage::from_pixel(new_w, new_h, background));
334 image::imageops::overlay(&mut new_image, image, 0, 0);
335 new_image
336}
337
338pub fn stack_batch(tensors: &[Array3<f32>]) -> Result<Array4<f32>> {
342 if tensors.is_empty() {
343 return Err(TransformError::EmptyBatch);
344 }
345
346 let shape = tensors[0].shape();
347 let (c, h, w) = (shape[0], shape[1], shape[2]);
348
349 for tensor in tensors.iter().skip(1) {
351 if tensor.shape() != shape {
352 return Err(TransformError::InvalidShape {
353 expected: format!("[{c}, {h}, {w}]"),
354 actual: tensor.shape().to_vec(),
355 });
356 }
357 }
358
359 let mut batch = Array4::<f32>::zeros((tensors.len(), c, h, w));
360 for (i, tensor) in tensors.iter().enumerate() {
361 batch.slice_mut(s![i, .., .., ..]).assign(tensor);
362 }
363
364 Ok(batch)
365}
366
367pub fn pil_to_filter(resampling: Option<usize>) -> FilterType {
377 match resampling {
378 Some(0) => FilterType::Nearest,
379 Some(1) => FilterType::Lanczos3,
380 Some(2) | None => FilterType::Triangle, Some(3) => FilterType::CatmullRom, Some(4) | Some(5) => FilterType::Triangle,
384 _ => FilterType::Triangle,
385 }
386}
387
388pub fn calculate_mean_color(image: &DynamicImage) -> Rgb<u8> {
390 let rgb = image.to_rgb8();
391 let (w, h) = (rgb.width() as u64, rgb.height() as u64);
392 let total_pixels = w * h;
393
394 if total_pixels == 0 {
395 return Rgb([128, 128, 128]);
396 }
397
398 let (mut r_sum, mut g_sum, mut b_sum) = (0u64, 0u64, 0u64);
399 for pixel in rgb.pixels() {
400 r_sum += pixel[0] as u64;
401 g_sum += pixel[1] as u64;
402 b_sum += pixel[2] as u64;
403 }
404
405 Rgb([
406 (r_sum / total_pixels) as u8,
407 (g_sum / total_pixels) as u8,
408 (b_sum / total_pixels) as u8,
409 ])
410}
411
412pub fn mean_to_rgb(mean: &[f64; 3]) -> Rgb<u8> {
414 Rgb([
415 (mean[0] * 255.0).round() as u8,
416 (mean[1] * 255.0).round() as u8,
417 (mean[2] * 255.0).round() as u8,
418 ])
419}
420
421#[inline]
426pub fn cubic_weight(x: f32) -> f32 {
427 let x = x.abs();
428 if x < 1.0 {
429 (1.5 * x - 2.5) * x * x + 1.0
430 } else if x < 2.0 {
431 ((-0.5 * x + 2.5) * x - 4.0) * x + 2.0
432 } else {
433 0.0
434 }
435}
436
437pub fn bicubic_interpolate(
453 tensor: &Array3<f32>,
454 c: usize,
455 src_y: f32,
456 src_x: f32,
457 h: usize,
458 w: usize,
459) -> f32 {
460 let y_int = src_y.floor() as i32;
461 let x_int = src_x.floor() as i32;
462 let y_frac = src_y - y_int as f32;
463 let x_frac = src_x - x_int as f32;
464
465 let mut result = 0.0f32;
466
467 for dy in -1..=2 {
469 let y_idx = (y_int + dy).clamp(0, h as i32 - 1) as usize;
470 let y_weight = cubic_weight(y_frac - dy as f32);
471
472 for dx in -1..=2 {
473 let x_idx = (x_int + dx).clamp(0, w as i32 - 1) as usize;
474 let x_weight = cubic_weight(x_frac - dx as f32);
475
476 result += tensor[[c, y_idx, x_idx]] * y_weight * x_weight;
477 }
478 }
479
480 result
481}
482
483pub fn bicubic_resize(tensor: &Array3<f32>, target_h: usize, target_w: usize) -> Array3<f32> {
495 let (c, h, w) = (tensor.shape()[0], tensor.shape()[1], tensor.shape()[2]);
496
497 if h == target_h && w == target_w {
498 return tensor.clone();
499 }
500
501 let mut result = Array3::<f32>::zeros((c, target_h, target_w));
502
503 let scale_h = h as f32 / target_h as f32;
505 let scale_w = w as f32 / target_w as f32;
506
507 for ch in 0..c {
508 for y in 0..target_h {
509 for x in 0..target_w {
510 let src_y = (y as f32 + 0.5) * scale_h - 0.5;
512 let src_x = (x as f32 + 0.5) * scale_w - 0.5;
513
514 result[[ch, y, x]] = bicubic_interpolate(tensor, ch, src_y, src_x, h, w);
515 }
516 }
517 }
518
519 result
520}
521
522#[cfg(test)]
523mod tests {
524 use super::*;
525
526 fn create_test_image(width: u32, height: u32, color: Rgb<u8>) -> DynamicImage {
527 DynamicImage::from(RgbImage::from_pixel(width, height, color))
528 }
529
530 #[test]
531 fn test_to_tensor_shape() {
532 let img = create_test_image(10, 20, Rgb([255, 128, 0]));
533 let tensor = to_tensor(&img);
534 assert_eq!(tensor.shape(), &[3, 20, 10]); }
536
537 #[test]
538 fn test_to_tensor_values() {
539 let img = create_test_image(2, 2, Rgb([255, 128, 0]));
540 let tensor = to_tensor(&img);
541
542 assert!((tensor[[0, 0, 0]] - 1.0).abs() < 1e-6); assert!((tensor[[1, 0, 0]] - 0.502).abs() < 0.01); assert!((tensor[[2, 0, 0]] - 0.0).abs() < 1e-6); }
547
548 #[test]
549 fn test_to_tensor_no_norm() {
550 let img = create_test_image(2, 2, Rgb([255, 128, 64]));
551 let tensor = to_tensor_no_norm(&img);
552
553 assert!((tensor[[0, 0, 0]] - 255.0).abs() < 1e-6);
554 assert!((tensor[[1, 0, 0]] - 128.0).abs() < 1e-6);
555 assert!((tensor[[2, 0, 0]] - 64.0).abs() < 1e-6);
556 }
557
558 #[test]
559 fn test_normalize() {
560 let mut tensor = Array3::<f32>::from_elem((3, 2, 2), 0.5);
561 let mean = [0.5, 0.5, 0.5];
562 let std = [0.5, 0.5, 0.5];
563
564 normalize(&mut tensor, &mean, &std);
565
566 for val in &tensor {
568 assert!(val.abs() < 1e-6);
569 }
570 }
571
572 #[test]
573 fn test_rescale() {
574 let mut tensor = Array3::<f32>::from_elem((3, 2, 2), 255.0);
575 rescale(&mut tensor, 1.0 / 255.0);
576
577 for val in &tensor {
578 assert!((val - 1.0).abs() < 1e-6);
579 }
580 }
581
582 #[test]
583 fn test_resize() {
584 let img = create_test_image(100, 50, Rgb([128, 128, 128]));
585 let resized = resize(&img, 50, 25, FilterType::Triangle);
586
587 assert_eq!(resized.width(), 50);
588 assert_eq!(resized.height(), 25);
589 }
590
591 #[test]
592 fn test_center_crop() {
593 let img = create_test_image(100, 100, Rgb([128, 128, 128]));
594 let cropped = center_crop(&img, 50, 50);
595
596 assert_eq!(cropped.width(), 50);
597 assert_eq!(cropped.height(), 50);
598 }
599
600 #[test]
601 fn test_expand_to_square_horizontal() {
602 let img = create_test_image(100, 50, Rgb([255, 0, 0]));
603 let background = Rgb([0, 0, 0]);
604 let squared = expand_to_square(&img, background);
605
606 assert_eq!(squared.width(), 100);
607 assert_eq!(squared.height(), 100);
608 }
609
610 #[test]
611 fn test_expand_to_square_vertical() {
612 let img = create_test_image(50, 100, Rgb([255, 0, 0]));
613 let background = Rgb([0, 0, 0]);
614 let squared = expand_to_square(&img, background);
615
616 assert_eq!(squared.width(), 100);
617 assert_eq!(squared.height(), 100);
618 }
619
620 #[test]
621 fn test_expand_to_square_already_square() {
622 let img = create_test_image(100, 100, Rgb([255, 0, 0]));
623 let background = Rgb([0, 0, 0]);
624 let squared = expand_to_square(&img, background);
625
626 assert_eq!(squared.width(), 100);
627 assert_eq!(squared.height(), 100);
628 }
629
630 #[test]
631 fn test_stack_batch() {
632 let t1 = Array3::<f32>::zeros((3, 10, 10));
633 let t2 = Array3::<f32>::ones((3, 10, 10));
634
635 let batch = stack_batch(&[t1, t2]).unwrap();
636
637 assert_eq!(batch.shape(), &[2, 3, 10, 10]);
638 }
639
640 #[test]
641 fn test_stack_batch_empty() {
642 let result = stack_batch(&[]);
643 assert!(matches!(result, Err(TransformError::EmptyBatch)));
644 }
645
646 #[test]
647 fn test_pil_to_filter() {
648 assert!(matches!(pil_to_filter(Some(0)), FilterType::Nearest));
649 assert!(matches!(pil_to_filter(Some(1)), FilterType::Lanczos3));
650 assert!(matches!(pil_to_filter(Some(2)), FilterType::Triangle));
651 assert!(matches!(pil_to_filter(Some(3)), FilterType::CatmullRom));
652 assert!(matches!(pil_to_filter(None), FilterType::Triangle));
653 }
654
655 #[test]
656 fn test_mean_to_rgb() {
657 let mean = [0.5, 0.25, 1.0];
658 let rgb = mean_to_rgb(&mean);
659
660 assert_eq!(rgb[0], 128);
661 assert_eq!(rgb[1], 64);
662 assert_eq!(rgb[2], 255);
663 }
664}