use crate::error::{SslError, SslResult};
use crate::handle::LcgRng;
#[derive(Debug, Clone, PartialEq)]
pub enum AugOp {
Identity,
AutoContrast,
Equalize,
Rotate,
Solarize,
Color,
Posterize,
Contrast,
Brightness,
Sharpness,
ShearX,
ShearY,
TranslateX,
TranslateY,
}
pub fn all_aug_ops() -> Vec<AugOp> {
vec![
AugOp::Identity,
AugOp::AutoContrast,
AugOp::Equalize,
AugOp::Rotate,
AugOp::Solarize,
AugOp::Color,
AugOp::Posterize,
AugOp::Contrast,
AugOp::Brightness,
AugOp::Sharpness,
AugOp::ShearX,
AugOp::ShearY,
AugOp::TranslateX,
AugOp::TranslateY,
]
}
#[derive(Debug, Clone)]
pub struct RandAugmentConfig {
pub n_ops: usize,
pub magnitude: f32,
pub fill_value: f32,
pub ops: Vec<AugOp>,
}
impl Default for RandAugmentConfig {
fn default() -> Self {
Self {
n_ops: 2,
magnitude: 9.0,
fill_value: 0.5,
ops: all_aug_ops(),
}
}
}
impl RandAugmentConfig {
pub fn validate(&self) -> SslResult<()> {
if !(self.magnitude.is_finite() && (0.0..=30.0).contains(&self.magnitude)) {
return Err(SslError::InvalidParameter {
name: "magnitude".into(),
reason: format!("must be in [0, 30] and finite, got {}", self.magnitude),
});
}
if !(self.fill_value.is_finite() && (0.0..=1.0).contains(&self.fill_value)) {
return Err(SslError::InvalidParameter {
name: "fill_value".into(),
reason: format!("must be in [0, 1] and finite, got {}", self.fill_value),
});
}
if self.ops.is_empty() {
return Err(SslError::InvalidParameter {
name: "ops".into(),
reason: "must contain at least one operation".into(),
});
}
Ok(())
}
}
pub type SubPolicy = ((AugOp, f32, usize), (AugOp, f32, usize));
#[derive(Debug, Clone)]
pub enum AutoAugPolicy {
ImageNet,
Cifar10,
Custom(Vec<SubPolicy>),
}
#[derive(Debug, Clone)]
pub struct AutoAugmentConfig {
pub policy: AutoAugPolicy,
pub fill_value: f32,
}
impl Default for AutoAugmentConfig {
fn default() -> Self {
Self {
policy: AutoAugPolicy::ImageNet,
fill_value: 0.5,
}
}
}
#[inline]
fn chw_idx(c: usize, y: usize, x: usize, height: usize, width: usize) -> usize {
c * height * width + y * width + x
}
fn bilinear_sample(
plane: &[f32],
height: usize,
width: usize,
fy: f32,
fx: f32,
fill_value: f32,
) -> f32 {
if fy < 0.0 || fx < 0.0 || fy > (height - 1) as f32 || fx > (width - 1) as f32 {
return fill_value;
}
let y0 = fy.floor() as usize;
let x0 = fx.floor() as usize;
let y1 = (y0 + 1).min(height - 1);
let x1 = (x0 + 1).min(width - 1);
let dy = fy - y0 as f32;
let dx = fx - x0 as f32;
let v00 = plane[y0 * width + x0];
let v01 = plane[y0 * width + x1];
let v10 = plane[y1 * width + x0];
let v11 = plane[y1 * width + x1];
let top = v00 * (1.0 - dx) + v01 * dx;
let bot = v10 * (1.0 - dx) + v11 * dx;
top * (1.0 - dy) + bot * dy
}
#[allow(clippy::too_many_arguments)]
fn warp_affine(
pixels: &[f32],
channels: usize,
height: usize,
width: usize,
a00: f32, a01: f32, a02: f32, a10: f32, a11: f32, a12: f32, fill_value: f32,
) -> Vec<f32> {
let plane = height * width;
let mut out = vec![fill_value; channels * plane];
for c in 0..channels {
let src_plane = &pixels[c * plane..(c + 1) * plane];
let dst_plane = &mut out[c * plane..(c + 1) * plane];
for y in 0..height {
for x in 0..width {
let fx = a00 * x as f32 + a01 * y as f32 + a02;
let fy = a10 * x as f32 + a11 * y as f32 + a12;
dst_plane[y * width + x] =
bilinear_sample(src_plane, height, width, fy, fx, fill_value);
}
}
}
out
}
fn op_auto_contrast(pixels: &[f32], channels: usize, height: usize, width: usize) -> Vec<f32> {
let plane = height * width;
let mut out = pixels.to_vec();
for c in 0..channels {
let ch = &pixels[c * plane..(c + 1) * plane];
let min_v = ch.iter().cloned().fold(f32::INFINITY, f32::min);
let max_v = ch.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
if (max_v - min_v).abs() < 1e-7 {
continue; }
let range = max_v - min_v;
for (dst, &src) in out[c * plane..(c + 1) * plane].iter_mut().zip(ch.iter()) {
*dst = ((src - min_v) / range).clamp(0.0, 1.0);
}
}
out
}
fn op_equalize(pixels: &[f32], channels: usize, height: usize, width: usize) -> Vec<f32> {
const BINS: usize = 256;
let plane = height * width;
let mut out = pixels.to_vec();
for c in 0..channels {
let ch = &pixels[c * plane..(c + 1) * plane];
let mut hist = [0u32; BINS];
for &p in ch.iter() {
let bin = ((p * (BINS as f32 - 1.0)).round() as usize).min(BINS - 1);
hist[bin] += 1;
}
let mut cdf = [0u32; BINS];
cdf[0] = hist[0];
for i in 1..BINS {
cdf[i] = cdf[i - 1] + hist[i];
}
let cdf_min = cdf.iter().find(|&&v| v > 0).copied().unwrap_or(0);
let total = plane as u32;
let denom = total.saturating_sub(cdf_min);
let mut lut = [0.0_f32; BINS];
for (i, lut_v) in lut.iter_mut().enumerate() {
if denom == 0 {
*lut_v = i as f32 / (BINS as f32 - 1.0);
} else {
let mapped = (cdf[i].saturating_sub(cdf_min)) as f32 / denom as f32;
*lut_v = mapped.clamp(0.0, 1.0);
}
}
for (dst, &src) in out[c * plane..(c + 1) * plane].iter_mut().zip(ch.iter()) {
let bin = ((src * (BINS as f32 - 1.0)).round() as usize).min(BINS - 1);
*dst = lut[bin];
}
}
out
}
fn op_rotate(
pixels: &[f32],
channels: usize,
height: usize,
width: usize,
angle_deg: f32,
fill_value: f32,
) -> Vec<f32> {
let angle_rad = angle_deg * std::f32::consts::PI / 180.0;
let cos_a = angle_rad.cos();
let sin_a = angle_rad.sin();
let cx = (width as f32 - 1.0) / 2.0;
let cy = (height as f32 - 1.0) / 2.0;
let a00 = cos_a;
let a01 = sin_a;
let a02 = -cos_a * cx - sin_a * cy + cx;
let a10 = -sin_a;
let a11 = cos_a;
let a12 = sin_a * cx - cos_a * cy + cy;
warp_affine(
pixels, channels, height, width, a00, a01, a02, a10, a11, a12, fill_value,
)
}
fn op_solarize(pixels: &[f32], threshold: f32) -> Vec<f32> {
pixels
.iter()
.map(|&p| if p >= threshold { 1.0 - p } else { p })
.collect()
}
fn op_color(pixels: &[f32], channels: usize, height: usize, width: usize, alpha: f32) -> Vec<f32> {
if channels != 3 {
return pixels.to_vec();
}
let plane = height * width;
let mut out = pixels.to_vec();
for i in 0..plane {
let r = pixels[i];
let g = pixels[plane + i];
let b = pixels[2 * plane + i];
let y = 0.299 * r + 0.587 * g + 0.114 * b;
out[i] = (alpha * r + (1.0 - alpha) * y).clamp(0.0, 1.0);
out[plane + i] = (alpha * g + (1.0 - alpha) * y).clamp(0.0, 1.0);
out[2 * plane + i] = (alpha * b + (1.0 - alpha) * y).clamp(0.0, 1.0);
}
out
}
fn op_posterize(pixels: &[f32], k: u32) -> Vec<f32> {
let shift = 8u32.saturating_sub(k);
let mask = if shift >= 8 { 0u8 } else { 0xFFu8 << shift };
pixels
.iter()
.map(|&p| {
let byte = (p * 255.0).round().clamp(0.0, 255.0) as u8;
let masked = byte & mask;
(masked as f32 / 255.0).clamp(0.0, 1.0)
})
.collect()
}
fn op_contrast(
pixels: &[f32],
channels: usize,
height: usize,
width: usize,
alpha: f32,
) -> Vec<f32> {
let plane = height * width;
let mut out = pixels.to_vec();
for c in 0..channels {
let ch = &pixels[c * plane..(c + 1) * plane];
let mean = ch.iter().sum::<f32>() / plane as f32;
for (dst, &src) in out[c * plane..(c + 1) * plane].iter_mut().zip(ch.iter()) {
*dst = ((1.0 - alpha) * mean + alpha * src).clamp(0.0, 1.0);
}
}
out
}
fn op_brightness(pixels: &[f32], strength: f32) -> Vec<f32> {
pixels
.iter()
.map(|&p| (strength * p).clamp(0.0, 1.0))
.collect()
}
fn op_sharpness(
pixels: &[f32],
channels: usize,
height: usize,
width: usize,
alpha: f32,
) -> Vec<f32> {
let plane = height * width;
let mut blurred = vec![0.0_f32; channels * plane];
for c in 0..channels {
for y in 0..height {
for x in 0..width {
let mut acc = 0.0_f32;
let mut count = 0u32;
for dy in 0..3usize {
let ny = y + dy;
if ny == 0 || ny > height {
continue;
}
let ny = ny - 1;
for dx in 0..3usize {
let nx = x + dx;
if nx == 0 || nx > width {
continue;
}
let nx = nx - 1;
acc += pixels[chw_idx(c, ny, nx, height, width)];
count += 1;
}
}
blurred[chw_idx(c, y, x, height, width)] =
if count > 0 { acc / count as f32 } else { 0.0 };
}
}
}
pixels
.iter()
.zip(blurred.iter())
.map(|(&orig, &blur)| (alpha * orig + (1.0 - alpha) * blur).clamp(0.0, 1.0))
.collect()
}
fn op_shear_x(
pixels: &[f32],
channels: usize,
height: usize,
width: usize,
shear: f32,
fill_value: f32,
) -> Vec<f32> {
warp_affine(
pixels, channels, height, width, 1.0, -shear, 0.0, 0.0, 1.0, 0.0, fill_value,
)
}
fn op_shear_y(
pixels: &[f32],
channels: usize,
height: usize,
width: usize,
shear: f32,
fill_value: f32,
) -> Vec<f32> {
warp_affine(
pixels, channels, height, width, 1.0, 0.0, 0.0, -shear, 1.0, 0.0, fill_value,
)
}
fn op_translate_x(
pixels: &[f32],
channels: usize,
height: usize,
width: usize,
shift_x: f32,
fill_value: f32,
) -> Vec<f32> {
warp_affine(
pixels, channels, height, width, 1.0, 0.0, -shift_x, 0.0, 1.0, 0.0, fill_value,
)
}
fn op_translate_y(
pixels: &[f32],
channels: usize,
height: usize,
width: usize,
shift_y: f32,
fill_value: f32,
) -> Vec<f32> {
warp_affine(
pixels, channels, height, width, 1.0, 0.0, 0.0, 0.0, 1.0, -shift_y, fill_value,
)
}
pub fn apply_aug_op(
pixels: &[f32],
channels: usize,
height: usize,
width: usize,
op: &AugOp,
magnitude: f32,
fill_value: f32,
) -> SslResult<Vec<f32>> {
if channels == 0 || height == 0 || width == 0 {
return Err(SslError::EmptyInput);
}
let expected = channels * height * width;
if pixels.len() != expected {
return Err(SslError::DimensionMismatch {
expected,
got: pixels.len(),
});
}
if !(magnitude.is_finite() && (0.0..=30.0).contains(&magnitude)) {
return Err(SslError::InvalidParameter {
name: "magnitude".into(),
reason: format!("must be in [0, 30] and finite, got {magnitude}"),
});
}
if !(fill_value.is_finite() && (0.0..=1.0).contains(&fill_value)) {
return Err(SslError::InvalidParameter {
name: "fill_value".into(),
reason: format!("must be in [0, 1] and finite, got {fill_value}"),
});
}
let m = magnitude / 30.0;
let result = match op {
AugOp::Identity => pixels.to_vec(),
AugOp::AutoContrast => op_auto_contrast(pixels, channels, height, width),
AugOp::Equalize => op_equalize(pixels, channels, height, width),
AugOp::Rotate => {
let angle = m * 30.0;
op_rotate(pixels, channels, height, width, angle, fill_value)
}
AugOp::Solarize => {
let threshold = (1.0 - m).clamp(0.0, 1.0);
op_solarize(pixels, threshold)
}
AugOp::Color => {
let alpha = (1.0 - m * 0.9).clamp(0.0, 1.0);
op_color(pixels, channels, height, width, alpha)
}
AugOp::Posterize => {
let k = 8 - (m * 4.0).floor() as u32;
let k = k.max(1);
op_posterize(pixels, k)
}
AugOp::Contrast => {
let alpha = (1.0 - m * 0.9).clamp(0.0, 1.0);
op_contrast(pixels, channels, height, width, alpha)
}
AugOp::Brightness => {
let strength = (m * 0.9 + 0.1).clamp(0.0, 1.0);
op_brightness(pixels, strength)
}
AugOp::Sharpness => {
let alpha = m.clamp(0.0, 1.0);
op_sharpness(pixels, channels, height, width, alpha)
}
AugOp::ShearX => {
let shear = m * 0.3;
op_shear_x(pixels, channels, height, width, shear, fill_value)
}
AugOp::ShearY => {
let shear = m * 0.3;
op_shear_y(pixels, channels, height, width, shear, fill_value)
}
AugOp::TranslateX => {
let shift = m * 0.33 * width as f32;
op_translate_x(pixels, channels, height, width, shift, fill_value)
}
AugOp::TranslateY => {
let shift = m * 0.33 * height as f32;
op_translate_y(pixels, channels, height, width, shift, fill_value)
}
};
Ok(result)
}
pub fn rand_augment(
pixels: &[f32],
channels: usize,
height: usize,
width: usize,
config: &RandAugmentConfig,
rng: &mut LcgRng,
) -> SslResult<Vec<f32>> {
if channels == 0 || height == 0 || width == 0 {
return Err(SslError::EmptyInput);
}
let expected = channels * height * width;
if pixels.len() != expected {
return Err(SslError::DimensionMismatch {
expected,
got: pixels.len(),
});
}
config.validate()?;
if config.n_ops == 0 {
return Ok(pixels.to_vec());
}
let n_pool = config.ops.len();
let mut current = pixels.to_vec();
for _ in 0..config.n_ops {
let idx = rng.next_usize(n_pool);
let op = &config.ops[idx];
current = apply_aug_op(
¤t,
channels,
height,
width,
op,
config.magnitude,
config.fill_value,
)?;
}
Ok(current)
}
fn imagenet_sub_policies() -> Vec<SubPolicy> {
use AugOp::*;
vec![
((Posterize, 0.4, 8), (Rotate, 0.6, 9)),
((Solarize, 0.6, 5), (AutoContrast, 0.6, 5)),
((Equalize, 0.8, 8), (Equalize, 0.6, 3)),
((Posterize, 0.6, 7), (Posterize, 0.6, 6)),
((Equalize, 0.4, 7), (Solarize, 0.2, 4)),
((Equalize, 0.4, 4), (Rotate, 0.8, 8)),
((Solarize, 0.6, 3), (Equalize, 0.6, 7)),
((Posterize, 0.8, 5), (Equalize, 1.0, 2)),
((Rotate, 0.2, 3), (Solarize, 0.6, 8)),
((Equalize, 0.6, 8), (Posterize, 0.4, 6)),
((Rotate, 0.8, 8), (Color, 1.0, 2)),
((Rotate, 0.9, 9), (Equalize, 1.0, 2)),
((Equalize, 0.6, 7), (Equalize, 0.6, 3)),
((Equalize, 0.6, 4), (Rotate, 0.6, 4)),
((Solarize, 0.6, 7), (Rotate, 0.6, 3)),
((ShearX, 0.8, 8), (Solarize, 0.8, 4)),
((Color, 0.8, 3), (Color, 1.0, 7)),
((Color, 0.4, 1), (Rotate, 0.6, 8)),
((Color, 0.8, 8), (Solarize, 0.8, 8)),
((Equalize, 0.4, 8), (Equalize, 0.8, 3)),
((Posterize, 0.4, 6), (Rotate, 0.4, 3)),
((Equalize, 0.6, 7), (Color, 0.4, 4)),
((Color, 0.4, 9), (Equalize, 0.6, 3)),
((Color, 0.8, 8), (Contrast, 0.6, 1)),
((Rotate, 0.8, 8), (Contrast, 1.0, 2)),
]
}
fn cifar10_sub_policies() -> Vec<SubPolicy> {
use AugOp::*;
vec![
((Equalize, 0.1, 8), (ShearY, 0.6, 4)),
((Color, 0.6, 1), (Equalize, 0.6, 2)),
((Sharpness, 0.6, 7), (Brightness, 0.6, 6)),
((AutoContrast, 0.4, 0), (Equalize, 0.6, 0)),
((Equalize, 1.0, 9), (ShearY, 0.6, 3)),
((Color, 0.4, 3), (AutoContrast, 0.6, 1)),
((ShearX, 0.8, 5), (Color, 1.0, 3)),
((ShearX, 0.4, 4), (Posterize, 0.4, 7)),
((Color, 0.4, 3), (Brightness, 0.6, 7)),
((ShearY, 0.6, 4), (Color, 1.0, 9)),
((Equalize, 0.6, 9), (Posterize, 0.4, 6)),
((Solarize, 0.4, 9), (AutoContrast, 0.6, 3)),
((AutoContrast, 0.6, 1), (Posterize, 0.6, 9)),
((Equalize, 0.4, 9), (Solarize, 0.4, 5)),
((Brightness, 0.2, 1), (Equalize, 0.6, 2)),
((Equalize, 0.0, 0), (Equalize, 1.0, 0)),
((AutoContrast, 0.2, 0), (Equalize, 0.6, 0)),
((Equalize, 0.2, 0), (AutoContrast, 0.6, 0)),
((Contrast, 0.2, 0), (Equalize, 0.6, 0)),
((Brightness, 0.6, 5), (Contrast, 0.6, 6)),
((AutoContrast, 0.8, 5), (Rotate, 0.6, 2)),
((Solarize, 0.4, 3), (Brightness, 0.8, 9)),
((Rotate, 0.6, 6), (Color, 1.0, 1)),
((Equalize, 0.4, 5), (AutoContrast, 0.6, 5)),
((Rotate, 0.6, 6), (Posterize, 0.8, 8)),
]
}
pub fn auto_augment(
pixels: &[f32],
channels: usize,
height: usize,
width: usize,
config: &AutoAugmentConfig,
rng: &mut LcgRng,
) -> SslResult<Vec<f32>> {
if channels == 0 || height == 0 || width == 0 {
return Err(SslError::EmptyInput);
}
let expected = channels * height * width;
if pixels.len() != expected {
return Err(SslError::DimensionMismatch {
expected,
got: pixels.len(),
});
}
if !(config.fill_value.is_finite() && (0.0..=1.0).contains(&config.fill_value)) {
return Err(SslError::InvalidParameter {
name: "fill_value".into(),
reason: format!("must be in [0, 1] and finite, got {}", config.fill_value),
});
}
let sub_policies: Vec<SubPolicy> = match &config.policy {
AutoAugPolicy::ImageNet => imagenet_sub_policies(),
AutoAugPolicy::Cifar10 => cifar10_sub_policies(),
AutoAugPolicy::Custom(v) => v.clone(),
};
if sub_policies.is_empty() {
return Err(SslError::InvalidParameter {
name: "policy".into(),
reason: "policy contains no sub-policies".into(),
});
}
let sp_idx = rng.next_usize(sub_policies.len());
let ((op1, prob1, mag_level1), (op2, prob2, mag_level2)) = &sub_policies[sp_idx];
let mag1 = (*mag_level1 as f32 * 3.0).clamp(0.0, 30.0);
let mag2 = (*mag_level2 as f32 * 3.0).clamp(0.0, 30.0);
let mut current = pixels.to_vec();
if rng.next_f32() < *prob1 {
current = apply_aug_op(
¤t,
channels,
height,
width,
op1,
mag1,
config.fill_value,
)?;
}
if rng.next_f32() < *prob2 {
current = apply_aug_op(
¤t,
channels,
height,
width,
op2,
mag2,
config.fill_value,
)?;
}
Ok(current)
}
#[cfg(test)]
mod tests {
use super::*;
fn gradient_image(channels: usize, height: usize, width: usize) -> Vec<f32> {
let n = channels * height * width;
(0..n)
.map(|i| {
let v = (i as f32) / (n as f32);
v.clamp(0.0, 1.0)
})
.collect()
}
fn assert_unit_range(pixels: &[f32], label: &str) {
for (i, &v) in pixels.iter().enumerate() {
assert!(
(0.0..=1.0).contains(&v),
"{label}: pixel[{i}] = {v} out of [0, 1]"
);
}
}
#[test]
fn output_shape_equals_input_for_all_ops() {
let (c, h, w) = (3, 16, 16);
let img = gradient_image(c, h, w);
let expected_len = c * h * w;
for op in all_aug_ops() {
let out =
apply_aug_op(&img, c, h, w, &op, 15.0, 0.5).expect("apply_aug_op should succeed");
assert_eq!(out.len(), expected_len, "shape mismatch for op {:?}", op);
}
}
#[test]
fn all_pixels_in_unit_range_for_all_ops() {
let (c, h, w) = (3, 16, 16);
let img = gradient_image(c, h, w);
for op in all_aug_ops() {
let out =
apply_aug_op(&img, c, h, w, &op, 20.0, 0.5).expect("apply_aug_op should succeed");
assert_unit_range(&out, &format!("{op:?}"));
}
}
#[test]
fn identity_op_returns_exact_copy() {
let (c, h, w) = (3, 8, 8);
let img = gradient_image(c, h, w);
let out = apply_aug_op(&img, c, h, w, &AugOp::Identity, 15.0, 0.5)
.expect("apply_aug_op should succeed");
assert_eq!(out, img, "Identity must return exact copy");
}
#[test]
fn auto_contrast_stretches_to_unit() {
let (c, h, w) = (3, 4, 4);
let plane = h * w;
let mut img = vec![0.0_f32; c * plane];
for v in img[0..plane].iter_mut() {
*v = 0.5;
}
img[0] = 0.2;
img[plane - 1] = 0.8;
for v in img[plane..2 * plane].iter_mut() {
*v = 0.5;
}
img[plane] = 0.1;
img[2 * plane - 1] = 0.9;
for v in img[2 * plane..].iter_mut() {
*v = 0.3;
}
let out = apply_aug_op(&img, c, h, w, &AugOp::AutoContrast, 0.0, 0.5)
.expect("apply_aug_op should succeed");
let ch0_min = out[..plane].iter().cloned().fold(f32::INFINITY, f32::min);
let ch0_max = out[..plane]
.iter()
.cloned()
.fold(f32::NEG_INFINITY, f32::max);
assert!(ch0_min.abs() < 1e-5, "ch0 min = {ch0_min}");
assert!((ch0_max - 1.0).abs() < 1e-5, "ch0 max = {ch0_max}");
for &v in &out[2 * plane..] {
assert!((v - 0.3).abs() < 1e-5, "constant channel changed: {v}");
}
}
#[test]
fn equalize_output_in_unit_range() {
let (c, h, w) = (1, 32, 32);
let img = gradient_image(c, h, w);
let out = apply_aug_op(&img, c, h, w, &AugOp::Equalize, 0.0, 0.5)
.expect("apply_aug_op should succeed");
assert_unit_range(&out, "Equalize");
assert_eq!(out.len(), c * h * w);
}
#[test]
fn rotate_zero_degrees_approx_identity() {
let (c, h, w) = (1, 8, 8);
let img = gradient_image(c, h, w);
let out = apply_aug_op(&img, c, h, w, &AugOp::Rotate, 0.0, 0.5)
.expect("apply_aug_op should succeed");
for (i, (&a, &b)) in img.iter().zip(out.iter()).enumerate() {
assert!(
(a - b).abs() < 1e-4,
"rotate(0°): pixel[{i}]: input={a} output={b}"
);
}
}
#[test]
fn solarize_threshold_one_unchanged() {
let (c, h, w) = (3, 8, 8);
let img = gradient_image(c, h, w);
let out = apply_aug_op(&img, c, h, w, &AugOp::Solarize, 0.0, 0.5)
.expect("apply_aug_op should succeed");
for (i, (&a, &b)) in img.iter().zip(out.iter()).enumerate() {
if a < 1.0 {
assert!(
(a - b).abs() < 1e-6,
"solarize(threshold=1): pixel[{i}] changed: {a}→{b}"
);
}
}
}
#[test]
fn rand_augment_zero_ops_unchanged() {
let (c, h, w) = (3, 8, 8);
let img = gradient_image(c, h, w);
let config = RandAugmentConfig {
n_ops: 0,
magnitude: 9.0,
fill_value: 0.5,
ops: all_aug_ops(),
};
let mut rng = LcgRng::new(42);
let out =
rand_augment(&img, c, h, w, &config, &mut rng).expect("rand_augment should succeed");
assert_eq!(out, img, "n_ops=0 must return exact input copy");
}
#[test]
fn rand_augment_output_valid_shape_and_range() {
let (c, h, w) = (3, 16, 16);
let img = gradient_image(c, h, w);
let config = RandAugmentConfig {
n_ops: 3,
magnitude: 15.0,
fill_value: 0.5,
ops: all_aug_ops(),
};
let mut rng = LcgRng::new(7);
let out =
rand_augment(&img, c, h, w, &config, &mut rng).expect("rand_augment should succeed");
assert_eq!(out.len(), c * h * w);
assert_unit_range(&out, "RandAugment(N=3)");
}
#[test]
fn auto_augment_imagenet_output_finite_and_valid() {
let (c, h, w) = (3, 16, 16);
let img = gradient_image(c, h, w);
let config = AutoAugmentConfig {
policy: AutoAugPolicy::ImageNet,
fill_value: 0.5,
};
let mut rng = LcgRng::new(13);
let out =
auto_augment(&img, c, h, w, &config, &mut rng).expect("auto_augment should succeed");
assert_eq!(out.len(), c * h * w);
assert_unit_range(&out, "AutoAugment(ImageNet)");
for &v in &out {
assert!(v.is_finite(), "non-finite pixel in AutoAugment output");
}
}
#[test]
fn different_seeds_produce_different_outputs() {
let (c, h, w) = (3, 16, 16);
let img = gradient_image(c, h, w);
let config = RandAugmentConfig::default();
let mut rng_a = LcgRng::new(1);
let mut rng_b = LcgRng::new(999);
let out_a =
rand_augment(&img, c, h, w, &config, &mut rng_a).expect("rand_augment should succeed");
let out_b =
rand_augment(&img, c, h, w, &config, &mut rng_b).expect("rand_augment should succeed");
let identical = out_a
.iter()
.zip(out_b.iter())
.all(|(a, b)| (a - b).abs() < 1e-8);
assert!(!identical, "different seeds must produce different outputs");
}
#[test]
fn same_seed_produces_same_output() {
let (c, h, w) = (3, 16, 16);
let img = gradient_image(c, h, w);
let config = RandAugmentConfig::default();
let mut rng_a = LcgRng::new(42);
let mut rng_b = LcgRng::new(42);
let out_a =
rand_augment(&img, c, h, w, &config, &mut rng_a).expect("rand_augment should succeed");
let out_b =
rand_augment(&img, c, h, w, &config, &mut rng_b).expect("rand_augment should succeed");
assert_eq!(out_a, out_b, "same seed must produce identical output");
}
#[test]
fn brightness_low_magnitude_dims_image() {
let (c, h, w) = (3, 8, 8);
let img = vec![0.8_f32; c * h * w];
let out = apply_aug_op(&img, c, h, w, &AugOp::Brightness, 0.0, 0.5)
.expect("apply_aug_op should succeed");
let mean_out: f32 = out.iter().sum::<f32>() / out.len() as f32;
assert!(
mean_out < 0.2,
"Brightness(mag=0) should produce near-black image, got mean={mean_out}"
);
}
#[test]
fn all_14_ops_run_without_error() {
let (c, h, w) = (3, 12, 12);
let img = gradient_image(c, h, w);
for mag in [0.0_f32, 9.0, 15.0, 30.0] {
for op in all_aug_ops() {
let result = apply_aug_op(&img, c, h, w, &op, mag, 0.5);
assert!(
result.is_ok(),
"op {:?} at magnitude={mag} returned error: {:?}",
op,
result
);
assert_unit_range(
&result.expect("result should be present"),
&format!("{op:?}@{mag}"),
);
}
}
}
#[test]
fn auto_augment_cifar10_output_valid() {
let (c, h, w) = (3, 32, 32);
let img = gradient_image(c, h, w);
let config = AutoAugmentConfig {
policy: AutoAugPolicy::Cifar10,
fill_value: 0.5,
};
let mut rng = LcgRng::new(77);
let out =
auto_augment(&img, c, h, w, &config, &mut rng).expect("auto_augment should succeed");
assert_eq!(out.len(), c * h * w);
assert_unit_range(&out, "AutoAugment(Cifar10)");
}
#[test]
fn auto_augment_custom_policy_identity_always() {
let (c, h, w) = (3, 8, 8);
let img = gradient_image(c, h, w);
let config = AutoAugmentConfig {
policy: AutoAugPolicy::Custom(vec![(
(AugOp::Identity, 1.0, 0),
(AugOp::Identity, 1.0, 0),
)]),
fill_value: 0.5,
};
let mut rng = LcgRng::new(1);
let out =
auto_augment(&img, c, h, w, &config, &mut rng).expect("auto_augment should succeed");
assert_eq!(
out, img,
"custom Identity × Identity should return exact copy"
);
}
#[test]
fn error_on_empty_input() {
let result = apply_aug_op(&[], 0, 8, 8, &AugOp::Identity, 0.0, 0.5);
assert!(matches!(result, Err(SslError::EmptyInput)));
}
#[test]
fn error_on_dimension_mismatch() {
let img = vec![0.5_f32; 10]; let result = apply_aug_op(&img, 3, 4, 4, &AugOp::Identity, 0.0, 0.5);
assert!(matches!(result, Err(SslError::DimensionMismatch { .. })));
}
#[test]
fn posterize_full_magnitude_reduces_unique_values() {
let (c, h, w) = (1, 16, 16);
let img = gradient_image(c, h, w);
let out = apply_aug_op(&img, c, h, w, &AugOp::Posterize, 30.0, 0.5)
.expect("apply_aug_op should succeed");
let mut values: Vec<u32> = out.iter().map(|&v| (v * 255.0).round() as u32).collect();
values.sort_unstable();
values.dedup();
assert!(
values.len() <= 16,
"expected ≤16 distinct values after 4-bit posterize, got {}",
values.len()
);
}
#[test]
fn sharpness_full_magnitude_is_original() {
let (c, h, w) = (3, 8, 8);
let img = gradient_image(c, h, w);
let out = apply_aug_op(&img, c, h, w, &AugOp::Sharpness, 30.0, 0.5)
.expect("apply_aug_op should succeed");
for (i, (&a, &b)) in img.iter().zip(out.iter()).enumerate() {
assert!(
(a - b).abs() < 1e-5,
"Sharpness(1.0): pixel[{i}] input={a} output={b}"
);
}
}
}