1use axonml_tensor::Tensor;
18use rand::Rng;
19
20pub trait Transform: Send + Sync {
26 fn apply(&self, input: &Tensor<f32>) -> Tensor<f32>;
28}
29
30pub struct Compose {
36 transforms: Vec<Box<dyn Transform>>,
37}
38
39impl Compose {
40 #[must_use]
42 pub fn new(transforms: Vec<Box<dyn Transform>>) -> Self {
43 Self { transforms }
44 }
45
46 #[must_use]
48 pub fn empty() -> Self {
49 Self {
50 transforms: Vec::new(),
51 }
52 }
53
54 pub fn add<T: Transform + 'static>(mut self, transform: T) -> Self {
56 self.transforms.push(Box::new(transform));
57 self
58 }
59}
60
61impl Transform for Compose {
62 fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
63 let mut result = input.clone();
64 for transform in &self.transforms {
65 result = transform.apply(&result);
66 }
67 result
68 }
69}
70
71pub struct ToTensor;
77
78impl ToTensor {
79 #[must_use]
81 pub fn new() -> Self {
82 Self
83 }
84}
85
86impl Default for ToTensor {
87 fn default() -> Self {
88 Self::new()
89 }
90}
91
92impl Transform for ToTensor {
93 fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
94 input.clone()
95 }
96}
97
98pub struct Normalize {
104 mean: f32,
105 std: f32,
106}
107
108impl Normalize {
109 #[must_use]
111 pub fn new(mean: f32, std: f32) -> Self {
112 Self { mean, std }
113 }
114
115 #[must_use]
117 pub fn standard() -> Self {
118 Self::new(0.0, 1.0)
119 }
120
121 #[must_use]
123 pub fn zero_centered() -> Self {
124 Self::new(0.5, 0.5)
125 }
126}
127
128impl Transform for Normalize {
129 fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
130 let data = input.to_vec();
131 let normalized: Vec<f32> = data.iter().map(|&x| (x - self.mean) / self.std).collect();
132 Tensor::from_vec(normalized, input.shape()).unwrap()
133 }
134}
135
136pub struct RandomNoise {
142 std: f32,
143}
144
145impl RandomNoise {
146 #[must_use]
148 pub fn new(std: f32) -> Self {
149 Self { std }
150 }
151}
152
153impl Transform for RandomNoise {
154 fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
155 if self.std == 0.0 {
156 return input.clone();
157 }
158
159 let mut rng = rand::thread_rng();
160 let data = input.to_vec();
161 let noisy: Vec<f32> = data
162 .iter()
163 .map(|&x| {
164 let u1: f32 = rng.r#gen();
166 let u2: f32 = rng.r#gen();
167 let z = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f32::consts::PI * u2).cos();
168 x + z * self.std
169 })
170 .collect();
171 Tensor::from_vec(noisy, input.shape()).unwrap()
172 }
173}
174
175pub struct RandomCrop {
181 size: Vec<usize>,
182}
183
184impl RandomCrop {
185 #[must_use]
187 pub fn new(size: Vec<usize>) -> Self {
188 Self { size }
189 }
190
191 #[must_use]
193 pub fn new_2d(height: usize, width: usize) -> Self {
194 Self::new(vec![height, width])
195 }
196}
197
198impl Transform for RandomCrop {
199 fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
200 let shape = input.shape();
201
202 if shape.len() < self.size.len() {
204 return input.clone();
205 }
206
207 let spatial_start = shape.len() - self.size.len();
208 let mut rng = rand::thread_rng();
209
210 let mut offsets = Vec::with_capacity(self.size.len());
212 for (i, &target_dim) in self.size.iter().enumerate() {
213 let input_dim = shape[spatial_start + i];
214 if input_dim <= target_dim {
215 offsets.push(0);
216 } else {
217 offsets.push(rng.gen_range(0..=input_dim - target_dim));
218 }
219 }
220
221 let crop_sizes: Vec<usize> = self
223 .size
224 .iter()
225 .enumerate()
226 .map(|(i, &s)| s.min(shape[spatial_start + i]))
227 .collect();
228
229 let data = input.to_vec();
230
231 if shape.len() == 1 && self.size.len() == 1 {
233 let start = offsets[0];
234 let end = start + crop_sizes[0];
235 let cropped = data[start..end].to_vec();
236 let len = cropped.len();
237 return Tensor::from_vec(cropped, &[len]).unwrap();
238 }
239
240 if shape.len() == 2 && self.size.len() == 2 {
242 let (_h, w) = (shape[0], shape[1]);
243 let (crop_h, crop_w) = (crop_sizes[0], crop_sizes[1]);
244 let (off_h, off_w) = (offsets[0], offsets[1]);
245
246 let mut cropped = Vec::with_capacity(crop_h * crop_w);
247 for row in off_h..off_h + crop_h {
248 for col in off_w..off_w + crop_w {
249 cropped.push(data[row * w + col]);
250 }
251 }
252 return Tensor::from_vec(cropped, &[crop_h, crop_w]).unwrap();
253 }
254
255 if shape.len() == 3 && self.size.len() == 2 {
257 let (c, h, w) = (shape[0], shape[1], shape[2]);
258 let (crop_h, crop_w) = (crop_sizes[0], crop_sizes[1]);
259 let (off_h, off_w) = (offsets[0], offsets[1]);
260
261 let mut cropped = Vec::with_capacity(c * crop_h * crop_w);
262 for channel in 0..c {
263 for row in off_h..off_h + crop_h {
264 for col in off_w..off_w + crop_w {
265 cropped.push(data[channel * h * w + row * w + col]);
266 }
267 }
268 }
269 return Tensor::from_vec(cropped, &[c, crop_h, crop_w]).unwrap();
270 }
271
272 if shape.len() == 4 && self.size.len() == 2 {
274 let (n, c, h, w) = (shape[0], shape[1], shape[2], shape[3]);
275 let (crop_h, crop_w) = (crop_sizes[0], crop_sizes[1]);
276 let (off_h, off_w) = (offsets[0], offsets[1]);
277
278 let mut cropped = Vec::with_capacity(n * c * crop_h * crop_w);
279 for batch in 0..n {
280 for channel in 0..c {
281 for row in off_h..off_h + crop_h {
282 for col in off_w..off_w + crop_w {
283 let idx = batch * c * h * w + channel * h * w + row * w + col;
284 cropped.push(data[idx]);
285 }
286 }
287 }
288 }
289 return Tensor::from_vec(cropped, &[n, c, crop_h, crop_w]).unwrap();
290 }
291
292 input.clone()
294 }
295}
296
297pub struct RandomFlip {
303 dim: usize,
304 probability: f32,
305}
306
307impl RandomFlip {
308 #[must_use]
310 pub fn new(dim: usize, probability: f32) -> Self {
311 Self {
312 dim,
313 probability: probability.clamp(0.0, 1.0),
314 }
315 }
316
317 #[must_use]
319 pub fn horizontal() -> Self {
320 Self::new(1, 0.5)
321 }
322
323 #[must_use]
325 pub fn vertical() -> Self {
326 Self::new(0, 0.5)
327 }
328}
329
330impl Transform for RandomFlip {
331 fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
332 let mut rng = rand::thread_rng();
333 if rng.r#gen::<f32>() > self.probability {
334 return input.clone();
335 }
336
337 let shape = input.shape();
338 if self.dim >= shape.len() {
339 return input.clone();
340 }
341
342 if shape.len() == 1 {
344 let mut data = input.to_vec();
345 data.reverse();
346 return Tensor::from_vec(data, shape).unwrap();
347 }
348
349 if shape.len() == 2 {
351 let data = input.to_vec();
352 let (rows, cols) = (shape[0], shape[1]);
353 let mut flipped = vec![0.0; data.len()];
354
355 if self.dim == 0 {
356 for r in 0..rows {
358 for c in 0..cols {
359 flipped[r * cols + c] = data[(rows - 1 - r) * cols + c];
360 }
361 }
362 } else {
363 for r in 0..rows {
365 for c in 0..cols {
366 flipped[r * cols + c] = data[r * cols + (cols - 1 - c)];
367 }
368 }
369 }
370
371 return Tensor::from_vec(flipped, shape).unwrap();
372 }
373
374 input.clone()
375 }
376}
377
378pub struct Scale {
384 factor: f32,
385}
386
387impl Scale {
388 #[must_use]
390 pub fn new(factor: f32) -> Self {
391 Self { factor }
392 }
393}
394
395impl Transform for Scale {
396 fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
397 input.mul_scalar(self.factor)
398 }
399}
400
401pub struct Clamp {
407 min: f32,
408 max: f32,
409}
410
411impl Clamp {
412 #[must_use]
414 pub fn new(min: f32, max: f32) -> Self {
415 Self { min, max }
416 }
417
418 #[must_use]
420 pub fn zero_one() -> Self {
421 Self::new(0.0, 1.0)
422 }
423
424 #[must_use]
426 pub fn symmetric() -> Self {
427 Self::new(-1.0, 1.0)
428 }
429}
430
431impl Transform for Clamp {
432 fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
433 let data = input.to_vec();
434 let clamped: Vec<f32> = data.iter().map(|&x| x.clamp(self.min, self.max)).collect();
435 Tensor::from_vec(clamped, input.shape()).unwrap()
436 }
437}
438
439pub struct Flatten;
445
446impl Flatten {
447 #[must_use]
449 pub fn new() -> Self {
450 Self
451 }
452}
453
454impl Default for Flatten {
455 fn default() -> Self {
456 Self::new()
457 }
458}
459
460impl Transform for Flatten {
461 fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
462 let data = input.to_vec();
463 Tensor::from_vec(data.clone(), &[data.len()]).unwrap()
464 }
465}
466
467pub struct Reshape {
473 shape: Vec<usize>,
474}
475
476impl Reshape {
477 #[must_use]
479 pub fn new(shape: Vec<usize>) -> Self {
480 Self { shape }
481 }
482}
483
484impl Transform for Reshape {
485 fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
486 let data = input.to_vec();
487 let expected_size: usize = self.shape.iter().product();
488
489 if data.len() != expected_size {
490 return input.clone();
492 }
493
494 Tensor::from_vec(data, &self.shape).unwrap()
495 }
496}
497
498pub struct DropoutTransform {
504 probability: f32,
505}
506
507impl DropoutTransform {
508 #[must_use]
510 pub fn new(probability: f32) -> Self {
511 Self {
512 probability: probability.clamp(0.0, 1.0),
513 }
514 }
515}
516
517impl Transform for DropoutTransform {
518 fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
519 if self.probability == 0.0 {
520 return input.clone();
521 }
522
523 let mut rng = rand::thread_rng();
524 let scale = 1.0 / (1.0 - self.probability);
525 let data = input.to_vec();
526
527 let dropped: Vec<f32> = data
528 .iter()
529 .map(|&x| {
530 if rng.r#gen::<f32>() < self.probability {
531 0.0
532 } else {
533 x * scale
534 }
535 })
536 .collect();
537
538 Tensor::from_vec(dropped, input.shape()).unwrap()
539 }
540}
541
542pub struct Lambda<F>
548where
549 F: Fn(&Tensor<f32>) -> Tensor<f32> + Send + Sync,
550{
551 func: F,
552}
553
554impl<F> Lambda<F>
555where
556 F: Fn(&Tensor<f32>) -> Tensor<f32> + Send + Sync,
557{
558 pub fn new(func: F) -> Self {
560 Self { func }
561 }
562}
563
564impl<F> Transform for Lambda<F>
565where
566 F: Fn(&Tensor<f32>) -> Tensor<f32> + Send + Sync,
567{
568 fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
569 (self.func)(input)
570 }
571}
572
573#[cfg(test)]
578mod tests {
579 use super::*;
580
581 #[test]
582 fn test_normalize() {
583 let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[4]).unwrap();
584 let normalize = Normalize::new(2.5, 0.5);
585
586 let output = normalize.apply(&input);
587 let expected = [-3.0, -1.0, 1.0, 3.0];
588
589 let result = output.to_vec();
590 for (a, b) in result.iter().zip(expected.iter()) {
591 assert!((a - b).abs() < 1e-6);
592 }
593 }
594
595 #[test]
596 fn test_scale() {
597 let input = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
598 let scale = Scale::new(2.0);
599
600 let output = scale.apply(&input);
601 assert_eq!(output.to_vec(), vec![2.0, 4.0, 6.0]);
602 }
603
604 #[test]
605 fn test_clamp() {
606 let input = Tensor::from_vec(vec![-1.0, 0.5, 2.0], &[3]).unwrap();
607 let clamp = Clamp::zero_one();
608
609 let output = clamp.apply(&input);
610 assert_eq!(output.to_vec(), vec![0.0, 0.5, 1.0]);
611 }
612
613 #[test]
614 fn test_flatten() {
615 let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
616 let flatten = Flatten::new();
617
618 let output = flatten.apply(&input);
619 assert_eq!(output.shape(), &[4]);
620 assert_eq!(output.to_vec(), vec![1.0, 2.0, 3.0, 4.0]);
621 }
622
623 #[test]
624 fn test_reshape() {
625 let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[6]).unwrap();
626 let reshape = Reshape::new(vec![2, 3]);
627
628 let output = reshape.apply(&input);
629 assert_eq!(output.shape(), &[2, 3]);
630 }
631
632 #[test]
633 fn test_compose() {
634 let normalize = Normalize::new(0.0, 1.0);
635 let scale = Scale::new(2.0);
636
637 let compose = Compose::new(vec![Box::new(normalize), Box::new(scale)]);
638
639 let input = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
640 let output = compose.apply(&input);
641
642 assert_eq!(output.to_vec(), vec![2.0, 4.0, 6.0]);
644 }
645
646 #[test]
647 fn test_compose_builder() {
648 let compose = Compose::empty()
649 .add(Normalize::new(0.0, 1.0))
650 .add(Scale::new(2.0));
651
652 let input = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
653 let output = compose.apply(&input);
654
655 assert_eq!(output.to_vec(), vec![2.0, 4.0, 6.0]);
656 }
657
658 #[test]
659 fn test_random_noise() {
660 let input = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
661 let noise = RandomNoise::new(0.0);
662
663 let output = noise.apply(&input);
665 assert_eq!(output.to_vec(), vec![1.0, 2.0, 3.0]);
666 }
667
668 #[test]
669 fn test_random_flip_1d() {
670 let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[4]).unwrap();
671 let flip = RandomFlip::new(0, 1.0); let output = flip.apply(&input);
674 assert_eq!(output.to_vec(), vec![4.0, 3.0, 2.0, 1.0]);
675 }
676
677 #[test]
678 fn test_random_flip_2d_horizontal() {
679 let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
680 let flip = RandomFlip::new(1, 1.0); let output = flip.apply(&input);
683 assert_eq!(output.to_vec(), vec![2.0, 1.0, 4.0, 3.0]);
685 }
686
687 #[test]
688 fn test_random_flip_2d_vertical() {
689 let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
690 let flip = RandomFlip::new(0, 1.0); let output = flip.apply(&input);
693 assert_eq!(output.to_vec(), vec![3.0, 4.0, 1.0, 2.0]);
695 }
696
697 #[test]
698 fn test_dropout_transform() {
699 let input = Tensor::from_vec(vec![1.0; 1000], &[1000]).unwrap();
700 let dropout = DropoutTransform::new(0.5);
701
702 let output = dropout.apply(&input);
703 let output_vec = output.to_vec();
704
705 let zeros = output_vec.iter().filter(|&&x| x == 0.0).count();
707 assert!(
708 zeros > 300 && zeros < 700,
709 "Expected ~500 zeros, got {zeros}"
710 );
711
712 let nonzeros: Vec<f32> = output_vec.iter().filter(|&&x| x != 0.0).copied().collect();
714 for x in nonzeros {
715 assert!((x - 2.0).abs() < 1e-6);
716 }
717 }
718
719 #[test]
720 fn test_lambda() {
721 let lambda = Lambda::new(|t: &Tensor<f32>| t.mul_scalar(3.0));
722
723 let input = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
724 let output = lambda.apply(&input);
725
726 assert_eq!(output.to_vec(), vec![3.0, 6.0, 9.0]);
727 }
728
729 #[test]
730 fn test_to_tensor() {
731 let input = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
732 let to_tensor = ToTensor::new();
733
734 let output = to_tensor.apply(&input);
735 assert_eq!(output.to_vec(), input.to_vec());
736 }
737
738 #[test]
739 fn test_normalize_variants() {
740 let standard = Normalize::standard();
741 assert_eq!(standard.mean, 0.0);
742 assert_eq!(standard.std, 1.0);
743
744 let zero_centered = Normalize::zero_centered();
745 assert_eq!(zero_centered.mean, 0.5);
746 assert_eq!(zero_centered.std, 0.5);
747 }
748
749 #[test]
750 fn test_random_crop_1d() {
751 let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0], &[5]).unwrap();
752 let crop = RandomCrop::new(vec![3]);
753
754 let output = crop.apply(&input);
755 assert_eq!(output.shape(), &[3]);
756 }
757
758 #[test]
759 fn test_random_crop_2d() {
760 let input = Tensor::from_vec((1..=16).map(|x| x as f32).collect(), &[4, 4]).unwrap();
762 let crop = RandomCrop::new_2d(2, 2);
763
764 let output = crop.apply(&input);
765 assert_eq!(output.shape(), &[2, 2]);
766 let vals = output.to_vec();
768 assert_eq!(vals.len(), 4);
769 }
770
771 #[test]
772 fn test_random_crop_3d() {
773 let input = Tensor::from_vec((1..=32).map(|x| x as f32).collect(), &[2, 4, 4]).unwrap();
775 let crop = RandomCrop::new_2d(2, 2);
776
777 let output = crop.apply(&input);
778 assert_eq!(output.shape(), &[2, 2, 2]); }
780}