1use crate::error::{SslError, SslResult};
16use crate::handle::LcgRng;
17
18#[derive(Debug, Clone, PartialEq)]
22pub enum AugOp {
23 Identity,
25 AutoContrast,
27 Equalize,
29 Rotate,
31 Solarize,
33 Color,
35 Posterize,
37 Contrast,
39 Brightness,
41 Sharpness,
43 ShearX,
45 ShearY,
47 TranslateX,
49 TranslateY,
51}
52
53pub fn all_aug_ops() -> Vec<AugOp> {
55 vec![
56 AugOp::Identity,
57 AugOp::AutoContrast,
58 AugOp::Equalize,
59 AugOp::Rotate,
60 AugOp::Solarize,
61 AugOp::Color,
62 AugOp::Posterize,
63 AugOp::Contrast,
64 AugOp::Brightness,
65 AugOp::Sharpness,
66 AugOp::ShearX,
67 AugOp::ShearY,
68 AugOp::TranslateX,
69 AugOp::TranslateY,
70 ]
71}
72
73#[derive(Debug, Clone)]
77pub struct RandAugmentConfig {
78 pub n_ops: usize,
80 pub magnitude: f32,
82 pub fill_value: f32,
84 pub ops: Vec<AugOp>,
86}
87
88impl Default for RandAugmentConfig {
89 fn default() -> Self {
90 Self {
91 n_ops: 2,
92 magnitude: 9.0,
93 fill_value: 0.5,
94 ops: all_aug_ops(),
95 }
96 }
97}
98
99impl RandAugmentConfig {
100 pub fn validate(&self) -> SslResult<()> {
102 if !(self.magnitude.is_finite() && (0.0..=30.0).contains(&self.magnitude)) {
103 return Err(SslError::InvalidParameter {
104 name: "magnitude".into(),
105 reason: format!("must be in [0, 30] and finite, got {}", self.magnitude),
106 });
107 }
108 if !(self.fill_value.is_finite() && (0.0..=1.0).contains(&self.fill_value)) {
109 return Err(SslError::InvalidParameter {
110 name: "fill_value".into(),
111 reason: format!("must be in [0, 1] and finite, got {}", self.fill_value),
112 });
113 }
114 if self.ops.is_empty() {
115 return Err(SslError::InvalidParameter {
116 name: "ops".into(),
117 reason: "must contain at least one operation".into(),
118 });
119 }
120 Ok(())
121 }
122}
123
124pub type SubPolicy = ((AugOp, f32, usize), (AugOp, f32, usize));
132
133#[derive(Debug, Clone)]
135pub enum AutoAugPolicy {
136 ImageNet,
138 Cifar10,
140 Custom(Vec<SubPolicy>),
142}
143
144#[derive(Debug, Clone)]
146pub struct AutoAugmentConfig {
147 pub policy: AutoAugPolicy,
149 pub fill_value: f32,
151}
152
153impl Default for AutoAugmentConfig {
154 fn default() -> Self {
155 Self {
156 policy: AutoAugPolicy::ImageNet,
157 fill_value: 0.5,
158 }
159 }
160}
161
162#[inline]
166fn chw_idx(c: usize, y: usize, x: usize, height: usize, width: usize) -> usize {
167 c * height * width + y * width + x
168}
169
170fn bilinear_sample(
174 plane: &[f32],
175 height: usize,
176 width: usize,
177 fy: f32,
178 fx: f32,
179 fill_value: f32,
180) -> f32 {
181 if fy < 0.0 || fx < 0.0 || fy > (height - 1) as f32 || fx > (width - 1) as f32 {
182 return fill_value;
183 }
184 let y0 = fy.floor() as usize;
185 let x0 = fx.floor() as usize;
186 let y1 = (y0 + 1).min(height - 1);
187 let x1 = (x0 + 1).min(width - 1);
188 let dy = fy - y0 as f32;
189 let dx = fx - x0 as f32;
190
191 let v00 = plane[y0 * width + x0];
192 let v01 = plane[y0 * width + x1];
193 let v10 = plane[y1 * width + x0];
194 let v11 = plane[y1 * width + x1];
195
196 let top = v00 * (1.0 - dx) + v01 * dx;
197 let bot = v10 * (1.0 - dx) + v11 * dx;
198 top * (1.0 - dy) + bot * dy
199}
200
201#[allow(clippy::too_many_arguments)]
212fn warp_affine(
213 pixels: &[f32],
214 channels: usize,
215 height: usize,
216 width: usize,
217 a00: f32, a01: f32, a02: f32, a10: f32, a11: f32, a12: f32, fill_value: f32,
225) -> Vec<f32> {
226 let plane = height * width;
227 let mut out = vec![fill_value; channels * plane];
228 for c in 0..channels {
229 let src_plane = &pixels[c * plane..(c + 1) * plane];
230 let dst_plane = &mut out[c * plane..(c + 1) * plane];
231 for y in 0..height {
232 for x in 0..width {
233 let fx = a00 * x as f32 + a01 * y as f32 + a02;
234 let fy = a10 * x as f32 + a11 * y as f32 + a12;
235 dst_plane[y * width + x] =
236 bilinear_sample(src_plane, height, width, fy, fx, fill_value);
237 }
238 }
239 }
240 out
241}
242
243fn op_auto_contrast(pixels: &[f32], channels: usize, height: usize, width: usize) -> Vec<f32> {
245 let plane = height * width;
246 let mut out = pixels.to_vec();
247 for c in 0..channels {
248 let ch = &pixels[c * plane..(c + 1) * plane];
249 let min_v = ch.iter().cloned().fold(f32::INFINITY, f32::min);
250 let max_v = ch.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
251 if (max_v - min_v).abs() < 1e-7 {
252 continue; }
254 let range = max_v - min_v;
255 for (dst, &src) in out[c * plane..(c + 1) * plane].iter_mut().zip(ch.iter()) {
256 *dst = ((src - min_v) / range).clamp(0.0, 1.0);
257 }
258 }
259 out
260}
261
262fn op_equalize(pixels: &[f32], channels: usize, height: usize, width: usize) -> Vec<f32> {
267 const BINS: usize = 256;
268 let plane = height * width;
269 let mut out = pixels.to_vec();
270 for c in 0..channels {
271 let ch = &pixels[c * plane..(c + 1) * plane];
272 let mut hist = [0u32; BINS];
273 for &p in ch.iter() {
274 let bin = ((p * (BINS as f32 - 1.0)).round() as usize).min(BINS - 1);
275 hist[bin] += 1;
276 }
277 let mut cdf = [0u32; BINS];
279 cdf[0] = hist[0];
280 for i in 1..BINS {
281 cdf[i] = cdf[i - 1] + hist[i];
282 }
283 let cdf_min = cdf.iter().find(|&&v| v > 0).copied().unwrap_or(0);
284 let total = plane as u32;
285 let denom = total.saturating_sub(cdf_min);
286 let mut lut = [0.0_f32; BINS];
288 for (i, lut_v) in lut.iter_mut().enumerate() {
289 if denom == 0 {
290 *lut_v = i as f32 / (BINS as f32 - 1.0);
291 } else {
292 let mapped = (cdf[i].saturating_sub(cdf_min)) as f32 / denom as f32;
293 *lut_v = mapped.clamp(0.0, 1.0);
294 }
295 }
296 for (dst, &src) in out[c * plane..(c + 1) * plane].iter_mut().zip(ch.iter()) {
297 let bin = ((src * (BINS as f32 - 1.0)).round() as usize).min(BINS - 1);
298 *dst = lut[bin];
299 }
300 }
301 out
302}
303
304fn op_rotate(
306 pixels: &[f32],
307 channels: usize,
308 height: usize,
309 width: usize,
310 angle_deg: f32,
311 fill_value: f32,
312) -> Vec<f32> {
313 let angle_rad = angle_deg * std::f32::consts::PI / 180.0;
314 let cos_a = angle_rad.cos();
315 let sin_a = angle_rad.sin();
316 let cx = (width as f32 - 1.0) / 2.0;
317 let cy = (height as f32 - 1.0) / 2.0;
318 let a00 = cos_a;
322 let a01 = sin_a;
323 let a02 = -cos_a * cx - sin_a * cy + cx;
324 let a10 = -sin_a;
325 let a11 = cos_a;
326 let a12 = sin_a * cx - cos_a * cy + cy;
327 warp_affine(
328 pixels, channels, height, width, a00, a01, a02, a10, a11, a12, fill_value,
329 )
330}
331
332fn op_solarize(pixels: &[f32], threshold: f32) -> Vec<f32> {
334 pixels
335 .iter()
336 .map(|&p| if p >= threshold { 1.0 - p } else { p })
337 .collect()
338}
339
340fn op_color(pixels: &[f32], channels: usize, height: usize, width: usize, alpha: f32) -> Vec<f32> {
345 if channels != 3 {
346 return pixels.to_vec();
348 }
349 let plane = height * width;
350 let mut out = pixels.to_vec();
351 for i in 0..plane {
352 let r = pixels[i];
353 let g = pixels[plane + i];
354 let b = pixels[2 * plane + i];
355 let y = 0.299 * r + 0.587 * g + 0.114 * b;
356 out[i] = (alpha * r + (1.0 - alpha) * y).clamp(0.0, 1.0);
357 out[plane + i] = (alpha * g + (1.0 - alpha) * y).clamp(0.0, 1.0);
358 out[2 * plane + i] = (alpha * b + (1.0 - alpha) * y).clamp(0.0, 1.0);
359 }
360 out
361}
362
363fn op_posterize(pixels: &[f32], k: u32) -> Vec<f32> {
367 let shift = 8u32.saturating_sub(k);
369 let mask = if shift >= 8 { 0u8 } else { 0xFFu8 << shift };
370 pixels
371 .iter()
372 .map(|&p| {
373 let byte = (p * 255.0).round().clamp(0.0, 255.0) as u8;
374 let masked = byte & mask;
375 (masked as f32 / 255.0).clamp(0.0, 1.0)
376 })
377 .collect()
378}
379
380fn op_contrast(
382 pixels: &[f32],
383 channels: usize,
384 height: usize,
385 width: usize,
386 alpha: f32,
387) -> Vec<f32> {
388 let plane = height * width;
389 let mut out = pixels.to_vec();
390 for c in 0..channels {
391 let ch = &pixels[c * plane..(c + 1) * plane];
392 let mean = ch.iter().sum::<f32>() / plane as f32;
393 for (dst, &src) in out[c * plane..(c + 1) * plane].iter_mut().zip(ch.iter()) {
394 *dst = ((1.0 - alpha) * mean + alpha * src).clamp(0.0, 1.0);
395 }
396 }
397 out
398}
399
400fn op_brightness(pixels: &[f32], strength: f32) -> Vec<f32> {
402 pixels
403 .iter()
404 .map(|&p| (strength * p).clamp(0.0, 1.0))
405 .collect()
406}
407
408fn op_sharpness(
410 pixels: &[f32],
411 channels: usize,
412 height: usize,
413 width: usize,
414 alpha: f32,
415) -> Vec<f32> {
416 let plane = height * width;
418 let mut blurred = vec![0.0_f32; channels * plane];
419 for c in 0..channels {
420 for y in 0..height {
421 for x in 0..width {
422 let mut acc = 0.0_f32;
423 let mut count = 0u32;
424 for dy in 0..3usize {
425 let ny = y + dy;
426 if ny == 0 || ny > height {
427 continue;
428 }
429 let ny = ny - 1;
430 for dx in 0..3usize {
431 let nx = x + dx;
432 if nx == 0 || nx > width {
433 continue;
434 }
435 let nx = nx - 1;
436 acc += pixels[chw_idx(c, ny, nx, height, width)];
437 count += 1;
438 }
439 }
440 blurred[chw_idx(c, y, x, height, width)] =
441 if count > 0 { acc / count as f32 } else { 0.0 };
442 }
443 }
444 }
445 pixels
447 .iter()
448 .zip(blurred.iter())
449 .map(|(&orig, &blur)| (alpha * orig + (1.0 - alpha) * blur).clamp(0.0, 1.0))
450 .collect()
451}
452
453fn op_shear_x(
455 pixels: &[f32],
456 channels: usize,
457 height: usize,
458 width: usize,
459 shear: f32,
460 fill_value: f32,
461) -> Vec<f32> {
462 warp_affine(
464 pixels, channels, height, width, 1.0, -shear, 0.0, 0.0, 1.0, 0.0, fill_value,
467 )
468}
469
470fn op_shear_y(
472 pixels: &[f32],
473 channels: usize,
474 height: usize,
475 width: usize,
476 shear: f32,
477 fill_value: f32,
478) -> Vec<f32> {
479 warp_affine(
481 pixels, channels, height, width, 1.0, 0.0, 0.0, -shear, 1.0, 0.0, fill_value,
484 )
485}
486
487fn op_translate_x(
489 pixels: &[f32],
490 channels: usize,
491 height: usize,
492 width: usize,
493 shift_x: f32,
494 fill_value: f32,
495) -> Vec<f32> {
496 warp_affine(
497 pixels, channels, height, width, 1.0, 0.0, -shift_x, 0.0, 1.0, 0.0, fill_value,
498 )
499}
500
501fn op_translate_y(
503 pixels: &[f32],
504 channels: usize,
505 height: usize,
506 width: usize,
507 shift_y: f32,
508 fill_value: f32,
509) -> Vec<f32> {
510 warp_affine(
511 pixels, channels, height, width, 1.0, 0.0, 0.0, 0.0, 1.0, -shift_y, fill_value,
512 )
513}
514
515pub fn apply_aug_op(
533 pixels: &[f32],
534 channels: usize,
535 height: usize,
536 width: usize,
537 op: &AugOp,
538 magnitude: f32,
539 fill_value: f32,
540) -> SslResult<Vec<f32>> {
541 if channels == 0 || height == 0 || width == 0 {
542 return Err(SslError::EmptyInput);
543 }
544 let expected = channels * height * width;
545 if pixels.len() != expected {
546 return Err(SslError::DimensionMismatch {
547 expected,
548 got: pixels.len(),
549 });
550 }
551 if !(magnitude.is_finite() && (0.0..=30.0).contains(&magnitude)) {
552 return Err(SslError::InvalidParameter {
553 name: "magnitude".into(),
554 reason: format!("must be in [0, 30] and finite, got {magnitude}"),
555 });
556 }
557 if !(fill_value.is_finite() && (0.0..=1.0).contains(&fill_value)) {
558 return Err(SslError::InvalidParameter {
559 name: "fill_value".into(),
560 reason: format!("must be in [0, 1] and finite, got {fill_value}"),
561 });
562 }
563
564 let m = magnitude / 30.0; let result = match op {
567 AugOp::Identity => pixels.to_vec(),
568
569 AugOp::AutoContrast => op_auto_contrast(pixels, channels, height, width),
570
571 AugOp::Equalize => op_equalize(pixels, channels, height, width),
572
573 AugOp::Rotate => {
574 let angle = m * 30.0;
578 op_rotate(pixels, channels, height, width, angle, fill_value)
579 }
580
581 AugOp::Solarize => {
582 let threshold = (1.0 - m).clamp(0.0, 1.0);
585 op_solarize(pixels, threshold)
586 }
587
588 AugOp::Color => {
589 let alpha = (1.0 - m * 0.9).clamp(0.0, 1.0);
591 op_color(pixels, channels, height, width, alpha)
592 }
593
594 AugOp::Posterize => {
595 let k = 8 - (m * 4.0).floor() as u32;
597 let k = k.max(1);
598 op_posterize(pixels, k)
599 }
600
601 AugOp::Contrast => {
602 let alpha = (1.0 - m * 0.9).clamp(0.0, 1.0);
604 op_contrast(pixels, channels, height, width, alpha)
605 }
606
607 AugOp::Brightness => {
608 let strength = (m * 0.9 + 0.1).clamp(0.0, 1.0);
610 op_brightness(pixels, strength)
611 }
612
613 AugOp::Sharpness => {
614 let alpha = m.clamp(0.0, 1.0);
616 op_sharpness(pixels, channels, height, width, alpha)
617 }
618
619 AugOp::ShearX => {
620 let shear = m * 0.3;
621 op_shear_x(pixels, channels, height, width, shear, fill_value)
622 }
623
624 AugOp::ShearY => {
625 let shear = m * 0.3;
626 op_shear_y(pixels, channels, height, width, shear, fill_value)
627 }
628
629 AugOp::TranslateX => {
630 let shift = m * 0.33 * width as f32;
631 op_translate_x(pixels, channels, height, width, shift, fill_value)
632 }
633
634 AugOp::TranslateY => {
635 let shift = m * 0.33 * height as f32;
636 op_translate_y(pixels, channels, height, width, shift, fill_value)
637 }
638 };
639
640 Ok(result)
641}
642
643pub fn rand_augment(
654 pixels: &[f32],
655 channels: usize,
656 height: usize,
657 width: usize,
658 config: &RandAugmentConfig,
659 rng: &mut LcgRng,
660) -> SslResult<Vec<f32>> {
661 if channels == 0 || height == 0 || width == 0 {
662 return Err(SslError::EmptyInput);
663 }
664 let expected = channels * height * width;
665 if pixels.len() != expected {
666 return Err(SslError::DimensionMismatch {
667 expected,
668 got: pixels.len(),
669 });
670 }
671 config.validate()?;
672
673 if config.n_ops == 0 {
674 return Ok(pixels.to_vec());
675 }
676
677 let n_pool = config.ops.len();
678 let mut current = pixels.to_vec();
679
680 for _ in 0..config.n_ops {
681 let idx = rng.next_usize(n_pool);
682 let op = &config.ops[idx];
683 current = apply_aug_op(
684 ¤t,
685 channels,
686 height,
687 width,
688 op,
689 config.magnitude,
690 config.fill_value,
691 )?;
692 }
693 Ok(current)
694}
695
696fn imagenet_sub_policies() -> Vec<SubPolicy> {
703 use AugOp::*;
704 vec![
705 ((Posterize, 0.4, 8), (Rotate, 0.6, 9)),
706 ((Solarize, 0.6, 5), (AutoContrast, 0.6, 5)),
707 ((Equalize, 0.8, 8), (Equalize, 0.6, 3)),
708 ((Posterize, 0.6, 7), (Posterize, 0.6, 6)),
709 ((Equalize, 0.4, 7), (Solarize, 0.2, 4)),
710 ((Equalize, 0.4, 4), (Rotate, 0.8, 8)),
711 ((Solarize, 0.6, 3), (Equalize, 0.6, 7)),
712 ((Posterize, 0.8, 5), (Equalize, 1.0, 2)),
713 ((Rotate, 0.2, 3), (Solarize, 0.6, 8)),
714 ((Equalize, 0.6, 8), (Posterize, 0.4, 6)),
715 ((Rotate, 0.8, 8), (Color, 1.0, 2)),
716 ((Rotate, 0.9, 9), (Equalize, 1.0, 2)),
717 ((Equalize, 0.6, 7), (Equalize, 0.6, 3)),
718 ((Equalize, 0.6, 4), (Rotate, 0.6, 4)),
719 ((Solarize, 0.6, 7), (Rotate, 0.6, 3)),
720 ((ShearX, 0.8, 8), (Solarize, 0.8, 4)),
721 ((Color, 0.8, 3), (Color, 1.0, 7)),
722 ((Color, 0.4, 1), (Rotate, 0.6, 8)),
723 ((Color, 0.8, 8), (Solarize, 0.8, 8)),
724 ((Equalize, 0.4, 8), (Equalize, 0.8, 3)),
725 ((Posterize, 0.4, 6), (Rotate, 0.4, 3)),
726 ((Equalize, 0.6, 7), (Color, 0.4, 4)),
727 ((Color, 0.4, 9), (Equalize, 0.6, 3)),
728 ((Color, 0.8, 8), (Contrast, 0.6, 1)),
729 ((Rotate, 0.8, 8), (Contrast, 1.0, 2)),
730 ]
731}
732
733fn cifar10_sub_policies() -> Vec<SubPolicy> {
735 use AugOp::*;
736 vec![
737 ((Equalize, 0.1, 8), (ShearY, 0.6, 4)),
738 ((Color, 0.6, 1), (Equalize, 0.6, 2)),
739 ((Sharpness, 0.6, 7), (Brightness, 0.6, 6)),
740 ((AutoContrast, 0.4, 0), (Equalize, 0.6, 0)),
741 ((Equalize, 1.0, 9), (ShearY, 0.6, 3)),
742 ((Color, 0.4, 3), (AutoContrast, 0.6, 1)),
743 ((ShearX, 0.8, 5), (Color, 1.0, 3)),
744 ((ShearX, 0.4, 4), (Posterize, 0.4, 7)),
745 ((Color, 0.4, 3), (Brightness, 0.6, 7)),
746 ((ShearY, 0.6, 4), (Color, 1.0, 9)),
747 ((Equalize, 0.6, 9), (Posterize, 0.4, 6)),
748 ((Solarize, 0.4, 9), (AutoContrast, 0.6, 3)),
749 ((AutoContrast, 0.6, 1), (Posterize, 0.6, 9)),
750 ((Equalize, 0.4, 9), (Solarize, 0.4, 5)),
751 ((Brightness, 0.2, 1), (Equalize, 0.6, 2)),
752 ((Equalize, 0.0, 0), (Equalize, 1.0, 0)),
753 ((AutoContrast, 0.2, 0), (Equalize, 0.6, 0)),
754 ((Equalize, 0.2, 0), (AutoContrast, 0.6, 0)),
755 ((Contrast, 0.2, 0), (Equalize, 0.6, 0)),
756 ((Brightness, 0.6, 5), (Contrast, 0.6, 6)),
757 ((AutoContrast, 0.8, 5), (Rotate, 0.6, 2)),
758 ((Solarize, 0.4, 3), (Brightness, 0.8, 9)),
759 ((Rotate, 0.6, 6), (Color, 1.0, 1)),
760 ((Equalize, 0.4, 5), (AutoContrast, 0.6, 5)),
761 ((Rotate, 0.6, 6), (Posterize, 0.8, 8)),
762 ]
763}
764
765pub fn auto_augment(
779 pixels: &[f32],
780 channels: usize,
781 height: usize,
782 width: usize,
783 config: &AutoAugmentConfig,
784 rng: &mut LcgRng,
785) -> SslResult<Vec<f32>> {
786 if channels == 0 || height == 0 || width == 0 {
787 return Err(SslError::EmptyInput);
788 }
789 let expected = channels * height * width;
790 if pixels.len() != expected {
791 return Err(SslError::DimensionMismatch {
792 expected,
793 got: pixels.len(),
794 });
795 }
796 if !(config.fill_value.is_finite() && (0.0..=1.0).contains(&config.fill_value)) {
797 return Err(SslError::InvalidParameter {
798 name: "fill_value".into(),
799 reason: format!("must be in [0, 1] and finite, got {}", config.fill_value),
800 });
801 }
802
803 let sub_policies: Vec<SubPolicy> = match &config.policy {
804 AutoAugPolicy::ImageNet => imagenet_sub_policies(),
805 AutoAugPolicy::Cifar10 => cifar10_sub_policies(),
806 AutoAugPolicy::Custom(v) => v.clone(),
807 };
808
809 if sub_policies.is_empty() {
810 return Err(SslError::InvalidParameter {
811 name: "policy".into(),
812 reason: "policy contains no sub-policies".into(),
813 });
814 }
815
816 let sp_idx = rng.next_usize(sub_policies.len());
818 let ((op1, prob1, mag_level1), (op2, prob2, mag_level2)) = &sub_policies[sp_idx];
819
820 let mag1 = (*mag_level1 as f32 * 3.0).clamp(0.0, 30.0);
822 let mag2 = (*mag_level2 as f32 * 3.0).clamp(0.0, 30.0);
823
824 let mut current = pixels.to_vec();
825
826 if rng.next_f32() < *prob1 {
827 current = apply_aug_op(
828 ¤t,
829 channels,
830 height,
831 width,
832 op1,
833 mag1,
834 config.fill_value,
835 )?;
836 }
837 if rng.next_f32() < *prob2 {
838 current = apply_aug_op(
839 ¤t,
840 channels,
841 height,
842 width,
843 op2,
844 mag2,
845 config.fill_value,
846 )?;
847 }
848 Ok(current)
849}
850
851#[cfg(test)]
854mod tests {
855 use super::*;
856
857 fn gradient_image(channels: usize, height: usize, width: usize) -> Vec<f32> {
861 let n = channels * height * width;
862 (0..n)
863 .map(|i| {
864 let v = (i as f32) / (n as f32);
865 v.clamp(0.0, 1.0)
866 })
867 .collect()
868 }
869
870 fn assert_unit_range(pixels: &[f32], label: &str) {
872 for (i, &v) in pixels.iter().enumerate() {
873 assert!(
874 (0.0..=1.0).contains(&v),
875 "{label}: pixel[{i}] = {v} out of [0, 1]"
876 );
877 }
878 }
879
880 #[test]
883 fn output_shape_equals_input_for_all_ops() {
884 let (c, h, w) = (3, 16, 16);
885 let img = gradient_image(c, h, w);
886 let expected_len = c * h * w;
887
888 for op in all_aug_ops() {
889 let out = apply_aug_op(&img, c, h, w, &op, 15.0, 0.5).unwrap();
890 assert_eq!(out.len(), expected_len, "shape mismatch for op {:?}", op);
891 }
892 }
893
894 #[test]
897 fn all_pixels_in_unit_range_for_all_ops() {
898 let (c, h, w) = (3, 16, 16);
899 let img = gradient_image(c, h, w);
900
901 for op in all_aug_ops() {
902 let out = apply_aug_op(&img, c, h, w, &op, 20.0, 0.5).unwrap();
903 assert_unit_range(&out, &format!("{op:?}"));
904 }
905 }
906
907 #[test]
910 fn identity_op_returns_exact_copy() {
911 let (c, h, w) = (3, 8, 8);
912 let img = gradient_image(c, h, w);
913 let out = apply_aug_op(&img, c, h, w, &AugOp::Identity, 15.0, 0.5).unwrap();
914 assert_eq!(out, img, "Identity must return exact copy");
915 }
916
917 #[test]
920 fn auto_contrast_stretches_to_unit() {
921 let (c, h, w) = (3, 4, 4);
923 let plane = h * w;
924 let mut img = vec![0.0_f32; c * plane];
925 for v in img[0..plane].iter_mut() {
927 *v = 0.5;
928 }
929 img[0] = 0.2;
930 img[plane - 1] = 0.8;
931 for v in img[plane..2 * plane].iter_mut() {
933 *v = 0.5;
934 }
935 img[plane] = 0.1;
936 img[2 * plane - 1] = 0.9;
937 for v in img[2 * plane..].iter_mut() {
939 *v = 0.3;
940 }
941
942 let out = apply_aug_op(&img, c, h, w, &AugOp::AutoContrast, 0.0, 0.5).unwrap();
943 let ch0_min = out[..plane].iter().cloned().fold(f32::INFINITY, f32::min);
945 let ch0_max = out[..plane]
946 .iter()
947 .cloned()
948 .fold(f32::NEG_INFINITY, f32::max);
949 assert!(ch0_min.abs() < 1e-5, "ch0 min = {ch0_min}");
950 assert!((ch0_max - 1.0).abs() < 1e-5, "ch0 max = {ch0_max}");
951 for &v in &out[2 * plane..] {
953 assert!((v - 0.3).abs() < 1e-5, "constant channel changed: {v}");
954 }
955 }
956
957 #[test]
960 fn equalize_output_in_unit_range() {
961 let (c, h, w) = (1, 32, 32);
962 let img = gradient_image(c, h, w);
963 let out = apply_aug_op(&img, c, h, w, &AugOp::Equalize, 0.0, 0.5).unwrap();
964 assert_unit_range(&out, "Equalize");
965 assert_eq!(out.len(), c * h * w);
966 }
967
968 #[test]
971 fn rotate_zero_degrees_approx_identity() {
972 let (c, h, w) = (1, 8, 8);
973 let img = gradient_image(c, h, w);
974 let out = apply_aug_op(&img, c, h, w, &AugOp::Rotate, 0.0, 0.5).unwrap();
976 for (i, (&a, &b)) in img.iter().zip(out.iter()).enumerate() {
977 assert!(
978 (a - b).abs() < 1e-4,
979 "rotate(0°): pixel[{i}]: input={a} output={b}"
980 );
981 }
982 }
983
984 #[test]
987 fn solarize_threshold_one_unchanged() {
988 let (c, h, w) = (3, 8, 8);
991 let img = gradient_image(c, h, w);
992 let out = apply_aug_op(&img, c, h, w, &AugOp::Solarize, 0.0, 0.5).unwrap();
993 for (i, (&a, &b)) in img.iter().zip(out.iter()).enumerate() {
995 if a < 1.0 {
996 assert!(
997 (a - b).abs() < 1e-6,
998 "solarize(threshold=1): pixel[{i}] changed: {a}→{b}"
999 );
1000 }
1001 }
1002 }
1003
1004 #[test]
1007 fn rand_augment_zero_ops_unchanged() {
1008 let (c, h, w) = (3, 8, 8);
1009 let img = gradient_image(c, h, w);
1010 let config = RandAugmentConfig {
1011 n_ops: 0,
1012 magnitude: 9.0,
1013 fill_value: 0.5,
1014 ops: all_aug_ops(),
1015 };
1016 let mut rng = LcgRng::new(42);
1017 let out = rand_augment(&img, c, h, w, &config, &mut rng).unwrap();
1018 assert_eq!(out, img, "n_ops=0 must return exact input copy");
1019 }
1020
1021 #[test]
1024 fn rand_augment_output_valid_shape_and_range() {
1025 let (c, h, w) = (3, 16, 16);
1026 let img = gradient_image(c, h, w);
1027 let config = RandAugmentConfig {
1028 n_ops: 3,
1029 magnitude: 15.0,
1030 fill_value: 0.5,
1031 ops: all_aug_ops(),
1032 };
1033 let mut rng = LcgRng::new(7);
1034 let out = rand_augment(&img, c, h, w, &config, &mut rng).unwrap();
1035 assert_eq!(out.len(), c * h * w);
1036 assert_unit_range(&out, "RandAugment(N=3)");
1037 }
1038
1039 #[test]
1042 fn auto_augment_imagenet_output_finite_and_valid() {
1043 let (c, h, w) = (3, 16, 16);
1044 let img = gradient_image(c, h, w);
1045 let config = AutoAugmentConfig {
1046 policy: AutoAugPolicy::ImageNet,
1047 fill_value: 0.5,
1048 };
1049 let mut rng = LcgRng::new(13);
1050 let out = auto_augment(&img, c, h, w, &config, &mut rng).unwrap();
1051 assert_eq!(out.len(), c * h * w);
1052 assert_unit_range(&out, "AutoAugment(ImageNet)");
1053 for &v in &out {
1054 assert!(v.is_finite(), "non-finite pixel in AutoAugment output");
1055 }
1056 }
1057
1058 #[test]
1061 fn different_seeds_produce_different_outputs() {
1062 let (c, h, w) = (3, 16, 16);
1063 let img = gradient_image(c, h, w);
1064 let config = RandAugmentConfig::default();
1065
1066 let mut rng_a = LcgRng::new(1);
1067 let mut rng_b = LcgRng::new(999);
1068 let out_a = rand_augment(&img, c, h, w, &config, &mut rng_a).unwrap();
1069 let out_b = rand_augment(&img, c, h, w, &config, &mut rng_b).unwrap();
1070
1071 let identical = out_a
1074 .iter()
1075 .zip(out_b.iter())
1076 .all(|(a, b)| (a - b).abs() < 1e-8);
1077 assert!(!identical, "different seeds must produce different outputs");
1078 }
1079
1080 #[test]
1083 fn same_seed_produces_same_output() {
1084 let (c, h, w) = (3, 16, 16);
1085 let img = gradient_image(c, h, w);
1086 let config = RandAugmentConfig::default();
1087
1088 let mut rng_a = LcgRng::new(42);
1089 let mut rng_b = LcgRng::new(42);
1090 let out_a = rand_augment(&img, c, h, w, &config, &mut rng_a).unwrap();
1091 let out_b = rand_augment(&img, c, h, w, &config, &mut rng_b).unwrap();
1092 assert_eq!(out_a, out_b, "same seed must produce identical output");
1093 }
1094
1095 #[test]
1098 fn brightness_low_magnitude_dims_image() {
1099 let (c, h, w) = (3, 8, 8);
1100 let img = vec![0.8_f32; c * h * w];
1101 let out = apply_aug_op(&img, c, h, w, &AugOp::Brightness, 0.0, 0.5).unwrap();
1103 let mean_out: f32 = out.iter().sum::<f32>() / out.len() as f32;
1104 assert!(
1106 mean_out < 0.2,
1107 "Brightness(mag=0) should produce near-black image, got mean={mean_out}"
1108 );
1109 }
1110
1111 #[test]
1114 fn all_14_ops_run_without_error() {
1115 let (c, h, w) = (3, 12, 12);
1116 let img = gradient_image(c, h, w);
1117 for mag in [0.0_f32, 9.0, 15.0, 30.0] {
1118 for op in all_aug_ops() {
1119 let result = apply_aug_op(&img, c, h, w, &op, mag, 0.5);
1120 assert!(
1121 result.is_ok(),
1122 "op {:?} at magnitude={mag} returned error: {:?}",
1123 op,
1124 result
1125 );
1126 assert_unit_range(&result.unwrap(), &format!("{op:?}@{mag}"));
1127 }
1128 }
1129 }
1130
1131 #[test]
1134 fn auto_augment_cifar10_output_valid() {
1135 let (c, h, w) = (3, 32, 32);
1136 let img = gradient_image(c, h, w);
1137 let config = AutoAugmentConfig {
1138 policy: AutoAugPolicy::Cifar10,
1139 fill_value: 0.5,
1140 };
1141 let mut rng = LcgRng::new(77);
1142 let out = auto_augment(&img, c, h, w, &config, &mut rng).unwrap();
1143 assert_eq!(out.len(), c * h * w);
1144 assert_unit_range(&out, "AutoAugment(Cifar10)");
1145 }
1146
1147 #[test]
1150 fn auto_augment_custom_policy_identity_always() {
1151 let (c, h, w) = (3, 8, 8);
1153 let img = gradient_image(c, h, w);
1154 let config = AutoAugmentConfig {
1155 policy: AutoAugPolicy::Custom(vec![(
1156 (AugOp::Identity, 1.0, 0),
1157 (AugOp::Identity, 1.0, 0),
1158 )]),
1159 fill_value: 0.5,
1160 };
1161 let mut rng = LcgRng::new(1);
1162 let out = auto_augment(&img, c, h, w, &config, &mut rng).unwrap();
1163 assert_eq!(
1164 out, img,
1165 "custom Identity × Identity should return exact copy"
1166 );
1167 }
1168
1169 #[test]
1172 fn error_on_empty_input() {
1173 let result = apply_aug_op(&[], 0, 8, 8, &AugOp::Identity, 0.0, 0.5);
1174 assert!(matches!(result, Err(SslError::EmptyInput)));
1175 }
1176
1177 #[test]
1180 fn error_on_dimension_mismatch() {
1181 let img = vec![0.5_f32; 10]; let result = apply_aug_op(&img, 3, 4, 4, &AugOp::Identity, 0.0, 0.5);
1183 assert!(matches!(result, Err(SslError::DimensionMismatch { .. })));
1184 }
1185
1186 #[test]
1189 fn posterize_full_magnitude_reduces_unique_values() {
1190 let (c, h, w) = (1, 16, 16);
1191 let img = gradient_image(c, h, w);
1192 let out = apply_aug_op(&img, c, h, w, &AugOp::Posterize, 30.0, 0.5).unwrap();
1194 let mut values: Vec<u32> = out.iter().map(|&v| (v * 255.0).round() as u32).collect();
1196 values.sort_unstable();
1197 values.dedup();
1198 assert!(
1199 values.len() <= 16,
1200 "expected ≤16 distinct values after 4-bit posterize, got {}",
1201 values.len()
1202 );
1203 }
1204
1205 #[test]
1208 fn sharpness_full_magnitude_is_original() {
1209 let (c, h, w) = (3, 8, 8);
1210 let img = gradient_image(c, h, w);
1211 let out = apply_aug_op(&img, c, h, w, &AugOp::Sharpness, 30.0, 0.5).unwrap();
1213 for (i, (&a, &b)) in img.iter().zip(out.iter()).enumerate() {
1214 assert!(
1215 (a - b).abs() < 1e-5,
1216 "Sharpness(1.0): pixel[{i}] input={a} output={b}"
1217 );
1218 }
1219 }
1220}