Skip to main content

axonml_data/
transforms.rs

1//! Transforms - Data Augmentation and Preprocessing
2//!
3//! # File
4//! `crates/axonml-data/src/transforms.rs`
5//!
6//! # Author
7//! Andrew Jewell Sr - AutomataNexus
8//!
9//! # Updated
10//! March 8, 2026
11//!
12//! # Disclaimer
13//! Use at own risk. This software is provided "as is", without warranty of any
14//! kind, express or implied. The author and AutomataNexus shall not be held
15//! liable for any damages arising from the use of this software.
16
17use axonml_tensor::Tensor;
18use rand::Rng;
19
20// =============================================================================
21// Transform Trait
22// =============================================================================
23
24/// Trait for data transformations.
25pub trait Transform: Send + Sync {
26    /// Applies the transform to a tensor.
27    fn apply(&self, input: &Tensor<f32>) -> Tensor<f32>;
28}
29
30// =============================================================================
31// Compose
32// =============================================================================
33
34/// Composes multiple transforms into a single transform.
35pub struct Compose {
36    transforms: Vec<Box<dyn Transform>>,
37}
38
39impl Compose {
40    /// Creates a new Compose from a vector of transforms.
41    #[must_use]
42    pub fn new(transforms: Vec<Box<dyn Transform>>) -> Self {
43        Self { transforms }
44    }
45
46    /// Creates an empty Compose.
47    #[must_use]
48    pub fn empty() -> Self {
49        Self {
50            transforms: Vec::new(),
51        }
52    }
53
54    /// Adds a transform to the composition.
55    pub fn add<T: Transform + 'static>(mut self, transform: T) -> Self {
56        self.transforms.push(Box::new(transform));
57        self
58    }
59}
60
61impl Transform for Compose {
62    fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
63        let mut result = input.clone();
64        for transform in &self.transforms {
65            result = transform.apply(&result);
66        }
67        result
68    }
69}
70
71// =============================================================================
72// ToTensor
73// =============================================================================
74
75/// Converts input to a tensor (identity for already-tensor inputs).
76pub struct ToTensor;
77
78impl ToTensor {
79    /// Creates a new `ToTensor` transform.
80    #[must_use]
81    pub fn new() -> Self {
82        Self
83    }
84}
85
86impl Default for ToTensor {
87    fn default() -> Self {
88        Self::new()
89    }
90}
91
92impl Transform for ToTensor {
93    fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
94        input.clone()
95    }
96}
97
98// =============================================================================
99// Normalize
100// =============================================================================
101
102/// Normalizes a tensor with mean and standard deviation.
103///
104/// Supports both scalar normalization (applied uniformly) and per-channel
105/// normalization (PyTorch-style `transforms.Normalize(mean=[...], std=[...])`).
106/// For per-channel mode, the tensor is expected to have shape `[C, H, W]` or
107/// `[N, C, H, W]`.
108pub struct Normalize {
109    mean: Vec<f32>,
110    std: Vec<f32>,
111}
112
113impl Normalize {
114    /// Creates a new scalar Normalize transform (applied uniformly to all elements).
115    #[must_use]
116    pub fn new(mean: f32, std: f32) -> Self {
117        Self {
118            mean: vec![mean],
119            std: vec![std],
120        }
121    }
122
123    /// Creates a per-channel Normalize transform (PyTorch-compatible).
124    ///
125    /// For a `[C, H, W]` tensor, `mean` and `std` must have length `C`.
126    /// Each channel is normalized independently: `output[c] = (input[c] - mean[c]) / std[c]`.
127    #[must_use]
128    pub fn per_channel(mean: Vec<f32>, std: Vec<f32>) -> Self {
129        assert_eq!(mean.len(), std.len(), "mean and std must have same length");
130        Self { mean, std }
131    }
132
133    /// Creates a Normalize for standard normal distribution (mean=0, std=1).
134    #[must_use]
135    pub fn standard() -> Self {
136        Self::new(0.0, 1.0)
137    }
138
139    /// Creates a Normalize for [0,1] to [-1,1] conversion (mean=0.5, std=0.5).
140    #[must_use]
141    pub fn zero_centered() -> Self {
142        Self::new(0.5, 0.5)
143    }
144
145    /// Creates a Normalize for ImageNet (3-channel RGB).
146    #[must_use]
147    pub fn imagenet() -> Self {
148        Self::per_channel(
149            vec![0.485, 0.456, 0.406],
150            vec![0.229, 0.224, 0.225],
151        )
152    }
153}
154
155impl Transform for Normalize {
156    fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
157        let shape = input.shape();
158        let mut data = input.to_vec();
159
160        if self.mean.len() == 1 {
161            // Scalar normalization — apply uniformly
162            let m = self.mean[0];
163            let s = self.std[0];
164            for x in &mut data {
165                *x = (*x - m) / s;
166            }
167        } else {
168            // Per-channel normalization
169            let num_channels = self.mean.len();
170
171            if shape.len() == 3 && shape[0] == num_channels {
172                // [C, H, W]
173                let spatial = shape[1] * shape[2];
174                for c in 0..num_channels {
175                    let offset = c * spatial;
176                    let m = self.mean[c];
177                    let s = self.std[c];
178                    for i in 0..spatial {
179                        data[offset + i] = (data[offset + i] - m) / s;
180                    }
181                }
182            } else if shape.len() == 4 && shape[1] == num_channels {
183                // [N, C, H, W]
184                let spatial = shape[2] * shape[3];
185                let sample_size = num_channels * spatial;
186                for n in 0..shape[0] {
187                    for c in 0..num_channels {
188                        let offset = n * sample_size + c * spatial;
189                        let m = self.mean[c];
190                        let s = self.std[c];
191                        for i in 0..spatial {
192                            data[offset + i] = (data[offset + i] - m) / s;
193                        }
194                    }
195                }
196            } else {
197                // Fallback: apply first channel's mean/std uniformly
198                let m = self.mean[0];
199                let s = self.std[0];
200                for x in &mut data {
201                    *x = (*x - m) / s;
202                }
203            }
204        }
205
206        Tensor::from_vec(data, shape).unwrap()
207    }
208}
209
210// =============================================================================
211// RandomNoise
212// =============================================================================
213
214/// Adds random Gaussian noise to the input.
215pub struct RandomNoise {
216    std: f32,
217}
218
219impl RandomNoise {
220    /// Creates a new `RandomNoise` transform.
221    #[must_use]
222    pub fn new(std: f32) -> Self {
223        Self { std }
224    }
225}
226
227impl Transform for RandomNoise {
228    fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
229        if self.std == 0.0 {
230            return input.clone();
231        }
232
233        let mut rng = rand::thread_rng();
234        let data = input.to_vec();
235        let noisy: Vec<f32> = data
236            .iter()
237            .map(|&x| {
238                // Box-Muller transform for Gaussian noise
239                let u1: f32 = rng.r#gen();
240                let u2: f32 = rng.r#gen();
241                let z = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f32::consts::PI * u2).cos();
242                x + z * self.std
243            })
244            .collect();
245        Tensor::from_vec(noisy, input.shape()).unwrap()
246    }
247}
248
249// =============================================================================
250// RandomCrop
251// =============================================================================
252
253/// Randomly crops a portion of the input.
254pub struct RandomCrop {
255    size: Vec<usize>,
256}
257
258impl RandomCrop {
259    /// Creates a new `RandomCrop` with target size.
260    #[must_use]
261    pub fn new(size: Vec<usize>) -> Self {
262        Self { size }
263    }
264
265    /// Creates a `RandomCrop` for 2D images.
266    #[must_use]
267    pub fn new_2d(height: usize, width: usize) -> Self {
268        Self::new(vec![height, width])
269    }
270}
271
272impl Transform for RandomCrop {
273    fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
274        let shape = input.shape();
275
276        // Determine spatial dimensions (last N dimensions where N = size.len())
277        if shape.len() < self.size.len() {
278            return input.clone();
279        }
280
281        let spatial_start = shape.len() - self.size.len();
282        let mut rng = rand::thread_rng();
283
284        // Calculate random offsets for each spatial dimension
285        let mut offsets = Vec::with_capacity(self.size.len());
286        for (i, &target_dim) in self.size.iter().enumerate() {
287            let input_dim = shape[spatial_start + i];
288            if input_dim <= target_dim {
289                offsets.push(0);
290            } else {
291                offsets.push(rng.gen_range(0..=input_dim - target_dim));
292            }
293        }
294
295        // Calculate actual crop sizes (clamped to input dimensions)
296        let crop_sizes: Vec<usize> = self
297            .size
298            .iter()
299            .enumerate()
300            .map(|(i, &s)| s.min(shape[spatial_start + i]))
301            .collect();
302
303        let data = input.to_vec();
304
305        // Handle 1D case
306        if shape.len() == 1 && self.size.len() == 1 {
307            let start = offsets[0];
308            let end = start + crop_sizes[0];
309            let cropped = data[start..end].to_vec();
310            let len = cropped.len();
311            return Tensor::from_vec(cropped, &[len]).unwrap();
312        }
313
314        // Handle 2D case (H x W)
315        if shape.len() == 2 && self.size.len() == 2 {
316            let (_h, w) = (shape[0], shape[1]);
317            let (crop_h, crop_w) = (crop_sizes[0], crop_sizes[1]);
318            let (off_h, off_w) = (offsets[0], offsets[1]);
319
320            let mut cropped = Vec::with_capacity(crop_h * crop_w);
321            for row in off_h..off_h + crop_h {
322                for col in off_w..off_w + crop_w {
323                    cropped.push(data[row * w + col]);
324                }
325            }
326            return Tensor::from_vec(cropped, &[crop_h, crop_w]).unwrap();
327        }
328
329        // Handle 3D case (C x H x W) - common for images
330        if shape.len() == 3 && self.size.len() == 2 {
331            let (c, h, w) = (shape[0], shape[1], shape[2]);
332            let (crop_h, crop_w) = (crop_sizes[0], crop_sizes[1]);
333            let (off_h, off_w) = (offsets[0], offsets[1]);
334
335            let mut cropped = Vec::with_capacity(c * crop_h * crop_w);
336            for channel in 0..c {
337                for row in off_h..off_h + crop_h {
338                    for col in off_w..off_w + crop_w {
339                        cropped.push(data[channel * h * w + row * w + col]);
340                    }
341                }
342            }
343            return Tensor::from_vec(cropped, &[c, crop_h, crop_w]).unwrap();
344        }
345
346        // Handle 4D case (N x C x H x W) - batched images
347        if shape.len() == 4 && self.size.len() == 2 {
348            let (n, c, h, w) = (shape[0], shape[1], shape[2], shape[3]);
349            let (crop_h, crop_w) = (crop_sizes[0], crop_sizes[1]);
350            let (off_h, off_w) = (offsets[0], offsets[1]);
351
352            let mut cropped = Vec::with_capacity(n * c * crop_h * crop_w);
353            for batch in 0..n {
354                for channel in 0..c {
355                    for row in off_h..off_h + crop_h {
356                        for col in off_w..off_w + crop_w {
357                            let idx = batch * c * h * w + channel * h * w + row * w + col;
358                            cropped.push(data[idx]);
359                        }
360                    }
361                }
362            }
363            return Tensor::from_vec(cropped, &[n, c, crop_h, crop_w]).unwrap();
364        }
365
366        // Fallback for unsupported dimensions - shouldn't reach here in practice
367        input.clone()
368    }
369}
370
371// =============================================================================
372// RandomFlip
373// =============================================================================
374
375/// Randomly flips the input along a specified dimension.
376pub struct RandomFlip {
377    dim: usize,
378    probability: f32,
379}
380
381impl RandomFlip {
382    /// Creates a new `RandomFlip`.
383    #[must_use]
384    pub fn new(dim: usize, probability: f32) -> Self {
385        Self {
386            dim,
387            probability: probability.clamp(0.0, 1.0),
388        }
389    }
390
391    /// Creates a horizontal flip (dim=1 for `HxW` images).
392    #[must_use]
393    pub fn horizontal() -> Self {
394        Self::new(1, 0.5)
395    }
396
397    /// Creates a vertical flip (dim=0 for `HxW` images).
398    #[must_use]
399    pub fn vertical() -> Self {
400        Self::new(0, 0.5)
401    }
402}
403
404impl Transform for RandomFlip {
405    fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
406        let mut rng = rand::thread_rng();
407        if rng.r#gen::<f32>() > self.probability {
408            return input.clone();
409        }
410
411        let shape = input.shape();
412        if self.dim >= shape.len() {
413            return input.clone();
414        }
415
416        let data = input.to_vec();
417        let ndim = shape.len();
418
419        // Generic N-dimensional flip along self.dim:
420        // Compute strides, then for each element map the flipped index.
421        let total = data.len();
422        let mut flipped = vec![0.0f32; total];
423
424        // Compute strides (row-major)
425        let mut strides = vec![1usize; ndim];
426        for i in (0..ndim - 1).rev() {
427            strides[i] = strides[i + 1] * shape[i + 1];
428        }
429
430        let dim = self.dim;
431        let dim_size = shape[dim];
432        let dim_stride = strides[dim];
433
434        for i in 0..total {
435            // Extract the coordinate along the flip dimension
436            let coord_in_dim = (i / dim_stride) % dim_size;
437            let flipped_coord = dim_size - 1 - coord_in_dim;
438            // Compute the source index with the flipped coordinate
439            let diff = flipped_coord as isize - coord_in_dim as isize;
440            let src = (i as isize + diff * dim_stride as isize) as usize;
441            flipped[i] = data[src];
442        }
443
444        Tensor::from_vec(flipped, shape).unwrap()
445    }
446}
447
448// =============================================================================
449// Scale
450// =============================================================================
451
452/// Scales tensor values by a constant factor.
453pub struct Scale {
454    factor: f32,
455}
456
457impl Scale {
458    /// Creates a new Scale transform.
459    #[must_use]
460    pub fn new(factor: f32) -> Self {
461        Self { factor }
462    }
463}
464
465impl Transform for Scale {
466    fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
467        input.mul_scalar(self.factor)
468    }
469}
470
471// =============================================================================
472// Clamp
473// =============================================================================
474
475/// Clamps tensor values to a specified range.
476pub struct Clamp {
477    min: f32,
478    max: f32,
479}
480
481impl Clamp {
482    /// Creates a new Clamp transform.
483    #[must_use]
484    pub fn new(min: f32, max: f32) -> Self {
485        Self { min, max }
486    }
487
488    /// Creates a Clamp for [0, 1] range.
489    #[must_use]
490    pub fn zero_one() -> Self {
491        Self::new(0.0, 1.0)
492    }
493
494    /// Creates a Clamp for [-1, 1] range.
495    #[must_use]
496    pub fn symmetric() -> Self {
497        Self::new(-1.0, 1.0)
498    }
499}
500
501impl Transform for Clamp {
502    fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
503        let data = input.to_vec();
504        let clamped: Vec<f32> = data.iter().map(|&x| x.clamp(self.min, self.max)).collect();
505        Tensor::from_vec(clamped, input.shape()).unwrap()
506    }
507}
508
509// =============================================================================
510// Flatten
511// =============================================================================
512
513/// Flattens the tensor to 1D.
514pub struct Flatten;
515
516impl Flatten {
517    /// Creates a new Flatten transform.
518    #[must_use]
519    pub fn new() -> Self {
520        Self
521    }
522}
523
524impl Default for Flatten {
525    fn default() -> Self {
526        Self::new()
527    }
528}
529
530impl Transform for Flatten {
531    fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
532        let data = input.to_vec();
533        Tensor::from_vec(data.clone(), &[data.len()]).unwrap()
534    }
535}
536
537// =============================================================================
538// Reshape
539// =============================================================================
540
541/// Reshapes the tensor to a specified shape.
542pub struct Reshape {
543    shape: Vec<usize>,
544}
545
546impl Reshape {
547    /// Creates a new Reshape transform.
548    #[must_use]
549    pub fn new(shape: Vec<usize>) -> Self {
550        Self { shape }
551    }
552}
553
554impl Transform for Reshape {
555    fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
556        let data = input.to_vec();
557        let expected_size: usize = self.shape.iter().product();
558
559        if data.len() != expected_size {
560            // Size mismatch, return original
561            return input.clone();
562        }
563
564        Tensor::from_vec(data, &self.shape).unwrap()
565    }
566}
567
568// =============================================================================
569// Dropout Transform
570// =============================================================================
571
572/// Applies dropout by randomly zeroing elements during training.
573///
574/// Respects train/eval mode: dropout is only applied when `training` is true.
575/// Use `set_training(false)` to disable dropout during evaluation.
576pub struct DropoutTransform {
577    probability: f32,
578    training: std::sync::atomic::AtomicBool,
579}
580
581impl DropoutTransform {
582    /// Creates a new `DropoutTransform` in training mode.
583    #[must_use]
584    pub fn new(probability: f32) -> Self {
585        Self {
586            probability: probability.clamp(0.0, 1.0),
587            training: std::sync::atomic::AtomicBool::new(true),
588        }
589    }
590
591    /// Sets whether this transform is in training mode.
592    pub fn set_training(&self, training: bool) {
593        self.training.store(training, std::sync::atomic::Ordering::Relaxed);
594    }
595
596    /// Returns whether this transform is in training mode.
597    pub fn is_training(&self) -> bool {
598        self.training.load(std::sync::atomic::Ordering::Relaxed)
599    }
600}
601
602impl Transform for DropoutTransform {
603    fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
604        // No dropout in eval mode or with p=0
605        if !self.is_training() || self.probability == 0.0 {
606            return input.clone();
607        }
608
609        let mut rng = rand::thread_rng();
610        let scale = 1.0 / (1.0 - self.probability);
611        let data = input.to_vec();
612
613        let dropped: Vec<f32> = data
614            .iter()
615            .map(|&x| {
616                if rng.r#gen::<f32>() < self.probability {
617                    0.0
618                } else {
619                    x * scale
620                }
621            })
622            .collect();
623
624        Tensor::from_vec(dropped, input.shape()).unwrap()
625    }
626}
627
628// =============================================================================
629// Lambda Transform
630// =============================================================================
631
632/// Applies a custom function as a transform.
633pub struct Lambda<F>
634where
635    F: Fn(&Tensor<f32>) -> Tensor<f32> + Send + Sync,
636{
637    func: F,
638}
639
640impl<F> Lambda<F>
641where
642    F: Fn(&Tensor<f32>) -> Tensor<f32> + Send + Sync,
643{
644    /// Creates a new Lambda transform.
645    pub fn new(func: F) -> Self {
646        Self { func }
647    }
648}
649
650impl<F> Transform for Lambda<F>
651where
652    F: Fn(&Tensor<f32>) -> Tensor<f32> + Send + Sync,
653{
654    fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
655        (self.func)(input)
656    }
657}
658
659// =============================================================================
660// Tests
661// =============================================================================
662
663#[cfg(test)]
664mod tests {
665    use super::*;
666
667    #[test]
668    fn test_normalize() {
669        let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[4]).unwrap();
670        let normalize = Normalize::new(2.5, 0.5);
671
672        let output = normalize.apply(&input);
673        let expected = [-3.0, -1.0, 1.0, 3.0];
674
675        let result = output.to_vec();
676        for (a, b) in result.iter().zip(expected.iter()) {
677            assert!((a - b).abs() < 1e-6);
678        }
679    }
680
681    #[test]
682    fn test_normalize_per_channel() {
683        // 2 channels, 2x2 spatial => [2, 2, 2]
684        let input = Tensor::from_vec(
685            vec![1.0, 2.0, 3.0, 4.0, 10.0, 20.0, 30.0, 40.0],
686            &[2, 2, 2],
687        )
688        .unwrap();
689        let normalize = Normalize::per_channel(vec![0.0, 10.0], vec![1.0, 10.0]);
690
691        let output = normalize.apply(&input);
692        let result = output.to_vec();
693        // Channel 0: (x - 0) / 1 = x
694        assert!((result[0] - 1.0).abs() < 1e-6);
695        assert!((result[3] - 4.0).abs() < 1e-6);
696        // Channel 1: (x - 10) / 10
697        assert!((result[4] - 0.0).abs() < 1e-6); // (10-10)/10
698        assert!((result[5] - 1.0).abs() < 1e-6); // (20-10)/10
699    }
700
701    #[test]
702    fn test_scale() {
703        let input = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
704        let scale = Scale::new(2.0);
705
706        let output = scale.apply(&input);
707        assert_eq!(output.to_vec(), vec![2.0, 4.0, 6.0]);
708    }
709
710    #[test]
711    fn test_clamp() {
712        let input = Tensor::from_vec(vec![-1.0, 0.5, 2.0], &[3]).unwrap();
713        let clamp = Clamp::zero_one();
714
715        let output = clamp.apply(&input);
716        assert_eq!(output.to_vec(), vec![0.0, 0.5, 1.0]);
717    }
718
719    #[test]
720    fn test_flatten() {
721        let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
722        let flatten = Flatten::new();
723
724        let output = flatten.apply(&input);
725        assert_eq!(output.shape(), &[4]);
726        assert_eq!(output.to_vec(), vec![1.0, 2.0, 3.0, 4.0]);
727    }
728
729    #[test]
730    fn test_reshape() {
731        let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[6]).unwrap();
732        let reshape = Reshape::new(vec![2, 3]);
733
734        let output = reshape.apply(&input);
735        assert_eq!(output.shape(), &[2, 3]);
736    }
737
738    #[test]
739    fn test_compose() {
740        let normalize = Normalize::new(0.0, 1.0);
741        let scale = Scale::new(2.0);
742
743        let compose = Compose::new(vec![Box::new(normalize), Box::new(scale)]);
744
745        let input = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
746        let output = compose.apply(&input);
747
748        // normalize(x) = x, then scale by 2
749        assert_eq!(output.to_vec(), vec![2.0, 4.0, 6.0]);
750    }
751
752    #[test]
753    fn test_compose_builder() {
754        let compose = Compose::empty()
755            .add(Normalize::new(0.0, 1.0))
756            .add(Scale::new(2.0));
757
758        let input = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
759        let output = compose.apply(&input);
760
761        assert_eq!(output.to_vec(), vec![2.0, 4.0, 6.0]);
762    }
763
764    #[test]
765    fn test_random_noise() {
766        let input = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
767        let noise = RandomNoise::new(0.0);
768
769        // With std=0, output should equal input
770        let output = noise.apply(&input);
771        assert_eq!(output.to_vec(), vec![1.0, 2.0, 3.0]);
772    }
773
774    #[test]
775    fn test_random_flip_1d() {
776        let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[4]).unwrap();
777        let flip = RandomFlip::new(0, 1.0); // Always flip
778
779        let output = flip.apply(&input);
780        assert_eq!(output.to_vec(), vec![4.0, 3.0, 2.0, 1.0]);
781    }
782
783    #[test]
784    fn test_random_flip_2d_horizontal() {
785        let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
786        let flip = RandomFlip::new(1, 1.0); // Always flip horizontal
787
788        let output = flip.apply(&input);
789        // [[1, 2], [3, 4]] -> [[2, 1], [4, 3]]
790        assert_eq!(output.to_vec(), vec![2.0, 1.0, 4.0, 3.0]);
791    }
792
793    #[test]
794    fn test_random_flip_2d_vertical() {
795        let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
796        let flip = RandomFlip::new(0, 1.0); // Always flip vertical
797
798        let output = flip.apply(&input);
799        // [[1, 2], [3, 4]] -> [[3, 4], [1, 2]]
800        assert_eq!(output.to_vec(), vec![3.0, 4.0, 1.0, 2.0]);
801    }
802
803    #[test]
804    fn test_random_flip_3d() {
805        // C=1, H=2, W=2 — flip along dim=2 (horizontal in spatial)
806        let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[1, 2, 2]).unwrap();
807        let flip = RandomFlip::new(2, 1.0); // Always flip along W
808
809        let output = flip.apply(&input);
810        // [[1,2],[3,4]] → [[2,1],[4,3]]
811        assert_eq!(output.to_vec(), vec![2.0, 1.0, 4.0, 3.0]);
812        assert_eq!(output.shape(), &[1, 2, 2]);
813    }
814
815    #[test]
816    fn test_random_flip_4d() {
817        // N=1, C=1, H=2, W=2 — flip along dim=2 (vertical flip of H)
818        let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[1, 1, 2, 2]).unwrap();
819        let flip = RandomFlip::new(2, 1.0); // Flip along H
820
821        let output = flip.apply(&input);
822        // Rows flipped: [[1,2],[3,4]] → [[3,4],[1,2]]
823        assert_eq!(output.to_vec(), vec![3.0, 4.0, 1.0, 2.0]);
824        assert_eq!(output.shape(), &[1, 1, 2, 2]);
825    }
826
827    #[test]
828    fn test_dropout_eval_mode() {
829        let input = Tensor::from_vec(vec![1.0; 100], &[100]).unwrap();
830        let dropout = DropoutTransform::new(0.5);
831
832        // In training mode, should drop elements
833        let output_train = dropout.apply(&input);
834        let zeros_train = output_train.to_vec().iter().filter(|&&x| x == 0.0).count();
835        assert!(zeros_train > 0, "Training mode should drop elements");
836
837        // Switch to eval mode — should be identity
838        dropout.set_training(false);
839        let output_eval = dropout.apply(&input);
840        assert_eq!(output_eval.to_vec(), vec![1.0; 100]);
841    }
842
843    #[test]
844    fn test_dropout_transform() {
845        let input = Tensor::from_vec(vec![1.0; 1000], &[1000]).unwrap();
846        let dropout = DropoutTransform::new(0.5);
847
848        let output = dropout.apply(&input);
849        let output_vec = output.to_vec();
850
851        // About half should be zero
852        let zeros = output_vec.iter().filter(|&&x| x == 0.0).count();
853        assert!(
854            zeros > 300 && zeros < 700,
855            "Expected ~500 zeros, got {zeros}"
856        );
857
858        // Non-zeros should be scaled by 2 (1/(1-0.5))
859        let nonzeros: Vec<f32> = output_vec.iter().filter(|&&x| x != 0.0).copied().collect();
860        for x in nonzeros {
861            assert!((x - 2.0).abs() < 1e-6);
862        }
863    }
864
865    #[test]
866    fn test_lambda() {
867        let lambda = Lambda::new(|t: &Tensor<f32>| t.mul_scalar(3.0));
868
869        let input = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
870        let output = lambda.apply(&input);
871
872        assert_eq!(output.to_vec(), vec![3.0, 6.0, 9.0]);
873    }
874
875    #[test]
876    fn test_to_tensor() {
877        let input = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
878        let to_tensor = ToTensor::new();
879
880        let output = to_tensor.apply(&input);
881        assert_eq!(output.to_vec(), input.to_vec());
882    }
883
884    #[test]
885    fn test_normalize_variants() {
886        let standard = Normalize::standard();
887        assert_eq!(standard.mean, vec![0.0]);
888        assert_eq!(standard.std, vec![1.0]);
889
890        let zero_centered = Normalize::zero_centered();
891        assert_eq!(zero_centered.mean, vec![0.5]);
892        assert_eq!(zero_centered.std, vec![0.5]);
893    }
894
895    #[test]
896    fn test_random_crop_1d() {
897        let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0], &[5]).unwrap();
898        let crop = RandomCrop::new(vec![3]);
899
900        let output = crop.apply(&input);
901        assert_eq!(output.shape(), &[3]);
902    }
903
904    #[test]
905    fn test_random_crop_2d() {
906        // 4x4 image
907        let input = Tensor::from_vec((1..=16).map(|x| x as f32).collect(), &[4, 4]).unwrap();
908        let crop = RandomCrop::new_2d(2, 2);
909
910        let output = crop.apply(&input);
911        assert_eq!(output.shape(), &[2, 2]);
912        // Verify values are contiguous from the original
913        let vals = output.to_vec();
914        assert_eq!(vals.len(), 4);
915    }
916
917    #[test]
918    fn test_random_crop_3d() {
919        // 2 channels x 4x4 image
920        let input = Tensor::from_vec((1..=32).map(|x| x as f32).collect(), &[2, 4, 4]).unwrap();
921        let crop = RandomCrop::new_2d(2, 2);
922
923        let output = crop.apply(&input);
924        assert_eq!(output.shape(), &[2, 2, 2]); // 2 channels, 2x2 spatial
925    }
926}