Skip to main content

axonml_data/
transforms.rs

1//! Transforms - Data Augmentation and Preprocessing
2//!
3//! # File
4//! `crates/axonml-data/src/transforms.rs`
5//!
6//! # Author
7//! Andrew Jewell Sr. — AutomataNexus LLC
8//! ORCID: 0009-0005-2158-7060
9//!
10//! # Updated
11//! April 14, 2026 11:15 PM EST
12//!
13//! # Disclaimer
14//! Use at own risk. This software is provided "as is", without warranty of any
15//! kind, express or implied. The author and AutomataNexus shall not be held
16//! liable for any damages arising from the use of this software.
17
18use axonml_tensor::Tensor;
19use rand::Rng;
20
21// =============================================================================
22// Transform Trait
23// =============================================================================
24
25/// Trait for data transformations.
26pub trait Transform: Send + Sync {
27    /// Applies the transform to a tensor.
28    fn apply(&self, input: &Tensor<f32>) -> Tensor<f32>;
29}
30
31// =============================================================================
32// Compose
33// =============================================================================
34
35/// Composes multiple transforms into a single transform.
36pub struct Compose {
37    transforms: Vec<Box<dyn Transform>>,
38}
39
40impl Compose {
41    /// Creates a new Compose from a vector of transforms.
42    #[must_use]
43    pub fn new(transforms: Vec<Box<dyn Transform>>) -> Self {
44        Self { transforms }
45    }
46
47    /// Creates an empty Compose.
48    #[must_use]
49    pub fn empty() -> Self {
50        Self {
51            transforms: Vec::new(),
52        }
53    }
54
55    /// Adds a transform to the composition.
56    pub fn add<T: Transform + 'static>(mut self, transform: T) -> Self {
57        self.transforms.push(Box::new(transform));
58        self
59    }
60}
61
62impl Transform for Compose {
63    fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
64        let mut result = input.clone();
65        for transform in &self.transforms {
66            result = transform.apply(&result);
67        }
68        result
69    }
70}
71
72// =============================================================================
73// ToTensor
74// =============================================================================
75
76/// Converts input to a tensor (identity for already-tensor inputs).
77pub struct ToTensor;
78
79impl ToTensor {
80    /// Creates a new `ToTensor` transform.
81    #[must_use]
82    pub fn new() -> Self {
83        Self
84    }
85}
86
87impl Default for ToTensor {
88    fn default() -> Self {
89        Self::new()
90    }
91}
92
93impl Transform for ToTensor {
94    fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
95        input.clone()
96    }
97}
98
99// =============================================================================
100// Normalize
101// =============================================================================
102
103/// Normalizes a tensor with mean and standard deviation.
104///
105/// Supports both scalar normalization (applied uniformly) and per-channel
106/// normalization (PyTorch-style `transforms.Normalize(mean=[...], std=[...])`).
107/// For per-channel mode, the tensor is expected to have shape `[C, H, W]` or
108/// `[N, C, H, W]`.
109pub struct Normalize {
110    mean: Vec<f32>,
111    std: Vec<f32>,
112}
113
114impl Normalize {
115    /// Creates a new scalar Normalize transform (applied uniformly to all elements).
116    #[must_use]
117    pub fn new(mean: f32, std: f32) -> Self {
118        Self {
119            mean: vec![mean],
120            std: vec![std],
121        }
122    }
123
124    /// Creates a per-channel Normalize transform (PyTorch-compatible).
125    ///
126    /// For a `[C, H, W]` tensor, `mean` and `std` must have length `C`.
127    /// Each channel is normalized independently: `output[c] = (input[c] - mean[c]) / std[c]`.
128    #[must_use]
129    pub fn per_channel(mean: Vec<f32>, std: Vec<f32>) -> Self {
130        assert_eq!(mean.len(), std.len(), "mean and std must have same length");
131        Self { mean, std }
132    }
133
134    /// Creates a Normalize for standard normal distribution (mean=0, std=1).
135    #[must_use]
136    pub fn standard() -> Self {
137        Self::new(0.0, 1.0)
138    }
139
140    /// Creates a Normalize for [0,1] to [-1,1] conversion (mean=0.5, std=0.5).
141    #[must_use]
142    pub fn zero_centered() -> Self {
143        Self::new(0.5, 0.5)
144    }
145
146    /// Creates a Normalize for ImageNet (3-channel RGB).
147    #[must_use]
148    pub fn imagenet() -> Self {
149        Self::per_channel(vec![0.485, 0.456, 0.406], vec![0.229, 0.224, 0.225])
150    }
151}
152
153impl Transform for Normalize {
154    fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
155        let shape = input.shape();
156        let mut data = input.to_vec();
157
158        if self.mean.len() == 1 {
159            // Scalar normalization — apply uniformly
160            let m = self.mean[0];
161            let s = self.std[0];
162            for x in &mut data {
163                *x = (*x - m) / s;
164            }
165        } else {
166            // Per-channel normalization
167            let num_channels = self.mean.len();
168
169            if shape.len() == 3 && shape[0] == num_channels {
170                // [C, H, W]
171                let spatial = shape[1] * shape[2];
172                for c in 0..num_channels {
173                    let offset = c * spatial;
174                    let m = self.mean[c];
175                    let s = self.std[c];
176                    for i in 0..spatial {
177                        data[offset + i] = (data[offset + i] - m) / s;
178                    }
179                }
180            } else if shape.len() == 4 && shape[1] == num_channels {
181                // [N, C, H, W]
182                let spatial = shape[2] * shape[3];
183                let sample_size = num_channels * spatial;
184                for n in 0..shape[0] {
185                    for c in 0..num_channels {
186                        let offset = n * sample_size + c * spatial;
187                        let m = self.mean[c];
188                        let s = self.std[c];
189                        for i in 0..spatial {
190                            data[offset + i] = (data[offset + i] - m) / s;
191                        }
192                    }
193                }
194            } else {
195                // Fallback: apply first channel's mean/std uniformly
196                let m = self.mean[0];
197                let s = self.std[0];
198                for x in &mut data {
199                    *x = (*x - m) / s;
200                }
201            }
202        }
203
204        Tensor::from_vec(data, shape).unwrap()
205    }
206}
207
208// =============================================================================
209// RandomNoise
210// =============================================================================
211
212/// Adds random Gaussian noise to the input.
213pub struct RandomNoise {
214    std: f32,
215}
216
217impl RandomNoise {
218    /// Creates a new `RandomNoise` transform.
219    #[must_use]
220    pub fn new(std: f32) -> Self {
221        Self { std }
222    }
223}
224
225impl Transform for RandomNoise {
226    fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
227        if self.std == 0.0 {
228            return input.clone();
229        }
230
231        let mut rng = rand::thread_rng();
232        let data = input.to_vec();
233        let noisy: Vec<f32> = data
234            .iter()
235            .map(|&x| {
236                // Box-Muller transform for Gaussian noise
237                let u1: f32 = rng.r#gen();
238                let u2: f32 = rng.r#gen();
239                let z = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f32::consts::PI * u2).cos();
240                x + z * self.std
241            })
242            .collect();
243        Tensor::from_vec(noisy, input.shape()).unwrap()
244    }
245}
246
247// =============================================================================
248// RandomCrop
249// =============================================================================
250
251/// Randomly crops a portion of the input.
252pub struct RandomCrop {
253    size: Vec<usize>,
254}
255
256impl RandomCrop {
257    /// Creates a new `RandomCrop` with target size.
258    #[must_use]
259    pub fn new(size: Vec<usize>) -> Self {
260        Self { size }
261    }
262
263    /// Creates a `RandomCrop` for 2D images.
264    #[must_use]
265    pub fn new_2d(height: usize, width: usize) -> Self {
266        Self::new(vec![height, width])
267    }
268}
269
270impl Transform for RandomCrop {
271    fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
272        let shape = input.shape();
273
274        // Determine spatial dimensions (last N dimensions where N = size.len())
275        if shape.len() < self.size.len() {
276            return input.clone();
277        }
278
279        let spatial_start = shape.len() - self.size.len();
280        let mut rng = rand::thread_rng();
281
282        // Calculate random offsets for each spatial dimension
283        let mut offsets = Vec::with_capacity(self.size.len());
284        for (i, &target_dim) in self.size.iter().enumerate() {
285            let input_dim = shape[spatial_start + i];
286            if input_dim <= target_dim {
287                offsets.push(0);
288            } else {
289                offsets.push(rng.gen_range(0..=input_dim - target_dim));
290            }
291        }
292
293        // Calculate actual crop sizes (clamped to input dimensions)
294        let crop_sizes: Vec<usize> = self
295            .size
296            .iter()
297            .enumerate()
298            .map(|(i, &s)| s.min(shape[spatial_start + i]))
299            .collect();
300
301        let data = input.to_vec();
302
303        // Handle 1D case
304        if shape.len() == 1 && self.size.len() == 1 {
305            let start = offsets[0];
306            let end = start + crop_sizes[0];
307            let cropped = data[start..end].to_vec();
308            let len = cropped.len();
309            return Tensor::from_vec(cropped, &[len]).unwrap();
310        }
311
312        // Handle 2D case (H x W)
313        if shape.len() == 2 && self.size.len() == 2 {
314            let (_h, w) = (shape[0], shape[1]);
315            let (crop_h, crop_w) = (crop_sizes[0], crop_sizes[1]);
316            let (off_h, off_w) = (offsets[0], offsets[1]);
317
318            let mut cropped = Vec::with_capacity(crop_h * crop_w);
319            for row in off_h..off_h + crop_h {
320                for col in off_w..off_w + crop_w {
321                    cropped.push(data[row * w + col]);
322                }
323            }
324            return Tensor::from_vec(cropped, &[crop_h, crop_w]).unwrap();
325        }
326
327        // Handle 3D case (C x H x W) - common for images
328        if shape.len() == 3 && self.size.len() == 2 {
329            let (c, h, w) = (shape[0], shape[1], shape[2]);
330            let (crop_h, crop_w) = (crop_sizes[0], crop_sizes[1]);
331            let (off_h, off_w) = (offsets[0], offsets[1]);
332
333            let mut cropped = Vec::with_capacity(c * crop_h * crop_w);
334            for channel in 0..c {
335                for row in off_h..off_h + crop_h {
336                    for col in off_w..off_w + crop_w {
337                        cropped.push(data[channel * h * w + row * w + col]);
338                    }
339                }
340            }
341            return Tensor::from_vec(cropped, &[c, crop_h, crop_w]).unwrap();
342        }
343
344        // Handle 4D case (N x C x H x W) - batched images
345        if shape.len() == 4 && self.size.len() == 2 {
346            let (n, c, h, w) = (shape[0], shape[1], shape[2], shape[3]);
347            let (crop_h, crop_w) = (crop_sizes[0], crop_sizes[1]);
348            let (off_h, off_w) = (offsets[0], offsets[1]);
349
350            let mut cropped = Vec::with_capacity(n * c * crop_h * crop_w);
351            for batch in 0..n {
352                for channel in 0..c {
353                    for row in off_h..off_h + crop_h {
354                        for col in off_w..off_w + crop_w {
355                            let idx = batch * c * h * w + channel * h * w + row * w + col;
356                            cropped.push(data[idx]);
357                        }
358                    }
359                }
360            }
361            return Tensor::from_vec(cropped, &[n, c, crop_h, crop_w]).unwrap();
362        }
363
364        // Fallback for unsupported dimensions - shouldn't reach here in practice
365        input.clone()
366    }
367}
368
369// =============================================================================
370// RandomFlip
371// =============================================================================
372
373/// Randomly flips the input along a specified dimension.
374pub struct RandomFlip {
375    dim: usize,
376    probability: f32,
377}
378
379impl RandomFlip {
380    /// Creates a new `RandomFlip`.
381    #[must_use]
382    pub fn new(dim: usize, probability: f32) -> Self {
383        Self {
384            dim,
385            probability: probability.clamp(0.0, 1.0),
386        }
387    }
388
389    /// Creates a horizontal flip (dim=1 for `HxW` images).
390    #[must_use]
391    pub fn horizontal() -> Self {
392        Self::new(1, 0.5)
393    }
394
395    /// Creates a vertical flip (dim=0 for `HxW` images).
396    #[must_use]
397    pub fn vertical() -> Self {
398        Self::new(0, 0.5)
399    }
400}
401
402impl Transform for RandomFlip {
403    fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
404        let mut rng = rand::thread_rng();
405        if rng.r#gen::<f32>() > self.probability {
406            return input.clone();
407        }
408
409        let shape = input.shape();
410        if self.dim >= shape.len() {
411            return input.clone();
412        }
413
414        let data = input.to_vec();
415        let ndim = shape.len();
416
417        // Generic N-dimensional flip along self.dim:
418        // Compute strides, then for each element map the flipped index.
419        let total = data.len();
420        let mut flipped = vec![0.0f32; total];
421
422        // Compute strides (row-major)
423        let mut strides = vec![1usize; ndim];
424        for i in (0..ndim - 1).rev() {
425            strides[i] = strides[i + 1] * shape[i + 1];
426        }
427
428        let dim = self.dim;
429        let dim_size = shape[dim];
430        let dim_stride = strides[dim];
431
432        for i in 0..total {
433            // Extract the coordinate along the flip dimension
434            let coord_in_dim = (i / dim_stride) % dim_size;
435            let flipped_coord = dim_size - 1 - coord_in_dim;
436            // Compute the source index with the flipped coordinate
437            let diff = flipped_coord as isize - coord_in_dim as isize;
438            let src = (i as isize + diff * dim_stride as isize) as usize;
439            flipped[i] = data[src];
440        }
441
442        Tensor::from_vec(flipped, shape).unwrap()
443    }
444}
445
446// =============================================================================
447// Scale
448// =============================================================================
449
450/// Scales tensor values by a constant factor.
451pub struct Scale {
452    factor: f32,
453}
454
455impl Scale {
456    /// Creates a new Scale transform.
457    #[must_use]
458    pub fn new(factor: f32) -> Self {
459        Self { factor }
460    }
461}
462
463impl Transform for Scale {
464    fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
465        input.mul_scalar(self.factor)
466    }
467}
468
469// =============================================================================
470// Clamp
471// =============================================================================
472
473/// Clamps tensor values to a specified range.
474pub struct Clamp {
475    min: f32,
476    max: f32,
477}
478
479impl Clamp {
480    /// Creates a new Clamp transform.
481    #[must_use]
482    pub fn new(min: f32, max: f32) -> Self {
483        Self { min, max }
484    }
485
486    /// Creates a Clamp for [0, 1] range.
487    #[must_use]
488    pub fn zero_one() -> Self {
489        Self::new(0.0, 1.0)
490    }
491
492    /// Creates a Clamp for [-1, 1] range.
493    #[must_use]
494    pub fn symmetric() -> Self {
495        Self::new(-1.0, 1.0)
496    }
497}
498
499impl Transform for Clamp {
500    fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
501        let data = input.to_vec();
502        let clamped: Vec<f32> = data.iter().map(|&x| x.clamp(self.min, self.max)).collect();
503        Tensor::from_vec(clamped, input.shape()).unwrap()
504    }
505}
506
507// =============================================================================
508// Flatten
509// =============================================================================
510
511/// Flattens the tensor to 1D.
512pub struct Flatten;
513
514impl Flatten {
515    /// Creates a new Flatten transform.
516    #[must_use]
517    pub fn new() -> Self {
518        Self
519    }
520}
521
522impl Default for Flatten {
523    fn default() -> Self {
524        Self::new()
525    }
526}
527
528impl Transform for Flatten {
529    fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
530        let data = input.to_vec();
531        Tensor::from_vec(data.clone(), &[data.len()]).unwrap()
532    }
533}
534
535// =============================================================================
536// Reshape
537// =============================================================================
538
539/// Reshapes the tensor to a specified shape.
540pub struct Reshape {
541    shape: Vec<usize>,
542}
543
544impl Reshape {
545    /// Creates a new Reshape transform.
546    #[must_use]
547    pub fn new(shape: Vec<usize>) -> Self {
548        Self { shape }
549    }
550}
551
552impl Transform for Reshape {
553    fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
554        let data = input.to_vec();
555        let expected_size: usize = self.shape.iter().product();
556
557        if data.len() != expected_size {
558            // Size mismatch, return original
559            return input.clone();
560        }
561
562        Tensor::from_vec(data, &self.shape).unwrap()
563    }
564}
565
566// =============================================================================
567// Dropout Transform
568// =============================================================================
569
570/// Applies dropout by randomly zeroing elements during training.
571///
572/// Respects train/eval mode: dropout is only applied when `training` is true.
573/// Use `set_training(false)` to disable dropout during evaluation.
574pub struct DropoutTransform {
575    probability: f32,
576    training: std::sync::atomic::AtomicBool,
577}
578
579impl DropoutTransform {
580    /// Creates a new `DropoutTransform` in training mode.
581    #[must_use]
582    pub fn new(probability: f32) -> Self {
583        Self {
584            probability: probability.clamp(0.0, 1.0),
585            training: std::sync::atomic::AtomicBool::new(true),
586        }
587    }
588
589    /// Sets whether this transform is in training mode.
590    pub fn set_training(&self, training: bool) {
591        self.training
592            .store(training, std::sync::atomic::Ordering::Relaxed);
593    }
594
595    /// Returns whether this transform is in training mode.
596    pub fn is_training(&self) -> bool {
597        self.training.load(std::sync::atomic::Ordering::Relaxed)
598    }
599}
600
601impl Transform for DropoutTransform {
602    fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
603        // No dropout in eval mode or with p=0
604        if !self.is_training() || self.probability == 0.0 {
605            return input.clone();
606        }
607
608        let mut rng = rand::thread_rng();
609        let scale = 1.0 / (1.0 - self.probability);
610        let data = input.to_vec();
611
612        let dropped: Vec<f32> = data
613            .iter()
614            .map(|&x| {
615                if rng.r#gen::<f32>() < self.probability {
616                    0.0
617                } else {
618                    x * scale
619                }
620            })
621            .collect();
622
623        Tensor::from_vec(dropped, input.shape()).unwrap()
624    }
625}
626
627// =============================================================================
628// Lambda Transform
629// =============================================================================
630
631/// Applies a custom function as a transform.
632pub struct Lambda<F>
633where
634    F: Fn(&Tensor<f32>) -> Tensor<f32> + Send + Sync,
635{
636    func: F,
637}
638
639impl<F> Lambda<F>
640where
641    F: Fn(&Tensor<f32>) -> Tensor<f32> + Send + Sync,
642{
643    /// Creates a new Lambda transform.
644    pub fn new(func: F) -> Self {
645        Self { func }
646    }
647}
648
649impl<F> Transform for Lambda<F>
650where
651    F: Fn(&Tensor<f32>) -> Tensor<f32> + Send + Sync,
652{
653    fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
654        (self.func)(input)
655    }
656}
657
658// =============================================================================
659// Tests
660// =============================================================================
661
662#[cfg(test)]
663mod tests {
664    use super::*;
665
666    #[test]
667    fn test_normalize() {
668        let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[4]).unwrap();
669        let normalize = Normalize::new(2.5, 0.5);
670
671        let output = normalize.apply(&input);
672        let expected = [-3.0, -1.0, 1.0, 3.0];
673
674        let result = output.to_vec();
675        for (a, b) in result.iter().zip(expected.iter()) {
676            assert!((a - b).abs() < 1e-6);
677        }
678    }
679
680    #[test]
681    fn test_normalize_per_channel() {
682        // 2 channels, 2x2 spatial => [2, 2, 2]
683        let input =
684            Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 10.0, 20.0, 30.0, 40.0], &[2, 2, 2]).unwrap();
685        let normalize = Normalize::per_channel(vec![0.0, 10.0], vec![1.0, 10.0]);
686
687        let output = normalize.apply(&input);
688        let result = output.to_vec();
689        // Channel 0: (x - 0) / 1 = x
690        assert!((result[0] - 1.0).abs() < 1e-6);
691        assert!((result[3] - 4.0).abs() < 1e-6);
692        // Channel 1: (x - 10) / 10
693        assert!((result[4] - 0.0).abs() < 1e-6); // (10-10)/10
694        assert!((result[5] - 1.0).abs() < 1e-6); // (20-10)/10
695    }
696
697    #[test]
698    fn test_scale() {
699        let input = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
700        let scale = Scale::new(2.0);
701
702        let output = scale.apply(&input);
703        assert_eq!(output.to_vec(), vec![2.0, 4.0, 6.0]);
704    }
705
706    #[test]
707    fn test_clamp() {
708        let input = Tensor::from_vec(vec![-1.0, 0.5, 2.0], &[3]).unwrap();
709        let clamp = Clamp::zero_one();
710
711        let output = clamp.apply(&input);
712        assert_eq!(output.to_vec(), vec![0.0, 0.5, 1.0]);
713    }
714
715    #[test]
716    fn test_flatten() {
717        let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
718        let flatten = Flatten::new();
719
720        let output = flatten.apply(&input);
721        assert_eq!(output.shape(), &[4]);
722        assert_eq!(output.to_vec(), vec![1.0, 2.0, 3.0, 4.0]);
723    }
724
725    #[test]
726    fn test_reshape() {
727        let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[6]).unwrap();
728        let reshape = Reshape::new(vec![2, 3]);
729
730        let output = reshape.apply(&input);
731        assert_eq!(output.shape(), &[2, 3]);
732    }
733
734    #[test]
735    fn test_compose() {
736        let normalize = Normalize::new(0.0, 1.0);
737        let scale = Scale::new(2.0);
738
739        let compose = Compose::new(vec![Box::new(normalize), Box::new(scale)]);
740
741        let input = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
742        let output = compose.apply(&input);
743
744        // normalize(x) = x, then scale by 2
745        assert_eq!(output.to_vec(), vec![2.0, 4.0, 6.0]);
746    }
747
748    #[test]
749    fn test_compose_builder() {
750        let compose = Compose::empty()
751            .add(Normalize::new(0.0, 1.0))
752            .add(Scale::new(2.0));
753
754        let input = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
755        let output = compose.apply(&input);
756
757        assert_eq!(output.to_vec(), vec![2.0, 4.0, 6.0]);
758    }
759
760    #[test]
761    fn test_random_noise() {
762        let input = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
763        let noise = RandomNoise::new(0.0);
764
765        // With std=0, output should equal input
766        let output = noise.apply(&input);
767        assert_eq!(output.to_vec(), vec![1.0, 2.0, 3.0]);
768    }
769
770    #[test]
771    fn test_random_flip_1d() {
772        let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[4]).unwrap();
773        let flip = RandomFlip::new(0, 1.0); // Always flip
774
775        let output = flip.apply(&input);
776        assert_eq!(output.to_vec(), vec![4.0, 3.0, 2.0, 1.0]);
777    }
778
779    #[test]
780    fn test_random_flip_2d_horizontal() {
781        let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
782        let flip = RandomFlip::new(1, 1.0); // Always flip horizontal
783
784        let output = flip.apply(&input);
785        // [[1, 2], [3, 4]] -> [[2, 1], [4, 3]]
786        assert_eq!(output.to_vec(), vec![2.0, 1.0, 4.0, 3.0]);
787    }
788
789    #[test]
790    fn test_random_flip_2d_vertical() {
791        let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
792        let flip = RandomFlip::new(0, 1.0); // Always flip vertical
793
794        let output = flip.apply(&input);
795        // [[1, 2], [3, 4]] -> [[3, 4], [1, 2]]
796        assert_eq!(output.to_vec(), vec![3.0, 4.0, 1.0, 2.0]);
797    }
798
799    #[test]
800    fn test_random_flip_3d() {
801        // C=1, H=2, W=2 — flip along dim=2 (horizontal in spatial)
802        let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[1, 2, 2]).unwrap();
803        let flip = RandomFlip::new(2, 1.0); // Always flip along W
804
805        let output = flip.apply(&input);
806        // [[1,2],[3,4]] → [[2,1],[4,3]]
807        assert_eq!(output.to_vec(), vec![2.0, 1.0, 4.0, 3.0]);
808        assert_eq!(output.shape(), &[1, 2, 2]);
809    }
810
811    #[test]
812    fn test_random_flip_4d() {
813        // N=1, C=1, H=2, W=2 — flip along dim=2 (vertical flip of H)
814        let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[1, 1, 2, 2]).unwrap();
815        let flip = RandomFlip::new(2, 1.0); // Flip along H
816
817        let output = flip.apply(&input);
818        // Rows flipped: [[1,2],[3,4]] → [[3,4],[1,2]]
819        assert_eq!(output.to_vec(), vec![3.0, 4.0, 1.0, 2.0]);
820        assert_eq!(output.shape(), &[1, 1, 2, 2]);
821    }
822
823    #[test]
824    fn test_dropout_eval_mode() {
825        let input = Tensor::from_vec(vec![1.0; 100], &[100]).unwrap();
826        let dropout = DropoutTransform::new(0.5);
827
828        // In training mode, should drop elements
829        let output_train = dropout.apply(&input);
830        let zeros_train = output_train.to_vec().iter().filter(|&&x| x == 0.0).count();
831        assert!(zeros_train > 0, "Training mode should drop elements");
832
833        // Switch to eval mode — should be identity
834        dropout.set_training(false);
835        let output_eval = dropout.apply(&input);
836        assert_eq!(output_eval.to_vec(), vec![1.0; 100]);
837    }
838
839    #[test]
840    fn test_dropout_transform() {
841        let input = Tensor::from_vec(vec![1.0; 1000], &[1000]).unwrap();
842        let dropout = DropoutTransform::new(0.5);
843
844        let output = dropout.apply(&input);
845        let output_vec = output.to_vec();
846
847        // About half should be zero
848        let zeros = output_vec.iter().filter(|&&x| x == 0.0).count();
849        assert!(
850            zeros > 300 && zeros < 700,
851            "Expected ~500 zeros, got {zeros}"
852        );
853
854        // Non-zeros should be scaled by 2 (1/(1-0.5))
855        let nonzeros: Vec<f32> = output_vec.iter().filter(|&&x| x != 0.0).copied().collect();
856        for x in nonzeros {
857            assert!((x - 2.0).abs() < 1e-6);
858        }
859    }
860
861    #[test]
862    fn test_lambda() {
863        let lambda = Lambda::new(|t: &Tensor<f32>| t.mul_scalar(3.0));
864
865        let input = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
866        let output = lambda.apply(&input);
867
868        assert_eq!(output.to_vec(), vec![3.0, 6.0, 9.0]);
869    }
870
871    #[test]
872    fn test_to_tensor() {
873        let input = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
874        let to_tensor = ToTensor::new();
875
876        let output = to_tensor.apply(&input);
877        assert_eq!(output.to_vec(), input.to_vec());
878    }
879
880    #[test]
881    fn test_normalize_variants() {
882        let standard = Normalize::standard();
883        assert_eq!(standard.mean, vec![0.0]);
884        assert_eq!(standard.std, vec![1.0]);
885
886        let zero_centered = Normalize::zero_centered();
887        assert_eq!(zero_centered.mean, vec![0.5]);
888        assert_eq!(zero_centered.std, vec![0.5]);
889    }
890
891    #[test]
892    fn test_random_crop_1d() {
893        let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0], &[5]).unwrap();
894        let crop = RandomCrop::new(vec![3]);
895
896        let output = crop.apply(&input);
897        assert_eq!(output.shape(), &[3]);
898    }
899
900    #[test]
901    fn test_random_crop_2d() {
902        // 4x4 image
903        let input = Tensor::from_vec((1..=16).map(|x| x as f32).collect(), &[4, 4]).unwrap();
904        let crop = RandomCrop::new_2d(2, 2);
905
906        let output = crop.apply(&input);
907        assert_eq!(output.shape(), &[2, 2]);
908        // Verify values are contiguous from the original
909        let vals = output.to_vec();
910        assert_eq!(vals.len(), 4);
911    }
912
913    #[test]
914    fn test_random_crop_3d() {
915        // 2 channels x 4x4 image
916        let input = Tensor::from_vec((1..=32).map(|x| x as f32).collect(), &[2, 4, 4]).unwrap();
917        let crop = RandomCrop::new_2d(2, 2);
918
919        let output = crop.apply(&input);
920        assert_eq!(output.shape(), &[2, 2, 2]); // 2 channels, 2x2 spatial
921    }
922}