1use axonml_tensor::Tensor;
19use rand::Rng;
20
21pub trait Transform: Send + Sync {
27 fn apply(&self, input: &Tensor<f32>) -> Tensor<f32>;
29}
30
31pub struct Compose {
37 transforms: Vec<Box<dyn Transform>>,
38}
39
40impl Compose {
41 #[must_use]
43 pub fn new(transforms: Vec<Box<dyn Transform>>) -> Self {
44 Self { transforms }
45 }
46
47 #[must_use]
49 pub fn empty() -> Self {
50 Self {
51 transforms: Vec::new(),
52 }
53 }
54
55 pub fn add<T: Transform + 'static>(mut self, transform: T) -> Self {
57 self.transforms.push(Box::new(transform));
58 self
59 }
60}
61
62impl Transform for Compose {
63 fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
64 let mut result = input.clone();
65 for transform in &self.transforms {
66 result = transform.apply(&result);
67 }
68 result
69 }
70}
71
72pub struct ToTensor;
78
79impl ToTensor {
80 #[must_use]
82 pub fn new() -> Self {
83 Self
84 }
85}
86
87impl Default for ToTensor {
88 fn default() -> Self {
89 Self::new()
90 }
91}
92
93impl Transform for ToTensor {
94 fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
95 input.clone()
96 }
97}
98
99pub struct Normalize {
110 mean: Vec<f32>,
111 std: Vec<f32>,
112}
113
114impl Normalize {
115 #[must_use]
117 pub fn new(mean: f32, std: f32) -> Self {
118 Self {
119 mean: vec![mean],
120 std: vec![std],
121 }
122 }
123
124 #[must_use]
129 pub fn per_channel(mean: Vec<f32>, std: Vec<f32>) -> Self {
130 assert_eq!(mean.len(), std.len(), "mean and std must have same length");
131 Self { mean, std }
132 }
133
134 #[must_use]
136 pub fn standard() -> Self {
137 Self::new(0.0, 1.0)
138 }
139
140 #[must_use]
142 pub fn zero_centered() -> Self {
143 Self::new(0.5, 0.5)
144 }
145
146 #[must_use]
148 pub fn imagenet() -> Self {
149 Self::per_channel(vec![0.485, 0.456, 0.406], vec![0.229, 0.224, 0.225])
150 }
151}
152
153impl Transform for Normalize {
154 fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
155 let shape = input.shape();
156 let mut data = input.to_vec();
157
158 if self.mean.len() == 1 {
159 let m = self.mean[0];
161 let s = self.std[0];
162 for x in &mut data {
163 *x = (*x - m) / s;
164 }
165 } else {
166 let num_channels = self.mean.len();
168
169 if shape.len() == 3 && shape[0] == num_channels {
170 let spatial = shape[1] * shape[2];
172 for c in 0..num_channels {
173 let offset = c * spatial;
174 let m = self.mean[c];
175 let s = self.std[c];
176 for i in 0..spatial {
177 data[offset + i] = (data[offset + i] - m) / s;
178 }
179 }
180 } else if shape.len() == 4 && shape[1] == num_channels {
181 let spatial = shape[2] * shape[3];
183 let sample_size = num_channels * spatial;
184 for n in 0..shape[0] {
185 for c in 0..num_channels {
186 let offset = n * sample_size + c * spatial;
187 let m = self.mean[c];
188 let s = self.std[c];
189 for i in 0..spatial {
190 data[offset + i] = (data[offset + i] - m) / s;
191 }
192 }
193 }
194 } else {
195 let m = self.mean[0];
197 let s = self.std[0];
198 for x in &mut data {
199 *x = (*x - m) / s;
200 }
201 }
202 }
203
204 Tensor::from_vec(data, shape).unwrap()
205 }
206}
207
208pub struct RandomNoise {
214 std: f32,
215}
216
217impl RandomNoise {
218 #[must_use]
220 pub fn new(std: f32) -> Self {
221 Self { std }
222 }
223}
224
225impl Transform for RandomNoise {
226 fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
227 if self.std == 0.0 {
228 return input.clone();
229 }
230
231 let mut rng = rand::thread_rng();
232 let data = input.to_vec();
233 let noisy: Vec<f32> = data
234 .iter()
235 .map(|&x| {
236 let u1: f32 = rng.r#gen();
238 let u2: f32 = rng.r#gen();
239 let z = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f32::consts::PI * u2).cos();
240 x + z * self.std
241 })
242 .collect();
243 Tensor::from_vec(noisy, input.shape()).unwrap()
244 }
245}
246
247pub struct RandomCrop {
253 size: Vec<usize>,
254}
255
256impl RandomCrop {
257 #[must_use]
259 pub fn new(size: Vec<usize>) -> Self {
260 Self { size }
261 }
262
263 #[must_use]
265 pub fn new_2d(height: usize, width: usize) -> Self {
266 Self::new(vec![height, width])
267 }
268}
269
270impl Transform for RandomCrop {
271 fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
272 let shape = input.shape();
273
274 if shape.len() < self.size.len() {
276 return input.clone();
277 }
278
279 let spatial_start = shape.len() - self.size.len();
280 let mut rng = rand::thread_rng();
281
282 let mut offsets = Vec::with_capacity(self.size.len());
284 for (i, &target_dim) in self.size.iter().enumerate() {
285 let input_dim = shape[spatial_start + i];
286 if input_dim <= target_dim {
287 offsets.push(0);
288 } else {
289 offsets.push(rng.gen_range(0..=input_dim - target_dim));
290 }
291 }
292
293 let crop_sizes: Vec<usize> = self
295 .size
296 .iter()
297 .enumerate()
298 .map(|(i, &s)| s.min(shape[spatial_start + i]))
299 .collect();
300
301 let data = input.to_vec();
302
303 if shape.len() == 1 && self.size.len() == 1 {
305 let start = offsets[0];
306 let end = start + crop_sizes[0];
307 let cropped = data[start..end].to_vec();
308 let len = cropped.len();
309 return Tensor::from_vec(cropped, &[len]).unwrap();
310 }
311
312 if shape.len() == 2 && self.size.len() == 2 {
314 let (_h, w) = (shape[0], shape[1]);
315 let (crop_h, crop_w) = (crop_sizes[0], crop_sizes[1]);
316 let (off_h, off_w) = (offsets[0], offsets[1]);
317
318 let mut cropped = Vec::with_capacity(crop_h * crop_w);
319 for row in off_h..off_h + crop_h {
320 for col in off_w..off_w + crop_w {
321 cropped.push(data[row * w + col]);
322 }
323 }
324 return Tensor::from_vec(cropped, &[crop_h, crop_w]).unwrap();
325 }
326
327 if shape.len() == 3 && self.size.len() == 2 {
329 let (c, h, w) = (shape[0], shape[1], shape[2]);
330 let (crop_h, crop_w) = (crop_sizes[0], crop_sizes[1]);
331 let (off_h, off_w) = (offsets[0], offsets[1]);
332
333 let mut cropped = Vec::with_capacity(c * crop_h * crop_w);
334 for channel in 0..c {
335 for row in off_h..off_h + crop_h {
336 for col in off_w..off_w + crop_w {
337 cropped.push(data[channel * h * w + row * w + col]);
338 }
339 }
340 }
341 return Tensor::from_vec(cropped, &[c, crop_h, crop_w]).unwrap();
342 }
343
344 if shape.len() == 4 && self.size.len() == 2 {
346 let (n, c, h, w) = (shape[0], shape[1], shape[2], shape[3]);
347 let (crop_h, crop_w) = (crop_sizes[0], crop_sizes[1]);
348 let (off_h, off_w) = (offsets[0], offsets[1]);
349
350 let mut cropped = Vec::with_capacity(n * c * crop_h * crop_w);
351 for batch in 0..n {
352 for channel in 0..c {
353 for row in off_h..off_h + crop_h {
354 for col in off_w..off_w + crop_w {
355 let idx = batch * c * h * w + channel * h * w + row * w + col;
356 cropped.push(data[idx]);
357 }
358 }
359 }
360 }
361 return Tensor::from_vec(cropped, &[n, c, crop_h, crop_w]).unwrap();
362 }
363
364 input.clone()
366 }
367}
368
369pub struct RandomFlip {
375 dim: usize,
376 probability: f32,
377}
378
379impl RandomFlip {
380 #[must_use]
382 pub fn new(dim: usize, probability: f32) -> Self {
383 Self {
384 dim,
385 probability: probability.clamp(0.0, 1.0),
386 }
387 }
388
389 #[must_use]
391 pub fn horizontal() -> Self {
392 Self::new(1, 0.5)
393 }
394
395 #[must_use]
397 pub fn vertical() -> Self {
398 Self::new(0, 0.5)
399 }
400}
401
402impl Transform for RandomFlip {
403 fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
404 let mut rng = rand::thread_rng();
405 if rng.r#gen::<f32>() > self.probability {
406 return input.clone();
407 }
408
409 let shape = input.shape();
410 if self.dim >= shape.len() {
411 return input.clone();
412 }
413
414 let data = input.to_vec();
415 let ndim = shape.len();
416
417 let total = data.len();
420 let mut flipped = vec![0.0f32; total];
421
422 let mut strides = vec![1usize; ndim];
424 for i in (0..ndim - 1).rev() {
425 strides[i] = strides[i + 1] * shape[i + 1];
426 }
427
428 let dim = self.dim;
429 let dim_size = shape[dim];
430 let dim_stride = strides[dim];
431
432 for i in 0..total {
433 let coord_in_dim = (i / dim_stride) % dim_size;
435 let flipped_coord = dim_size - 1 - coord_in_dim;
436 let diff = flipped_coord as isize - coord_in_dim as isize;
438 let src = (i as isize + diff * dim_stride as isize) as usize;
439 flipped[i] = data[src];
440 }
441
442 Tensor::from_vec(flipped, shape).unwrap()
443 }
444}
445
446pub struct Scale {
452 factor: f32,
453}
454
455impl Scale {
456 #[must_use]
458 pub fn new(factor: f32) -> Self {
459 Self { factor }
460 }
461}
462
463impl Transform for Scale {
464 fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
465 input.mul_scalar(self.factor)
466 }
467}
468
469pub struct Clamp {
475 min: f32,
476 max: f32,
477}
478
479impl Clamp {
480 #[must_use]
482 pub fn new(min: f32, max: f32) -> Self {
483 Self { min, max }
484 }
485
486 #[must_use]
488 pub fn zero_one() -> Self {
489 Self::new(0.0, 1.0)
490 }
491
492 #[must_use]
494 pub fn symmetric() -> Self {
495 Self::new(-1.0, 1.0)
496 }
497}
498
499impl Transform for Clamp {
500 fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
501 let data = input.to_vec();
502 let clamped: Vec<f32> = data.iter().map(|&x| x.clamp(self.min, self.max)).collect();
503 Tensor::from_vec(clamped, input.shape()).unwrap()
504 }
505}
506
507pub struct Flatten;
513
514impl Flatten {
515 #[must_use]
517 pub fn new() -> Self {
518 Self
519 }
520}
521
522impl Default for Flatten {
523 fn default() -> Self {
524 Self::new()
525 }
526}
527
528impl Transform for Flatten {
529 fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
530 let data = input.to_vec();
531 Tensor::from_vec(data.clone(), &[data.len()]).unwrap()
532 }
533}
534
535pub struct Reshape {
541 shape: Vec<usize>,
542}
543
544impl Reshape {
545 #[must_use]
547 pub fn new(shape: Vec<usize>) -> Self {
548 Self { shape }
549 }
550}
551
552impl Transform for Reshape {
553 fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
554 let data = input.to_vec();
555 let expected_size: usize = self.shape.iter().product();
556
557 if data.len() != expected_size {
558 return input.clone();
560 }
561
562 Tensor::from_vec(data, &self.shape).unwrap()
563 }
564}
565
566pub struct DropoutTransform {
575 probability: f32,
576 training: std::sync::atomic::AtomicBool,
577}
578
579impl DropoutTransform {
580 #[must_use]
582 pub fn new(probability: f32) -> Self {
583 Self {
584 probability: probability.clamp(0.0, 1.0),
585 training: std::sync::atomic::AtomicBool::new(true),
586 }
587 }
588
589 pub fn set_training(&self, training: bool) {
591 self.training
592 .store(training, std::sync::atomic::Ordering::Relaxed);
593 }
594
595 pub fn is_training(&self) -> bool {
597 self.training.load(std::sync::atomic::Ordering::Relaxed)
598 }
599}
600
601impl Transform for DropoutTransform {
602 fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
603 if !self.is_training() || self.probability == 0.0 {
605 return input.clone();
606 }
607
608 let mut rng = rand::thread_rng();
609 let scale = 1.0 / (1.0 - self.probability);
610 let data = input.to_vec();
611
612 let dropped: Vec<f32> = data
613 .iter()
614 .map(|&x| {
615 if rng.r#gen::<f32>() < self.probability {
616 0.0
617 } else {
618 x * scale
619 }
620 })
621 .collect();
622
623 Tensor::from_vec(dropped, input.shape()).unwrap()
624 }
625}
626
627pub struct Lambda<F>
633where
634 F: Fn(&Tensor<f32>) -> Tensor<f32> + Send + Sync,
635{
636 func: F,
637}
638
639impl<F> Lambda<F>
640where
641 F: Fn(&Tensor<f32>) -> Tensor<f32> + Send + Sync,
642{
643 pub fn new(func: F) -> Self {
645 Self { func }
646 }
647}
648
649impl<F> Transform for Lambda<F>
650where
651 F: Fn(&Tensor<f32>) -> Tensor<f32> + Send + Sync,
652{
653 fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
654 (self.func)(input)
655 }
656}
657
658#[cfg(test)]
663mod tests {
664 use super::*;
665
666 #[test]
667 fn test_normalize() {
668 let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[4]).unwrap();
669 let normalize = Normalize::new(2.5, 0.5);
670
671 let output = normalize.apply(&input);
672 let expected = [-3.0, -1.0, 1.0, 3.0];
673
674 let result = output.to_vec();
675 for (a, b) in result.iter().zip(expected.iter()) {
676 assert!((a - b).abs() < 1e-6);
677 }
678 }
679
680 #[test]
681 fn test_normalize_per_channel() {
682 let input =
684 Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 10.0, 20.0, 30.0, 40.0], &[2, 2, 2]).unwrap();
685 let normalize = Normalize::per_channel(vec![0.0, 10.0], vec![1.0, 10.0]);
686
687 let output = normalize.apply(&input);
688 let result = output.to_vec();
689 assert!((result[0] - 1.0).abs() < 1e-6);
691 assert!((result[3] - 4.0).abs() < 1e-6);
692 assert!((result[4] - 0.0).abs() < 1e-6); assert!((result[5] - 1.0).abs() < 1e-6); }
696
697 #[test]
698 fn test_scale() {
699 let input = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
700 let scale = Scale::new(2.0);
701
702 let output = scale.apply(&input);
703 assert_eq!(output.to_vec(), vec![2.0, 4.0, 6.0]);
704 }
705
706 #[test]
707 fn test_clamp() {
708 let input = Tensor::from_vec(vec![-1.0, 0.5, 2.0], &[3]).unwrap();
709 let clamp = Clamp::zero_one();
710
711 let output = clamp.apply(&input);
712 assert_eq!(output.to_vec(), vec![0.0, 0.5, 1.0]);
713 }
714
715 #[test]
716 fn test_flatten() {
717 let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
718 let flatten = Flatten::new();
719
720 let output = flatten.apply(&input);
721 assert_eq!(output.shape(), &[4]);
722 assert_eq!(output.to_vec(), vec![1.0, 2.0, 3.0, 4.0]);
723 }
724
725 #[test]
726 fn test_reshape() {
727 let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[6]).unwrap();
728 let reshape = Reshape::new(vec![2, 3]);
729
730 let output = reshape.apply(&input);
731 assert_eq!(output.shape(), &[2, 3]);
732 }
733
734 #[test]
735 fn test_compose() {
736 let normalize = Normalize::new(0.0, 1.0);
737 let scale = Scale::new(2.0);
738
739 let compose = Compose::new(vec![Box::new(normalize), Box::new(scale)]);
740
741 let input = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
742 let output = compose.apply(&input);
743
744 assert_eq!(output.to_vec(), vec![2.0, 4.0, 6.0]);
746 }
747
748 #[test]
749 fn test_compose_builder() {
750 let compose = Compose::empty()
751 .add(Normalize::new(0.0, 1.0))
752 .add(Scale::new(2.0));
753
754 let input = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
755 let output = compose.apply(&input);
756
757 assert_eq!(output.to_vec(), vec![2.0, 4.0, 6.0]);
758 }
759
760 #[test]
761 fn test_random_noise() {
762 let input = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
763 let noise = RandomNoise::new(0.0);
764
765 let output = noise.apply(&input);
767 assert_eq!(output.to_vec(), vec![1.0, 2.0, 3.0]);
768 }
769
770 #[test]
771 fn test_random_flip_1d() {
772 let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[4]).unwrap();
773 let flip = RandomFlip::new(0, 1.0); let output = flip.apply(&input);
776 assert_eq!(output.to_vec(), vec![4.0, 3.0, 2.0, 1.0]);
777 }
778
779 #[test]
780 fn test_random_flip_2d_horizontal() {
781 let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
782 let flip = RandomFlip::new(1, 1.0); let output = flip.apply(&input);
785 assert_eq!(output.to_vec(), vec![2.0, 1.0, 4.0, 3.0]);
787 }
788
789 #[test]
790 fn test_random_flip_2d_vertical() {
791 let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
792 let flip = RandomFlip::new(0, 1.0); let output = flip.apply(&input);
795 assert_eq!(output.to_vec(), vec![3.0, 4.0, 1.0, 2.0]);
797 }
798
799 #[test]
800 fn test_random_flip_3d() {
801 let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[1, 2, 2]).unwrap();
803 let flip = RandomFlip::new(2, 1.0); let output = flip.apply(&input);
806 assert_eq!(output.to_vec(), vec![2.0, 1.0, 4.0, 3.0]);
808 assert_eq!(output.shape(), &[1, 2, 2]);
809 }
810
811 #[test]
812 fn test_random_flip_4d() {
813 let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[1, 1, 2, 2]).unwrap();
815 let flip = RandomFlip::new(2, 1.0); let output = flip.apply(&input);
818 assert_eq!(output.to_vec(), vec![3.0, 4.0, 1.0, 2.0]);
820 assert_eq!(output.shape(), &[1, 1, 2, 2]);
821 }
822
823 #[test]
824 fn test_dropout_eval_mode() {
825 let input = Tensor::from_vec(vec![1.0; 100], &[100]).unwrap();
826 let dropout = DropoutTransform::new(0.5);
827
828 let output_train = dropout.apply(&input);
830 let zeros_train = output_train.to_vec().iter().filter(|&&x| x == 0.0).count();
831 assert!(zeros_train > 0, "Training mode should drop elements");
832
833 dropout.set_training(false);
835 let output_eval = dropout.apply(&input);
836 assert_eq!(output_eval.to_vec(), vec![1.0; 100]);
837 }
838
839 #[test]
840 fn test_dropout_transform() {
841 let input = Tensor::from_vec(vec![1.0; 1000], &[1000]).unwrap();
842 let dropout = DropoutTransform::new(0.5);
843
844 let output = dropout.apply(&input);
845 let output_vec = output.to_vec();
846
847 let zeros = output_vec.iter().filter(|&&x| x == 0.0).count();
849 assert!(
850 zeros > 300 && zeros < 700,
851 "Expected ~500 zeros, got {zeros}"
852 );
853
854 let nonzeros: Vec<f32> = output_vec.iter().filter(|&&x| x != 0.0).copied().collect();
856 for x in nonzeros {
857 assert!((x - 2.0).abs() < 1e-6);
858 }
859 }
860
861 #[test]
862 fn test_lambda() {
863 let lambda = Lambda::new(|t: &Tensor<f32>| t.mul_scalar(3.0));
864
865 let input = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
866 let output = lambda.apply(&input);
867
868 assert_eq!(output.to_vec(), vec![3.0, 6.0, 9.0]);
869 }
870
871 #[test]
872 fn test_to_tensor() {
873 let input = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
874 let to_tensor = ToTensor::new();
875
876 let output = to_tensor.apply(&input);
877 assert_eq!(output.to_vec(), input.to_vec());
878 }
879
880 #[test]
881 fn test_normalize_variants() {
882 let standard = Normalize::standard();
883 assert_eq!(standard.mean, vec![0.0]);
884 assert_eq!(standard.std, vec![1.0]);
885
886 let zero_centered = Normalize::zero_centered();
887 assert_eq!(zero_centered.mean, vec![0.5]);
888 assert_eq!(zero_centered.std, vec![0.5]);
889 }
890
891 #[test]
892 fn test_random_crop_1d() {
893 let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0], &[5]).unwrap();
894 let crop = RandomCrop::new(vec![3]);
895
896 let output = crop.apply(&input);
897 assert_eq!(output.shape(), &[3]);
898 }
899
900 #[test]
901 fn test_random_crop_2d() {
902 let input = Tensor::from_vec((1..=16).map(|x| x as f32).collect(), &[4, 4]).unwrap();
904 let crop = RandomCrop::new_2d(2, 2);
905
906 let output = crop.apply(&input);
907 assert_eq!(output.shape(), &[2, 2]);
908 let vals = output.to_vec();
910 assert_eq!(vals.len(), 4);
911 }
912
913 #[test]
914 fn test_random_crop_3d() {
915 let input = Tensor::from_vec((1..=32).map(|x| x as f32).collect(), &[2, 4, 4]).unwrap();
917 let crop = RandomCrop::new_2d(2, 2);
918
919 let output = crop.apply(&input);
920 assert_eq!(output.shape(), &[2, 2, 2]); }
922}