Skip to main content

axonml_vision/
transforms.rs

1//! Image Transforms - Vision-Specific Data Augmentation
2//!
3//! # File
4//! `crates/axonml-vision/src/transforms.rs`
5//!
6//! # Author
7//! Andrew Jewell Sr - AutomataNexus
8//!
9//! # Updated
10//! March 8, 2026
11//!
12//! # Disclaimer
13//! Use at own risk. This software is provided "as is", without warranty of any
14//! kind, express or implied. The author and AutomataNexus shall not be held
15//! liable for any damages arising from the use of this software.
16
17use axonml_data::Transform;
18use axonml_tensor::Tensor;
19use rand::Rng;
20
21// =============================================================================
22// Resize
23// =============================================================================
24
25/// Resizes an image to the specified size using bilinear interpolation.
26pub struct Resize {
27    height: usize,
28    width: usize,
29}
30
31impl Resize {
32    /// Creates a new Resize transform.
33    #[must_use]
34    pub fn new(height: usize, width: usize) -> Self {
35        Self { height, width }
36    }
37
38    /// Creates a square Resize transform.
39    #[must_use]
40    pub fn square(size: usize) -> Self {
41        Self::new(size, size)
42    }
43}
44
45impl Transform for Resize {
46    fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
47        let shape = input.shape();
48
49        // Handle different input formats
50        match shape.len() {
51            2 => resize_2d(input, self.height, self.width),
52            3 => resize_3d(input, self.height, self.width),
53            4 => resize_4d(input, self.height, self.width),
54            _ => input.clone(),
55        }
56    }
57}
58
59/// Bilinear interpolation resize for 2D tensor (H x W).
60fn resize_2d(input: &Tensor<f32>, new_h: usize, new_w: usize) -> Tensor<f32> {
61    let shape = input.shape();
62    let (old_h, old_w) = (shape[0], shape[1]);
63    let data = input.to_vec();
64
65    let mut result = vec![0.0; new_h * new_w];
66
67    let scale_h = old_h as f32 / new_h as f32;
68    let scale_w = old_w as f32 / new_w as f32;
69
70    for y in 0..new_h {
71        for x in 0..new_w {
72            let src_y = y as f32 * scale_h;
73            let src_x = x as f32 * scale_w;
74
75            let y0 = (src_y.floor() as usize).min(old_h - 1);
76            let y1 = (y0 + 1).min(old_h - 1);
77            let x0 = (src_x.floor() as usize).min(old_w - 1);
78            let x1 = (x0 + 1).min(old_w - 1);
79
80            let dy = src_y - y0 as f32;
81            let dx = src_x - x0 as f32;
82
83            let v00 = data[y0 * old_w + x0];
84            let v01 = data[y0 * old_w + x1];
85            let v10 = data[y1 * old_w + x0];
86            let v11 = data[y1 * old_w + x1];
87
88            let value = v00 * (1.0 - dx) * (1.0 - dy)
89                + v01 * dx * (1.0 - dy)
90                + v10 * (1.0 - dx) * dy
91                + v11 * dx * dy;
92
93            result[y * new_w + x] = value;
94        }
95    }
96
97    Tensor::from_vec(result, &[new_h, new_w]).unwrap()
98}
99
100/// Bilinear interpolation resize for 3D tensor (C x H x W).
101fn resize_3d(input: &Tensor<f32>, new_h: usize, new_w: usize) -> Tensor<f32> {
102    let shape = input.shape();
103    let (channels, old_h, old_w) = (shape[0], shape[1], shape[2]);
104    let data = input.to_vec();
105
106    let mut result = vec![0.0; channels * new_h * new_w];
107
108    let scale_h = old_h as f32 / new_h as f32;
109    let scale_w = old_w as f32 / new_w as f32;
110
111    for c in 0..channels {
112        for y in 0..new_h {
113            for x in 0..new_w {
114                let src_y = y as f32 * scale_h;
115                let src_x = x as f32 * scale_w;
116
117                let y0 = (src_y.floor() as usize).min(old_h - 1);
118                let y1 = (y0 + 1).min(old_h - 1);
119                let x0 = (src_x.floor() as usize).min(old_w - 1);
120                let x1 = (x0 + 1).min(old_w - 1);
121
122                let dy = src_y - y0 as f32;
123                let dx = src_x - x0 as f32;
124
125                let base = c * old_h * old_w;
126                let v00 = data[base + y0 * old_w + x0];
127                let v01 = data[base + y0 * old_w + x1];
128                let v10 = data[base + y1 * old_w + x0];
129                let v11 = data[base + y1 * old_w + x1];
130
131                let value = v00 * (1.0 - dx) * (1.0 - dy)
132                    + v01 * dx * (1.0 - dy)
133                    + v10 * (1.0 - dx) * dy
134                    + v11 * dx * dy;
135
136                result[c * new_h * new_w + y * new_w + x] = value;
137            }
138        }
139    }
140
141    Tensor::from_vec(result, &[channels, new_h, new_w]).unwrap()
142}
143
144/// Bilinear interpolation resize for 4D tensor (N x C x H x W).
145fn resize_4d(input: &Tensor<f32>, new_h: usize, new_w: usize) -> Tensor<f32> {
146    let shape = input.shape();
147    let (batch, channels, old_h, old_w) = (shape[0], shape[1], shape[2], shape[3]);
148    let data = input.to_vec();
149
150    let mut result = vec![0.0; batch * channels * new_h * new_w];
151
152    let scale_h = old_h as f32 / new_h as f32;
153    let scale_w = old_w as f32 / new_w as f32;
154
155    for n in 0..batch {
156        for c in 0..channels {
157            for y in 0..new_h {
158                for x in 0..new_w {
159                    let src_y = y as f32 * scale_h;
160                    let src_x = x as f32 * scale_w;
161
162                    let y0 = (src_y.floor() as usize).min(old_h - 1);
163                    let y1 = (y0 + 1).min(old_h - 1);
164                    let x0 = (src_x.floor() as usize).min(old_w - 1);
165                    let x1 = (x0 + 1).min(old_w - 1);
166
167                    let dy = src_y - y0 as f32;
168                    let dx = src_x - x0 as f32;
169
170                    let base = n * channels * old_h * old_w + c * old_h * old_w;
171                    let v00 = data[base + y0 * old_w + x0];
172                    let v01 = data[base + y0 * old_w + x1];
173                    let v10 = data[base + y1 * old_w + x0];
174                    let v11 = data[base + y1 * old_w + x1];
175
176                    let value = v00 * (1.0 - dx) * (1.0 - dy)
177                        + v01 * dx * (1.0 - dy)
178                        + v10 * (1.0 - dx) * dy
179                        + v11 * dx * dy;
180
181                    let out_idx = n * channels * new_h * new_w + c * new_h * new_w + y * new_w + x;
182                    result[out_idx] = value;
183                }
184            }
185        }
186    }
187
188    Tensor::from_vec(result, &[batch, channels, new_h, new_w]).unwrap()
189}
190
191// =============================================================================
192// CenterCrop
193// =============================================================================
194
195/// Crops the center of an image to the specified size.
196pub struct CenterCrop {
197    height: usize,
198    width: usize,
199}
200
201impl CenterCrop {
202    /// Creates a new `CenterCrop` transform.
203    #[must_use]
204    pub fn new(height: usize, width: usize) -> Self {
205        Self { height, width }
206    }
207
208    /// Creates a square `CenterCrop` transform.
209    #[must_use]
210    pub fn square(size: usize) -> Self {
211        Self::new(size, size)
212    }
213}
214
215impl Transform for CenterCrop {
216    fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
217        let shape = input.shape();
218        let data = input.to_vec();
219
220        match shape.len() {
221            2 => {
222                let (h, w) = (shape[0], shape[1]);
223                let start_h = (h.saturating_sub(self.height)) / 2;
224                let start_w = (w.saturating_sub(self.width)) / 2;
225                let crop_h = self.height.min(h);
226                let crop_w = self.width.min(w);
227
228                let mut result = Vec::with_capacity(crop_h * crop_w);
229                for y in start_h..start_h + crop_h {
230                    for x in start_w..start_w + crop_w {
231                        result.push(data[y * w + x]);
232                    }
233                }
234                Tensor::from_vec(result, &[crop_h, crop_w]).unwrap()
235            }
236            3 => {
237                let (c, h, w) = (shape[0], shape[1], shape[2]);
238                let start_h = (h.saturating_sub(self.height)) / 2;
239                let start_w = (w.saturating_sub(self.width)) / 2;
240                let crop_h = self.height.min(h);
241                let crop_w = self.width.min(w);
242
243                let mut result = Vec::with_capacity(c * crop_h * crop_w);
244                for ch in 0..c {
245                    for y in start_h..start_h + crop_h {
246                        for x in start_w..start_w + crop_w {
247                            result.push(data[ch * h * w + y * w + x]);
248                        }
249                    }
250                }
251                Tensor::from_vec(result, &[c, crop_h, crop_w]).unwrap()
252            }
253            _ => input.clone(),
254        }
255    }
256}
257
258// =============================================================================
259// RandomHorizontalFlip
260// =============================================================================
261
262/// Randomly flips an image horizontally with given probability.
263pub struct RandomHorizontalFlip {
264    probability: f32,
265}
266
267impl RandomHorizontalFlip {
268    /// Creates a new `RandomHorizontalFlip` with probability 0.5.
269    #[must_use]
270    pub fn new() -> Self {
271        Self { probability: 0.5 }
272    }
273
274    /// Creates a `RandomHorizontalFlip` with custom probability.
275    #[must_use]
276    pub fn with_probability(probability: f32) -> Self {
277        Self {
278            probability: probability.clamp(0.0, 1.0),
279        }
280    }
281}
282
283impl Default for RandomHorizontalFlip {
284    fn default() -> Self {
285        Self::new()
286    }
287}
288
289impl Transform for RandomHorizontalFlip {
290    fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
291        let mut rng = rand::thread_rng();
292        if rng.r#gen::<f32>() > self.probability {
293            return input.clone();
294        }
295
296        let shape = input.shape();
297        let data = input.to_vec();
298
299        match shape.len() {
300            2 => {
301                let (h, w) = (shape[0], shape[1]);
302                let mut result = vec![0.0; h * w];
303                for y in 0..h {
304                    for x in 0..w {
305                        result[y * w + x] = data[y * w + (w - 1 - x)];
306                    }
307                }
308                Tensor::from_vec(result, shape).unwrap()
309            }
310            3 => {
311                let (c, h, w) = (shape[0], shape[1], shape[2]);
312                let mut result = vec![0.0; c * h * w];
313                for ch in 0..c {
314                    for y in 0..h {
315                        for x in 0..w {
316                            result[ch * h * w + y * w + x] = data[ch * h * w + y * w + (w - 1 - x)];
317                        }
318                    }
319                }
320                Tensor::from_vec(result, shape).unwrap()
321            }
322            _ => input.clone(),
323        }
324    }
325}
326
327// =============================================================================
328// RandomVerticalFlip
329// =============================================================================
330
331/// Randomly flips an image vertically with given probability.
332pub struct RandomVerticalFlip {
333    probability: f32,
334}
335
336impl RandomVerticalFlip {
337    /// Creates a new `RandomVerticalFlip` with probability 0.5.
338    #[must_use]
339    pub fn new() -> Self {
340        Self { probability: 0.5 }
341    }
342
343    /// Creates a `RandomVerticalFlip` with custom probability.
344    #[must_use]
345    pub fn with_probability(probability: f32) -> Self {
346        Self {
347            probability: probability.clamp(0.0, 1.0),
348        }
349    }
350}
351
352impl Default for RandomVerticalFlip {
353    fn default() -> Self {
354        Self::new()
355    }
356}
357
358impl Transform for RandomVerticalFlip {
359    fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
360        let mut rng = rand::thread_rng();
361        if rng.r#gen::<f32>() > self.probability {
362            return input.clone();
363        }
364
365        let shape = input.shape();
366        let data = input.to_vec();
367
368        match shape.len() {
369            2 => {
370                let (h, w) = (shape[0], shape[1]);
371                let mut result = vec![0.0; h * w];
372                for y in 0..h {
373                    for x in 0..w {
374                        result[y * w + x] = data[(h - 1 - y) * w + x];
375                    }
376                }
377                Tensor::from_vec(result, shape).unwrap()
378            }
379            3 => {
380                let (c, h, w) = (shape[0], shape[1], shape[2]);
381                let mut result = vec![0.0; c * h * w];
382                for ch in 0..c {
383                    for y in 0..h {
384                        for x in 0..w {
385                            result[ch * h * w + y * w + x] = data[ch * h * w + (h - 1 - y) * w + x];
386                        }
387                    }
388                }
389                Tensor::from_vec(result, shape).unwrap()
390            }
391            _ => input.clone(),
392        }
393    }
394}
395
396// =============================================================================
397// RandomRotation
398// =============================================================================
399
400/// Randomly rotates an image by 90-degree increments.
401pub struct RandomRotation {
402    /// Allowed rotations: 0, 90, 180, 270 degrees
403    angles: Vec<i32>,
404}
405
406impl RandomRotation {
407    /// Creates a `RandomRotation` that can rotate by any 90-degree increment.
408    #[must_use]
409    pub fn new() -> Self {
410        Self {
411            angles: vec![0, 90, 180, 270],
412        }
413    }
414
415    /// Creates a `RandomRotation` with specific allowed angles.
416    #[must_use]
417    pub fn with_angles(angles: Vec<i32>) -> Self {
418        let valid: Vec<i32> = angles
419            .into_iter()
420            .filter(|&a| a == 0 || a == 90 || a == 180 || a == 270)
421            .collect();
422        Self {
423            angles: if valid.is_empty() { vec![0] } else { valid },
424        }
425    }
426}
427
428impl Default for RandomRotation {
429    fn default() -> Self {
430        Self::new()
431    }
432}
433
434impl Transform for RandomRotation {
435    fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
436        let mut rng = rand::thread_rng();
437        let angle = self.angles[rng.gen_range(0..self.angles.len())];
438
439        if angle == 0 {
440            return input.clone();
441        }
442
443        let shape = input.shape();
444        let data = input.to_vec();
445
446        // Only handle 2D (H x W) for simplicity
447        if shape.len() != 2 {
448            return input.clone();
449        }
450
451        let (h, w) = (shape[0], shape[1]);
452
453        match angle {
454            90 => {
455                // Rotate 90 degrees clockwise: (x, y) -> (y, h-1-x)
456                let mut result = vec![0.0; h * w];
457                for y in 0..h {
458                    for x in 0..w {
459                        result[x * h + (h - 1 - y)] = data[y * w + x];
460                    }
461                }
462                Tensor::from_vec(result, &[w, h]).unwrap()
463            }
464            180 => {
465                // Rotate 180 degrees: (x, y) -> (w-1-x, h-1-y)
466                let mut result = vec![0.0; h * w];
467                for y in 0..h {
468                    for x in 0..w {
469                        result[(h - 1 - y) * w + (w - 1 - x)] = data[y * w + x];
470                    }
471                }
472                Tensor::from_vec(result, &[h, w]).unwrap()
473            }
474            270 => {
475                // Rotate 270 degrees clockwise: (x, y) -> (w-1-y, x)
476                let mut result = vec![0.0; h * w];
477                for y in 0..h {
478                    for x in 0..w {
479                        result[(w - 1 - x) * h + y] = data[y * w + x];
480                    }
481                }
482                Tensor::from_vec(result, &[w, h]).unwrap()
483            }
484            _ => input.clone(),
485        }
486    }
487}
488
489// =============================================================================
490// ColorJitter
491// =============================================================================
492
493/// Randomly adjusts brightness, contrast, saturation of an image.
494pub struct ColorJitter {
495    brightness: f32,
496    contrast: f32,
497    saturation: f32,
498}
499
500impl ColorJitter {
501    /// Creates a new `ColorJitter` with specified ranges.
502    #[must_use]
503    pub fn new(brightness: f32, contrast: f32, saturation: f32) -> Self {
504        Self {
505            brightness: brightness.abs(),
506            contrast: contrast.abs(),
507            saturation: saturation.abs(),
508        }
509    }
510}
511
512impl Transform for ColorJitter {
513    fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
514        let mut rng = rand::thread_rng();
515        let mut data = input.to_vec();
516        let shape = input.shape();
517
518        // Apply brightness adjustment
519        if self.brightness > 0.0 {
520            let factor = 1.0 + rng.gen_range(-self.brightness..self.brightness);
521            for val in &mut data {
522                *val = (*val * factor).clamp(0.0, 1.0);
523            }
524        }
525
526        // Apply contrast adjustment
527        if self.contrast > 0.0 {
528            let factor = 1.0 + rng.gen_range(-self.contrast..self.contrast);
529            let mean: f32 = data.iter().sum::<f32>() / data.len() as f32;
530            for val in &mut data {
531                *val = ((*val - mean) * factor + mean).clamp(0.0, 1.0);
532            }
533        }
534
535        // Apply saturation adjustment (simplified - works best with 3-channel images)
536        if self.saturation > 0.0 && shape.len() == 3 && shape[0] == 3 {
537            let factor = 1.0 + rng.gen_range(-self.saturation..self.saturation);
538            let (h, w) = (shape[1], shape[2]);
539
540            for y in 0..h {
541                for x in 0..w {
542                    let r = data[0 * h * w + y * w + x];
543                    let g = data[h * w + y * w + x];
544                    let b = data[2 * h * w + y * w + x];
545
546                    let gray = 0.299 * r + 0.587 * g + 0.114 * b;
547
548                    data[0 * h * w + y * w + x] = (gray + (r - gray) * factor).clamp(0.0, 1.0);
549                    data[h * w + y * w + x] = (gray + (g - gray) * factor).clamp(0.0, 1.0);
550                    data[2 * h * w + y * w + x] = (gray + (b - gray) * factor).clamp(0.0, 1.0);
551                }
552            }
553        }
554
555        Tensor::from_vec(data, shape).unwrap()
556    }
557}
558
559// =============================================================================
560// Grayscale
561// =============================================================================
562
563/// Converts an RGB image to grayscale.
564pub struct Grayscale {
565    num_output_channels: usize,
566}
567
568impl Grayscale {
569    /// Creates a Grayscale transform with 1 output channel.
570    #[must_use]
571    pub fn new() -> Self {
572        Self {
573            num_output_channels: 1,
574        }
575    }
576
577    /// Creates a Grayscale transform with specified output channels.
578    #[must_use]
579    pub fn with_channels(num_output_channels: usize) -> Self {
580        Self {
581            num_output_channels: num_output_channels.max(1),
582        }
583    }
584}
585
586impl Default for Grayscale {
587    fn default() -> Self {
588        Self::new()
589    }
590}
591
592impl Transform for Grayscale {
593    fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
594        let shape = input.shape();
595
596        // Only works with 3-channel images (C x H x W)
597        if shape.len() != 3 || shape[0] != 3 {
598            return input.clone();
599        }
600
601        let (_, h, w) = (shape[0], shape[1], shape[2]);
602        let data = input.to_vec();
603
604        let mut gray = Vec::with_capacity(h * w);
605        for y in 0..h {
606            for x in 0..w {
607                let r = data[0 * h * w + y * w + x];
608                let g = data[h * w + y * w + x];
609                let b = data[2 * h * w + y * w + x];
610                gray.push(0.299 * r + 0.587 * g + 0.114 * b);
611            }
612        }
613
614        if self.num_output_channels == 1 {
615            Tensor::from_vec(gray, &[1, h, w]).unwrap()
616        } else {
617            // Replicate grayscale across channels
618            let mut result = Vec::with_capacity(self.num_output_channels * h * w);
619            for _ in 0..self.num_output_channels {
620                result.extend(&gray);
621            }
622            Tensor::from_vec(result, &[self.num_output_channels, h, w]).unwrap()
623        }
624    }
625}
626
627// =============================================================================
628// Normalize (Image-specific)
629// =============================================================================
630
631/// Normalizes an image with per-channel mean and std.
632pub struct ImageNormalize {
633    mean: Vec<f32>,
634    std: Vec<f32>,
635}
636
637impl ImageNormalize {
638    /// Creates a new `ImageNormalize` with per-channel mean and std.
639    #[must_use]
640    pub fn new(mean: Vec<f32>, std: Vec<f32>) -> Self {
641        Self { mean, std }
642    }
643
644    /// Creates normalization for `ImageNet` pretrained models.
645    #[must_use]
646    pub fn imagenet() -> Self {
647        Self::new(vec![0.485, 0.456, 0.406], vec![0.229, 0.224, 0.225])
648    }
649
650    /// Creates normalization for MNIST (single channel).
651    #[must_use]
652    pub fn mnist() -> Self {
653        Self::new(vec![0.1307], vec![0.3081])
654    }
655
656    /// Creates normalization for CIFAR-10.
657    #[must_use]
658    pub fn cifar10() -> Self {
659        Self::new(vec![0.4914, 0.4822, 0.4465], vec![0.2470, 0.2435, 0.2616])
660    }
661}
662
663impl Transform for ImageNormalize {
664    fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
665        let shape = input.shape();
666        let mut data = input.to_vec();
667
668        match shape.len() {
669            3 => {
670                let (c, h, w) = (shape[0], shape[1], shape[2]);
671                for ch in 0..c {
672                    let mean = self.mean.get(ch).copied().unwrap_or(0.0);
673                    let std = self.std.get(ch).copied().unwrap_or(1.0);
674                    for y in 0..h {
675                        for x in 0..w {
676                            let idx = ch * h * w + y * w + x;
677                            data[idx] = (data[idx] - mean) / std;
678                        }
679                    }
680                }
681            }
682            4 => {
683                let (n, c, h, w) = (shape[0], shape[1], shape[2], shape[3]);
684                for batch in 0..n {
685                    for ch in 0..c {
686                        let mean = self.mean.get(ch).copied().unwrap_or(0.0);
687                        let std = self.std.get(ch).copied().unwrap_or(1.0);
688                        for y in 0..h {
689                            for x in 0..w {
690                                let idx = batch * c * h * w + ch * h * w + y * w + x;
691                                data[idx] = (data[idx] - mean) / std;
692                            }
693                        }
694                    }
695                }
696            }
697            _ => {}
698        }
699
700        Tensor::from_vec(data, shape).unwrap()
701    }
702}
703
704// =============================================================================
705// Pad
706// =============================================================================
707
708/// Pads an image with a constant value.
709pub struct Pad {
710    padding: (usize, usize, usize, usize), // (left, right, top, bottom)
711    fill_value: f32,
712}
713
714impl Pad {
715    /// Creates a new Pad with uniform padding.
716    #[must_use]
717    pub fn new(padding: usize) -> Self {
718        Self {
719            padding: (padding, padding, padding, padding),
720            fill_value: 0.0,
721        }
722    }
723
724    /// Creates a Pad with asymmetric padding.
725    #[must_use]
726    pub fn asymmetric(left: usize, right: usize, top: usize, bottom: usize) -> Self {
727        Self {
728            padding: (left, right, top, bottom),
729            fill_value: 0.0,
730        }
731    }
732
733    /// Sets the fill value.
734    #[must_use]
735    pub fn with_fill(mut self, value: f32) -> Self {
736        self.fill_value = value;
737        self
738    }
739}
740
741impl Transform for Pad {
742    fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
743        let shape = input.shape();
744        let data = input.to_vec();
745        let (left, right, top, bottom) = self.padding;
746
747        match shape.len() {
748            2 => {
749                let (h, w) = (shape[0], shape[1]);
750                let new_h = h + top + bottom;
751                let new_w = w + left + right;
752
753                let mut result = vec![self.fill_value; new_h * new_w];
754                for y in 0..h {
755                    for x in 0..w {
756                        result[(y + top) * new_w + (x + left)] = data[y * w + x];
757                    }
758                }
759                Tensor::from_vec(result, &[new_h, new_w]).unwrap()
760            }
761            3 => {
762                let (c, h, w) = (shape[0], shape[1], shape[2]);
763                let new_h = h + top + bottom;
764                let new_w = w + left + right;
765
766                let mut result = vec![self.fill_value; c * new_h * new_w];
767                for ch in 0..c {
768                    for y in 0..h {
769                        for x in 0..w {
770                            result[ch * new_h * new_w + (y + top) * new_w + (x + left)] =
771                                data[ch * h * w + y * w + x];
772                        }
773                    }
774                }
775                Tensor::from_vec(result, &[c, new_h, new_w]).unwrap()
776            }
777            _ => input.clone(),
778        }
779    }
780}
781
782// =============================================================================
783// ToTensorImage
784// =============================================================================
785
786/// Converts image data from [0, 255] to [0, 1] range.
787pub struct ToTensorImage;
788
789impl ToTensorImage {
790    /// Creates a new `ToTensorImage` transform.
791    #[must_use]
792    pub fn new() -> Self {
793        Self
794    }
795}
796
797impl Default for ToTensorImage {
798    fn default() -> Self {
799        Self::new()
800    }
801}
802
803impl Transform for ToTensorImage {
804    fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
805        let data: Vec<f32> = input.to_vec().iter().map(|&x| x / 255.0).collect();
806        Tensor::from_vec(data, input.shape()).unwrap()
807    }
808}
809
810// =============================================================================
811// Tests
812// =============================================================================
813
814#[cfg(test)]
815mod tests {
816    use super::*;
817
818    #[test]
819    fn test_resize_2d() {
820        let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
821
822        let resize = Resize::new(4, 4);
823        let output = resize.apply(&input);
824
825        assert_eq!(output.shape(), &[4, 4]);
826    }
827
828    #[test]
829    fn test_resize_3d() {
830        let input = Tensor::from_vec(vec![1.0; 3 * 8 * 8], &[3, 8, 8]).unwrap();
831
832        let resize = Resize::new(4, 4);
833        let output = resize.apply(&input);
834
835        assert_eq!(output.shape(), &[3, 4, 4]);
836    }
837
838    #[test]
839    fn test_center_crop() {
840        let input = Tensor::from_vec((1..=16).map(|x| x as f32).collect(), &[4, 4]).unwrap();
841
842        let crop = CenterCrop::new(2, 2);
843        let output = crop.apply(&input);
844
845        assert_eq!(output.shape(), &[2, 2]);
846        // Center 2x2 of 4x4: values 6, 7, 10, 11
847        assert_eq!(output.to_vec(), vec![6.0, 7.0, 10.0, 11.0]);
848    }
849
850    #[test]
851    fn test_random_horizontal_flip() {
852        let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
853
854        let flip = RandomHorizontalFlip::with_probability(1.0);
855        let output = flip.apply(&input);
856
857        // [[1, 2], [3, 4]] -> [[2, 1], [4, 3]]
858        assert_eq!(output.to_vec(), vec![2.0, 1.0, 4.0, 3.0]);
859    }
860
861    #[test]
862    fn test_random_vertical_flip() {
863        let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
864
865        let flip = RandomVerticalFlip::with_probability(1.0);
866        let output = flip.apply(&input);
867
868        // [[1, 2], [3, 4]] -> [[3, 4], [1, 2]]
869        assert_eq!(output.to_vec(), vec![3.0, 4.0, 1.0, 2.0]);
870    }
871
872    #[test]
873    fn test_random_rotation_180() {
874        let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
875
876        let rotation = RandomRotation::with_angles(vec![180]);
877        let output = rotation.apply(&input);
878
879        // [[1, 2], [3, 4]] rotated 180 -> [[4, 3], [2, 1]]
880        assert_eq!(output.to_vec(), vec![4.0, 3.0, 2.0, 1.0]);
881    }
882
883    #[test]
884    fn test_grayscale() {
885        let input = Tensor::from_vec(
886            vec![
887                1.0, 1.0, 1.0, 1.0, // R channel
888                0.5, 0.5, 0.5, 0.5, // G channel
889                0.0, 0.0, 0.0, 0.0, // B channel
890            ],
891            &[3, 2, 2],
892        )
893        .unwrap();
894
895        let gray = Grayscale::new();
896        let output = gray.apply(&input);
897
898        assert_eq!(output.shape(), &[1, 2, 2]);
899        // Gray = 0.299 * 1.0 + 0.587 * 0.5 + 0.114 * 0.0 = 0.5925
900        let expected = 0.299 + 0.587 * 0.5;
901        for val in output.to_vec() {
902            assert!((val - expected).abs() < 0.001);
903        }
904    }
905
906    #[test]
907    fn test_image_normalize() {
908        let input = Tensor::from_vec(vec![0.5; 3 * 2 * 2], &[3, 2, 2]).unwrap();
909
910        let normalize = ImageNormalize::new(vec![0.5, 0.5, 0.5], vec![0.5, 0.5, 0.5]);
911        let output = normalize.apply(&input);
912
913        // (0.5 - 0.5) / 0.5 = 0.0
914        for val in output.to_vec() {
915            assert!((val - 0.0).abs() < 0.001);
916        }
917    }
918
919    #[test]
920    fn test_pad() {
921        let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
922
923        let pad = Pad::new(1);
924        let output = pad.apply(&input);
925
926        assert_eq!(output.shape(), &[4, 4]);
927        // Check corners are zero
928        let data = output.to_vec();
929        assert_eq!(data[0], 0.0);
930        assert_eq!(data[3], 0.0);
931        assert_eq!(data[12], 0.0);
932        assert_eq!(data[15], 0.0);
933        // Check center values
934        assert_eq!(data[5], 1.0);
935        assert_eq!(data[6], 2.0);
936        assert_eq!(data[9], 3.0);
937        assert_eq!(data[10], 4.0);
938    }
939
940    #[test]
941    fn test_to_tensor_image() {
942        let input = Tensor::from_vec(vec![0.0, 127.5, 255.0], &[3]).unwrap();
943
944        let transform = ToTensorImage::new();
945        let output = transform.apply(&input);
946
947        let data = output.to_vec();
948        assert!((data[0] - 0.0).abs() < 0.001);
949        assert!((data[1] - 0.5).abs() < 0.001);
950        assert!((data[2] - 1.0).abs() < 0.001);
951    }
952
953    #[test]
954    fn test_color_jitter() {
955        let input = Tensor::from_vec(vec![0.5; 3 * 4 * 4], &[3, 4, 4]).unwrap();
956
957        let jitter = ColorJitter::new(0.1, 0.1, 0.1);
958        let output = jitter.apply(&input);
959
960        assert_eq!(output.shape(), &[3, 4, 4]);
961        // Values should be in valid range
962        for val in output.to_vec() {
963            assert!((0.0..=1.0).contains(&val));
964        }
965    }
966}