Skip to main content

axonml_vision/
transforms.rs

1//! Image Transforms - Vision-Specific Data Augmentation
2//!
3//! Provides image-specific transformations for data augmentation and preprocessing.
4//!
5//! @version 0.1.0
6//! @author `AutomataNexus` Development Team
7
8use axonml_data::Transform;
9use axonml_tensor::Tensor;
10use rand::Rng;
11
12// =============================================================================
13// Resize
14// =============================================================================
15
16/// Resizes an image to the specified size using bilinear interpolation.
17pub struct Resize {
18    height: usize,
19    width: usize,
20}
21
22impl Resize {
23    /// Creates a new Resize transform.
24    #[must_use] pub fn new(height: usize, width: usize) -> Self {
25        Self { height, width }
26    }
27
28    /// Creates a square Resize transform.
29    #[must_use] pub fn square(size: usize) -> Self {
30        Self::new(size, size)
31    }
32}
33
34impl Transform for Resize {
35    fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
36        let shape = input.shape();
37
38        // Handle different input formats
39        match shape.len() {
40            2 => resize_2d(input, self.height, self.width),
41            3 => resize_3d(input, self.height, self.width),
42            4 => resize_4d(input, self.height, self.width),
43            _ => input.clone(),
44        }
45    }
46}
47
48/// Bilinear interpolation resize for 2D tensor (H x W).
49fn resize_2d(input: &Tensor<f32>, new_h: usize, new_w: usize) -> Tensor<f32> {
50    let shape = input.shape();
51    let (old_h, old_w) = (shape[0], shape[1]);
52    let data = input.to_vec();
53
54    let mut result = vec![0.0; new_h * new_w];
55
56    let scale_h = old_h as f32 / new_h as f32;
57    let scale_w = old_w as f32 / new_w as f32;
58
59    for y in 0..new_h {
60        for x in 0..new_w {
61            let src_y = y as f32 * scale_h;
62            let src_x = x as f32 * scale_w;
63
64            let y0 = (src_y.floor() as usize).min(old_h - 1);
65            let y1 = (y0 + 1).min(old_h - 1);
66            let x0 = (src_x.floor() as usize).min(old_w - 1);
67            let x1 = (x0 + 1).min(old_w - 1);
68
69            let dy = src_y - y0 as f32;
70            let dx = src_x - x0 as f32;
71
72            let v00 = data[y0 * old_w + x0];
73            let v01 = data[y0 * old_w + x1];
74            let v10 = data[y1 * old_w + x0];
75            let v11 = data[y1 * old_w + x1];
76
77            let value = v00 * (1.0 - dx) * (1.0 - dy)
78                + v01 * dx * (1.0 - dy)
79                + v10 * (1.0 - dx) * dy
80                + v11 * dx * dy;
81
82            result[y * new_w + x] = value;
83        }
84    }
85
86    Tensor::from_vec(result, &[new_h, new_w]).unwrap()
87}
88
89/// Bilinear interpolation resize for 3D tensor (C x H x W).
90fn resize_3d(input: &Tensor<f32>, new_h: usize, new_w: usize) -> Tensor<f32> {
91    let shape = input.shape();
92    let (channels, old_h, old_w) = (shape[0], shape[1], shape[2]);
93    let data = input.to_vec();
94
95    let mut result = vec![0.0; channels * new_h * new_w];
96
97    let scale_h = old_h as f32 / new_h as f32;
98    let scale_w = old_w as f32 / new_w as f32;
99
100    for c in 0..channels {
101        for y in 0..new_h {
102            for x in 0..new_w {
103                let src_y = y as f32 * scale_h;
104                let src_x = x as f32 * scale_w;
105
106                let y0 = (src_y.floor() as usize).min(old_h - 1);
107                let y1 = (y0 + 1).min(old_h - 1);
108                let x0 = (src_x.floor() as usize).min(old_w - 1);
109                let x1 = (x0 + 1).min(old_w - 1);
110
111                let dy = src_y - y0 as f32;
112                let dx = src_x - x0 as f32;
113
114                let base = c * old_h * old_w;
115                let v00 = data[base + y0 * old_w + x0];
116                let v01 = data[base + y0 * old_w + x1];
117                let v10 = data[base + y1 * old_w + x0];
118                let v11 = data[base + y1 * old_w + x1];
119
120                let value = v00 * (1.0 - dx) * (1.0 - dy)
121                    + v01 * dx * (1.0 - dy)
122                    + v10 * (1.0 - dx) * dy
123                    + v11 * dx * dy;
124
125                result[c * new_h * new_w + y * new_w + x] = value;
126            }
127        }
128    }
129
130    Tensor::from_vec(result, &[channels, new_h, new_w]).unwrap()
131}
132
133/// Bilinear interpolation resize for 4D tensor (N x C x H x W).
134fn resize_4d(input: &Tensor<f32>, new_h: usize, new_w: usize) -> Tensor<f32> {
135    let shape = input.shape();
136    let (batch, channels, old_h, old_w) = (shape[0], shape[1], shape[2], shape[3]);
137    let data = input.to_vec();
138
139    let mut result = vec![0.0; batch * channels * new_h * new_w];
140
141    let scale_h = old_h as f32 / new_h as f32;
142    let scale_w = old_w as f32 / new_w as f32;
143
144    for n in 0..batch {
145        for c in 0..channels {
146            for y in 0..new_h {
147                for x in 0..new_w {
148                    let src_y = y as f32 * scale_h;
149                    let src_x = x as f32 * scale_w;
150
151                    let y0 = (src_y.floor() as usize).min(old_h - 1);
152                    let y1 = (y0 + 1).min(old_h - 1);
153                    let x0 = (src_x.floor() as usize).min(old_w - 1);
154                    let x1 = (x0 + 1).min(old_w - 1);
155
156                    let dy = src_y - y0 as f32;
157                    let dx = src_x - x0 as f32;
158
159                    let base = n * channels * old_h * old_w + c * old_h * old_w;
160                    let v00 = data[base + y0 * old_w + x0];
161                    let v01 = data[base + y0 * old_w + x1];
162                    let v10 = data[base + y1 * old_w + x0];
163                    let v11 = data[base + y1 * old_w + x1];
164
165                    let value = v00 * (1.0 - dx) * (1.0 - dy)
166                        + v01 * dx * (1.0 - dy)
167                        + v10 * (1.0 - dx) * dy
168                        + v11 * dx * dy;
169
170                    let out_idx = n * channels * new_h * new_w + c * new_h * new_w + y * new_w + x;
171                    result[out_idx] = value;
172                }
173            }
174        }
175    }
176
177    Tensor::from_vec(result, &[batch, channels, new_h, new_w]).unwrap()
178}
179
180// =============================================================================
181// CenterCrop
182// =============================================================================
183
184/// Crops the center of an image to the specified size.
185pub struct CenterCrop {
186    height: usize,
187    width: usize,
188}
189
190impl CenterCrop {
191    /// Creates a new `CenterCrop` transform.
192    #[must_use] pub fn new(height: usize, width: usize) -> Self {
193        Self { height, width }
194    }
195
196    /// Creates a square `CenterCrop` transform.
197    #[must_use] pub fn square(size: usize) -> Self {
198        Self::new(size, size)
199    }
200}
201
202impl Transform for CenterCrop {
203    fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
204        let shape = input.shape();
205        let data = input.to_vec();
206
207        match shape.len() {
208            2 => {
209                let (h, w) = (shape[0], shape[1]);
210                let start_h = (h.saturating_sub(self.height)) / 2;
211                let start_w = (w.saturating_sub(self.width)) / 2;
212                let crop_h = self.height.min(h);
213                let crop_w = self.width.min(w);
214
215                let mut result = Vec::with_capacity(crop_h * crop_w);
216                for y in start_h..start_h + crop_h {
217                    for x in start_w..start_w + crop_w {
218                        result.push(data[y * w + x]);
219                    }
220                }
221                Tensor::from_vec(result, &[crop_h, crop_w]).unwrap()
222            }
223            3 => {
224                let (c, h, w) = (shape[0], shape[1], shape[2]);
225                let start_h = (h.saturating_sub(self.height)) / 2;
226                let start_w = (w.saturating_sub(self.width)) / 2;
227                let crop_h = self.height.min(h);
228                let crop_w = self.width.min(w);
229
230                let mut result = Vec::with_capacity(c * crop_h * crop_w);
231                for ch in 0..c {
232                    for y in start_h..start_h + crop_h {
233                        for x in start_w..start_w + crop_w {
234                            result.push(data[ch * h * w + y * w + x]);
235                        }
236                    }
237                }
238                Tensor::from_vec(result, &[c, crop_h, crop_w]).unwrap()
239            }
240            _ => input.clone(),
241        }
242    }
243}
244
245// =============================================================================
246// RandomHorizontalFlip
247// =============================================================================
248
249/// Randomly flips an image horizontally with given probability.
250pub struct RandomHorizontalFlip {
251    probability: f32,
252}
253
254impl RandomHorizontalFlip {
255    /// Creates a new `RandomHorizontalFlip` with probability 0.5.
256    #[must_use] pub fn new() -> Self {
257        Self { probability: 0.5 }
258    }
259
260    /// Creates a `RandomHorizontalFlip` with custom probability.
261    #[must_use] pub fn with_probability(probability: f32) -> Self {
262        Self {
263            probability: probability.clamp(0.0, 1.0),
264        }
265    }
266}
267
268impl Default for RandomHorizontalFlip {
269    fn default() -> Self {
270        Self::new()
271    }
272}
273
274impl Transform for RandomHorizontalFlip {
275    fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
276        let mut rng = rand::thread_rng();
277        if rng.gen::<f32>() > self.probability {
278            return input.clone();
279        }
280
281        let shape = input.shape();
282        let data = input.to_vec();
283
284        match shape.len() {
285            2 => {
286                let (h, w) = (shape[0], shape[1]);
287                let mut result = vec![0.0; h * w];
288                for y in 0..h {
289                    for x in 0..w {
290                        result[y * w + x] = data[y * w + (w - 1 - x)];
291                    }
292                }
293                Tensor::from_vec(result, shape).unwrap()
294            }
295            3 => {
296                let (c, h, w) = (shape[0], shape[1], shape[2]);
297                let mut result = vec![0.0; c * h * w];
298                for ch in 0..c {
299                    for y in 0..h {
300                        for x in 0..w {
301                            result[ch * h * w + y * w + x] = data[ch * h * w + y * w + (w - 1 - x)];
302                        }
303                    }
304                }
305                Tensor::from_vec(result, shape).unwrap()
306            }
307            _ => input.clone(),
308        }
309    }
310}
311
312// =============================================================================
313// RandomVerticalFlip
314// =============================================================================
315
316/// Randomly flips an image vertically with given probability.
317pub struct RandomVerticalFlip {
318    probability: f32,
319}
320
321impl RandomVerticalFlip {
322    /// Creates a new `RandomVerticalFlip` with probability 0.5.
323    #[must_use] pub fn new() -> Self {
324        Self { probability: 0.5 }
325    }
326
327    /// Creates a `RandomVerticalFlip` with custom probability.
328    #[must_use] pub fn with_probability(probability: f32) -> Self {
329        Self {
330            probability: probability.clamp(0.0, 1.0),
331        }
332    }
333}
334
335impl Default for RandomVerticalFlip {
336    fn default() -> Self {
337        Self::new()
338    }
339}
340
341impl Transform for RandomVerticalFlip {
342    fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
343        let mut rng = rand::thread_rng();
344        if rng.gen::<f32>() > self.probability {
345            return input.clone();
346        }
347
348        let shape = input.shape();
349        let data = input.to_vec();
350
351        match shape.len() {
352            2 => {
353                let (h, w) = (shape[0], shape[1]);
354                let mut result = vec![0.0; h * w];
355                for y in 0..h {
356                    for x in 0..w {
357                        result[y * w + x] = data[(h - 1 - y) * w + x];
358                    }
359                }
360                Tensor::from_vec(result, shape).unwrap()
361            }
362            3 => {
363                let (c, h, w) = (shape[0], shape[1], shape[2]);
364                let mut result = vec![0.0; c * h * w];
365                for ch in 0..c {
366                    for y in 0..h {
367                        for x in 0..w {
368                            result[ch * h * w + y * w + x] = data[ch * h * w + (h - 1 - y) * w + x];
369                        }
370                    }
371                }
372                Tensor::from_vec(result, shape).unwrap()
373            }
374            _ => input.clone(),
375        }
376    }
377}
378
379// =============================================================================
380// RandomRotation
381// =============================================================================
382
383/// Randomly rotates an image by 90-degree increments.
384pub struct RandomRotation {
385    /// Allowed rotations: 0, 90, 180, 270 degrees
386    angles: Vec<i32>,
387}
388
389impl RandomRotation {
390    /// Creates a `RandomRotation` that can rotate by any 90-degree increment.
391    #[must_use] pub fn new() -> Self {
392        Self {
393            angles: vec![0, 90, 180, 270],
394        }
395    }
396
397    /// Creates a `RandomRotation` with specific allowed angles.
398    #[must_use] pub fn with_angles(angles: Vec<i32>) -> Self {
399        let valid: Vec<i32> = angles
400            .into_iter()
401            .filter(|&a| a == 0 || a == 90 || a == 180 || a == 270)
402            .collect();
403        Self {
404            angles: if valid.is_empty() { vec![0] } else { valid },
405        }
406    }
407}
408
409impl Default for RandomRotation {
410    fn default() -> Self {
411        Self::new()
412    }
413}
414
415impl Transform for RandomRotation {
416    fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
417        let mut rng = rand::thread_rng();
418        let angle = self.angles[rng.gen_range(0..self.angles.len())];
419
420        if angle == 0 {
421            return input.clone();
422        }
423
424        let shape = input.shape();
425        let data = input.to_vec();
426
427        // Only handle 2D (H x W) for simplicity
428        if shape.len() != 2 {
429            return input.clone();
430        }
431
432        let (h, w) = (shape[0], shape[1]);
433
434        match angle {
435            90 => {
436                // Rotate 90 degrees clockwise: (x, y) -> (y, h-1-x)
437                let mut result = vec![0.0; h * w];
438                for y in 0..h {
439                    for x in 0..w {
440                        result[x * h + (h - 1 - y)] = data[y * w + x];
441                    }
442                }
443                Tensor::from_vec(result, &[w, h]).unwrap()
444            }
445            180 => {
446                // Rotate 180 degrees: (x, y) -> (w-1-x, h-1-y)
447                let mut result = vec![0.0; h * w];
448                for y in 0..h {
449                    for x in 0..w {
450                        result[(h - 1 - y) * w + (w - 1 - x)] = data[y * w + x];
451                    }
452                }
453                Tensor::from_vec(result, &[h, w]).unwrap()
454            }
455            270 => {
456                // Rotate 270 degrees clockwise: (x, y) -> (w-1-y, x)
457                let mut result = vec![0.0; h * w];
458                for y in 0..h {
459                    for x in 0..w {
460                        result[(w - 1 - x) * h + y] = data[y * w + x];
461                    }
462                }
463                Tensor::from_vec(result, &[w, h]).unwrap()
464            }
465            _ => input.clone(),
466        }
467    }
468}
469
470// =============================================================================
471// ColorJitter
472// =============================================================================
473
474/// Randomly adjusts brightness, contrast, saturation of an image.
475pub struct ColorJitter {
476    brightness: f32,
477    contrast: f32,
478    saturation: f32,
479}
480
481impl ColorJitter {
482    /// Creates a new `ColorJitter` with specified ranges.
483    #[must_use] pub fn new(brightness: f32, contrast: f32, saturation: f32) -> Self {
484        Self {
485            brightness: brightness.abs(),
486            contrast: contrast.abs(),
487            saturation: saturation.abs(),
488        }
489    }
490}
491
492impl Transform for ColorJitter {
493    fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
494        let mut rng = rand::thread_rng();
495        let mut data = input.to_vec();
496        let shape = input.shape();
497
498        // Apply brightness adjustment
499        if self.brightness > 0.0 {
500            let factor = 1.0 + rng.gen_range(-self.brightness..self.brightness);
501            for val in &mut data {
502                *val = (*val * factor).clamp(0.0, 1.0);
503            }
504        }
505
506        // Apply contrast adjustment
507        if self.contrast > 0.0 {
508            let factor = 1.0 + rng.gen_range(-self.contrast..self.contrast);
509            let mean: f32 = data.iter().sum::<f32>() / data.len() as f32;
510            for val in &mut data {
511                *val = ((*val - mean) * factor + mean).clamp(0.0, 1.0);
512            }
513        }
514
515        // Apply saturation adjustment (simplified - works best with 3-channel images)
516        if self.saturation > 0.0 && shape.len() == 3 && shape[0] == 3 {
517            let factor = 1.0 + rng.gen_range(-self.saturation..self.saturation);
518            let (h, w) = (shape[1], shape[2]);
519
520            for y in 0..h {
521                for x in 0..w {
522                    let r = data[0 * h * w + y * w + x];
523                    let g = data[h * w + y * w + x];
524                    let b = data[2 * h * w + y * w + x];
525
526                    let gray = 0.299 * r + 0.587 * g + 0.114 * b;
527
528                    data[0 * h * w + y * w + x] = (gray + (r - gray) * factor).clamp(0.0, 1.0);
529                    data[h * w + y * w + x] = (gray + (g - gray) * factor).clamp(0.0, 1.0);
530                    data[2 * h * w + y * w + x] = (gray + (b - gray) * factor).clamp(0.0, 1.0);
531                }
532            }
533        }
534
535        Tensor::from_vec(data, shape).unwrap()
536    }
537}
538
539// =============================================================================
540// Grayscale
541// =============================================================================
542
543/// Converts an RGB image to grayscale.
544pub struct Grayscale {
545    num_output_channels: usize,
546}
547
548impl Grayscale {
549    /// Creates a Grayscale transform with 1 output channel.
550    #[must_use] pub fn new() -> Self {
551        Self {
552            num_output_channels: 1,
553        }
554    }
555
556    /// Creates a Grayscale transform with specified output channels.
557    #[must_use] pub fn with_channels(num_output_channels: usize) -> Self {
558        Self {
559            num_output_channels: num_output_channels.max(1),
560        }
561    }
562}
563
564impl Default for Grayscale {
565    fn default() -> Self {
566        Self::new()
567    }
568}
569
570impl Transform for Grayscale {
571    fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
572        let shape = input.shape();
573
574        // Only works with 3-channel images (C x H x W)
575        if shape.len() != 3 || shape[0] != 3 {
576            return input.clone();
577        }
578
579        let (_, h, w) = (shape[0], shape[1], shape[2]);
580        let data = input.to_vec();
581
582        let mut gray = Vec::with_capacity(h * w);
583        for y in 0..h {
584            for x in 0..w {
585                let r = data[0 * h * w + y * w + x];
586                let g = data[h * w + y * w + x];
587                let b = data[2 * h * w + y * w + x];
588                gray.push(0.299 * r + 0.587 * g + 0.114 * b);
589            }
590        }
591
592        if self.num_output_channels == 1 {
593            Tensor::from_vec(gray, &[1, h, w]).unwrap()
594        } else {
595            // Replicate grayscale across channels
596            let mut result = Vec::with_capacity(self.num_output_channels * h * w);
597            for _ in 0..self.num_output_channels {
598                result.extend(&gray);
599            }
600            Tensor::from_vec(result, &[self.num_output_channels, h, w]).unwrap()
601        }
602    }
603}
604
605// =============================================================================
606// Normalize (Image-specific)
607// =============================================================================
608
609/// Normalizes an image with per-channel mean and std.
610pub struct ImageNormalize {
611    mean: Vec<f32>,
612    std: Vec<f32>,
613}
614
615impl ImageNormalize {
616    /// Creates a new `ImageNormalize` with per-channel mean and std.
617    #[must_use] pub fn new(mean: Vec<f32>, std: Vec<f32>) -> Self {
618        Self { mean, std }
619    }
620
621    /// Creates normalization for `ImageNet` pretrained models.
622    #[must_use] pub fn imagenet() -> Self {
623        Self::new(vec![0.485, 0.456, 0.406], vec![0.229, 0.224, 0.225])
624    }
625
626    /// Creates normalization for MNIST (single channel).
627    #[must_use] pub fn mnist() -> Self {
628        Self::new(vec![0.1307], vec![0.3081])
629    }
630
631    /// Creates normalization for CIFAR-10.
632    #[must_use] pub fn cifar10() -> Self {
633        Self::new(vec![0.4914, 0.4822, 0.4465], vec![0.2470, 0.2435, 0.2616])
634    }
635}
636
637impl Transform for ImageNormalize {
638    fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
639        let shape = input.shape();
640        let mut data = input.to_vec();
641
642        match shape.len() {
643            3 => {
644                let (c, h, w) = (shape[0], shape[1], shape[2]);
645                for ch in 0..c {
646                    let mean = self.mean.get(ch).copied().unwrap_or(0.0);
647                    let std = self.std.get(ch).copied().unwrap_or(1.0);
648                    for y in 0..h {
649                        for x in 0..w {
650                            let idx = ch * h * w + y * w + x;
651                            data[idx] = (data[idx] - mean) / std;
652                        }
653                    }
654                }
655            }
656            4 => {
657                let (n, c, h, w) = (shape[0], shape[1], shape[2], shape[3]);
658                for batch in 0..n {
659                    for ch in 0..c {
660                        let mean = self.mean.get(ch).copied().unwrap_or(0.0);
661                        let std = self.std.get(ch).copied().unwrap_or(1.0);
662                        for y in 0..h {
663                            for x in 0..w {
664                                let idx = batch * c * h * w + ch * h * w + y * w + x;
665                                data[idx] = (data[idx] - mean) / std;
666                            }
667                        }
668                    }
669                }
670            }
671            _ => {}
672        }
673
674        Tensor::from_vec(data, shape).unwrap()
675    }
676}
677
678// =============================================================================
679// Pad
680// =============================================================================
681
682/// Pads an image with a constant value.
683pub struct Pad {
684    padding: (usize, usize, usize, usize), // (left, right, top, bottom)
685    fill_value: f32,
686}
687
688impl Pad {
689    /// Creates a new Pad with uniform padding.
690    #[must_use] pub fn new(padding: usize) -> Self {
691        Self {
692            padding: (padding, padding, padding, padding),
693            fill_value: 0.0,
694        }
695    }
696
697    /// Creates a Pad with asymmetric padding.
698    #[must_use] pub fn asymmetric(left: usize, right: usize, top: usize, bottom: usize) -> Self {
699        Self {
700            padding: (left, right, top, bottom),
701            fill_value: 0.0,
702        }
703    }
704
705    /// Sets the fill value.
706    #[must_use] pub fn with_fill(mut self, value: f32) -> Self {
707        self.fill_value = value;
708        self
709    }
710}
711
712impl Transform for Pad {
713    fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
714        let shape = input.shape();
715        let data = input.to_vec();
716        let (left, right, top, bottom) = self.padding;
717
718        match shape.len() {
719            2 => {
720                let (h, w) = (shape[0], shape[1]);
721                let new_h = h + top + bottom;
722                let new_w = w + left + right;
723
724                let mut result = vec![self.fill_value; new_h * new_w];
725                for y in 0..h {
726                    for x in 0..w {
727                        result[(y + top) * new_w + (x + left)] = data[y * w + x];
728                    }
729                }
730                Tensor::from_vec(result, &[new_h, new_w]).unwrap()
731            }
732            3 => {
733                let (c, h, w) = (shape[0], shape[1], shape[2]);
734                let new_h = h + top + bottom;
735                let new_w = w + left + right;
736
737                let mut result = vec![self.fill_value; c * new_h * new_w];
738                for ch in 0..c {
739                    for y in 0..h {
740                        for x in 0..w {
741                            result[ch * new_h * new_w + (y + top) * new_w + (x + left)] =
742                                data[ch * h * w + y * w + x];
743                        }
744                    }
745                }
746                Tensor::from_vec(result, &[c, new_h, new_w]).unwrap()
747            }
748            _ => input.clone(),
749        }
750    }
751}
752
753// =============================================================================
754// ToTensorImage
755// =============================================================================
756
757/// Converts image data from [0, 255] to [0, 1] range.
758pub struct ToTensorImage;
759
760impl ToTensorImage {
761    /// Creates a new `ToTensorImage` transform.
762    #[must_use] pub fn new() -> Self {
763        Self
764    }
765}
766
767impl Default for ToTensorImage {
768    fn default() -> Self {
769        Self::new()
770    }
771}
772
773impl Transform for ToTensorImage {
774    fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
775        let data: Vec<f32> = input.to_vec().iter().map(|&x| x / 255.0).collect();
776        Tensor::from_vec(data, input.shape()).unwrap()
777    }
778}
779
780// =============================================================================
781// Tests
782// =============================================================================
783
784#[cfg(test)]
785mod tests {
786    use super::*;
787
788    #[test]
789    fn test_resize_2d() {
790        let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
791
792        let resize = Resize::new(4, 4);
793        let output = resize.apply(&input);
794
795        assert_eq!(output.shape(), &[4, 4]);
796    }
797
798    #[test]
799    fn test_resize_3d() {
800        let input = Tensor::from_vec(vec![1.0; 3 * 8 * 8], &[3, 8, 8]).unwrap();
801
802        let resize = Resize::new(4, 4);
803        let output = resize.apply(&input);
804
805        assert_eq!(output.shape(), &[3, 4, 4]);
806    }
807
808    #[test]
809    fn test_center_crop() {
810        let input = Tensor::from_vec((1..=16).map(|x| x as f32).collect(), &[4, 4]).unwrap();
811
812        let crop = CenterCrop::new(2, 2);
813        let output = crop.apply(&input);
814
815        assert_eq!(output.shape(), &[2, 2]);
816        // Center 2x2 of 4x4: values 6, 7, 10, 11
817        assert_eq!(output.to_vec(), vec![6.0, 7.0, 10.0, 11.0]);
818    }
819
820    #[test]
821    fn test_random_horizontal_flip() {
822        let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
823
824        let flip = RandomHorizontalFlip::with_probability(1.0);
825        let output = flip.apply(&input);
826
827        // [[1, 2], [3, 4]] -> [[2, 1], [4, 3]]
828        assert_eq!(output.to_vec(), vec![2.0, 1.0, 4.0, 3.0]);
829    }
830
831    #[test]
832    fn test_random_vertical_flip() {
833        let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
834
835        let flip = RandomVerticalFlip::with_probability(1.0);
836        let output = flip.apply(&input);
837
838        // [[1, 2], [3, 4]] -> [[3, 4], [1, 2]]
839        assert_eq!(output.to_vec(), vec![3.0, 4.0, 1.0, 2.0]);
840    }
841
842    #[test]
843    fn test_random_rotation_180() {
844        let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
845
846        let rotation = RandomRotation::with_angles(vec![180]);
847        let output = rotation.apply(&input);
848
849        // [[1, 2], [3, 4]] rotated 180 -> [[4, 3], [2, 1]]
850        assert_eq!(output.to_vec(), vec![4.0, 3.0, 2.0, 1.0]);
851    }
852
853    #[test]
854    fn test_grayscale() {
855        let input = Tensor::from_vec(
856            vec![
857                1.0, 1.0, 1.0, 1.0, // R channel
858                0.5, 0.5, 0.5, 0.5, // G channel
859                0.0, 0.0, 0.0, 0.0, // B channel
860            ],
861            &[3, 2, 2],
862        )
863        .unwrap();
864
865        let gray = Grayscale::new();
866        let output = gray.apply(&input);
867
868        assert_eq!(output.shape(), &[1, 2, 2]);
869        // Gray = 0.299 * 1.0 + 0.587 * 0.5 + 0.114 * 0.0 = 0.5925
870        let expected = 0.299 + 0.587 * 0.5;
871        for val in output.to_vec() {
872            assert!((val - expected).abs() < 0.001);
873        }
874    }
875
876    #[test]
877    fn test_image_normalize() {
878        let input = Tensor::from_vec(vec![0.5; 3 * 2 * 2], &[3, 2, 2]).unwrap();
879
880        let normalize = ImageNormalize::new(vec![0.5, 0.5, 0.5], vec![0.5, 0.5, 0.5]);
881        let output = normalize.apply(&input);
882
883        // (0.5 - 0.5) / 0.5 = 0.0
884        for val in output.to_vec() {
885            assert!((val - 0.0).abs() < 0.001);
886        }
887    }
888
889    #[test]
890    fn test_pad() {
891        let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
892
893        let pad = Pad::new(1);
894        let output = pad.apply(&input);
895
896        assert_eq!(output.shape(), &[4, 4]);
897        // Check corners are zero
898        let data = output.to_vec();
899        assert_eq!(data[0], 0.0);
900        assert_eq!(data[3], 0.0);
901        assert_eq!(data[12], 0.0);
902        assert_eq!(data[15], 0.0);
903        // Check center values
904        assert_eq!(data[5], 1.0);
905        assert_eq!(data[6], 2.0);
906        assert_eq!(data[9], 3.0);
907        assert_eq!(data[10], 4.0);
908    }
909
910    #[test]
911    fn test_to_tensor_image() {
912        let input = Tensor::from_vec(vec![0.0, 127.5, 255.0], &[3]).unwrap();
913
914        let transform = ToTensorImage::new();
915        let output = transform.apply(&input);
916
917        let data = output.to_vec();
918        assert!((data[0] - 0.0).abs() < 0.001);
919        assert!((data[1] - 0.5).abs() < 0.001);
920        assert!((data[2] - 1.0).abs() < 0.001);
921    }
922
923    #[test]
924    fn test_color_jitter() {
925        let input = Tensor::from_vec(vec![0.5; 3 * 4 * 4], &[3, 4, 4]).unwrap();
926
927        let jitter = ColorJitter::new(0.1, 0.1, 0.1);
928        let output = jitter.apply(&input);
929
930        assert_eq!(output.shape(), &[3, 4, 4]);
931        // Values should be in valid range
932        for val in output.to_vec() {
933            assert!((0.0..=1.0).contains(&val));
934        }
935    }
936}