1use axonml_tensor::Tensor;
18use rand::Rng;
19
20pub trait Transform: Send + Sync {
26 fn apply(&self, input: &Tensor<f32>) -> Tensor<f32>;
28}
29
30pub struct Compose {
36 transforms: Vec<Box<dyn Transform>>,
37}
38
39impl Compose {
40 #[must_use]
42 pub fn new(transforms: Vec<Box<dyn Transform>>) -> Self {
43 Self { transforms }
44 }
45
46 #[must_use]
48 pub fn empty() -> Self {
49 Self {
50 transforms: Vec::new(),
51 }
52 }
53
54 pub fn add<T: Transform + 'static>(mut self, transform: T) -> Self {
56 self.transforms.push(Box::new(transform));
57 self
58 }
59}
60
61impl Transform for Compose {
62 fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
63 let mut result = input.clone();
64 for transform in &self.transforms {
65 result = transform.apply(&result);
66 }
67 result
68 }
69}
70
71pub struct ToTensor;
77
78impl ToTensor {
79 #[must_use]
81 pub fn new() -> Self {
82 Self
83 }
84}
85
86impl Default for ToTensor {
87 fn default() -> Self {
88 Self::new()
89 }
90}
91
92impl Transform for ToTensor {
93 fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
94 input.clone()
95 }
96}
97
98pub struct Normalize {
109 mean: Vec<f32>,
110 std: Vec<f32>,
111}
112
113impl Normalize {
114 #[must_use]
116 pub fn new(mean: f32, std: f32) -> Self {
117 Self {
118 mean: vec![mean],
119 std: vec![std],
120 }
121 }
122
123 #[must_use]
128 pub fn per_channel(mean: Vec<f32>, std: Vec<f32>) -> Self {
129 assert_eq!(mean.len(), std.len(), "mean and std must have same length");
130 Self { mean, std }
131 }
132
133 #[must_use]
135 pub fn standard() -> Self {
136 Self::new(0.0, 1.0)
137 }
138
139 #[must_use]
141 pub fn zero_centered() -> Self {
142 Self::new(0.5, 0.5)
143 }
144
145 #[must_use]
147 pub fn imagenet() -> Self {
148 Self::per_channel(
149 vec![0.485, 0.456, 0.406],
150 vec![0.229, 0.224, 0.225],
151 )
152 }
153}
154
155impl Transform for Normalize {
156 fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
157 let shape = input.shape();
158 let mut data = input.to_vec();
159
160 if self.mean.len() == 1 {
161 let m = self.mean[0];
163 let s = self.std[0];
164 for x in &mut data {
165 *x = (*x - m) / s;
166 }
167 } else {
168 let num_channels = self.mean.len();
170
171 if shape.len() == 3 && shape[0] == num_channels {
172 let spatial = shape[1] * shape[2];
174 for c in 0..num_channels {
175 let offset = c * spatial;
176 let m = self.mean[c];
177 let s = self.std[c];
178 for i in 0..spatial {
179 data[offset + i] = (data[offset + i] - m) / s;
180 }
181 }
182 } else if shape.len() == 4 && shape[1] == num_channels {
183 let spatial = shape[2] * shape[3];
185 let sample_size = num_channels * spatial;
186 for n in 0..shape[0] {
187 for c in 0..num_channels {
188 let offset = n * sample_size + c * spatial;
189 let m = self.mean[c];
190 let s = self.std[c];
191 for i in 0..spatial {
192 data[offset + i] = (data[offset + i] - m) / s;
193 }
194 }
195 }
196 } else {
197 let m = self.mean[0];
199 let s = self.std[0];
200 for x in &mut data {
201 *x = (*x - m) / s;
202 }
203 }
204 }
205
206 Tensor::from_vec(data, shape).unwrap()
207 }
208}
209
210pub struct RandomNoise {
216 std: f32,
217}
218
219impl RandomNoise {
220 #[must_use]
222 pub fn new(std: f32) -> Self {
223 Self { std }
224 }
225}
226
227impl Transform for RandomNoise {
228 fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
229 if self.std == 0.0 {
230 return input.clone();
231 }
232
233 let mut rng = rand::thread_rng();
234 let data = input.to_vec();
235 let noisy: Vec<f32> = data
236 .iter()
237 .map(|&x| {
238 let u1: f32 = rng.r#gen();
240 let u2: f32 = rng.r#gen();
241 let z = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f32::consts::PI * u2).cos();
242 x + z * self.std
243 })
244 .collect();
245 Tensor::from_vec(noisy, input.shape()).unwrap()
246 }
247}
248
249pub struct RandomCrop {
255 size: Vec<usize>,
256}
257
258impl RandomCrop {
259 #[must_use]
261 pub fn new(size: Vec<usize>) -> Self {
262 Self { size }
263 }
264
265 #[must_use]
267 pub fn new_2d(height: usize, width: usize) -> Self {
268 Self::new(vec![height, width])
269 }
270}
271
272impl Transform for RandomCrop {
273 fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
274 let shape = input.shape();
275
276 if shape.len() < self.size.len() {
278 return input.clone();
279 }
280
281 let spatial_start = shape.len() - self.size.len();
282 let mut rng = rand::thread_rng();
283
284 let mut offsets = Vec::with_capacity(self.size.len());
286 for (i, &target_dim) in self.size.iter().enumerate() {
287 let input_dim = shape[spatial_start + i];
288 if input_dim <= target_dim {
289 offsets.push(0);
290 } else {
291 offsets.push(rng.gen_range(0..=input_dim - target_dim));
292 }
293 }
294
295 let crop_sizes: Vec<usize> = self
297 .size
298 .iter()
299 .enumerate()
300 .map(|(i, &s)| s.min(shape[spatial_start + i]))
301 .collect();
302
303 let data = input.to_vec();
304
305 if shape.len() == 1 && self.size.len() == 1 {
307 let start = offsets[0];
308 let end = start + crop_sizes[0];
309 let cropped = data[start..end].to_vec();
310 let len = cropped.len();
311 return Tensor::from_vec(cropped, &[len]).unwrap();
312 }
313
314 if shape.len() == 2 && self.size.len() == 2 {
316 let (_h, w) = (shape[0], shape[1]);
317 let (crop_h, crop_w) = (crop_sizes[0], crop_sizes[1]);
318 let (off_h, off_w) = (offsets[0], offsets[1]);
319
320 let mut cropped = Vec::with_capacity(crop_h * crop_w);
321 for row in off_h..off_h + crop_h {
322 for col in off_w..off_w + crop_w {
323 cropped.push(data[row * w + col]);
324 }
325 }
326 return Tensor::from_vec(cropped, &[crop_h, crop_w]).unwrap();
327 }
328
329 if shape.len() == 3 && self.size.len() == 2 {
331 let (c, h, w) = (shape[0], shape[1], shape[2]);
332 let (crop_h, crop_w) = (crop_sizes[0], crop_sizes[1]);
333 let (off_h, off_w) = (offsets[0], offsets[1]);
334
335 let mut cropped = Vec::with_capacity(c * crop_h * crop_w);
336 for channel in 0..c {
337 for row in off_h..off_h + crop_h {
338 for col in off_w..off_w + crop_w {
339 cropped.push(data[channel * h * w + row * w + col]);
340 }
341 }
342 }
343 return Tensor::from_vec(cropped, &[c, crop_h, crop_w]).unwrap();
344 }
345
346 if shape.len() == 4 && self.size.len() == 2 {
348 let (n, c, h, w) = (shape[0], shape[1], shape[2], shape[3]);
349 let (crop_h, crop_w) = (crop_sizes[0], crop_sizes[1]);
350 let (off_h, off_w) = (offsets[0], offsets[1]);
351
352 let mut cropped = Vec::with_capacity(n * c * crop_h * crop_w);
353 for batch in 0..n {
354 for channel in 0..c {
355 for row in off_h..off_h + crop_h {
356 for col in off_w..off_w + crop_w {
357 let idx = batch * c * h * w + channel * h * w + row * w + col;
358 cropped.push(data[idx]);
359 }
360 }
361 }
362 }
363 return Tensor::from_vec(cropped, &[n, c, crop_h, crop_w]).unwrap();
364 }
365
366 input.clone()
368 }
369}
370
371pub struct RandomFlip {
377 dim: usize,
378 probability: f32,
379}
380
381impl RandomFlip {
382 #[must_use]
384 pub fn new(dim: usize, probability: f32) -> Self {
385 Self {
386 dim,
387 probability: probability.clamp(0.0, 1.0),
388 }
389 }
390
391 #[must_use]
393 pub fn horizontal() -> Self {
394 Self::new(1, 0.5)
395 }
396
397 #[must_use]
399 pub fn vertical() -> Self {
400 Self::new(0, 0.5)
401 }
402}
403
404impl Transform for RandomFlip {
405 fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
406 let mut rng = rand::thread_rng();
407 if rng.r#gen::<f32>() > self.probability {
408 return input.clone();
409 }
410
411 let shape = input.shape();
412 if self.dim >= shape.len() {
413 return input.clone();
414 }
415
416 let data = input.to_vec();
417 let ndim = shape.len();
418
419 let total = data.len();
422 let mut flipped = vec![0.0f32; total];
423
424 let mut strides = vec![1usize; ndim];
426 for i in (0..ndim - 1).rev() {
427 strides[i] = strides[i + 1] * shape[i + 1];
428 }
429
430 let dim = self.dim;
431 let dim_size = shape[dim];
432 let dim_stride = strides[dim];
433
434 for i in 0..total {
435 let coord_in_dim = (i / dim_stride) % dim_size;
437 let flipped_coord = dim_size - 1 - coord_in_dim;
438 let diff = flipped_coord as isize - coord_in_dim as isize;
440 let src = (i as isize + diff * dim_stride as isize) as usize;
441 flipped[i] = data[src];
442 }
443
444 Tensor::from_vec(flipped, shape).unwrap()
445 }
446}
447
448pub struct Scale {
454 factor: f32,
455}
456
457impl Scale {
458 #[must_use]
460 pub fn new(factor: f32) -> Self {
461 Self { factor }
462 }
463}
464
465impl Transform for Scale {
466 fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
467 input.mul_scalar(self.factor)
468 }
469}
470
471pub struct Clamp {
477 min: f32,
478 max: f32,
479}
480
481impl Clamp {
482 #[must_use]
484 pub fn new(min: f32, max: f32) -> Self {
485 Self { min, max }
486 }
487
488 #[must_use]
490 pub fn zero_one() -> Self {
491 Self::new(0.0, 1.0)
492 }
493
494 #[must_use]
496 pub fn symmetric() -> Self {
497 Self::new(-1.0, 1.0)
498 }
499}
500
501impl Transform for Clamp {
502 fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
503 let data = input.to_vec();
504 let clamped: Vec<f32> = data.iter().map(|&x| x.clamp(self.min, self.max)).collect();
505 Tensor::from_vec(clamped, input.shape()).unwrap()
506 }
507}
508
509pub struct Flatten;
515
516impl Flatten {
517 #[must_use]
519 pub fn new() -> Self {
520 Self
521 }
522}
523
524impl Default for Flatten {
525 fn default() -> Self {
526 Self::new()
527 }
528}
529
530impl Transform for Flatten {
531 fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
532 let data = input.to_vec();
533 Tensor::from_vec(data.clone(), &[data.len()]).unwrap()
534 }
535}
536
537pub struct Reshape {
543 shape: Vec<usize>,
544}
545
546impl Reshape {
547 #[must_use]
549 pub fn new(shape: Vec<usize>) -> Self {
550 Self { shape }
551 }
552}
553
554impl Transform for Reshape {
555 fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
556 let data = input.to_vec();
557 let expected_size: usize = self.shape.iter().product();
558
559 if data.len() != expected_size {
560 return input.clone();
562 }
563
564 Tensor::from_vec(data, &self.shape).unwrap()
565 }
566}
567
568pub struct DropoutTransform {
577 probability: f32,
578 training: std::sync::atomic::AtomicBool,
579}
580
581impl DropoutTransform {
582 #[must_use]
584 pub fn new(probability: f32) -> Self {
585 Self {
586 probability: probability.clamp(0.0, 1.0),
587 training: std::sync::atomic::AtomicBool::new(true),
588 }
589 }
590
591 pub fn set_training(&self, training: bool) {
593 self.training.store(training, std::sync::atomic::Ordering::Relaxed);
594 }
595
596 pub fn is_training(&self) -> bool {
598 self.training.load(std::sync::atomic::Ordering::Relaxed)
599 }
600}
601
602impl Transform for DropoutTransform {
603 fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
604 if !self.is_training() || self.probability == 0.0 {
606 return input.clone();
607 }
608
609 let mut rng = rand::thread_rng();
610 let scale = 1.0 / (1.0 - self.probability);
611 let data = input.to_vec();
612
613 let dropped: Vec<f32> = data
614 .iter()
615 .map(|&x| {
616 if rng.r#gen::<f32>() < self.probability {
617 0.0
618 } else {
619 x * scale
620 }
621 })
622 .collect();
623
624 Tensor::from_vec(dropped, input.shape()).unwrap()
625 }
626}
627
628pub struct Lambda<F>
634where
635 F: Fn(&Tensor<f32>) -> Tensor<f32> + Send + Sync,
636{
637 func: F,
638}
639
640impl<F> Lambda<F>
641where
642 F: Fn(&Tensor<f32>) -> Tensor<f32> + Send + Sync,
643{
644 pub fn new(func: F) -> Self {
646 Self { func }
647 }
648}
649
650impl<F> Transform for Lambda<F>
651where
652 F: Fn(&Tensor<f32>) -> Tensor<f32> + Send + Sync,
653{
654 fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
655 (self.func)(input)
656 }
657}
658
659#[cfg(test)]
664mod tests {
665 use super::*;
666
667 #[test]
668 fn test_normalize() {
669 let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[4]).unwrap();
670 let normalize = Normalize::new(2.5, 0.5);
671
672 let output = normalize.apply(&input);
673 let expected = [-3.0, -1.0, 1.0, 3.0];
674
675 let result = output.to_vec();
676 for (a, b) in result.iter().zip(expected.iter()) {
677 assert!((a - b).abs() < 1e-6);
678 }
679 }
680
681 #[test]
682 fn test_normalize_per_channel() {
683 let input = Tensor::from_vec(
685 vec![1.0, 2.0, 3.0, 4.0, 10.0, 20.0, 30.0, 40.0],
686 &[2, 2, 2],
687 )
688 .unwrap();
689 let normalize = Normalize::per_channel(vec![0.0, 10.0], vec![1.0, 10.0]);
690
691 let output = normalize.apply(&input);
692 let result = output.to_vec();
693 assert!((result[0] - 1.0).abs() < 1e-6);
695 assert!((result[3] - 4.0).abs() < 1e-6);
696 assert!((result[4] - 0.0).abs() < 1e-6); assert!((result[5] - 1.0).abs() < 1e-6); }
700
701 #[test]
702 fn test_scale() {
703 let input = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
704 let scale = Scale::new(2.0);
705
706 let output = scale.apply(&input);
707 assert_eq!(output.to_vec(), vec![2.0, 4.0, 6.0]);
708 }
709
710 #[test]
711 fn test_clamp() {
712 let input = Tensor::from_vec(vec![-1.0, 0.5, 2.0], &[3]).unwrap();
713 let clamp = Clamp::zero_one();
714
715 let output = clamp.apply(&input);
716 assert_eq!(output.to_vec(), vec![0.0, 0.5, 1.0]);
717 }
718
719 #[test]
720 fn test_flatten() {
721 let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
722 let flatten = Flatten::new();
723
724 let output = flatten.apply(&input);
725 assert_eq!(output.shape(), &[4]);
726 assert_eq!(output.to_vec(), vec![1.0, 2.0, 3.0, 4.0]);
727 }
728
729 #[test]
730 fn test_reshape() {
731 let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[6]).unwrap();
732 let reshape = Reshape::new(vec![2, 3]);
733
734 let output = reshape.apply(&input);
735 assert_eq!(output.shape(), &[2, 3]);
736 }
737
738 #[test]
739 fn test_compose() {
740 let normalize = Normalize::new(0.0, 1.0);
741 let scale = Scale::new(2.0);
742
743 let compose = Compose::new(vec![Box::new(normalize), Box::new(scale)]);
744
745 let input = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
746 let output = compose.apply(&input);
747
748 assert_eq!(output.to_vec(), vec![2.0, 4.0, 6.0]);
750 }
751
752 #[test]
753 fn test_compose_builder() {
754 let compose = Compose::empty()
755 .add(Normalize::new(0.0, 1.0))
756 .add(Scale::new(2.0));
757
758 let input = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
759 let output = compose.apply(&input);
760
761 assert_eq!(output.to_vec(), vec![2.0, 4.0, 6.0]);
762 }
763
764 #[test]
765 fn test_random_noise() {
766 let input = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
767 let noise = RandomNoise::new(0.0);
768
769 let output = noise.apply(&input);
771 assert_eq!(output.to_vec(), vec![1.0, 2.0, 3.0]);
772 }
773
774 #[test]
775 fn test_random_flip_1d() {
776 let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[4]).unwrap();
777 let flip = RandomFlip::new(0, 1.0); let output = flip.apply(&input);
780 assert_eq!(output.to_vec(), vec![4.0, 3.0, 2.0, 1.0]);
781 }
782
783 #[test]
784 fn test_random_flip_2d_horizontal() {
785 let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
786 let flip = RandomFlip::new(1, 1.0); let output = flip.apply(&input);
789 assert_eq!(output.to_vec(), vec![2.0, 1.0, 4.0, 3.0]);
791 }
792
793 #[test]
794 fn test_random_flip_2d_vertical() {
795 let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
796 let flip = RandomFlip::new(0, 1.0); let output = flip.apply(&input);
799 assert_eq!(output.to_vec(), vec![3.0, 4.0, 1.0, 2.0]);
801 }
802
803 #[test]
804 fn test_random_flip_3d() {
805 let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[1, 2, 2]).unwrap();
807 let flip = RandomFlip::new(2, 1.0); let output = flip.apply(&input);
810 assert_eq!(output.to_vec(), vec![2.0, 1.0, 4.0, 3.0]);
812 assert_eq!(output.shape(), &[1, 2, 2]);
813 }
814
815 #[test]
816 fn test_random_flip_4d() {
817 let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[1, 1, 2, 2]).unwrap();
819 let flip = RandomFlip::new(2, 1.0); let output = flip.apply(&input);
822 assert_eq!(output.to_vec(), vec![3.0, 4.0, 1.0, 2.0]);
824 assert_eq!(output.shape(), &[1, 1, 2, 2]);
825 }
826
827 #[test]
828 fn test_dropout_eval_mode() {
829 let input = Tensor::from_vec(vec![1.0; 100], &[100]).unwrap();
830 let dropout = DropoutTransform::new(0.5);
831
832 let output_train = dropout.apply(&input);
834 let zeros_train = output_train.to_vec().iter().filter(|&&x| x == 0.0).count();
835 assert!(zeros_train > 0, "Training mode should drop elements");
836
837 dropout.set_training(false);
839 let output_eval = dropout.apply(&input);
840 assert_eq!(output_eval.to_vec(), vec![1.0; 100]);
841 }
842
843 #[test]
844 fn test_dropout_transform() {
845 let input = Tensor::from_vec(vec![1.0; 1000], &[1000]).unwrap();
846 let dropout = DropoutTransform::new(0.5);
847
848 let output = dropout.apply(&input);
849 let output_vec = output.to_vec();
850
851 let zeros = output_vec.iter().filter(|&&x| x == 0.0).count();
853 assert!(
854 zeros > 300 && zeros < 700,
855 "Expected ~500 zeros, got {zeros}"
856 );
857
858 let nonzeros: Vec<f32> = output_vec.iter().filter(|&&x| x != 0.0).copied().collect();
860 for x in nonzeros {
861 assert!((x - 2.0).abs() < 1e-6);
862 }
863 }
864
865 #[test]
866 fn test_lambda() {
867 let lambda = Lambda::new(|t: &Tensor<f32>| t.mul_scalar(3.0));
868
869 let input = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
870 let output = lambda.apply(&input);
871
872 assert_eq!(output.to_vec(), vec![3.0, 6.0, 9.0]);
873 }
874
875 #[test]
876 fn test_to_tensor() {
877 let input = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
878 let to_tensor = ToTensor::new();
879
880 let output = to_tensor.apply(&input);
881 assert_eq!(output.to_vec(), input.to_vec());
882 }
883
884 #[test]
885 fn test_normalize_variants() {
886 let standard = Normalize::standard();
887 assert_eq!(standard.mean, vec![0.0]);
888 assert_eq!(standard.std, vec![1.0]);
889
890 let zero_centered = Normalize::zero_centered();
891 assert_eq!(zero_centered.mean, vec![0.5]);
892 assert_eq!(zero_centered.std, vec![0.5]);
893 }
894
895 #[test]
896 fn test_random_crop_1d() {
897 let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0], &[5]).unwrap();
898 let crop = RandomCrop::new(vec![3]);
899
900 let output = crop.apply(&input);
901 assert_eq!(output.shape(), &[3]);
902 }
903
904 #[test]
905 fn test_random_crop_2d() {
906 let input = Tensor::from_vec((1..=16).map(|x| x as f32).collect(), &[4, 4]).unwrap();
908 let crop = RandomCrop::new_2d(2, 2);
909
910 let output = crop.apply(&input);
911 assert_eq!(output.shape(), &[2, 2]);
912 let vals = output.to_vec();
914 assert_eq!(vals.len(), 4);
915 }
916
917 #[test]
918 fn test_random_crop_3d() {
919 let input = Tensor::from_vec((1..=32).map(|x| x as f32).collect(), &[2, 4, 4]).unwrap();
921 let crop = RandomCrop::new_2d(2, 2);
922
923 let output = crop.apply(&input);
924 assert_eq!(output.shape(), &[2, 2, 2]); }
926}