Skip to main content

axonml_data/
transforms.rs

1//! Transforms - Data Augmentation and Preprocessing
2//!
3//! # File
4//! `crates/axonml-data/src/transforms.rs`
5//!
6//! # Author
7//! Andrew Jewell Sr - AutomataNexus
8//!
9//! # Updated
10//! March 8, 2026
11//!
12//! # Disclaimer
13//! Use at own risk. This software is provided "as is", without warranty of any
14//! kind, express or implied. The author and AutomataNexus shall not be held
15//! liable for any damages arising from the use of this software.
16
17use axonml_tensor::Tensor;
18use rand::Rng;
19
20// =============================================================================
21// Transform Trait
22// =============================================================================
23
24/// Trait for data transformations.
25pub trait Transform: Send + Sync {
26    /// Applies the transform to a tensor.
27    fn apply(&self, input: &Tensor<f32>) -> Tensor<f32>;
28}
29
30// =============================================================================
31// Compose
32// =============================================================================
33
34/// Composes multiple transforms into a single transform.
35pub struct Compose {
36    transforms: Vec<Box<dyn Transform>>,
37}
38
39impl Compose {
40    /// Creates a new Compose from a vector of transforms.
41    #[must_use]
42    pub fn new(transforms: Vec<Box<dyn Transform>>) -> Self {
43        Self { transforms }
44    }
45
46    /// Creates an empty Compose.
47    #[must_use]
48    pub fn empty() -> Self {
49        Self {
50            transforms: Vec::new(),
51        }
52    }
53
54    /// Adds a transform to the composition.
55    pub fn add<T: Transform + 'static>(mut self, transform: T) -> Self {
56        self.transforms.push(Box::new(transform));
57        self
58    }
59}
60
61impl Transform for Compose {
62    fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
63        let mut result = input.clone();
64        for transform in &self.transforms {
65            result = transform.apply(&result);
66        }
67        result
68    }
69}
70
71// =============================================================================
72// ToTensor
73// =============================================================================
74
75/// Converts input to a tensor (identity for already-tensor inputs).
76pub struct ToTensor;
77
78impl ToTensor {
79    /// Creates a new `ToTensor` transform.
80    #[must_use]
81    pub fn new() -> Self {
82        Self
83    }
84}
85
86impl Default for ToTensor {
87    fn default() -> Self {
88        Self::new()
89    }
90}
91
92impl Transform for ToTensor {
93    fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
94        input.clone()
95    }
96}
97
98// =============================================================================
99// Normalize
100// =============================================================================
101
102/// Normalizes a tensor with mean and standard deviation.
103///
104/// Supports both scalar normalization (applied uniformly) and per-channel
105/// normalization (PyTorch-style `transforms.Normalize(mean=[...], std=[...])`).
106/// For per-channel mode, the tensor is expected to have shape `[C, H, W]` or
107/// `[N, C, H, W]`.
108pub struct Normalize {
109    mean: Vec<f32>,
110    std: Vec<f32>,
111}
112
113impl Normalize {
114    /// Creates a new scalar Normalize transform (applied uniformly to all elements).
115    #[must_use]
116    pub fn new(mean: f32, std: f32) -> Self {
117        Self {
118            mean: vec![mean],
119            std: vec![std],
120        }
121    }
122
123    /// Creates a per-channel Normalize transform (PyTorch-compatible).
124    ///
125    /// For a `[C, H, W]` tensor, `mean` and `std` must have length `C`.
126    /// Each channel is normalized independently: `output[c] = (input[c] - mean[c]) / std[c]`.
127    #[must_use]
128    pub fn per_channel(mean: Vec<f32>, std: Vec<f32>) -> Self {
129        assert_eq!(mean.len(), std.len(), "mean and std must have same length");
130        Self { mean, std }
131    }
132
133    /// Creates a Normalize for standard normal distribution (mean=0, std=1).
134    #[must_use]
135    pub fn standard() -> Self {
136        Self::new(0.0, 1.0)
137    }
138
139    /// Creates a Normalize for [0,1] to [-1,1] conversion (mean=0.5, std=0.5).
140    #[must_use]
141    pub fn zero_centered() -> Self {
142        Self::new(0.5, 0.5)
143    }
144
145    /// Creates a Normalize for ImageNet (3-channel RGB).
146    #[must_use]
147    pub fn imagenet() -> Self {
148        Self::per_channel(vec![0.485, 0.456, 0.406], vec![0.229, 0.224, 0.225])
149    }
150}
151
152impl Transform for Normalize {
153    fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
154        let shape = input.shape();
155        let mut data = input.to_vec();
156
157        if self.mean.len() == 1 {
158            // Scalar normalization — apply uniformly
159            let m = self.mean[0];
160            let s = self.std[0];
161            for x in &mut data {
162                *x = (*x - m) / s;
163            }
164        } else {
165            // Per-channel normalization
166            let num_channels = self.mean.len();
167
168            if shape.len() == 3 && shape[0] == num_channels {
169                // [C, H, W]
170                let spatial = shape[1] * shape[2];
171                for c in 0..num_channels {
172                    let offset = c * spatial;
173                    let m = self.mean[c];
174                    let s = self.std[c];
175                    for i in 0..spatial {
176                        data[offset + i] = (data[offset + i] - m) / s;
177                    }
178                }
179            } else if shape.len() == 4 && shape[1] == num_channels {
180                // [N, C, H, W]
181                let spatial = shape[2] * shape[3];
182                let sample_size = num_channels * spatial;
183                for n in 0..shape[0] {
184                    for c in 0..num_channels {
185                        let offset = n * sample_size + c * spatial;
186                        let m = self.mean[c];
187                        let s = self.std[c];
188                        for i in 0..spatial {
189                            data[offset + i] = (data[offset + i] - m) / s;
190                        }
191                    }
192                }
193            } else {
194                // Fallback: apply first channel's mean/std uniformly
195                let m = self.mean[0];
196                let s = self.std[0];
197                for x in &mut data {
198                    *x = (*x - m) / s;
199                }
200            }
201        }
202
203        Tensor::from_vec(data, shape).unwrap()
204    }
205}
206
207// =============================================================================
208// RandomNoise
209// =============================================================================
210
211/// Adds random Gaussian noise to the input.
212pub struct RandomNoise {
213    std: f32,
214}
215
216impl RandomNoise {
217    /// Creates a new `RandomNoise` transform.
218    #[must_use]
219    pub fn new(std: f32) -> Self {
220        Self { std }
221    }
222}
223
224impl Transform for RandomNoise {
225    fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
226        if self.std == 0.0 {
227            return input.clone();
228        }
229
230        let mut rng = rand::thread_rng();
231        let data = input.to_vec();
232        let noisy: Vec<f32> = data
233            .iter()
234            .map(|&x| {
235                // Box-Muller transform for Gaussian noise
236                let u1: f32 = rng.r#gen();
237                let u2: f32 = rng.r#gen();
238                let z = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f32::consts::PI * u2).cos();
239                x + z * self.std
240            })
241            .collect();
242        Tensor::from_vec(noisy, input.shape()).unwrap()
243    }
244}
245
246// =============================================================================
247// RandomCrop
248// =============================================================================
249
250/// Randomly crops a portion of the input.
251pub struct RandomCrop {
252    size: Vec<usize>,
253}
254
255impl RandomCrop {
256    /// Creates a new `RandomCrop` with target size.
257    #[must_use]
258    pub fn new(size: Vec<usize>) -> Self {
259        Self { size }
260    }
261
262    /// Creates a `RandomCrop` for 2D images.
263    #[must_use]
264    pub fn new_2d(height: usize, width: usize) -> Self {
265        Self::new(vec![height, width])
266    }
267}
268
269impl Transform for RandomCrop {
270    fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
271        let shape = input.shape();
272
273        // Determine spatial dimensions (last N dimensions where N = size.len())
274        if shape.len() < self.size.len() {
275            return input.clone();
276        }
277
278        let spatial_start = shape.len() - self.size.len();
279        let mut rng = rand::thread_rng();
280
281        // Calculate random offsets for each spatial dimension
282        let mut offsets = Vec::with_capacity(self.size.len());
283        for (i, &target_dim) in self.size.iter().enumerate() {
284            let input_dim = shape[spatial_start + i];
285            if input_dim <= target_dim {
286                offsets.push(0);
287            } else {
288                offsets.push(rng.gen_range(0..=input_dim - target_dim));
289            }
290        }
291
292        // Calculate actual crop sizes (clamped to input dimensions)
293        let crop_sizes: Vec<usize> = self
294            .size
295            .iter()
296            .enumerate()
297            .map(|(i, &s)| s.min(shape[spatial_start + i]))
298            .collect();
299
300        let data = input.to_vec();
301
302        // Handle 1D case
303        if shape.len() == 1 && self.size.len() == 1 {
304            let start = offsets[0];
305            let end = start + crop_sizes[0];
306            let cropped = data[start..end].to_vec();
307            let len = cropped.len();
308            return Tensor::from_vec(cropped, &[len]).unwrap();
309        }
310
311        // Handle 2D case (H x W)
312        if shape.len() == 2 && self.size.len() == 2 {
313            let (_h, w) = (shape[0], shape[1]);
314            let (crop_h, crop_w) = (crop_sizes[0], crop_sizes[1]);
315            let (off_h, off_w) = (offsets[0], offsets[1]);
316
317            let mut cropped = Vec::with_capacity(crop_h * crop_w);
318            for row in off_h..off_h + crop_h {
319                for col in off_w..off_w + crop_w {
320                    cropped.push(data[row * w + col]);
321                }
322            }
323            return Tensor::from_vec(cropped, &[crop_h, crop_w]).unwrap();
324        }
325
326        // Handle 3D case (C x H x W) - common for images
327        if shape.len() == 3 && self.size.len() == 2 {
328            let (c, h, w) = (shape[0], shape[1], shape[2]);
329            let (crop_h, crop_w) = (crop_sizes[0], crop_sizes[1]);
330            let (off_h, off_w) = (offsets[0], offsets[1]);
331
332            let mut cropped = Vec::with_capacity(c * crop_h * crop_w);
333            for channel in 0..c {
334                for row in off_h..off_h + crop_h {
335                    for col in off_w..off_w + crop_w {
336                        cropped.push(data[channel * h * w + row * w + col]);
337                    }
338                }
339            }
340            return Tensor::from_vec(cropped, &[c, crop_h, crop_w]).unwrap();
341        }
342
343        // Handle 4D case (N x C x H x W) - batched images
344        if shape.len() == 4 && self.size.len() == 2 {
345            let (n, c, h, w) = (shape[0], shape[1], shape[2], shape[3]);
346            let (crop_h, crop_w) = (crop_sizes[0], crop_sizes[1]);
347            let (off_h, off_w) = (offsets[0], offsets[1]);
348
349            let mut cropped = Vec::with_capacity(n * c * crop_h * crop_w);
350            for batch in 0..n {
351                for channel in 0..c {
352                    for row in off_h..off_h + crop_h {
353                        for col in off_w..off_w + crop_w {
354                            let idx = batch * c * h * w + channel * h * w + row * w + col;
355                            cropped.push(data[idx]);
356                        }
357                    }
358                }
359            }
360            return Tensor::from_vec(cropped, &[n, c, crop_h, crop_w]).unwrap();
361        }
362
363        // Fallback for unsupported dimensions - shouldn't reach here in practice
364        input.clone()
365    }
366}
367
368// =============================================================================
369// RandomFlip
370// =============================================================================
371
372/// Randomly flips the input along a specified dimension.
373pub struct RandomFlip {
374    dim: usize,
375    probability: f32,
376}
377
378impl RandomFlip {
379    /// Creates a new `RandomFlip`.
380    #[must_use]
381    pub fn new(dim: usize, probability: f32) -> Self {
382        Self {
383            dim,
384            probability: probability.clamp(0.0, 1.0),
385        }
386    }
387
388    /// Creates a horizontal flip (dim=1 for `HxW` images).
389    #[must_use]
390    pub fn horizontal() -> Self {
391        Self::new(1, 0.5)
392    }
393
394    /// Creates a vertical flip (dim=0 for `HxW` images).
395    #[must_use]
396    pub fn vertical() -> Self {
397        Self::new(0, 0.5)
398    }
399}
400
401impl Transform for RandomFlip {
402    fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
403        let mut rng = rand::thread_rng();
404        if rng.r#gen::<f32>() > self.probability {
405            return input.clone();
406        }
407
408        let shape = input.shape();
409        if self.dim >= shape.len() {
410            return input.clone();
411        }
412
413        let data = input.to_vec();
414        let ndim = shape.len();
415
416        // Generic N-dimensional flip along self.dim:
417        // Compute strides, then for each element map the flipped index.
418        let total = data.len();
419        let mut flipped = vec![0.0f32; total];
420
421        // Compute strides (row-major)
422        let mut strides = vec![1usize; ndim];
423        for i in (0..ndim - 1).rev() {
424            strides[i] = strides[i + 1] * shape[i + 1];
425        }
426
427        let dim = self.dim;
428        let dim_size = shape[dim];
429        let dim_stride = strides[dim];
430
431        for i in 0..total {
432            // Extract the coordinate along the flip dimension
433            let coord_in_dim = (i / dim_stride) % dim_size;
434            let flipped_coord = dim_size - 1 - coord_in_dim;
435            // Compute the source index with the flipped coordinate
436            let diff = flipped_coord as isize - coord_in_dim as isize;
437            let src = (i as isize + diff * dim_stride as isize) as usize;
438            flipped[i] = data[src];
439        }
440
441        Tensor::from_vec(flipped, shape).unwrap()
442    }
443}
444
445// =============================================================================
446// Scale
447// =============================================================================
448
449/// Scales tensor values by a constant factor.
450pub struct Scale {
451    factor: f32,
452}
453
454impl Scale {
455    /// Creates a new Scale transform.
456    #[must_use]
457    pub fn new(factor: f32) -> Self {
458        Self { factor }
459    }
460}
461
462impl Transform for Scale {
463    fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
464        input.mul_scalar(self.factor)
465    }
466}
467
468// =============================================================================
469// Clamp
470// =============================================================================
471
472/// Clamps tensor values to a specified range.
473pub struct Clamp {
474    min: f32,
475    max: f32,
476}
477
478impl Clamp {
479    /// Creates a new Clamp transform.
480    #[must_use]
481    pub fn new(min: f32, max: f32) -> Self {
482        Self { min, max }
483    }
484
485    /// Creates a Clamp for [0, 1] range.
486    #[must_use]
487    pub fn zero_one() -> Self {
488        Self::new(0.0, 1.0)
489    }
490
491    /// Creates a Clamp for [-1, 1] range.
492    #[must_use]
493    pub fn symmetric() -> Self {
494        Self::new(-1.0, 1.0)
495    }
496}
497
498impl Transform for Clamp {
499    fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
500        let data = input.to_vec();
501        let clamped: Vec<f32> = data.iter().map(|&x| x.clamp(self.min, self.max)).collect();
502        Tensor::from_vec(clamped, input.shape()).unwrap()
503    }
504}
505
506// =============================================================================
507// Flatten
508// =============================================================================
509
510/// Flattens the tensor to 1D.
511pub struct Flatten;
512
513impl Flatten {
514    /// Creates a new Flatten transform.
515    #[must_use]
516    pub fn new() -> Self {
517        Self
518    }
519}
520
521impl Default for Flatten {
522    fn default() -> Self {
523        Self::new()
524    }
525}
526
527impl Transform for Flatten {
528    fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
529        let data = input.to_vec();
530        Tensor::from_vec(data.clone(), &[data.len()]).unwrap()
531    }
532}
533
534// =============================================================================
535// Reshape
536// =============================================================================
537
538/// Reshapes the tensor to a specified shape.
539pub struct Reshape {
540    shape: Vec<usize>,
541}
542
543impl Reshape {
544    /// Creates a new Reshape transform.
545    #[must_use]
546    pub fn new(shape: Vec<usize>) -> Self {
547        Self { shape }
548    }
549}
550
551impl Transform for Reshape {
552    fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
553        let data = input.to_vec();
554        let expected_size: usize = self.shape.iter().product();
555
556        if data.len() != expected_size {
557            // Size mismatch, return original
558            return input.clone();
559        }
560
561        Tensor::from_vec(data, &self.shape).unwrap()
562    }
563}
564
565// =============================================================================
566// Dropout Transform
567// =============================================================================
568
569/// Applies dropout by randomly zeroing elements during training.
570///
571/// Respects train/eval mode: dropout is only applied when `training` is true.
572/// Use `set_training(false)` to disable dropout during evaluation.
573pub struct DropoutTransform {
574    probability: f32,
575    training: std::sync::atomic::AtomicBool,
576}
577
578impl DropoutTransform {
579    /// Creates a new `DropoutTransform` in training mode.
580    #[must_use]
581    pub fn new(probability: f32) -> Self {
582        Self {
583            probability: probability.clamp(0.0, 1.0),
584            training: std::sync::atomic::AtomicBool::new(true),
585        }
586    }
587
588    /// Sets whether this transform is in training mode.
589    pub fn set_training(&self, training: bool) {
590        self.training
591            .store(training, std::sync::atomic::Ordering::Relaxed);
592    }
593
594    /// Returns whether this transform is in training mode.
595    pub fn is_training(&self) -> bool {
596        self.training.load(std::sync::atomic::Ordering::Relaxed)
597    }
598}
599
600impl Transform for DropoutTransform {
601    fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
602        // No dropout in eval mode or with p=0
603        if !self.is_training() || self.probability == 0.0 {
604            return input.clone();
605        }
606
607        let mut rng = rand::thread_rng();
608        let scale = 1.0 / (1.0 - self.probability);
609        let data = input.to_vec();
610
611        let dropped: Vec<f32> = data
612            .iter()
613            .map(|&x| {
614                if rng.r#gen::<f32>() < self.probability {
615                    0.0
616                } else {
617                    x * scale
618                }
619            })
620            .collect();
621
622        Tensor::from_vec(dropped, input.shape()).unwrap()
623    }
624}
625
626// =============================================================================
627// Lambda Transform
628// =============================================================================
629
630/// Applies a custom function as a transform.
631pub struct Lambda<F>
632where
633    F: Fn(&Tensor<f32>) -> Tensor<f32> + Send + Sync,
634{
635    func: F,
636}
637
638impl<F> Lambda<F>
639where
640    F: Fn(&Tensor<f32>) -> Tensor<f32> + Send + Sync,
641{
642    /// Creates a new Lambda transform.
643    pub fn new(func: F) -> Self {
644        Self { func }
645    }
646}
647
648impl<F> Transform for Lambda<F>
649where
650    F: Fn(&Tensor<f32>) -> Tensor<f32> + Send + Sync,
651{
652    fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
653        (self.func)(input)
654    }
655}
656
657// =============================================================================
658// Tests
659// =============================================================================
660
661#[cfg(test)]
662mod tests {
663    use super::*;
664
665    #[test]
666    fn test_normalize() {
667        let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[4]).unwrap();
668        let normalize = Normalize::new(2.5, 0.5);
669
670        let output = normalize.apply(&input);
671        let expected = [-3.0, -1.0, 1.0, 3.0];
672
673        let result = output.to_vec();
674        for (a, b) in result.iter().zip(expected.iter()) {
675            assert!((a - b).abs() < 1e-6);
676        }
677    }
678
679    #[test]
680    fn test_normalize_per_channel() {
681        // 2 channels, 2x2 spatial => [2, 2, 2]
682        let input =
683            Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 10.0, 20.0, 30.0, 40.0], &[2, 2, 2]).unwrap();
684        let normalize = Normalize::per_channel(vec![0.0, 10.0], vec![1.0, 10.0]);
685
686        let output = normalize.apply(&input);
687        let result = output.to_vec();
688        // Channel 0: (x - 0) / 1 = x
689        assert!((result[0] - 1.0).abs() < 1e-6);
690        assert!((result[3] - 4.0).abs() < 1e-6);
691        // Channel 1: (x - 10) / 10
692        assert!((result[4] - 0.0).abs() < 1e-6); // (10-10)/10
693        assert!((result[5] - 1.0).abs() < 1e-6); // (20-10)/10
694    }
695
696    #[test]
697    fn test_scale() {
698        let input = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
699        let scale = Scale::new(2.0);
700
701        let output = scale.apply(&input);
702        assert_eq!(output.to_vec(), vec![2.0, 4.0, 6.0]);
703    }
704
705    #[test]
706    fn test_clamp() {
707        let input = Tensor::from_vec(vec![-1.0, 0.5, 2.0], &[3]).unwrap();
708        let clamp = Clamp::zero_one();
709
710        let output = clamp.apply(&input);
711        assert_eq!(output.to_vec(), vec![0.0, 0.5, 1.0]);
712    }
713
714    #[test]
715    fn test_flatten() {
716        let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
717        let flatten = Flatten::new();
718
719        let output = flatten.apply(&input);
720        assert_eq!(output.shape(), &[4]);
721        assert_eq!(output.to_vec(), vec![1.0, 2.0, 3.0, 4.0]);
722    }
723
724    #[test]
725    fn test_reshape() {
726        let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[6]).unwrap();
727        let reshape = Reshape::new(vec![2, 3]);
728
729        let output = reshape.apply(&input);
730        assert_eq!(output.shape(), &[2, 3]);
731    }
732
733    #[test]
734    fn test_compose() {
735        let normalize = Normalize::new(0.0, 1.0);
736        let scale = Scale::new(2.0);
737
738        let compose = Compose::new(vec![Box::new(normalize), Box::new(scale)]);
739
740        let input = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
741        let output = compose.apply(&input);
742
743        // normalize(x) = x, then scale by 2
744        assert_eq!(output.to_vec(), vec![2.0, 4.0, 6.0]);
745    }
746
747    #[test]
748    fn test_compose_builder() {
749        let compose = Compose::empty()
750            .add(Normalize::new(0.0, 1.0))
751            .add(Scale::new(2.0));
752
753        let input = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
754        let output = compose.apply(&input);
755
756        assert_eq!(output.to_vec(), vec![2.0, 4.0, 6.0]);
757    }
758
759    #[test]
760    fn test_random_noise() {
761        let input = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
762        let noise = RandomNoise::new(0.0);
763
764        // With std=0, output should equal input
765        let output = noise.apply(&input);
766        assert_eq!(output.to_vec(), vec![1.0, 2.0, 3.0]);
767    }
768
769    #[test]
770    fn test_random_flip_1d() {
771        let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[4]).unwrap();
772        let flip = RandomFlip::new(0, 1.0); // Always flip
773
774        let output = flip.apply(&input);
775        assert_eq!(output.to_vec(), vec![4.0, 3.0, 2.0, 1.0]);
776    }
777
778    #[test]
779    fn test_random_flip_2d_horizontal() {
780        let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
781        let flip = RandomFlip::new(1, 1.0); // Always flip horizontal
782
783        let output = flip.apply(&input);
784        // [[1, 2], [3, 4]] -> [[2, 1], [4, 3]]
785        assert_eq!(output.to_vec(), vec![2.0, 1.0, 4.0, 3.0]);
786    }
787
788    #[test]
789    fn test_random_flip_2d_vertical() {
790        let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
791        let flip = RandomFlip::new(0, 1.0); // Always flip vertical
792
793        let output = flip.apply(&input);
794        // [[1, 2], [3, 4]] -> [[3, 4], [1, 2]]
795        assert_eq!(output.to_vec(), vec![3.0, 4.0, 1.0, 2.0]);
796    }
797
798    #[test]
799    fn test_random_flip_3d() {
800        // C=1, H=2, W=2 — flip along dim=2 (horizontal in spatial)
801        let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[1, 2, 2]).unwrap();
802        let flip = RandomFlip::new(2, 1.0); // Always flip along W
803
804        let output = flip.apply(&input);
805        // [[1,2],[3,4]] → [[2,1],[4,3]]
806        assert_eq!(output.to_vec(), vec![2.0, 1.0, 4.0, 3.0]);
807        assert_eq!(output.shape(), &[1, 2, 2]);
808    }
809
810    #[test]
811    fn test_random_flip_4d() {
812        // N=1, C=1, H=2, W=2 — flip along dim=2 (vertical flip of H)
813        let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[1, 1, 2, 2]).unwrap();
814        let flip = RandomFlip::new(2, 1.0); // Flip along H
815
816        let output = flip.apply(&input);
817        // Rows flipped: [[1,2],[3,4]] → [[3,4],[1,2]]
818        assert_eq!(output.to_vec(), vec![3.0, 4.0, 1.0, 2.0]);
819        assert_eq!(output.shape(), &[1, 1, 2, 2]);
820    }
821
822    #[test]
823    fn test_dropout_eval_mode() {
824        let input = Tensor::from_vec(vec![1.0; 100], &[100]).unwrap();
825        let dropout = DropoutTransform::new(0.5);
826
827        // In training mode, should drop elements
828        let output_train = dropout.apply(&input);
829        let zeros_train = output_train.to_vec().iter().filter(|&&x| x == 0.0).count();
830        assert!(zeros_train > 0, "Training mode should drop elements");
831
832        // Switch to eval mode — should be identity
833        dropout.set_training(false);
834        let output_eval = dropout.apply(&input);
835        assert_eq!(output_eval.to_vec(), vec![1.0; 100]);
836    }
837
838    #[test]
839    fn test_dropout_transform() {
840        let input = Tensor::from_vec(vec![1.0; 1000], &[1000]).unwrap();
841        let dropout = DropoutTransform::new(0.5);
842
843        let output = dropout.apply(&input);
844        let output_vec = output.to_vec();
845
846        // About half should be zero
847        let zeros = output_vec.iter().filter(|&&x| x == 0.0).count();
848        assert!(
849            zeros > 300 && zeros < 700,
850            "Expected ~500 zeros, got {zeros}"
851        );
852
853        // Non-zeros should be scaled by 2 (1/(1-0.5))
854        let nonzeros: Vec<f32> = output_vec.iter().filter(|&&x| x != 0.0).copied().collect();
855        for x in nonzeros {
856            assert!((x - 2.0).abs() < 1e-6);
857        }
858    }
859
860    #[test]
861    fn test_lambda() {
862        let lambda = Lambda::new(|t: &Tensor<f32>| t.mul_scalar(3.0));
863
864        let input = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
865        let output = lambda.apply(&input);
866
867        assert_eq!(output.to_vec(), vec![3.0, 6.0, 9.0]);
868    }
869
870    #[test]
871    fn test_to_tensor() {
872        let input = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
873        let to_tensor = ToTensor::new();
874
875        let output = to_tensor.apply(&input);
876        assert_eq!(output.to_vec(), input.to_vec());
877    }
878
879    #[test]
880    fn test_normalize_variants() {
881        let standard = Normalize::standard();
882        assert_eq!(standard.mean, vec![0.0]);
883        assert_eq!(standard.std, vec![1.0]);
884
885        let zero_centered = Normalize::zero_centered();
886        assert_eq!(zero_centered.mean, vec![0.5]);
887        assert_eq!(zero_centered.std, vec![0.5]);
888    }
889
890    #[test]
891    fn test_random_crop_1d() {
892        let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0], &[5]).unwrap();
893        let crop = RandomCrop::new(vec![3]);
894
895        let output = crop.apply(&input);
896        assert_eq!(output.shape(), &[3]);
897    }
898
899    #[test]
900    fn test_random_crop_2d() {
901        // 4x4 image
902        let input = Tensor::from_vec((1..=16).map(|x| x as f32).collect(), &[4, 4]).unwrap();
903        let crop = RandomCrop::new_2d(2, 2);
904
905        let output = crop.apply(&input);
906        assert_eq!(output.shape(), &[2, 2]);
907        // Verify values are contiguous from the original
908        let vals = output.to_vec();
909        assert_eq!(vals.len(), 4);
910    }
911
912    #[test]
913    fn test_random_crop_3d() {
914        // 2 channels x 4x4 image
915        let input = Tensor::from_vec((1..=32).map(|x| x as f32).collect(), &[2, 4, 4]).unwrap();
916        let crop = RandomCrop::new_2d(2, 2);
917
918        let output = crop.apply(&input);
919        assert_eq!(output.shape(), &[2, 2, 2]); // 2 channels, 2x2 spatial
920    }
921}