1use axonml_data::Transform;
18use axonml_tensor::Tensor;
19use rand::Rng;
20
21pub struct Resize {
27 height: usize,
28 width: usize,
29}
30
31impl Resize {
32 #[must_use]
34 pub fn new(height: usize, width: usize) -> Self {
35 Self { height, width }
36 }
37
38 #[must_use]
40 pub fn square(size: usize) -> Self {
41 Self::new(size, size)
42 }
43}
44
45impl Transform for Resize {
46 fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
47 let shape = input.shape();
48
49 match shape.len() {
51 2 => resize_2d(input, self.height, self.width),
52 3 => resize_3d(input, self.height, self.width),
53 4 => resize_4d(input, self.height, self.width),
54 _ => input.clone(),
55 }
56 }
57}
58
59fn resize_2d(input: &Tensor<f32>, new_h: usize, new_w: usize) -> Tensor<f32> {
61 let shape = input.shape();
62 let (old_h, old_w) = (shape[0], shape[1]);
63 let data = input.to_vec();
64
65 let mut result = vec![0.0; new_h * new_w];
66
67 let scale_h = old_h as f32 / new_h as f32;
68 let scale_w = old_w as f32 / new_w as f32;
69
70 for y in 0..new_h {
71 for x in 0..new_w {
72 let src_y = y as f32 * scale_h;
73 let src_x = x as f32 * scale_w;
74
75 let y0 = (src_y.floor() as usize).min(old_h - 1);
76 let y1 = (y0 + 1).min(old_h - 1);
77 let x0 = (src_x.floor() as usize).min(old_w - 1);
78 let x1 = (x0 + 1).min(old_w - 1);
79
80 let dy = src_y - y0 as f32;
81 let dx = src_x - x0 as f32;
82
83 let v00 = data[y0 * old_w + x0];
84 let v01 = data[y0 * old_w + x1];
85 let v10 = data[y1 * old_w + x0];
86 let v11 = data[y1 * old_w + x1];
87
88 let value = v00 * (1.0 - dx) * (1.0 - dy)
89 + v01 * dx * (1.0 - dy)
90 + v10 * (1.0 - dx) * dy
91 + v11 * dx * dy;
92
93 result[y * new_w + x] = value;
94 }
95 }
96
97 Tensor::from_vec(result, &[new_h, new_w]).unwrap()
98}
99
100fn resize_3d(input: &Tensor<f32>, new_h: usize, new_w: usize) -> Tensor<f32> {
102 let shape = input.shape();
103 let (channels, old_h, old_w) = (shape[0], shape[1], shape[2]);
104 let data = input.to_vec();
105
106 let mut result = vec![0.0; channels * new_h * new_w];
107
108 let scale_h = old_h as f32 / new_h as f32;
109 let scale_w = old_w as f32 / new_w as f32;
110
111 for c in 0..channels {
112 for y in 0..new_h {
113 for x in 0..new_w {
114 let src_y = y as f32 * scale_h;
115 let src_x = x as f32 * scale_w;
116
117 let y0 = (src_y.floor() as usize).min(old_h - 1);
118 let y1 = (y0 + 1).min(old_h - 1);
119 let x0 = (src_x.floor() as usize).min(old_w - 1);
120 let x1 = (x0 + 1).min(old_w - 1);
121
122 let dy = src_y - y0 as f32;
123 let dx = src_x - x0 as f32;
124
125 let base = c * old_h * old_w;
126 let v00 = data[base + y0 * old_w + x0];
127 let v01 = data[base + y0 * old_w + x1];
128 let v10 = data[base + y1 * old_w + x0];
129 let v11 = data[base + y1 * old_w + x1];
130
131 let value = v00 * (1.0 - dx) * (1.0 - dy)
132 + v01 * dx * (1.0 - dy)
133 + v10 * (1.0 - dx) * dy
134 + v11 * dx * dy;
135
136 result[c * new_h * new_w + y * new_w + x] = value;
137 }
138 }
139 }
140
141 Tensor::from_vec(result, &[channels, new_h, new_w]).unwrap()
142}
143
144fn resize_4d(input: &Tensor<f32>, new_h: usize, new_w: usize) -> Tensor<f32> {
146 let shape = input.shape();
147 let (batch, channels, old_h, old_w) = (shape[0], shape[1], shape[2], shape[3]);
148 let data = input.to_vec();
149
150 let mut result = vec![0.0; batch * channels * new_h * new_w];
151
152 let scale_h = old_h as f32 / new_h as f32;
153 let scale_w = old_w as f32 / new_w as f32;
154
155 for n in 0..batch {
156 for c in 0..channels {
157 for y in 0..new_h {
158 for x in 0..new_w {
159 let src_y = y as f32 * scale_h;
160 let src_x = x as f32 * scale_w;
161
162 let y0 = (src_y.floor() as usize).min(old_h - 1);
163 let y1 = (y0 + 1).min(old_h - 1);
164 let x0 = (src_x.floor() as usize).min(old_w - 1);
165 let x1 = (x0 + 1).min(old_w - 1);
166
167 let dy = src_y - y0 as f32;
168 let dx = src_x - x0 as f32;
169
170 let base = n * channels * old_h * old_w + c * old_h * old_w;
171 let v00 = data[base + y0 * old_w + x0];
172 let v01 = data[base + y0 * old_w + x1];
173 let v10 = data[base + y1 * old_w + x0];
174 let v11 = data[base + y1 * old_w + x1];
175
176 let value = v00 * (1.0 - dx) * (1.0 - dy)
177 + v01 * dx * (1.0 - dy)
178 + v10 * (1.0 - dx) * dy
179 + v11 * dx * dy;
180
181 let out_idx = n * channels * new_h * new_w + c * new_h * new_w + y * new_w + x;
182 result[out_idx] = value;
183 }
184 }
185 }
186 }
187
188 Tensor::from_vec(result, &[batch, channels, new_h, new_w]).unwrap()
189}
190
191pub struct CenterCrop {
197 height: usize,
198 width: usize,
199}
200
201impl CenterCrop {
202 #[must_use]
204 pub fn new(height: usize, width: usize) -> Self {
205 Self { height, width }
206 }
207
208 #[must_use]
210 pub fn square(size: usize) -> Self {
211 Self::new(size, size)
212 }
213}
214
215impl Transform for CenterCrop {
216 fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
217 let shape = input.shape();
218 let data = input.to_vec();
219
220 match shape.len() {
221 2 => {
222 let (h, w) = (shape[0], shape[1]);
223 let start_h = (h.saturating_sub(self.height)) / 2;
224 let start_w = (w.saturating_sub(self.width)) / 2;
225 let crop_h = self.height.min(h);
226 let crop_w = self.width.min(w);
227
228 let mut result = Vec::with_capacity(crop_h * crop_w);
229 for y in start_h..start_h + crop_h {
230 for x in start_w..start_w + crop_w {
231 result.push(data[y * w + x]);
232 }
233 }
234 Tensor::from_vec(result, &[crop_h, crop_w]).unwrap()
235 }
236 3 => {
237 let (c, h, w) = (shape[0], shape[1], shape[2]);
238 let start_h = (h.saturating_sub(self.height)) / 2;
239 let start_w = (w.saturating_sub(self.width)) / 2;
240 let crop_h = self.height.min(h);
241 let crop_w = self.width.min(w);
242
243 let mut result = Vec::with_capacity(c * crop_h * crop_w);
244 for ch in 0..c {
245 for y in start_h..start_h + crop_h {
246 for x in start_w..start_w + crop_w {
247 result.push(data[ch * h * w + y * w + x]);
248 }
249 }
250 }
251 Tensor::from_vec(result, &[c, crop_h, crop_w]).unwrap()
252 }
253 _ => input.clone(),
254 }
255 }
256}
257
258pub struct RandomHorizontalFlip {
264 probability: f32,
265}
266
267impl RandomHorizontalFlip {
268 #[must_use]
270 pub fn new() -> Self {
271 Self { probability: 0.5 }
272 }
273
274 #[must_use]
276 pub fn with_probability(probability: f32) -> Self {
277 Self {
278 probability: probability.clamp(0.0, 1.0),
279 }
280 }
281}
282
283impl Default for RandomHorizontalFlip {
284 fn default() -> Self {
285 Self::new()
286 }
287}
288
289impl Transform for RandomHorizontalFlip {
290 fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
291 let mut rng = rand::thread_rng();
292 if rng.r#gen::<f32>() > self.probability {
293 return input.clone();
294 }
295
296 let shape = input.shape();
297 let data = input.to_vec();
298
299 match shape.len() {
300 2 => {
301 let (h, w) = (shape[0], shape[1]);
302 let mut result = vec![0.0; h * w];
303 for y in 0..h {
304 for x in 0..w {
305 result[y * w + x] = data[y * w + (w - 1 - x)];
306 }
307 }
308 Tensor::from_vec(result, shape).unwrap()
309 }
310 3 => {
311 let (c, h, w) = (shape[0], shape[1], shape[2]);
312 let mut result = vec![0.0; c * h * w];
313 for ch in 0..c {
314 for y in 0..h {
315 for x in 0..w {
316 result[ch * h * w + y * w + x] = data[ch * h * w + y * w + (w - 1 - x)];
317 }
318 }
319 }
320 Tensor::from_vec(result, shape).unwrap()
321 }
322 _ => input.clone(),
323 }
324 }
325}
326
327pub struct RandomVerticalFlip {
333 probability: f32,
334}
335
336impl RandomVerticalFlip {
337 #[must_use]
339 pub fn new() -> Self {
340 Self { probability: 0.5 }
341 }
342
343 #[must_use]
345 pub fn with_probability(probability: f32) -> Self {
346 Self {
347 probability: probability.clamp(0.0, 1.0),
348 }
349 }
350}
351
352impl Default for RandomVerticalFlip {
353 fn default() -> Self {
354 Self::new()
355 }
356}
357
358impl Transform for RandomVerticalFlip {
359 fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
360 let mut rng = rand::thread_rng();
361 if rng.r#gen::<f32>() > self.probability {
362 return input.clone();
363 }
364
365 let shape = input.shape();
366 let data = input.to_vec();
367
368 match shape.len() {
369 2 => {
370 let (h, w) = (shape[0], shape[1]);
371 let mut result = vec![0.0; h * w];
372 for y in 0..h {
373 for x in 0..w {
374 result[y * w + x] = data[(h - 1 - y) * w + x];
375 }
376 }
377 Tensor::from_vec(result, shape).unwrap()
378 }
379 3 => {
380 let (c, h, w) = (shape[0], shape[1], shape[2]);
381 let mut result = vec![0.0; c * h * w];
382 for ch in 0..c {
383 for y in 0..h {
384 for x in 0..w {
385 result[ch * h * w + y * w + x] = data[ch * h * w + (h - 1 - y) * w + x];
386 }
387 }
388 }
389 Tensor::from_vec(result, shape).unwrap()
390 }
391 _ => input.clone(),
392 }
393 }
394}
395
396pub struct RandomRotation {
402 angles: Vec<i32>,
404}
405
406impl RandomRotation {
407 #[must_use]
409 pub fn new() -> Self {
410 Self {
411 angles: vec![0, 90, 180, 270],
412 }
413 }
414
415 #[must_use]
417 pub fn with_angles(angles: Vec<i32>) -> Self {
418 let valid: Vec<i32> = angles
419 .into_iter()
420 .filter(|&a| a == 0 || a == 90 || a == 180 || a == 270)
421 .collect();
422 Self {
423 angles: if valid.is_empty() { vec![0] } else { valid },
424 }
425 }
426}
427
428impl Default for RandomRotation {
429 fn default() -> Self {
430 Self::new()
431 }
432}
433
434impl Transform for RandomRotation {
435 fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
436 let mut rng = rand::thread_rng();
437 let angle = self.angles[rng.gen_range(0..self.angles.len())];
438
439 if angle == 0 {
440 return input.clone();
441 }
442
443 let shape = input.shape();
444 let data = input.to_vec();
445
446 if shape.len() != 2 {
448 return input.clone();
449 }
450
451 let (h, w) = (shape[0], shape[1]);
452
453 match angle {
454 90 => {
455 let mut result = vec![0.0; h * w];
457 for y in 0..h {
458 for x in 0..w {
459 result[x * h + (h - 1 - y)] = data[y * w + x];
460 }
461 }
462 Tensor::from_vec(result, &[w, h]).unwrap()
463 }
464 180 => {
465 let mut result = vec![0.0; h * w];
467 for y in 0..h {
468 for x in 0..w {
469 result[(h - 1 - y) * w + (w - 1 - x)] = data[y * w + x];
470 }
471 }
472 Tensor::from_vec(result, &[h, w]).unwrap()
473 }
474 270 => {
475 let mut result = vec![0.0; h * w];
477 for y in 0..h {
478 for x in 0..w {
479 result[(w - 1 - x) * h + y] = data[y * w + x];
480 }
481 }
482 Tensor::from_vec(result, &[w, h]).unwrap()
483 }
484 _ => input.clone(),
485 }
486 }
487}
488
489pub struct ColorJitter {
495 brightness: f32,
496 contrast: f32,
497 saturation: f32,
498}
499
500impl ColorJitter {
501 #[must_use]
503 pub fn new(brightness: f32, contrast: f32, saturation: f32) -> Self {
504 Self {
505 brightness: brightness.abs(),
506 contrast: contrast.abs(),
507 saturation: saturation.abs(),
508 }
509 }
510}
511
512impl Transform for ColorJitter {
513 fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
514 let mut rng = rand::thread_rng();
515 let mut data = input.to_vec();
516 let shape = input.shape();
517
518 if self.brightness > 0.0 {
520 let factor = 1.0 + rng.gen_range(-self.brightness..self.brightness);
521 for val in &mut data {
522 *val = (*val * factor).clamp(0.0, 1.0);
523 }
524 }
525
526 if self.contrast > 0.0 {
528 let factor = 1.0 + rng.gen_range(-self.contrast..self.contrast);
529 let mean: f32 = data.iter().sum::<f32>() / data.len() as f32;
530 for val in &mut data {
531 *val = ((*val - mean) * factor + mean).clamp(0.0, 1.0);
532 }
533 }
534
535 if self.saturation > 0.0 && shape.len() == 3 && shape[0] == 3 {
537 let factor = 1.0 + rng.gen_range(-self.saturation..self.saturation);
538 let (h, w) = (shape[1], shape[2]);
539
540 for y in 0..h {
541 for x in 0..w {
542 let r = data[0 * h * w + y * w + x];
543 let g = data[h * w + y * w + x];
544 let b = data[2 * h * w + y * w + x];
545
546 let gray = 0.299 * r + 0.587 * g + 0.114 * b;
547
548 data[0 * h * w + y * w + x] = (gray + (r - gray) * factor).clamp(0.0, 1.0);
549 data[h * w + y * w + x] = (gray + (g - gray) * factor).clamp(0.0, 1.0);
550 data[2 * h * w + y * w + x] = (gray + (b - gray) * factor).clamp(0.0, 1.0);
551 }
552 }
553 }
554
555 Tensor::from_vec(data, shape).unwrap()
556 }
557}
558
559pub struct Grayscale {
565 num_output_channels: usize,
566}
567
568impl Grayscale {
569 #[must_use]
571 pub fn new() -> Self {
572 Self {
573 num_output_channels: 1,
574 }
575 }
576
577 #[must_use]
579 pub fn with_channels(num_output_channels: usize) -> Self {
580 Self {
581 num_output_channels: num_output_channels.max(1),
582 }
583 }
584}
585
586impl Default for Grayscale {
587 fn default() -> Self {
588 Self::new()
589 }
590}
591
592impl Transform for Grayscale {
593 fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
594 let shape = input.shape();
595
596 if shape.len() != 3 || shape[0] != 3 {
598 return input.clone();
599 }
600
601 let (_, h, w) = (shape[0], shape[1], shape[2]);
602 let data = input.to_vec();
603
604 let mut gray = Vec::with_capacity(h * w);
605 for y in 0..h {
606 for x in 0..w {
607 let r = data[0 * h * w + y * w + x];
608 let g = data[h * w + y * w + x];
609 let b = data[2 * h * w + y * w + x];
610 gray.push(0.299 * r + 0.587 * g + 0.114 * b);
611 }
612 }
613
614 if self.num_output_channels == 1 {
615 Tensor::from_vec(gray, &[1, h, w]).unwrap()
616 } else {
617 let mut result = Vec::with_capacity(self.num_output_channels * h * w);
619 for _ in 0..self.num_output_channels {
620 result.extend(&gray);
621 }
622 Tensor::from_vec(result, &[self.num_output_channels, h, w]).unwrap()
623 }
624 }
625}
626
627pub struct ImageNormalize {
633 mean: Vec<f32>,
634 std: Vec<f32>,
635}
636
637impl ImageNormalize {
638 #[must_use]
640 pub fn new(mean: Vec<f32>, std: Vec<f32>) -> Self {
641 Self { mean, std }
642 }
643
644 #[must_use]
646 pub fn imagenet() -> Self {
647 Self::new(vec![0.485, 0.456, 0.406], vec![0.229, 0.224, 0.225])
648 }
649
650 #[must_use]
652 pub fn mnist() -> Self {
653 Self::new(vec![0.1307], vec![0.3081])
654 }
655
656 #[must_use]
658 pub fn cifar10() -> Self {
659 Self::new(vec![0.4914, 0.4822, 0.4465], vec![0.2470, 0.2435, 0.2616])
660 }
661}
662
663impl Transform for ImageNormalize {
664 fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
665 let shape = input.shape();
666 let mut data = input.to_vec();
667
668 match shape.len() {
669 3 => {
670 let (c, h, w) = (shape[0], shape[1], shape[2]);
671 for ch in 0..c {
672 let mean = self.mean.get(ch).copied().unwrap_or(0.0);
673 let std = self.std.get(ch).copied().unwrap_or(1.0);
674 for y in 0..h {
675 for x in 0..w {
676 let idx = ch * h * w + y * w + x;
677 data[idx] = (data[idx] - mean) / std;
678 }
679 }
680 }
681 }
682 4 => {
683 let (n, c, h, w) = (shape[0], shape[1], shape[2], shape[3]);
684 for batch in 0..n {
685 for ch in 0..c {
686 let mean = self.mean.get(ch).copied().unwrap_or(0.0);
687 let std = self.std.get(ch).copied().unwrap_or(1.0);
688 for y in 0..h {
689 for x in 0..w {
690 let idx = batch * c * h * w + ch * h * w + y * w + x;
691 data[idx] = (data[idx] - mean) / std;
692 }
693 }
694 }
695 }
696 }
697 _ => {}
698 }
699
700 Tensor::from_vec(data, shape).unwrap()
701 }
702}
703
704pub struct Pad {
710 padding: (usize, usize, usize, usize), fill_value: f32,
712}
713
714impl Pad {
715 #[must_use]
717 pub fn new(padding: usize) -> Self {
718 Self {
719 padding: (padding, padding, padding, padding),
720 fill_value: 0.0,
721 }
722 }
723
724 #[must_use]
726 pub fn asymmetric(left: usize, right: usize, top: usize, bottom: usize) -> Self {
727 Self {
728 padding: (left, right, top, bottom),
729 fill_value: 0.0,
730 }
731 }
732
733 #[must_use]
735 pub fn with_fill(mut self, value: f32) -> Self {
736 self.fill_value = value;
737 self
738 }
739}
740
741impl Transform for Pad {
742 fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
743 let shape = input.shape();
744 let data = input.to_vec();
745 let (left, right, top, bottom) = self.padding;
746
747 match shape.len() {
748 2 => {
749 let (h, w) = (shape[0], shape[1]);
750 let new_h = h + top + bottom;
751 let new_w = w + left + right;
752
753 let mut result = vec![self.fill_value; new_h * new_w];
754 for y in 0..h {
755 for x in 0..w {
756 result[(y + top) * new_w + (x + left)] = data[y * w + x];
757 }
758 }
759 Tensor::from_vec(result, &[new_h, new_w]).unwrap()
760 }
761 3 => {
762 let (c, h, w) = (shape[0], shape[1], shape[2]);
763 let new_h = h + top + bottom;
764 let new_w = w + left + right;
765
766 let mut result = vec![self.fill_value; c * new_h * new_w];
767 for ch in 0..c {
768 for y in 0..h {
769 for x in 0..w {
770 result[ch * new_h * new_w + (y + top) * new_w + (x + left)] =
771 data[ch * h * w + y * w + x];
772 }
773 }
774 }
775 Tensor::from_vec(result, &[c, new_h, new_w]).unwrap()
776 }
777 _ => input.clone(),
778 }
779 }
780}
781
782pub struct ToTensorImage;
788
789impl ToTensorImage {
790 #[must_use]
792 pub fn new() -> Self {
793 Self
794 }
795}
796
797impl Default for ToTensorImage {
798 fn default() -> Self {
799 Self::new()
800 }
801}
802
803impl Transform for ToTensorImage {
804 fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
805 let data: Vec<f32> = input.to_vec().iter().map(|&x| x / 255.0).collect();
806 Tensor::from_vec(data, input.shape()).unwrap()
807 }
808}
809
810#[cfg(test)]
815mod tests {
816 use super::*;
817
818 #[test]
819 fn test_resize_2d() {
820 let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
821
822 let resize = Resize::new(4, 4);
823 let output = resize.apply(&input);
824
825 assert_eq!(output.shape(), &[4, 4]);
826 }
827
828 #[test]
829 fn test_resize_3d() {
830 let input = Tensor::from_vec(vec![1.0; 3 * 8 * 8], &[3, 8, 8]).unwrap();
831
832 let resize = Resize::new(4, 4);
833 let output = resize.apply(&input);
834
835 assert_eq!(output.shape(), &[3, 4, 4]);
836 }
837
838 #[test]
839 fn test_center_crop() {
840 let input = Tensor::from_vec((1..=16).map(|x| x as f32).collect(), &[4, 4]).unwrap();
841
842 let crop = CenterCrop::new(2, 2);
843 let output = crop.apply(&input);
844
845 assert_eq!(output.shape(), &[2, 2]);
846 assert_eq!(output.to_vec(), vec![6.0, 7.0, 10.0, 11.0]);
848 }
849
850 #[test]
851 fn test_random_horizontal_flip() {
852 let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
853
854 let flip = RandomHorizontalFlip::with_probability(1.0);
855 let output = flip.apply(&input);
856
857 assert_eq!(output.to_vec(), vec![2.0, 1.0, 4.0, 3.0]);
859 }
860
861 #[test]
862 fn test_random_vertical_flip() {
863 let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
864
865 let flip = RandomVerticalFlip::with_probability(1.0);
866 let output = flip.apply(&input);
867
868 assert_eq!(output.to_vec(), vec![3.0, 4.0, 1.0, 2.0]);
870 }
871
872 #[test]
873 fn test_random_rotation_180() {
874 let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
875
876 let rotation = RandomRotation::with_angles(vec![180]);
877 let output = rotation.apply(&input);
878
879 assert_eq!(output.to_vec(), vec![4.0, 3.0, 2.0, 1.0]);
881 }
882
883 #[test]
884 fn test_grayscale() {
885 let input = Tensor::from_vec(
886 vec![
887 1.0, 1.0, 1.0, 1.0, 0.5, 0.5, 0.5, 0.5, 0.0, 0.0, 0.0, 0.0, ],
891 &[3, 2, 2],
892 )
893 .unwrap();
894
895 let gray = Grayscale::new();
896 let output = gray.apply(&input);
897
898 assert_eq!(output.shape(), &[1, 2, 2]);
899 let expected = 0.299 + 0.587 * 0.5;
901 for val in output.to_vec() {
902 assert!((val - expected).abs() < 0.001);
903 }
904 }
905
906 #[test]
907 fn test_image_normalize() {
908 let input = Tensor::from_vec(vec![0.5; 3 * 2 * 2], &[3, 2, 2]).unwrap();
909
910 let normalize = ImageNormalize::new(vec![0.5, 0.5, 0.5], vec![0.5, 0.5, 0.5]);
911 let output = normalize.apply(&input);
912
913 for val in output.to_vec() {
915 assert!((val - 0.0).abs() < 0.001);
916 }
917 }
918
919 #[test]
920 fn test_pad() {
921 let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
922
923 let pad = Pad::new(1);
924 let output = pad.apply(&input);
925
926 assert_eq!(output.shape(), &[4, 4]);
927 let data = output.to_vec();
929 assert_eq!(data[0], 0.0);
930 assert_eq!(data[3], 0.0);
931 assert_eq!(data[12], 0.0);
932 assert_eq!(data[15], 0.0);
933 assert_eq!(data[5], 1.0);
935 assert_eq!(data[6], 2.0);
936 assert_eq!(data[9], 3.0);
937 assert_eq!(data[10], 4.0);
938 }
939
940 #[test]
941 fn test_to_tensor_image() {
942 let input = Tensor::from_vec(vec![0.0, 127.5, 255.0], &[3]).unwrap();
943
944 let transform = ToTensorImage::new();
945 let output = transform.apply(&input);
946
947 let data = output.to_vec();
948 assert!((data[0] - 0.0).abs() < 0.001);
949 assert!((data[1] - 0.5).abs() < 0.001);
950 assert!((data[2] - 1.0).abs() < 0.001);
951 }
952
953 #[test]
954 fn test_color_jitter() {
955 let input = Tensor::from_vec(vec![0.5; 3 * 4 * 4], &[3, 4, 4]).unwrap();
956
957 let jitter = ColorJitter::new(0.1, 0.1, 0.1);
958 let output = jitter.apply(&input);
959
960 assert_eq!(output.shape(), &[3, 4, 4]);
961 for val in output.to_vec() {
963 assert!((0.0..=1.0).contains(&val));
964 }
965 }
966}