Skip to main content

oxicuda_vision/augment/
mod.rs

1//! Image augmentation pipeline for CHW tensors.
2//!
3//! Provides an enum-dispatched set of operations (`AugOp`) and a composable
4//! `Pipeline` that applies them in sequence.  All operations work on flat
5//! `[channels × h × w]` row-major `f32` buffers; dimensions are tracked
6//! as `(channels, h, w)` tuples so that spatial-modifying operations (crop,
7//! resize) can propagate updated dimensions to later stages.
8
9pub mod geometric;
10pub mod mixup;
11pub mod normalize;
12pub mod photometric;
13
14pub use mixup::{MixOutput, cutmix, mixup};
15
16use crate::{error::VisionResult, handle::LcgRng};
17
18use geometric::{center_crop, random_crop, random_horizontal_flip, resize_bilinear};
19use normalize::normalize_chw;
20use photometric::{color_jitter, random_grayscale};
21
22// ─── AugOp ───────────────────────────────────────────────────────────────────
23
24/// Enum-dispatched augmentation operations — no `dyn Trait`, no heap boxing.
25///
26/// Each variant carries the hyperparameters it needs.  Stochastic operations
27/// receive a mutable `LcgRng` reference at call time via [`AugOp::apply`].
28#[derive(Debug, Clone)]
29pub enum AugOp {
30    /// Randomly crop to `[channels, crop_size, crop_size]`.
31    RandomCrop { crop_size: usize },
32
33    /// Deterministic centre crop to `[channels, crop_size, crop_size]`.
34    CenterCrop { crop_size: usize },
35
36    /// Randomly flip the image horizontally with the given probability.
37    HorizontalFlip { prob: f32 },
38
39    /// Bilinear resize to `[channels, target, target]`.
40    Resize { target: usize },
41
42    /// Colour jitter: brightness, contrast, saturation perturbation magnitudes.
43    ColorJitter {
44        brightness: f32,
45        contrast: f32,
46        saturation: f32,
47    },
48
49    /// Convert to grayscale with the given probability (RGB images only).
50    RandomGrayscale { prob: f32 },
51
52    /// Channel-wise normalisation: `(x - mean[c]) / std[c]`.
53    Normalize { mean: [f32; 3], std: [f32; 3] },
54}
55
56impl AugOp {
57    /// Apply this augmentation to a CHW image.
58    ///
59    /// # Parameters
60    /// - `img`: flat `[channels × h × w]` input buffer.
61    /// - `channels`: number of channels (e.g., 3 for RGB).
62    /// - `h`, `w`: spatial height and width of the input image.
63    /// - `rng`: mutable RNG for stochastic operations; deterministic ops ignore it.
64    ///
65    /// # Returns
66    /// `(new_img, new_h, new_w)` — the transformed image and its (possibly
67    /// updated) spatial dimensions.  `channels` is never changed.
68    ///
69    /// # Errors
70    /// Propagates errors from the underlying operation functions (invalid
71    /// dimensions, mismatched buffers, non-positive std, etc.).
72    pub fn apply(
73        &self,
74        img: &[f32],
75        channels: usize,
76        h: usize,
77        w: usize,
78        rng: &mut LcgRng,
79    ) -> VisionResult<(Vec<f32>, usize, usize)> {
80        match self {
81            AugOp::RandomCrop { crop_size } => {
82                let out = random_crop(img, channels, h, w, *crop_size, rng)?;
83                Ok((out, *crop_size, *crop_size))
84            }
85            AugOp::CenterCrop { crop_size } => {
86                let out = center_crop(img, channels, h, w, *crop_size)?;
87                Ok((out, *crop_size, *crop_size))
88            }
89            AugOp::HorizontalFlip { prob } => {
90                let out = random_horizontal_flip(img, channels, h, w, *prob, rng);
91                Ok((out, h, w))
92            }
93            AugOp::Resize { target } => {
94                let out = resize_bilinear(img, channels, h, w, *target)?;
95                Ok((out, *target, *target))
96            }
97            AugOp::ColorJitter {
98                brightness,
99                contrast,
100                saturation,
101            } => {
102                let out = color_jitter(
103                    img,
104                    channels,
105                    h,
106                    w,
107                    *brightness,
108                    *contrast,
109                    *saturation,
110                    rng,
111                );
112                Ok((out, h, w))
113            }
114            AugOp::RandomGrayscale { prob } => {
115                let out = random_grayscale(img, channels, h, w, *prob, rng);
116                Ok((out, h, w))
117            }
118            AugOp::Normalize { mean, std } => {
119                let out = normalize_chw(img, channels, h, w, mean, std)?;
120                Ok((out, h, w))
121            }
122        }
123    }
124}
125
126// ─── Pipeline ────────────────────────────────────────────────────────────────
127
128/// A sequence of augmentation operations applied in order.
129///
130/// `Pipeline` owns a `Vec<AugOp>` and threads `(img, channels, h, w)` through
131/// each operation, updating the spatial dimensions as needed (e.g., after a
132/// crop or resize).
133///
134/// # Example
135/// ```rust,ignore
136/// let pipeline = Pipeline::new()
137///     .push(AugOp::Resize { target: 256 })
138///     .push(AugOp::RandomCrop { crop_size: 224 })
139///     .push(AugOp::HorizontalFlip { prob: 0.5 })
140///     .push(AugOp::Normalize { mean: IMAGENET_MEAN, std: IMAGENET_STD });
141/// ```
142// Note: method is named `push` (not `add`) to avoid confusion with std::ops::Add::add.
143#[derive(Debug, Clone, Default)]
144pub struct Pipeline {
145    /// Ordered list of augmentation operations.
146    pub ops: Vec<AugOp>,
147}
148
149impl Pipeline {
150    /// Create an empty pipeline.
151    #[must_use]
152    pub fn new() -> Self {
153        Self { ops: Vec::new() }
154    }
155
156    /// Append an operation to the pipeline (builder pattern).
157    #[must_use]
158    pub fn push(mut self, op: AugOp) -> Self {
159        self.ops.push(op);
160        self
161    }
162
163    /// Apply all operations in sequence, threading the output through.
164    ///
165    /// Returns the final `(image, h, w)` after all augmentations, or the
166    /// first error encountered.  If the pipeline is empty the image and
167    /// dimensions are returned unchanged (cloning the input slice).
168    pub fn apply(
169        &self,
170        img: &[f32],
171        channels: usize,
172        h: usize,
173        w: usize,
174        rng: &mut LcgRng,
175    ) -> VisionResult<(Vec<f32>, usize, usize)> {
176        if self.ops.is_empty() {
177            return Ok((img.to_vec(), h, w));
178        }
179
180        // Apply first op to the original input.
181        let (mut cur_img, mut cur_h, mut cur_w) = self.ops[0].apply(img, channels, h, w, rng)?;
182
183        // Apply subsequent ops to the evolving (cur_img, cur_h, cur_w).
184        for op in &self.ops[1..] {
185            let (next_img, next_h, next_w) = op.apply(&cur_img, channels, cur_h, cur_w, rng)?;
186            cur_img = next_img;
187            cur_h = next_h;
188            cur_w = next_w;
189        }
190
191        Ok((cur_img, cur_h, cur_w))
192    }
193}
194
195// ─── Tests ───────────────────────────────────────────────────────────────────
196
197#[cfg(test)]
198mod tests {
199    use super::*;
200    use crate::handle::LcgRng;
201    use normalize::{IMAGENET_MEAN, IMAGENET_STD};
202
203    fn ramp_rgb(h: usize, w: usize) -> Vec<f32> {
204        let hw = h * w;
205        (0..3 * hw).map(|i| i as f32 / (3 * hw) as f32).collect()
206    }
207
208    // ── AugOp::RandomCrop ────────────────────────────────────────────────────
209
210    #[test]
211    fn aug_op_random_crop_updates_dims() {
212        let img = ramp_rgb(32, 32);
213        let mut rng = LcgRng::new(1);
214        let op = AugOp::RandomCrop { crop_size: 24 };
215        let (out, new_h, new_w) = op.apply(&img, 3, 32, 32, &mut rng).expect("ok");
216        assert_eq!((new_h, new_w), (24, 24));
217        assert_eq!(out.len(), 3 * 24 * 24);
218    }
219
220    // ── AugOp::CenterCrop ────────────────────────────────────────────────────
221
222    #[test]
223    fn aug_op_center_crop_updates_dims() {
224        let img = ramp_rgb(32, 32);
225        let mut rng = LcgRng::new(2);
226        let op = AugOp::CenterCrop { crop_size: 16 };
227        let (out, new_h, new_w) = op.apply(&img, 3, 32, 32, &mut rng).expect("ok");
228        assert_eq!((new_h, new_w), (16, 16));
229        assert_eq!(out.len(), 3 * 16 * 16);
230    }
231
232    // ── AugOp::HorizontalFlip ────────────────────────────────────────────────
233
234    #[test]
235    fn aug_op_flip_preserves_dims() {
236        let img = ramp_rgb(16, 16);
237        let mut rng = LcgRng::new(3);
238        let op = AugOp::HorizontalFlip { prob: 0.5 };
239        let (out, new_h, new_w) = op.apply(&img, 3, 16, 16, &mut rng).expect("ok");
240        assert_eq!((new_h, new_w), (16, 16));
241        assert_eq!(out.len(), img.len());
242    }
243
244    // ── AugOp::Resize ────────────────────────────────────────────────────────
245
246    #[test]
247    fn aug_op_resize_updates_dims() {
248        let img = ramp_rgb(64, 64);
249        let mut rng = LcgRng::new(4);
250        let op = AugOp::Resize { target: 32 };
251        let (out, new_h, new_w) = op.apply(&img, 3, 64, 64, &mut rng).expect("ok");
252        assert_eq!((new_h, new_w), (32, 32));
253        assert_eq!(out.len(), 3 * 32 * 32);
254    }
255
256    // ── AugOp::ColorJitter ───────────────────────────────────────────────────
257
258    #[test]
259    fn aug_op_color_jitter_preserves_dims() {
260        let img = ramp_rgb(16, 16);
261        let mut rng = LcgRng::new(5);
262        let op = AugOp::ColorJitter {
263            brightness: 0.2,
264            contrast: 0.2,
265            saturation: 0.2,
266        };
267        let (out, new_h, new_w) = op.apply(&img, 3, 16, 16, &mut rng).expect("ok");
268        assert_eq!((new_h, new_w), (16, 16));
269        assert_eq!(out.len(), img.len());
270    }
271
272    // ── AugOp::RandomGrayscale ───────────────────────────────────────────────
273
274    #[test]
275    fn aug_op_grayscale_preserves_dims() {
276        let img = ramp_rgb(16, 16);
277        let mut rng = LcgRng::new(6);
278        let op = AugOp::RandomGrayscale { prob: 0.5 };
279        let (out, new_h, new_w) = op.apply(&img, 3, 16, 16, &mut rng).expect("ok");
280        assert_eq!((new_h, new_w), (16, 16));
281        assert_eq!(out.len(), img.len());
282    }
283
284    // ── AugOp::Normalize ─────────────────────────────────────────────────────
285
286    #[test]
287    fn aug_op_normalize_preserves_dims() {
288        let img = ramp_rgb(16, 16);
289        let mut rng = LcgRng::new(7);
290        let op = AugOp::Normalize {
291            mean: IMAGENET_MEAN,
292            std: IMAGENET_STD,
293        };
294        let (out, new_h, new_w) = op.apply(&img, 3, 16, 16, &mut rng).expect("ok");
295        assert_eq!((new_h, new_w), (16, 16));
296        assert_eq!(out.len(), img.len());
297    }
298
299    // ── Pipeline ─────────────────────────────────────────────────────────────
300
301    #[test]
302    fn pipeline_empty_returns_clone() {
303        let img = ramp_rgb(16, 16);
304        let pipeline = Pipeline::new();
305        let mut rng = LcgRng::new(8);
306        let (out, new_h, new_w) = pipeline.apply(&img, 3, 16, 16, &mut rng).expect("ok");
307        assert_eq!((new_h, new_w), (16, 16));
308        assert_eq!(out, img);
309    }
310
311    #[test]
312    fn pipeline_single_op() {
313        let img = ramp_rgb(32, 32);
314        let pipeline = Pipeline::new().push(AugOp::Resize { target: 16 });
315        let mut rng = LcgRng::new(9);
316        let (out, new_h, new_w) = pipeline.apply(&img, 3, 32, 32, &mut rng).expect("ok");
317        assert_eq!((new_h, new_w), (16, 16));
318        assert_eq!(out.len(), 3 * 16 * 16);
319    }
320
321    #[test]
322    fn pipeline_multi_op_dims_chain() {
323        // Resize 64→32, then CenterCrop 32→24.
324        let img = ramp_rgb(64, 64);
325        let pipeline = Pipeline::new()
326            .push(AugOp::Resize { target: 32 })
327            .push(AugOp::CenterCrop { crop_size: 24 });
328        let mut rng = LcgRng::new(10);
329        let (out, new_h, new_w) = pipeline.apply(&img, 3, 64, 64, &mut rng).expect("ok");
330        assert_eq!((new_h, new_w), (24, 24));
331        assert_eq!(out.len(), 3 * 24 * 24);
332    }
333
334    #[test]
335    fn pipeline_full_augmentation_chain() {
336        // Typical training augmentation: resize → random_crop → flip → jitter → normalize.
337        let img: Vec<f32> = (0..3 * 256 * 256)
338            .map(|i| i as f32 / (3.0 * 256.0 * 256.0))
339            .collect();
340        let pipeline = Pipeline::new()
341            .push(AugOp::Resize { target: 256 })
342            .push(AugOp::RandomCrop { crop_size: 224 })
343            .push(AugOp::HorizontalFlip { prob: 0.5 })
344            .push(AugOp::ColorJitter {
345                brightness: 0.1,
346                contrast: 0.1,
347                saturation: 0.1,
348            })
349            .push(AugOp::Normalize {
350                mean: IMAGENET_MEAN,
351                std: IMAGENET_STD,
352            });
353        let mut rng = LcgRng::new(11);
354        let (out, new_h, new_w) = pipeline.apply(&img, 3, 256, 256, &mut rng).expect("ok");
355        assert_eq!((new_h, new_w), (224, 224));
356        assert_eq!(out.len(), 3 * 224 * 224);
357        assert!(
358            out.iter().all(|v| v.is_finite()),
359            "pipeline output must be finite"
360        );
361    }
362
363    #[test]
364    fn pipeline_add_is_builder() {
365        // The builder pattern should accumulate ops correctly.
366        let p = Pipeline::new()
367            .push(AugOp::HorizontalFlip { prob: 1.0 })
368            .push(AugOp::HorizontalFlip { prob: 1.0 });
369        assert_eq!(p.ops.len(), 2);
370
371        // Two horizontal flips at prob=1 should recover original.
372        let img = ramp_rgb(8, 8);
373        let mut rng = LcgRng::new(12);
374        let (out, _, _) = p.apply(&img, 3, 8, 8, &mut rng).expect("ok");
375        for (i, (&a, &b)) in img.iter().zip(out.iter()).enumerate() {
376            assert!(
377                (a - b).abs() < 1e-6,
378                "pixel {i}: double-flip should be identity"
379            );
380        }
381    }
382
383    #[test]
384    fn pipeline_clone_is_independent() {
385        let p1 = Pipeline::new().push(AugOp::Resize { target: 16 });
386        let p2 = p1.clone();
387        assert_eq!(p1.ops.len(), p2.ops.len());
388    }
389
390    #[test]
391    fn aug_op_error_propagated_through_pipeline() {
392        // A crop larger than the image should produce an error.
393        let img = ramp_rgb(16, 16);
394        let pipeline = Pipeline::new().push(AugOp::CenterCrop { crop_size: 32 }); // 32 > 16
395        let mut rng = LcgRng::new(13);
396        let r = pipeline.apply(&img, 3, 16, 16, &mut rng);
397        assert!(
398            r.is_err(),
399            "oversized crop through pipeline should propagate error"
400        );
401    }
402}