pub mod geometric;
pub mod mixup;
pub mod normalize;
pub mod photometric;
pub use mixup::{MixOutput, cutmix, mixup};
use crate::{error::VisionResult, handle::LcgRng};
use geometric::{center_crop, random_crop, random_horizontal_flip, resize_bilinear};
use normalize::normalize_chw;
use photometric::{color_jitter, random_grayscale};
#[derive(Debug, Clone)]
pub enum AugOp {
RandomCrop { crop_size: usize },
CenterCrop { crop_size: usize },
HorizontalFlip { prob: f32 },
Resize { target: usize },
ColorJitter {
brightness: f32,
contrast: f32,
saturation: f32,
},
RandomGrayscale { prob: f32 },
Normalize { mean: [f32; 3], std: [f32; 3] },
}
impl AugOp {
pub fn apply(
&self,
img: &[f32],
channels: usize,
h: usize,
w: usize,
rng: &mut LcgRng,
) -> VisionResult<(Vec<f32>, usize, usize)> {
match self {
AugOp::RandomCrop { crop_size } => {
let out = random_crop(img, channels, h, w, *crop_size, rng)?;
Ok((out, *crop_size, *crop_size))
}
AugOp::CenterCrop { crop_size } => {
let out = center_crop(img, channels, h, w, *crop_size)?;
Ok((out, *crop_size, *crop_size))
}
AugOp::HorizontalFlip { prob } => {
let out = random_horizontal_flip(img, channels, h, w, *prob, rng);
Ok((out, h, w))
}
AugOp::Resize { target } => {
let out = resize_bilinear(img, channels, h, w, *target)?;
Ok((out, *target, *target))
}
AugOp::ColorJitter {
brightness,
contrast,
saturation,
} => {
let out = color_jitter(
img,
channels,
h,
w,
*brightness,
*contrast,
*saturation,
rng,
);
Ok((out, h, w))
}
AugOp::RandomGrayscale { prob } => {
let out = random_grayscale(img, channels, h, w, *prob, rng);
Ok((out, h, w))
}
AugOp::Normalize { mean, std } => {
let out = normalize_chw(img, channels, h, w, mean, std)?;
Ok((out, h, w))
}
}
}
}
#[derive(Debug, Clone, Default)]
pub struct Pipeline {
pub ops: Vec<AugOp>,
}
impl Pipeline {
#[must_use]
pub fn new() -> Self {
Self { ops: Vec::new() }
}
#[must_use]
pub fn push(mut self, op: AugOp) -> Self {
self.ops.push(op);
self
}
pub fn apply(
&self,
img: &[f32],
channels: usize,
h: usize,
w: usize,
rng: &mut LcgRng,
) -> VisionResult<(Vec<f32>, usize, usize)> {
if self.ops.is_empty() {
return Ok((img.to_vec(), h, w));
}
let (mut cur_img, mut cur_h, mut cur_w) = self.ops[0].apply(img, channels, h, w, rng)?;
for op in &self.ops[1..] {
let (next_img, next_h, next_w) = op.apply(&cur_img, channels, cur_h, cur_w, rng)?;
cur_img = next_img;
cur_h = next_h;
cur_w = next_w;
}
Ok((cur_img, cur_h, cur_w))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::handle::LcgRng;
use normalize::{IMAGENET_MEAN, IMAGENET_STD};
fn ramp_rgb(h: usize, w: usize) -> Vec<f32> {
let hw = h * w;
(0..3 * hw).map(|i| i as f32 / (3 * hw) as f32).collect()
}
#[test]
fn aug_op_random_crop_updates_dims() {
let img = ramp_rgb(32, 32);
let mut rng = LcgRng::new(1);
let op = AugOp::RandomCrop { crop_size: 24 };
let (out, new_h, new_w) = op.apply(&img, 3, 32, 32, &mut rng).expect("ok");
assert_eq!((new_h, new_w), (24, 24));
assert_eq!(out.len(), 3 * 24 * 24);
}
#[test]
fn aug_op_center_crop_updates_dims() {
let img = ramp_rgb(32, 32);
let mut rng = LcgRng::new(2);
let op = AugOp::CenterCrop { crop_size: 16 };
let (out, new_h, new_w) = op.apply(&img, 3, 32, 32, &mut rng).expect("ok");
assert_eq!((new_h, new_w), (16, 16));
assert_eq!(out.len(), 3 * 16 * 16);
}
#[test]
fn aug_op_flip_preserves_dims() {
let img = ramp_rgb(16, 16);
let mut rng = LcgRng::new(3);
let op = AugOp::HorizontalFlip { prob: 0.5 };
let (out, new_h, new_w) = op.apply(&img, 3, 16, 16, &mut rng).expect("ok");
assert_eq!((new_h, new_w), (16, 16));
assert_eq!(out.len(), img.len());
}
#[test]
fn aug_op_resize_updates_dims() {
let img = ramp_rgb(64, 64);
let mut rng = LcgRng::new(4);
let op = AugOp::Resize { target: 32 };
let (out, new_h, new_w) = op.apply(&img, 3, 64, 64, &mut rng).expect("ok");
assert_eq!((new_h, new_w), (32, 32));
assert_eq!(out.len(), 3 * 32 * 32);
}
#[test]
fn aug_op_color_jitter_preserves_dims() {
let img = ramp_rgb(16, 16);
let mut rng = LcgRng::new(5);
let op = AugOp::ColorJitter {
brightness: 0.2,
contrast: 0.2,
saturation: 0.2,
};
let (out, new_h, new_w) = op.apply(&img, 3, 16, 16, &mut rng).expect("ok");
assert_eq!((new_h, new_w), (16, 16));
assert_eq!(out.len(), img.len());
}
#[test]
fn aug_op_grayscale_preserves_dims() {
let img = ramp_rgb(16, 16);
let mut rng = LcgRng::new(6);
let op = AugOp::RandomGrayscale { prob: 0.5 };
let (out, new_h, new_w) = op.apply(&img, 3, 16, 16, &mut rng).expect("ok");
assert_eq!((new_h, new_w), (16, 16));
assert_eq!(out.len(), img.len());
}
#[test]
fn aug_op_normalize_preserves_dims() {
let img = ramp_rgb(16, 16);
let mut rng = LcgRng::new(7);
let op = AugOp::Normalize {
mean: IMAGENET_MEAN,
std: IMAGENET_STD,
};
let (out, new_h, new_w) = op.apply(&img, 3, 16, 16, &mut rng).expect("ok");
assert_eq!((new_h, new_w), (16, 16));
assert_eq!(out.len(), img.len());
}
#[test]
fn pipeline_empty_returns_clone() {
let img = ramp_rgb(16, 16);
let pipeline = Pipeline::new();
let mut rng = LcgRng::new(8);
let (out, new_h, new_w) = pipeline.apply(&img, 3, 16, 16, &mut rng).expect("ok");
assert_eq!((new_h, new_w), (16, 16));
assert_eq!(out, img);
}
#[test]
fn pipeline_single_op() {
let img = ramp_rgb(32, 32);
let pipeline = Pipeline::new().push(AugOp::Resize { target: 16 });
let mut rng = LcgRng::new(9);
let (out, new_h, new_w) = pipeline.apply(&img, 3, 32, 32, &mut rng).expect("ok");
assert_eq!((new_h, new_w), (16, 16));
assert_eq!(out.len(), 3 * 16 * 16);
}
#[test]
fn pipeline_multi_op_dims_chain() {
let img = ramp_rgb(64, 64);
let pipeline = Pipeline::new()
.push(AugOp::Resize { target: 32 })
.push(AugOp::CenterCrop { crop_size: 24 });
let mut rng = LcgRng::new(10);
let (out, new_h, new_w) = pipeline.apply(&img, 3, 64, 64, &mut rng).expect("ok");
assert_eq!((new_h, new_w), (24, 24));
assert_eq!(out.len(), 3 * 24 * 24);
}
#[test]
fn pipeline_full_augmentation_chain() {
let img: Vec<f32> = (0..3 * 256 * 256)
.map(|i| i as f32 / (3.0 * 256.0 * 256.0))
.collect();
let pipeline = Pipeline::new()
.push(AugOp::Resize { target: 256 })
.push(AugOp::RandomCrop { crop_size: 224 })
.push(AugOp::HorizontalFlip { prob: 0.5 })
.push(AugOp::ColorJitter {
brightness: 0.1,
contrast: 0.1,
saturation: 0.1,
})
.push(AugOp::Normalize {
mean: IMAGENET_MEAN,
std: IMAGENET_STD,
});
let mut rng = LcgRng::new(11);
let (out, new_h, new_w) = pipeline.apply(&img, 3, 256, 256, &mut rng).expect("ok");
assert_eq!((new_h, new_w), (224, 224));
assert_eq!(out.len(), 3 * 224 * 224);
assert!(
out.iter().all(|v| v.is_finite()),
"pipeline output must be finite"
);
}
#[test]
fn pipeline_add_is_builder() {
let p = Pipeline::new()
.push(AugOp::HorizontalFlip { prob: 1.0 })
.push(AugOp::HorizontalFlip { prob: 1.0 });
assert_eq!(p.ops.len(), 2);
let img = ramp_rgb(8, 8);
let mut rng = LcgRng::new(12);
let (out, _, _) = p.apply(&img, 3, 8, 8, &mut rng).expect("ok");
for (i, (&a, &b)) in img.iter().zip(out.iter()).enumerate() {
assert!(
(a - b).abs() < 1e-6,
"pixel {i}: double-flip should be identity"
);
}
}
#[test]
fn pipeline_clone_is_independent() {
let p1 = Pipeline::new().push(AugOp::Resize { target: 16 });
let p2 = p1.clone();
assert_eq!(p1.ops.len(), p2.ops.len());
}
#[test]
fn aug_op_error_propagated_through_pipeline() {
let img = ramp_rgb(16, 16);
let pipeline = Pipeline::new().push(AugOp::CenterCrop { crop_size: 32 }); let mut rng = LcgRng::new(13);
let r = pipeline.apply(&img, 3, 16, 16, &mut rng);
assert!(
r.is_err(),
"oversized crop through pipeline should propagate error"
);
}
}