Skip to main content

oxicuda_ssl/augment/
rand_augment.rs

1//! RandAugment and AutoAugment augmentation policies for CHW images.
2//!
3//! # Layout convention
4//!
5//! All functions operate on a flat `[C × H × W]` row-major buffer where
6//! channel `c`, row `y`, and column `x` maps to index `c * H * W + y * W + x`.
7//! Pixel values are `f32` in `[0.0, 1.0]`.
8//!
9//! # References
10//! - Cubuk et al., "RandAugment: Practical automated data augmentation with a
11//!   reduced search space", NeurIPS 2020.
12//! - Cubuk et al., "AutoAugment: Learning Augmentation Policies from Data",
13//!   CVPR 2019.
14
15use crate::error::{SslError, SslResult};
16use crate::handle::LcgRng;
17
18// ─── Augmentation operation enum ──────────────────────────────────────────────
19
20/// The 14 canonical RandAugment operations.
21#[derive(Debug, Clone, PartialEq)]
22pub enum AugOp {
23    /// Pass the image through unchanged.
24    Identity,
25    /// Stretch per-channel histogram to [0, 1].
26    AutoContrast,
27    /// Histogram equalization per channel.
28    Equalize,
29    /// Rotate by ±30° scaled by magnitude.
30    Rotate,
31    /// Invert pixels above a magnitude-derived threshold.
32    Solarize,
33    /// Blend between grayscale and original (saturation adjust).
34    Color,
35    /// Reduce effective bit depth.
36    Posterize,
37    /// Blend between channel-mean image and original.
38    Contrast,
39    /// Blend between black and original.
40    Brightness,
41    /// Sharpen via unsharp masking.
42    Sharpness,
43    /// Shear horizontally.
44    ShearX,
45    /// Shear vertically.
46    ShearY,
47    /// Translate horizontally.
48    TranslateX,
49    /// Translate vertically.
50    TranslateY,
51}
52
53/// Default set of all 14 RandAugment operations in canonical order.
54pub fn all_aug_ops() -> Vec<AugOp> {
55    vec![
56        AugOp::Identity,
57        AugOp::AutoContrast,
58        AugOp::Equalize,
59        AugOp::Rotate,
60        AugOp::Solarize,
61        AugOp::Color,
62        AugOp::Posterize,
63        AugOp::Contrast,
64        AugOp::Brightness,
65        AugOp::Sharpness,
66        AugOp::ShearX,
67        AugOp::ShearY,
68        AugOp::TranslateX,
69        AugOp::TranslateY,
70    ]
71}
72
73// ─── Configuration types ──────────────────────────────────────────────────────
74
75/// Configuration for the RandAugment policy (Cubuk et al., NeurIPS 2020).
76#[derive(Debug, Clone)]
77pub struct RandAugmentConfig {
78    /// N: number of operations to sample and apply per image (default: 2).
79    pub n_ops: usize,
80    /// M: shared magnitude on a 0–30 scale (default: 9.0).
81    pub magnitude: f32,
82    /// Fill value for geometric transforms when sampling outside the image boundary.
83    pub fill_value: f32,
84    /// Pool of operations to sample from (default: all 14).
85    pub ops: Vec<AugOp>,
86}
87
88impl Default for RandAugmentConfig {
89    fn default() -> Self {
90        Self {
91            n_ops: 2,
92            magnitude: 9.0,
93            fill_value: 0.5,
94            ops: all_aug_ops(),
95        }
96    }
97}
98
99impl RandAugmentConfig {
100    /// Validate that the config is self-consistent.
101    pub fn validate(&self) -> SslResult<()> {
102        if !(self.magnitude.is_finite() && (0.0..=30.0).contains(&self.magnitude)) {
103            return Err(SslError::InvalidParameter {
104                name: "magnitude".into(),
105                reason: format!("must be in [0, 30] and finite, got {}", self.magnitude),
106            });
107        }
108        if !(self.fill_value.is_finite() && (0.0..=1.0).contains(&self.fill_value)) {
109            return Err(SslError::InvalidParameter {
110                name: "fill_value".into(),
111                reason: format!("must be in [0, 1] and finite, got {}", self.fill_value),
112            });
113        }
114        if self.ops.is_empty() {
115            return Err(SslError::InvalidParameter {
116                name: "ops".into(),
117                reason: "must contain at least one operation".into(),
118            });
119        }
120        Ok(())
121    }
122}
123
124/// AutoAugment sub-policy: two sequential operations each with a probability
125/// and discrete magnitude level.
126///
127/// Each element is `(op, probability, magnitude_level)`.  Probability is in
128/// `[0.0, 1.0]`; magnitude level is an integer in `[0, 10]` (AutoAugment
129/// convention) and is internally remapped to the 0–30 RandAugment magnitude
130/// scale before being passed to [`apply_aug_op`].
131pub type SubPolicy = ((AugOp, f32, usize), (AugOp, f32, usize));
132
133/// Built-in AutoAugment dataset policies.
134#[derive(Debug, Clone)]
135pub enum AutoAugPolicy {
136    /// The 25 sub-policies from the original ImageNet AutoAugment paper.
137    ImageNet,
138    /// The 25 sub-policies from the original CIFAR-10 AutoAugment paper.
139    Cifar10,
140    /// User-defined collection of sub-policies.
141    Custom(Vec<SubPolicy>),
142}
143
144/// Configuration for the AutoAugment policy (Cubuk et al., CVPR 2019).
145#[derive(Debug, Clone)]
146pub struct AutoAugmentConfig {
147    /// Which policy (set of sub-policies) to use.
148    pub policy: AutoAugPolicy,
149    /// Fill value for geometric transforms.
150    pub fill_value: f32,
151}
152
153impl Default for AutoAugmentConfig {
154    fn default() -> Self {
155        Self {
156            policy: AutoAugPolicy::ImageNet,
157            fill_value: 0.5,
158        }
159    }
160}
161
162// ─── Primitive image operations ───────────────────────────────────────────────
163
164/// Index into a CHW buffer: `c * H * W + y * W + x`.
165#[inline]
166fn chw_idx(c: usize, y: usize, x: usize, height: usize, width: usize) -> usize {
167    c * height * width + y * width + x
168}
169
170/// Bilinear sample from a single-channel plane of size `H × W`.
171///
172/// Coordinates outside `[0, H-1] × [0, W-1]` return `fill_value`.
173fn bilinear_sample(
174    plane: &[f32],
175    height: usize,
176    width: usize,
177    fy: f32,
178    fx: f32,
179    fill_value: f32,
180) -> f32 {
181    if fy < 0.0 || fx < 0.0 || fy > (height - 1) as f32 || fx > (width - 1) as f32 {
182        return fill_value;
183    }
184    let y0 = fy.floor() as usize;
185    let x0 = fx.floor() as usize;
186    let y1 = (y0 + 1).min(height - 1);
187    let x1 = (x0 + 1).min(width - 1);
188    let dy = fy - y0 as f32;
189    let dx = fx - x0 as f32;
190
191    let v00 = plane[y0 * width + x0];
192    let v01 = plane[y0 * width + x1];
193    let v10 = plane[y1 * width + x0];
194    let v11 = plane[y1 * width + x1];
195
196    let top = v00 * (1.0 - dx) + v01 * dx;
197    let bot = v10 * (1.0 - dx) + v11 * dx;
198    top * (1.0 - dy) + bot * dy
199}
200
201/// Apply an affine warp to all channels of a CHW image.
202///
203/// For each output pixel `(y, x)`, the source coordinate is:
204/// ```text
205///   src_y = y + dy_coeff * y + dyx_coeff * x + shift_y
206///   src_x = x + dx_coeff * x + dxy_coeff * y + shift_x
207/// ```
208/// where `dy_coeff`, `dx_coeff`, `dxy_coeff`, `dyx_coeff` encode shear, and
209/// `shift_x`, `shift_y` encode translation.  Pixels outside the source image
210/// are filled with `fill_value`.
211#[allow(clippy::too_many_arguments)]
212fn warp_affine(
213    pixels: &[f32],
214    channels: usize,
215    height: usize,
216    width: usize,
217    // Inverse affine coefficients for the source lookup
218    a00: f32, // src_x += a00 * x
219    a01: f32, // src_x += a01 * y
220    a02: f32, // src_x += a02 (translation x)
221    a10: f32, // src_y += a10 * x
222    a11: f32, // src_y += a11 * y
223    a12: f32, // src_y += a12 (translation y)
224    fill_value: f32,
225) -> Vec<f32> {
226    let plane = height * width;
227    let mut out = vec![fill_value; channels * plane];
228    for c in 0..channels {
229        let src_plane = &pixels[c * plane..(c + 1) * plane];
230        let dst_plane = &mut out[c * plane..(c + 1) * plane];
231        for y in 0..height {
232            for x in 0..width {
233                let fx = a00 * x as f32 + a01 * y as f32 + a02;
234                let fy = a10 * x as f32 + a11 * y as f32 + a12;
235                dst_plane[y * width + x] =
236                    bilinear_sample(src_plane, height, width, fy, fx, fill_value);
237            }
238        }
239    }
240    out
241}
242
243/// Auto-contrast: per-channel linear stretch to [0, 1].
244fn op_auto_contrast(pixels: &[f32], channels: usize, height: usize, width: usize) -> Vec<f32> {
245    let plane = height * width;
246    let mut out = pixels.to_vec();
247    for c in 0..channels {
248        let ch = &pixels[c * plane..(c + 1) * plane];
249        let min_v = ch.iter().cloned().fold(f32::INFINITY, f32::min);
250        let max_v = ch.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
251        if (max_v - min_v).abs() < 1e-7 {
252            continue; // uniform channel — leave as-is
253        }
254        let range = max_v - min_v;
255        for (dst, &src) in out[c * plane..(c + 1) * plane].iter_mut().zip(ch.iter()) {
256            *dst = ((src - min_v) / range).clamp(0.0, 1.0);
257        }
258    }
259    out
260}
261
262/// Histogram equalization per channel.
263///
264/// Pixels are quantized into 256 bins, a CDF is computed, and each pixel is
265/// remapped via the CDF so that the output histogram is approximately uniform.
266fn op_equalize(pixels: &[f32], channels: usize, height: usize, width: usize) -> Vec<f32> {
267    const BINS: usize = 256;
268    let plane = height * width;
269    let mut out = pixels.to_vec();
270    for c in 0..channels {
271        let ch = &pixels[c * plane..(c + 1) * plane];
272        let mut hist = [0u32; BINS];
273        for &p in ch.iter() {
274            let bin = ((p * (BINS as f32 - 1.0)).round() as usize).min(BINS - 1);
275            hist[bin] += 1;
276        }
277        // Compute CDF
278        let mut cdf = [0u32; BINS];
279        cdf[0] = hist[0];
280        for i in 1..BINS {
281            cdf[i] = cdf[i - 1] + hist[i];
282        }
283        let cdf_min = cdf.iter().find(|&&v| v > 0).copied().unwrap_or(0);
284        let total = plane as u32;
285        let denom = total.saturating_sub(cdf_min);
286        // Build the LUT
287        let mut lut = [0.0_f32; BINS];
288        for (i, lut_v) in lut.iter_mut().enumerate() {
289            if denom == 0 {
290                *lut_v = i as f32 / (BINS as f32 - 1.0);
291            } else {
292                let mapped = (cdf[i].saturating_sub(cdf_min)) as f32 / denom as f32;
293                *lut_v = mapped.clamp(0.0, 1.0);
294            }
295        }
296        for (dst, &src) in out[c * plane..(c + 1) * plane].iter_mut().zip(ch.iter()) {
297            let bin = ((src * (BINS as f32 - 1.0)).round() as usize).min(BINS - 1);
298            *dst = lut[bin];
299        }
300    }
301    out
302}
303
304/// Rotation by `angle_deg` degrees with bilinear interpolation (center pivot).
305fn op_rotate(
306    pixels: &[f32],
307    channels: usize,
308    height: usize,
309    width: usize,
310    angle_deg: f32,
311    fill_value: f32,
312) -> Vec<f32> {
313    let angle_rad = angle_deg * std::f32::consts::PI / 180.0;
314    let cos_a = angle_rad.cos();
315    let sin_a = angle_rad.sin();
316    let cx = (width as f32 - 1.0) / 2.0;
317    let cy = (height as f32 - 1.0) / 2.0;
318    // Inverse rotation: given output (x, y), find source.
319    // src_x = cos_a * (x - cx) + sin_a * (y - cy) + cx
320    // src_y = -sin_a * (x - cx) + cos_a * (y - cy) + cy
321    let a00 = cos_a;
322    let a01 = sin_a;
323    let a02 = -cos_a * cx - sin_a * cy + cx;
324    let a10 = -sin_a;
325    let a11 = cos_a;
326    let a12 = sin_a * cx - cos_a * cy + cy;
327    warp_affine(
328        pixels, channels, height, width, a00, a01, a02, a10, a11, a12, fill_value,
329    )
330}
331
332/// Solarize: invert pixels at or above `threshold`.
333fn op_solarize(pixels: &[f32], threshold: f32) -> Vec<f32> {
334    pixels
335        .iter()
336        .map(|&p| if p >= threshold { 1.0 - p } else { p })
337        .collect()
338}
339
340/// Color (saturation) adjustment.
341///
342/// `alpha` in `[0, 1]`: 0 = grayscale, 1 = original.
343/// Uses BT.601 luminance weights.
344fn op_color(pixels: &[f32], channels: usize, height: usize, width: usize, alpha: f32) -> Vec<f32> {
345    if channels != 3 {
346        // For non-RGB images, no-op.
347        return pixels.to_vec();
348    }
349    let plane = height * width;
350    let mut out = pixels.to_vec();
351    for i in 0..plane {
352        let r = pixels[i];
353        let g = pixels[plane + i];
354        let b = pixels[2 * plane + i];
355        let y = 0.299 * r + 0.587 * g + 0.114 * b;
356        out[i] = (alpha * r + (1.0 - alpha) * y).clamp(0.0, 1.0);
357        out[plane + i] = (alpha * g + (1.0 - alpha) * y).clamp(0.0, 1.0);
358        out[2 * plane + i] = (alpha * b + (1.0 - alpha) * y).clamp(0.0, 1.0);
359    }
360    out
361}
362
363/// Posterize: keep the top `k` bits of each pixel value (quantized to 8-bit).
364///
365/// `k` ranges from 4–8; lower = more posterized.
366fn op_posterize(pixels: &[f32], k: u32) -> Vec<f32> {
367    // k bits: mask is 0xFF with the lower (8-k) bits zeroed.
368    let shift = 8u32.saturating_sub(k);
369    let mask = if shift >= 8 { 0u8 } else { 0xFFu8 << shift };
370    pixels
371        .iter()
372        .map(|&p| {
373            let byte = (p * 255.0).round().clamp(0.0, 255.0) as u8;
374            let masked = byte & mask;
375            (masked as f32 / 255.0).clamp(0.0, 1.0)
376        })
377        .collect()
378}
379
380/// Contrast: blend between channel-mean and original.
381fn op_contrast(
382    pixels: &[f32],
383    channels: usize,
384    height: usize,
385    width: usize,
386    alpha: f32,
387) -> Vec<f32> {
388    let plane = height * width;
389    let mut out = pixels.to_vec();
390    for c in 0..channels {
391        let ch = &pixels[c * plane..(c + 1) * plane];
392        let mean = ch.iter().sum::<f32>() / plane as f32;
393        for (dst, &src) in out[c * plane..(c + 1) * plane].iter_mut().zip(ch.iter()) {
394            *dst = ((1.0 - alpha) * mean + alpha * src).clamp(0.0, 1.0);
395        }
396    }
397    out
398}
399
400/// Brightness: blend between black (0) and original.
401fn op_brightness(pixels: &[f32], strength: f32) -> Vec<f32> {
402    pixels
403        .iter()
404        .map(|&p| (strength * p).clamp(0.0, 1.0))
405        .collect()
406}
407
408/// Sharpness: blend between blurred (3×3 box) and original.
409fn op_sharpness(
410    pixels: &[f32],
411    channels: usize,
412    height: usize,
413    width: usize,
414    alpha: f32,
415) -> Vec<f32> {
416    // Build a blurred version using a 3×3 box filter.
417    let plane = height * width;
418    let mut blurred = vec![0.0_f32; channels * plane];
419    for c in 0..channels {
420        for y in 0..height {
421            for x in 0..width {
422                let mut acc = 0.0_f32;
423                let mut count = 0u32;
424                for dy in 0..3usize {
425                    let ny = y + dy;
426                    if ny == 0 || ny > height {
427                        continue;
428                    }
429                    let ny = ny - 1;
430                    for dx in 0..3usize {
431                        let nx = x + dx;
432                        if nx == 0 || nx > width {
433                            continue;
434                        }
435                        let nx = nx - 1;
436                        acc += pixels[chw_idx(c, ny, nx, height, width)];
437                        count += 1;
438                    }
439                }
440                blurred[chw_idx(c, y, x, height, width)] =
441                    if count > 0 { acc / count as f32 } else { 0.0 };
442            }
443        }
444    }
445    // Blend: alpha * original + (1 - alpha) * blurred
446    pixels
447        .iter()
448        .zip(blurred.iter())
449        .map(|(&orig, &blur)| (alpha * orig + (1.0 - alpha) * blur).clamp(0.0, 1.0))
450        .collect()
451}
452
453/// Horizontal shear by `shear` radians (inverse warp).
454fn op_shear_x(
455    pixels: &[f32],
456    channels: usize,
457    height: usize,
458    width: usize,
459    shear: f32,
460    fill_value: f32,
461) -> Vec<f32> {
462    // For output (x, y): src_x = x - shear * y; src_y = y.
463    warp_affine(
464        pixels, channels, height, width, 1.0, -shear, 0.0, // src_x coefficients
465        0.0, 1.0, 0.0, // src_y coefficients
466        fill_value,
467    )
468}
469
470/// Vertical shear by `shear` radians (inverse warp).
471fn op_shear_y(
472    pixels: &[f32],
473    channels: usize,
474    height: usize,
475    width: usize,
476    shear: f32,
477    fill_value: f32,
478) -> Vec<f32> {
479    // For output (x, y): src_x = x; src_y = y - shear * x.
480    warp_affine(
481        pixels, channels, height, width, 1.0, 0.0, 0.0, // src_x coefficients
482        -shear, 1.0, 0.0, // src_y coefficients
483        fill_value,
484    )
485}
486
487/// Horizontal translation by `shift_x` pixels.
488fn op_translate_x(
489    pixels: &[f32],
490    channels: usize,
491    height: usize,
492    width: usize,
493    shift_x: f32,
494    fill_value: f32,
495) -> Vec<f32> {
496    warp_affine(
497        pixels, channels, height, width, 1.0, 0.0, -shift_x, 0.0, 1.0, 0.0, fill_value,
498    )
499}
500
501/// Vertical translation by `shift_y` pixels.
502fn op_translate_y(
503    pixels: &[f32],
504    channels: usize,
505    height: usize,
506    width: usize,
507    shift_y: f32,
508    fill_value: f32,
509) -> Vec<f32> {
510    warp_affine(
511        pixels, channels, height, width, 1.0, 0.0, 0.0, 0.0, 1.0, -shift_y, fill_value,
512    )
513}
514
515// ─── Public API ───────────────────────────────────────────────────────────────
516
517/// Apply a single augmentation operation with the given magnitude and fill value.
518///
519/// # Parameters
520/// - `pixels`     — flat `[C × H × W]` CHW input in `[0, 1]`.
521/// - `channels`   — number of channels `C`.
522/// - `height`     — image height `H`.
523/// - `width`      — image width `W`.
524/// - `op`         — which [`AugOp`] to apply.
525/// - `magnitude`  — shared magnitude in `[0, 30]`.
526/// - `fill_value` — fill for geometric OOB pixels.
527///
528/// # Errors
529/// - [`SslError::EmptyInput`] if any dimension is zero.
530/// - [`SslError::DimensionMismatch`] if `pixels.len() != C·H·W`.
531/// - [`SslError::InvalidParameter`] for invalid magnitude or fill_value.
532pub fn apply_aug_op(
533    pixels: &[f32],
534    channels: usize,
535    height: usize,
536    width: usize,
537    op: &AugOp,
538    magnitude: f32,
539    fill_value: f32,
540) -> SslResult<Vec<f32>> {
541    if channels == 0 || height == 0 || width == 0 {
542        return Err(SslError::EmptyInput);
543    }
544    let expected = channels * height * width;
545    if pixels.len() != expected {
546        return Err(SslError::DimensionMismatch {
547            expected,
548            got: pixels.len(),
549        });
550    }
551    if !(magnitude.is_finite() && (0.0..=30.0).contains(&magnitude)) {
552        return Err(SslError::InvalidParameter {
553            name: "magnitude".into(),
554            reason: format!("must be in [0, 30] and finite, got {magnitude}"),
555        });
556    }
557    if !(fill_value.is_finite() && (0.0..=1.0).contains(&fill_value)) {
558        return Err(SslError::InvalidParameter {
559            name: "fill_value".into(),
560            reason: format!("must be in [0, 1] and finite, got {fill_value}"),
561        });
562    }
563
564    let m = magnitude / 30.0; // normalised to [0, 1]
565
566    let result = match op {
567        AugOp::Identity => pixels.to_vec(),
568
569        AugOp::AutoContrast => op_auto_contrast(pixels, channels, height, width),
570
571        AugOp::Equalize => op_equalize(pixels, channels, height, width),
572
573        AugOp::Rotate => {
574            // ±30° max; use a signed direction encoded by (m >= 0.5).
575            // In RandAugment the sign is sampled externally; here we use the
576            // direct magnitude linearly in [0°, 30°] (caller picks sign).
577            let angle = m * 30.0;
578            op_rotate(pixels, channels, height, width, angle, fill_value)
579        }
580
581        AugOp::Solarize => {
582            // threshold = 1 - m: magnitude=0 → threshold=1 (nothing flipped);
583            // magnitude=30 → threshold=0 (all pixels flipped).
584            let threshold = (1.0 - m).clamp(0.0, 1.0);
585            op_solarize(pixels, threshold)
586        }
587
588        AugOp::Color => {
589            // alpha=1 at magnitude=0 (original); alpha decreases with magnitude.
590            let alpha = (1.0 - m * 0.9).clamp(0.0, 1.0);
591            op_color(pixels, channels, height, width, alpha)
592        }
593
594        AugOp::Posterize => {
595            // k = 8 - floor(m * 4): range [4, 8].
596            let k = 8 - (m * 4.0).floor() as u32;
597            let k = k.max(1);
598            op_posterize(pixels, k)
599        }
600
601        AugOp::Contrast => {
602            // alpha=1 at magnitude=0 (original); blend toward mean.
603            let alpha = (1.0 - m * 0.9).clamp(0.0, 1.0);
604            op_contrast(pixels, channels, height, width, alpha)
605        }
606
607        AugOp::Brightness => {
608            // strength = m * 0.9 + 0.1 so at m=0 strength≈0.1 (dim) and m=1 strength=1.0.
609            let strength = (m * 0.9 + 0.1).clamp(0.0, 1.0);
610            op_brightness(pixels, strength)
611        }
612
613        AugOp::Sharpness => {
614            // alpha=1 → sharp (original); alpha=0 → fully blurred.
615            let alpha = m.clamp(0.0, 1.0);
616            op_sharpness(pixels, channels, height, width, alpha)
617        }
618
619        AugOp::ShearX => {
620            let shear = m * 0.3;
621            op_shear_x(pixels, channels, height, width, shear, fill_value)
622        }
623
624        AugOp::ShearY => {
625            let shear = m * 0.3;
626            op_shear_y(pixels, channels, height, width, shear, fill_value)
627        }
628
629        AugOp::TranslateX => {
630            let shift = m * 0.33 * width as f32;
631            op_translate_x(pixels, channels, height, width, shift, fill_value)
632        }
633
634        AugOp::TranslateY => {
635            let shift = m * 0.33 * height as f32;
636            op_translate_y(pixels, channels, height, width, shift, fill_value)
637        }
638    };
639
640    Ok(result)
641}
642
643/// Apply the RandAugment policy to a CHW image.
644///
645/// Randomly samples `config.n_ops` operations (with replacement) from
646/// `config.ops` and applies each in sequence using `config.magnitude`.
647/// When `n_ops == 0` the image is returned unchanged.
648///
649/// # Errors
650/// - [`SslError::EmptyInput`] if any dimension is zero.
651/// - [`SslError::DimensionMismatch`] if slice length != `C·H·W`.
652/// - [`SslError::InvalidParameter`] if config is invalid.
653pub fn rand_augment(
654    pixels: &[f32],
655    channels: usize,
656    height: usize,
657    width: usize,
658    config: &RandAugmentConfig,
659    rng: &mut LcgRng,
660) -> SslResult<Vec<f32>> {
661    if channels == 0 || height == 0 || width == 0 {
662        return Err(SslError::EmptyInput);
663    }
664    let expected = channels * height * width;
665    if pixels.len() != expected {
666        return Err(SslError::DimensionMismatch {
667            expected,
668            got: pixels.len(),
669        });
670    }
671    config.validate()?;
672
673    if config.n_ops == 0 {
674        return Ok(pixels.to_vec());
675    }
676
677    let n_pool = config.ops.len();
678    let mut current = pixels.to_vec();
679
680    for _ in 0..config.n_ops {
681        let idx = rng.next_usize(n_pool);
682        let op = &config.ops[idx];
683        current = apply_aug_op(
684            &current,
685            channels,
686            height,
687            width,
688            op,
689            config.magnitude,
690            config.fill_value,
691        )?;
692    }
693    Ok(current)
694}
695
696// ─── AutoAugment policy tables ────────────────────────────────────────────────
697
698/// Build the 25 ImageNet AutoAugment sub-policies from Cubuk et al., CVPR 2019.
699///
700/// Each entry is `((op, prob, mag_level), (op, prob, mag_level))` where
701/// `mag_level` is in `[0, 10]` (scaled ×3 to reach the 0–30 magnitude range).
702fn imagenet_sub_policies() -> Vec<SubPolicy> {
703    use AugOp::*;
704    vec![
705        ((Posterize, 0.4, 8), (Rotate, 0.6, 9)),
706        ((Solarize, 0.6, 5), (AutoContrast, 0.6, 5)),
707        ((Equalize, 0.8, 8), (Equalize, 0.6, 3)),
708        ((Posterize, 0.6, 7), (Posterize, 0.6, 6)),
709        ((Equalize, 0.4, 7), (Solarize, 0.2, 4)),
710        ((Equalize, 0.4, 4), (Rotate, 0.8, 8)),
711        ((Solarize, 0.6, 3), (Equalize, 0.6, 7)),
712        ((Posterize, 0.8, 5), (Equalize, 1.0, 2)),
713        ((Rotate, 0.2, 3), (Solarize, 0.6, 8)),
714        ((Equalize, 0.6, 8), (Posterize, 0.4, 6)),
715        ((Rotate, 0.8, 8), (Color, 1.0, 2)),
716        ((Rotate, 0.9, 9), (Equalize, 1.0, 2)),
717        ((Equalize, 0.6, 7), (Equalize, 0.6, 3)),
718        ((Equalize, 0.6, 4), (Rotate, 0.6, 4)),
719        ((Solarize, 0.6, 7), (Rotate, 0.6, 3)),
720        ((ShearX, 0.8, 8), (Solarize, 0.8, 4)),
721        ((Color, 0.8, 3), (Color, 1.0, 7)),
722        ((Color, 0.4, 1), (Rotate, 0.6, 8)),
723        ((Color, 0.8, 8), (Solarize, 0.8, 8)),
724        ((Equalize, 0.4, 8), (Equalize, 0.8, 3)),
725        ((Posterize, 0.4, 6), (Rotate, 0.4, 3)),
726        ((Equalize, 0.6, 7), (Color, 0.4, 4)),
727        ((Color, 0.4, 9), (Equalize, 0.6, 3)),
728        ((Color, 0.8, 8), (Contrast, 0.6, 1)),
729        ((Rotate, 0.8, 8), (Contrast, 1.0, 2)),
730    ]
731}
732
733/// Build the 25 CIFAR-10 AutoAugment sub-policies from Cubuk et al., CVPR 2019.
734fn cifar10_sub_policies() -> Vec<SubPolicy> {
735    use AugOp::*;
736    vec![
737        ((Equalize, 0.1, 8), (ShearY, 0.6, 4)),
738        ((Color, 0.6, 1), (Equalize, 0.6, 2)),
739        ((Sharpness, 0.6, 7), (Brightness, 0.6, 6)),
740        ((AutoContrast, 0.4, 0), (Equalize, 0.6, 0)),
741        ((Equalize, 1.0, 9), (ShearY, 0.6, 3)),
742        ((Color, 0.4, 3), (AutoContrast, 0.6, 1)),
743        ((ShearX, 0.8, 5), (Color, 1.0, 3)),
744        ((ShearX, 0.4, 4), (Posterize, 0.4, 7)),
745        ((Color, 0.4, 3), (Brightness, 0.6, 7)),
746        ((ShearY, 0.6, 4), (Color, 1.0, 9)),
747        ((Equalize, 0.6, 9), (Posterize, 0.4, 6)),
748        ((Solarize, 0.4, 9), (AutoContrast, 0.6, 3)),
749        ((AutoContrast, 0.6, 1), (Posterize, 0.6, 9)),
750        ((Equalize, 0.4, 9), (Solarize, 0.4, 5)),
751        ((Brightness, 0.2, 1), (Equalize, 0.6, 2)),
752        ((Equalize, 0.0, 0), (Equalize, 1.0, 0)),
753        ((AutoContrast, 0.2, 0), (Equalize, 0.6, 0)),
754        ((Equalize, 0.2, 0), (AutoContrast, 0.6, 0)),
755        ((Contrast, 0.2, 0), (Equalize, 0.6, 0)),
756        ((Brightness, 0.6, 5), (Contrast, 0.6, 6)),
757        ((AutoContrast, 0.8, 5), (Rotate, 0.6, 2)),
758        ((Solarize, 0.4, 3), (Brightness, 0.8, 9)),
759        ((Rotate, 0.6, 6), (Color, 1.0, 1)),
760        ((Equalize, 0.4, 5), (AutoContrast, 0.6, 5)),
761        ((Rotate, 0.6, 6), (Posterize, 0.8, 8)),
762    ]
763}
764
765/// Apply the AutoAugment policy to a CHW image.
766///
767/// 1. Uniformly samples one sub-policy from the policy's list.
768/// 2. For each of the two operations in the sub-policy, applies it with the
769///    corresponding probability and magnitude level.
770///
771/// AutoAugment magnitude levels are integers in `[0, 10]`; they are scaled ×3
772/// to map into the `[0, 30]` range expected by [`apply_aug_op`].
773///
774/// # Errors
775/// - [`SslError::EmptyInput`] if any dimension is zero.
776/// - [`SslError::DimensionMismatch`] if `pixels.len() != C·H·W`.
777/// - [`SslError::InvalidParameter`] if policy has no sub-policies.
778pub fn auto_augment(
779    pixels: &[f32],
780    channels: usize,
781    height: usize,
782    width: usize,
783    config: &AutoAugmentConfig,
784    rng: &mut LcgRng,
785) -> SslResult<Vec<f32>> {
786    if channels == 0 || height == 0 || width == 0 {
787        return Err(SslError::EmptyInput);
788    }
789    let expected = channels * height * width;
790    if pixels.len() != expected {
791        return Err(SslError::DimensionMismatch {
792            expected,
793            got: pixels.len(),
794        });
795    }
796    if !(config.fill_value.is_finite() && (0.0..=1.0).contains(&config.fill_value)) {
797        return Err(SslError::InvalidParameter {
798            name: "fill_value".into(),
799            reason: format!("must be in [0, 1] and finite, got {}", config.fill_value),
800        });
801    }
802
803    let sub_policies: Vec<SubPolicy> = match &config.policy {
804        AutoAugPolicy::ImageNet => imagenet_sub_policies(),
805        AutoAugPolicy::Cifar10 => cifar10_sub_policies(),
806        AutoAugPolicy::Custom(v) => v.clone(),
807    };
808
809    if sub_policies.is_empty() {
810        return Err(SslError::InvalidParameter {
811            name: "policy".into(),
812            reason: "policy contains no sub-policies".into(),
813        });
814    }
815
816    // Sample one sub-policy.
817    let sp_idx = rng.next_usize(sub_policies.len());
818    let ((op1, prob1, mag_level1), (op2, prob2, mag_level2)) = &sub_policies[sp_idx];
819
820    // Scale magnitude level [0, 10] → [0, 30].
821    let mag1 = (*mag_level1 as f32 * 3.0).clamp(0.0, 30.0);
822    let mag2 = (*mag_level2 as f32 * 3.0).clamp(0.0, 30.0);
823
824    let mut current = pixels.to_vec();
825
826    if rng.next_f32() < *prob1 {
827        current = apply_aug_op(
828            &current,
829            channels,
830            height,
831            width,
832            op1,
833            mag1,
834            config.fill_value,
835        )?;
836    }
837    if rng.next_f32() < *prob2 {
838        current = apply_aug_op(
839            &current,
840            channels,
841            height,
842            width,
843            op2,
844            mag2,
845            config.fill_value,
846        )?;
847    }
848    Ok(current)
849}
850
851// ─── Unit tests ───────────────────────────────────────────────────────────────
852
853#[cfg(test)]
854mod tests {
855    use super::*;
856
857    // ── Helpers ───────────────────────────────────────────────────────────────
858
859    /// Create a deterministic gradient CHW image.
860    fn gradient_image(channels: usize, height: usize, width: usize) -> Vec<f32> {
861        let n = channels * height * width;
862        (0..n)
863            .map(|i| {
864                let v = (i as f32) / (n as f32);
865                v.clamp(0.0, 1.0)
866            })
867            .collect()
868    }
869
870    /// Assert all pixels in `[0, 1]`.
871    fn assert_unit_range(pixels: &[f32], label: &str) {
872        for (i, &v) in pixels.iter().enumerate() {
873            assert!(
874                (0.0..=1.0).contains(&v),
875                "{label}: pixel[{i}] = {v} out of [0, 1]"
876            );
877        }
878    }
879
880    // ── Test 1: Output shape always matches input ──────────────────────────────
881
882    #[test]
883    fn output_shape_equals_input_for_all_ops() {
884        let (c, h, w) = (3, 16, 16);
885        let img = gradient_image(c, h, w);
886        let expected_len = c * h * w;
887
888        for op in all_aug_ops() {
889            let out =
890                apply_aug_op(&img, c, h, w, &op, 15.0, 0.5).expect("apply_aug_op should succeed");
891            assert_eq!(out.len(), expected_len, "shape mismatch for op {:?}", op);
892        }
893    }
894
895    // ── Test 2: All pixels in [0, 1] after any operation ─────────────────────
896
897    #[test]
898    fn all_pixels_in_unit_range_for_all_ops() {
899        let (c, h, w) = (3, 16, 16);
900        let img = gradient_image(c, h, w);
901
902        for op in all_aug_ops() {
903            let out =
904                apply_aug_op(&img, c, h, w, &op, 20.0, 0.5).expect("apply_aug_op should succeed");
905            assert_unit_range(&out, &format!("{op:?}"));
906        }
907    }
908
909    // ── Test 3: Identity op returns exact copy ────────────────────────────────
910
911    #[test]
912    fn identity_op_returns_exact_copy() {
913        let (c, h, w) = (3, 8, 8);
914        let img = gradient_image(c, h, w);
915        let out = apply_aug_op(&img, c, h, w, &AugOp::Identity, 15.0, 0.5)
916            .expect("apply_aug_op should succeed");
917        assert_eq!(out, img, "Identity must return exact copy");
918    }
919
920    // ── Test 4: AutoContrast stretches to [0, 1] per channel ─────────────────
921
922    #[test]
923    fn auto_contrast_stretches_to_unit() {
924        // Create image with known range per channel.
925        let (c, h, w) = (3, 4, 4);
926        let plane = h * w;
927        let mut img = vec![0.0_f32; c * plane];
928        // Channel 0: range [0.2, 0.8]
929        for v in img[0..plane].iter_mut() {
930            *v = 0.5;
931        }
932        img[0] = 0.2;
933        img[plane - 1] = 0.8;
934        // Channel 1: range [0.1, 0.9]
935        for v in img[plane..2 * plane].iter_mut() {
936            *v = 0.5;
937        }
938        img[plane] = 0.1;
939        img[2 * plane - 1] = 0.9;
940        // Channel 2: constant → should be left alone.
941        for v in img[2 * plane..].iter_mut() {
942            *v = 0.3;
943        }
944
945        let out = apply_aug_op(&img, c, h, w, &AugOp::AutoContrast, 0.0, 0.5)
946            .expect("apply_aug_op should succeed");
947        // Channel 0: min should become 0, max should become 1.
948        let ch0_min = out[..plane].iter().cloned().fold(f32::INFINITY, f32::min);
949        let ch0_max = out[..plane]
950            .iter()
951            .cloned()
952            .fold(f32::NEG_INFINITY, f32::max);
953        assert!(ch0_min.abs() < 1e-5, "ch0 min = {ch0_min}");
954        assert!((ch0_max - 1.0).abs() < 1e-5, "ch0 max = {ch0_max}");
955        // Channel 2: constant → should stay ≈ 0.3.
956        for &v in &out[2 * plane..] {
957            assert!((v - 0.3).abs() < 1e-5, "constant channel changed: {v}");
958        }
959    }
960
961    // ── Test 5: Equalize outputs in [0, 1] ────────────────────────────────────
962
963    #[test]
964    fn equalize_output_in_unit_range() {
965        let (c, h, w) = (1, 32, 32);
966        let img = gradient_image(c, h, w);
967        let out = apply_aug_op(&img, c, h, w, &AugOp::Equalize, 0.0, 0.5)
968            .expect("apply_aug_op should succeed");
969        assert_unit_range(&out, "Equalize");
970        assert_eq!(out.len(), c * h * w);
971    }
972
973    // ── Test 6: Rotate by 0° returns original ────────────────────────────────
974
975    #[test]
976    fn rotate_zero_degrees_approx_identity() {
977        let (c, h, w) = (1, 8, 8);
978        let img = gradient_image(c, h, w);
979        // magnitude = 0 → angle = 0°.
980        let out = apply_aug_op(&img, c, h, w, &AugOp::Rotate, 0.0, 0.5)
981            .expect("apply_aug_op should succeed");
982        for (i, (&a, &b)) in img.iter().zip(out.iter()).enumerate() {
983            assert!(
984                (a - b).abs() < 1e-4,
985                "rotate(0°): pixel[{i}]: input={a} output={b}"
986            );
987        }
988    }
989
990    // ── Test 7: Solarize with threshold=1.0 leaves all pixels unchanged ───────
991
992    #[test]
993    fn solarize_threshold_one_unchanged() {
994        // magnitude=0 → threshold = 1 - 0 = 1.0.
995        // No pixel in [0,1] is ≥ 1.0 (strictly), so nothing flips.
996        let (c, h, w) = (3, 8, 8);
997        let img = gradient_image(c, h, w);
998        let out = apply_aug_op(&img, c, h, w, &AugOp::Solarize, 0.0, 0.5)
999            .expect("apply_aug_op should succeed");
1000        // Pixels < 1.0 are unchanged; pixel at exactly 1.0 (if any) gets flipped to 0.
1001        for (i, (&a, &b)) in img.iter().zip(out.iter()).enumerate() {
1002            if a < 1.0 {
1003                assert!(
1004                    (a - b).abs() < 1e-6,
1005                    "solarize(threshold=1): pixel[{i}] changed: {a}→{b}"
1006                );
1007            }
1008        }
1009    }
1010
1011    // ── Test 8: RandAugment with N=0 returns unchanged image ─────────────────
1012
1013    #[test]
1014    fn rand_augment_zero_ops_unchanged() {
1015        let (c, h, w) = (3, 8, 8);
1016        let img = gradient_image(c, h, w);
1017        let config = RandAugmentConfig {
1018            n_ops: 0,
1019            magnitude: 9.0,
1020            fill_value: 0.5,
1021            ops: all_aug_ops(),
1022        };
1023        let mut rng = LcgRng::new(42);
1024        let out =
1025            rand_augment(&img, c, h, w, &config, &mut rng).expect("rand_augment should succeed");
1026        assert_eq!(out, img, "n_ops=0 must return exact input copy");
1027    }
1028
1029    // ── Test 9: RandAugment applies exactly N ops (implicit via shape) ────────
1030
1031    #[test]
1032    fn rand_augment_output_valid_shape_and_range() {
1033        let (c, h, w) = (3, 16, 16);
1034        let img = gradient_image(c, h, w);
1035        let config = RandAugmentConfig {
1036            n_ops: 3,
1037            magnitude: 15.0,
1038            fill_value: 0.5,
1039            ops: all_aug_ops(),
1040        };
1041        let mut rng = LcgRng::new(7);
1042        let out =
1043            rand_augment(&img, c, h, w, &config, &mut rng).expect("rand_augment should succeed");
1044        assert_eq!(out.len(), c * h * w);
1045        assert_unit_range(&out, "RandAugment(N=3)");
1046    }
1047
1048    // ── Test 10: AutoAugment ImageNet policy: output is finite and valid ──────
1049
1050    #[test]
1051    fn auto_augment_imagenet_output_finite_and_valid() {
1052        let (c, h, w) = (3, 16, 16);
1053        let img = gradient_image(c, h, w);
1054        let config = AutoAugmentConfig {
1055            policy: AutoAugPolicy::ImageNet,
1056            fill_value: 0.5,
1057        };
1058        let mut rng = LcgRng::new(13);
1059        let out =
1060            auto_augment(&img, c, h, w, &config, &mut rng).expect("auto_augment should succeed");
1061        assert_eq!(out.len(), c * h * w);
1062        assert_unit_range(&out, "AutoAugment(ImageNet)");
1063        for &v in &out {
1064            assert!(v.is_finite(), "non-finite pixel in AutoAugment output");
1065        }
1066    }
1067
1068    // ── Test 11: Different seeds → different augmentations ───────────────────
1069
1070    #[test]
1071    fn different_seeds_produce_different_outputs() {
1072        let (c, h, w) = (3, 16, 16);
1073        let img = gradient_image(c, h, w);
1074        let config = RandAugmentConfig::default();
1075
1076        let mut rng_a = LcgRng::new(1);
1077        let mut rng_b = LcgRng::new(999);
1078        let out_a =
1079            rand_augment(&img, c, h, w, &config, &mut rng_a).expect("rand_augment should succeed");
1080        let out_b =
1081            rand_augment(&img, c, h, w, &config, &mut rng_b).expect("rand_augment should succeed");
1082
1083        // It is overwhelmingly unlikely that two different random seeds produce
1084        // identical augmented outputs; if they do, the test catches a RNG bug.
1085        let identical = out_a
1086            .iter()
1087            .zip(out_b.iter())
1088            .all(|(a, b)| (a - b).abs() < 1e-8);
1089        assert!(!identical, "different seeds must produce different outputs");
1090    }
1091
1092    // ── Test 12: Same seed → same output (deterministic) ─────────────────────
1093
1094    #[test]
1095    fn same_seed_produces_same_output() {
1096        let (c, h, w) = (3, 16, 16);
1097        let img = gradient_image(c, h, w);
1098        let config = RandAugmentConfig::default();
1099
1100        let mut rng_a = LcgRng::new(42);
1101        let mut rng_b = LcgRng::new(42);
1102        let out_a =
1103            rand_augment(&img, c, h, w, &config, &mut rng_a).expect("rand_augment should succeed");
1104        let out_b =
1105            rand_augment(&img, c, h, w, &config, &mut rng_b).expect("rand_augment should succeed");
1106        assert_eq!(out_a, out_b, "same seed must produce identical output");
1107    }
1108
1109    // ── Test 13: Brightness at magnitude=0 dims image significantly ──────────
1110
1111    #[test]
1112    fn brightness_low_magnitude_dims_image() {
1113        let (c, h, w) = (3, 8, 8);
1114        let img = vec![0.8_f32; c * h * w];
1115        // magnitude=0 → strength = 0*0.9 + 0.1 = 0.1.
1116        let out = apply_aug_op(&img, c, h, w, &AugOp::Brightness, 0.0, 0.5)
1117            .expect("apply_aug_op should succeed");
1118        let mean_out: f32 = out.iter().sum::<f32>() / out.len() as f32;
1119        // 0.8 * 0.1 = 0.08; allow tolerance.
1120        assert!(
1121            mean_out < 0.2,
1122            "Brightness(mag=0) should produce near-black image, got mean={mean_out}"
1123        );
1124    }
1125
1126    // ── Test 14: apply_aug_op valid for all 14 ops without panic ─────────────
1127
1128    #[test]
1129    fn all_14_ops_run_without_error() {
1130        let (c, h, w) = (3, 12, 12);
1131        let img = gradient_image(c, h, w);
1132        for mag in [0.0_f32, 9.0, 15.0, 30.0] {
1133            for op in all_aug_ops() {
1134                let result = apply_aug_op(&img, c, h, w, &op, mag, 0.5);
1135                assert!(
1136                    result.is_ok(),
1137                    "op {:?} at magnitude={mag} returned error: {:?}",
1138                    op,
1139                    result
1140                );
1141                assert_unit_range(
1142                    &result.expect("result should be present"),
1143                    &format!("{op:?}@{mag}"),
1144                );
1145            }
1146        }
1147    }
1148
1149    // ── Test 15: AutoAugment CIFAR-10 policy ─────────────────────────────────
1150
1151    #[test]
1152    fn auto_augment_cifar10_output_valid() {
1153        let (c, h, w) = (3, 32, 32);
1154        let img = gradient_image(c, h, w);
1155        let config = AutoAugmentConfig {
1156            policy: AutoAugPolicy::Cifar10,
1157            fill_value: 0.5,
1158        };
1159        let mut rng = LcgRng::new(77);
1160        let out =
1161            auto_augment(&img, c, h, w, &config, &mut rng).expect("auto_augment should succeed");
1162        assert_eq!(out.len(), c * h * w);
1163        assert_unit_range(&out, "AutoAugment(Cifar10)");
1164    }
1165
1166    // ── Test 16: Custom AutoAugment policy ────────────────────────────────────
1167
1168    #[test]
1169    fn auto_augment_custom_policy_identity_always() {
1170        // A custom policy with a single sub-policy: Identity at prob=1.
1171        let (c, h, w) = (3, 8, 8);
1172        let img = gradient_image(c, h, w);
1173        let config = AutoAugmentConfig {
1174            policy: AutoAugPolicy::Custom(vec![(
1175                (AugOp::Identity, 1.0, 0),
1176                (AugOp::Identity, 1.0, 0),
1177            )]),
1178            fill_value: 0.5,
1179        };
1180        let mut rng = LcgRng::new(1);
1181        let out =
1182            auto_augment(&img, c, h, w, &config, &mut rng).expect("auto_augment should succeed");
1183        assert_eq!(
1184            out, img,
1185            "custom Identity × Identity should return exact copy"
1186        );
1187    }
1188
1189    // ── Test 17: Error on empty input ────────────────────────────────────────
1190
1191    #[test]
1192    fn error_on_empty_input() {
1193        let result = apply_aug_op(&[], 0, 8, 8, &AugOp::Identity, 0.0, 0.5);
1194        assert!(matches!(result, Err(SslError::EmptyInput)));
1195    }
1196
1197    // ── Test 18: Error on dimension mismatch ─────────────────────────────────
1198
1199    #[test]
1200    fn error_on_dimension_mismatch() {
1201        let img = vec![0.5_f32; 10]; // wrong for 3×4×4=48
1202        let result = apply_aug_op(&img, 3, 4, 4, &AugOp::Identity, 0.0, 0.5);
1203        assert!(matches!(result, Err(SslError::DimensionMismatch { .. })));
1204    }
1205
1206    // ── Test 19: Posterize at full magnitude quantizes heavily ────────────────
1207
1208    #[test]
1209    fn posterize_full_magnitude_reduces_unique_values() {
1210        let (c, h, w) = (1, 16, 16);
1211        let img = gradient_image(c, h, w);
1212        // magnitude=30 → k = 8 - floor(1.0 * 4) = 4 bits.
1213        let out = apply_aug_op(&img, c, h, w, &AugOp::Posterize, 30.0, 0.5)
1214            .expect("apply_aug_op should succeed");
1215        // With 4-bit posterization we expect at most 16 distinct values.
1216        let mut values: Vec<u32> = out.iter().map(|&v| (v * 255.0).round() as u32).collect();
1217        values.sort_unstable();
1218        values.dedup();
1219        assert!(
1220            values.len() <= 16,
1221            "expected ≤16 distinct values after 4-bit posterize, got {}",
1222            values.len()
1223        );
1224    }
1225
1226    // ── Test 20: Sharpness at magnitude=1 returns original ───────────────────
1227
1228    #[test]
1229    fn sharpness_full_magnitude_is_original() {
1230        let (c, h, w) = (3, 8, 8);
1231        let img = gradient_image(c, h, w);
1232        // alpha = magnitude/30 = 1.0 → pure original, no blur blended in.
1233        let out = apply_aug_op(&img, c, h, w, &AugOp::Sharpness, 30.0, 0.5)
1234            .expect("apply_aug_op should succeed");
1235        for (i, (&a, &b)) in img.iter().zip(out.iter()).enumerate() {
1236            assert!(
1237                (a - b).abs() < 1e-5,
1238                "Sharpness(1.0): pixel[{i}] input={a} output={b}"
1239            );
1240        }
1241    }
1242}