Skip to main content

axonml_vision/
transforms.rs

1//! Image Transforms — Vision-Specific Data Augmentation Pipeline
2//!
3//! Image preprocessing and augmentation transforms implementing the `Transform`
4//! trait from axonml-data. `Resize` performs bilinear interpolation on 2D/3D/4D
5//! tensors. `CenterCrop` extracts the central region. `RandomHorizontalFlip` and
6//! `RandomVerticalFlip` apply stochastic mirroring. `RandomRotation` rotates by
7//! 90-degree increments. `ColorJitter` randomly adjusts brightness, contrast, and
8//! saturation. `Grayscale` converts RGB to single-channel via BT.601 coefficients.
9//! `ImageNormalize` applies per-channel (value - mean) / std with presets for
10//! ImageNet, MNIST, and CIFAR-10. `Pad` adds constant-value borders. `ToTensorImage`
11//! rescales [0, 255] to [0, 1].
12//!
13//! # File
14//! `crates/axonml-vision/src/transforms.rs`
15//!
16//! # Author
17//! Andrew Jewell Sr. — AutomataNexus LLC
18//! ORCID: 0009-0005-2158-7060
19//!
20//! # Updated
21//! April 16, 2026 11:15 PM EST
22//!
23//! # Disclaimer
24//! Use at own risk. This software is provided "as is", without warranty of any
25//! kind, express or implied. The author and AutomataNexus shall not be held
26//! liable for any damages arising from the use of this software.
27
28use axonml_data::Transform;
29use axonml_tensor::Tensor;
30use rand::Rng;
31
32// =============================================================================
33// Resize
34// =============================================================================
35
36/// Resizes an image to the specified size using bilinear interpolation.
37pub struct Resize {
38    height: usize,
39    width: usize,
40}
41
42impl Resize {
43    /// Creates a new Resize transform.
44    #[must_use]
45    pub fn new(height: usize, width: usize) -> Self {
46        Self { height, width }
47    }
48
49    /// Creates a square Resize transform.
50    #[must_use]
51    pub fn square(size: usize) -> Self {
52        Self::new(size, size)
53    }
54}
55
56impl Transform for Resize {
57    fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
58        let shape = input.shape();
59
60        // Handle different input formats
61        match shape.len() {
62            2 => resize_2d(input, self.height, self.width),
63            3 => resize_3d(input, self.height, self.width),
64            4 => resize_4d(input, self.height, self.width),
65            _ => input.clone(),
66        }
67    }
68}
69
70/// Bilinear interpolation resize for 2D tensor (H x W).
71fn resize_2d(input: &Tensor<f32>, new_h: usize, new_w: usize) -> Tensor<f32> {
72    let shape = input.shape();
73    let (old_h, old_w) = (shape[0], shape[1]);
74    let data = input.to_vec();
75
76    let mut result = vec![0.0; new_h * new_w];
77
78    let scale_h = old_h as f32 / new_h as f32;
79    let scale_w = old_w as f32 / new_w as f32;
80
81    for y in 0..new_h {
82        for x in 0..new_w {
83            let src_y = y as f32 * scale_h;
84            let src_x = x as f32 * scale_w;
85
86            let y0 = (src_y.floor() as usize).min(old_h - 1);
87            let y1 = (y0 + 1).min(old_h - 1);
88            let x0 = (src_x.floor() as usize).min(old_w - 1);
89            let x1 = (x0 + 1).min(old_w - 1);
90
91            let dy = src_y - y0 as f32;
92            let dx = src_x - x0 as f32;
93
94            let v00 = data[y0 * old_w + x0];
95            let v01 = data[y0 * old_w + x1];
96            let v10 = data[y1 * old_w + x0];
97            let v11 = data[y1 * old_w + x1];
98
99            let value = v00 * (1.0 - dx) * (1.0 - dy)
100                + v01 * dx * (1.0 - dy)
101                + v10 * (1.0 - dx) * dy
102                + v11 * dx * dy;
103
104            result[y * new_w + x] = value;
105        }
106    }
107
108    Tensor::from_vec(result, &[new_h, new_w]).unwrap()
109}
110
111/// Bilinear interpolation resize for 3D tensor (C x H x W).
112fn resize_3d(input: &Tensor<f32>, new_h: usize, new_w: usize) -> Tensor<f32> {
113    let shape = input.shape();
114    let (channels, old_h, old_w) = (shape[0], shape[1], shape[2]);
115    let data = input.to_vec();
116
117    let mut result = vec![0.0; channels * new_h * new_w];
118
119    let scale_h = old_h as f32 / new_h as f32;
120    let scale_w = old_w as f32 / new_w as f32;
121
122    for c in 0..channels {
123        for y in 0..new_h {
124            for x in 0..new_w {
125                let src_y = y as f32 * scale_h;
126                let src_x = x as f32 * scale_w;
127
128                let y0 = (src_y.floor() as usize).min(old_h - 1);
129                let y1 = (y0 + 1).min(old_h - 1);
130                let x0 = (src_x.floor() as usize).min(old_w - 1);
131                let x1 = (x0 + 1).min(old_w - 1);
132
133                let dy = src_y - y0 as f32;
134                let dx = src_x - x0 as f32;
135
136                let base = c * old_h * old_w;
137                let v00 = data[base + y0 * old_w + x0];
138                let v01 = data[base + y0 * old_w + x1];
139                let v10 = data[base + y1 * old_w + x0];
140                let v11 = data[base + y1 * old_w + x1];
141
142                let value = v00 * (1.0 - dx) * (1.0 - dy)
143                    + v01 * dx * (1.0 - dy)
144                    + v10 * (1.0 - dx) * dy
145                    + v11 * dx * dy;
146
147                result[c * new_h * new_w + y * new_w + x] = value;
148            }
149        }
150    }
151
152    Tensor::from_vec(result, &[channels, new_h, new_w]).unwrap()
153}
154
155/// Bilinear interpolation resize for 4D tensor (N x C x H x W).
156fn resize_4d(input: &Tensor<f32>, new_h: usize, new_w: usize) -> Tensor<f32> {
157    let shape = input.shape();
158    let (batch, channels, old_h, old_w) = (shape[0], shape[1], shape[2], shape[3]);
159    let data = input.to_vec();
160
161    let mut result = vec![0.0; batch * channels * new_h * new_w];
162
163    let scale_h = old_h as f32 / new_h as f32;
164    let scale_w = old_w as f32 / new_w as f32;
165
166    for n in 0..batch {
167        for c in 0..channels {
168            for y in 0..new_h {
169                for x in 0..new_w {
170                    let src_y = y as f32 * scale_h;
171                    let src_x = x as f32 * scale_w;
172
173                    let y0 = (src_y.floor() as usize).min(old_h - 1);
174                    let y1 = (y0 + 1).min(old_h - 1);
175                    let x0 = (src_x.floor() as usize).min(old_w - 1);
176                    let x1 = (x0 + 1).min(old_w - 1);
177
178                    let dy = src_y - y0 as f32;
179                    let dx = src_x - x0 as f32;
180
181                    let base = n * channels * old_h * old_w + c * old_h * old_w;
182                    let v00 = data[base + y0 * old_w + x0];
183                    let v01 = data[base + y0 * old_w + x1];
184                    let v10 = data[base + y1 * old_w + x0];
185                    let v11 = data[base + y1 * old_w + x1];
186
187                    let value = v00 * (1.0 - dx) * (1.0 - dy)
188                        + v01 * dx * (1.0 - dy)
189                        + v10 * (1.0 - dx) * dy
190                        + v11 * dx * dy;
191
192                    let out_idx = n * channels * new_h * new_w + c * new_h * new_w + y * new_w + x;
193                    result[out_idx] = value;
194                }
195            }
196        }
197    }
198
199    Tensor::from_vec(result, &[batch, channels, new_h, new_w]).unwrap()
200}
201
202// =============================================================================
203// CenterCrop
204// =============================================================================
205
206/// Crops the center of an image to the specified size.
207pub struct CenterCrop {
208    height: usize,
209    width: usize,
210}
211
212impl CenterCrop {
213    /// Creates a new `CenterCrop` transform.
214    #[must_use]
215    pub fn new(height: usize, width: usize) -> Self {
216        Self { height, width }
217    }
218
219    /// Creates a square `CenterCrop` transform.
220    #[must_use]
221    pub fn square(size: usize) -> Self {
222        Self::new(size, size)
223    }
224}
225
226impl Transform for CenterCrop {
227    fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
228        let shape = input.shape();
229        let data = input.to_vec();
230
231        match shape.len() {
232            2 => {
233                let (h, w) = (shape[0], shape[1]);
234                let start_h = (h.saturating_sub(self.height)) / 2;
235                let start_w = (w.saturating_sub(self.width)) / 2;
236                let crop_h = self.height.min(h);
237                let crop_w = self.width.min(w);
238
239                let mut result = Vec::with_capacity(crop_h * crop_w);
240                for y in start_h..start_h + crop_h {
241                    for x in start_w..start_w + crop_w {
242                        result.push(data[y * w + x]);
243                    }
244                }
245                Tensor::from_vec(result, &[crop_h, crop_w]).unwrap()
246            }
247            3 => {
248                let (c, h, w) = (shape[0], shape[1], shape[2]);
249                let start_h = (h.saturating_sub(self.height)) / 2;
250                let start_w = (w.saturating_sub(self.width)) / 2;
251                let crop_h = self.height.min(h);
252                let crop_w = self.width.min(w);
253
254                let mut result = Vec::with_capacity(c * crop_h * crop_w);
255                for ch in 0..c {
256                    for y in start_h..start_h + crop_h {
257                        for x in start_w..start_w + crop_w {
258                            result.push(data[ch * h * w + y * w + x]);
259                        }
260                    }
261                }
262                Tensor::from_vec(result, &[c, crop_h, crop_w]).unwrap()
263            }
264            _ => input.clone(),
265        }
266    }
267}
268
269// =============================================================================
270// RandomHorizontalFlip
271// =============================================================================
272
273/// Randomly flips an image horizontally with given probability.
274pub struct RandomHorizontalFlip {
275    probability: f32,
276}
277
278impl RandomHorizontalFlip {
279    /// Creates a new `RandomHorizontalFlip` with probability 0.5.
280    #[must_use]
281    pub fn new() -> Self {
282        Self { probability: 0.5 }
283    }
284
285    /// Creates a `RandomHorizontalFlip` with custom probability.
286    #[must_use]
287    pub fn with_probability(probability: f32) -> Self {
288        Self {
289            probability: probability.clamp(0.0, 1.0),
290        }
291    }
292}
293
294impl Default for RandomHorizontalFlip {
295    fn default() -> Self {
296        Self::new()
297    }
298}
299
300impl Transform for RandomHorizontalFlip {
301    fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
302        let mut rng = rand::thread_rng();
303        if rng.r#gen::<f32>() > self.probability {
304            return input.clone();
305        }
306
307        let shape = input.shape();
308        let data = input.to_vec();
309
310        match shape.len() {
311            2 => {
312                let (h, w) = (shape[0], shape[1]);
313                let mut result = vec![0.0; h * w];
314                for y in 0..h {
315                    for x in 0..w {
316                        result[y * w + x] = data[y * w + (w - 1 - x)];
317                    }
318                }
319                Tensor::from_vec(result, shape).unwrap()
320            }
321            3 => {
322                let (c, h, w) = (shape[0], shape[1], shape[2]);
323                let mut result = vec![0.0; c * h * w];
324                for ch in 0..c {
325                    for y in 0..h {
326                        for x in 0..w {
327                            result[ch * h * w + y * w + x] = data[ch * h * w + y * w + (w - 1 - x)];
328                        }
329                    }
330                }
331                Tensor::from_vec(result, shape).unwrap()
332            }
333            _ => input.clone(),
334        }
335    }
336}
337
338// =============================================================================
339// RandomVerticalFlip
340// =============================================================================
341
342/// Randomly flips an image vertically with given probability.
343pub struct RandomVerticalFlip {
344    probability: f32,
345}
346
347impl RandomVerticalFlip {
348    /// Creates a new `RandomVerticalFlip` with probability 0.5.
349    #[must_use]
350    pub fn new() -> Self {
351        Self { probability: 0.5 }
352    }
353
354    /// Creates a `RandomVerticalFlip` with custom probability.
355    #[must_use]
356    pub fn with_probability(probability: f32) -> Self {
357        Self {
358            probability: probability.clamp(0.0, 1.0),
359        }
360    }
361}
362
363impl Default for RandomVerticalFlip {
364    fn default() -> Self {
365        Self::new()
366    }
367}
368
369impl Transform for RandomVerticalFlip {
370    fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
371        let mut rng = rand::thread_rng();
372        if rng.r#gen::<f32>() > self.probability {
373            return input.clone();
374        }
375
376        let shape = input.shape();
377        let data = input.to_vec();
378
379        match shape.len() {
380            2 => {
381                let (h, w) = (shape[0], shape[1]);
382                let mut result = vec![0.0; h * w];
383                for y in 0..h {
384                    for x in 0..w {
385                        result[y * w + x] = data[(h - 1 - y) * w + x];
386                    }
387                }
388                Tensor::from_vec(result, shape).unwrap()
389            }
390            3 => {
391                let (c, h, w) = (shape[0], shape[1], shape[2]);
392                let mut result = vec![0.0; c * h * w];
393                for ch in 0..c {
394                    for y in 0..h {
395                        for x in 0..w {
396                            result[ch * h * w + y * w + x] = data[ch * h * w + (h - 1 - y) * w + x];
397                        }
398                    }
399                }
400                Tensor::from_vec(result, shape).unwrap()
401            }
402            _ => input.clone(),
403        }
404    }
405}
406
407// =============================================================================
408// RandomRotation
409// =============================================================================
410
411/// Randomly rotates an image by 90-degree increments.
412pub struct RandomRotation {
413    /// Allowed rotations: 0, 90, 180, 270 degrees
414    angles: Vec<i32>,
415}
416
417impl RandomRotation {
418    /// Creates a `RandomRotation` that can rotate by any 90-degree increment.
419    #[must_use]
420    pub fn new() -> Self {
421        Self {
422            angles: vec![0, 90, 180, 270],
423        }
424    }
425
426    /// Creates a `RandomRotation` with specific allowed angles.
427    #[must_use]
428    pub fn with_angles(angles: Vec<i32>) -> Self {
429        let valid: Vec<i32> = angles
430            .into_iter()
431            .filter(|&a| a == 0 || a == 90 || a == 180 || a == 270)
432            .collect();
433        Self {
434            angles: if valid.is_empty() { vec![0] } else { valid },
435        }
436    }
437}
438
439impl Default for RandomRotation {
440    fn default() -> Self {
441        Self::new()
442    }
443}
444
445impl Transform for RandomRotation {
446    fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
447        let mut rng = rand::thread_rng();
448        let angle = self.angles[rng.gen_range(0..self.angles.len())];
449
450        if angle == 0 {
451            return input.clone();
452        }
453
454        let shape = input.shape();
455        let data = input.to_vec();
456
457        // Only handle 2D (H x W) for simplicity
458        if shape.len() != 2 {
459            return input.clone();
460        }
461
462        let (h, w) = (shape[0], shape[1]);
463
464        match angle {
465            90 => {
466                // Rotate 90 degrees clockwise: (x, y) -> (y, h-1-x)
467                let mut result = vec![0.0; h * w];
468                for y in 0..h {
469                    for x in 0..w {
470                        result[x * h + (h - 1 - y)] = data[y * w + x];
471                    }
472                }
473                Tensor::from_vec(result, &[w, h]).unwrap()
474            }
475            180 => {
476                // Rotate 180 degrees: (x, y) -> (w-1-x, h-1-y)
477                let mut result = vec![0.0; h * w];
478                for y in 0..h {
479                    for x in 0..w {
480                        result[(h - 1 - y) * w + (w - 1 - x)] = data[y * w + x];
481                    }
482                }
483                Tensor::from_vec(result, &[h, w]).unwrap()
484            }
485            270 => {
486                // Rotate 270 degrees clockwise: (x, y) -> (w-1-y, x)
487                let mut result = vec![0.0; h * w];
488                for y in 0..h {
489                    for x in 0..w {
490                        result[(w - 1 - x) * h + y] = data[y * w + x];
491                    }
492                }
493                Tensor::from_vec(result, &[w, h]).unwrap()
494            }
495            _ => input.clone(),
496        }
497    }
498}
499
500// =============================================================================
501// ColorJitter
502// =============================================================================
503
504/// Randomly adjusts brightness, contrast, saturation of an image.
505pub struct ColorJitter {
506    brightness: f32,
507    contrast: f32,
508    saturation: f32,
509}
510
511impl ColorJitter {
512    /// Creates a new `ColorJitter` with specified ranges.
513    #[must_use]
514    pub fn new(brightness: f32, contrast: f32, saturation: f32) -> Self {
515        Self {
516            brightness: brightness.abs(),
517            contrast: contrast.abs(),
518            saturation: saturation.abs(),
519        }
520    }
521}
522
523impl Transform for ColorJitter {
524    fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
525        let mut rng = rand::thread_rng();
526        let mut data = input.to_vec();
527        let shape = input.shape();
528
529        // Apply brightness adjustment
530        if self.brightness > 0.0 {
531            let factor = 1.0 + rng.gen_range(-self.brightness..self.brightness);
532            for val in &mut data {
533                *val = (*val * factor).clamp(0.0, 1.0);
534            }
535        }
536
537        // Apply contrast adjustment
538        if self.contrast > 0.0 {
539            let factor = 1.0 + rng.gen_range(-self.contrast..self.contrast);
540            let mean: f32 = data.iter().sum::<f32>() / data.len() as f32;
541            for val in &mut data {
542                *val = ((*val - mean) * factor + mean).clamp(0.0, 1.0);
543            }
544        }
545
546        // Apply saturation adjustment (simplified - works best with 3-channel images)
547        if self.saturation > 0.0 && shape.len() == 3 && shape[0] == 3 {
548            let factor = 1.0 + rng.gen_range(-self.saturation..self.saturation);
549            let (h, w) = (shape[1], shape[2]);
550
551            for y in 0..h {
552                for x in 0..w {
553                    let r = data[0 * h * w + y * w + x];
554                    let g = data[h * w + y * w + x];
555                    let b = data[2 * h * w + y * w + x];
556
557                    let gray = 0.299 * r + 0.587 * g + 0.114 * b;
558
559                    data[0 * h * w + y * w + x] = (gray + (r - gray) * factor).clamp(0.0, 1.0);
560                    data[h * w + y * w + x] = (gray + (g - gray) * factor).clamp(0.0, 1.0);
561                    data[2 * h * w + y * w + x] = (gray + (b - gray) * factor).clamp(0.0, 1.0);
562                }
563            }
564        }
565
566        Tensor::from_vec(data, shape).unwrap()
567    }
568}
569
570// =============================================================================
571// Grayscale
572// =============================================================================
573
574/// Converts an RGB image to grayscale.
575pub struct Grayscale {
576    num_output_channels: usize,
577}
578
579impl Grayscale {
580    /// Creates a Grayscale transform with 1 output channel.
581    #[must_use]
582    pub fn new() -> Self {
583        Self {
584            num_output_channels: 1,
585        }
586    }
587
588    /// Creates a Grayscale transform with specified output channels.
589    #[must_use]
590    pub fn with_channels(num_output_channels: usize) -> Self {
591        Self {
592            num_output_channels: num_output_channels.max(1),
593        }
594    }
595}
596
597impl Default for Grayscale {
598    fn default() -> Self {
599        Self::new()
600    }
601}
602
603impl Transform for Grayscale {
604    fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
605        let shape = input.shape();
606
607        // Only works with 3-channel images (C x H x W)
608        if shape.len() != 3 || shape[0] != 3 {
609            return input.clone();
610        }
611
612        let (_, h, w) = (shape[0], shape[1], shape[2]);
613        let data = input.to_vec();
614
615        let mut gray = Vec::with_capacity(h * w);
616        for y in 0..h {
617            for x in 0..w {
618                let r = data[0 * h * w + y * w + x];
619                let g = data[h * w + y * w + x];
620                let b = data[2 * h * w + y * w + x];
621                gray.push(0.299 * r + 0.587 * g + 0.114 * b);
622            }
623        }
624
625        if self.num_output_channels == 1 {
626            Tensor::from_vec(gray, &[1, h, w]).unwrap()
627        } else {
628            // Replicate grayscale across channels
629            let mut result = Vec::with_capacity(self.num_output_channels * h * w);
630            for _ in 0..self.num_output_channels {
631                result.extend(&gray);
632            }
633            Tensor::from_vec(result, &[self.num_output_channels, h, w]).unwrap()
634        }
635    }
636}
637
638// =============================================================================
639// Normalize (Image-specific)
640// =============================================================================
641
642/// Normalizes an image with per-channel mean and std.
643pub struct ImageNormalize {
644    mean: Vec<f32>,
645    std: Vec<f32>,
646}
647
648impl ImageNormalize {
649    /// Creates a new `ImageNormalize` with per-channel mean and std.
650    #[must_use]
651    pub fn new(mean: Vec<f32>, std: Vec<f32>) -> Self {
652        Self { mean, std }
653    }
654
655    /// Creates normalization for `ImageNet` pretrained models.
656    #[must_use]
657    pub fn imagenet() -> Self {
658        Self::new(vec![0.485, 0.456, 0.406], vec![0.229, 0.224, 0.225])
659    }
660
661    /// Creates normalization for MNIST (single channel).
662    #[must_use]
663    pub fn mnist() -> Self {
664        Self::new(vec![0.1307], vec![0.3081])
665    }
666
667    /// Creates normalization for CIFAR-10.
668    #[must_use]
669    pub fn cifar10() -> Self {
670        Self::new(vec![0.4914, 0.4822, 0.4465], vec![0.2470, 0.2435, 0.2616])
671    }
672}
673
674impl Transform for ImageNormalize {
675    fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
676        let shape = input.shape();
677        let mut data = input.to_vec();
678
679        match shape.len() {
680            3 => {
681                let (c, h, w) = (shape[0], shape[1], shape[2]);
682                for ch in 0..c {
683                    let mean = self.mean.get(ch).copied().unwrap_or(0.0);
684                    let std = self.std.get(ch).copied().unwrap_or(1.0);
685                    for y in 0..h {
686                        for x in 0..w {
687                            let idx = ch * h * w + y * w + x;
688                            data[idx] = (data[idx] - mean) / std;
689                        }
690                    }
691                }
692            }
693            4 => {
694                let (n, c, h, w) = (shape[0], shape[1], shape[2], shape[3]);
695                for batch in 0..n {
696                    for ch in 0..c {
697                        let mean = self.mean.get(ch).copied().unwrap_or(0.0);
698                        let std = self.std.get(ch).copied().unwrap_or(1.0);
699                        for y in 0..h {
700                            for x in 0..w {
701                                let idx = batch * c * h * w + ch * h * w + y * w + x;
702                                data[idx] = (data[idx] - mean) / std;
703                            }
704                        }
705                    }
706                }
707            }
708            _ => {}
709        }
710
711        Tensor::from_vec(data, shape).unwrap()
712    }
713}
714
715// =============================================================================
716// Pad
717// =============================================================================
718
719/// Pads an image with a constant value.
720pub struct Pad {
721    padding: (usize, usize, usize, usize), // (left, right, top, bottom)
722    fill_value: f32,
723}
724
725impl Pad {
726    /// Creates a new Pad with uniform padding.
727    #[must_use]
728    pub fn new(padding: usize) -> Self {
729        Self {
730            padding: (padding, padding, padding, padding),
731            fill_value: 0.0,
732        }
733    }
734
735    /// Creates a Pad with asymmetric padding.
736    #[must_use]
737    pub fn asymmetric(left: usize, right: usize, top: usize, bottom: usize) -> Self {
738        Self {
739            padding: (left, right, top, bottom),
740            fill_value: 0.0,
741        }
742    }
743
744    /// Sets the fill value.
745    #[must_use]
746    pub fn with_fill(mut self, value: f32) -> Self {
747        self.fill_value = value;
748        self
749    }
750}
751
752impl Transform for Pad {
753    fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
754        let shape = input.shape();
755        let data = input.to_vec();
756        let (left, right, top, bottom) = self.padding;
757
758        match shape.len() {
759            2 => {
760                let (h, w) = (shape[0], shape[1]);
761                let new_h = h + top + bottom;
762                let new_w = w + left + right;
763
764                let mut result = vec![self.fill_value; new_h * new_w];
765                for y in 0..h {
766                    for x in 0..w {
767                        result[(y + top) * new_w + (x + left)] = data[y * w + x];
768                    }
769                }
770                Tensor::from_vec(result, &[new_h, new_w]).unwrap()
771            }
772            3 => {
773                let (c, h, w) = (shape[0], shape[1], shape[2]);
774                let new_h = h + top + bottom;
775                let new_w = w + left + right;
776
777                let mut result = vec![self.fill_value; c * new_h * new_w];
778                for ch in 0..c {
779                    for y in 0..h {
780                        for x in 0..w {
781                            result[ch * new_h * new_w + (y + top) * new_w + (x + left)] =
782                                data[ch * h * w + y * w + x];
783                        }
784                    }
785                }
786                Tensor::from_vec(result, &[c, new_h, new_w]).unwrap()
787            }
788            _ => input.clone(),
789        }
790    }
791}
792
793// =============================================================================
794// ToTensorImage
795// =============================================================================
796
797/// Converts image data from [0, 255] to [0, 1] range.
798pub struct ToTensorImage;
799
800impl ToTensorImage {
801    /// Creates a new `ToTensorImage` transform.
802    #[must_use]
803    pub fn new() -> Self {
804        Self
805    }
806}
807
808impl Default for ToTensorImage {
809    fn default() -> Self {
810        Self::new()
811    }
812}
813
814impl Transform for ToTensorImage {
815    fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
816        let data: Vec<f32> = input.to_vec().iter().map(|&x| x / 255.0).collect();
817        Tensor::from_vec(data, input.shape()).unwrap()
818    }
819}
820
821// =============================================================================
822// Tests
823// =============================================================================
824
825#[cfg(test)]
826mod tests {
827    use super::*;
828
829    #[test]
830    fn test_resize_2d() {
831        let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
832
833        let resize = Resize::new(4, 4);
834        let output = resize.apply(&input);
835
836        assert_eq!(output.shape(), &[4, 4]);
837    }
838
839    #[test]
840    fn test_resize_3d() {
841        let input = Tensor::from_vec(vec![1.0; 3 * 8 * 8], &[3, 8, 8]).unwrap();
842
843        let resize = Resize::new(4, 4);
844        let output = resize.apply(&input);
845
846        assert_eq!(output.shape(), &[3, 4, 4]);
847    }
848
849    #[test]
850    fn test_center_crop() {
851        let input = Tensor::from_vec((1..=16).map(|x| x as f32).collect(), &[4, 4]).unwrap();
852
853        let crop = CenterCrop::new(2, 2);
854        let output = crop.apply(&input);
855
856        assert_eq!(output.shape(), &[2, 2]);
857        // Center 2x2 of 4x4: values 6, 7, 10, 11
858        assert_eq!(output.to_vec(), vec![6.0, 7.0, 10.0, 11.0]);
859    }
860
861    #[test]
862    fn test_random_horizontal_flip() {
863        let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
864
865        let flip = RandomHorizontalFlip::with_probability(1.0);
866        let output = flip.apply(&input);
867
868        // [[1, 2], [3, 4]] -> [[2, 1], [4, 3]]
869        assert_eq!(output.to_vec(), vec![2.0, 1.0, 4.0, 3.0]);
870    }
871
872    #[test]
873    fn test_random_vertical_flip() {
874        let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
875
876        let flip = RandomVerticalFlip::with_probability(1.0);
877        let output = flip.apply(&input);
878
879        // [[1, 2], [3, 4]] -> [[3, 4], [1, 2]]
880        assert_eq!(output.to_vec(), vec![3.0, 4.0, 1.0, 2.0]);
881    }
882
883    #[test]
884    fn test_random_rotation_180() {
885        let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
886
887        let rotation = RandomRotation::with_angles(vec![180]);
888        let output = rotation.apply(&input);
889
890        // [[1, 2], [3, 4]] rotated 180 -> [[4, 3], [2, 1]]
891        assert_eq!(output.to_vec(), vec![4.0, 3.0, 2.0, 1.0]);
892    }
893
894    #[test]
895    fn test_grayscale() {
896        let input = Tensor::from_vec(
897            vec![
898                1.0, 1.0, 1.0, 1.0, // R channel
899                0.5, 0.5, 0.5, 0.5, // G channel
900                0.0, 0.0, 0.0, 0.0, // B channel
901            ],
902            &[3, 2, 2],
903        )
904        .unwrap();
905
906        let gray = Grayscale::new();
907        let output = gray.apply(&input);
908
909        assert_eq!(output.shape(), &[1, 2, 2]);
910        // Gray = 0.299 * 1.0 + 0.587 * 0.5 + 0.114 * 0.0 = 0.5925
911        let expected = 0.299 + 0.587 * 0.5;
912        for val in output.to_vec() {
913            assert!((val - expected).abs() < 0.001);
914        }
915    }
916
917    #[test]
918    fn test_image_normalize() {
919        let input = Tensor::from_vec(vec![0.5; 3 * 2 * 2], &[3, 2, 2]).unwrap();
920
921        let normalize = ImageNormalize::new(vec![0.5, 0.5, 0.5], vec![0.5, 0.5, 0.5]);
922        let output = normalize.apply(&input);
923
924        // (0.5 - 0.5) / 0.5 = 0.0
925        for val in output.to_vec() {
926            assert!((val - 0.0).abs() < 0.001);
927        }
928    }
929
930    #[test]
931    fn test_pad() {
932        let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
933
934        let pad = Pad::new(1);
935        let output = pad.apply(&input);
936
937        assert_eq!(output.shape(), &[4, 4]);
938        // Check corners are zero
939        let data = output.to_vec();
940        assert_eq!(data[0], 0.0);
941        assert_eq!(data[3], 0.0);
942        assert_eq!(data[12], 0.0);
943        assert_eq!(data[15], 0.0);
944        // Check center values
945        assert_eq!(data[5], 1.0);
946        assert_eq!(data[6], 2.0);
947        assert_eq!(data[9], 3.0);
948        assert_eq!(data[10], 4.0);
949    }
950
951    #[test]
952    fn test_to_tensor_image() {
953        let input = Tensor::from_vec(vec![0.0, 127.5, 255.0], &[3]).unwrap();
954
955        let transform = ToTensorImage::new();
956        let output = transform.apply(&input);
957
958        let data = output.to_vec();
959        assert!((data[0] - 0.0).abs() < 0.001);
960        assert!((data[1] - 0.5).abs() < 0.001);
961        assert!((data[2] - 1.0).abs() < 0.001);
962    }
963
964    #[test]
965    fn test_color_jitter() {
966        let input = Tensor::from_vec(vec![0.5; 3 * 4 * 4], &[3, 4, 4]).unwrap();
967
968        let jitter = ColorJitter::new(0.1, 0.1, 0.1);
969        let output = jitter.apply(&input);
970
971        assert_eq!(output.shape(), &[3, 4, 4]);
972        // Values should be in valid range
973        for val in output.to_vec() {
974            assert!((0.0..=1.0).contains(&val));
975        }
976    }
977}