Skip to main content

oxicuda_vision/augment/
mixup.rs

1//! MixUp and CutMix regularising augmentations for image classification.
2//!
3//! Both operate on a *batch* of CHW images and their one-hot (or soft) label
4//! vectors, producing a convex / spatial mixture of two samples together with
5//! the correspondingly mixed targets. They are among the strongest
6//! regularisers for modern image classifiers.
7//!
8//! * **MixUp** (Zhang et al. 2018): pixel-wise convex combination
9//!   `x̃ = λ·x_i + (1−λ)·x_j`, label `ỹ = λ·y_i + (1−λ)·y_j`.
10//! * **CutMix** (Yun et al. 2019): a rectangular patch of `x_j` is pasted into
11//!   `x_i`; the label mixing coefficient is the *area fraction* of the surviving
12//!   region of `x_i`, `λ = 1 − (patch_area / image_area)`.
13//!
14//! Each image is paired with another sample drawn from the same batch via a
15//! random permutation (a sample may be paired with itself, in which case the
16//! mixture is the identity — exactly as in the reference implementations).
17//!
18//! The mixing coefficient `λ` is drawn from a `Beta(α, α)` distribution. Beta
19//! sampling is performed without external crates via two Gamma(α) draws
20//! (Marsaglia–Tsang for α ≥ 1, with the Johnk boost for α < 1) using the
21//! crate-local [`LcgRng`].
22
23use crate::{
24    error::{VisionError, VisionResult},
25    handle::LcgRng,
26};
27
28/// Output of a batch mix operation.
29#[derive(Debug, Clone)]
30pub struct MixOutput {
31    /// Mixed images, shape `[batch × channels × h × w]` (flat row-major).
32    pub images: Vec<f32>,
33    /// Mixed labels, shape `[batch × n_classes]` (flat row-major).
34    pub labels: Vec<f32>,
35    /// Per-sample label mixing coefficient `λ`, length `batch`.
36    pub lambdas: Vec<f32>,
37    /// Per-sample partner index (into the original batch), length `batch`.
38    pub partners: Vec<usize>,
39}
40
41// ─── Beta / Gamma sampling ──────────────────────────────────────────────────
42
43/// Sample `Gamma(shape, 1)` using Marsaglia–Tsang (shape ≥ 1) with the
44/// Johnk boosting transform for `shape < 1`.
45fn sample_gamma(shape: f32, rng: &mut LcgRng) -> f32 {
46    if shape < 1.0 {
47        // Boost: Gamma(a) = Gamma(a+1) · U^{1/a}.
48        let g = sample_gamma(shape + 1.0, rng);
49        let u = rng.next_f32().max(1e-12);
50        return g * u.powf(1.0 / shape);
51    }
52    let d = shape - 1.0 / 3.0;
53    let c = 1.0 / (9.0 * d).sqrt();
54    loop {
55        // Standard normal via Box–Muller.
56        let (z, _) = rng.next_normal_pair();
57        let v0 = 1.0 + c * z;
58        if v0 <= 0.0 {
59            continue;
60        }
61        let v = v0 * v0 * v0;
62        let u = rng.next_f32().max(1e-12);
63        if u < 1.0 - 0.0331 * z * z * z * z {
64            return d * v;
65        }
66        if u.ln() < 0.5 * z * z + d * (1.0 - v + v.ln()) {
67            return d * v;
68        }
69    }
70}
71
72/// Sample `λ ~ Beta(alpha, alpha)`. For `alpha <= 0` returns `1.0` (no mixing,
73/// matching the common convention where α=0 disables MixUp).
74fn sample_beta_symmetric(alpha: f32, rng: &mut LcgRng) -> f32 {
75    if !alpha.is_finite() || alpha <= 0.0 {
76        return 1.0;
77    }
78    let x = sample_gamma(alpha, rng);
79    let y = sample_gamma(alpha, rng);
80    let s = x + y;
81    if s <= 1e-12 { 0.5 } else { x / s }
82}
83
84#[inline]
85fn validate_batch(
86    images: &[f32],
87    labels: &[f32],
88    batch: usize,
89    channels: usize,
90    h: usize,
91    w: usize,
92    n_classes: usize,
93) -> VisionResult<()> {
94    if batch == 0 {
95        return Err(VisionError::EmptyInput("mixup batch"));
96    }
97    if channels == 0 || h == 0 || w == 0 {
98        return Err(VisionError::InvalidImageSize {
99            height: h,
100            width: w,
101            channels,
102        });
103    }
104    if n_classes == 0 {
105        return Err(VisionError::InvalidNumClasses(n_classes));
106    }
107    let img_expected = batch * channels * h * w;
108    if images.len() != img_expected {
109        return Err(VisionError::DimensionMismatch {
110            expected: img_expected,
111            got: images.len(),
112        });
113    }
114    let lbl_expected = batch * n_classes;
115    if labels.len() != lbl_expected {
116        return Err(VisionError::DimensionMismatch {
117            expected: lbl_expected,
118            got: labels.len(),
119        });
120    }
121    Ok(())
122}
123
124/// Build a random partner permutation (a derangement is *not* required; a
125/// sample may pair with itself, exactly as in the reference code).
126fn random_partners(batch: usize, rng: &mut LcgRng) -> Vec<usize> {
127    let mut perm: Vec<usize> = (0..batch).collect();
128    rng.shuffle(&mut perm);
129    perm
130}
131
132/// Mix `labels[i]` and `labels[j]` as `λ·y_i + (1−λ)·y_j` into `out[i]`.
133fn mix_labels_into(
134    out: &mut [f32],
135    labels: &[f32],
136    i: usize,
137    j: usize,
138    n_classes: usize,
139    lambda: f32,
140) {
141    let oi = i * n_classes;
142    let li = i * n_classes;
143    let lj = j * n_classes;
144    for c in 0..n_classes {
145        out[oi + c] = lambda * labels[li + c] + (1.0 - lambda) * labels[lj + c];
146    }
147}
148
149/// Apply **MixUp** to a batch.
150///
151/// # Errors
152/// Returns [`VisionError`] for empty / mismatched inputs.
153pub fn mixup(
154    images: &[f32],
155    labels: &[f32],
156    batch: usize,
157    channels: usize,
158    h: usize,
159    w: usize,
160    n_classes: usize,
161    alpha: f32,
162    rng: &mut LcgRng,
163) -> VisionResult<MixOutput> {
164    validate_batch(images, labels, batch, channels, h, w, n_classes)?;
165    let chw = channels * h * w;
166    let partners = random_partners(batch, rng);
167
168    let mut out_images = vec![0.0_f32; images.len()];
169    let mut out_labels = vec![0.0_f32; labels.len()];
170    let mut lambdas = vec![0.0_f32; batch];
171
172    for i in 0..batch {
173        let j = partners[i];
174        let lambda = sample_beta_symmetric(alpha, rng);
175        lambdas[i] = lambda;
176        let bi = i * chw;
177        let bj = j * chw;
178        for p in 0..chw {
179            out_images[bi + p] = lambda * images[bi + p] + (1.0 - lambda) * images[bj + p];
180        }
181        mix_labels_into(&mut out_labels, labels, i, j, n_classes, lambda);
182    }
183
184    Ok(MixOutput {
185        images: out_images,
186        labels: out_labels,
187        lambdas,
188        partners,
189    })
190}
191
192/// Sample a CutMix bounding box for a `λ` area ratio.
193///
194/// The patch has side `√(1−λ)` of the image; its centre is uniform. Returns
195/// `(x1, y1, x2, y2)` clamped to the image bounds, plus the *corrected* area
196/// fraction actually pasted (the box may be clipped at the border).
197fn cutmix_bbox(h: usize, w: usize, lambda: f32, rng: &mut LcgRng) -> (usize, usize, usize, usize) {
198    let cut_ratio = (1.0 - lambda).max(0.0).sqrt();
199    let cut_h = ((h as f32) * cut_ratio).round() as usize;
200    let cut_w = ((w as f32) * cut_ratio).round() as usize;
201    let cy = rng.next_usize(h);
202    let cx = rng.next_usize(w);
203    let y1 = cy.saturating_sub(cut_h / 2);
204    let x1 = cx.saturating_sub(cut_w / 2);
205    let y2 = (cy + cut_h.div_ceil(2)).min(h);
206    let x2 = (cx + cut_w.div_ceil(2)).min(w);
207    (x1, y1, x2, y2)
208}
209
210/// Apply **CutMix** to a batch.
211///
212/// A rectangular patch from the partner image is pasted over each image; the
213/// label coefficient `λ` is corrected to the *true* surviving area fraction
214/// `1 − (patch_area / image_area)` after border clipping, as in the reference.
215///
216/// # Errors
217/// Returns [`VisionError`] for empty / mismatched inputs.
218pub fn cutmix(
219    images: &[f32],
220    labels: &[f32],
221    batch: usize,
222    channels: usize,
223    h: usize,
224    w: usize,
225    n_classes: usize,
226    alpha: f32,
227    rng: &mut LcgRng,
228) -> VisionResult<MixOutput> {
229    validate_batch(images, labels, batch, channels, h, w, n_classes)?;
230    let chw = channels * h * w;
231    let partners = random_partners(batch, rng);
232    let area = (h * w) as f32;
233
234    let mut out_images = images.to_vec();
235    let mut out_labels = vec![0.0_f32; labels.len()];
236    let mut lambdas = vec![0.0_f32; batch];
237
238    for i in 0..batch {
239        let j = partners[i];
240        let lambda0 = sample_beta_symmetric(alpha, rng);
241        let (x1, y1, x2, y2) = cutmix_bbox(h, w, lambda0, rng);
242        let patch_area = ((x2 - x1) * (y2 - y1)) as f32;
243        // True coefficient after border clipping.
244        let lambda = 1.0 - patch_area / area;
245        lambdas[i] = lambda;
246
247        let bi = i * chw;
248        let bj = j * chw;
249        for c in 0..channels {
250            let ci = bi + c * h * w;
251            let cj = bj + c * h * w;
252            for y in y1..y2 {
253                for x in x1..x2 {
254                    out_images[ci + y * w + x] = images[cj + y * w + x];
255                }
256            }
257        }
258        mix_labels_into(&mut out_labels, labels, i, j, n_classes, lambda);
259    }
260
261    Ok(MixOutput {
262        images: out_images,
263        labels: out_labels,
264        lambdas,
265        partners,
266    })
267}
268
269// ─── Tests ───────────────────────────────────────────────────────────────────
270
271#[cfg(test)]
272mod tests {
273    use super::*;
274
275    fn one_hot_batch(batch: usize, n_classes: usize) -> Vec<f32> {
276        let mut labels = vec![0.0_f32; batch * n_classes];
277        for i in 0..batch {
278            labels[i * n_classes + (i % n_classes)] = 1.0;
279        }
280        labels
281    }
282
283    #[test]
284    fn beta_symmetric_in_unit_interval() {
285        let mut rng = LcgRng::new(1);
286        for _ in 0..1000 {
287            let l = sample_beta_symmetric(0.4, &mut rng);
288            assert!((0.0..=1.0).contains(&l), "beta sample out of [0,1]: {l}");
289        }
290    }
291
292    #[test]
293    fn beta_alpha_nonpositive_is_one() {
294        let mut rng = LcgRng::new(2);
295        assert_eq!(sample_beta_symmetric(0.0, &mut rng), 1.0);
296        assert_eq!(sample_beta_symmetric(-1.0, &mut rng), 1.0);
297    }
298
299    #[test]
300    fn gamma_samples_positive() {
301        let mut rng = LcgRng::new(3);
302        for a in [0.3_f32, 1.0, 2.5, 5.0] {
303            for _ in 0..200 {
304                let g = sample_gamma(a, &mut rng);
305                assert!(g > 0.0 && g.is_finite(), "gamma({a})={g}");
306            }
307        }
308    }
309
310    #[test]
311    fn mixup_output_shapes() {
312        let batch = 4;
313        let (c, h, w, k) = (3, 8, 8, 5);
314        let images = vec![0.5_f32; batch * c * h * w];
315        let labels = one_hot_batch(batch, k);
316        let mut rng = LcgRng::new(4);
317        let out = mixup(&images, &labels, batch, c, h, w, k, 0.4, &mut rng).expect("ok");
318        assert_eq!(out.images.len(), batch * c * h * w);
319        assert_eq!(out.labels.len(), batch * k);
320        assert_eq!(out.lambdas.len(), batch);
321        assert_eq!(out.partners.len(), batch);
322    }
323
324    #[test]
325    fn mixup_labels_sum_preserved() {
326        // One-hot labels each sum to 1 → any convex mix also sums to 1.
327        let batch = 6;
328        let (c, h, w, k) = (1, 4, 4, 4);
329        let images = vec![0.3_f32; batch * c * h * w];
330        let labels = one_hot_batch(batch, k);
331        let mut rng = LcgRng::new(5);
332        let out = mixup(&images, &labels, batch, c, h, w, k, 0.5, &mut rng).expect("ok");
333        for i in 0..batch {
334            let s: f32 = out.labels[i * k..(i + 1) * k].iter().sum();
335            assert!((s - 1.0).abs() < 1e-5, "row {i} label sum {s} != 1");
336        }
337    }
338
339    #[test]
340    fn mixup_constant_images_value_preserved() {
341        // Both images constant 0.5 → any λ leaves the value 0.5.
342        let batch = 3;
343        let (c, h, w, k) = (3, 4, 4, 3);
344        let images = vec![0.5_f32; batch * c * h * w];
345        let labels = one_hot_batch(batch, k);
346        let mut rng = LcgRng::new(6);
347        let out = mixup(&images, &labels, batch, c, h, w, k, 0.4, &mut rng).expect("ok");
348        assert!(out.images.iter().all(|&v| (v - 0.5).abs() < 1e-5));
349    }
350
351    #[test]
352    fn mixup_output_finite() {
353        let batch = 4;
354        let (c, h, w, k) = (3, 8, 8, 10);
355        let mut rng = LcgRng::new(7);
356        let mut images = vec![0.0_f32; batch * c * h * w];
357        rng.fill_normal(&mut images);
358        let labels = one_hot_batch(batch, k);
359        let out = mixup(&images, &labels, batch, c, h, w, k, 0.2, &mut rng).expect("ok");
360        assert!(out.images.iter().all(|v| v.is_finite()));
361        assert!(out.labels.iter().all(|v| v.is_finite()));
362    }
363
364    #[test]
365    fn mixup_deterministic_with_seed() {
366        let batch = 5;
367        let (c, h, w, k) = (3, 8, 8, 4);
368        let images = vec![0.4_f32; batch * c * h * w];
369        let labels = one_hot_batch(batch, k);
370        let mut r1 = LcgRng::new(123);
371        let mut r2 = LcgRng::new(123);
372        let o1 = mixup(&images, &labels, batch, c, h, w, k, 0.5, &mut r1).expect("ok");
373        let o2 = mixup(&images, &labels, batch, c, h, w, k, 0.5, &mut r2).expect("ok");
374        assert_eq!(o1.partners, o2.partners);
375        assert_eq!(o1.lambdas, o2.lambdas);
376        assert_eq!(o1.images, o2.images);
377    }
378
379    #[test]
380    fn mixup_empty_batch_errors() {
381        let mut rng = LcgRng::new(8);
382        let r = mixup(&[], &[], 0, 3, 8, 8, 5, 0.4, &mut rng);
383        assert!(matches!(r, Err(VisionError::EmptyInput(_))));
384    }
385
386    #[test]
387    fn mixup_label_size_mismatch_errors() {
388        let batch = 4;
389        let images = vec![0.5_f32; batch * 3 * 8 * 8];
390        let labels = vec![0.0_f32; batch * 4]; // claim k=5 below
391        let mut rng = LcgRng::new(9);
392        let r = mixup(&images, &labels, batch, 3, 8, 8, 5, 0.4, &mut rng);
393        assert!(matches!(r, Err(VisionError::DimensionMismatch { .. })));
394    }
395
396    #[test]
397    fn cutmix_output_shapes() {
398        let batch = 4;
399        let (c, h, w, k) = (3, 16, 16, 5);
400        let images = vec![0.5_f32; batch * c * h * w];
401        let labels = one_hot_batch(batch, k);
402        let mut rng = LcgRng::new(10);
403        let out = cutmix(&images, &labels, batch, c, h, w, k, 1.0, &mut rng).expect("ok");
404        assert_eq!(out.images.len(), batch * c * h * w);
405        assert_eq!(out.labels.len(), batch * k);
406        assert_eq!(out.lambdas.len(), batch);
407    }
408
409    #[test]
410    fn cutmix_labels_sum_to_one() {
411        let batch = 6;
412        let (c, h, w, k) = (3, 16, 16, 4);
413        let images = vec![0.5_f32; batch * c * h * w];
414        let labels = one_hot_batch(batch, k);
415        let mut rng = LcgRng::new(11);
416        let out = cutmix(&images, &labels, batch, c, h, w, k, 1.0, &mut rng).expect("ok");
417        for i in 0..batch {
418            let s: f32 = out.labels[i * k..(i + 1) * k].iter().sum();
419            assert!((s - 1.0).abs() < 1e-5, "row {i} sum {s}");
420        }
421    }
422
423    #[test]
424    fn cutmix_lambda_matches_area() {
425        // Build distinguishable images: sample 0 all 0.0, others all 1.0; pair
426        // forcing is not possible, but we can at least check the pasted area is
427        // consistent with the reported λ for every sample.
428        let batch = 4;
429        let (c, h, w, k) = (1, 16, 16, 4);
430        let images: Vec<f32> = (0..batch).flat_map(|i| vec![i as f32; c * h * w]).collect();
431        let labels = one_hot_batch(batch, k);
432        let mut rng = LcgRng::new(12);
433        let out = cutmix(&images, &labels, batch, c, h, w, k, 1.0, &mut rng).expect("ok");
434        let area = (h * w) as f32;
435        for i in 0..batch {
436            let j = out.partners[i];
437            let vi = i as f32;
438            let vj = j as f32;
439            if (vi - vj).abs() < 1e-6 {
440                continue; // self-paste — undetectable
441            }
442            // Count pixels that changed to the partner's constant value.
443            let base = i * c * h * w;
444            let changed = (0..h * w)
445                .filter(|&p| (out.images[base + p] - vj).abs() < 1e-5)
446                .count() as f32;
447            let observed_lambda = 1.0 - changed / area;
448            assert!(
449                (observed_lambda - out.lambdas[i]).abs() < 1e-4,
450                "sample {i}: observed λ {observed_lambda} vs reported {}",
451                out.lambdas[i]
452            );
453        }
454    }
455
456    #[test]
457    fn cutmix_lambda_in_unit_range() {
458        let batch = 5;
459        let (c, h, w, k) = (3, 16, 16, 4);
460        let images = vec![0.5_f32; batch * c * h * w];
461        let labels = one_hot_batch(batch, k);
462        let mut rng = LcgRng::new(13);
463        let out = cutmix(&images, &labels, batch, c, h, w, k, 0.5, &mut rng).expect("ok");
464        for &l in &out.lambdas {
465            assert!((0.0..=1.0).contains(&l), "cutmix λ out of range: {l}");
466        }
467    }
468
469    #[test]
470    fn cutmix_self_paste_identity_when_partner_equal() {
471        // Single-sample batch: the only partner is itself → output == input.
472        let batch = 1;
473        let (c, h, w, k) = (3, 16, 16, 3);
474        let mut rng = LcgRng::new(14);
475        let mut images = vec![0.0_f32; batch * c * h * w];
476        rng.fill_normal(&mut images);
477        let labels = one_hot_batch(batch, k);
478        let out = cutmix(&images, &labels, batch, c, h, w, k, 1.0, &mut rng).expect("ok");
479        assert_eq!(out.images, images, "self-paste must be identity");
480    }
481
482    #[test]
483    fn cutmix_output_finite() {
484        let batch = 4;
485        let (c, h, w, k) = (3, 16, 16, 10);
486        let mut rng = LcgRng::new(15);
487        let mut images = vec![0.0_f32; batch * c * h * w];
488        rng.fill_normal(&mut images);
489        let labels = one_hot_batch(batch, k);
490        let out = cutmix(&images, &labels, batch, c, h, w, k, 0.3, &mut rng).expect("ok");
491        assert!(out.images.iter().all(|v| v.is_finite()));
492        assert!(out.labels.iter().all(|v| v.is_finite()));
493    }
494
495    #[test]
496    fn cutmix_bbox_within_bounds() {
497        let mut rng = LcgRng::new(16);
498        for _ in 0..200 {
499            let (x1, y1, x2, y2) = cutmix_bbox(16, 16, 0.3, &mut rng);
500            assert!(x1 <= x2 && y1 <= y2);
501            assert!(x2 <= 16 && y2 <= 16);
502        }
503    }
504
505    #[test]
506    fn cutmix_empty_errors() {
507        let mut rng = LcgRng::new(17);
508        let r = cutmix(&[], &[], 0, 3, 8, 8, 5, 0.4, &mut rng);
509        assert!(matches!(r, Err(VisionError::EmptyInput(_))));
510    }
511}