Skip to main content

axonml_data/
transforms.rs

1//! Transforms - Data Augmentation and Preprocessing
2//!
3//! # File
4//! `crates/axonml-data/src/transforms.rs`
5//!
6//! # Author
7//! Andrew Jewell Sr - AutomataNexus
8//!
9//! # Updated
10//! March 8, 2026
11//!
12//! # Disclaimer
13//! Use at own risk. This software is provided "as is", without warranty of any
14//! kind, express or implied. The author and AutomataNexus shall not be held
15//! liable for any damages arising from the use of this software.
16
17use axonml_tensor::Tensor;
18use rand::Rng;
19
20// =============================================================================
21// Transform Trait
22// =============================================================================
23
24/// Trait for data transformations.
25pub trait Transform: Send + Sync {
26    /// Applies the transform to a tensor.
27    fn apply(&self, input: &Tensor<f32>) -> Tensor<f32>;
28}
29
30// =============================================================================
31// Compose
32// =============================================================================
33
34/// Composes multiple transforms into a single transform.
35pub struct Compose {
36    transforms: Vec<Box<dyn Transform>>,
37}
38
39impl Compose {
40    /// Creates a new Compose from a vector of transforms.
41    #[must_use]
42    pub fn new(transforms: Vec<Box<dyn Transform>>) -> Self {
43        Self { transforms }
44    }
45
46    /// Creates an empty Compose.
47    #[must_use]
48    pub fn empty() -> Self {
49        Self {
50            transforms: Vec::new(),
51        }
52    }
53
54    /// Adds a transform to the composition.
55    pub fn add<T: Transform + 'static>(mut self, transform: T) -> Self {
56        self.transforms.push(Box::new(transform));
57        self
58    }
59}
60
61impl Transform for Compose {
62    fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
63        let mut result = input.clone();
64        for transform in &self.transforms {
65            result = transform.apply(&result);
66        }
67        result
68    }
69}
70
71// =============================================================================
72// ToTensor
73// =============================================================================
74
75/// Converts input to a tensor (identity for already-tensor inputs).
76pub struct ToTensor;
77
78impl ToTensor {
79    /// Creates a new `ToTensor` transform.
80    #[must_use]
81    pub fn new() -> Self {
82        Self
83    }
84}
85
86impl Default for ToTensor {
87    fn default() -> Self {
88        Self::new()
89    }
90}
91
92impl Transform for ToTensor {
93    fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
94        input.clone()
95    }
96}
97
98// =============================================================================
99// Normalize
100// =============================================================================
101
102/// Normalizes a tensor with mean and standard deviation.
103pub struct Normalize {
104    mean: f32,
105    std: f32,
106}
107
108impl Normalize {
109    /// Creates a new Normalize transform.
110    #[must_use]
111    pub fn new(mean: f32, std: f32) -> Self {
112        Self { mean, std }
113    }
114
115    /// Creates a Normalize for standard normal distribution (mean=0, std=1).
116    #[must_use]
117    pub fn standard() -> Self {
118        Self::new(0.0, 1.0)
119    }
120
121    /// Creates a Normalize for [0,1] to [-1,1] conversion (mean=0.5, std=0.5).
122    #[must_use]
123    pub fn zero_centered() -> Self {
124        Self::new(0.5, 0.5)
125    }
126}
127
128impl Transform for Normalize {
129    fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
130        let data = input.to_vec();
131        let normalized: Vec<f32> = data.iter().map(|&x| (x - self.mean) / self.std).collect();
132        Tensor::from_vec(normalized, input.shape()).unwrap()
133    }
134}
135
136// =============================================================================
137// RandomNoise
138// =============================================================================
139
140/// Adds random Gaussian noise to the input.
141pub struct RandomNoise {
142    std: f32,
143}
144
145impl RandomNoise {
146    /// Creates a new `RandomNoise` transform.
147    #[must_use]
148    pub fn new(std: f32) -> Self {
149        Self { std }
150    }
151}
152
153impl Transform for RandomNoise {
154    fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
155        if self.std == 0.0 {
156            return input.clone();
157        }
158
159        let mut rng = rand::thread_rng();
160        let data = input.to_vec();
161        let noisy: Vec<f32> = data
162            .iter()
163            .map(|&x| {
164                // Box-Muller transform for Gaussian noise
165                let u1: f32 = rng.r#gen();
166                let u2: f32 = rng.r#gen();
167                let z = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f32::consts::PI * u2).cos();
168                x + z * self.std
169            })
170            .collect();
171        Tensor::from_vec(noisy, input.shape()).unwrap()
172    }
173}
174
175// =============================================================================
176// RandomCrop
177// =============================================================================
178
179/// Randomly crops a portion of the input.
180pub struct RandomCrop {
181    size: Vec<usize>,
182}
183
184impl RandomCrop {
185    /// Creates a new `RandomCrop` with target size.
186    #[must_use]
187    pub fn new(size: Vec<usize>) -> Self {
188        Self { size }
189    }
190
191    /// Creates a `RandomCrop` for 2D images.
192    #[must_use]
193    pub fn new_2d(height: usize, width: usize) -> Self {
194        Self::new(vec![height, width])
195    }
196}
197
198impl Transform for RandomCrop {
199    fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
200        let shape = input.shape();
201
202        // Determine spatial dimensions (last N dimensions where N = size.len())
203        if shape.len() < self.size.len() {
204            return input.clone();
205        }
206
207        let spatial_start = shape.len() - self.size.len();
208        let mut rng = rand::thread_rng();
209
210        // Calculate random offsets for each spatial dimension
211        let mut offsets = Vec::with_capacity(self.size.len());
212        for (i, &target_dim) in self.size.iter().enumerate() {
213            let input_dim = shape[spatial_start + i];
214            if input_dim <= target_dim {
215                offsets.push(0);
216            } else {
217                offsets.push(rng.gen_range(0..=input_dim - target_dim));
218            }
219        }
220
221        // Calculate actual crop sizes (clamped to input dimensions)
222        let crop_sizes: Vec<usize> = self
223            .size
224            .iter()
225            .enumerate()
226            .map(|(i, &s)| s.min(shape[spatial_start + i]))
227            .collect();
228
229        let data = input.to_vec();
230
231        // Handle 1D case
232        if shape.len() == 1 && self.size.len() == 1 {
233            let start = offsets[0];
234            let end = start + crop_sizes[0];
235            let cropped = data[start..end].to_vec();
236            let len = cropped.len();
237            return Tensor::from_vec(cropped, &[len]).unwrap();
238        }
239
240        // Handle 2D case (H x W)
241        if shape.len() == 2 && self.size.len() == 2 {
242            let (_h, w) = (shape[0], shape[1]);
243            let (crop_h, crop_w) = (crop_sizes[0], crop_sizes[1]);
244            let (off_h, off_w) = (offsets[0], offsets[1]);
245
246            let mut cropped = Vec::with_capacity(crop_h * crop_w);
247            for row in off_h..off_h + crop_h {
248                for col in off_w..off_w + crop_w {
249                    cropped.push(data[row * w + col]);
250                }
251            }
252            return Tensor::from_vec(cropped, &[crop_h, crop_w]).unwrap();
253        }
254
255        // Handle 3D case (C x H x W) - common for images
256        if shape.len() == 3 && self.size.len() == 2 {
257            let (c, h, w) = (shape[0], shape[1], shape[2]);
258            let (crop_h, crop_w) = (crop_sizes[0], crop_sizes[1]);
259            let (off_h, off_w) = (offsets[0], offsets[1]);
260
261            let mut cropped = Vec::with_capacity(c * crop_h * crop_w);
262            for channel in 0..c {
263                for row in off_h..off_h + crop_h {
264                    for col in off_w..off_w + crop_w {
265                        cropped.push(data[channel * h * w + row * w + col]);
266                    }
267                }
268            }
269            return Tensor::from_vec(cropped, &[c, crop_h, crop_w]).unwrap();
270        }
271
272        // Handle 4D case (N x C x H x W) - batched images
273        if shape.len() == 4 && self.size.len() == 2 {
274            let (n, c, h, w) = (shape[0], shape[1], shape[2], shape[3]);
275            let (crop_h, crop_w) = (crop_sizes[0], crop_sizes[1]);
276            let (off_h, off_w) = (offsets[0], offsets[1]);
277
278            let mut cropped = Vec::with_capacity(n * c * crop_h * crop_w);
279            for batch in 0..n {
280                for channel in 0..c {
281                    for row in off_h..off_h + crop_h {
282                        for col in off_w..off_w + crop_w {
283                            let idx = batch * c * h * w + channel * h * w + row * w + col;
284                            cropped.push(data[idx]);
285                        }
286                    }
287                }
288            }
289            return Tensor::from_vec(cropped, &[n, c, crop_h, crop_w]).unwrap();
290        }
291
292        // Fallback for unsupported dimensions - shouldn't reach here in practice
293        input.clone()
294    }
295}
296
297// =============================================================================
298// RandomFlip
299// =============================================================================
300
301/// Randomly flips the input along a specified dimension.
302pub struct RandomFlip {
303    dim: usize,
304    probability: f32,
305}
306
307impl RandomFlip {
308    /// Creates a new `RandomFlip`.
309    #[must_use]
310    pub fn new(dim: usize, probability: f32) -> Self {
311        Self {
312            dim,
313            probability: probability.clamp(0.0, 1.0),
314        }
315    }
316
317    /// Creates a horizontal flip (dim=1 for `HxW` images).
318    #[must_use]
319    pub fn horizontal() -> Self {
320        Self::new(1, 0.5)
321    }
322
323    /// Creates a vertical flip (dim=0 for `HxW` images).
324    #[must_use]
325    pub fn vertical() -> Self {
326        Self::new(0, 0.5)
327    }
328}
329
330impl Transform for RandomFlip {
331    fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
332        let mut rng = rand::thread_rng();
333        if rng.r#gen::<f32>() > self.probability {
334            return input.clone();
335        }
336
337        let shape = input.shape();
338        if self.dim >= shape.len() {
339            return input.clone();
340        }
341
342        // Simple 1D flip implementation
343        if shape.len() == 1 {
344            let mut data = input.to_vec();
345            data.reverse();
346            return Tensor::from_vec(data, shape).unwrap();
347        }
348
349        // For 2D, flip along the specified dimension
350        if shape.len() == 2 {
351            let data = input.to_vec();
352            let (rows, cols) = (shape[0], shape[1]);
353            let mut flipped = vec![0.0; data.len()];
354
355            if self.dim == 0 {
356                // Vertical flip
357                for r in 0..rows {
358                    for c in 0..cols {
359                        flipped[r * cols + c] = data[(rows - 1 - r) * cols + c];
360                    }
361                }
362            } else {
363                // Horizontal flip
364                for r in 0..rows {
365                    for c in 0..cols {
366                        flipped[r * cols + c] = data[r * cols + (cols - 1 - c)];
367                    }
368                }
369            }
370
371            return Tensor::from_vec(flipped, shape).unwrap();
372        }
373
374        input.clone()
375    }
376}
377
378// =============================================================================
379// Scale
380// =============================================================================
381
382/// Scales tensor values by a constant factor.
383pub struct Scale {
384    factor: f32,
385}
386
387impl Scale {
388    /// Creates a new Scale transform.
389    #[must_use]
390    pub fn new(factor: f32) -> Self {
391        Self { factor }
392    }
393}
394
395impl Transform for Scale {
396    fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
397        input.mul_scalar(self.factor)
398    }
399}
400
401// =============================================================================
402// Clamp
403// =============================================================================
404
405/// Clamps tensor values to a specified range.
406pub struct Clamp {
407    min: f32,
408    max: f32,
409}
410
411impl Clamp {
412    /// Creates a new Clamp transform.
413    #[must_use]
414    pub fn new(min: f32, max: f32) -> Self {
415        Self { min, max }
416    }
417
418    /// Creates a Clamp for [0, 1] range.
419    #[must_use]
420    pub fn zero_one() -> Self {
421        Self::new(0.0, 1.0)
422    }
423
424    /// Creates a Clamp for [-1, 1] range.
425    #[must_use]
426    pub fn symmetric() -> Self {
427        Self::new(-1.0, 1.0)
428    }
429}
430
431impl Transform for Clamp {
432    fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
433        let data = input.to_vec();
434        let clamped: Vec<f32> = data.iter().map(|&x| x.clamp(self.min, self.max)).collect();
435        Tensor::from_vec(clamped, input.shape()).unwrap()
436    }
437}
438
439// =============================================================================
440// Flatten
441// =============================================================================
442
443/// Flattens the tensor to 1D.
444pub struct Flatten;
445
446impl Flatten {
447    /// Creates a new Flatten transform.
448    #[must_use]
449    pub fn new() -> Self {
450        Self
451    }
452}
453
454impl Default for Flatten {
455    fn default() -> Self {
456        Self::new()
457    }
458}
459
460impl Transform for Flatten {
461    fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
462        let data = input.to_vec();
463        Tensor::from_vec(data.clone(), &[data.len()]).unwrap()
464    }
465}
466
467// =============================================================================
468// Reshape
469// =============================================================================
470
471/// Reshapes the tensor to a specified shape.
472pub struct Reshape {
473    shape: Vec<usize>,
474}
475
476impl Reshape {
477    /// Creates a new Reshape transform.
478    #[must_use]
479    pub fn new(shape: Vec<usize>) -> Self {
480        Self { shape }
481    }
482}
483
484impl Transform for Reshape {
485    fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
486        let data = input.to_vec();
487        let expected_size: usize = self.shape.iter().product();
488
489        if data.len() != expected_size {
490            // Size mismatch, return original
491            return input.clone();
492        }
493
494        Tensor::from_vec(data, &self.shape).unwrap()
495    }
496}
497
498// =============================================================================
499// Dropout Transform
500// =============================================================================
501
502/// Applies dropout by randomly zeroing elements during training.
503pub struct DropoutTransform {
504    probability: f32,
505}
506
507impl DropoutTransform {
508    /// Creates a new `DropoutTransform`.
509    #[must_use]
510    pub fn new(probability: f32) -> Self {
511        Self {
512            probability: probability.clamp(0.0, 1.0),
513        }
514    }
515}
516
517impl Transform for DropoutTransform {
518    fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
519        if self.probability == 0.0 {
520            return input.clone();
521        }
522
523        let mut rng = rand::thread_rng();
524        let scale = 1.0 / (1.0 - self.probability);
525        let data = input.to_vec();
526
527        let dropped: Vec<f32> = data
528            .iter()
529            .map(|&x| {
530                if rng.r#gen::<f32>() < self.probability {
531                    0.0
532                } else {
533                    x * scale
534                }
535            })
536            .collect();
537
538        Tensor::from_vec(dropped, input.shape()).unwrap()
539    }
540}
541
542// =============================================================================
543// Lambda Transform
544// =============================================================================
545
546/// Applies a custom function as a transform.
547pub struct Lambda<F>
548where
549    F: Fn(&Tensor<f32>) -> Tensor<f32> + Send + Sync,
550{
551    func: F,
552}
553
554impl<F> Lambda<F>
555where
556    F: Fn(&Tensor<f32>) -> Tensor<f32> + Send + Sync,
557{
558    /// Creates a new Lambda transform.
559    pub fn new(func: F) -> Self {
560        Self { func }
561    }
562}
563
564impl<F> Transform for Lambda<F>
565where
566    F: Fn(&Tensor<f32>) -> Tensor<f32> + Send + Sync,
567{
568    fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
569        (self.func)(input)
570    }
571}
572
573// =============================================================================
574// Tests
575// =============================================================================
576
577#[cfg(test)]
578mod tests {
579    use super::*;
580
581    #[test]
582    fn test_normalize() {
583        let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[4]).unwrap();
584        let normalize = Normalize::new(2.5, 0.5);
585
586        let output = normalize.apply(&input);
587        let expected = [-3.0, -1.0, 1.0, 3.0];
588
589        let result = output.to_vec();
590        for (a, b) in result.iter().zip(expected.iter()) {
591            assert!((a - b).abs() < 1e-6);
592        }
593    }
594
595    #[test]
596    fn test_scale() {
597        let input = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
598        let scale = Scale::new(2.0);
599
600        let output = scale.apply(&input);
601        assert_eq!(output.to_vec(), vec![2.0, 4.0, 6.0]);
602    }
603
604    #[test]
605    fn test_clamp() {
606        let input = Tensor::from_vec(vec![-1.0, 0.5, 2.0], &[3]).unwrap();
607        let clamp = Clamp::zero_one();
608
609        let output = clamp.apply(&input);
610        assert_eq!(output.to_vec(), vec![0.0, 0.5, 1.0]);
611    }
612
613    #[test]
614    fn test_flatten() {
615        let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
616        let flatten = Flatten::new();
617
618        let output = flatten.apply(&input);
619        assert_eq!(output.shape(), &[4]);
620        assert_eq!(output.to_vec(), vec![1.0, 2.0, 3.0, 4.0]);
621    }
622
623    #[test]
624    fn test_reshape() {
625        let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[6]).unwrap();
626        let reshape = Reshape::new(vec![2, 3]);
627
628        let output = reshape.apply(&input);
629        assert_eq!(output.shape(), &[2, 3]);
630    }
631
632    #[test]
633    fn test_compose() {
634        let normalize = Normalize::new(0.0, 1.0);
635        let scale = Scale::new(2.0);
636
637        let compose = Compose::new(vec![Box::new(normalize), Box::new(scale)]);
638
639        let input = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
640        let output = compose.apply(&input);
641
642        // normalize(x) = x, then scale by 2
643        assert_eq!(output.to_vec(), vec![2.0, 4.0, 6.0]);
644    }
645
646    #[test]
647    fn test_compose_builder() {
648        let compose = Compose::empty()
649            .add(Normalize::new(0.0, 1.0))
650            .add(Scale::new(2.0));
651
652        let input = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
653        let output = compose.apply(&input);
654
655        assert_eq!(output.to_vec(), vec![2.0, 4.0, 6.0]);
656    }
657
658    #[test]
659    fn test_random_noise() {
660        let input = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
661        let noise = RandomNoise::new(0.0);
662
663        // With std=0, output should equal input
664        let output = noise.apply(&input);
665        assert_eq!(output.to_vec(), vec![1.0, 2.0, 3.0]);
666    }
667
668    #[test]
669    fn test_random_flip_1d() {
670        let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[4]).unwrap();
671        let flip = RandomFlip::new(0, 1.0); // Always flip
672
673        let output = flip.apply(&input);
674        assert_eq!(output.to_vec(), vec![4.0, 3.0, 2.0, 1.0]);
675    }
676
677    #[test]
678    fn test_random_flip_2d_horizontal() {
679        let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
680        let flip = RandomFlip::new(1, 1.0); // Always flip horizontal
681
682        let output = flip.apply(&input);
683        // [[1, 2], [3, 4]] -> [[2, 1], [4, 3]]
684        assert_eq!(output.to_vec(), vec![2.0, 1.0, 4.0, 3.0]);
685    }
686
687    #[test]
688    fn test_random_flip_2d_vertical() {
689        let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
690        let flip = RandomFlip::new(0, 1.0); // Always flip vertical
691
692        let output = flip.apply(&input);
693        // [[1, 2], [3, 4]] -> [[3, 4], [1, 2]]
694        assert_eq!(output.to_vec(), vec![3.0, 4.0, 1.0, 2.0]);
695    }
696
697    #[test]
698    fn test_dropout_transform() {
699        let input = Tensor::from_vec(vec![1.0; 1000], &[1000]).unwrap();
700        let dropout = DropoutTransform::new(0.5);
701
702        let output = dropout.apply(&input);
703        let output_vec = output.to_vec();
704
705        // About half should be zero
706        let zeros = output_vec.iter().filter(|&&x| x == 0.0).count();
707        assert!(
708            zeros > 300 && zeros < 700,
709            "Expected ~500 zeros, got {zeros}"
710        );
711
712        // Non-zeros should be scaled by 2 (1/(1-0.5))
713        let nonzeros: Vec<f32> = output_vec.iter().filter(|&&x| x != 0.0).copied().collect();
714        for x in nonzeros {
715            assert!((x - 2.0).abs() < 1e-6);
716        }
717    }
718
719    #[test]
720    fn test_lambda() {
721        let lambda = Lambda::new(|t: &Tensor<f32>| t.mul_scalar(3.0));
722
723        let input = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
724        let output = lambda.apply(&input);
725
726        assert_eq!(output.to_vec(), vec![3.0, 6.0, 9.0]);
727    }
728
729    #[test]
730    fn test_to_tensor() {
731        let input = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
732        let to_tensor = ToTensor::new();
733
734        let output = to_tensor.apply(&input);
735        assert_eq!(output.to_vec(), input.to_vec());
736    }
737
738    #[test]
739    fn test_normalize_variants() {
740        let standard = Normalize::standard();
741        assert_eq!(standard.mean, 0.0);
742        assert_eq!(standard.std, 1.0);
743
744        let zero_centered = Normalize::zero_centered();
745        assert_eq!(zero_centered.mean, 0.5);
746        assert_eq!(zero_centered.std, 0.5);
747    }
748
749    #[test]
750    fn test_random_crop_1d() {
751        let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0], &[5]).unwrap();
752        let crop = RandomCrop::new(vec![3]);
753
754        let output = crop.apply(&input);
755        assert_eq!(output.shape(), &[3]);
756    }
757
758    #[test]
759    fn test_random_crop_2d() {
760        // 4x4 image
761        let input = Tensor::from_vec((1..=16).map(|x| x as f32).collect(), &[4, 4]).unwrap();
762        let crop = RandomCrop::new_2d(2, 2);
763
764        let output = crop.apply(&input);
765        assert_eq!(output.shape(), &[2, 2]);
766        // Verify values are contiguous from the original
767        let vals = output.to_vec();
768        assert_eq!(vals.len(), 4);
769    }
770
771    #[test]
772    fn test_random_crop_3d() {
773        // 2 channels x 4x4 image
774        let input = Tensor::from_vec((1..=32).map(|x| x as f32).collect(), &[2, 4, 4]).unwrap();
775        let crop = RandomCrop::new_2d(2, 2);
776
777        let output = crop.apply(&input);
778        assert_eq!(output.shape(), &[2, 2, 2]); // 2 channels, 2x2 spatial
779    }
780}