axonml_data/
transforms.rs

1//! Transforms - Data Augmentation and Preprocessing
2//!
3//! Provides composable transformations for data preprocessing and augmentation.
4//!
5//! @version 0.1.0
6//! @author `AutomataNexus` Development Team
7
8use axonml_tensor::Tensor;
9use rand::Rng;
10
11// =============================================================================
12// Transform Trait
13// =============================================================================
14
15/// Trait for data transformations.
16pub trait Transform: Send + Sync {
17    /// Applies the transform to a tensor.
18    fn apply(&self, input: &Tensor<f32>) -> Tensor<f32>;
19}
20
21// =============================================================================
22// Compose
23// =============================================================================
24
25/// Composes multiple transforms into a single transform.
26pub struct Compose {
27    transforms: Vec<Box<dyn Transform>>,
28}
29
30impl Compose {
31    /// Creates a new Compose from a vector of transforms.
32    #[must_use] pub fn new(transforms: Vec<Box<dyn Transform>>) -> Self {
33        Self { transforms }
34    }
35
36    /// Creates an empty Compose.
37    #[must_use] pub fn empty() -> Self {
38        Self {
39            transforms: Vec::new(),
40        }
41    }
42
43    /// Adds a transform to the composition.
44    pub fn add<T: Transform + 'static>(mut self, transform: T) -> Self {
45        self.transforms.push(Box::new(transform));
46        self
47    }
48}
49
50impl Transform for Compose {
51    fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
52        let mut result = input.clone();
53        for transform in &self.transforms {
54            result = transform.apply(&result);
55        }
56        result
57    }
58}
59
60// =============================================================================
61// ToTensor
62// =============================================================================
63
64/// Converts input to a tensor (identity for already-tensor inputs).
65pub struct ToTensor;
66
67impl ToTensor {
68    /// Creates a new `ToTensor` transform.
69    #[must_use] pub fn new() -> Self {
70        Self
71    }
72}
73
74impl Default for ToTensor {
75    fn default() -> Self {
76        Self::new()
77    }
78}
79
80impl Transform for ToTensor {
81    fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
82        input.clone()
83    }
84}
85
86// =============================================================================
87// Normalize
88// =============================================================================
89
90/// Normalizes a tensor with mean and standard deviation.
91pub struct Normalize {
92    mean: f32,
93    std: f32,
94}
95
96impl Normalize {
97    /// Creates a new Normalize transform.
98    #[must_use] pub fn new(mean: f32, std: f32) -> Self {
99        Self { mean, std }
100    }
101
102    /// Creates a Normalize for standard normal distribution (mean=0, std=1).
103    #[must_use] pub fn standard() -> Self {
104        Self::new(0.0, 1.0)
105    }
106
107    /// Creates a Normalize for [0,1] to [-1,1] conversion (mean=0.5, std=0.5).
108    #[must_use] pub fn zero_centered() -> Self {
109        Self::new(0.5, 0.5)
110    }
111}
112
113impl Transform for Normalize {
114    fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
115        let data = input.to_vec();
116        let normalized: Vec<f32> = data.iter().map(|&x| (x - self.mean) / self.std).collect();
117        Tensor::from_vec(normalized, input.shape()).unwrap()
118    }
119}
120
121// =============================================================================
122// RandomNoise
123// =============================================================================
124
125/// Adds random Gaussian noise to the input.
126pub struct RandomNoise {
127    std: f32,
128}
129
130impl RandomNoise {
131    /// Creates a new `RandomNoise` transform.
132    #[must_use] pub fn new(std: f32) -> Self {
133        Self { std }
134    }
135}
136
137impl Transform for RandomNoise {
138    fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
139        if self.std == 0.0 {
140            return input.clone();
141        }
142
143        let mut rng = rand::thread_rng();
144        let data = input.to_vec();
145        let noisy: Vec<f32> = data
146            .iter()
147            .map(|&x| {
148                // Box-Muller transform for Gaussian noise
149                let u1: f32 = rng.gen();
150                let u2: f32 = rng.gen();
151                let z = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f32::consts::PI * u2).cos();
152                x + z * self.std
153            })
154            .collect();
155        Tensor::from_vec(noisy, input.shape()).unwrap()
156    }
157}
158
159// =============================================================================
160// RandomCrop
161// =============================================================================
162
163/// Randomly crops a portion of the input.
164pub struct RandomCrop {
165    size: Vec<usize>,
166}
167
168impl RandomCrop {
169    /// Creates a new `RandomCrop` with target size.
170    #[must_use] pub fn new(size: Vec<usize>) -> Self {
171        Self { size }
172    }
173
174    /// Creates a `RandomCrop` for 2D images.
175    #[must_use] pub fn new_2d(height: usize, width: usize) -> Self {
176        Self::new(vec![height, width])
177    }
178}
179
180impl Transform for RandomCrop {
181    fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
182        let shape = input.shape();
183
184        // Determine spatial dimensions (last N dimensions where N = size.len())
185        if shape.len() < self.size.len() {
186            return input.clone();
187        }
188
189        let spatial_start = shape.len() - self.size.len();
190        let mut rng = rand::thread_rng();
191
192        // Calculate random offsets for each spatial dimension
193        let mut offsets = Vec::with_capacity(self.size.len());
194        for (i, &target_dim) in self.size.iter().enumerate() {
195            let input_dim = shape[spatial_start + i];
196            if input_dim <= target_dim {
197                offsets.push(0);
198            } else {
199                offsets.push(rng.gen_range(0..=input_dim - target_dim));
200            }
201        }
202
203        // Calculate actual crop sizes (clamped to input dimensions)
204        let crop_sizes: Vec<usize> = self
205            .size
206            .iter()
207            .enumerate()
208            .map(|(i, &s)| s.min(shape[spatial_start + i]))
209            .collect();
210
211        let data = input.to_vec();
212
213        // Handle 1D case
214        if shape.len() == 1 && self.size.len() == 1 {
215            let start = offsets[0];
216            let end = start + crop_sizes[0];
217            let cropped = data[start..end].to_vec();
218            let len = cropped.len();
219            return Tensor::from_vec(cropped, &[len]).unwrap();
220        }
221
222        // Handle 2D case (H x W)
223        if shape.len() == 2 && self.size.len() == 2 {
224            let (_h, w) = (shape[0], shape[1]);
225            let (crop_h, crop_w) = (crop_sizes[0], crop_sizes[1]);
226            let (off_h, off_w) = (offsets[0], offsets[1]);
227
228            let mut cropped = Vec::with_capacity(crop_h * crop_w);
229            for row in off_h..off_h + crop_h {
230                for col in off_w..off_w + crop_w {
231                    cropped.push(data[row * w + col]);
232                }
233            }
234            return Tensor::from_vec(cropped, &[crop_h, crop_w]).unwrap();
235        }
236
237        // Handle 3D case (C x H x W) - common for images
238        if shape.len() == 3 && self.size.len() == 2 {
239            let (c, h, w) = (shape[0], shape[1], shape[2]);
240            let (crop_h, crop_w) = (crop_sizes[0], crop_sizes[1]);
241            let (off_h, off_w) = (offsets[0], offsets[1]);
242
243            let mut cropped = Vec::with_capacity(c * crop_h * crop_w);
244            for channel in 0..c {
245                for row in off_h..off_h + crop_h {
246                    for col in off_w..off_w + crop_w {
247                        cropped.push(data[channel * h * w + row * w + col]);
248                    }
249                }
250            }
251            return Tensor::from_vec(cropped, &[c, crop_h, crop_w]).unwrap();
252        }
253
254        // Handle 4D case (N x C x H x W) - batched images
255        if shape.len() == 4 && self.size.len() == 2 {
256            let (n, c, h, w) = (shape[0], shape[1], shape[2], shape[3]);
257            let (crop_h, crop_w) = (crop_sizes[0], crop_sizes[1]);
258            let (off_h, off_w) = (offsets[0], offsets[1]);
259
260            let mut cropped = Vec::with_capacity(n * c * crop_h * crop_w);
261            for batch in 0..n {
262                for channel in 0..c {
263                    for row in off_h..off_h + crop_h {
264                        for col in off_w..off_w + crop_w {
265                            let idx = batch * c * h * w + channel * h * w + row * w + col;
266                            cropped.push(data[idx]);
267                        }
268                    }
269                }
270            }
271            return Tensor::from_vec(cropped, &[n, c, crop_h, crop_w]).unwrap();
272        }
273
274        // Fallback for unsupported dimensions - shouldn't reach here in practice
275        input.clone()
276    }
277}
278
279// =============================================================================
280// RandomFlip
281// =============================================================================
282
283/// Randomly flips the input along a specified dimension.
284pub struct RandomFlip {
285    dim: usize,
286    probability: f32,
287}
288
289impl RandomFlip {
290    /// Creates a new `RandomFlip`.
291    #[must_use] pub fn new(dim: usize, probability: f32) -> Self {
292        Self {
293            dim,
294            probability: probability.clamp(0.0, 1.0),
295        }
296    }
297
298    /// Creates a horizontal flip (dim=1 for `HxW` images).
299    #[must_use] pub fn horizontal() -> Self {
300        Self::new(1, 0.5)
301    }
302
303    /// Creates a vertical flip (dim=0 for `HxW` images).
304    #[must_use] pub fn vertical() -> Self {
305        Self::new(0, 0.5)
306    }
307}
308
309impl Transform for RandomFlip {
310    fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
311        let mut rng = rand::thread_rng();
312        if rng.gen::<f32>() > self.probability {
313            return input.clone();
314        }
315
316        let shape = input.shape();
317        if self.dim >= shape.len() {
318            return input.clone();
319        }
320
321        // Simple 1D flip implementation
322        if shape.len() == 1 {
323            let mut data = input.to_vec();
324            data.reverse();
325            return Tensor::from_vec(data, shape).unwrap();
326        }
327
328        // For 2D, flip along the specified dimension
329        if shape.len() == 2 {
330            let data = input.to_vec();
331            let (rows, cols) = (shape[0], shape[1]);
332            let mut flipped = vec![0.0; data.len()];
333
334            if self.dim == 0 {
335                // Vertical flip
336                for r in 0..rows {
337                    for c in 0..cols {
338                        flipped[r * cols + c] = data[(rows - 1 - r) * cols + c];
339                    }
340                }
341            } else {
342                // Horizontal flip
343                for r in 0..rows {
344                    for c in 0..cols {
345                        flipped[r * cols + c] = data[r * cols + (cols - 1 - c)];
346                    }
347                }
348            }
349
350            return Tensor::from_vec(flipped, shape).unwrap();
351        }
352
353        input.clone()
354    }
355}
356
357// =============================================================================
358// Scale
359// =============================================================================
360
361/// Scales tensor values by a constant factor.
362pub struct Scale {
363    factor: f32,
364}
365
366impl Scale {
367    /// Creates a new Scale transform.
368    #[must_use] pub fn new(factor: f32) -> Self {
369        Self { factor }
370    }
371}
372
373impl Transform for Scale {
374    fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
375        input.mul_scalar(self.factor)
376    }
377}
378
379// =============================================================================
380// Clamp
381// =============================================================================
382
383/// Clamps tensor values to a specified range.
384pub struct Clamp {
385    min: f32,
386    max: f32,
387}
388
389impl Clamp {
390    /// Creates a new Clamp transform.
391    #[must_use] pub fn new(min: f32, max: f32) -> Self {
392        Self { min, max }
393    }
394
395    /// Creates a Clamp for [0, 1] range.
396    #[must_use] pub fn zero_one() -> Self {
397        Self::new(0.0, 1.0)
398    }
399
400    /// Creates a Clamp for [-1, 1] range.
401    #[must_use] pub fn symmetric() -> Self {
402        Self::new(-1.0, 1.0)
403    }
404}
405
406impl Transform for Clamp {
407    fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
408        let data = input.to_vec();
409        let clamped: Vec<f32> = data.iter().map(|&x| x.clamp(self.min, self.max)).collect();
410        Tensor::from_vec(clamped, input.shape()).unwrap()
411    }
412}
413
414// =============================================================================
415// Flatten
416// =============================================================================
417
418/// Flattens the tensor to 1D.
419pub struct Flatten;
420
421impl Flatten {
422    /// Creates a new Flatten transform.
423    #[must_use] pub fn new() -> Self {
424        Self
425    }
426}
427
428impl Default for Flatten {
429    fn default() -> Self {
430        Self::new()
431    }
432}
433
434impl Transform for Flatten {
435    fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
436        let data = input.to_vec();
437        Tensor::from_vec(data.clone(), &[data.len()]).unwrap()
438    }
439}
440
441// =============================================================================
442// Reshape
443// =============================================================================
444
445/// Reshapes the tensor to a specified shape.
446pub struct Reshape {
447    shape: Vec<usize>,
448}
449
450impl Reshape {
451    /// Creates a new Reshape transform.
452    #[must_use] pub fn new(shape: Vec<usize>) -> Self {
453        Self { shape }
454    }
455}
456
457impl Transform for Reshape {
458    fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
459        let data = input.to_vec();
460        let expected_size: usize = self.shape.iter().product();
461
462        if data.len() != expected_size {
463            // Size mismatch, return original
464            return input.clone();
465        }
466
467        Tensor::from_vec(data, &self.shape).unwrap()
468    }
469}
470
471// =============================================================================
472// Dropout Transform
473// =============================================================================
474
475/// Applies dropout by randomly zeroing elements during training.
476pub struct DropoutTransform {
477    probability: f32,
478}
479
480impl DropoutTransform {
481    /// Creates a new `DropoutTransform`.
482    #[must_use] pub fn new(probability: f32) -> Self {
483        Self {
484            probability: probability.clamp(0.0, 1.0),
485        }
486    }
487}
488
489impl Transform for DropoutTransform {
490    fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
491        if self.probability == 0.0 {
492            return input.clone();
493        }
494
495        let mut rng = rand::thread_rng();
496        let scale = 1.0 / (1.0 - self.probability);
497        let data = input.to_vec();
498
499        let dropped: Vec<f32> = data
500            .iter()
501            .map(|&x| {
502                if rng.gen::<f32>() < self.probability {
503                    0.0
504                } else {
505                    x * scale
506                }
507            })
508            .collect();
509
510        Tensor::from_vec(dropped, input.shape()).unwrap()
511    }
512}
513
514// =============================================================================
515// Lambda Transform
516// =============================================================================
517
518/// Applies a custom function as a transform.
519pub struct Lambda<F>
520where
521    F: Fn(&Tensor<f32>) -> Tensor<f32> + Send + Sync,
522{
523    func: F,
524}
525
526impl<F> Lambda<F>
527where
528    F: Fn(&Tensor<f32>) -> Tensor<f32> + Send + Sync,
529{
530    /// Creates a new Lambda transform.
531    pub fn new(func: F) -> Self {
532        Self { func }
533    }
534}
535
536impl<F> Transform for Lambda<F>
537where
538    F: Fn(&Tensor<f32>) -> Tensor<f32> + Send + Sync,
539{
540    fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
541        (self.func)(input)
542    }
543}
544
545// =============================================================================
546// Tests
547// =============================================================================
548
549#[cfg(test)]
550mod tests {
551    use super::*;
552
553    #[test]
554    fn test_normalize() {
555        let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[4]).unwrap();
556        let normalize = Normalize::new(2.5, 0.5);
557
558        let output = normalize.apply(&input);
559        let expected = [-3.0, -1.0, 1.0, 3.0];
560
561        let result = output.to_vec();
562        for (a, b) in result.iter().zip(expected.iter()) {
563            assert!((a - b).abs() < 1e-6);
564        }
565    }
566
567    #[test]
568    fn test_scale() {
569        let input = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
570        let scale = Scale::new(2.0);
571
572        let output = scale.apply(&input);
573        assert_eq!(output.to_vec(), vec![2.0, 4.0, 6.0]);
574    }
575
576    #[test]
577    fn test_clamp() {
578        let input = Tensor::from_vec(vec![-1.0, 0.5, 2.0], &[3]).unwrap();
579        let clamp = Clamp::zero_one();
580
581        let output = clamp.apply(&input);
582        assert_eq!(output.to_vec(), vec![0.0, 0.5, 1.0]);
583    }
584
585    #[test]
586    fn test_flatten() {
587        let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
588        let flatten = Flatten::new();
589
590        let output = flatten.apply(&input);
591        assert_eq!(output.shape(), &[4]);
592        assert_eq!(output.to_vec(), vec![1.0, 2.0, 3.0, 4.0]);
593    }
594
595    #[test]
596    fn test_reshape() {
597        let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[6]).unwrap();
598        let reshape = Reshape::new(vec![2, 3]);
599
600        let output = reshape.apply(&input);
601        assert_eq!(output.shape(), &[2, 3]);
602    }
603
604    #[test]
605    fn test_compose() {
606        let normalize = Normalize::new(0.0, 1.0);
607        let scale = Scale::new(2.0);
608
609        let compose = Compose::new(vec![Box::new(normalize), Box::new(scale)]);
610
611        let input = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
612        let output = compose.apply(&input);
613
614        // normalize(x) = x, then scale by 2
615        assert_eq!(output.to_vec(), vec![2.0, 4.0, 6.0]);
616    }
617
618    #[test]
619    fn test_compose_builder() {
620        let compose = Compose::empty()
621            .add(Normalize::new(0.0, 1.0))
622            .add(Scale::new(2.0));
623
624        let input = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
625        let output = compose.apply(&input);
626
627        assert_eq!(output.to_vec(), vec![2.0, 4.0, 6.0]);
628    }
629
630    #[test]
631    fn test_random_noise() {
632        let input = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
633        let noise = RandomNoise::new(0.0);
634
635        // With std=0, output should equal input
636        let output = noise.apply(&input);
637        assert_eq!(output.to_vec(), vec![1.0, 2.0, 3.0]);
638    }
639
640    #[test]
641    fn test_random_flip_1d() {
642        let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[4]).unwrap();
643        let flip = RandomFlip::new(0, 1.0); // Always flip
644
645        let output = flip.apply(&input);
646        assert_eq!(output.to_vec(), vec![4.0, 3.0, 2.0, 1.0]);
647    }
648
649    #[test]
650    fn test_random_flip_2d_horizontal() {
651        let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
652        let flip = RandomFlip::new(1, 1.0); // Always flip horizontal
653
654        let output = flip.apply(&input);
655        // [[1, 2], [3, 4]] -> [[2, 1], [4, 3]]
656        assert_eq!(output.to_vec(), vec![2.0, 1.0, 4.0, 3.0]);
657    }
658
659    #[test]
660    fn test_random_flip_2d_vertical() {
661        let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
662        let flip = RandomFlip::new(0, 1.0); // Always flip vertical
663
664        let output = flip.apply(&input);
665        // [[1, 2], [3, 4]] -> [[3, 4], [1, 2]]
666        assert_eq!(output.to_vec(), vec![3.0, 4.0, 1.0, 2.0]);
667    }
668
669    #[test]
670    fn test_dropout_transform() {
671        let input = Tensor::from_vec(vec![1.0; 1000], &[1000]).unwrap();
672        let dropout = DropoutTransform::new(0.5);
673
674        let output = dropout.apply(&input);
675        let output_vec = output.to_vec();
676
677        // About half should be zero
678        let zeros = output_vec.iter().filter(|&&x| x == 0.0).count();
679        assert!(
680            zeros > 300 && zeros < 700,
681            "Expected ~500 zeros, got {zeros}"
682        );
683
684        // Non-zeros should be scaled by 2 (1/(1-0.5))
685        let nonzeros: Vec<f32> = output_vec.iter().filter(|&&x| x != 0.0).copied().collect();
686        for x in nonzeros {
687            assert!((x - 2.0).abs() < 1e-6);
688        }
689    }
690
691    #[test]
692    fn test_lambda() {
693        let lambda = Lambda::new(|t: &Tensor<f32>| t.mul_scalar(3.0));
694
695        let input = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
696        let output = lambda.apply(&input);
697
698        assert_eq!(output.to_vec(), vec![3.0, 6.0, 9.0]);
699    }
700
701    #[test]
702    fn test_to_tensor() {
703        let input = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
704        let to_tensor = ToTensor::new();
705
706        let output = to_tensor.apply(&input);
707        assert_eq!(output.to_vec(), input.to_vec());
708    }
709
710    #[test]
711    fn test_normalize_variants() {
712        let standard = Normalize::standard();
713        assert_eq!(standard.mean, 0.0);
714        assert_eq!(standard.std, 1.0);
715
716        let zero_centered = Normalize::zero_centered();
717        assert_eq!(zero_centered.mean, 0.5);
718        assert_eq!(zero_centered.std, 0.5);
719    }
720
721    #[test]
722    fn test_random_crop_1d() {
723        let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0], &[5]).unwrap();
724        let crop = RandomCrop::new(vec![3]);
725
726        let output = crop.apply(&input);
727        assert_eq!(output.shape(), &[3]);
728    }
729
730    #[test]
731    fn test_random_crop_2d() {
732        // 4x4 image
733        let input = Tensor::from_vec((1..=16).map(|x| x as f32).collect(), &[4, 4]).unwrap();
734        let crop = RandomCrop::new_2d(2, 2);
735
736        let output = crop.apply(&input);
737        assert_eq!(output.shape(), &[2, 2]);
738        // Verify values are contiguous from the original
739        let vals = output.to_vec();
740        assert_eq!(vals.len(), 4);
741    }
742
743    #[test]
744    fn test_random_crop_3d() {
745        // 2 channels x 4x4 image
746        let input = Tensor::from_vec((1..=32).map(|x| x as f32).collect(), &[2, 4, 4]).unwrap();
747        let crop = RandomCrop::new_2d(2, 2);
748
749        let output = crop.apply(&input);
750        assert_eq!(output.shape(), &[2, 2, 2]); // 2 channels, 2x2 spatial
751    }
752}