1pub mod geometric;
10pub mod mixup;
11pub mod normalize;
12pub mod photometric;
13
14pub use mixup::{MixOutput, cutmix, mixup};
15
16use crate::{error::VisionResult, handle::LcgRng};
17
18use geometric::{center_crop, random_crop, random_horizontal_flip, resize_bilinear};
19use normalize::normalize_chw;
20use photometric::{color_jitter, random_grayscale};
21
22#[derive(Debug, Clone)]
29pub enum AugOp {
30 RandomCrop { crop_size: usize },
32
33 CenterCrop { crop_size: usize },
35
36 HorizontalFlip { prob: f32 },
38
39 Resize { target: usize },
41
42 ColorJitter {
44 brightness: f32,
45 contrast: f32,
46 saturation: f32,
47 },
48
49 RandomGrayscale { prob: f32 },
51
52 Normalize { mean: [f32; 3], std: [f32; 3] },
54}
55
56impl AugOp {
57 pub fn apply(
73 &self,
74 img: &[f32],
75 channels: usize,
76 h: usize,
77 w: usize,
78 rng: &mut LcgRng,
79 ) -> VisionResult<(Vec<f32>, usize, usize)> {
80 match self {
81 AugOp::RandomCrop { crop_size } => {
82 let out = random_crop(img, channels, h, w, *crop_size, rng)?;
83 Ok((out, *crop_size, *crop_size))
84 }
85 AugOp::CenterCrop { crop_size } => {
86 let out = center_crop(img, channels, h, w, *crop_size)?;
87 Ok((out, *crop_size, *crop_size))
88 }
89 AugOp::HorizontalFlip { prob } => {
90 let out = random_horizontal_flip(img, channels, h, w, *prob, rng);
91 Ok((out, h, w))
92 }
93 AugOp::Resize { target } => {
94 let out = resize_bilinear(img, channels, h, w, *target)?;
95 Ok((out, *target, *target))
96 }
97 AugOp::ColorJitter {
98 brightness,
99 contrast,
100 saturation,
101 } => {
102 let out = color_jitter(
103 img,
104 channels,
105 h,
106 w,
107 *brightness,
108 *contrast,
109 *saturation,
110 rng,
111 );
112 Ok((out, h, w))
113 }
114 AugOp::RandomGrayscale { prob } => {
115 let out = random_grayscale(img, channels, h, w, *prob, rng);
116 Ok((out, h, w))
117 }
118 AugOp::Normalize { mean, std } => {
119 let out = normalize_chw(img, channels, h, w, mean, std)?;
120 Ok((out, h, w))
121 }
122 }
123 }
124}
125
126#[derive(Debug, Clone, Default)]
144pub struct Pipeline {
145 pub ops: Vec<AugOp>,
147}
148
149impl Pipeline {
150 #[must_use]
152 pub fn new() -> Self {
153 Self { ops: Vec::new() }
154 }
155
156 #[must_use]
158 pub fn push(mut self, op: AugOp) -> Self {
159 self.ops.push(op);
160 self
161 }
162
163 pub fn apply(
169 &self,
170 img: &[f32],
171 channels: usize,
172 h: usize,
173 w: usize,
174 rng: &mut LcgRng,
175 ) -> VisionResult<(Vec<f32>, usize, usize)> {
176 if self.ops.is_empty() {
177 return Ok((img.to_vec(), h, w));
178 }
179
180 let (mut cur_img, mut cur_h, mut cur_w) = self.ops[0].apply(img, channels, h, w, rng)?;
182
183 for op in &self.ops[1..] {
185 let (next_img, next_h, next_w) = op.apply(&cur_img, channels, cur_h, cur_w, rng)?;
186 cur_img = next_img;
187 cur_h = next_h;
188 cur_w = next_w;
189 }
190
191 Ok((cur_img, cur_h, cur_w))
192 }
193}
194
195#[cfg(test)]
198mod tests {
199 use super::*;
200 use crate::handle::LcgRng;
201 use normalize::{IMAGENET_MEAN, IMAGENET_STD};
202
203 fn ramp_rgb(h: usize, w: usize) -> Vec<f32> {
204 let hw = h * w;
205 (0..3 * hw).map(|i| i as f32 / (3 * hw) as f32).collect()
206 }
207
208 #[test]
211 fn aug_op_random_crop_updates_dims() {
212 let img = ramp_rgb(32, 32);
213 let mut rng = LcgRng::new(1);
214 let op = AugOp::RandomCrop { crop_size: 24 };
215 let (out, new_h, new_w) = op.apply(&img, 3, 32, 32, &mut rng).expect("ok");
216 assert_eq!((new_h, new_w), (24, 24));
217 assert_eq!(out.len(), 3 * 24 * 24);
218 }
219
220 #[test]
223 fn aug_op_center_crop_updates_dims() {
224 let img = ramp_rgb(32, 32);
225 let mut rng = LcgRng::new(2);
226 let op = AugOp::CenterCrop { crop_size: 16 };
227 let (out, new_h, new_w) = op.apply(&img, 3, 32, 32, &mut rng).expect("ok");
228 assert_eq!((new_h, new_w), (16, 16));
229 assert_eq!(out.len(), 3 * 16 * 16);
230 }
231
232 #[test]
235 fn aug_op_flip_preserves_dims() {
236 let img = ramp_rgb(16, 16);
237 let mut rng = LcgRng::new(3);
238 let op = AugOp::HorizontalFlip { prob: 0.5 };
239 let (out, new_h, new_w) = op.apply(&img, 3, 16, 16, &mut rng).expect("ok");
240 assert_eq!((new_h, new_w), (16, 16));
241 assert_eq!(out.len(), img.len());
242 }
243
244 #[test]
247 fn aug_op_resize_updates_dims() {
248 let img = ramp_rgb(64, 64);
249 let mut rng = LcgRng::new(4);
250 let op = AugOp::Resize { target: 32 };
251 let (out, new_h, new_w) = op.apply(&img, 3, 64, 64, &mut rng).expect("ok");
252 assert_eq!((new_h, new_w), (32, 32));
253 assert_eq!(out.len(), 3 * 32 * 32);
254 }
255
256 #[test]
259 fn aug_op_color_jitter_preserves_dims() {
260 let img = ramp_rgb(16, 16);
261 let mut rng = LcgRng::new(5);
262 let op = AugOp::ColorJitter {
263 brightness: 0.2,
264 contrast: 0.2,
265 saturation: 0.2,
266 };
267 let (out, new_h, new_w) = op.apply(&img, 3, 16, 16, &mut rng).expect("ok");
268 assert_eq!((new_h, new_w), (16, 16));
269 assert_eq!(out.len(), img.len());
270 }
271
272 #[test]
275 fn aug_op_grayscale_preserves_dims() {
276 let img = ramp_rgb(16, 16);
277 let mut rng = LcgRng::new(6);
278 let op = AugOp::RandomGrayscale { prob: 0.5 };
279 let (out, new_h, new_w) = op.apply(&img, 3, 16, 16, &mut rng).expect("ok");
280 assert_eq!((new_h, new_w), (16, 16));
281 assert_eq!(out.len(), img.len());
282 }
283
284 #[test]
287 fn aug_op_normalize_preserves_dims() {
288 let img = ramp_rgb(16, 16);
289 let mut rng = LcgRng::new(7);
290 let op = AugOp::Normalize {
291 mean: IMAGENET_MEAN,
292 std: IMAGENET_STD,
293 };
294 let (out, new_h, new_w) = op.apply(&img, 3, 16, 16, &mut rng).expect("ok");
295 assert_eq!((new_h, new_w), (16, 16));
296 assert_eq!(out.len(), img.len());
297 }
298
299 #[test]
302 fn pipeline_empty_returns_clone() {
303 let img = ramp_rgb(16, 16);
304 let pipeline = Pipeline::new();
305 let mut rng = LcgRng::new(8);
306 let (out, new_h, new_w) = pipeline.apply(&img, 3, 16, 16, &mut rng).expect("ok");
307 assert_eq!((new_h, new_w), (16, 16));
308 assert_eq!(out, img);
309 }
310
311 #[test]
312 fn pipeline_single_op() {
313 let img = ramp_rgb(32, 32);
314 let pipeline = Pipeline::new().push(AugOp::Resize { target: 16 });
315 let mut rng = LcgRng::new(9);
316 let (out, new_h, new_w) = pipeline.apply(&img, 3, 32, 32, &mut rng).expect("ok");
317 assert_eq!((new_h, new_w), (16, 16));
318 assert_eq!(out.len(), 3 * 16 * 16);
319 }
320
321 #[test]
322 fn pipeline_multi_op_dims_chain() {
323 let img = ramp_rgb(64, 64);
325 let pipeline = Pipeline::new()
326 .push(AugOp::Resize { target: 32 })
327 .push(AugOp::CenterCrop { crop_size: 24 });
328 let mut rng = LcgRng::new(10);
329 let (out, new_h, new_w) = pipeline.apply(&img, 3, 64, 64, &mut rng).expect("ok");
330 assert_eq!((new_h, new_w), (24, 24));
331 assert_eq!(out.len(), 3 * 24 * 24);
332 }
333
334 #[test]
335 fn pipeline_full_augmentation_chain() {
336 let img: Vec<f32> = (0..3 * 256 * 256)
338 .map(|i| i as f32 / (3.0 * 256.0 * 256.0))
339 .collect();
340 let pipeline = Pipeline::new()
341 .push(AugOp::Resize { target: 256 })
342 .push(AugOp::RandomCrop { crop_size: 224 })
343 .push(AugOp::HorizontalFlip { prob: 0.5 })
344 .push(AugOp::ColorJitter {
345 brightness: 0.1,
346 contrast: 0.1,
347 saturation: 0.1,
348 })
349 .push(AugOp::Normalize {
350 mean: IMAGENET_MEAN,
351 std: IMAGENET_STD,
352 });
353 let mut rng = LcgRng::new(11);
354 let (out, new_h, new_w) = pipeline.apply(&img, 3, 256, 256, &mut rng).expect("ok");
355 assert_eq!((new_h, new_w), (224, 224));
356 assert_eq!(out.len(), 3 * 224 * 224);
357 assert!(
358 out.iter().all(|v| v.is_finite()),
359 "pipeline output must be finite"
360 );
361 }
362
363 #[test]
364 fn pipeline_add_is_builder() {
365 let p = Pipeline::new()
367 .push(AugOp::HorizontalFlip { prob: 1.0 })
368 .push(AugOp::HorizontalFlip { prob: 1.0 });
369 assert_eq!(p.ops.len(), 2);
370
371 let img = ramp_rgb(8, 8);
373 let mut rng = LcgRng::new(12);
374 let (out, _, _) = p.apply(&img, 3, 8, 8, &mut rng).expect("ok");
375 for (i, (&a, &b)) in img.iter().zip(out.iter()).enumerate() {
376 assert!(
377 (a - b).abs() < 1e-6,
378 "pixel {i}: double-flip should be identity"
379 );
380 }
381 }
382
383 #[test]
384 fn pipeline_clone_is_independent() {
385 let p1 = Pipeline::new().push(AugOp::Resize { target: 16 });
386 let p2 = p1.clone();
387 assert_eq!(p1.ops.len(), p2.ops.len());
388 }
389
390 #[test]
391 fn aug_op_error_propagated_through_pipeline() {
392 let img = ramp_rgb(16, 16);
394 let pipeline = Pipeline::new().push(AugOp::CenterCrop { crop_size: 32 }); let mut rng = LcgRng::new(13);
396 let r = pipeline.apply(&img, 3, 16, 16, &mut rng);
397 assert!(
398 r.is_err(),
399 "oversized crop through pipeline should propagate error"
400 );
401 }
402}