Skip to main content

axonml_data/
transforms.rs

1//! Transforms - Data Augmentation and Preprocessing
2//!
3//! Provides composable transformations for data preprocessing and augmentation.
4//!
5//! @version 0.1.0
6//! @author `AutomataNexus` Development Team
7
8use axonml_tensor::Tensor;
9use rand::Rng;
10
11// =============================================================================
12// Transform Trait
13// =============================================================================
14
15/// Trait for data transformations.
16pub trait Transform: Send + Sync {
17    /// Applies the transform to a tensor.
18    fn apply(&self, input: &Tensor<f32>) -> Tensor<f32>;
19}
20
21// =============================================================================
22// Compose
23// =============================================================================
24
25/// Composes multiple transforms into a single transform.
26pub struct Compose {
27    transforms: Vec<Box<dyn Transform>>,
28}
29
30impl Compose {
31    /// Creates a new Compose from a vector of transforms.
32    #[must_use]
33    pub fn new(transforms: Vec<Box<dyn Transform>>) -> Self {
34        Self { transforms }
35    }
36
37    /// Creates an empty Compose.
38    #[must_use]
39    pub fn empty() -> Self {
40        Self {
41            transforms: Vec::new(),
42        }
43    }
44
45    /// Adds a transform to the composition.
46    pub fn add<T: Transform + 'static>(mut self, transform: T) -> Self {
47        self.transforms.push(Box::new(transform));
48        self
49    }
50}
51
52impl Transform for Compose {
53    fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
54        let mut result = input.clone();
55        for transform in &self.transforms {
56            result = transform.apply(&result);
57        }
58        result
59    }
60}
61
62// =============================================================================
63// ToTensor
64// =============================================================================
65
66/// Converts input to a tensor (identity for already-tensor inputs).
67pub struct ToTensor;
68
69impl ToTensor {
70    /// Creates a new `ToTensor` transform.
71    #[must_use]
72    pub fn new() -> Self {
73        Self
74    }
75}
76
77impl Default for ToTensor {
78    fn default() -> Self {
79        Self::new()
80    }
81}
82
83impl Transform for ToTensor {
84    fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
85        input.clone()
86    }
87}
88
89// =============================================================================
90// Normalize
91// =============================================================================
92
93/// Normalizes a tensor with mean and standard deviation.
94pub struct Normalize {
95    mean: f32,
96    std: f32,
97}
98
99impl Normalize {
100    /// Creates a new Normalize transform.
101    #[must_use]
102    pub fn new(mean: f32, std: f32) -> Self {
103        Self { mean, std }
104    }
105
106    /// Creates a Normalize for standard normal distribution (mean=0, std=1).
107    #[must_use]
108    pub fn standard() -> Self {
109        Self::new(0.0, 1.0)
110    }
111
112    /// Creates a Normalize for [0,1] to [-1,1] conversion (mean=0.5, std=0.5).
113    #[must_use]
114    pub fn zero_centered() -> Self {
115        Self::new(0.5, 0.5)
116    }
117}
118
119impl Transform for Normalize {
120    fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
121        let data = input.to_vec();
122        let normalized: Vec<f32> = data.iter().map(|&x| (x - self.mean) / self.std).collect();
123        Tensor::from_vec(normalized, input.shape()).unwrap()
124    }
125}
126
127// =============================================================================
128// RandomNoise
129// =============================================================================
130
131/// Adds random Gaussian noise to the input.
132pub struct RandomNoise {
133    std: f32,
134}
135
136impl RandomNoise {
137    /// Creates a new `RandomNoise` transform.
138    #[must_use]
139    pub fn new(std: f32) -> Self {
140        Self { std }
141    }
142}
143
144impl Transform for RandomNoise {
145    fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
146        if self.std == 0.0 {
147            return input.clone();
148        }
149
150        let mut rng = rand::thread_rng();
151        let data = input.to_vec();
152        let noisy: Vec<f32> = data
153            .iter()
154            .map(|&x| {
155                // Box-Muller transform for Gaussian noise
156                let u1: f32 = rng.gen();
157                let u2: f32 = rng.gen();
158                let z = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f32::consts::PI * u2).cos();
159                x + z * self.std
160            })
161            .collect();
162        Tensor::from_vec(noisy, input.shape()).unwrap()
163    }
164}
165
166// =============================================================================
167// RandomCrop
168// =============================================================================
169
170/// Randomly crops a portion of the input.
171pub struct RandomCrop {
172    size: Vec<usize>,
173}
174
175impl RandomCrop {
176    /// Creates a new `RandomCrop` with target size.
177    #[must_use]
178    pub fn new(size: Vec<usize>) -> Self {
179        Self { size }
180    }
181
182    /// Creates a `RandomCrop` for 2D images.
183    #[must_use]
184    pub fn new_2d(height: usize, width: usize) -> Self {
185        Self::new(vec![height, width])
186    }
187}
188
189impl Transform for RandomCrop {
190    fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
191        let shape = input.shape();
192
193        // Determine spatial dimensions (last N dimensions where N = size.len())
194        if shape.len() < self.size.len() {
195            return input.clone();
196        }
197
198        let spatial_start = shape.len() - self.size.len();
199        let mut rng = rand::thread_rng();
200
201        // Calculate random offsets for each spatial dimension
202        let mut offsets = Vec::with_capacity(self.size.len());
203        for (i, &target_dim) in self.size.iter().enumerate() {
204            let input_dim = shape[spatial_start + i];
205            if input_dim <= target_dim {
206                offsets.push(0);
207            } else {
208                offsets.push(rng.gen_range(0..=input_dim - target_dim));
209            }
210        }
211
212        // Calculate actual crop sizes (clamped to input dimensions)
213        let crop_sizes: Vec<usize> = self
214            .size
215            .iter()
216            .enumerate()
217            .map(|(i, &s)| s.min(shape[spatial_start + i]))
218            .collect();
219
220        let data = input.to_vec();
221
222        // Handle 1D case
223        if shape.len() == 1 && self.size.len() == 1 {
224            let start = offsets[0];
225            let end = start + crop_sizes[0];
226            let cropped = data[start..end].to_vec();
227            let len = cropped.len();
228            return Tensor::from_vec(cropped, &[len]).unwrap();
229        }
230
231        // Handle 2D case (H x W)
232        if shape.len() == 2 && self.size.len() == 2 {
233            let (_h, w) = (shape[0], shape[1]);
234            let (crop_h, crop_w) = (crop_sizes[0], crop_sizes[1]);
235            let (off_h, off_w) = (offsets[0], offsets[1]);
236
237            let mut cropped = Vec::with_capacity(crop_h * crop_w);
238            for row in off_h..off_h + crop_h {
239                for col in off_w..off_w + crop_w {
240                    cropped.push(data[row * w + col]);
241                }
242            }
243            return Tensor::from_vec(cropped, &[crop_h, crop_w]).unwrap();
244        }
245
246        // Handle 3D case (C x H x W) - common for images
247        if shape.len() == 3 && self.size.len() == 2 {
248            let (c, h, w) = (shape[0], shape[1], shape[2]);
249            let (crop_h, crop_w) = (crop_sizes[0], crop_sizes[1]);
250            let (off_h, off_w) = (offsets[0], offsets[1]);
251
252            let mut cropped = Vec::with_capacity(c * crop_h * crop_w);
253            for channel in 0..c {
254                for row in off_h..off_h + crop_h {
255                    for col in off_w..off_w + crop_w {
256                        cropped.push(data[channel * h * w + row * w + col]);
257                    }
258                }
259            }
260            return Tensor::from_vec(cropped, &[c, crop_h, crop_w]).unwrap();
261        }
262
263        // Handle 4D case (N x C x H x W) - batched images
264        if shape.len() == 4 && self.size.len() == 2 {
265            let (n, c, h, w) = (shape[0], shape[1], shape[2], shape[3]);
266            let (crop_h, crop_w) = (crop_sizes[0], crop_sizes[1]);
267            let (off_h, off_w) = (offsets[0], offsets[1]);
268
269            let mut cropped = Vec::with_capacity(n * c * crop_h * crop_w);
270            for batch in 0..n {
271                for channel in 0..c {
272                    for row in off_h..off_h + crop_h {
273                        for col in off_w..off_w + crop_w {
274                            let idx = batch * c * h * w + channel * h * w + row * w + col;
275                            cropped.push(data[idx]);
276                        }
277                    }
278                }
279            }
280            return Tensor::from_vec(cropped, &[n, c, crop_h, crop_w]).unwrap();
281        }
282
283        // Fallback for unsupported dimensions - shouldn't reach here in practice
284        input.clone()
285    }
286}
287
288// =============================================================================
289// RandomFlip
290// =============================================================================
291
292/// Randomly flips the input along a specified dimension.
293pub struct RandomFlip {
294    dim: usize,
295    probability: f32,
296}
297
298impl RandomFlip {
299    /// Creates a new `RandomFlip`.
300    #[must_use]
301    pub fn new(dim: usize, probability: f32) -> Self {
302        Self {
303            dim,
304            probability: probability.clamp(0.0, 1.0),
305        }
306    }
307
308    /// Creates a horizontal flip (dim=1 for `HxW` images).
309    #[must_use]
310    pub fn horizontal() -> Self {
311        Self::new(1, 0.5)
312    }
313
314    /// Creates a vertical flip (dim=0 for `HxW` images).
315    #[must_use]
316    pub fn vertical() -> Self {
317        Self::new(0, 0.5)
318    }
319}
320
321impl Transform for RandomFlip {
322    fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
323        let mut rng = rand::thread_rng();
324        if rng.gen::<f32>() > self.probability {
325            return input.clone();
326        }
327
328        let shape = input.shape();
329        if self.dim >= shape.len() {
330            return input.clone();
331        }
332
333        // Simple 1D flip implementation
334        if shape.len() == 1 {
335            let mut data = input.to_vec();
336            data.reverse();
337            return Tensor::from_vec(data, shape).unwrap();
338        }
339
340        // For 2D, flip along the specified dimension
341        if shape.len() == 2 {
342            let data = input.to_vec();
343            let (rows, cols) = (shape[0], shape[1]);
344            let mut flipped = vec![0.0; data.len()];
345
346            if self.dim == 0 {
347                // Vertical flip
348                for r in 0..rows {
349                    for c in 0..cols {
350                        flipped[r * cols + c] = data[(rows - 1 - r) * cols + c];
351                    }
352                }
353            } else {
354                // Horizontal flip
355                for r in 0..rows {
356                    for c in 0..cols {
357                        flipped[r * cols + c] = data[r * cols + (cols - 1 - c)];
358                    }
359                }
360            }
361
362            return Tensor::from_vec(flipped, shape).unwrap();
363        }
364
365        input.clone()
366    }
367}
368
369// =============================================================================
370// Scale
371// =============================================================================
372
373/// Scales tensor values by a constant factor.
374pub struct Scale {
375    factor: f32,
376}
377
378impl Scale {
379    /// Creates a new Scale transform.
380    #[must_use]
381    pub fn new(factor: f32) -> Self {
382        Self { factor }
383    }
384}
385
386impl Transform for Scale {
387    fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
388        input.mul_scalar(self.factor)
389    }
390}
391
392// =============================================================================
393// Clamp
394// =============================================================================
395
396/// Clamps tensor values to a specified range.
397pub struct Clamp {
398    min: f32,
399    max: f32,
400}
401
402impl Clamp {
403    /// Creates a new Clamp transform.
404    #[must_use]
405    pub fn new(min: f32, max: f32) -> Self {
406        Self { min, max }
407    }
408
409    /// Creates a Clamp for [0, 1] range.
410    #[must_use]
411    pub fn zero_one() -> Self {
412        Self::new(0.0, 1.0)
413    }
414
415    /// Creates a Clamp for [-1, 1] range.
416    #[must_use]
417    pub fn symmetric() -> Self {
418        Self::new(-1.0, 1.0)
419    }
420}
421
422impl Transform for Clamp {
423    fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
424        let data = input.to_vec();
425        let clamped: Vec<f32> = data.iter().map(|&x| x.clamp(self.min, self.max)).collect();
426        Tensor::from_vec(clamped, input.shape()).unwrap()
427    }
428}
429
430// =============================================================================
431// Flatten
432// =============================================================================
433
434/// Flattens the tensor to 1D.
435pub struct Flatten;
436
437impl Flatten {
438    /// Creates a new Flatten transform.
439    #[must_use]
440    pub fn new() -> Self {
441        Self
442    }
443}
444
445impl Default for Flatten {
446    fn default() -> Self {
447        Self::new()
448    }
449}
450
451impl Transform for Flatten {
452    fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
453        let data = input.to_vec();
454        Tensor::from_vec(data.clone(), &[data.len()]).unwrap()
455    }
456}
457
458// =============================================================================
459// Reshape
460// =============================================================================
461
462/// Reshapes the tensor to a specified shape.
463pub struct Reshape {
464    shape: Vec<usize>,
465}
466
467impl Reshape {
468    /// Creates a new Reshape transform.
469    #[must_use]
470    pub fn new(shape: Vec<usize>) -> Self {
471        Self { shape }
472    }
473}
474
475impl Transform for Reshape {
476    fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
477        let data = input.to_vec();
478        let expected_size: usize = self.shape.iter().product();
479
480        if data.len() != expected_size {
481            // Size mismatch, return original
482            return input.clone();
483        }
484
485        Tensor::from_vec(data, &self.shape).unwrap()
486    }
487}
488
489// =============================================================================
490// Dropout Transform
491// =============================================================================
492
493/// Applies dropout by randomly zeroing elements during training.
494pub struct DropoutTransform {
495    probability: f32,
496}
497
498impl DropoutTransform {
499    /// Creates a new `DropoutTransform`.
500    #[must_use]
501    pub fn new(probability: f32) -> Self {
502        Self {
503            probability: probability.clamp(0.0, 1.0),
504        }
505    }
506}
507
508impl Transform for DropoutTransform {
509    fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
510        if self.probability == 0.0 {
511            return input.clone();
512        }
513
514        let mut rng = rand::thread_rng();
515        let scale = 1.0 / (1.0 - self.probability);
516        let data = input.to_vec();
517
518        let dropped: Vec<f32> = data
519            .iter()
520            .map(|&x| {
521                if rng.gen::<f32>() < self.probability {
522                    0.0
523                } else {
524                    x * scale
525                }
526            })
527            .collect();
528
529        Tensor::from_vec(dropped, input.shape()).unwrap()
530    }
531}
532
533// =============================================================================
534// Lambda Transform
535// =============================================================================
536
537/// Applies a custom function as a transform.
538pub struct Lambda<F>
539where
540    F: Fn(&Tensor<f32>) -> Tensor<f32> + Send + Sync,
541{
542    func: F,
543}
544
545impl<F> Lambda<F>
546where
547    F: Fn(&Tensor<f32>) -> Tensor<f32> + Send + Sync,
548{
549    /// Creates a new Lambda transform.
550    pub fn new(func: F) -> Self {
551        Self { func }
552    }
553}
554
555impl<F> Transform for Lambda<F>
556where
557    F: Fn(&Tensor<f32>) -> Tensor<f32> + Send + Sync,
558{
559    fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
560        (self.func)(input)
561    }
562}
563
564// =============================================================================
565// Tests
566// =============================================================================
567
568#[cfg(test)]
569mod tests {
570    use super::*;
571
572    #[test]
573    fn test_normalize() {
574        let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[4]).unwrap();
575        let normalize = Normalize::new(2.5, 0.5);
576
577        let output = normalize.apply(&input);
578        let expected = [-3.0, -1.0, 1.0, 3.0];
579
580        let result = output.to_vec();
581        for (a, b) in result.iter().zip(expected.iter()) {
582            assert!((a - b).abs() < 1e-6);
583        }
584    }
585
586    #[test]
587    fn test_scale() {
588        let input = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
589        let scale = Scale::new(2.0);
590
591        let output = scale.apply(&input);
592        assert_eq!(output.to_vec(), vec![2.0, 4.0, 6.0]);
593    }
594
595    #[test]
596    fn test_clamp() {
597        let input = Tensor::from_vec(vec![-1.0, 0.5, 2.0], &[3]).unwrap();
598        let clamp = Clamp::zero_one();
599
600        let output = clamp.apply(&input);
601        assert_eq!(output.to_vec(), vec![0.0, 0.5, 1.0]);
602    }
603
604    #[test]
605    fn test_flatten() {
606        let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
607        let flatten = Flatten::new();
608
609        let output = flatten.apply(&input);
610        assert_eq!(output.shape(), &[4]);
611        assert_eq!(output.to_vec(), vec![1.0, 2.0, 3.0, 4.0]);
612    }
613
614    #[test]
615    fn test_reshape() {
616        let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[6]).unwrap();
617        let reshape = Reshape::new(vec![2, 3]);
618
619        let output = reshape.apply(&input);
620        assert_eq!(output.shape(), &[2, 3]);
621    }
622
623    #[test]
624    fn test_compose() {
625        let normalize = Normalize::new(0.0, 1.0);
626        let scale = Scale::new(2.0);
627
628        let compose = Compose::new(vec![Box::new(normalize), Box::new(scale)]);
629
630        let input = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
631        let output = compose.apply(&input);
632
633        // normalize(x) = x, then scale by 2
634        assert_eq!(output.to_vec(), vec![2.0, 4.0, 6.0]);
635    }
636
637    #[test]
638    fn test_compose_builder() {
639        let compose = Compose::empty()
640            .add(Normalize::new(0.0, 1.0))
641            .add(Scale::new(2.0));
642
643        let input = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
644        let output = compose.apply(&input);
645
646        assert_eq!(output.to_vec(), vec![2.0, 4.0, 6.0]);
647    }
648
649    #[test]
650    fn test_random_noise() {
651        let input = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
652        let noise = RandomNoise::new(0.0);
653
654        // With std=0, output should equal input
655        let output = noise.apply(&input);
656        assert_eq!(output.to_vec(), vec![1.0, 2.0, 3.0]);
657    }
658
659    #[test]
660    fn test_random_flip_1d() {
661        let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[4]).unwrap();
662        let flip = RandomFlip::new(0, 1.0); // Always flip
663
664        let output = flip.apply(&input);
665        assert_eq!(output.to_vec(), vec![4.0, 3.0, 2.0, 1.0]);
666    }
667
668    #[test]
669    fn test_random_flip_2d_horizontal() {
670        let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
671        let flip = RandomFlip::new(1, 1.0); // Always flip horizontal
672
673        let output = flip.apply(&input);
674        // [[1, 2], [3, 4]] -> [[2, 1], [4, 3]]
675        assert_eq!(output.to_vec(), vec![2.0, 1.0, 4.0, 3.0]);
676    }
677
678    #[test]
679    fn test_random_flip_2d_vertical() {
680        let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
681        let flip = RandomFlip::new(0, 1.0); // Always flip vertical
682
683        let output = flip.apply(&input);
684        // [[1, 2], [3, 4]] -> [[3, 4], [1, 2]]
685        assert_eq!(output.to_vec(), vec![3.0, 4.0, 1.0, 2.0]);
686    }
687
688    #[test]
689    fn test_dropout_transform() {
690        let input = Tensor::from_vec(vec![1.0; 1000], &[1000]).unwrap();
691        let dropout = DropoutTransform::new(0.5);
692
693        let output = dropout.apply(&input);
694        let output_vec = output.to_vec();
695
696        // About half should be zero
697        let zeros = output_vec.iter().filter(|&&x| x == 0.0).count();
698        assert!(
699            zeros > 300 && zeros < 700,
700            "Expected ~500 zeros, got {zeros}"
701        );
702
703        // Non-zeros should be scaled by 2 (1/(1-0.5))
704        let nonzeros: Vec<f32> = output_vec.iter().filter(|&&x| x != 0.0).copied().collect();
705        for x in nonzeros {
706            assert!((x - 2.0).abs() < 1e-6);
707        }
708    }
709
710    #[test]
711    fn test_lambda() {
712        let lambda = Lambda::new(|t: &Tensor<f32>| t.mul_scalar(3.0));
713
714        let input = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
715        let output = lambda.apply(&input);
716
717        assert_eq!(output.to_vec(), vec![3.0, 6.0, 9.0]);
718    }
719
720    #[test]
721    fn test_to_tensor() {
722        let input = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
723        let to_tensor = ToTensor::new();
724
725        let output = to_tensor.apply(&input);
726        assert_eq!(output.to_vec(), input.to_vec());
727    }
728
729    #[test]
730    fn test_normalize_variants() {
731        let standard = Normalize::standard();
732        assert_eq!(standard.mean, 0.0);
733        assert_eq!(standard.std, 1.0);
734
735        let zero_centered = Normalize::zero_centered();
736        assert_eq!(zero_centered.mean, 0.5);
737        assert_eq!(zero_centered.std, 0.5);
738    }
739
740    #[test]
741    fn test_random_crop_1d() {
742        let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0], &[5]).unwrap();
743        let crop = RandomCrop::new(vec![3]);
744
745        let output = crop.apply(&input);
746        assert_eq!(output.shape(), &[3]);
747    }
748
749    #[test]
750    fn test_random_crop_2d() {
751        // 4x4 image
752        let input = Tensor::from_vec((1..=16).map(|x| x as f32).collect(), &[4, 4]).unwrap();
753        let crop = RandomCrop::new_2d(2, 2);
754
755        let output = crop.apply(&input);
756        assert_eq!(output.shape(), &[2, 2]);
757        // Verify values are contiguous from the original
758        let vals = output.to_vec();
759        assert_eq!(vals.len(), 4);
760    }
761
762    #[test]
763    fn test_random_crop_3d() {
764        // 2 channels x 4x4 image
765        let input = Tensor::from_vec((1..=32).map(|x| x as f32).collect(), &[2, 4, 4]).unwrap();
766        let crop = RandomCrop::new_2d(2, 2);
767
768        let output = crop.apply(&input);
769        assert_eq!(output.shape(), &[2, 2, 2]); // 2 channels, 2x2 spatial
770    }
771}