1use axonml_tensor::Tensor;
9use rand::Rng;
10
11pub trait Transform: Send + Sync {
17 fn apply(&self, input: &Tensor<f32>) -> Tensor<f32>;
19}
20
21pub struct Compose {
27 transforms: Vec<Box<dyn Transform>>,
28}
29
30impl Compose {
31 #[must_use]
33 pub fn new(transforms: Vec<Box<dyn Transform>>) -> Self {
34 Self { transforms }
35 }
36
37 #[must_use]
39 pub fn empty() -> Self {
40 Self {
41 transforms: Vec::new(),
42 }
43 }
44
45 pub fn add<T: Transform + 'static>(mut self, transform: T) -> Self {
47 self.transforms.push(Box::new(transform));
48 self
49 }
50}
51
52impl Transform for Compose {
53 fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
54 let mut result = input.clone();
55 for transform in &self.transforms {
56 result = transform.apply(&result);
57 }
58 result
59 }
60}
61
62pub struct ToTensor;
68
69impl ToTensor {
70 #[must_use]
72 pub fn new() -> Self {
73 Self
74 }
75}
76
77impl Default for ToTensor {
78 fn default() -> Self {
79 Self::new()
80 }
81}
82
83impl Transform for ToTensor {
84 fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
85 input.clone()
86 }
87}
88
89pub struct Normalize {
95 mean: f32,
96 std: f32,
97}
98
99impl Normalize {
100 #[must_use]
102 pub fn new(mean: f32, std: f32) -> Self {
103 Self { mean, std }
104 }
105
106 #[must_use]
108 pub fn standard() -> Self {
109 Self::new(0.0, 1.0)
110 }
111
112 #[must_use]
114 pub fn zero_centered() -> Self {
115 Self::new(0.5, 0.5)
116 }
117}
118
119impl Transform for Normalize {
120 fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
121 let data = input.to_vec();
122 let normalized: Vec<f32> = data.iter().map(|&x| (x - self.mean) / self.std).collect();
123 Tensor::from_vec(normalized, input.shape()).unwrap()
124 }
125}
126
127pub struct RandomNoise {
133 std: f32,
134}
135
136impl RandomNoise {
137 #[must_use]
139 pub fn new(std: f32) -> Self {
140 Self { std }
141 }
142}
143
144impl Transform for RandomNoise {
145 fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
146 if self.std == 0.0 {
147 return input.clone();
148 }
149
150 let mut rng = rand::thread_rng();
151 let data = input.to_vec();
152 let noisy: Vec<f32> = data
153 .iter()
154 .map(|&x| {
155 let u1: f32 = rng.gen();
157 let u2: f32 = rng.gen();
158 let z = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f32::consts::PI * u2).cos();
159 x + z * self.std
160 })
161 .collect();
162 Tensor::from_vec(noisy, input.shape()).unwrap()
163 }
164}
165
166pub struct RandomCrop {
172 size: Vec<usize>,
173}
174
175impl RandomCrop {
176 #[must_use]
178 pub fn new(size: Vec<usize>) -> Self {
179 Self { size }
180 }
181
182 #[must_use]
184 pub fn new_2d(height: usize, width: usize) -> Self {
185 Self::new(vec![height, width])
186 }
187}
188
189impl Transform for RandomCrop {
190 fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
191 let shape = input.shape();
192
193 if shape.len() < self.size.len() {
195 return input.clone();
196 }
197
198 let spatial_start = shape.len() - self.size.len();
199 let mut rng = rand::thread_rng();
200
201 let mut offsets = Vec::with_capacity(self.size.len());
203 for (i, &target_dim) in self.size.iter().enumerate() {
204 let input_dim = shape[spatial_start + i];
205 if input_dim <= target_dim {
206 offsets.push(0);
207 } else {
208 offsets.push(rng.gen_range(0..=input_dim - target_dim));
209 }
210 }
211
212 let crop_sizes: Vec<usize> = self
214 .size
215 .iter()
216 .enumerate()
217 .map(|(i, &s)| s.min(shape[spatial_start + i]))
218 .collect();
219
220 let data = input.to_vec();
221
222 if shape.len() == 1 && self.size.len() == 1 {
224 let start = offsets[0];
225 let end = start + crop_sizes[0];
226 let cropped = data[start..end].to_vec();
227 let len = cropped.len();
228 return Tensor::from_vec(cropped, &[len]).unwrap();
229 }
230
231 if shape.len() == 2 && self.size.len() == 2 {
233 let (_h, w) = (shape[0], shape[1]);
234 let (crop_h, crop_w) = (crop_sizes[0], crop_sizes[1]);
235 let (off_h, off_w) = (offsets[0], offsets[1]);
236
237 let mut cropped = Vec::with_capacity(crop_h * crop_w);
238 for row in off_h..off_h + crop_h {
239 for col in off_w..off_w + crop_w {
240 cropped.push(data[row * w + col]);
241 }
242 }
243 return Tensor::from_vec(cropped, &[crop_h, crop_w]).unwrap();
244 }
245
246 if shape.len() == 3 && self.size.len() == 2 {
248 let (c, h, w) = (shape[0], shape[1], shape[2]);
249 let (crop_h, crop_w) = (crop_sizes[0], crop_sizes[1]);
250 let (off_h, off_w) = (offsets[0], offsets[1]);
251
252 let mut cropped = Vec::with_capacity(c * crop_h * crop_w);
253 for channel in 0..c {
254 for row in off_h..off_h + crop_h {
255 for col in off_w..off_w + crop_w {
256 cropped.push(data[channel * h * w + row * w + col]);
257 }
258 }
259 }
260 return Tensor::from_vec(cropped, &[c, crop_h, crop_w]).unwrap();
261 }
262
263 if shape.len() == 4 && self.size.len() == 2 {
265 let (n, c, h, w) = (shape[0], shape[1], shape[2], shape[3]);
266 let (crop_h, crop_w) = (crop_sizes[0], crop_sizes[1]);
267 let (off_h, off_w) = (offsets[0], offsets[1]);
268
269 let mut cropped = Vec::with_capacity(n * c * crop_h * crop_w);
270 for batch in 0..n {
271 for channel in 0..c {
272 for row in off_h..off_h + crop_h {
273 for col in off_w..off_w + crop_w {
274 let idx = batch * c * h * w + channel * h * w + row * w + col;
275 cropped.push(data[idx]);
276 }
277 }
278 }
279 }
280 return Tensor::from_vec(cropped, &[n, c, crop_h, crop_w]).unwrap();
281 }
282
283 input.clone()
285 }
286}
287
288pub struct RandomFlip {
294 dim: usize,
295 probability: f32,
296}
297
298impl RandomFlip {
299 #[must_use]
301 pub fn new(dim: usize, probability: f32) -> Self {
302 Self {
303 dim,
304 probability: probability.clamp(0.0, 1.0),
305 }
306 }
307
308 #[must_use]
310 pub fn horizontal() -> Self {
311 Self::new(1, 0.5)
312 }
313
314 #[must_use]
316 pub fn vertical() -> Self {
317 Self::new(0, 0.5)
318 }
319}
320
321impl Transform for RandomFlip {
322 fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
323 let mut rng = rand::thread_rng();
324 if rng.gen::<f32>() > self.probability {
325 return input.clone();
326 }
327
328 let shape = input.shape();
329 if self.dim >= shape.len() {
330 return input.clone();
331 }
332
333 if shape.len() == 1 {
335 let mut data = input.to_vec();
336 data.reverse();
337 return Tensor::from_vec(data, shape).unwrap();
338 }
339
340 if shape.len() == 2 {
342 let data = input.to_vec();
343 let (rows, cols) = (shape[0], shape[1]);
344 let mut flipped = vec![0.0; data.len()];
345
346 if self.dim == 0 {
347 for r in 0..rows {
349 for c in 0..cols {
350 flipped[r * cols + c] = data[(rows - 1 - r) * cols + c];
351 }
352 }
353 } else {
354 for r in 0..rows {
356 for c in 0..cols {
357 flipped[r * cols + c] = data[r * cols + (cols - 1 - c)];
358 }
359 }
360 }
361
362 return Tensor::from_vec(flipped, shape).unwrap();
363 }
364
365 input.clone()
366 }
367}
368
369pub struct Scale {
375 factor: f32,
376}
377
378impl Scale {
379 #[must_use]
381 pub fn new(factor: f32) -> Self {
382 Self { factor }
383 }
384}
385
386impl Transform for Scale {
387 fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
388 input.mul_scalar(self.factor)
389 }
390}
391
392pub struct Clamp {
398 min: f32,
399 max: f32,
400}
401
402impl Clamp {
403 #[must_use]
405 pub fn new(min: f32, max: f32) -> Self {
406 Self { min, max }
407 }
408
409 #[must_use]
411 pub fn zero_one() -> Self {
412 Self::new(0.0, 1.0)
413 }
414
415 #[must_use]
417 pub fn symmetric() -> Self {
418 Self::new(-1.0, 1.0)
419 }
420}
421
422impl Transform for Clamp {
423 fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
424 let data = input.to_vec();
425 let clamped: Vec<f32> = data.iter().map(|&x| x.clamp(self.min, self.max)).collect();
426 Tensor::from_vec(clamped, input.shape()).unwrap()
427 }
428}
429
430pub struct Flatten;
436
437impl Flatten {
438 #[must_use]
440 pub fn new() -> Self {
441 Self
442 }
443}
444
445impl Default for Flatten {
446 fn default() -> Self {
447 Self::new()
448 }
449}
450
451impl Transform for Flatten {
452 fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
453 let data = input.to_vec();
454 Tensor::from_vec(data.clone(), &[data.len()]).unwrap()
455 }
456}
457
458pub struct Reshape {
464 shape: Vec<usize>,
465}
466
467impl Reshape {
468 #[must_use]
470 pub fn new(shape: Vec<usize>) -> Self {
471 Self { shape }
472 }
473}
474
475impl Transform for Reshape {
476 fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
477 let data = input.to_vec();
478 let expected_size: usize = self.shape.iter().product();
479
480 if data.len() != expected_size {
481 return input.clone();
483 }
484
485 Tensor::from_vec(data, &self.shape).unwrap()
486 }
487}
488
489pub struct DropoutTransform {
495 probability: f32,
496}
497
498impl DropoutTransform {
499 #[must_use]
501 pub fn new(probability: f32) -> Self {
502 Self {
503 probability: probability.clamp(0.0, 1.0),
504 }
505 }
506}
507
508impl Transform for DropoutTransform {
509 fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
510 if self.probability == 0.0 {
511 return input.clone();
512 }
513
514 let mut rng = rand::thread_rng();
515 let scale = 1.0 / (1.0 - self.probability);
516 let data = input.to_vec();
517
518 let dropped: Vec<f32> = data
519 .iter()
520 .map(|&x| {
521 if rng.gen::<f32>() < self.probability {
522 0.0
523 } else {
524 x * scale
525 }
526 })
527 .collect();
528
529 Tensor::from_vec(dropped, input.shape()).unwrap()
530 }
531}
532
533pub struct Lambda<F>
539where
540 F: Fn(&Tensor<f32>) -> Tensor<f32> + Send + Sync,
541{
542 func: F,
543}
544
545impl<F> Lambda<F>
546where
547 F: Fn(&Tensor<f32>) -> Tensor<f32> + Send + Sync,
548{
549 pub fn new(func: F) -> Self {
551 Self { func }
552 }
553}
554
555impl<F> Transform for Lambda<F>
556where
557 F: Fn(&Tensor<f32>) -> Tensor<f32> + Send + Sync,
558{
559 fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
560 (self.func)(input)
561 }
562}
563
564#[cfg(test)]
569mod tests {
570 use super::*;
571
572 #[test]
573 fn test_normalize() {
574 let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[4]).unwrap();
575 let normalize = Normalize::new(2.5, 0.5);
576
577 let output = normalize.apply(&input);
578 let expected = [-3.0, -1.0, 1.0, 3.0];
579
580 let result = output.to_vec();
581 for (a, b) in result.iter().zip(expected.iter()) {
582 assert!((a - b).abs() < 1e-6);
583 }
584 }
585
586 #[test]
587 fn test_scale() {
588 let input = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
589 let scale = Scale::new(2.0);
590
591 let output = scale.apply(&input);
592 assert_eq!(output.to_vec(), vec![2.0, 4.0, 6.0]);
593 }
594
595 #[test]
596 fn test_clamp() {
597 let input = Tensor::from_vec(vec![-1.0, 0.5, 2.0], &[3]).unwrap();
598 let clamp = Clamp::zero_one();
599
600 let output = clamp.apply(&input);
601 assert_eq!(output.to_vec(), vec![0.0, 0.5, 1.0]);
602 }
603
604 #[test]
605 fn test_flatten() {
606 let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
607 let flatten = Flatten::new();
608
609 let output = flatten.apply(&input);
610 assert_eq!(output.shape(), &[4]);
611 assert_eq!(output.to_vec(), vec![1.0, 2.0, 3.0, 4.0]);
612 }
613
614 #[test]
615 fn test_reshape() {
616 let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[6]).unwrap();
617 let reshape = Reshape::new(vec![2, 3]);
618
619 let output = reshape.apply(&input);
620 assert_eq!(output.shape(), &[2, 3]);
621 }
622
623 #[test]
624 fn test_compose() {
625 let normalize = Normalize::new(0.0, 1.0);
626 let scale = Scale::new(2.0);
627
628 let compose = Compose::new(vec![Box::new(normalize), Box::new(scale)]);
629
630 let input = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
631 let output = compose.apply(&input);
632
633 assert_eq!(output.to_vec(), vec![2.0, 4.0, 6.0]);
635 }
636
637 #[test]
638 fn test_compose_builder() {
639 let compose = Compose::empty()
640 .add(Normalize::new(0.0, 1.0))
641 .add(Scale::new(2.0));
642
643 let input = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
644 let output = compose.apply(&input);
645
646 assert_eq!(output.to_vec(), vec![2.0, 4.0, 6.0]);
647 }
648
649 #[test]
650 fn test_random_noise() {
651 let input = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
652 let noise = RandomNoise::new(0.0);
653
654 let output = noise.apply(&input);
656 assert_eq!(output.to_vec(), vec![1.0, 2.0, 3.0]);
657 }
658
659 #[test]
660 fn test_random_flip_1d() {
661 let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[4]).unwrap();
662 let flip = RandomFlip::new(0, 1.0); let output = flip.apply(&input);
665 assert_eq!(output.to_vec(), vec![4.0, 3.0, 2.0, 1.0]);
666 }
667
668 #[test]
669 fn test_random_flip_2d_horizontal() {
670 let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
671 let flip = RandomFlip::new(1, 1.0); let output = flip.apply(&input);
674 assert_eq!(output.to_vec(), vec![2.0, 1.0, 4.0, 3.0]);
676 }
677
678 #[test]
679 fn test_random_flip_2d_vertical() {
680 let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
681 let flip = RandomFlip::new(0, 1.0); let output = flip.apply(&input);
684 assert_eq!(output.to_vec(), vec![3.0, 4.0, 1.0, 2.0]);
686 }
687
688 #[test]
689 fn test_dropout_transform() {
690 let input = Tensor::from_vec(vec![1.0; 1000], &[1000]).unwrap();
691 let dropout = DropoutTransform::new(0.5);
692
693 let output = dropout.apply(&input);
694 let output_vec = output.to_vec();
695
696 let zeros = output_vec.iter().filter(|&&x| x == 0.0).count();
698 assert!(
699 zeros > 300 && zeros < 700,
700 "Expected ~500 zeros, got {zeros}"
701 );
702
703 let nonzeros: Vec<f32> = output_vec.iter().filter(|&&x| x != 0.0).copied().collect();
705 for x in nonzeros {
706 assert!((x - 2.0).abs() < 1e-6);
707 }
708 }
709
710 #[test]
711 fn test_lambda() {
712 let lambda = Lambda::new(|t: &Tensor<f32>| t.mul_scalar(3.0));
713
714 let input = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
715 let output = lambda.apply(&input);
716
717 assert_eq!(output.to_vec(), vec![3.0, 6.0, 9.0]);
718 }
719
720 #[test]
721 fn test_to_tensor() {
722 let input = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
723 let to_tensor = ToTensor::new();
724
725 let output = to_tensor.apply(&input);
726 assert_eq!(output.to_vec(), input.to_vec());
727 }
728
729 #[test]
730 fn test_normalize_variants() {
731 let standard = Normalize::standard();
732 assert_eq!(standard.mean, 0.0);
733 assert_eq!(standard.std, 1.0);
734
735 let zero_centered = Normalize::zero_centered();
736 assert_eq!(zero_centered.mean, 0.5);
737 assert_eq!(zero_centered.std, 0.5);
738 }
739
740 #[test]
741 fn test_random_crop_1d() {
742 let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0], &[5]).unwrap();
743 let crop = RandomCrop::new(vec![3]);
744
745 let output = crop.apply(&input);
746 assert_eq!(output.shape(), &[3]);
747 }
748
749 #[test]
750 fn test_random_crop_2d() {
751 let input = Tensor::from_vec((1..=16).map(|x| x as f32).collect(), &[4, 4]).unwrap();
753 let crop = RandomCrop::new_2d(2, 2);
754
755 let output = crop.apply(&input);
756 assert_eq!(output.shape(), &[2, 2]);
757 let vals = output.to_vec();
759 assert_eq!(vals.len(), 4);
760 }
761
762 #[test]
763 fn test_random_crop_3d() {
764 let input = Tensor::from_vec((1..=32).map(|x| x as f32).collect(), &[2, 4, 4]).unwrap();
766 let crop = RandomCrop::new_2d(2, 2);
767
768 let output = crop.apply(&input);
769 assert_eq!(output.shape(), &[2, 2, 2]); }
771}