1use crate::error::{SslError, SslResult};
16use crate::handle::LcgRng;
17
18#[derive(Debug, Clone, PartialEq)]
22pub enum AugOp {
23 Identity,
25 AutoContrast,
27 Equalize,
29 Rotate,
31 Solarize,
33 Color,
35 Posterize,
37 Contrast,
39 Brightness,
41 Sharpness,
43 ShearX,
45 ShearY,
47 TranslateX,
49 TranslateY,
51}
52
53pub fn all_aug_ops() -> Vec<AugOp> {
55 vec![
56 AugOp::Identity,
57 AugOp::AutoContrast,
58 AugOp::Equalize,
59 AugOp::Rotate,
60 AugOp::Solarize,
61 AugOp::Color,
62 AugOp::Posterize,
63 AugOp::Contrast,
64 AugOp::Brightness,
65 AugOp::Sharpness,
66 AugOp::ShearX,
67 AugOp::ShearY,
68 AugOp::TranslateX,
69 AugOp::TranslateY,
70 ]
71}
72
73#[derive(Debug, Clone)]
77pub struct RandAugmentConfig {
78 pub n_ops: usize,
80 pub magnitude: f32,
82 pub fill_value: f32,
84 pub ops: Vec<AugOp>,
86}
87
88impl Default for RandAugmentConfig {
89 fn default() -> Self {
90 Self {
91 n_ops: 2,
92 magnitude: 9.0,
93 fill_value: 0.5,
94 ops: all_aug_ops(),
95 }
96 }
97}
98
99impl RandAugmentConfig {
100 pub fn validate(&self) -> SslResult<()> {
102 if !(self.magnitude.is_finite() && (0.0..=30.0).contains(&self.magnitude)) {
103 return Err(SslError::InvalidParameter {
104 name: "magnitude".into(),
105 reason: format!("must be in [0, 30] and finite, got {}", self.magnitude),
106 });
107 }
108 if !(self.fill_value.is_finite() && (0.0..=1.0).contains(&self.fill_value)) {
109 return Err(SslError::InvalidParameter {
110 name: "fill_value".into(),
111 reason: format!("must be in [0, 1] and finite, got {}", self.fill_value),
112 });
113 }
114 if self.ops.is_empty() {
115 return Err(SslError::InvalidParameter {
116 name: "ops".into(),
117 reason: "must contain at least one operation".into(),
118 });
119 }
120 Ok(())
121 }
122}
123
124pub type SubPolicy = ((AugOp, f32, usize), (AugOp, f32, usize));
132
133#[derive(Debug, Clone)]
135pub enum AutoAugPolicy {
136 ImageNet,
138 Cifar10,
140 Custom(Vec<SubPolicy>),
142}
143
144#[derive(Debug, Clone)]
146pub struct AutoAugmentConfig {
147 pub policy: AutoAugPolicy,
149 pub fill_value: f32,
151}
152
153impl Default for AutoAugmentConfig {
154 fn default() -> Self {
155 Self {
156 policy: AutoAugPolicy::ImageNet,
157 fill_value: 0.5,
158 }
159 }
160}
161
162#[inline]
166fn chw_idx(c: usize, y: usize, x: usize, height: usize, width: usize) -> usize {
167 c * height * width + y * width + x
168}
169
170fn bilinear_sample(
174 plane: &[f32],
175 height: usize,
176 width: usize,
177 fy: f32,
178 fx: f32,
179 fill_value: f32,
180) -> f32 {
181 if fy < 0.0 || fx < 0.0 || fy > (height - 1) as f32 || fx > (width - 1) as f32 {
182 return fill_value;
183 }
184 let y0 = fy.floor() as usize;
185 let x0 = fx.floor() as usize;
186 let y1 = (y0 + 1).min(height - 1);
187 let x1 = (x0 + 1).min(width - 1);
188 let dy = fy - y0 as f32;
189 let dx = fx - x0 as f32;
190
191 let v00 = plane[y0 * width + x0];
192 let v01 = plane[y0 * width + x1];
193 let v10 = plane[y1 * width + x0];
194 let v11 = plane[y1 * width + x1];
195
196 let top = v00 * (1.0 - dx) + v01 * dx;
197 let bot = v10 * (1.0 - dx) + v11 * dx;
198 top * (1.0 - dy) + bot * dy
199}
200
201#[allow(clippy::too_many_arguments)]
212fn warp_affine(
213 pixels: &[f32],
214 channels: usize,
215 height: usize,
216 width: usize,
217 a00: f32, a01: f32, a02: f32, a10: f32, a11: f32, a12: f32, fill_value: f32,
225) -> Vec<f32> {
226 let plane = height * width;
227 let mut out = vec![fill_value; channels * plane];
228 for c in 0..channels {
229 let src_plane = &pixels[c * plane..(c + 1) * plane];
230 let dst_plane = &mut out[c * plane..(c + 1) * plane];
231 for y in 0..height {
232 for x in 0..width {
233 let fx = a00 * x as f32 + a01 * y as f32 + a02;
234 let fy = a10 * x as f32 + a11 * y as f32 + a12;
235 dst_plane[y * width + x] =
236 bilinear_sample(src_plane, height, width, fy, fx, fill_value);
237 }
238 }
239 }
240 out
241}
242
243fn op_auto_contrast(pixels: &[f32], channels: usize, height: usize, width: usize) -> Vec<f32> {
245 let plane = height * width;
246 let mut out = pixels.to_vec();
247 for c in 0..channels {
248 let ch = &pixels[c * plane..(c + 1) * plane];
249 let min_v = ch.iter().cloned().fold(f32::INFINITY, f32::min);
250 let max_v = ch.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
251 if (max_v - min_v).abs() < 1e-7 {
252 continue; }
254 let range = max_v - min_v;
255 for (dst, &src) in out[c * plane..(c + 1) * plane].iter_mut().zip(ch.iter()) {
256 *dst = ((src - min_v) / range).clamp(0.0, 1.0);
257 }
258 }
259 out
260}
261
262fn op_equalize(pixels: &[f32], channels: usize, height: usize, width: usize) -> Vec<f32> {
267 const BINS: usize = 256;
268 let plane = height * width;
269 let mut out = pixels.to_vec();
270 for c in 0..channels {
271 let ch = &pixels[c * plane..(c + 1) * plane];
272 let mut hist = [0u32; BINS];
273 for &p in ch.iter() {
274 let bin = ((p * (BINS as f32 - 1.0)).round() as usize).min(BINS - 1);
275 hist[bin] += 1;
276 }
277 let mut cdf = [0u32; BINS];
279 cdf[0] = hist[0];
280 for i in 1..BINS {
281 cdf[i] = cdf[i - 1] + hist[i];
282 }
283 let cdf_min = cdf.iter().find(|&&v| v > 0).copied().unwrap_or(0);
284 let total = plane as u32;
285 let denom = total.saturating_sub(cdf_min);
286 let mut lut = [0.0_f32; BINS];
288 for (i, lut_v) in lut.iter_mut().enumerate() {
289 if denom == 0 {
290 *lut_v = i as f32 / (BINS as f32 - 1.0);
291 } else {
292 let mapped = (cdf[i].saturating_sub(cdf_min)) as f32 / denom as f32;
293 *lut_v = mapped.clamp(0.0, 1.0);
294 }
295 }
296 for (dst, &src) in out[c * plane..(c + 1) * plane].iter_mut().zip(ch.iter()) {
297 let bin = ((src * (BINS as f32 - 1.0)).round() as usize).min(BINS - 1);
298 *dst = lut[bin];
299 }
300 }
301 out
302}
303
304fn op_rotate(
306 pixels: &[f32],
307 channels: usize,
308 height: usize,
309 width: usize,
310 angle_deg: f32,
311 fill_value: f32,
312) -> Vec<f32> {
313 let angle_rad = angle_deg * std::f32::consts::PI / 180.0;
314 let cos_a = angle_rad.cos();
315 let sin_a = angle_rad.sin();
316 let cx = (width as f32 - 1.0) / 2.0;
317 let cy = (height as f32 - 1.0) / 2.0;
318 let a00 = cos_a;
322 let a01 = sin_a;
323 let a02 = -cos_a * cx - sin_a * cy + cx;
324 let a10 = -sin_a;
325 let a11 = cos_a;
326 let a12 = sin_a * cx - cos_a * cy + cy;
327 warp_affine(
328 pixels, channels, height, width, a00, a01, a02, a10, a11, a12, fill_value,
329 )
330}
331
332fn op_solarize(pixels: &[f32], threshold: f32) -> Vec<f32> {
334 pixels
335 .iter()
336 .map(|&p| if p >= threshold { 1.0 - p } else { p })
337 .collect()
338}
339
340fn op_color(pixels: &[f32], channels: usize, height: usize, width: usize, alpha: f32) -> Vec<f32> {
345 if channels != 3 {
346 return pixels.to_vec();
348 }
349 let plane = height * width;
350 let mut out = pixels.to_vec();
351 for i in 0..plane {
352 let r = pixels[i];
353 let g = pixels[plane + i];
354 let b = pixels[2 * plane + i];
355 let y = 0.299 * r + 0.587 * g + 0.114 * b;
356 out[i] = (alpha * r + (1.0 - alpha) * y).clamp(0.0, 1.0);
357 out[plane + i] = (alpha * g + (1.0 - alpha) * y).clamp(0.0, 1.0);
358 out[2 * plane + i] = (alpha * b + (1.0 - alpha) * y).clamp(0.0, 1.0);
359 }
360 out
361}
362
363fn op_posterize(pixels: &[f32], k: u32) -> Vec<f32> {
367 let shift = 8u32.saturating_sub(k);
369 let mask = if shift >= 8 { 0u8 } else { 0xFFu8 << shift };
370 pixels
371 .iter()
372 .map(|&p| {
373 let byte = (p * 255.0).round().clamp(0.0, 255.0) as u8;
374 let masked = byte & mask;
375 (masked as f32 / 255.0).clamp(0.0, 1.0)
376 })
377 .collect()
378}
379
380fn op_contrast(
382 pixels: &[f32],
383 channels: usize,
384 height: usize,
385 width: usize,
386 alpha: f32,
387) -> Vec<f32> {
388 let plane = height * width;
389 let mut out = pixels.to_vec();
390 for c in 0..channels {
391 let ch = &pixels[c * plane..(c + 1) * plane];
392 let mean = ch.iter().sum::<f32>() / plane as f32;
393 for (dst, &src) in out[c * plane..(c + 1) * plane].iter_mut().zip(ch.iter()) {
394 *dst = ((1.0 - alpha) * mean + alpha * src).clamp(0.0, 1.0);
395 }
396 }
397 out
398}
399
400fn op_brightness(pixels: &[f32], strength: f32) -> Vec<f32> {
402 pixels
403 .iter()
404 .map(|&p| (strength * p).clamp(0.0, 1.0))
405 .collect()
406}
407
408fn op_sharpness(
410 pixels: &[f32],
411 channels: usize,
412 height: usize,
413 width: usize,
414 alpha: f32,
415) -> Vec<f32> {
416 let plane = height * width;
418 let mut blurred = vec![0.0_f32; channels * plane];
419 for c in 0..channels {
420 for y in 0..height {
421 for x in 0..width {
422 let mut acc = 0.0_f32;
423 let mut count = 0u32;
424 for dy in 0..3usize {
425 let ny = y + dy;
426 if ny == 0 || ny > height {
427 continue;
428 }
429 let ny = ny - 1;
430 for dx in 0..3usize {
431 let nx = x + dx;
432 if nx == 0 || nx > width {
433 continue;
434 }
435 let nx = nx - 1;
436 acc += pixels[chw_idx(c, ny, nx, height, width)];
437 count += 1;
438 }
439 }
440 blurred[chw_idx(c, y, x, height, width)] =
441 if count > 0 { acc / count as f32 } else { 0.0 };
442 }
443 }
444 }
445 pixels
447 .iter()
448 .zip(blurred.iter())
449 .map(|(&orig, &blur)| (alpha * orig + (1.0 - alpha) * blur).clamp(0.0, 1.0))
450 .collect()
451}
452
453fn op_shear_x(
455 pixels: &[f32],
456 channels: usize,
457 height: usize,
458 width: usize,
459 shear: f32,
460 fill_value: f32,
461) -> Vec<f32> {
462 warp_affine(
464 pixels, channels, height, width, 1.0, -shear, 0.0, 0.0, 1.0, 0.0, fill_value,
467 )
468}
469
470fn op_shear_y(
472 pixels: &[f32],
473 channels: usize,
474 height: usize,
475 width: usize,
476 shear: f32,
477 fill_value: f32,
478) -> Vec<f32> {
479 warp_affine(
481 pixels, channels, height, width, 1.0, 0.0, 0.0, -shear, 1.0, 0.0, fill_value,
484 )
485}
486
487fn op_translate_x(
489 pixels: &[f32],
490 channels: usize,
491 height: usize,
492 width: usize,
493 shift_x: f32,
494 fill_value: f32,
495) -> Vec<f32> {
496 warp_affine(
497 pixels, channels, height, width, 1.0, 0.0, -shift_x, 0.0, 1.0, 0.0, fill_value,
498 )
499}
500
501fn op_translate_y(
503 pixels: &[f32],
504 channels: usize,
505 height: usize,
506 width: usize,
507 shift_y: f32,
508 fill_value: f32,
509) -> Vec<f32> {
510 warp_affine(
511 pixels, channels, height, width, 1.0, 0.0, 0.0, 0.0, 1.0, -shift_y, fill_value,
512 )
513}
514
515pub fn apply_aug_op(
533 pixels: &[f32],
534 channels: usize,
535 height: usize,
536 width: usize,
537 op: &AugOp,
538 magnitude: f32,
539 fill_value: f32,
540) -> SslResult<Vec<f32>> {
541 if channels == 0 || height == 0 || width == 0 {
542 return Err(SslError::EmptyInput);
543 }
544 let expected = channels * height * width;
545 if pixels.len() != expected {
546 return Err(SslError::DimensionMismatch {
547 expected,
548 got: pixels.len(),
549 });
550 }
551 if !(magnitude.is_finite() && (0.0..=30.0).contains(&magnitude)) {
552 return Err(SslError::InvalidParameter {
553 name: "magnitude".into(),
554 reason: format!("must be in [0, 30] and finite, got {magnitude}"),
555 });
556 }
557 if !(fill_value.is_finite() && (0.0..=1.0).contains(&fill_value)) {
558 return Err(SslError::InvalidParameter {
559 name: "fill_value".into(),
560 reason: format!("must be in [0, 1] and finite, got {fill_value}"),
561 });
562 }
563
564 let m = magnitude / 30.0; let result = match op {
567 AugOp::Identity => pixels.to_vec(),
568
569 AugOp::AutoContrast => op_auto_contrast(pixels, channels, height, width),
570
571 AugOp::Equalize => op_equalize(pixels, channels, height, width),
572
573 AugOp::Rotate => {
574 let angle = m * 30.0;
578 op_rotate(pixels, channels, height, width, angle, fill_value)
579 }
580
581 AugOp::Solarize => {
582 let threshold = (1.0 - m).clamp(0.0, 1.0);
585 op_solarize(pixels, threshold)
586 }
587
588 AugOp::Color => {
589 let alpha = (1.0 - m * 0.9).clamp(0.0, 1.0);
591 op_color(pixels, channels, height, width, alpha)
592 }
593
594 AugOp::Posterize => {
595 let k = 8 - (m * 4.0).floor() as u32;
597 let k = k.max(1);
598 op_posterize(pixels, k)
599 }
600
601 AugOp::Contrast => {
602 let alpha = (1.0 - m * 0.9).clamp(0.0, 1.0);
604 op_contrast(pixels, channels, height, width, alpha)
605 }
606
607 AugOp::Brightness => {
608 let strength = (m * 0.9 + 0.1).clamp(0.0, 1.0);
610 op_brightness(pixels, strength)
611 }
612
613 AugOp::Sharpness => {
614 let alpha = m.clamp(0.0, 1.0);
616 op_sharpness(pixels, channels, height, width, alpha)
617 }
618
619 AugOp::ShearX => {
620 let shear = m * 0.3;
621 op_shear_x(pixels, channels, height, width, shear, fill_value)
622 }
623
624 AugOp::ShearY => {
625 let shear = m * 0.3;
626 op_shear_y(pixels, channels, height, width, shear, fill_value)
627 }
628
629 AugOp::TranslateX => {
630 let shift = m * 0.33 * width as f32;
631 op_translate_x(pixels, channels, height, width, shift, fill_value)
632 }
633
634 AugOp::TranslateY => {
635 let shift = m * 0.33 * height as f32;
636 op_translate_y(pixels, channels, height, width, shift, fill_value)
637 }
638 };
639
640 Ok(result)
641}
642
643pub fn rand_augment(
654 pixels: &[f32],
655 channels: usize,
656 height: usize,
657 width: usize,
658 config: &RandAugmentConfig,
659 rng: &mut LcgRng,
660) -> SslResult<Vec<f32>> {
661 if channels == 0 || height == 0 || width == 0 {
662 return Err(SslError::EmptyInput);
663 }
664 let expected = channels * height * width;
665 if pixels.len() != expected {
666 return Err(SslError::DimensionMismatch {
667 expected,
668 got: pixels.len(),
669 });
670 }
671 config.validate()?;
672
673 if config.n_ops == 0 {
674 return Ok(pixels.to_vec());
675 }
676
677 let n_pool = config.ops.len();
678 let mut current = pixels.to_vec();
679
680 for _ in 0..config.n_ops {
681 let idx = rng.next_usize(n_pool);
682 let op = &config.ops[idx];
683 current = apply_aug_op(
684 ¤t,
685 channels,
686 height,
687 width,
688 op,
689 config.magnitude,
690 config.fill_value,
691 )?;
692 }
693 Ok(current)
694}
695
696fn imagenet_sub_policies() -> Vec<SubPolicy> {
703 use AugOp::*;
704 vec![
705 ((Posterize, 0.4, 8), (Rotate, 0.6, 9)),
706 ((Solarize, 0.6, 5), (AutoContrast, 0.6, 5)),
707 ((Equalize, 0.8, 8), (Equalize, 0.6, 3)),
708 ((Posterize, 0.6, 7), (Posterize, 0.6, 6)),
709 ((Equalize, 0.4, 7), (Solarize, 0.2, 4)),
710 ((Equalize, 0.4, 4), (Rotate, 0.8, 8)),
711 ((Solarize, 0.6, 3), (Equalize, 0.6, 7)),
712 ((Posterize, 0.8, 5), (Equalize, 1.0, 2)),
713 ((Rotate, 0.2, 3), (Solarize, 0.6, 8)),
714 ((Equalize, 0.6, 8), (Posterize, 0.4, 6)),
715 ((Rotate, 0.8, 8), (Color, 1.0, 2)),
716 ((Rotate, 0.9, 9), (Equalize, 1.0, 2)),
717 ((Equalize, 0.6, 7), (Equalize, 0.6, 3)),
718 ((Equalize, 0.6, 4), (Rotate, 0.6, 4)),
719 ((Solarize, 0.6, 7), (Rotate, 0.6, 3)),
720 ((ShearX, 0.8, 8), (Solarize, 0.8, 4)),
721 ((Color, 0.8, 3), (Color, 1.0, 7)),
722 ((Color, 0.4, 1), (Rotate, 0.6, 8)),
723 ((Color, 0.8, 8), (Solarize, 0.8, 8)),
724 ((Equalize, 0.4, 8), (Equalize, 0.8, 3)),
725 ((Posterize, 0.4, 6), (Rotate, 0.4, 3)),
726 ((Equalize, 0.6, 7), (Color, 0.4, 4)),
727 ((Color, 0.4, 9), (Equalize, 0.6, 3)),
728 ((Color, 0.8, 8), (Contrast, 0.6, 1)),
729 ((Rotate, 0.8, 8), (Contrast, 1.0, 2)),
730 ]
731}
732
733fn cifar10_sub_policies() -> Vec<SubPolicy> {
735 use AugOp::*;
736 vec![
737 ((Equalize, 0.1, 8), (ShearY, 0.6, 4)),
738 ((Color, 0.6, 1), (Equalize, 0.6, 2)),
739 ((Sharpness, 0.6, 7), (Brightness, 0.6, 6)),
740 ((AutoContrast, 0.4, 0), (Equalize, 0.6, 0)),
741 ((Equalize, 1.0, 9), (ShearY, 0.6, 3)),
742 ((Color, 0.4, 3), (AutoContrast, 0.6, 1)),
743 ((ShearX, 0.8, 5), (Color, 1.0, 3)),
744 ((ShearX, 0.4, 4), (Posterize, 0.4, 7)),
745 ((Color, 0.4, 3), (Brightness, 0.6, 7)),
746 ((ShearY, 0.6, 4), (Color, 1.0, 9)),
747 ((Equalize, 0.6, 9), (Posterize, 0.4, 6)),
748 ((Solarize, 0.4, 9), (AutoContrast, 0.6, 3)),
749 ((AutoContrast, 0.6, 1), (Posterize, 0.6, 9)),
750 ((Equalize, 0.4, 9), (Solarize, 0.4, 5)),
751 ((Brightness, 0.2, 1), (Equalize, 0.6, 2)),
752 ((Equalize, 0.0, 0), (Equalize, 1.0, 0)),
753 ((AutoContrast, 0.2, 0), (Equalize, 0.6, 0)),
754 ((Equalize, 0.2, 0), (AutoContrast, 0.6, 0)),
755 ((Contrast, 0.2, 0), (Equalize, 0.6, 0)),
756 ((Brightness, 0.6, 5), (Contrast, 0.6, 6)),
757 ((AutoContrast, 0.8, 5), (Rotate, 0.6, 2)),
758 ((Solarize, 0.4, 3), (Brightness, 0.8, 9)),
759 ((Rotate, 0.6, 6), (Color, 1.0, 1)),
760 ((Equalize, 0.4, 5), (AutoContrast, 0.6, 5)),
761 ((Rotate, 0.6, 6), (Posterize, 0.8, 8)),
762 ]
763}
764
765pub fn auto_augment(
779 pixels: &[f32],
780 channels: usize,
781 height: usize,
782 width: usize,
783 config: &AutoAugmentConfig,
784 rng: &mut LcgRng,
785) -> SslResult<Vec<f32>> {
786 if channels == 0 || height == 0 || width == 0 {
787 return Err(SslError::EmptyInput);
788 }
789 let expected = channels * height * width;
790 if pixels.len() != expected {
791 return Err(SslError::DimensionMismatch {
792 expected,
793 got: pixels.len(),
794 });
795 }
796 if !(config.fill_value.is_finite() && (0.0..=1.0).contains(&config.fill_value)) {
797 return Err(SslError::InvalidParameter {
798 name: "fill_value".into(),
799 reason: format!("must be in [0, 1] and finite, got {}", config.fill_value),
800 });
801 }
802
803 let sub_policies: Vec<SubPolicy> = match &config.policy {
804 AutoAugPolicy::ImageNet => imagenet_sub_policies(),
805 AutoAugPolicy::Cifar10 => cifar10_sub_policies(),
806 AutoAugPolicy::Custom(v) => v.clone(),
807 };
808
809 if sub_policies.is_empty() {
810 return Err(SslError::InvalidParameter {
811 name: "policy".into(),
812 reason: "policy contains no sub-policies".into(),
813 });
814 }
815
816 let sp_idx = rng.next_usize(sub_policies.len());
818 let ((op1, prob1, mag_level1), (op2, prob2, mag_level2)) = &sub_policies[sp_idx];
819
820 let mag1 = (*mag_level1 as f32 * 3.0).clamp(0.0, 30.0);
822 let mag2 = (*mag_level2 as f32 * 3.0).clamp(0.0, 30.0);
823
824 let mut current = pixels.to_vec();
825
826 if rng.next_f32() < *prob1 {
827 current = apply_aug_op(
828 ¤t,
829 channels,
830 height,
831 width,
832 op1,
833 mag1,
834 config.fill_value,
835 )?;
836 }
837 if rng.next_f32() < *prob2 {
838 current = apply_aug_op(
839 ¤t,
840 channels,
841 height,
842 width,
843 op2,
844 mag2,
845 config.fill_value,
846 )?;
847 }
848 Ok(current)
849}
850
851#[cfg(test)]
854mod tests {
855 use super::*;
856
857 fn gradient_image(channels: usize, height: usize, width: usize) -> Vec<f32> {
861 let n = channels * height * width;
862 (0..n)
863 .map(|i| {
864 let v = (i as f32) / (n as f32);
865 v.clamp(0.0, 1.0)
866 })
867 .collect()
868 }
869
870 fn assert_unit_range(pixels: &[f32], label: &str) {
872 for (i, &v) in pixels.iter().enumerate() {
873 assert!(
874 (0.0..=1.0).contains(&v),
875 "{label}: pixel[{i}] = {v} out of [0, 1]"
876 );
877 }
878 }
879
880 #[test]
883 fn output_shape_equals_input_for_all_ops() {
884 let (c, h, w) = (3, 16, 16);
885 let img = gradient_image(c, h, w);
886 let expected_len = c * h * w;
887
888 for op in all_aug_ops() {
889 let out =
890 apply_aug_op(&img, c, h, w, &op, 15.0, 0.5).expect("apply_aug_op should succeed");
891 assert_eq!(out.len(), expected_len, "shape mismatch for op {:?}", op);
892 }
893 }
894
895 #[test]
898 fn all_pixels_in_unit_range_for_all_ops() {
899 let (c, h, w) = (3, 16, 16);
900 let img = gradient_image(c, h, w);
901
902 for op in all_aug_ops() {
903 let out =
904 apply_aug_op(&img, c, h, w, &op, 20.0, 0.5).expect("apply_aug_op should succeed");
905 assert_unit_range(&out, &format!("{op:?}"));
906 }
907 }
908
909 #[test]
912 fn identity_op_returns_exact_copy() {
913 let (c, h, w) = (3, 8, 8);
914 let img = gradient_image(c, h, w);
915 let out = apply_aug_op(&img, c, h, w, &AugOp::Identity, 15.0, 0.5)
916 .expect("apply_aug_op should succeed");
917 assert_eq!(out, img, "Identity must return exact copy");
918 }
919
920 #[test]
923 fn auto_contrast_stretches_to_unit() {
924 let (c, h, w) = (3, 4, 4);
926 let plane = h * w;
927 let mut img = vec![0.0_f32; c * plane];
928 for v in img[0..plane].iter_mut() {
930 *v = 0.5;
931 }
932 img[0] = 0.2;
933 img[plane - 1] = 0.8;
934 for v in img[plane..2 * plane].iter_mut() {
936 *v = 0.5;
937 }
938 img[plane] = 0.1;
939 img[2 * plane - 1] = 0.9;
940 for v in img[2 * plane..].iter_mut() {
942 *v = 0.3;
943 }
944
945 let out = apply_aug_op(&img, c, h, w, &AugOp::AutoContrast, 0.0, 0.5)
946 .expect("apply_aug_op should succeed");
947 let ch0_min = out[..plane].iter().cloned().fold(f32::INFINITY, f32::min);
949 let ch0_max = out[..plane]
950 .iter()
951 .cloned()
952 .fold(f32::NEG_INFINITY, f32::max);
953 assert!(ch0_min.abs() < 1e-5, "ch0 min = {ch0_min}");
954 assert!((ch0_max - 1.0).abs() < 1e-5, "ch0 max = {ch0_max}");
955 for &v in &out[2 * plane..] {
957 assert!((v - 0.3).abs() < 1e-5, "constant channel changed: {v}");
958 }
959 }
960
961 #[test]
964 fn equalize_output_in_unit_range() {
965 let (c, h, w) = (1, 32, 32);
966 let img = gradient_image(c, h, w);
967 let out = apply_aug_op(&img, c, h, w, &AugOp::Equalize, 0.0, 0.5)
968 .expect("apply_aug_op should succeed");
969 assert_unit_range(&out, "Equalize");
970 assert_eq!(out.len(), c * h * w);
971 }
972
973 #[test]
976 fn rotate_zero_degrees_approx_identity() {
977 let (c, h, w) = (1, 8, 8);
978 let img = gradient_image(c, h, w);
979 let out = apply_aug_op(&img, c, h, w, &AugOp::Rotate, 0.0, 0.5)
981 .expect("apply_aug_op should succeed");
982 for (i, (&a, &b)) in img.iter().zip(out.iter()).enumerate() {
983 assert!(
984 (a - b).abs() < 1e-4,
985 "rotate(0°): pixel[{i}]: input={a} output={b}"
986 );
987 }
988 }
989
990 #[test]
993 fn solarize_threshold_one_unchanged() {
994 let (c, h, w) = (3, 8, 8);
997 let img = gradient_image(c, h, w);
998 let out = apply_aug_op(&img, c, h, w, &AugOp::Solarize, 0.0, 0.5)
999 .expect("apply_aug_op should succeed");
1000 for (i, (&a, &b)) in img.iter().zip(out.iter()).enumerate() {
1002 if a < 1.0 {
1003 assert!(
1004 (a - b).abs() < 1e-6,
1005 "solarize(threshold=1): pixel[{i}] changed: {a}→{b}"
1006 );
1007 }
1008 }
1009 }
1010
1011 #[test]
1014 fn rand_augment_zero_ops_unchanged() {
1015 let (c, h, w) = (3, 8, 8);
1016 let img = gradient_image(c, h, w);
1017 let config = RandAugmentConfig {
1018 n_ops: 0,
1019 magnitude: 9.0,
1020 fill_value: 0.5,
1021 ops: all_aug_ops(),
1022 };
1023 let mut rng = LcgRng::new(42);
1024 let out =
1025 rand_augment(&img, c, h, w, &config, &mut rng).expect("rand_augment should succeed");
1026 assert_eq!(out, img, "n_ops=0 must return exact input copy");
1027 }
1028
1029 #[test]
1032 fn rand_augment_output_valid_shape_and_range() {
1033 let (c, h, w) = (3, 16, 16);
1034 let img = gradient_image(c, h, w);
1035 let config = RandAugmentConfig {
1036 n_ops: 3,
1037 magnitude: 15.0,
1038 fill_value: 0.5,
1039 ops: all_aug_ops(),
1040 };
1041 let mut rng = LcgRng::new(7);
1042 let out =
1043 rand_augment(&img, c, h, w, &config, &mut rng).expect("rand_augment should succeed");
1044 assert_eq!(out.len(), c * h * w);
1045 assert_unit_range(&out, "RandAugment(N=3)");
1046 }
1047
1048 #[test]
1051 fn auto_augment_imagenet_output_finite_and_valid() {
1052 let (c, h, w) = (3, 16, 16);
1053 let img = gradient_image(c, h, w);
1054 let config = AutoAugmentConfig {
1055 policy: AutoAugPolicy::ImageNet,
1056 fill_value: 0.5,
1057 };
1058 let mut rng = LcgRng::new(13);
1059 let out =
1060 auto_augment(&img, c, h, w, &config, &mut rng).expect("auto_augment should succeed");
1061 assert_eq!(out.len(), c * h * w);
1062 assert_unit_range(&out, "AutoAugment(ImageNet)");
1063 for &v in &out {
1064 assert!(v.is_finite(), "non-finite pixel in AutoAugment output");
1065 }
1066 }
1067
1068 #[test]
1071 fn different_seeds_produce_different_outputs() {
1072 let (c, h, w) = (3, 16, 16);
1073 let img = gradient_image(c, h, w);
1074 let config = RandAugmentConfig::default();
1075
1076 let mut rng_a = LcgRng::new(1);
1077 let mut rng_b = LcgRng::new(999);
1078 let out_a =
1079 rand_augment(&img, c, h, w, &config, &mut rng_a).expect("rand_augment should succeed");
1080 let out_b =
1081 rand_augment(&img, c, h, w, &config, &mut rng_b).expect("rand_augment should succeed");
1082
1083 let identical = out_a
1086 .iter()
1087 .zip(out_b.iter())
1088 .all(|(a, b)| (a - b).abs() < 1e-8);
1089 assert!(!identical, "different seeds must produce different outputs");
1090 }
1091
1092 #[test]
1095 fn same_seed_produces_same_output() {
1096 let (c, h, w) = (3, 16, 16);
1097 let img = gradient_image(c, h, w);
1098 let config = RandAugmentConfig::default();
1099
1100 let mut rng_a = LcgRng::new(42);
1101 let mut rng_b = LcgRng::new(42);
1102 let out_a =
1103 rand_augment(&img, c, h, w, &config, &mut rng_a).expect("rand_augment should succeed");
1104 let out_b =
1105 rand_augment(&img, c, h, w, &config, &mut rng_b).expect("rand_augment should succeed");
1106 assert_eq!(out_a, out_b, "same seed must produce identical output");
1107 }
1108
1109 #[test]
1112 fn brightness_low_magnitude_dims_image() {
1113 let (c, h, w) = (3, 8, 8);
1114 let img = vec![0.8_f32; c * h * w];
1115 let out = apply_aug_op(&img, c, h, w, &AugOp::Brightness, 0.0, 0.5)
1117 .expect("apply_aug_op should succeed");
1118 let mean_out: f32 = out.iter().sum::<f32>() / out.len() as f32;
1119 assert!(
1121 mean_out < 0.2,
1122 "Brightness(mag=0) should produce near-black image, got mean={mean_out}"
1123 );
1124 }
1125
1126 #[test]
1129 fn all_14_ops_run_without_error() {
1130 let (c, h, w) = (3, 12, 12);
1131 let img = gradient_image(c, h, w);
1132 for mag in [0.0_f32, 9.0, 15.0, 30.0] {
1133 for op in all_aug_ops() {
1134 let result = apply_aug_op(&img, c, h, w, &op, mag, 0.5);
1135 assert!(
1136 result.is_ok(),
1137 "op {:?} at magnitude={mag} returned error: {:?}",
1138 op,
1139 result
1140 );
1141 assert_unit_range(
1142 &result.expect("result should be present"),
1143 &format!("{op:?}@{mag}"),
1144 );
1145 }
1146 }
1147 }
1148
1149 #[test]
1152 fn auto_augment_cifar10_output_valid() {
1153 let (c, h, w) = (3, 32, 32);
1154 let img = gradient_image(c, h, w);
1155 let config = AutoAugmentConfig {
1156 policy: AutoAugPolicy::Cifar10,
1157 fill_value: 0.5,
1158 };
1159 let mut rng = LcgRng::new(77);
1160 let out =
1161 auto_augment(&img, c, h, w, &config, &mut rng).expect("auto_augment should succeed");
1162 assert_eq!(out.len(), c * h * w);
1163 assert_unit_range(&out, "AutoAugment(Cifar10)");
1164 }
1165
1166 #[test]
1169 fn auto_augment_custom_policy_identity_always() {
1170 let (c, h, w) = (3, 8, 8);
1172 let img = gradient_image(c, h, w);
1173 let config = AutoAugmentConfig {
1174 policy: AutoAugPolicy::Custom(vec![(
1175 (AugOp::Identity, 1.0, 0),
1176 (AugOp::Identity, 1.0, 0),
1177 )]),
1178 fill_value: 0.5,
1179 };
1180 let mut rng = LcgRng::new(1);
1181 let out =
1182 auto_augment(&img, c, h, w, &config, &mut rng).expect("auto_augment should succeed");
1183 assert_eq!(
1184 out, img,
1185 "custom Identity × Identity should return exact copy"
1186 );
1187 }
1188
1189 #[test]
1192 fn error_on_empty_input() {
1193 let result = apply_aug_op(&[], 0, 8, 8, &AugOp::Identity, 0.0, 0.5);
1194 assert!(matches!(result, Err(SslError::EmptyInput)));
1195 }
1196
1197 #[test]
1200 fn error_on_dimension_mismatch() {
1201 let img = vec![0.5_f32; 10]; let result = apply_aug_op(&img, 3, 4, 4, &AugOp::Identity, 0.0, 0.5);
1203 assert!(matches!(result, Err(SslError::DimensionMismatch { .. })));
1204 }
1205
1206 #[test]
1209 fn posterize_full_magnitude_reduces_unique_values() {
1210 let (c, h, w) = (1, 16, 16);
1211 let img = gradient_image(c, h, w);
1212 let out = apply_aug_op(&img, c, h, w, &AugOp::Posterize, 30.0, 0.5)
1214 .expect("apply_aug_op should succeed");
1215 let mut values: Vec<u32> = out.iter().map(|&v| (v * 255.0).round() as u32).collect();
1217 values.sort_unstable();
1218 values.dedup();
1219 assert!(
1220 values.len() <= 16,
1221 "expected ≤16 distinct values after 4-bit posterize, got {}",
1222 values.len()
1223 );
1224 }
1225
1226 #[test]
1229 fn sharpness_full_magnitude_is_original() {
1230 let (c, h, w) = (3, 8, 8);
1231 let img = gradient_image(c, h, w);
1232 let out = apply_aug_op(&img, c, h, w, &AugOp::Sharpness, 30.0, 0.5)
1234 .expect("apply_aug_op should succeed");
1235 for (i, (&a, &b)) in img.iter().zip(out.iter()).enumerate() {
1236 assert!(
1237 (a - b).abs() < 1e-5,
1238 "Sharpness(1.0): pixel[{i}] input={a} output={b}"
1239 );
1240 }
1241 }
1242}