1use axonml_data::Transform;
29use axonml_tensor::Tensor;
30use rand::Rng;
31
32pub struct Resize {
38 height: usize,
39 width: usize,
40}
41
42impl Resize {
43 #[must_use]
45 pub fn new(height: usize, width: usize) -> Self {
46 Self { height, width }
47 }
48
49 #[must_use]
51 pub fn square(size: usize) -> Self {
52 Self::new(size, size)
53 }
54}
55
56impl Transform for Resize {
57 fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
58 let shape = input.shape();
59
60 match shape.len() {
62 2 => resize_2d(input, self.height, self.width),
63 3 => resize_3d(input, self.height, self.width),
64 4 => resize_4d(input, self.height, self.width),
65 _ => input.clone(),
66 }
67 }
68}
69
70fn resize_2d(input: &Tensor<f32>, new_h: usize, new_w: usize) -> Tensor<f32> {
72 let shape = input.shape();
73 let (old_h, old_w) = (shape[0], shape[1]);
74 let data = input.to_vec();
75
76 let mut result = vec![0.0; new_h * new_w];
77
78 let scale_h = old_h as f32 / new_h as f32;
79 let scale_w = old_w as f32 / new_w as f32;
80
81 for y in 0..new_h {
82 for x in 0..new_w {
83 let src_y = y as f32 * scale_h;
84 let src_x = x as f32 * scale_w;
85
86 let y0 = (src_y.floor() as usize).min(old_h - 1);
87 let y1 = (y0 + 1).min(old_h - 1);
88 let x0 = (src_x.floor() as usize).min(old_w - 1);
89 let x1 = (x0 + 1).min(old_w - 1);
90
91 let dy = src_y - y0 as f32;
92 let dx = src_x - x0 as f32;
93
94 let v00 = data[y0 * old_w + x0];
95 let v01 = data[y0 * old_w + x1];
96 let v10 = data[y1 * old_w + x0];
97 let v11 = data[y1 * old_w + x1];
98
99 let value = v00 * (1.0 - dx) * (1.0 - dy)
100 + v01 * dx * (1.0 - dy)
101 + v10 * (1.0 - dx) * dy
102 + v11 * dx * dy;
103
104 result[y * new_w + x] = value;
105 }
106 }
107
108 Tensor::from_vec(result, &[new_h, new_w]).unwrap()
109}
110
111fn resize_3d(input: &Tensor<f32>, new_h: usize, new_w: usize) -> Tensor<f32> {
113 let shape = input.shape();
114 let (channels, old_h, old_w) = (shape[0], shape[1], shape[2]);
115 let data = input.to_vec();
116
117 let mut result = vec![0.0; channels * new_h * new_w];
118
119 let scale_h = old_h as f32 / new_h as f32;
120 let scale_w = old_w as f32 / new_w as f32;
121
122 for c in 0..channels {
123 for y in 0..new_h {
124 for x in 0..new_w {
125 let src_y = y as f32 * scale_h;
126 let src_x = x as f32 * scale_w;
127
128 let y0 = (src_y.floor() as usize).min(old_h - 1);
129 let y1 = (y0 + 1).min(old_h - 1);
130 let x0 = (src_x.floor() as usize).min(old_w - 1);
131 let x1 = (x0 + 1).min(old_w - 1);
132
133 let dy = src_y - y0 as f32;
134 let dx = src_x - x0 as f32;
135
136 let base = c * old_h * old_w;
137 let v00 = data[base + y0 * old_w + x0];
138 let v01 = data[base + y0 * old_w + x1];
139 let v10 = data[base + y1 * old_w + x0];
140 let v11 = data[base + y1 * old_w + x1];
141
142 let value = v00 * (1.0 - dx) * (1.0 - dy)
143 + v01 * dx * (1.0 - dy)
144 + v10 * (1.0 - dx) * dy
145 + v11 * dx * dy;
146
147 result[c * new_h * new_w + y * new_w + x] = value;
148 }
149 }
150 }
151
152 Tensor::from_vec(result, &[channels, new_h, new_w]).unwrap()
153}
154
155fn resize_4d(input: &Tensor<f32>, new_h: usize, new_w: usize) -> Tensor<f32> {
157 let shape = input.shape();
158 let (batch, channels, old_h, old_w) = (shape[0], shape[1], shape[2], shape[3]);
159 let data = input.to_vec();
160
161 let mut result = vec![0.0; batch * channels * new_h * new_w];
162
163 let scale_h = old_h as f32 / new_h as f32;
164 let scale_w = old_w as f32 / new_w as f32;
165
166 for n in 0..batch {
167 for c in 0..channels {
168 for y in 0..new_h {
169 for x in 0..new_w {
170 let src_y = y as f32 * scale_h;
171 let src_x = x as f32 * scale_w;
172
173 let y0 = (src_y.floor() as usize).min(old_h - 1);
174 let y1 = (y0 + 1).min(old_h - 1);
175 let x0 = (src_x.floor() as usize).min(old_w - 1);
176 let x1 = (x0 + 1).min(old_w - 1);
177
178 let dy = src_y - y0 as f32;
179 let dx = src_x - x0 as f32;
180
181 let base = n * channels * old_h * old_w + c * old_h * old_w;
182 let v00 = data[base + y0 * old_w + x0];
183 let v01 = data[base + y0 * old_w + x1];
184 let v10 = data[base + y1 * old_w + x0];
185 let v11 = data[base + y1 * old_w + x1];
186
187 let value = v00 * (1.0 - dx) * (1.0 - dy)
188 + v01 * dx * (1.0 - dy)
189 + v10 * (1.0 - dx) * dy
190 + v11 * dx * dy;
191
192 let out_idx = n * channels * new_h * new_w + c * new_h * new_w + y * new_w + x;
193 result[out_idx] = value;
194 }
195 }
196 }
197 }
198
199 Tensor::from_vec(result, &[batch, channels, new_h, new_w]).unwrap()
200}
201
202pub struct CenterCrop {
208 height: usize,
209 width: usize,
210}
211
212impl CenterCrop {
213 #[must_use]
215 pub fn new(height: usize, width: usize) -> Self {
216 Self { height, width }
217 }
218
219 #[must_use]
221 pub fn square(size: usize) -> Self {
222 Self::new(size, size)
223 }
224}
225
226impl Transform for CenterCrop {
227 fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
228 let shape = input.shape();
229 let data = input.to_vec();
230
231 match shape.len() {
232 2 => {
233 let (h, w) = (shape[0], shape[1]);
234 let start_h = (h.saturating_sub(self.height)) / 2;
235 let start_w = (w.saturating_sub(self.width)) / 2;
236 let crop_h = self.height.min(h);
237 let crop_w = self.width.min(w);
238
239 let mut result = Vec::with_capacity(crop_h * crop_w);
240 for y in start_h..start_h + crop_h {
241 for x in start_w..start_w + crop_w {
242 result.push(data[y * w + x]);
243 }
244 }
245 Tensor::from_vec(result, &[crop_h, crop_w]).unwrap()
246 }
247 3 => {
248 let (c, h, w) = (shape[0], shape[1], shape[2]);
249 let start_h = (h.saturating_sub(self.height)) / 2;
250 let start_w = (w.saturating_sub(self.width)) / 2;
251 let crop_h = self.height.min(h);
252 let crop_w = self.width.min(w);
253
254 let mut result = Vec::with_capacity(c * crop_h * crop_w);
255 for ch in 0..c {
256 for y in start_h..start_h + crop_h {
257 for x in start_w..start_w + crop_w {
258 result.push(data[ch * h * w + y * w + x]);
259 }
260 }
261 }
262 Tensor::from_vec(result, &[c, crop_h, crop_w]).unwrap()
263 }
264 _ => input.clone(),
265 }
266 }
267}
268
269pub struct RandomHorizontalFlip {
275 probability: f32,
276}
277
278impl RandomHorizontalFlip {
279 #[must_use]
281 pub fn new() -> Self {
282 Self { probability: 0.5 }
283 }
284
285 #[must_use]
287 pub fn with_probability(probability: f32) -> Self {
288 Self {
289 probability: probability.clamp(0.0, 1.0),
290 }
291 }
292}
293
294impl Default for RandomHorizontalFlip {
295 fn default() -> Self {
296 Self::new()
297 }
298}
299
300impl Transform for RandomHorizontalFlip {
301 fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
302 let mut rng = rand::thread_rng();
303 if rng.r#gen::<f32>() > self.probability {
304 return input.clone();
305 }
306
307 let shape = input.shape();
308 let data = input.to_vec();
309
310 match shape.len() {
311 2 => {
312 let (h, w) = (shape[0], shape[1]);
313 let mut result = vec![0.0; h * w];
314 for y in 0..h {
315 for x in 0..w {
316 result[y * w + x] = data[y * w + (w - 1 - x)];
317 }
318 }
319 Tensor::from_vec(result, shape).unwrap()
320 }
321 3 => {
322 let (c, h, w) = (shape[0], shape[1], shape[2]);
323 let mut result = vec![0.0; c * h * w];
324 for ch in 0..c {
325 for y in 0..h {
326 for x in 0..w {
327 result[ch * h * w + y * w + x] = data[ch * h * w + y * w + (w - 1 - x)];
328 }
329 }
330 }
331 Tensor::from_vec(result, shape).unwrap()
332 }
333 _ => input.clone(),
334 }
335 }
336}
337
338pub struct RandomVerticalFlip {
344 probability: f32,
345}
346
347impl RandomVerticalFlip {
348 #[must_use]
350 pub fn new() -> Self {
351 Self { probability: 0.5 }
352 }
353
354 #[must_use]
356 pub fn with_probability(probability: f32) -> Self {
357 Self {
358 probability: probability.clamp(0.0, 1.0),
359 }
360 }
361}
362
363impl Default for RandomVerticalFlip {
364 fn default() -> Self {
365 Self::new()
366 }
367}
368
369impl Transform for RandomVerticalFlip {
370 fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
371 let mut rng = rand::thread_rng();
372 if rng.r#gen::<f32>() > self.probability {
373 return input.clone();
374 }
375
376 let shape = input.shape();
377 let data = input.to_vec();
378
379 match shape.len() {
380 2 => {
381 let (h, w) = (shape[0], shape[1]);
382 let mut result = vec![0.0; h * w];
383 for y in 0..h {
384 for x in 0..w {
385 result[y * w + x] = data[(h - 1 - y) * w + x];
386 }
387 }
388 Tensor::from_vec(result, shape).unwrap()
389 }
390 3 => {
391 let (c, h, w) = (shape[0], shape[1], shape[2]);
392 let mut result = vec![0.0; c * h * w];
393 for ch in 0..c {
394 for y in 0..h {
395 for x in 0..w {
396 result[ch * h * w + y * w + x] = data[ch * h * w + (h - 1 - y) * w + x];
397 }
398 }
399 }
400 Tensor::from_vec(result, shape).unwrap()
401 }
402 _ => input.clone(),
403 }
404 }
405}
406
407pub struct RandomRotation {
413 angles: Vec<i32>,
415}
416
417impl RandomRotation {
418 #[must_use]
420 pub fn new() -> Self {
421 Self {
422 angles: vec![0, 90, 180, 270],
423 }
424 }
425
426 #[must_use]
428 pub fn with_angles(angles: Vec<i32>) -> Self {
429 let valid: Vec<i32> = angles
430 .into_iter()
431 .filter(|&a| a == 0 || a == 90 || a == 180 || a == 270)
432 .collect();
433 Self {
434 angles: if valid.is_empty() { vec![0] } else { valid },
435 }
436 }
437}
438
439impl Default for RandomRotation {
440 fn default() -> Self {
441 Self::new()
442 }
443}
444
445impl Transform for RandomRotation {
446 fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
447 let mut rng = rand::thread_rng();
448 let angle = self.angles[rng.gen_range(0..self.angles.len())];
449
450 if angle == 0 {
451 return input.clone();
452 }
453
454 let shape = input.shape();
455 let data = input.to_vec();
456
457 if shape.len() != 2 {
459 return input.clone();
460 }
461
462 let (h, w) = (shape[0], shape[1]);
463
464 match angle {
465 90 => {
466 let mut result = vec![0.0; h * w];
468 for y in 0..h {
469 for x in 0..w {
470 result[x * h + (h - 1 - y)] = data[y * w + x];
471 }
472 }
473 Tensor::from_vec(result, &[w, h]).unwrap()
474 }
475 180 => {
476 let mut result = vec![0.0; h * w];
478 for y in 0..h {
479 for x in 0..w {
480 result[(h - 1 - y) * w + (w - 1 - x)] = data[y * w + x];
481 }
482 }
483 Tensor::from_vec(result, &[h, w]).unwrap()
484 }
485 270 => {
486 let mut result = vec![0.0; h * w];
488 for y in 0..h {
489 for x in 0..w {
490 result[(w - 1 - x) * h + y] = data[y * w + x];
491 }
492 }
493 Tensor::from_vec(result, &[w, h]).unwrap()
494 }
495 _ => input.clone(),
496 }
497 }
498}
499
500pub struct ColorJitter {
506 brightness: f32,
507 contrast: f32,
508 saturation: f32,
509}
510
511impl ColorJitter {
512 #[must_use]
514 pub fn new(brightness: f32, contrast: f32, saturation: f32) -> Self {
515 Self {
516 brightness: brightness.abs(),
517 contrast: contrast.abs(),
518 saturation: saturation.abs(),
519 }
520 }
521}
522
523impl Transform for ColorJitter {
524 fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
525 let mut rng = rand::thread_rng();
526 let mut data = input.to_vec();
527 let shape = input.shape();
528
529 if self.brightness > 0.0 {
531 let factor = 1.0 + rng.gen_range(-self.brightness..self.brightness);
532 for val in &mut data {
533 *val = (*val * factor).clamp(0.0, 1.0);
534 }
535 }
536
537 if self.contrast > 0.0 {
539 let factor = 1.0 + rng.gen_range(-self.contrast..self.contrast);
540 let mean: f32 = data.iter().sum::<f32>() / data.len() as f32;
541 for val in &mut data {
542 *val = ((*val - mean) * factor + mean).clamp(0.0, 1.0);
543 }
544 }
545
546 if self.saturation > 0.0 && shape.len() == 3 && shape[0] == 3 {
548 let factor = 1.0 + rng.gen_range(-self.saturation..self.saturation);
549 let (h, w) = (shape[1], shape[2]);
550
551 for y in 0..h {
552 for x in 0..w {
553 let r = data[0 * h * w + y * w + x];
554 let g = data[h * w + y * w + x];
555 let b = data[2 * h * w + y * w + x];
556
557 let gray = 0.299 * r + 0.587 * g + 0.114 * b;
558
559 data[0 * h * w + y * w + x] = (gray + (r - gray) * factor).clamp(0.0, 1.0);
560 data[h * w + y * w + x] = (gray + (g - gray) * factor).clamp(0.0, 1.0);
561 data[2 * h * w + y * w + x] = (gray + (b - gray) * factor).clamp(0.0, 1.0);
562 }
563 }
564 }
565
566 Tensor::from_vec(data, shape).unwrap()
567 }
568}
569
570pub struct Grayscale {
576 num_output_channels: usize,
577}
578
579impl Grayscale {
580 #[must_use]
582 pub fn new() -> Self {
583 Self {
584 num_output_channels: 1,
585 }
586 }
587
588 #[must_use]
590 pub fn with_channels(num_output_channels: usize) -> Self {
591 Self {
592 num_output_channels: num_output_channels.max(1),
593 }
594 }
595}
596
597impl Default for Grayscale {
598 fn default() -> Self {
599 Self::new()
600 }
601}
602
603impl Transform for Grayscale {
604 fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
605 let shape = input.shape();
606
607 if shape.len() != 3 || shape[0] != 3 {
609 return input.clone();
610 }
611
612 let (_, h, w) = (shape[0], shape[1], shape[2]);
613 let data = input.to_vec();
614
615 let mut gray = Vec::with_capacity(h * w);
616 for y in 0..h {
617 for x in 0..w {
618 let r = data[0 * h * w + y * w + x];
619 let g = data[h * w + y * w + x];
620 let b = data[2 * h * w + y * w + x];
621 gray.push(0.299 * r + 0.587 * g + 0.114 * b);
622 }
623 }
624
625 if self.num_output_channels == 1 {
626 Tensor::from_vec(gray, &[1, h, w]).unwrap()
627 } else {
628 let mut result = Vec::with_capacity(self.num_output_channels * h * w);
630 for _ in 0..self.num_output_channels {
631 result.extend(&gray);
632 }
633 Tensor::from_vec(result, &[self.num_output_channels, h, w]).unwrap()
634 }
635 }
636}
637
638pub struct ImageNormalize {
644 mean: Vec<f32>,
645 std: Vec<f32>,
646}
647
648impl ImageNormalize {
649 #[must_use]
651 pub fn new(mean: Vec<f32>, std: Vec<f32>) -> Self {
652 Self { mean, std }
653 }
654
655 #[must_use]
657 pub fn imagenet() -> Self {
658 Self::new(vec![0.485, 0.456, 0.406], vec![0.229, 0.224, 0.225])
659 }
660
661 #[must_use]
663 pub fn mnist() -> Self {
664 Self::new(vec![0.1307], vec![0.3081])
665 }
666
667 #[must_use]
669 pub fn cifar10() -> Self {
670 Self::new(vec![0.4914, 0.4822, 0.4465], vec![0.2470, 0.2435, 0.2616])
671 }
672}
673
674impl Transform for ImageNormalize {
675 fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
676 let shape = input.shape();
677 let mut data = input.to_vec();
678
679 match shape.len() {
680 3 => {
681 let (c, h, w) = (shape[0], shape[1], shape[2]);
682 for ch in 0..c {
683 let mean = self.mean.get(ch).copied().unwrap_or(0.0);
684 let std = self.std.get(ch).copied().unwrap_or(1.0);
685 for y in 0..h {
686 for x in 0..w {
687 let idx = ch * h * w + y * w + x;
688 data[idx] = (data[idx] - mean) / std;
689 }
690 }
691 }
692 }
693 4 => {
694 let (n, c, h, w) = (shape[0], shape[1], shape[2], shape[3]);
695 for batch in 0..n {
696 for ch in 0..c {
697 let mean = self.mean.get(ch).copied().unwrap_or(0.0);
698 let std = self.std.get(ch).copied().unwrap_or(1.0);
699 for y in 0..h {
700 for x in 0..w {
701 let idx = batch * c * h * w + ch * h * w + y * w + x;
702 data[idx] = (data[idx] - mean) / std;
703 }
704 }
705 }
706 }
707 }
708 _ => {}
709 }
710
711 Tensor::from_vec(data, shape).unwrap()
712 }
713}
714
715pub struct Pad {
721 padding: (usize, usize, usize, usize), fill_value: f32,
723}
724
725impl Pad {
726 #[must_use]
728 pub fn new(padding: usize) -> Self {
729 Self {
730 padding: (padding, padding, padding, padding),
731 fill_value: 0.0,
732 }
733 }
734
735 #[must_use]
737 pub fn asymmetric(left: usize, right: usize, top: usize, bottom: usize) -> Self {
738 Self {
739 padding: (left, right, top, bottom),
740 fill_value: 0.0,
741 }
742 }
743
744 #[must_use]
746 pub fn with_fill(mut self, value: f32) -> Self {
747 self.fill_value = value;
748 self
749 }
750}
751
752impl Transform for Pad {
753 fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
754 let shape = input.shape();
755 let data = input.to_vec();
756 let (left, right, top, bottom) = self.padding;
757
758 match shape.len() {
759 2 => {
760 let (h, w) = (shape[0], shape[1]);
761 let new_h = h + top + bottom;
762 let new_w = w + left + right;
763
764 let mut result = vec![self.fill_value; new_h * new_w];
765 for y in 0..h {
766 for x in 0..w {
767 result[(y + top) * new_w + (x + left)] = data[y * w + x];
768 }
769 }
770 Tensor::from_vec(result, &[new_h, new_w]).unwrap()
771 }
772 3 => {
773 let (c, h, w) = (shape[0], shape[1], shape[2]);
774 let new_h = h + top + bottom;
775 let new_w = w + left + right;
776
777 let mut result = vec![self.fill_value; c * new_h * new_w];
778 for ch in 0..c {
779 for y in 0..h {
780 for x in 0..w {
781 result[ch * new_h * new_w + (y + top) * new_w + (x + left)] =
782 data[ch * h * w + y * w + x];
783 }
784 }
785 }
786 Tensor::from_vec(result, &[c, new_h, new_w]).unwrap()
787 }
788 _ => input.clone(),
789 }
790 }
791}
792
793pub struct ToTensorImage;
799
800impl ToTensorImage {
801 #[must_use]
803 pub fn new() -> Self {
804 Self
805 }
806}
807
808impl Default for ToTensorImage {
809 fn default() -> Self {
810 Self::new()
811 }
812}
813
814impl Transform for ToTensorImage {
815 fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
816 let data: Vec<f32> = input.to_vec().iter().map(|&x| x / 255.0).collect();
817 Tensor::from_vec(data, input.shape()).unwrap()
818 }
819}
820
821#[cfg(test)]
826mod tests {
827 use super::*;
828
829 #[test]
830 fn test_resize_2d() {
831 let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
832
833 let resize = Resize::new(4, 4);
834 let output = resize.apply(&input);
835
836 assert_eq!(output.shape(), &[4, 4]);
837 }
838
839 #[test]
840 fn test_resize_3d() {
841 let input = Tensor::from_vec(vec![1.0; 3 * 8 * 8], &[3, 8, 8]).unwrap();
842
843 let resize = Resize::new(4, 4);
844 let output = resize.apply(&input);
845
846 assert_eq!(output.shape(), &[3, 4, 4]);
847 }
848
849 #[test]
850 fn test_center_crop() {
851 let input = Tensor::from_vec((1..=16).map(|x| x as f32).collect(), &[4, 4]).unwrap();
852
853 let crop = CenterCrop::new(2, 2);
854 let output = crop.apply(&input);
855
856 assert_eq!(output.shape(), &[2, 2]);
857 assert_eq!(output.to_vec(), vec![6.0, 7.0, 10.0, 11.0]);
859 }
860
861 #[test]
862 fn test_random_horizontal_flip() {
863 let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
864
865 let flip = RandomHorizontalFlip::with_probability(1.0);
866 let output = flip.apply(&input);
867
868 assert_eq!(output.to_vec(), vec![2.0, 1.0, 4.0, 3.0]);
870 }
871
872 #[test]
873 fn test_random_vertical_flip() {
874 let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
875
876 let flip = RandomVerticalFlip::with_probability(1.0);
877 let output = flip.apply(&input);
878
879 assert_eq!(output.to_vec(), vec![3.0, 4.0, 1.0, 2.0]);
881 }
882
883 #[test]
884 fn test_random_rotation_180() {
885 let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
886
887 let rotation = RandomRotation::with_angles(vec![180]);
888 let output = rotation.apply(&input);
889
890 assert_eq!(output.to_vec(), vec![4.0, 3.0, 2.0, 1.0]);
892 }
893
894 #[test]
895 fn test_grayscale() {
896 let input = Tensor::from_vec(
897 vec![
898 1.0, 1.0, 1.0, 1.0, 0.5, 0.5, 0.5, 0.5, 0.0, 0.0, 0.0, 0.0, ],
902 &[3, 2, 2],
903 )
904 .unwrap();
905
906 let gray = Grayscale::new();
907 let output = gray.apply(&input);
908
909 assert_eq!(output.shape(), &[1, 2, 2]);
910 let expected = 0.299 + 0.587 * 0.5;
912 for val in output.to_vec() {
913 assert!((val - expected).abs() < 0.001);
914 }
915 }
916
917 #[test]
918 fn test_image_normalize() {
919 let input = Tensor::from_vec(vec![0.5; 3 * 2 * 2], &[3, 2, 2]).unwrap();
920
921 let normalize = ImageNormalize::new(vec![0.5, 0.5, 0.5], vec![0.5, 0.5, 0.5]);
922 let output = normalize.apply(&input);
923
924 for val in output.to_vec() {
926 assert!((val - 0.0).abs() < 0.001);
927 }
928 }
929
930 #[test]
931 fn test_pad() {
932 let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
933
934 let pad = Pad::new(1);
935 let output = pad.apply(&input);
936
937 assert_eq!(output.shape(), &[4, 4]);
938 let data = output.to_vec();
940 assert_eq!(data[0], 0.0);
941 assert_eq!(data[3], 0.0);
942 assert_eq!(data[12], 0.0);
943 assert_eq!(data[15], 0.0);
944 assert_eq!(data[5], 1.0);
946 assert_eq!(data[6], 2.0);
947 assert_eq!(data[9], 3.0);
948 assert_eq!(data[10], 4.0);
949 }
950
951 #[test]
952 fn test_to_tensor_image() {
953 let input = Tensor::from_vec(vec![0.0, 127.5, 255.0], &[3]).unwrap();
954
955 let transform = ToTensorImage::new();
956 let output = transform.apply(&input);
957
958 let data = output.to_vec();
959 assert!((data[0] - 0.0).abs() < 0.001);
960 assert!((data[1] - 0.5).abs() < 0.001);
961 assert!((data[2] - 1.0).abs() < 0.001);
962 }
963
964 #[test]
965 fn test_color_jitter() {
966 let input = Tensor::from_vec(vec![0.5; 3 * 4 * 4], &[3, 4, 4]).unwrap();
967
968 let jitter = ColorJitter::new(0.1, 0.1, 0.1);
969 let output = jitter.apply(&input);
970
971 assert_eq!(output.shape(), &[3, 4, 4]);
972 for val in output.to_vec() {
974 assert!((0.0..=1.0).contains(&val));
975 }
976 }
977}