Skip to main content

oxicuda_ssl/augment/
rand_augment.rs

1//! RandAugment and AutoAugment augmentation policies for CHW images.
2//!
3//! # Layout convention
4//!
5//! All functions operate on a flat `[C × H × W]` row-major buffer where
6//! channel `c`, row `y`, and column `x` maps to index `c * H * W + y * W + x`.
7//! Pixel values are `f32` in `[0.0, 1.0]`.
8//!
9//! # References
10//! - Cubuk et al., "RandAugment: Practical automated data augmentation with a
11//!   reduced search space", NeurIPS 2020.
12//! - Cubuk et al., "AutoAugment: Learning Augmentation Policies from Data",
13//!   CVPR 2019.
14
15use crate::error::{SslError, SslResult};
16use crate::handle::LcgRng;
17
18// ─── Augmentation operation enum ──────────────────────────────────────────────
19
20/// The 14 canonical RandAugment operations.
21#[derive(Debug, Clone, PartialEq)]
22pub enum AugOp {
23    /// Pass the image through unchanged.
24    Identity,
25    /// Stretch per-channel histogram to [0, 1].
26    AutoContrast,
27    /// Histogram equalization per channel.
28    Equalize,
29    /// Rotate by ±30° scaled by magnitude.
30    Rotate,
31    /// Invert pixels above a magnitude-derived threshold.
32    Solarize,
33    /// Blend between grayscale and original (saturation adjust).
34    Color,
35    /// Reduce effective bit depth.
36    Posterize,
37    /// Blend between channel-mean image and original.
38    Contrast,
39    /// Blend between black and original.
40    Brightness,
41    /// Sharpen via unsharp masking.
42    Sharpness,
43    /// Shear horizontally.
44    ShearX,
45    /// Shear vertically.
46    ShearY,
47    /// Translate horizontally.
48    TranslateX,
49    /// Translate vertically.
50    TranslateY,
51}
52
53/// Default set of all 14 RandAugment operations in canonical order.
54pub fn all_aug_ops() -> Vec<AugOp> {
55    vec![
56        AugOp::Identity,
57        AugOp::AutoContrast,
58        AugOp::Equalize,
59        AugOp::Rotate,
60        AugOp::Solarize,
61        AugOp::Color,
62        AugOp::Posterize,
63        AugOp::Contrast,
64        AugOp::Brightness,
65        AugOp::Sharpness,
66        AugOp::ShearX,
67        AugOp::ShearY,
68        AugOp::TranslateX,
69        AugOp::TranslateY,
70    ]
71}
72
73// ─── Configuration types ──────────────────────────────────────────────────────
74
75/// Configuration for the RandAugment policy (Cubuk et al., NeurIPS 2020).
76#[derive(Debug, Clone)]
77pub struct RandAugmentConfig {
78    /// N: number of operations to sample and apply per image (default: 2).
79    pub n_ops: usize,
80    /// M: shared magnitude on a 0–30 scale (default: 9.0).
81    pub magnitude: f32,
82    /// Fill value for geometric transforms when sampling outside the image boundary.
83    pub fill_value: f32,
84    /// Pool of operations to sample from (default: all 14).
85    pub ops: Vec<AugOp>,
86}
87
88impl Default for RandAugmentConfig {
89    fn default() -> Self {
90        Self {
91            n_ops: 2,
92            magnitude: 9.0,
93            fill_value: 0.5,
94            ops: all_aug_ops(),
95        }
96    }
97}
98
99impl RandAugmentConfig {
100    /// Validate that the config is self-consistent.
101    pub fn validate(&self) -> SslResult<()> {
102        if !(self.magnitude.is_finite() && (0.0..=30.0).contains(&self.magnitude)) {
103            return Err(SslError::InvalidParameter {
104                name: "magnitude".into(),
105                reason: format!("must be in [0, 30] and finite, got {}", self.magnitude),
106            });
107        }
108        if !(self.fill_value.is_finite() && (0.0..=1.0).contains(&self.fill_value)) {
109            return Err(SslError::InvalidParameter {
110                name: "fill_value".into(),
111                reason: format!("must be in [0, 1] and finite, got {}", self.fill_value),
112            });
113        }
114        if self.ops.is_empty() {
115            return Err(SslError::InvalidParameter {
116                name: "ops".into(),
117                reason: "must contain at least one operation".into(),
118            });
119        }
120        Ok(())
121    }
122}
123
124/// AutoAugment sub-policy: two sequential operations each with a probability
125/// and discrete magnitude level.
126///
127/// Each element is `(op, probability, magnitude_level)`.  Probability is in
128/// `[0.0, 1.0]`; magnitude level is an integer in `[0, 10]` (AutoAugment
129/// convention) and is internally remapped to the 0–30 RandAugment magnitude
130/// scale before being passed to [`apply_aug_op`].
131pub type SubPolicy = ((AugOp, f32, usize), (AugOp, f32, usize));
132
133/// Built-in AutoAugment dataset policies.
134#[derive(Debug, Clone)]
135pub enum AutoAugPolicy {
136    /// The 25 sub-policies from the original ImageNet AutoAugment paper.
137    ImageNet,
138    /// The 25 sub-policies from the original CIFAR-10 AutoAugment paper.
139    Cifar10,
140    /// User-defined collection of sub-policies.
141    Custom(Vec<SubPolicy>),
142}
143
144/// Configuration for the AutoAugment policy (Cubuk et al., CVPR 2019).
145#[derive(Debug, Clone)]
146pub struct AutoAugmentConfig {
147    /// Which policy (set of sub-policies) to use.
148    pub policy: AutoAugPolicy,
149    /// Fill value for geometric transforms.
150    pub fill_value: f32,
151}
152
153impl Default for AutoAugmentConfig {
154    fn default() -> Self {
155        Self {
156            policy: AutoAugPolicy::ImageNet,
157            fill_value: 0.5,
158        }
159    }
160}
161
162// ─── Primitive image operations ───────────────────────────────────────────────
163
164/// Index into a CHW buffer: `c * H * W + y * W + x`.
165#[inline]
166fn chw_idx(c: usize, y: usize, x: usize, height: usize, width: usize) -> usize {
167    c * height * width + y * width + x
168}
169
170/// Bilinear sample from a single-channel plane of size `H × W`.
171///
172/// Coordinates outside `[0, H-1] × [0, W-1]` return `fill_value`.
173fn bilinear_sample(
174    plane: &[f32],
175    height: usize,
176    width: usize,
177    fy: f32,
178    fx: f32,
179    fill_value: f32,
180) -> f32 {
181    if fy < 0.0 || fx < 0.0 || fy > (height - 1) as f32 || fx > (width - 1) as f32 {
182        return fill_value;
183    }
184    let y0 = fy.floor() as usize;
185    let x0 = fx.floor() as usize;
186    let y1 = (y0 + 1).min(height - 1);
187    let x1 = (x0 + 1).min(width - 1);
188    let dy = fy - y0 as f32;
189    let dx = fx - x0 as f32;
190
191    let v00 = plane[y0 * width + x0];
192    let v01 = plane[y0 * width + x1];
193    let v10 = plane[y1 * width + x0];
194    let v11 = plane[y1 * width + x1];
195
196    let top = v00 * (1.0 - dx) + v01 * dx;
197    let bot = v10 * (1.0 - dx) + v11 * dx;
198    top * (1.0 - dy) + bot * dy
199}
200
201/// Apply an affine warp to all channels of a CHW image.
202///
203/// For each output pixel `(y, x)`, the source coordinate is:
204/// ```text
205///   src_y = y + dy_coeff * y + dyx_coeff * x + shift_y
206///   src_x = x + dx_coeff * x + dxy_coeff * y + shift_x
207/// ```
208/// where `dy_coeff`, `dx_coeff`, `dxy_coeff`, `dyx_coeff` encode shear, and
209/// `shift_x`, `shift_y` encode translation.  Pixels outside the source image
210/// are filled with `fill_value`.
211#[allow(clippy::too_many_arguments)]
212fn warp_affine(
213    pixels: &[f32],
214    channels: usize,
215    height: usize,
216    width: usize,
217    // Inverse affine coefficients for the source lookup
218    a00: f32, // src_x += a00 * x
219    a01: f32, // src_x += a01 * y
220    a02: f32, // src_x += a02 (translation x)
221    a10: f32, // src_y += a10 * x
222    a11: f32, // src_y += a11 * y
223    a12: f32, // src_y += a12 (translation y)
224    fill_value: f32,
225) -> Vec<f32> {
226    let plane = height * width;
227    let mut out = vec![fill_value; channels * plane];
228    for c in 0..channels {
229        let src_plane = &pixels[c * plane..(c + 1) * plane];
230        let dst_plane = &mut out[c * plane..(c + 1) * plane];
231        for y in 0..height {
232            for x in 0..width {
233                let fx = a00 * x as f32 + a01 * y as f32 + a02;
234                let fy = a10 * x as f32 + a11 * y as f32 + a12;
235                dst_plane[y * width + x] =
236                    bilinear_sample(src_plane, height, width, fy, fx, fill_value);
237            }
238        }
239    }
240    out
241}
242
243/// Auto-contrast: per-channel linear stretch to [0, 1].
244fn op_auto_contrast(pixels: &[f32], channels: usize, height: usize, width: usize) -> Vec<f32> {
245    let plane = height * width;
246    let mut out = pixels.to_vec();
247    for c in 0..channels {
248        let ch = &pixels[c * plane..(c + 1) * plane];
249        let min_v = ch.iter().cloned().fold(f32::INFINITY, f32::min);
250        let max_v = ch.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
251        if (max_v - min_v).abs() < 1e-7 {
252            continue; // uniform channel — leave as-is
253        }
254        let range = max_v - min_v;
255        for (dst, &src) in out[c * plane..(c + 1) * plane].iter_mut().zip(ch.iter()) {
256            *dst = ((src - min_v) / range).clamp(0.0, 1.0);
257        }
258    }
259    out
260}
261
262/// Histogram equalization per channel.
263///
264/// Pixels are quantized into 256 bins, a CDF is computed, and each pixel is
265/// remapped via the CDF so that the output histogram is approximately uniform.
266fn op_equalize(pixels: &[f32], channels: usize, height: usize, width: usize) -> Vec<f32> {
267    const BINS: usize = 256;
268    let plane = height * width;
269    let mut out = pixels.to_vec();
270    for c in 0..channels {
271        let ch = &pixels[c * plane..(c + 1) * plane];
272        let mut hist = [0u32; BINS];
273        for &p in ch.iter() {
274            let bin = ((p * (BINS as f32 - 1.0)).round() as usize).min(BINS - 1);
275            hist[bin] += 1;
276        }
277        // Compute CDF
278        let mut cdf = [0u32; BINS];
279        cdf[0] = hist[0];
280        for i in 1..BINS {
281            cdf[i] = cdf[i - 1] + hist[i];
282        }
283        let cdf_min = cdf.iter().find(|&&v| v > 0).copied().unwrap_or(0);
284        let total = plane as u32;
285        let denom = total.saturating_sub(cdf_min);
286        // Build the LUT
287        let mut lut = [0.0_f32; BINS];
288        for (i, lut_v) in lut.iter_mut().enumerate() {
289            if denom == 0 {
290                *lut_v = i as f32 / (BINS as f32 - 1.0);
291            } else {
292                let mapped = (cdf[i].saturating_sub(cdf_min)) as f32 / denom as f32;
293                *lut_v = mapped.clamp(0.0, 1.0);
294            }
295        }
296        for (dst, &src) in out[c * plane..(c + 1) * plane].iter_mut().zip(ch.iter()) {
297            let bin = ((src * (BINS as f32 - 1.0)).round() as usize).min(BINS - 1);
298            *dst = lut[bin];
299        }
300    }
301    out
302}
303
304/// Rotation by `angle_deg` degrees with bilinear interpolation (center pivot).
305fn op_rotate(
306    pixels: &[f32],
307    channels: usize,
308    height: usize,
309    width: usize,
310    angle_deg: f32,
311    fill_value: f32,
312) -> Vec<f32> {
313    let angle_rad = angle_deg * std::f32::consts::PI / 180.0;
314    let cos_a = angle_rad.cos();
315    let sin_a = angle_rad.sin();
316    let cx = (width as f32 - 1.0) / 2.0;
317    let cy = (height as f32 - 1.0) / 2.0;
318    // Inverse rotation: given output (x, y), find source.
319    // src_x = cos_a * (x - cx) + sin_a * (y - cy) + cx
320    // src_y = -sin_a * (x - cx) + cos_a * (y - cy) + cy
321    let a00 = cos_a;
322    let a01 = sin_a;
323    let a02 = -cos_a * cx - sin_a * cy + cx;
324    let a10 = -sin_a;
325    let a11 = cos_a;
326    let a12 = sin_a * cx - cos_a * cy + cy;
327    warp_affine(
328        pixels, channels, height, width, a00, a01, a02, a10, a11, a12, fill_value,
329    )
330}
331
332/// Solarize: invert pixels at or above `threshold`.
333fn op_solarize(pixels: &[f32], threshold: f32) -> Vec<f32> {
334    pixels
335        .iter()
336        .map(|&p| if p >= threshold { 1.0 - p } else { p })
337        .collect()
338}
339
340/// Color (saturation) adjustment.
341///
342/// `alpha` in `[0, 1]`: 0 = grayscale, 1 = original.
343/// Uses BT.601 luminance weights.
344fn op_color(pixels: &[f32], channels: usize, height: usize, width: usize, alpha: f32) -> Vec<f32> {
345    if channels != 3 {
346        // For non-RGB images, no-op.
347        return pixels.to_vec();
348    }
349    let plane = height * width;
350    let mut out = pixels.to_vec();
351    for i in 0..plane {
352        let r = pixels[i];
353        let g = pixels[plane + i];
354        let b = pixels[2 * plane + i];
355        let y = 0.299 * r + 0.587 * g + 0.114 * b;
356        out[i] = (alpha * r + (1.0 - alpha) * y).clamp(0.0, 1.0);
357        out[plane + i] = (alpha * g + (1.0 - alpha) * y).clamp(0.0, 1.0);
358        out[2 * plane + i] = (alpha * b + (1.0 - alpha) * y).clamp(0.0, 1.0);
359    }
360    out
361}
362
363/// Posterize: keep the top `k` bits of each pixel value (quantized to 8-bit).
364///
365/// `k` ranges from 4–8; lower = more posterized.
366fn op_posterize(pixels: &[f32], k: u32) -> Vec<f32> {
367    // k bits: mask is 0xFF with the lower (8-k) bits zeroed.
368    let shift = 8u32.saturating_sub(k);
369    let mask = if shift >= 8 { 0u8 } else { 0xFFu8 << shift };
370    pixels
371        .iter()
372        .map(|&p| {
373            let byte = (p * 255.0).round().clamp(0.0, 255.0) as u8;
374            let masked = byte & mask;
375            (masked as f32 / 255.0).clamp(0.0, 1.0)
376        })
377        .collect()
378}
379
380/// Contrast: blend between channel-mean and original.
381fn op_contrast(
382    pixels: &[f32],
383    channels: usize,
384    height: usize,
385    width: usize,
386    alpha: f32,
387) -> Vec<f32> {
388    let plane = height * width;
389    let mut out = pixels.to_vec();
390    for c in 0..channels {
391        let ch = &pixels[c * plane..(c + 1) * plane];
392        let mean = ch.iter().sum::<f32>() / plane as f32;
393        for (dst, &src) in out[c * plane..(c + 1) * plane].iter_mut().zip(ch.iter()) {
394            *dst = ((1.0 - alpha) * mean + alpha * src).clamp(0.0, 1.0);
395        }
396    }
397    out
398}
399
400/// Brightness: blend between black (0) and original.
401fn op_brightness(pixels: &[f32], strength: f32) -> Vec<f32> {
402    pixels
403        .iter()
404        .map(|&p| (strength * p).clamp(0.0, 1.0))
405        .collect()
406}
407
408/// Sharpness: blend between blurred (3×3 box) and original.
409fn op_sharpness(
410    pixels: &[f32],
411    channels: usize,
412    height: usize,
413    width: usize,
414    alpha: f32,
415) -> Vec<f32> {
416    // Build a blurred version using a 3×3 box filter.
417    let plane = height * width;
418    let mut blurred = vec![0.0_f32; channels * plane];
419    for c in 0..channels {
420        for y in 0..height {
421            for x in 0..width {
422                let mut acc = 0.0_f32;
423                let mut count = 0u32;
424                for dy in 0..3usize {
425                    let ny = y + dy;
426                    if ny == 0 || ny > height {
427                        continue;
428                    }
429                    let ny = ny - 1;
430                    for dx in 0..3usize {
431                        let nx = x + dx;
432                        if nx == 0 || nx > width {
433                            continue;
434                        }
435                        let nx = nx - 1;
436                        acc += pixels[chw_idx(c, ny, nx, height, width)];
437                        count += 1;
438                    }
439                }
440                blurred[chw_idx(c, y, x, height, width)] =
441                    if count > 0 { acc / count as f32 } else { 0.0 };
442            }
443        }
444    }
445    // Blend: alpha * original + (1 - alpha) * blurred
446    pixels
447        .iter()
448        .zip(blurred.iter())
449        .map(|(&orig, &blur)| (alpha * orig + (1.0 - alpha) * blur).clamp(0.0, 1.0))
450        .collect()
451}
452
453/// Horizontal shear by `shear` radians (inverse warp).
454fn op_shear_x(
455    pixels: &[f32],
456    channels: usize,
457    height: usize,
458    width: usize,
459    shear: f32,
460    fill_value: f32,
461) -> Vec<f32> {
462    // For output (x, y): src_x = x - shear * y; src_y = y.
463    warp_affine(
464        pixels, channels, height, width, 1.0, -shear, 0.0, // src_x coefficients
465        0.0, 1.0, 0.0, // src_y coefficients
466        fill_value,
467    )
468}
469
470/// Vertical shear by `shear` radians (inverse warp).
471fn op_shear_y(
472    pixels: &[f32],
473    channels: usize,
474    height: usize,
475    width: usize,
476    shear: f32,
477    fill_value: f32,
478) -> Vec<f32> {
479    // For output (x, y): src_x = x; src_y = y - shear * x.
480    warp_affine(
481        pixels, channels, height, width, 1.0, 0.0, 0.0, // src_x coefficients
482        -shear, 1.0, 0.0, // src_y coefficients
483        fill_value,
484    )
485}
486
487/// Horizontal translation by `shift_x` pixels.
488fn op_translate_x(
489    pixels: &[f32],
490    channels: usize,
491    height: usize,
492    width: usize,
493    shift_x: f32,
494    fill_value: f32,
495) -> Vec<f32> {
496    warp_affine(
497        pixels, channels, height, width, 1.0, 0.0, -shift_x, 0.0, 1.0, 0.0, fill_value,
498    )
499}
500
501/// Vertical translation by `shift_y` pixels.
502fn op_translate_y(
503    pixels: &[f32],
504    channels: usize,
505    height: usize,
506    width: usize,
507    shift_y: f32,
508    fill_value: f32,
509) -> Vec<f32> {
510    warp_affine(
511        pixels, channels, height, width, 1.0, 0.0, 0.0, 0.0, 1.0, -shift_y, fill_value,
512    )
513}
514
515// ─── Public API ───────────────────────────────────────────────────────────────
516
517/// Apply a single augmentation operation with the given magnitude and fill value.
518///
519/// # Parameters
520/// - `pixels`     — flat `[C × H × W]` CHW input in `[0, 1]`.
521/// - `channels`   — number of channels `C`.
522/// - `height`     — image height `H`.
523/// - `width`      — image width `W`.
524/// - `op`         — which [`AugOp`] to apply.
525/// - `magnitude`  — shared magnitude in `[0, 30]`.
526/// - `fill_value` — fill for geometric OOB pixels.
527///
528/// # Errors
529/// - [`SslError::EmptyInput`] if any dimension is zero.
530/// - [`SslError::DimensionMismatch`] if `pixels.len() != C·H·W`.
531/// - [`SslError::InvalidParameter`] for invalid magnitude or fill_value.
532pub fn apply_aug_op(
533    pixels: &[f32],
534    channels: usize,
535    height: usize,
536    width: usize,
537    op: &AugOp,
538    magnitude: f32,
539    fill_value: f32,
540) -> SslResult<Vec<f32>> {
541    if channels == 0 || height == 0 || width == 0 {
542        return Err(SslError::EmptyInput);
543    }
544    let expected = channels * height * width;
545    if pixels.len() != expected {
546        return Err(SslError::DimensionMismatch {
547            expected,
548            got: pixels.len(),
549        });
550    }
551    if !(magnitude.is_finite() && (0.0..=30.0).contains(&magnitude)) {
552        return Err(SslError::InvalidParameter {
553            name: "magnitude".into(),
554            reason: format!("must be in [0, 30] and finite, got {magnitude}"),
555        });
556    }
557    if !(fill_value.is_finite() && (0.0..=1.0).contains(&fill_value)) {
558        return Err(SslError::InvalidParameter {
559            name: "fill_value".into(),
560            reason: format!("must be in [0, 1] and finite, got {fill_value}"),
561        });
562    }
563
564    let m = magnitude / 30.0; // normalised to [0, 1]
565
566    let result = match op {
567        AugOp::Identity => pixels.to_vec(),
568
569        AugOp::AutoContrast => op_auto_contrast(pixels, channels, height, width),
570
571        AugOp::Equalize => op_equalize(pixels, channels, height, width),
572
573        AugOp::Rotate => {
574            // ±30° max; use a signed direction encoded by (m >= 0.5).
575            // In RandAugment the sign is sampled externally; here we use the
576            // direct magnitude linearly in [0°, 30°] (caller picks sign).
577            let angle = m * 30.0;
578            op_rotate(pixels, channels, height, width, angle, fill_value)
579        }
580
581        AugOp::Solarize => {
582            // threshold = 1 - m: magnitude=0 → threshold=1 (nothing flipped);
583            // magnitude=30 → threshold=0 (all pixels flipped).
584            let threshold = (1.0 - m).clamp(0.0, 1.0);
585            op_solarize(pixels, threshold)
586        }
587
588        AugOp::Color => {
589            // alpha=1 at magnitude=0 (original); alpha decreases with magnitude.
590            let alpha = (1.0 - m * 0.9).clamp(0.0, 1.0);
591            op_color(pixels, channels, height, width, alpha)
592        }
593
594        AugOp::Posterize => {
595            // k = 8 - floor(m * 4): range [4, 8].
596            let k = 8 - (m * 4.0).floor() as u32;
597            let k = k.max(1);
598            op_posterize(pixels, k)
599        }
600
601        AugOp::Contrast => {
602            // alpha=1 at magnitude=0 (original); blend toward mean.
603            let alpha = (1.0 - m * 0.9).clamp(0.0, 1.0);
604            op_contrast(pixels, channels, height, width, alpha)
605        }
606
607        AugOp::Brightness => {
608            // strength = m * 0.9 + 0.1 so at m=0 strength≈0.1 (dim) and m=1 strength=1.0.
609            let strength = (m * 0.9 + 0.1).clamp(0.0, 1.0);
610            op_brightness(pixels, strength)
611        }
612
613        AugOp::Sharpness => {
614            // alpha=1 → sharp (original); alpha=0 → fully blurred.
615            let alpha = m.clamp(0.0, 1.0);
616            op_sharpness(pixels, channels, height, width, alpha)
617        }
618
619        AugOp::ShearX => {
620            let shear = m * 0.3;
621            op_shear_x(pixels, channels, height, width, shear, fill_value)
622        }
623
624        AugOp::ShearY => {
625            let shear = m * 0.3;
626            op_shear_y(pixels, channels, height, width, shear, fill_value)
627        }
628
629        AugOp::TranslateX => {
630            let shift = m * 0.33 * width as f32;
631            op_translate_x(pixels, channels, height, width, shift, fill_value)
632        }
633
634        AugOp::TranslateY => {
635            let shift = m * 0.33 * height as f32;
636            op_translate_y(pixels, channels, height, width, shift, fill_value)
637        }
638    };
639
640    Ok(result)
641}
642
643/// Apply the RandAugment policy to a CHW image.
644///
645/// Randomly samples `config.n_ops` operations (with replacement) from
646/// `config.ops` and applies each in sequence using `config.magnitude`.
647/// When `n_ops == 0` the image is returned unchanged.
648///
649/// # Errors
650/// - [`SslError::EmptyInput`] if any dimension is zero.
651/// - [`SslError::DimensionMismatch`] if slice length != `C·H·W`.
652/// - [`SslError::InvalidParameter`] if config is invalid.
653pub fn rand_augment(
654    pixels: &[f32],
655    channels: usize,
656    height: usize,
657    width: usize,
658    config: &RandAugmentConfig,
659    rng: &mut LcgRng,
660) -> SslResult<Vec<f32>> {
661    if channels == 0 || height == 0 || width == 0 {
662        return Err(SslError::EmptyInput);
663    }
664    let expected = channels * height * width;
665    if pixels.len() != expected {
666        return Err(SslError::DimensionMismatch {
667            expected,
668            got: pixels.len(),
669        });
670    }
671    config.validate()?;
672
673    if config.n_ops == 0 {
674        return Ok(pixels.to_vec());
675    }
676
677    let n_pool = config.ops.len();
678    let mut current = pixels.to_vec();
679
680    for _ in 0..config.n_ops {
681        let idx = rng.next_usize(n_pool);
682        let op = &config.ops[idx];
683        current = apply_aug_op(
684            &current,
685            channels,
686            height,
687            width,
688            op,
689            config.magnitude,
690            config.fill_value,
691        )?;
692    }
693    Ok(current)
694}
695
696// ─── AutoAugment policy tables ────────────────────────────────────────────────
697
698/// Build the 25 ImageNet AutoAugment sub-policies from Cubuk et al., CVPR 2019.
699///
700/// Each entry is `((op, prob, mag_level), (op, prob, mag_level))` where
701/// `mag_level` is in `[0, 10]` (scaled ×3 to reach the 0–30 magnitude range).
702fn imagenet_sub_policies() -> Vec<SubPolicy> {
703    use AugOp::*;
704    vec![
705        ((Posterize, 0.4, 8), (Rotate, 0.6, 9)),
706        ((Solarize, 0.6, 5), (AutoContrast, 0.6, 5)),
707        ((Equalize, 0.8, 8), (Equalize, 0.6, 3)),
708        ((Posterize, 0.6, 7), (Posterize, 0.6, 6)),
709        ((Equalize, 0.4, 7), (Solarize, 0.2, 4)),
710        ((Equalize, 0.4, 4), (Rotate, 0.8, 8)),
711        ((Solarize, 0.6, 3), (Equalize, 0.6, 7)),
712        ((Posterize, 0.8, 5), (Equalize, 1.0, 2)),
713        ((Rotate, 0.2, 3), (Solarize, 0.6, 8)),
714        ((Equalize, 0.6, 8), (Posterize, 0.4, 6)),
715        ((Rotate, 0.8, 8), (Color, 1.0, 2)),
716        ((Rotate, 0.9, 9), (Equalize, 1.0, 2)),
717        ((Equalize, 0.6, 7), (Equalize, 0.6, 3)),
718        ((Equalize, 0.6, 4), (Rotate, 0.6, 4)),
719        ((Solarize, 0.6, 7), (Rotate, 0.6, 3)),
720        ((ShearX, 0.8, 8), (Solarize, 0.8, 4)),
721        ((Color, 0.8, 3), (Color, 1.0, 7)),
722        ((Color, 0.4, 1), (Rotate, 0.6, 8)),
723        ((Color, 0.8, 8), (Solarize, 0.8, 8)),
724        ((Equalize, 0.4, 8), (Equalize, 0.8, 3)),
725        ((Posterize, 0.4, 6), (Rotate, 0.4, 3)),
726        ((Equalize, 0.6, 7), (Color, 0.4, 4)),
727        ((Color, 0.4, 9), (Equalize, 0.6, 3)),
728        ((Color, 0.8, 8), (Contrast, 0.6, 1)),
729        ((Rotate, 0.8, 8), (Contrast, 1.0, 2)),
730    ]
731}
732
733/// Build the 25 CIFAR-10 AutoAugment sub-policies from Cubuk et al., CVPR 2019.
734fn cifar10_sub_policies() -> Vec<SubPolicy> {
735    use AugOp::*;
736    vec![
737        ((Equalize, 0.1, 8), (ShearY, 0.6, 4)),
738        ((Color, 0.6, 1), (Equalize, 0.6, 2)),
739        ((Sharpness, 0.6, 7), (Brightness, 0.6, 6)),
740        ((AutoContrast, 0.4, 0), (Equalize, 0.6, 0)),
741        ((Equalize, 1.0, 9), (ShearY, 0.6, 3)),
742        ((Color, 0.4, 3), (AutoContrast, 0.6, 1)),
743        ((ShearX, 0.8, 5), (Color, 1.0, 3)),
744        ((ShearX, 0.4, 4), (Posterize, 0.4, 7)),
745        ((Color, 0.4, 3), (Brightness, 0.6, 7)),
746        ((ShearY, 0.6, 4), (Color, 1.0, 9)),
747        ((Equalize, 0.6, 9), (Posterize, 0.4, 6)),
748        ((Solarize, 0.4, 9), (AutoContrast, 0.6, 3)),
749        ((AutoContrast, 0.6, 1), (Posterize, 0.6, 9)),
750        ((Equalize, 0.4, 9), (Solarize, 0.4, 5)),
751        ((Brightness, 0.2, 1), (Equalize, 0.6, 2)),
752        ((Equalize, 0.0, 0), (Equalize, 1.0, 0)),
753        ((AutoContrast, 0.2, 0), (Equalize, 0.6, 0)),
754        ((Equalize, 0.2, 0), (AutoContrast, 0.6, 0)),
755        ((Contrast, 0.2, 0), (Equalize, 0.6, 0)),
756        ((Brightness, 0.6, 5), (Contrast, 0.6, 6)),
757        ((AutoContrast, 0.8, 5), (Rotate, 0.6, 2)),
758        ((Solarize, 0.4, 3), (Brightness, 0.8, 9)),
759        ((Rotate, 0.6, 6), (Color, 1.0, 1)),
760        ((Equalize, 0.4, 5), (AutoContrast, 0.6, 5)),
761        ((Rotate, 0.6, 6), (Posterize, 0.8, 8)),
762    ]
763}
764
765/// Apply the AutoAugment policy to a CHW image.
766///
767/// 1. Uniformly samples one sub-policy from the policy's list.
768/// 2. For each of the two operations in the sub-policy, applies it with the
769///    corresponding probability and magnitude level.
770///
771/// AutoAugment magnitude levels are integers in `[0, 10]`; they are scaled ×3
772/// to map into the `[0, 30]` range expected by [`apply_aug_op`].
773///
774/// # Errors
775/// - [`SslError::EmptyInput`] if any dimension is zero.
776/// - [`SslError::DimensionMismatch`] if `pixels.len() != C·H·W`.
777/// - [`SslError::InvalidParameter`] if policy has no sub-policies.
778pub fn auto_augment(
779    pixels: &[f32],
780    channels: usize,
781    height: usize,
782    width: usize,
783    config: &AutoAugmentConfig,
784    rng: &mut LcgRng,
785) -> SslResult<Vec<f32>> {
786    if channels == 0 || height == 0 || width == 0 {
787        return Err(SslError::EmptyInput);
788    }
789    let expected = channels * height * width;
790    if pixels.len() != expected {
791        return Err(SslError::DimensionMismatch {
792            expected,
793            got: pixels.len(),
794        });
795    }
796    if !(config.fill_value.is_finite() && (0.0..=1.0).contains(&config.fill_value)) {
797        return Err(SslError::InvalidParameter {
798            name: "fill_value".into(),
799            reason: format!("must be in [0, 1] and finite, got {}", config.fill_value),
800        });
801    }
802
803    let sub_policies: Vec<SubPolicy> = match &config.policy {
804        AutoAugPolicy::ImageNet => imagenet_sub_policies(),
805        AutoAugPolicy::Cifar10 => cifar10_sub_policies(),
806        AutoAugPolicy::Custom(v) => v.clone(),
807    };
808
809    if sub_policies.is_empty() {
810        return Err(SslError::InvalidParameter {
811            name: "policy".into(),
812            reason: "policy contains no sub-policies".into(),
813        });
814    }
815
816    // Sample one sub-policy.
817    let sp_idx = rng.next_usize(sub_policies.len());
818    let ((op1, prob1, mag_level1), (op2, prob2, mag_level2)) = &sub_policies[sp_idx];
819
820    // Scale magnitude level [0, 10] → [0, 30].
821    let mag1 = (*mag_level1 as f32 * 3.0).clamp(0.0, 30.0);
822    let mag2 = (*mag_level2 as f32 * 3.0).clamp(0.0, 30.0);
823
824    let mut current = pixels.to_vec();
825
826    if rng.next_f32() < *prob1 {
827        current = apply_aug_op(
828            &current,
829            channels,
830            height,
831            width,
832            op1,
833            mag1,
834            config.fill_value,
835        )?;
836    }
837    if rng.next_f32() < *prob2 {
838        current = apply_aug_op(
839            &current,
840            channels,
841            height,
842            width,
843            op2,
844            mag2,
845            config.fill_value,
846        )?;
847    }
848    Ok(current)
849}
850
851// ─── Unit tests ───────────────────────────────────────────────────────────────
852
853#[cfg(test)]
854mod tests {
855    use super::*;
856
857    // ── Helpers ───────────────────────────────────────────────────────────────
858
859    /// Create a deterministic gradient CHW image.
860    fn gradient_image(channels: usize, height: usize, width: usize) -> Vec<f32> {
861        let n = channels * height * width;
862        (0..n)
863            .map(|i| {
864                let v = (i as f32) / (n as f32);
865                v.clamp(0.0, 1.0)
866            })
867            .collect()
868    }
869
870    /// Assert all pixels in `[0, 1]`.
871    fn assert_unit_range(pixels: &[f32], label: &str) {
872        for (i, &v) in pixels.iter().enumerate() {
873            assert!(
874                (0.0..=1.0).contains(&v),
875                "{label}: pixel[{i}] = {v} out of [0, 1]"
876            );
877        }
878    }
879
880    // ── Test 1: Output shape always matches input ──────────────────────────────
881
882    #[test]
883    fn output_shape_equals_input_for_all_ops() {
884        let (c, h, w) = (3, 16, 16);
885        let img = gradient_image(c, h, w);
886        let expected_len = c * h * w;
887
888        for op in all_aug_ops() {
889            let out = apply_aug_op(&img, c, h, w, &op, 15.0, 0.5).unwrap();
890            assert_eq!(out.len(), expected_len, "shape mismatch for op {:?}", op);
891        }
892    }
893
894    // ── Test 2: All pixels in [0, 1] after any operation ─────────────────────
895
896    #[test]
897    fn all_pixels_in_unit_range_for_all_ops() {
898        let (c, h, w) = (3, 16, 16);
899        let img = gradient_image(c, h, w);
900
901        for op in all_aug_ops() {
902            let out = apply_aug_op(&img, c, h, w, &op, 20.0, 0.5).unwrap();
903            assert_unit_range(&out, &format!("{op:?}"));
904        }
905    }
906
907    // ── Test 3: Identity op returns exact copy ────────────────────────────────
908
909    #[test]
910    fn identity_op_returns_exact_copy() {
911        let (c, h, w) = (3, 8, 8);
912        let img = gradient_image(c, h, w);
913        let out = apply_aug_op(&img, c, h, w, &AugOp::Identity, 15.0, 0.5).unwrap();
914        assert_eq!(out, img, "Identity must return exact copy");
915    }
916
917    // ── Test 4: AutoContrast stretches to [0, 1] per channel ─────────────────
918
919    #[test]
920    fn auto_contrast_stretches_to_unit() {
921        // Create image with known range per channel.
922        let (c, h, w) = (3, 4, 4);
923        let plane = h * w;
924        let mut img = vec![0.0_f32; c * plane];
925        // Channel 0: range [0.2, 0.8]
926        for v in img[0..plane].iter_mut() {
927            *v = 0.5;
928        }
929        img[0] = 0.2;
930        img[plane - 1] = 0.8;
931        // Channel 1: range [0.1, 0.9]
932        for v in img[plane..2 * plane].iter_mut() {
933            *v = 0.5;
934        }
935        img[plane] = 0.1;
936        img[2 * plane - 1] = 0.9;
937        // Channel 2: constant → should be left alone.
938        for v in img[2 * plane..].iter_mut() {
939            *v = 0.3;
940        }
941
942        let out = apply_aug_op(&img, c, h, w, &AugOp::AutoContrast, 0.0, 0.5).unwrap();
943        // Channel 0: min should become 0, max should become 1.
944        let ch0_min = out[..plane].iter().cloned().fold(f32::INFINITY, f32::min);
945        let ch0_max = out[..plane]
946            .iter()
947            .cloned()
948            .fold(f32::NEG_INFINITY, f32::max);
949        assert!(ch0_min.abs() < 1e-5, "ch0 min = {ch0_min}");
950        assert!((ch0_max - 1.0).abs() < 1e-5, "ch0 max = {ch0_max}");
951        // Channel 2: constant → should stay ≈ 0.3.
952        for &v in &out[2 * plane..] {
953            assert!((v - 0.3).abs() < 1e-5, "constant channel changed: {v}");
954        }
955    }
956
957    // ── Test 5: Equalize outputs in [0, 1] ────────────────────────────────────
958
959    #[test]
960    fn equalize_output_in_unit_range() {
961        let (c, h, w) = (1, 32, 32);
962        let img = gradient_image(c, h, w);
963        let out = apply_aug_op(&img, c, h, w, &AugOp::Equalize, 0.0, 0.5).unwrap();
964        assert_unit_range(&out, "Equalize");
965        assert_eq!(out.len(), c * h * w);
966    }
967
968    // ── Test 6: Rotate by 0° returns original ────────────────────────────────
969
970    #[test]
971    fn rotate_zero_degrees_approx_identity() {
972        let (c, h, w) = (1, 8, 8);
973        let img = gradient_image(c, h, w);
974        // magnitude = 0 → angle = 0°.
975        let out = apply_aug_op(&img, c, h, w, &AugOp::Rotate, 0.0, 0.5).unwrap();
976        for (i, (&a, &b)) in img.iter().zip(out.iter()).enumerate() {
977            assert!(
978                (a - b).abs() < 1e-4,
979                "rotate(0°): pixel[{i}]: input={a} output={b}"
980            );
981        }
982    }
983
984    // ── Test 7: Solarize with threshold=1.0 leaves all pixels unchanged ───────
985
986    #[test]
987    fn solarize_threshold_one_unchanged() {
988        // magnitude=0 → threshold = 1 - 0 = 1.0.
989        // No pixel in [0,1] is ≥ 1.0 (strictly), so nothing flips.
990        let (c, h, w) = (3, 8, 8);
991        let img = gradient_image(c, h, w);
992        let out = apply_aug_op(&img, c, h, w, &AugOp::Solarize, 0.0, 0.5).unwrap();
993        // Pixels < 1.0 are unchanged; pixel at exactly 1.0 (if any) gets flipped to 0.
994        for (i, (&a, &b)) in img.iter().zip(out.iter()).enumerate() {
995            if a < 1.0 {
996                assert!(
997                    (a - b).abs() < 1e-6,
998                    "solarize(threshold=1): pixel[{i}] changed: {a}→{b}"
999                );
1000            }
1001        }
1002    }
1003
1004    // ── Test 8: RandAugment with N=0 returns unchanged image ─────────────────
1005
1006    #[test]
1007    fn rand_augment_zero_ops_unchanged() {
1008        let (c, h, w) = (3, 8, 8);
1009        let img = gradient_image(c, h, w);
1010        let config = RandAugmentConfig {
1011            n_ops: 0,
1012            magnitude: 9.0,
1013            fill_value: 0.5,
1014            ops: all_aug_ops(),
1015        };
1016        let mut rng = LcgRng::new(42);
1017        let out = rand_augment(&img, c, h, w, &config, &mut rng).unwrap();
1018        assert_eq!(out, img, "n_ops=0 must return exact input copy");
1019    }
1020
1021    // ── Test 9: RandAugment applies exactly N ops (implicit via shape) ────────
1022
1023    #[test]
1024    fn rand_augment_output_valid_shape_and_range() {
1025        let (c, h, w) = (3, 16, 16);
1026        let img = gradient_image(c, h, w);
1027        let config = RandAugmentConfig {
1028            n_ops: 3,
1029            magnitude: 15.0,
1030            fill_value: 0.5,
1031            ops: all_aug_ops(),
1032        };
1033        let mut rng = LcgRng::new(7);
1034        let out = rand_augment(&img, c, h, w, &config, &mut rng).unwrap();
1035        assert_eq!(out.len(), c * h * w);
1036        assert_unit_range(&out, "RandAugment(N=3)");
1037    }
1038
1039    // ── Test 10: AutoAugment ImageNet policy: output is finite and valid ──────
1040
1041    #[test]
1042    fn auto_augment_imagenet_output_finite_and_valid() {
1043        let (c, h, w) = (3, 16, 16);
1044        let img = gradient_image(c, h, w);
1045        let config = AutoAugmentConfig {
1046            policy: AutoAugPolicy::ImageNet,
1047            fill_value: 0.5,
1048        };
1049        let mut rng = LcgRng::new(13);
1050        let out = auto_augment(&img, c, h, w, &config, &mut rng).unwrap();
1051        assert_eq!(out.len(), c * h * w);
1052        assert_unit_range(&out, "AutoAugment(ImageNet)");
1053        for &v in &out {
1054            assert!(v.is_finite(), "non-finite pixel in AutoAugment output");
1055        }
1056    }
1057
1058    // ── Test 11: Different seeds → different augmentations ───────────────────
1059
1060    #[test]
1061    fn different_seeds_produce_different_outputs() {
1062        let (c, h, w) = (3, 16, 16);
1063        let img = gradient_image(c, h, w);
1064        let config = RandAugmentConfig::default();
1065
1066        let mut rng_a = LcgRng::new(1);
1067        let mut rng_b = LcgRng::new(999);
1068        let out_a = rand_augment(&img, c, h, w, &config, &mut rng_a).unwrap();
1069        let out_b = rand_augment(&img, c, h, w, &config, &mut rng_b).unwrap();
1070
1071        // It is overwhelmingly unlikely that two different random seeds produce
1072        // identical augmented outputs; if they do, the test catches a RNG bug.
1073        let identical = out_a
1074            .iter()
1075            .zip(out_b.iter())
1076            .all(|(a, b)| (a - b).abs() < 1e-8);
1077        assert!(!identical, "different seeds must produce different outputs");
1078    }
1079
1080    // ── Test 12: Same seed → same output (deterministic) ─────────────────────
1081
1082    #[test]
1083    fn same_seed_produces_same_output() {
1084        let (c, h, w) = (3, 16, 16);
1085        let img = gradient_image(c, h, w);
1086        let config = RandAugmentConfig::default();
1087
1088        let mut rng_a = LcgRng::new(42);
1089        let mut rng_b = LcgRng::new(42);
1090        let out_a = rand_augment(&img, c, h, w, &config, &mut rng_a).unwrap();
1091        let out_b = rand_augment(&img, c, h, w, &config, &mut rng_b).unwrap();
1092        assert_eq!(out_a, out_b, "same seed must produce identical output");
1093    }
1094
1095    // ── Test 13: Brightness at magnitude=0 dims image significantly ──────────
1096
1097    #[test]
1098    fn brightness_low_magnitude_dims_image() {
1099        let (c, h, w) = (3, 8, 8);
1100        let img = vec![0.8_f32; c * h * w];
1101        // magnitude=0 → strength = 0*0.9 + 0.1 = 0.1.
1102        let out = apply_aug_op(&img, c, h, w, &AugOp::Brightness, 0.0, 0.5).unwrap();
1103        let mean_out: f32 = out.iter().sum::<f32>() / out.len() as f32;
1104        // 0.8 * 0.1 = 0.08; allow tolerance.
1105        assert!(
1106            mean_out < 0.2,
1107            "Brightness(mag=0) should produce near-black image, got mean={mean_out}"
1108        );
1109    }
1110
1111    // ── Test 14: apply_aug_op valid for all 14 ops without panic ─────────────
1112
1113    #[test]
1114    fn all_14_ops_run_without_error() {
1115        let (c, h, w) = (3, 12, 12);
1116        let img = gradient_image(c, h, w);
1117        for mag in [0.0_f32, 9.0, 15.0, 30.0] {
1118            for op in all_aug_ops() {
1119                let result = apply_aug_op(&img, c, h, w, &op, mag, 0.5);
1120                assert!(
1121                    result.is_ok(),
1122                    "op {:?} at magnitude={mag} returned error: {:?}",
1123                    op,
1124                    result
1125                );
1126                assert_unit_range(&result.unwrap(), &format!("{op:?}@{mag}"));
1127            }
1128        }
1129    }
1130
1131    // ── Test 15: AutoAugment CIFAR-10 policy ─────────────────────────────────
1132
1133    #[test]
1134    fn auto_augment_cifar10_output_valid() {
1135        let (c, h, w) = (3, 32, 32);
1136        let img = gradient_image(c, h, w);
1137        let config = AutoAugmentConfig {
1138            policy: AutoAugPolicy::Cifar10,
1139            fill_value: 0.5,
1140        };
1141        let mut rng = LcgRng::new(77);
1142        let out = auto_augment(&img, c, h, w, &config, &mut rng).unwrap();
1143        assert_eq!(out.len(), c * h * w);
1144        assert_unit_range(&out, "AutoAugment(Cifar10)");
1145    }
1146
1147    // ── Test 16: Custom AutoAugment policy ────────────────────────────────────
1148
1149    #[test]
1150    fn auto_augment_custom_policy_identity_always() {
1151        // A custom policy with a single sub-policy: Identity at prob=1.
1152        let (c, h, w) = (3, 8, 8);
1153        let img = gradient_image(c, h, w);
1154        let config = AutoAugmentConfig {
1155            policy: AutoAugPolicy::Custom(vec![(
1156                (AugOp::Identity, 1.0, 0),
1157                (AugOp::Identity, 1.0, 0),
1158            )]),
1159            fill_value: 0.5,
1160        };
1161        let mut rng = LcgRng::new(1);
1162        let out = auto_augment(&img, c, h, w, &config, &mut rng).unwrap();
1163        assert_eq!(
1164            out, img,
1165            "custom Identity × Identity should return exact copy"
1166        );
1167    }
1168
1169    // ── Test 17: Error on empty input ────────────────────────────────────────
1170
1171    #[test]
1172    fn error_on_empty_input() {
1173        let result = apply_aug_op(&[], 0, 8, 8, &AugOp::Identity, 0.0, 0.5);
1174        assert!(matches!(result, Err(SslError::EmptyInput)));
1175    }
1176
1177    // ── Test 18: Error on dimension mismatch ─────────────────────────────────
1178
1179    #[test]
1180    fn error_on_dimension_mismatch() {
1181        let img = vec![0.5_f32; 10]; // wrong for 3×4×4=48
1182        let result = apply_aug_op(&img, 3, 4, 4, &AugOp::Identity, 0.0, 0.5);
1183        assert!(matches!(result, Err(SslError::DimensionMismatch { .. })));
1184    }
1185
1186    // ── Test 19: Posterize at full magnitude quantizes heavily ────────────────
1187
1188    #[test]
1189    fn posterize_full_magnitude_reduces_unique_values() {
1190        let (c, h, w) = (1, 16, 16);
1191        let img = gradient_image(c, h, w);
1192        // magnitude=30 → k = 8 - floor(1.0 * 4) = 4 bits.
1193        let out = apply_aug_op(&img, c, h, w, &AugOp::Posterize, 30.0, 0.5).unwrap();
1194        // With 4-bit posterization we expect at most 16 distinct values.
1195        let mut values: Vec<u32> = out.iter().map(|&v| (v * 255.0).round() as u32).collect();
1196        values.sort_unstable();
1197        values.dedup();
1198        assert!(
1199            values.len() <= 16,
1200            "expected ≤16 distinct values after 4-bit posterize, got {}",
1201            values.len()
1202        );
1203    }
1204
1205    // ── Test 20: Sharpness at magnitude=1 returns original ───────────────────
1206
1207    #[test]
1208    fn sharpness_full_magnitude_is_original() {
1209        let (c, h, w) = (3, 8, 8);
1210        let img = gradient_image(c, h, w);
1211        // alpha = magnitude/30 = 1.0 → pure original, no blur blended in.
1212        let out = apply_aug_op(&img, c, h, w, &AugOp::Sharpness, 30.0, 0.5).unwrap();
1213        for (i, (&a, &b)) in img.iter().zip(out.iter()).enumerate() {
1214            assert!(
1215                (a - b).abs() < 1e-5,
1216                "Sharpness(1.0): pixel[{i}] input={a} output={b}"
1217            );
1218        }
1219    }
1220}