1use axonml_data::Transform;
9use axonml_tensor::Tensor;
10use rand::Rng;
11
12pub struct Resize {
18 height: usize,
19 width: usize,
20}
21
22impl Resize {
23 #[must_use] pub fn new(height: usize, width: usize) -> Self {
25 Self { height, width }
26 }
27
28 #[must_use] pub fn square(size: usize) -> Self {
30 Self::new(size, size)
31 }
32}
33
34impl Transform for Resize {
35 fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
36 let shape = input.shape();
37
38 match shape.len() {
40 2 => resize_2d(input, self.height, self.width),
41 3 => resize_3d(input, self.height, self.width),
42 4 => resize_4d(input, self.height, self.width),
43 _ => input.clone(),
44 }
45 }
46}
47
48fn resize_2d(input: &Tensor<f32>, new_h: usize, new_w: usize) -> Tensor<f32> {
50 let shape = input.shape();
51 let (old_h, old_w) = (shape[0], shape[1]);
52 let data = input.to_vec();
53
54 let mut result = vec![0.0; new_h * new_w];
55
56 let scale_h = old_h as f32 / new_h as f32;
57 let scale_w = old_w as f32 / new_w as f32;
58
59 for y in 0..new_h {
60 for x in 0..new_w {
61 let src_y = y as f32 * scale_h;
62 let src_x = x as f32 * scale_w;
63
64 let y0 = (src_y.floor() as usize).min(old_h - 1);
65 let y1 = (y0 + 1).min(old_h - 1);
66 let x0 = (src_x.floor() as usize).min(old_w - 1);
67 let x1 = (x0 + 1).min(old_w - 1);
68
69 let dy = src_y - y0 as f32;
70 let dx = src_x - x0 as f32;
71
72 let v00 = data[y0 * old_w + x0];
73 let v01 = data[y0 * old_w + x1];
74 let v10 = data[y1 * old_w + x0];
75 let v11 = data[y1 * old_w + x1];
76
77 let value = v00 * (1.0 - dx) * (1.0 - dy)
78 + v01 * dx * (1.0 - dy)
79 + v10 * (1.0 - dx) * dy
80 + v11 * dx * dy;
81
82 result[y * new_w + x] = value;
83 }
84 }
85
86 Tensor::from_vec(result, &[new_h, new_w]).unwrap()
87}
88
89fn resize_3d(input: &Tensor<f32>, new_h: usize, new_w: usize) -> Tensor<f32> {
91 let shape = input.shape();
92 let (channels, old_h, old_w) = (shape[0], shape[1], shape[2]);
93 let data = input.to_vec();
94
95 let mut result = vec![0.0; channels * new_h * new_w];
96
97 let scale_h = old_h as f32 / new_h as f32;
98 let scale_w = old_w as f32 / new_w as f32;
99
100 for c in 0..channels {
101 for y in 0..new_h {
102 for x in 0..new_w {
103 let src_y = y as f32 * scale_h;
104 let src_x = x as f32 * scale_w;
105
106 let y0 = (src_y.floor() as usize).min(old_h - 1);
107 let y1 = (y0 + 1).min(old_h - 1);
108 let x0 = (src_x.floor() as usize).min(old_w - 1);
109 let x1 = (x0 + 1).min(old_w - 1);
110
111 let dy = src_y - y0 as f32;
112 let dx = src_x - x0 as f32;
113
114 let base = c * old_h * old_w;
115 let v00 = data[base + y0 * old_w + x0];
116 let v01 = data[base + y0 * old_w + x1];
117 let v10 = data[base + y1 * old_w + x0];
118 let v11 = data[base + y1 * old_w + x1];
119
120 let value = v00 * (1.0 - dx) * (1.0 - dy)
121 + v01 * dx * (1.0 - dy)
122 + v10 * (1.0 - dx) * dy
123 + v11 * dx * dy;
124
125 result[c * new_h * new_w + y * new_w + x] = value;
126 }
127 }
128 }
129
130 Tensor::from_vec(result, &[channels, new_h, new_w]).unwrap()
131}
132
133fn resize_4d(input: &Tensor<f32>, new_h: usize, new_w: usize) -> Tensor<f32> {
135 let shape = input.shape();
136 let (batch, channels, old_h, old_w) = (shape[0], shape[1], shape[2], shape[3]);
137 let data = input.to_vec();
138
139 let mut result = vec![0.0; batch * channels * new_h * new_w];
140
141 let scale_h = old_h as f32 / new_h as f32;
142 let scale_w = old_w as f32 / new_w as f32;
143
144 for n in 0..batch {
145 for c in 0..channels {
146 for y in 0..new_h {
147 for x in 0..new_w {
148 let src_y = y as f32 * scale_h;
149 let src_x = x as f32 * scale_w;
150
151 let y0 = (src_y.floor() as usize).min(old_h - 1);
152 let y1 = (y0 + 1).min(old_h - 1);
153 let x0 = (src_x.floor() as usize).min(old_w - 1);
154 let x1 = (x0 + 1).min(old_w - 1);
155
156 let dy = src_y - y0 as f32;
157 let dx = src_x - x0 as f32;
158
159 let base = n * channels * old_h * old_w + c * old_h * old_w;
160 let v00 = data[base + y0 * old_w + x0];
161 let v01 = data[base + y0 * old_w + x1];
162 let v10 = data[base + y1 * old_w + x0];
163 let v11 = data[base + y1 * old_w + x1];
164
165 let value = v00 * (1.0 - dx) * (1.0 - dy)
166 + v01 * dx * (1.0 - dy)
167 + v10 * (1.0 - dx) * dy
168 + v11 * dx * dy;
169
170 let out_idx = n * channels * new_h * new_w + c * new_h * new_w + y * new_w + x;
171 result[out_idx] = value;
172 }
173 }
174 }
175 }
176
177 Tensor::from_vec(result, &[batch, channels, new_h, new_w]).unwrap()
178}
179
180pub struct CenterCrop {
186 height: usize,
187 width: usize,
188}
189
190impl CenterCrop {
191 #[must_use] pub fn new(height: usize, width: usize) -> Self {
193 Self { height, width }
194 }
195
196 #[must_use] pub fn square(size: usize) -> Self {
198 Self::new(size, size)
199 }
200}
201
202impl Transform for CenterCrop {
203 fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
204 let shape = input.shape();
205 let data = input.to_vec();
206
207 match shape.len() {
208 2 => {
209 let (h, w) = (shape[0], shape[1]);
210 let start_h = (h.saturating_sub(self.height)) / 2;
211 let start_w = (w.saturating_sub(self.width)) / 2;
212 let crop_h = self.height.min(h);
213 let crop_w = self.width.min(w);
214
215 let mut result = Vec::with_capacity(crop_h * crop_w);
216 for y in start_h..start_h + crop_h {
217 for x in start_w..start_w + crop_w {
218 result.push(data[y * w + x]);
219 }
220 }
221 Tensor::from_vec(result, &[crop_h, crop_w]).unwrap()
222 }
223 3 => {
224 let (c, h, w) = (shape[0], shape[1], shape[2]);
225 let start_h = (h.saturating_sub(self.height)) / 2;
226 let start_w = (w.saturating_sub(self.width)) / 2;
227 let crop_h = self.height.min(h);
228 let crop_w = self.width.min(w);
229
230 let mut result = Vec::with_capacity(c * crop_h * crop_w);
231 for ch in 0..c {
232 for y in start_h..start_h + crop_h {
233 for x in start_w..start_w + crop_w {
234 result.push(data[ch * h * w + y * w + x]);
235 }
236 }
237 }
238 Tensor::from_vec(result, &[c, crop_h, crop_w]).unwrap()
239 }
240 _ => input.clone(),
241 }
242 }
243}
244
245pub struct RandomHorizontalFlip {
251 probability: f32,
252}
253
254impl RandomHorizontalFlip {
255 #[must_use] pub fn new() -> Self {
257 Self { probability: 0.5 }
258 }
259
260 #[must_use] pub fn with_probability(probability: f32) -> Self {
262 Self {
263 probability: probability.clamp(0.0, 1.0),
264 }
265 }
266}
267
268impl Default for RandomHorizontalFlip {
269 fn default() -> Self {
270 Self::new()
271 }
272}
273
274impl Transform for RandomHorizontalFlip {
275 fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
276 let mut rng = rand::thread_rng();
277 if rng.gen::<f32>() > self.probability {
278 return input.clone();
279 }
280
281 let shape = input.shape();
282 let data = input.to_vec();
283
284 match shape.len() {
285 2 => {
286 let (h, w) = (shape[0], shape[1]);
287 let mut result = vec![0.0; h * w];
288 for y in 0..h {
289 for x in 0..w {
290 result[y * w + x] = data[y * w + (w - 1 - x)];
291 }
292 }
293 Tensor::from_vec(result, shape).unwrap()
294 }
295 3 => {
296 let (c, h, w) = (shape[0], shape[1], shape[2]);
297 let mut result = vec![0.0; c * h * w];
298 for ch in 0..c {
299 for y in 0..h {
300 for x in 0..w {
301 result[ch * h * w + y * w + x] = data[ch * h * w + y * w + (w - 1 - x)];
302 }
303 }
304 }
305 Tensor::from_vec(result, shape).unwrap()
306 }
307 _ => input.clone(),
308 }
309 }
310}
311
312pub struct RandomVerticalFlip {
318 probability: f32,
319}
320
321impl RandomVerticalFlip {
322 #[must_use] pub fn new() -> Self {
324 Self { probability: 0.5 }
325 }
326
327 #[must_use] pub fn with_probability(probability: f32) -> Self {
329 Self {
330 probability: probability.clamp(0.0, 1.0),
331 }
332 }
333}
334
335impl Default for RandomVerticalFlip {
336 fn default() -> Self {
337 Self::new()
338 }
339}
340
341impl Transform for RandomVerticalFlip {
342 fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
343 let mut rng = rand::thread_rng();
344 if rng.gen::<f32>() > self.probability {
345 return input.clone();
346 }
347
348 let shape = input.shape();
349 let data = input.to_vec();
350
351 match shape.len() {
352 2 => {
353 let (h, w) = (shape[0], shape[1]);
354 let mut result = vec![0.0; h * w];
355 for y in 0..h {
356 for x in 0..w {
357 result[y * w + x] = data[(h - 1 - y) * w + x];
358 }
359 }
360 Tensor::from_vec(result, shape).unwrap()
361 }
362 3 => {
363 let (c, h, w) = (shape[0], shape[1], shape[2]);
364 let mut result = vec![0.0; c * h * w];
365 for ch in 0..c {
366 for y in 0..h {
367 for x in 0..w {
368 result[ch * h * w + y * w + x] = data[ch * h * w + (h - 1 - y) * w + x];
369 }
370 }
371 }
372 Tensor::from_vec(result, shape).unwrap()
373 }
374 _ => input.clone(),
375 }
376 }
377}
378
379pub struct RandomRotation {
385 angles: Vec<i32>,
387}
388
389impl RandomRotation {
390 #[must_use] pub fn new() -> Self {
392 Self {
393 angles: vec![0, 90, 180, 270],
394 }
395 }
396
397 #[must_use] pub fn with_angles(angles: Vec<i32>) -> Self {
399 let valid: Vec<i32> = angles
400 .into_iter()
401 .filter(|&a| a == 0 || a == 90 || a == 180 || a == 270)
402 .collect();
403 Self {
404 angles: if valid.is_empty() { vec![0] } else { valid },
405 }
406 }
407}
408
409impl Default for RandomRotation {
410 fn default() -> Self {
411 Self::new()
412 }
413}
414
415impl Transform for RandomRotation {
416 fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
417 let mut rng = rand::thread_rng();
418 let angle = self.angles[rng.gen_range(0..self.angles.len())];
419
420 if angle == 0 {
421 return input.clone();
422 }
423
424 let shape = input.shape();
425 let data = input.to_vec();
426
427 if shape.len() != 2 {
429 return input.clone();
430 }
431
432 let (h, w) = (shape[0], shape[1]);
433
434 match angle {
435 90 => {
436 let mut result = vec![0.0; h * w];
438 for y in 0..h {
439 for x in 0..w {
440 result[x * h + (h - 1 - y)] = data[y * w + x];
441 }
442 }
443 Tensor::from_vec(result, &[w, h]).unwrap()
444 }
445 180 => {
446 let mut result = vec![0.0; h * w];
448 for y in 0..h {
449 for x in 0..w {
450 result[(h - 1 - y) * w + (w - 1 - x)] = data[y * w + x];
451 }
452 }
453 Tensor::from_vec(result, &[h, w]).unwrap()
454 }
455 270 => {
456 let mut result = vec![0.0; h * w];
458 for y in 0..h {
459 for x in 0..w {
460 result[(w - 1 - x) * h + y] = data[y * w + x];
461 }
462 }
463 Tensor::from_vec(result, &[w, h]).unwrap()
464 }
465 _ => input.clone(),
466 }
467 }
468}
469
470pub struct ColorJitter {
476 brightness: f32,
477 contrast: f32,
478 saturation: f32,
479}
480
481impl ColorJitter {
482 #[must_use] pub fn new(brightness: f32, contrast: f32, saturation: f32) -> Self {
484 Self {
485 brightness: brightness.abs(),
486 contrast: contrast.abs(),
487 saturation: saturation.abs(),
488 }
489 }
490}
491
492impl Transform for ColorJitter {
493 fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
494 let mut rng = rand::thread_rng();
495 let mut data = input.to_vec();
496 let shape = input.shape();
497
498 if self.brightness > 0.0 {
500 let factor = 1.0 + rng.gen_range(-self.brightness..self.brightness);
501 for val in &mut data {
502 *val = (*val * factor).clamp(0.0, 1.0);
503 }
504 }
505
506 if self.contrast > 0.0 {
508 let factor = 1.0 + rng.gen_range(-self.contrast..self.contrast);
509 let mean: f32 = data.iter().sum::<f32>() / data.len() as f32;
510 for val in &mut data {
511 *val = ((*val - mean) * factor + mean).clamp(0.0, 1.0);
512 }
513 }
514
515 if self.saturation > 0.0 && shape.len() == 3 && shape[0] == 3 {
517 let factor = 1.0 + rng.gen_range(-self.saturation..self.saturation);
518 let (h, w) = (shape[1], shape[2]);
519
520 for y in 0..h {
521 for x in 0..w {
522 let r = data[0 * h * w + y * w + x];
523 let g = data[h * w + y * w + x];
524 let b = data[2 * h * w + y * w + x];
525
526 let gray = 0.299 * r + 0.587 * g + 0.114 * b;
527
528 data[0 * h * w + y * w + x] = (gray + (r - gray) * factor).clamp(0.0, 1.0);
529 data[h * w + y * w + x] = (gray + (g - gray) * factor).clamp(0.0, 1.0);
530 data[2 * h * w + y * w + x] = (gray + (b - gray) * factor).clamp(0.0, 1.0);
531 }
532 }
533 }
534
535 Tensor::from_vec(data, shape).unwrap()
536 }
537}
538
539pub struct Grayscale {
545 num_output_channels: usize,
546}
547
548impl Grayscale {
549 #[must_use] pub fn new() -> Self {
551 Self {
552 num_output_channels: 1,
553 }
554 }
555
556 #[must_use] pub fn with_channels(num_output_channels: usize) -> Self {
558 Self {
559 num_output_channels: num_output_channels.max(1),
560 }
561 }
562}
563
564impl Default for Grayscale {
565 fn default() -> Self {
566 Self::new()
567 }
568}
569
570impl Transform for Grayscale {
571 fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
572 let shape = input.shape();
573
574 if shape.len() != 3 || shape[0] != 3 {
576 return input.clone();
577 }
578
579 let (_, h, w) = (shape[0], shape[1], shape[2]);
580 let data = input.to_vec();
581
582 let mut gray = Vec::with_capacity(h * w);
583 for y in 0..h {
584 for x in 0..w {
585 let r = data[0 * h * w + y * w + x];
586 let g = data[h * w + y * w + x];
587 let b = data[2 * h * w + y * w + x];
588 gray.push(0.299 * r + 0.587 * g + 0.114 * b);
589 }
590 }
591
592 if self.num_output_channels == 1 {
593 Tensor::from_vec(gray, &[1, h, w]).unwrap()
594 } else {
595 let mut result = Vec::with_capacity(self.num_output_channels * h * w);
597 for _ in 0..self.num_output_channels {
598 result.extend(&gray);
599 }
600 Tensor::from_vec(result, &[self.num_output_channels, h, w]).unwrap()
601 }
602 }
603}
604
605pub struct ImageNormalize {
611 mean: Vec<f32>,
612 std: Vec<f32>,
613}
614
615impl ImageNormalize {
616 #[must_use] pub fn new(mean: Vec<f32>, std: Vec<f32>) -> Self {
618 Self { mean, std }
619 }
620
621 #[must_use] pub fn imagenet() -> Self {
623 Self::new(vec![0.485, 0.456, 0.406], vec![0.229, 0.224, 0.225])
624 }
625
626 #[must_use] pub fn mnist() -> Self {
628 Self::new(vec![0.1307], vec![0.3081])
629 }
630
631 #[must_use] pub fn cifar10() -> Self {
633 Self::new(vec![0.4914, 0.4822, 0.4465], vec![0.2470, 0.2435, 0.2616])
634 }
635}
636
637impl Transform for ImageNormalize {
638 fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
639 let shape = input.shape();
640 let mut data = input.to_vec();
641
642 match shape.len() {
643 3 => {
644 let (c, h, w) = (shape[0], shape[1], shape[2]);
645 for ch in 0..c {
646 let mean = self.mean.get(ch).copied().unwrap_or(0.0);
647 let std = self.std.get(ch).copied().unwrap_or(1.0);
648 for y in 0..h {
649 for x in 0..w {
650 let idx = ch * h * w + y * w + x;
651 data[idx] = (data[idx] - mean) / std;
652 }
653 }
654 }
655 }
656 4 => {
657 let (n, c, h, w) = (shape[0], shape[1], shape[2], shape[3]);
658 for batch in 0..n {
659 for ch in 0..c {
660 let mean = self.mean.get(ch).copied().unwrap_or(0.0);
661 let std = self.std.get(ch).copied().unwrap_or(1.0);
662 for y in 0..h {
663 for x in 0..w {
664 let idx = batch * c * h * w + ch * h * w + y * w + x;
665 data[idx] = (data[idx] - mean) / std;
666 }
667 }
668 }
669 }
670 }
671 _ => {}
672 }
673
674 Tensor::from_vec(data, shape).unwrap()
675 }
676}
677
678pub struct Pad {
684 padding: (usize, usize, usize, usize), fill_value: f32,
686}
687
688impl Pad {
689 #[must_use] pub fn new(padding: usize) -> Self {
691 Self {
692 padding: (padding, padding, padding, padding),
693 fill_value: 0.0,
694 }
695 }
696
697 #[must_use] pub fn asymmetric(left: usize, right: usize, top: usize, bottom: usize) -> Self {
699 Self {
700 padding: (left, right, top, bottom),
701 fill_value: 0.0,
702 }
703 }
704
705 #[must_use] pub fn with_fill(mut self, value: f32) -> Self {
707 self.fill_value = value;
708 self
709 }
710}
711
712impl Transform for Pad {
713 fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
714 let shape = input.shape();
715 let data = input.to_vec();
716 let (left, right, top, bottom) = self.padding;
717
718 match shape.len() {
719 2 => {
720 let (h, w) = (shape[0], shape[1]);
721 let new_h = h + top + bottom;
722 let new_w = w + left + right;
723
724 let mut result = vec![self.fill_value; new_h * new_w];
725 for y in 0..h {
726 for x in 0..w {
727 result[(y + top) * new_w + (x + left)] = data[y * w + x];
728 }
729 }
730 Tensor::from_vec(result, &[new_h, new_w]).unwrap()
731 }
732 3 => {
733 let (c, h, w) = (shape[0], shape[1], shape[2]);
734 let new_h = h + top + bottom;
735 let new_w = w + left + right;
736
737 let mut result = vec![self.fill_value; c * new_h * new_w];
738 for ch in 0..c {
739 for y in 0..h {
740 for x in 0..w {
741 result[ch * new_h * new_w + (y + top) * new_w + (x + left)] =
742 data[ch * h * w + y * w + x];
743 }
744 }
745 }
746 Tensor::from_vec(result, &[c, new_h, new_w]).unwrap()
747 }
748 _ => input.clone(),
749 }
750 }
751}
752
753pub struct ToTensorImage;
759
760impl ToTensorImage {
761 #[must_use] pub fn new() -> Self {
763 Self
764 }
765}
766
767impl Default for ToTensorImage {
768 fn default() -> Self {
769 Self::new()
770 }
771}
772
773impl Transform for ToTensorImage {
774 fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
775 let data: Vec<f32> = input.to_vec().iter().map(|&x| x / 255.0).collect();
776 Tensor::from_vec(data, input.shape()).unwrap()
777 }
778}
779
780#[cfg(test)]
785mod tests {
786 use super::*;
787
788 #[test]
789 fn test_resize_2d() {
790 let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
791
792 let resize = Resize::new(4, 4);
793 let output = resize.apply(&input);
794
795 assert_eq!(output.shape(), &[4, 4]);
796 }
797
798 #[test]
799 fn test_resize_3d() {
800 let input = Tensor::from_vec(vec![1.0; 3 * 8 * 8], &[3, 8, 8]).unwrap();
801
802 let resize = Resize::new(4, 4);
803 let output = resize.apply(&input);
804
805 assert_eq!(output.shape(), &[3, 4, 4]);
806 }
807
808 #[test]
809 fn test_center_crop() {
810 let input = Tensor::from_vec((1..=16).map(|x| x as f32).collect(), &[4, 4]).unwrap();
811
812 let crop = CenterCrop::new(2, 2);
813 let output = crop.apply(&input);
814
815 assert_eq!(output.shape(), &[2, 2]);
816 assert_eq!(output.to_vec(), vec![6.0, 7.0, 10.0, 11.0]);
818 }
819
820 #[test]
821 fn test_random_horizontal_flip() {
822 let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
823
824 let flip = RandomHorizontalFlip::with_probability(1.0);
825 let output = flip.apply(&input);
826
827 assert_eq!(output.to_vec(), vec![2.0, 1.0, 4.0, 3.0]);
829 }
830
831 #[test]
832 fn test_random_vertical_flip() {
833 let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
834
835 let flip = RandomVerticalFlip::with_probability(1.0);
836 let output = flip.apply(&input);
837
838 assert_eq!(output.to_vec(), vec![3.0, 4.0, 1.0, 2.0]);
840 }
841
842 #[test]
843 fn test_random_rotation_180() {
844 let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
845
846 let rotation = RandomRotation::with_angles(vec![180]);
847 let output = rotation.apply(&input);
848
849 assert_eq!(output.to_vec(), vec![4.0, 3.0, 2.0, 1.0]);
851 }
852
853 #[test]
854 fn test_grayscale() {
855 let input = Tensor::from_vec(
856 vec![
857 1.0, 1.0, 1.0, 1.0, 0.5, 0.5, 0.5, 0.5, 0.0, 0.0, 0.0, 0.0, ],
861 &[3, 2, 2],
862 )
863 .unwrap();
864
865 let gray = Grayscale::new();
866 let output = gray.apply(&input);
867
868 assert_eq!(output.shape(), &[1, 2, 2]);
869 let expected = 0.299 + 0.587 * 0.5;
871 for val in output.to_vec() {
872 assert!((val - expected).abs() < 0.001);
873 }
874 }
875
876 #[test]
877 fn test_image_normalize() {
878 let input = Tensor::from_vec(vec![0.5; 3 * 2 * 2], &[3, 2, 2]).unwrap();
879
880 let normalize = ImageNormalize::new(vec![0.5, 0.5, 0.5], vec![0.5, 0.5, 0.5]);
881 let output = normalize.apply(&input);
882
883 for val in output.to_vec() {
885 assert!((val - 0.0).abs() < 0.001);
886 }
887 }
888
889 #[test]
890 fn test_pad() {
891 let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
892
893 let pad = Pad::new(1);
894 let output = pad.apply(&input);
895
896 assert_eq!(output.shape(), &[4, 4]);
897 let data = output.to_vec();
899 assert_eq!(data[0], 0.0);
900 assert_eq!(data[3], 0.0);
901 assert_eq!(data[12], 0.0);
902 assert_eq!(data[15], 0.0);
903 assert_eq!(data[5], 1.0);
905 assert_eq!(data[6], 2.0);
906 assert_eq!(data[9], 3.0);
907 assert_eq!(data[10], 4.0);
908 }
909
910 #[test]
911 fn test_to_tensor_image() {
912 let input = Tensor::from_vec(vec![0.0, 127.5, 255.0], &[3]).unwrap();
913
914 let transform = ToTensorImage::new();
915 let output = transform.apply(&input);
916
917 let data = output.to_vec();
918 assert!((data[0] - 0.0).abs() < 0.001);
919 assert!((data[1] - 0.5).abs() < 0.001);
920 assert!((data[2] - 1.0).abs() < 0.001);
921 }
922
923 #[test]
924 fn test_color_jitter() {
925 let input = Tensor::from_vec(vec![0.5; 3 * 4 * 4], &[3, 4, 4]).unwrap();
926
927 let jitter = ColorJitter::new(0.1, 0.1, 0.1);
928 let output = jitter.apply(&input);
929
930 assert_eq!(output.shape(), &[3, 4, 4]);
931 for val in output.to_vec() {
933 assert!((0.0..=1.0).contains(&val));
934 }
935 }
936}