1use axonml_tensor::Tensor;
9use rand::Rng;
10
11pub trait Transform: Send + Sync {
17 fn apply(&self, input: &Tensor<f32>) -> Tensor<f32>;
19}
20
21pub struct Compose {
27 transforms: Vec<Box<dyn Transform>>,
28}
29
30impl Compose {
31 #[must_use] pub fn new(transforms: Vec<Box<dyn Transform>>) -> Self {
33 Self { transforms }
34 }
35
36 #[must_use] pub fn empty() -> Self {
38 Self {
39 transforms: Vec::new(),
40 }
41 }
42
43 pub fn add<T: Transform + 'static>(mut self, transform: T) -> Self {
45 self.transforms.push(Box::new(transform));
46 self
47 }
48}
49
50impl Transform for Compose {
51 fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
52 let mut result = input.clone();
53 for transform in &self.transforms {
54 result = transform.apply(&result);
55 }
56 result
57 }
58}
59
60pub struct ToTensor;
66
67impl ToTensor {
68 #[must_use] pub fn new() -> Self {
70 Self
71 }
72}
73
74impl Default for ToTensor {
75 fn default() -> Self {
76 Self::new()
77 }
78}
79
80impl Transform for ToTensor {
81 fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
82 input.clone()
83 }
84}
85
86pub struct Normalize {
92 mean: f32,
93 std: f32,
94}
95
96impl Normalize {
97 #[must_use] pub fn new(mean: f32, std: f32) -> Self {
99 Self { mean, std }
100 }
101
102 #[must_use] pub fn standard() -> Self {
104 Self::new(0.0, 1.0)
105 }
106
107 #[must_use] pub fn zero_centered() -> Self {
109 Self::new(0.5, 0.5)
110 }
111}
112
113impl Transform for Normalize {
114 fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
115 let data = input.to_vec();
116 let normalized: Vec<f32> = data.iter().map(|&x| (x - self.mean) / self.std).collect();
117 Tensor::from_vec(normalized, input.shape()).unwrap()
118 }
119}
120
121pub struct RandomNoise {
127 std: f32,
128}
129
130impl RandomNoise {
131 #[must_use] pub fn new(std: f32) -> Self {
133 Self { std }
134 }
135}
136
137impl Transform for RandomNoise {
138 fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
139 if self.std == 0.0 {
140 return input.clone();
141 }
142
143 let mut rng = rand::thread_rng();
144 let data = input.to_vec();
145 let noisy: Vec<f32> = data
146 .iter()
147 .map(|&x| {
148 let u1: f32 = rng.gen();
150 let u2: f32 = rng.gen();
151 let z = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f32::consts::PI * u2).cos();
152 x + z * self.std
153 })
154 .collect();
155 Tensor::from_vec(noisy, input.shape()).unwrap()
156 }
157}
158
159pub struct RandomCrop {
165 size: Vec<usize>,
166}
167
168impl RandomCrop {
169 #[must_use] pub fn new(size: Vec<usize>) -> Self {
171 Self { size }
172 }
173
174 #[must_use] pub fn new_2d(height: usize, width: usize) -> Self {
176 Self::new(vec![height, width])
177 }
178}
179
180impl Transform for RandomCrop {
181 fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
182 let shape = input.shape();
183
184 if shape.len() < self.size.len() {
186 return input.clone();
187 }
188
189 let spatial_start = shape.len() - self.size.len();
190 let mut rng = rand::thread_rng();
191
192 let mut offsets = Vec::with_capacity(self.size.len());
194 for (i, &target_dim) in self.size.iter().enumerate() {
195 let input_dim = shape[spatial_start + i];
196 if input_dim <= target_dim {
197 offsets.push(0);
198 } else {
199 offsets.push(rng.gen_range(0..=input_dim - target_dim));
200 }
201 }
202
203 let crop_sizes: Vec<usize> = self
205 .size
206 .iter()
207 .enumerate()
208 .map(|(i, &s)| s.min(shape[spatial_start + i]))
209 .collect();
210
211 let data = input.to_vec();
212
213 if shape.len() == 1 && self.size.len() == 1 {
215 let start = offsets[0];
216 let end = start + crop_sizes[0];
217 let cropped = data[start..end].to_vec();
218 let len = cropped.len();
219 return Tensor::from_vec(cropped, &[len]).unwrap();
220 }
221
222 if shape.len() == 2 && self.size.len() == 2 {
224 let (_h, w) = (shape[0], shape[1]);
225 let (crop_h, crop_w) = (crop_sizes[0], crop_sizes[1]);
226 let (off_h, off_w) = (offsets[0], offsets[1]);
227
228 let mut cropped = Vec::with_capacity(crop_h * crop_w);
229 for row in off_h..off_h + crop_h {
230 for col in off_w..off_w + crop_w {
231 cropped.push(data[row * w + col]);
232 }
233 }
234 return Tensor::from_vec(cropped, &[crop_h, crop_w]).unwrap();
235 }
236
237 if shape.len() == 3 && self.size.len() == 2 {
239 let (c, h, w) = (shape[0], shape[1], shape[2]);
240 let (crop_h, crop_w) = (crop_sizes[0], crop_sizes[1]);
241 let (off_h, off_w) = (offsets[0], offsets[1]);
242
243 let mut cropped = Vec::with_capacity(c * crop_h * crop_w);
244 for channel in 0..c {
245 for row in off_h..off_h + crop_h {
246 for col in off_w..off_w + crop_w {
247 cropped.push(data[channel * h * w + row * w + col]);
248 }
249 }
250 }
251 return Tensor::from_vec(cropped, &[c, crop_h, crop_w]).unwrap();
252 }
253
254 if shape.len() == 4 && self.size.len() == 2 {
256 let (n, c, h, w) = (shape[0], shape[1], shape[2], shape[3]);
257 let (crop_h, crop_w) = (crop_sizes[0], crop_sizes[1]);
258 let (off_h, off_w) = (offsets[0], offsets[1]);
259
260 let mut cropped = Vec::with_capacity(n * c * crop_h * crop_w);
261 for batch in 0..n {
262 for channel in 0..c {
263 for row in off_h..off_h + crop_h {
264 for col in off_w..off_w + crop_w {
265 let idx = batch * c * h * w + channel * h * w + row * w + col;
266 cropped.push(data[idx]);
267 }
268 }
269 }
270 }
271 return Tensor::from_vec(cropped, &[n, c, crop_h, crop_w]).unwrap();
272 }
273
274 input.clone()
276 }
277}
278
279pub struct RandomFlip {
285 dim: usize,
286 probability: f32,
287}
288
289impl RandomFlip {
290 #[must_use] pub fn new(dim: usize, probability: f32) -> Self {
292 Self {
293 dim,
294 probability: probability.clamp(0.0, 1.0),
295 }
296 }
297
298 #[must_use] pub fn horizontal() -> Self {
300 Self::new(1, 0.5)
301 }
302
303 #[must_use] pub fn vertical() -> Self {
305 Self::new(0, 0.5)
306 }
307}
308
309impl Transform for RandomFlip {
310 fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
311 let mut rng = rand::thread_rng();
312 if rng.gen::<f32>() > self.probability {
313 return input.clone();
314 }
315
316 let shape = input.shape();
317 if self.dim >= shape.len() {
318 return input.clone();
319 }
320
321 if shape.len() == 1 {
323 let mut data = input.to_vec();
324 data.reverse();
325 return Tensor::from_vec(data, shape).unwrap();
326 }
327
328 if shape.len() == 2 {
330 let data = input.to_vec();
331 let (rows, cols) = (shape[0], shape[1]);
332 let mut flipped = vec![0.0; data.len()];
333
334 if self.dim == 0 {
335 for r in 0..rows {
337 for c in 0..cols {
338 flipped[r * cols + c] = data[(rows - 1 - r) * cols + c];
339 }
340 }
341 } else {
342 for r in 0..rows {
344 for c in 0..cols {
345 flipped[r * cols + c] = data[r * cols + (cols - 1 - c)];
346 }
347 }
348 }
349
350 return Tensor::from_vec(flipped, shape).unwrap();
351 }
352
353 input.clone()
354 }
355}
356
357pub struct Scale {
363 factor: f32,
364}
365
366impl Scale {
367 #[must_use] pub fn new(factor: f32) -> Self {
369 Self { factor }
370 }
371}
372
373impl Transform for Scale {
374 fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
375 input.mul_scalar(self.factor)
376 }
377}
378
379pub struct Clamp {
385 min: f32,
386 max: f32,
387}
388
389impl Clamp {
390 #[must_use] pub fn new(min: f32, max: f32) -> Self {
392 Self { min, max }
393 }
394
395 #[must_use] pub fn zero_one() -> Self {
397 Self::new(0.0, 1.0)
398 }
399
400 #[must_use] pub fn symmetric() -> Self {
402 Self::new(-1.0, 1.0)
403 }
404}
405
406impl Transform for Clamp {
407 fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
408 let data = input.to_vec();
409 let clamped: Vec<f32> = data.iter().map(|&x| x.clamp(self.min, self.max)).collect();
410 Tensor::from_vec(clamped, input.shape()).unwrap()
411 }
412}
413
414pub struct Flatten;
420
421impl Flatten {
422 #[must_use] pub fn new() -> Self {
424 Self
425 }
426}
427
428impl Default for Flatten {
429 fn default() -> Self {
430 Self::new()
431 }
432}
433
434impl Transform for Flatten {
435 fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
436 let data = input.to_vec();
437 Tensor::from_vec(data.clone(), &[data.len()]).unwrap()
438 }
439}
440
441pub struct Reshape {
447 shape: Vec<usize>,
448}
449
450impl Reshape {
451 #[must_use] pub fn new(shape: Vec<usize>) -> Self {
453 Self { shape }
454 }
455}
456
457impl Transform for Reshape {
458 fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
459 let data = input.to_vec();
460 let expected_size: usize = self.shape.iter().product();
461
462 if data.len() != expected_size {
463 return input.clone();
465 }
466
467 Tensor::from_vec(data, &self.shape).unwrap()
468 }
469}
470
471pub struct DropoutTransform {
477 probability: f32,
478}
479
480impl DropoutTransform {
481 #[must_use] pub fn new(probability: f32) -> Self {
483 Self {
484 probability: probability.clamp(0.0, 1.0),
485 }
486 }
487}
488
489impl Transform for DropoutTransform {
490 fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
491 if self.probability == 0.0 {
492 return input.clone();
493 }
494
495 let mut rng = rand::thread_rng();
496 let scale = 1.0 / (1.0 - self.probability);
497 let data = input.to_vec();
498
499 let dropped: Vec<f32> = data
500 .iter()
501 .map(|&x| {
502 if rng.gen::<f32>() < self.probability {
503 0.0
504 } else {
505 x * scale
506 }
507 })
508 .collect();
509
510 Tensor::from_vec(dropped, input.shape()).unwrap()
511 }
512}
513
514pub struct Lambda<F>
520where
521 F: Fn(&Tensor<f32>) -> Tensor<f32> + Send + Sync,
522{
523 func: F,
524}
525
526impl<F> Lambda<F>
527where
528 F: Fn(&Tensor<f32>) -> Tensor<f32> + Send + Sync,
529{
530 pub fn new(func: F) -> Self {
532 Self { func }
533 }
534}
535
536impl<F> Transform for Lambda<F>
537where
538 F: Fn(&Tensor<f32>) -> Tensor<f32> + Send + Sync,
539{
540 fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
541 (self.func)(input)
542 }
543}
544
545#[cfg(test)]
550mod tests {
551 use super::*;
552
553 #[test]
554 fn test_normalize() {
555 let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[4]).unwrap();
556 let normalize = Normalize::new(2.5, 0.5);
557
558 let output = normalize.apply(&input);
559 let expected = [-3.0, -1.0, 1.0, 3.0];
560
561 let result = output.to_vec();
562 for (a, b) in result.iter().zip(expected.iter()) {
563 assert!((a - b).abs() < 1e-6);
564 }
565 }
566
567 #[test]
568 fn test_scale() {
569 let input = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
570 let scale = Scale::new(2.0);
571
572 let output = scale.apply(&input);
573 assert_eq!(output.to_vec(), vec![2.0, 4.0, 6.0]);
574 }
575
576 #[test]
577 fn test_clamp() {
578 let input = Tensor::from_vec(vec![-1.0, 0.5, 2.0], &[3]).unwrap();
579 let clamp = Clamp::zero_one();
580
581 let output = clamp.apply(&input);
582 assert_eq!(output.to_vec(), vec![0.0, 0.5, 1.0]);
583 }
584
585 #[test]
586 fn test_flatten() {
587 let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
588 let flatten = Flatten::new();
589
590 let output = flatten.apply(&input);
591 assert_eq!(output.shape(), &[4]);
592 assert_eq!(output.to_vec(), vec![1.0, 2.0, 3.0, 4.0]);
593 }
594
595 #[test]
596 fn test_reshape() {
597 let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[6]).unwrap();
598 let reshape = Reshape::new(vec![2, 3]);
599
600 let output = reshape.apply(&input);
601 assert_eq!(output.shape(), &[2, 3]);
602 }
603
604 #[test]
605 fn test_compose() {
606 let normalize = Normalize::new(0.0, 1.0);
607 let scale = Scale::new(2.0);
608
609 let compose = Compose::new(vec![Box::new(normalize), Box::new(scale)]);
610
611 let input = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
612 let output = compose.apply(&input);
613
614 assert_eq!(output.to_vec(), vec![2.0, 4.0, 6.0]);
616 }
617
618 #[test]
619 fn test_compose_builder() {
620 let compose = Compose::empty()
621 .add(Normalize::new(0.0, 1.0))
622 .add(Scale::new(2.0));
623
624 let input = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
625 let output = compose.apply(&input);
626
627 assert_eq!(output.to_vec(), vec![2.0, 4.0, 6.0]);
628 }
629
630 #[test]
631 fn test_random_noise() {
632 let input = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
633 let noise = RandomNoise::new(0.0);
634
635 let output = noise.apply(&input);
637 assert_eq!(output.to_vec(), vec![1.0, 2.0, 3.0]);
638 }
639
640 #[test]
641 fn test_random_flip_1d() {
642 let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[4]).unwrap();
643 let flip = RandomFlip::new(0, 1.0); let output = flip.apply(&input);
646 assert_eq!(output.to_vec(), vec![4.0, 3.0, 2.0, 1.0]);
647 }
648
649 #[test]
650 fn test_random_flip_2d_horizontal() {
651 let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
652 let flip = RandomFlip::new(1, 1.0); let output = flip.apply(&input);
655 assert_eq!(output.to_vec(), vec![2.0, 1.0, 4.0, 3.0]);
657 }
658
659 #[test]
660 fn test_random_flip_2d_vertical() {
661 let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
662 let flip = RandomFlip::new(0, 1.0); let output = flip.apply(&input);
665 assert_eq!(output.to_vec(), vec![3.0, 4.0, 1.0, 2.0]);
667 }
668
669 #[test]
670 fn test_dropout_transform() {
671 let input = Tensor::from_vec(vec![1.0; 1000], &[1000]).unwrap();
672 let dropout = DropoutTransform::new(0.5);
673
674 let output = dropout.apply(&input);
675 let output_vec = output.to_vec();
676
677 let zeros = output_vec.iter().filter(|&&x| x == 0.0).count();
679 assert!(
680 zeros > 300 && zeros < 700,
681 "Expected ~500 zeros, got {zeros}"
682 );
683
684 let nonzeros: Vec<f32> = output_vec.iter().filter(|&&x| x != 0.0).copied().collect();
686 for x in nonzeros {
687 assert!((x - 2.0).abs() < 1e-6);
688 }
689 }
690
691 #[test]
692 fn test_lambda() {
693 let lambda = Lambda::new(|t: &Tensor<f32>| t.mul_scalar(3.0));
694
695 let input = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
696 let output = lambda.apply(&input);
697
698 assert_eq!(output.to_vec(), vec![3.0, 6.0, 9.0]);
699 }
700
701 #[test]
702 fn test_to_tensor() {
703 let input = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
704 let to_tensor = ToTensor::new();
705
706 let output = to_tensor.apply(&input);
707 assert_eq!(output.to_vec(), input.to_vec());
708 }
709
710 #[test]
711 fn test_normalize_variants() {
712 let standard = Normalize::standard();
713 assert_eq!(standard.mean, 0.0);
714 assert_eq!(standard.std, 1.0);
715
716 let zero_centered = Normalize::zero_centered();
717 assert_eq!(zero_centered.mean, 0.5);
718 assert_eq!(zero_centered.std, 0.5);
719 }
720
721 #[test]
722 fn test_random_crop_1d() {
723 let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0], &[5]).unwrap();
724 let crop = RandomCrop::new(vec![3]);
725
726 let output = crop.apply(&input);
727 assert_eq!(output.shape(), &[3]);
728 }
729
730 #[test]
731 fn test_random_crop_2d() {
732 let input = Tensor::from_vec((1..=16).map(|x| x as f32).collect(), &[4, 4]).unwrap();
734 let crop = RandomCrop::new_2d(2, 2);
735
736 let output = crop.apply(&input);
737 assert_eq!(output.shape(), &[2, 2]);
738 let vals = output.to_vec();
740 assert_eq!(vals.len(), 4);
741 }
742
743 #[test]
744 fn test_random_crop_3d() {
745 let input = Tensor::from_vec((1..=32).map(|x| x as f32).collect(), &[2, 4, 4]).unwrap();
747 let crop = RandomCrop::new_2d(2, 2);
748
749 let output = crop.apply(&input);
750 assert_eq!(output.shape(), &[2, 2, 2]); }
752}