1use axonml_tensor::Tensor;
18use rand::Rng;
19
20pub trait Transform: Send + Sync {
26 fn apply(&self, input: &Tensor<f32>) -> Tensor<f32>;
28}
29
30pub struct Compose {
36 transforms: Vec<Box<dyn Transform>>,
37}
38
39impl Compose {
40 #[must_use]
42 pub fn new(transforms: Vec<Box<dyn Transform>>) -> Self {
43 Self { transforms }
44 }
45
46 #[must_use]
48 pub fn empty() -> Self {
49 Self {
50 transforms: Vec::new(),
51 }
52 }
53
54 pub fn add<T: Transform + 'static>(mut self, transform: T) -> Self {
56 self.transforms.push(Box::new(transform));
57 self
58 }
59}
60
61impl Transform for Compose {
62 fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
63 let mut result = input.clone();
64 for transform in &self.transforms {
65 result = transform.apply(&result);
66 }
67 result
68 }
69}
70
71pub struct ToTensor;
77
78impl ToTensor {
79 #[must_use]
81 pub fn new() -> Self {
82 Self
83 }
84}
85
86impl Default for ToTensor {
87 fn default() -> Self {
88 Self::new()
89 }
90}
91
92impl Transform for ToTensor {
93 fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
94 input.clone()
95 }
96}
97
98pub struct Normalize {
109 mean: Vec<f32>,
110 std: Vec<f32>,
111}
112
113impl Normalize {
114 #[must_use]
116 pub fn new(mean: f32, std: f32) -> Self {
117 Self {
118 mean: vec![mean],
119 std: vec![std],
120 }
121 }
122
123 #[must_use]
128 pub fn per_channel(mean: Vec<f32>, std: Vec<f32>) -> Self {
129 assert_eq!(mean.len(), std.len(), "mean and std must have same length");
130 Self { mean, std }
131 }
132
133 #[must_use]
135 pub fn standard() -> Self {
136 Self::new(0.0, 1.0)
137 }
138
139 #[must_use]
141 pub fn zero_centered() -> Self {
142 Self::new(0.5, 0.5)
143 }
144
145 #[must_use]
147 pub fn imagenet() -> Self {
148 Self::per_channel(vec![0.485, 0.456, 0.406], vec![0.229, 0.224, 0.225])
149 }
150}
151
152impl Transform for Normalize {
153 fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
154 let shape = input.shape();
155 let mut data = input.to_vec();
156
157 if self.mean.len() == 1 {
158 let m = self.mean[0];
160 let s = self.std[0];
161 for x in &mut data {
162 *x = (*x - m) / s;
163 }
164 } else {
165 let num_channels = self.mean.len();
167
168 if shape.len() == 3 && shape[0] == num_channels {
169 let spatial = shape[1] * shape[2];
171 for c in 0..num_channels {
172 let offset = c * spatial;
173 let m = self.mean[c];
174 let s = self.std[c];
175 for i in 0..spatial {
176 data[offset + i] = (data[offset + i] - m) / s;
177 }
178 }
179 } else if shape.len() == 4 && shape[1] == num_channels {
180 let spatial = shape[2] * shape[3];
182 let sample_size = num_channels * spatial;
183 for n in 0..shape[0] {
184 for c in 0..num_channels {
185 let offset = n * sample_size + c * spatial;
186 let m = self.mean[c];
187 let s = self.std[c];
188 for i in 0..spatial {
189 data[offset + i] = (data[offset + i] - m) / s;
190 }
191 }
192 }
193 } else {
194 let m = self.mean[0];
196 let s = self.std[0];
197 for x in &mut data {
198 *x = (*x - m) / s;
199 }
200 }
201 }
202
203 Tensor::from_vec(data, shape).unwrap()
204 }
205}
206
207pub struct RandomNoise {
213 std: f32,
214}
215
216impl RandomNoise {
217 #[must_use]
219 pub fn new(std: f32) -> Self {
220 Self { std }
221 }
222}
223
224impl Transform for RandomNoise {
225 fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
226 if self.std == 0.0 {
227 return input.clone();
228 }
229
230 let mut rng = rand::thread_rng();
231 let data = input.to_vec();
232 let noisy: Vec<f32> = data
233 .iter()
234 .map(|&x| {
235 let u1: f32 = rng.r#gen();
237 let u2: f32 = rng.r#gen();
238 let z = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f32::consts::PI * u2).cos();
239 x + z * self.std
240 })
241 .collect();
242 Tensor::from_vec(noisy, input.shape()).unwrap()
243 }
244}
245
246pub struct RandomCrop {
252 size: Vec<usize>,
253}
254
255impl RandomCrop {
256 #[must_use]
258 pub fn new(size: Vec<usize>) -> Self {
259 Self { size }
260 }
261
262 #[must_use]
264 pub fn new_2d(height: usize, width: usize) -> Self {
265 Self::new(vec![height, width])
266 }
267}
268
269impl Transform for RandomCrop {
270 fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
271 let shape = input.shape();
272
273 if shape.len() < self.size.len() {
275 return input.clone();
276 }
277
278 let spatial_start = shape.len() - self.size.len();
279 let mut rng = rand::thread_rng();
280
281 let mut offsets = Vec::with_capacity(self.size.len());
283 for (i, &target_dim) in self.size.iter().enumerate() {
284 let input_dim = shape[spatial_start + i];
285 if input_dim <= target_dim {
286 offsets.push(0);
287 } else {
288 offsets.push(rng.gen_range(0..=input_dim - target_dim));
289 }
290 }
291
292 let crop_sizes: Vec<usize> = self
294 .size
295 .iter()
296 .enumerate()
297 .map(|(i, &s)| s.min(shape[spatial_start + i]))
298 .collect();
299
300 let data = input.to_vec();
301
302 if shape.len() == 1 && self.size.len() == 1 {
304 let start = offsets[0];
305 let end = start + crop_sizes[0];
306 let cropped = data[start..end].to_vec();
307 let len = cropped.len();
308 return Tensor::from_vec(cropped, &[len]).unwrap();
309 }
310
311 if shape.len() == 2 && self.size.len() == 2 {
313 let (_h, w) = (shape[0], shape[1]);
314 let (crop_h, crop_w) = (crop_sizes[0], crop_sizes[1]);
315 let (off_h, off_w) = (offsets[0], offsets[1]);
316
317 let mut cropped = Vec::with_capacity(crop_h * crop_w);
318 for row in off_h..off_h + crop_h {
319 for col in off_w..off_w + crop_w {
320 cropped.push(data[row * w + col]);
321 }
322 }
323 return Tensor::from_vec(cropped, &[crop_h, crop_w]).unwrap();
324 }
325
326 if shape.len() == 3 && self.size.len() == 2 {
328 let (c, h, w) = (shape[0], shape[1], shape[2]);
329 let (crop_h, crop_w) = (crop_sizes[0], crop_sizes[1]);
330 let (off_h, off_w) = (offsets[0], offsets[1]);
331
332 let mut cropped = Vec::with_capacity(c * crop_h * crop_w);
333 for channel in 0..c {
334 for row in off_h..off_h + crop_h {
335 for col in off_w..off_w + crop_w {
336 cropped.push(data[channel * h * w + row * w + col]);
337 }
338 }
339 }
340 return Tensor::from_vec(cropped, &[c, crop_h, crop_w]).unwrap();
341 }
342
343 if shape.len() == 4 && self.size.len() == 2 {
345 let (n, c, h, w) = (shape[0], shape[1], shape[2], shape[3]);
346 let (crop_h, crop_w) = (crop_sizes[0], crop_sizes[1]);
347 let (off_h, off_w) = (offsets[0], offsets[1]);
348
349 let mut cropped = Vec::with_capacity(n * c * crop_h * crop_w);
350 for batch in 0..n {
351 for channel in 0..c {
352 for row in off_h..off_h + crop_h {
353 for col in off_w..off_w + crop_w {
354 let idx = batch * c * h * w + channel * h * w + row * w + col;
355 cropped.push(data[idx]);
356 }
357 }
358 }
359 }
360 return Tensor::from_vec(cropped, &[n, c, crop_h, crop_w]).unwrap();
361 }
362
363 input.clone()
365 }
366}
367
368pub struct RandomFlip {
374 dim: usize,
375 probability: f32,
376}
377
378impl RandomFlip {
379 #[must_use]
381 pub fn new(dim: usize, probability: f32) -> Self {
382 Self {
383 dim,
384 probability: probability.clamp(0.0, 1.0),
385 }
386 }
387
388 #[must_use]
390 pub fn horizontal() -> Self {
391 Self::new(1, 0.5)
392 }
393
394 #[must_use]
396 pub fn vertical() -> Self {
397 Self::new(0, 0.5)
398 }
399}
400
401impl Transform for RandomFlip {
402 fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
403 let mut rng = rand::thread_rng();
404 if rng.r#gen::<f32>() > self.probability {
405 return input.clone();
406 }
407
408 let shape = input.shape();
409 if self.dim >= shape.len() {
410 return input.clone();
411 }
412
413 let data = input.to_vec();
414 let ndim = shape.len();
415
416 let total = data.len();
419 let mut flipped = vec![0.0f32; total];
420
421 let mut strides = vec![1usize; ndim];
423 for i in (0..ndim - 1).rev() {
424 strides[i] = strides[i + 1] * shape[i + 1];
425 }
426
427 let dim = self.dim;
428 let dim_size = shape[dim];
429 let dim_stride = strides[dim];
430
431 for i in 0..total {
432 let coord_in_dim = (i / dim_stride) % dim_size;
434 let flipped_coord = dim_size - 1 - coord_in_dim;
435 let diff = flipped_coord as isize - coord_in_dim as isize;
437 let src = (i as isize + diff * dim_stride as isize) as usize;
438 flipped[i] = data[src];
439 }
440
441 Tensor::from_vec(flipped, shape).unwrap()
442 }
443}
444
445pub struct Scale {
451 factor: f32,
452}
453
454impl Scale {
455 #[must_use]
457 pub fn new(factor: f32) -> Self {
458 Self { factor }
459 }
460}
461
462impl Transform for Scale {
463 fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
464 input.mul_scalar(self.factor)
465 }
466}
467
468pub struct Clamp {
474 min: f32,
475 max: f32,
476}
477
478impl Clamp {
479 #[must_use]
481 pub fn new(min: f32, max: f32) -> Self {
482 Self { min, max }
483 }
484
485 #[must_use]
487 pub fn zero_one() -> Self {
488 Self::new(0.0, 1.0)
489 }
490
491 #[must_use]
493 pub fn symmetric() -> Self {
494 Self::new(-1.0, 1.0)
495 }
496}
497
498impl Transform for Clamp {
499 fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
500 let data = input.to_vec();
501 let clamped: Vec<f32> = data.iter().map(|&x| x.clamp(self.min, self.max)).collect();
502 Tensor::from_vec(clamped, input.shape()).unwrap()
503 }
504}
505
506pub struct Flatten;
512
513impl Flatten {
514 #[must_use]
516 pub fn new() -> Self {
517 Self
518 }
519}
520
521impl Default for Flatten {
522 fn default() -> Self {
523 Self::new()
524 }
525}
526
527impl Transform for Flatten {
528 fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
529 let data = input.to_vec();
530 Tensor::from_vec(data.clone(), &[data.len()]).unwrap()
531 }
532}
533
534pub struct Reshape {
540 shape: Vec<usize>,
541}
542
543impl Reshape {
544 #[must_use]
546 pub fn new(shape: Vec<usize>) -> Self {
547 Self { shape }
548 }
549}
550
551impl Transform for Reshape {
552 fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
553 let data = input.to_vec();
554 let expected_size: usize = self.shape.iter().product();
555
556 if data.len() != expected_size {
557 return input.clone();
559 }
560
561 Tensor::from_vec(data, &self.shape).unwrap()
562 }
563}
564
565pub struct DropoutTransform {
574 probability: f32,
575 training: std::sync::atomic::AtomicBool,
576}
577
578impl DropoutTransform {
579 #[must_use]
581 pub fn new(probability: f32) -> Self {
582 Self {
583 probability: probability.clamp(0.0, 1.0),
584 training: std::sync::atomic::AtomicBool::new(true),
585 }
586 }
587
588 pub fn set_training(&self, training: bool) {
590 self.training
591 .store(training, std::sync::atomic::Ordering::Relaxed);
592 }
593
594 pub fn is_training(&self) -> bool {
596 self.training.load(std::sync::atomic::Ordering::Relaxed)
597 }
598}
599
600impl Transform for DropoutTransform {
601 fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
602 if !self.is_training() || self.probability == 0.0 {
604 return input.clone();
605 }
606
607 let mut rng = rand::thread_rng();
608 let scale = 1.0 / (1.0 - self.probability);
609 let data = input.to_vec();
610
611 let dropped: Vec<f32> = data
612 .iter()
613 .map(|&x| {
614 if rng.r#gen::<f32>() < self.probability {
615 0.0
616 } else {
617 x * scale
618 }
619 })
620 .collect();
621
622 Tensor::from_vec(dropped, input.shape()).unwrap()
623 }
624}
625
626pub struct Lambda<F>
632where
633 F: Fn(&Tensor<f32>) -> Tensor<f32> + Send + Sync,
634{
635 func: F,
636}
637
638impl<F> Lambda<F>
639where
640 F: Fn(&Tensor<f32>) -> Tensor<f32> + Send + Sync,
641{
642 pub fn new(func: F) -> Self {
644 Self { func }
645 }
646}
647
648impl<F> Transform for Lambda<F>
649where
650 F: Fn(&Tensor<f32>) -> Tensor<f32> + Send + Sync,
651{
652 fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
653 (self.func)(input)
654 }
655}
656
657#[cfg(test)]
662mod tests {
663 use super::*;
664
665 #[test]
666 fn test_normalize() {
667 let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[4]).unwrap();
668 let normalize = Normalize::new(2.5, 0.5);
669
670 let output = normalize.apply(&input);
671 let expected = [-3.0, -1.0, 1.0, 3.0];
672
673 let result = output.to_vec();
674 for (a, b) in result.iter().zip(expected.iter()) {
675 assert!((a - b).abs() < 1e-6);
676 }
677 }
678
679 #[test]
680 fn test_normalize_per_channel() {
681 let input =
683 Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 10.0, 20.0, 30.0, 40.0], &[2, 2, 2]).unwrap();
684 let normalize = Normalize::per_channel(vec![0.0, 10.0], vec![1.0, 10.0]);
685
686 let output = normalize.apply(&input);
687 let result = output.to_vec();
688 assert!((result[0] - 1.0).abs() < 1e-6);
690 assert!((result[3] - 4.0).abs() < 1e-6);
691 assert!((result[4] - 0.0).abs() < 1e-6); assert!((result[5] - 1.0).abs() < 1e-6); }
695
696 #[test]
697 fn test_scale() {
698 let input = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
699 let scale = Scale::new(2.0);
700
701 let output = scale.apply(&input);
702 assert_eq!(output.to_vec(), vec![2.0, 4.0, 6.0]);
703 }
704
705 #[test]
706 fn test_clamp() {
707 let input = Tensor::from_vec(vec![-1.0, 0.5, 2.0], &[3]).unwrap();
708 let clamp = Clamp::zero_one();
709
710 let output = clamp.apply(&input);
711 assert_eq!(output.to_vec(), vec![0.0, 0.5, 1.0]);
712 }
713
714 #[test]
715 fn test_flatten() {
716 let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
717 let flatten = Flatten::new();
718
719 let output = flatten.apply(&input);
720 assert_eq!(output.shape(), &[4]);
721 assert_eq!(output.to_vec(), vec![1.0, 2.0, 3.0, 4.0]);
722 }
723
724 #[test]
725 fn test_reshape() {
726 let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[6]).unwrap();
727 let reshape = Reshape::new(vec![2, 3]);
728
729 let output = reshape.apply(&input);
730 assert_eq!(output.shape(), &[2, 3]);
731 }
732
733 #[test]
734 fn test_compose() {
735 let normalize = Normalize::new(0.0, 1.0);
736 let scale = Scale::new(2.0);
737
738 let compose = Compose::new(vec![Box::new(normalize), Box::new(scale)]);
739
740 let input = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
741 let output = compose.apply(&input);
742
743 assert_eq!(output.to_vec(), vec![2.0, 4.0, 6.0]);
745 }
746
747 #[test]
748 fn test_compose_builder() {
749 let compose = Compose::empty()
750 .add(Normalize::new(0.0, 1.0))
751 .add(Scale::new(2.0));
752
753 let input = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
754 let output = compose.apply(&input);
755
756 assert_eq!(output.to_vec(), vec![2.0, 4.0, 6.0]);
757 }
758
759 #[test]
760 fn test_random_noise() {
761 let input = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
762 let noise = RandomNoise::new(0.0);
763
764 let output = noise.apply(&input);
766 assert_eq!(output.to_vec(), vec![1.0, 2.0, 3.0]);
767 }
768
769 #[test]
770 fn test_random_flip_1d() {
771 let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[4]).unwrap();
772 let flip = RandomFlip::new(0, 1.0); let output = flip.apply(&input);
775 assert_eq!(output.to_vec(), vec![4.0, 3.0, 2.0, 1.0]);
776 }
777
778 #[test]
779 fn test_random_flip_2d_horizontal() {
780 let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
781 let flip = RandomFlip::new(1, 1.0); let output = flip.apply(&input);
784 assert_eq!(output.to_vec(), vec![2.0, 1.0, 4.0, 3.0]);
786 }
787
788 #[test]
789 fn test_random_flip_2d_vertical() {
790 let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
791 let flip = RandomFlip::new(0, 1.0); let output = flip.apply(&input);
794 assert_eq!(output.to_vec(), vec![3.0, 4.0, 1.0, 2.0]);
796 }
797
798 #[test]
799 fn test_random_flip_3d() {
800 let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[1, 2, 2]).unwrap();
802 let flip = RandomFlip::new(2, 1.0); let output = flip.apply(&input);
805 assert_eq!(output.to_vec(), vec![2.0, 1.0, 4.0, 3.0]);
807 assert_eq!(output.shape(), &[1, 2, 2]);
808 }
809
810 #[test]
811 fn test_random_flip_4d() {
812 let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[1, 1, 2, 2]).unwrap();
814 let flip = RandomFlip::new(2, 1.0); let output = flip.apply(&input);
817 assert_eq!(output.to_vec(), vec![3.0, 4.0, 1.0, 2.0]);
819 assert_eq!(output.shape(), &[1, 1, 2, 2]);
820 }
821
822 #[test]
823 fn test_dropout_eval_mode() {
824 let input = Tensor::from_vec(vec![1.0; 100], &[100]).unwrap();
825 let dropout = DropoutTransform::new(0.5);
826
827 let output_train = dropout.apply(&input);
829 let zeros_train = output_train.to_vec().iter().filter(|&&x| x == 0.0).count();
830 assert!(zeros_train > 0, "Training mode should drop elements");
831
832 dropout.set_training(false);
834 let output_eval = dropout.apply(&input);
835 assert_eq!(output_eval.to_vec(), vec![1.0; 100]);
836 }
837
838 #[test]
839 fn test_dropout_transform() {
840 let input = Tensor::from_vec(vec![1.0; 1000], &[1000]).unwrap();
841 let dropout = DropoutTransform::new(0.5);
842
843 let output = dropout.apply(&input);
844 let output_vec = output.to_vec();
845
846 let zeros = output_vec.iter().filter(|&&x| x == 0.0).count();
848 assert!(
849 zeros > 300 && zeros < 700,
850 "Expected ~500 zeros, got {zeros}"
851 );
852
853 let nonzeros: Vec<f32> = output_vec.iter().filter(|&&x| x != 0.0).copied().collect();
855 for x in nonzeros {
856 assert!((x - 2.0).abs() < 1e-6);
857 }
858 }
859
860 #[test]
861 fn test_lambda() {
862 let lambda = Lambda::new(|t: &Tensor<f32>| t.mul_scalar(3.0));
863
864 let input = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
865 let output = lambda.apply(&input);
866
867 assert_eq!(output.to_vec(), vec![3.0, 6.0, 9.0]);
868 }
869
870 #[test]
871 fn test_to_tensor() {
872 let input = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
873 let to_tensor = ToTensor::new();
874
875 let output = to_tensor.apply(&input);
876 assert_eq!(output.to_vec(), input.to_vec());
877 }
878
879 #[test]
880 fn test_normalize_variants() {
881 let standard = Normalize::standard();
882 assert_eq!(standard.mean, vec![0.0]);
883 assert_eq!(standard.std, vec![1.0]);
884
885 let zero_centered = Normalize::zero_centered();
886 assert_eq!(zero_centered.mean, vec![0.5]);
887 assert_eq!(zero_centered.std, vec![0.5]);
888 }
889
890 #[test]
891 fn test_random_crop_1d() {
892 let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0], &[5]).unwrap();
893 let crop = RandomCrop::new(vec![3]);
894
895 let output = crop.apply(&input);
896 assert_eq!(output.shape(), &[3]);
897 }
898
899 #[test]
900 fn test_random_crop_2d() {
901 let input = Tensor::from_vec((1..=16).map(|x| x as f32).collect(), &[4, 4]).unwrap();
903 let crop = RandomCrop::new_2d(2, 2);
904
905 let output = crop.apply(&input);
906 assert_eq!(output.shape(), &[2, 2]);
907 let vals = output.to_vec();
909 assert_eq!(vals.len(), 4);
910 }
911
912 #[test]
913 fn test_random_crop_3d() {
914 let input = Tensor::from_vec((1..=32).map(|x| x as f32).collect(), &[2, 4, 4]).unwrap();
916 let crop = RandomCrop::new_2d(2, 2);
917
918 let output = crop.apply(&input);
919 assert_eq!(output.shape(), &[2, 2, 2]); }
921}