1use crate::{
24 error::{VisionError, VisionResult},
25 handle::LcgRng,
26};
27
28#[derive(Debug, Clone)]
30pub struct MixOutput {
31 pub images: Vec<f32>,
33 pub labels: Vec<f32>,
35 pub lambdas: Vec<f32>,
37 pub partners: Vec<usize>,
39}
40
41fn sample_gamma(shape: f32, rng: &mut LcgRng) -> f32 {
46 if shape < 1.0 {
47 let g = sample_gamma(shape + 1.0, rng);
49 let u = rng.next_f32().max(1e-12);
50 return g * u.powf(1.0 / shape);
51 }
52 let d = shape - 1.0 / 3.0;
53 let c = 1.0 / (9.0 * d).sqrt();
54 loop {
55 let (z, _) = rng.next_normal_pair();
57 let v0 = 1.0 + c * z;
58 if v0 <= 0.0 {
59 continue;
60 }
61 let v = v0 * v0 * v0;
62 let u = rng.next_f32().max(1e-12);
63 if u < 1.0 - 0.0331 * z * z * z * z {
64 return d * v;
65 }
66 if u.ln() < 0.5 * z * z + d * (1.0 - v + v.ln()) {
67 return d * v;
68 }
69 }
70}
71
72fn sample_beta_symmetric(alpha: f32, rng: &mut LcgRng) -> f32 {
75 if !alpha.is_finite() || alpha <= 0.0 {
76 return 1.0;
77 }
78 let x = sample_gamma(alpha, rng);
79 let y = sample_gamma(alpha, rng);
80 let s = x + y;
81 if s <= 1e-12 { 0.5 } else { x / s }
82}
83
84#[inline]
85fn validate_batch(
86 images: &[f32],
87 labels: &[f32],
88 batch: usize,
89 channels: usize,
90 h: usize,
91 w: usize,
92 n_classes: usize,
93) -> VisionResult<()> {
94 if batch == 0 {
95 return Err(VisionError::EmptyInput("mixup batch"));
96 }
97 if channels == 0 || h == 0 || w == 0 {
98 return Err(VisionError::InvalidImageSize {
99 height: h,
100 width: w,
101 channels,
102 });
103 }
104 if n_classes == 0 {
105 return Err(VisionError::InvalidNumClasses(n_classes));
106 }
107 let img_expected = batch * channels * h * w;
108 if images.len() != img_expected {
109 return Err(VisionError::DimensionMismatch {
110 expected: img_expected,
111 got: images.len(),
112 });
113 }
114 let lbl_expected = batch * n_classes;
115 if labels.len() != lbl_expected {
116 return Err(VisionError::DimensionMismatch {
117 expected: lbl_expected,
118 got: labels.len(),
119 });
120 }
121 Ok(())
122}
123
124fn random_partners(batch: usize, rng: &mut LcgRng) -> Vec<usize> {
127 let mut perm: Vec<usize> = (0..batch).collect();
128 rng.shuffle(&mut perm);
129 perm
130}
131
132fn mix_labels_into(
134 out: &mut [f32],
135 labels: &[f32],
136 i: usize,
137 j: usize,
138 n_classes: usize,
139 lambda: f32,
140) {
141 let oi = i * n_classes;
142 let li = i * n_classes;
143 let lj = j * n_classes;
144 for c in 0..n_classes {
145 out[oi + c] = lambda * labels[li + c] + (1.0 - lambda) * labels[lj + c];
146 }
147}
148
149pub fn mixup(
154 images: &[f32],
155 labels: &[f32],
156 batch: usize,
157 channels: usize,
158 h: usize,
159 w: usize,
160 n_classes: usize,
161 alpha: f32,
162 rng: &mut LcgRng,
163) -> VisionResult<MixOutput> {
164 validate_batch(images, labels, batch, channels, h, w, n_classes)?;
165 let chw = channels * h * w;
166 let partners = random_partners(batch, rng);
167
168 let mut out_images = vec![0.0_f32; images.len()];
169 let mut out_labels = vec![0.0_f32; labels.len()];
170 let mut lambdas = vec![0.0_f32; batch];
171
172 for i in 0..batch {
173 let j = partners[i];
174 let lambda = sample_beta_symmetric(alpha, rng);
175 lambdas[i] = lambda;
176 let bi = i * chw;
177 let bj = j * chw;
178 for p in 0..chw {
179 out_images[bi + p] = lambda * images[bi + p] + (1.0 - lambda) * images[bj + p];
180 }
181 mix_labels_into(&mut out_labels, labels, i, j, n_classes, lambda);
182 }
183
184 Ok(MixOutput {
185 images: out_images,
186 labels: out_labels,
187 lambdas,
188 partners,
189 })
190}
191
192fn cutmix_bbox(h: usize, w: usize, lambda: f32, rng: &mut LcgRng) -> (usize, usize, usize, usize) {
198 let cut_ratio = (1.0 - lambda).max(0.0).sqrt();
199 let cut_h = ((h as f32) * cut_ratio).round() as usize;
200 let cut_w = ((w as f32) * cut_ratio).round() as usize;
201 let cy = rng.next_usize(h);
202 let cx = rng.next_usize(w);
203 let y1 = cy.saturating_sub(cut_h / 2);
204 let x1 = cx.saturating_sub(cut_w / 2);
205 let y2 = (cy + cut_h.div_ceil(2)).min(h);
206 let x2 = (cx + cut_w.div_ceil(2)).min(w);
207 (x1, y1, x2, y2)
208}
209
210pub fn cutmix(
219 images: &[f32],
220 labels: &[f32],
221 batch: usize,
222 channels: usize,
223 h: usize,
224 w: usize,
225 n_classes: usize,
226 alpha: f32,
227 rng: &mut LcgRng,
228) -> VisionResult<MixOutput> {
229 validate_batch(images, labels, batch, channels, h, w, n_classes)?;
230 let chw = channels * h * w;
231 let partners = random_partners(batch, rng);
232 let area = (h * w) as f32;
233
234 let mut out_images = images.to_vec();
235 let mut out_labels = vec![0.0_f32; labels.len()];
236 let mut lambdas = vec![0.0_f32; batch];
237
238 for i in 0..batch {
239 let j = partners[i];
240 let lambda0 = sample_beta_symmetric(alpha, rng);
241 let (x1, y1, x2, y2) = cutmix_bbox(h, w, lambda0, rng);
242 let patch_area = ((x2 - x1) * (y2 - y1)) as f32;
243 let lambda = 1.0 - patch_area / area;
245 lambdas[i] = lambda;
246
247 let bi = i * chw;
248 let bj = j * chw;
249 for c in 0..channels {
250 let ci = bi + c * h * w;
251 let cj = bj + c * h * w;
252 for y in y1..y2 {
253 for x in x1..x2 {
254 out_images[ci + y * w + x] = images[cj + y * w + x];
255 }
256 }
257 }
258 mix_labels_into(&mut out_labels, labels, i, j, n_classes, lambda);
259 }
260
261 Ok(MixOutput {
262 images: out_images,
263 labels: out_labels,
264 lambdas,
265 partners,
266 })
267}
268
269#[cfg(test)]
272mod tests {
273 use super::*;
274
275 fn one_hot_batch(batch: usize, n_classes: usize) -> Vec<f32> {
276 let mut labels = vec![0.0_f32; batch * n_classes];
277 for i in 0..batch {
278 labels[i * n_classes + (i % n_classes)] = 1.0;
279 }
280 labels
281 }
282
283 #[test]
284 fn beta_symmetric_in_unit_interval() {
285 let mut rng = LcgRng::new(1);
286 for _ in 0..1000 {
287 let l = sample_beta_symmetric(0.4, &mut rng);
288 assert!((0.0..=1.0).contains(&l), "beta sample out of [0,1]: {l}");
289 }
290 }
291
292 #[test]
293 fn beta_alpha_nonpositive_is_one() {
294 let mut rng = LcgRng::new(2);
295 assert_eq!(sample_beta_symmetric(0.0, &mut rng), 1.0);
296 assert_eq!(sample_beta_symmetric(-1.0, &mut rng), 1.0);
297 }
298
299 #[test]
300 fn gamma_samples_positive() {
301 let mut rng = LcgRng::new(3);
302 for a in [0.3_f32, 1.0, 2.5, 5.0] {
303 for _ in 0..200 {
304 let g = sample_gamma(a, &mut rng);
305 assert!(g > 0.0 && g.is_finite(), "gamma({a})={g}");
306 }
307 }
308 }
309
310 #[test]
311 fn mixup_output_shapes() {
312 let batch = 4;
313 let (c, h, w, k) = (3, 8, 8, 5);
314 let images = vec![0.5_f32; batch * c * h * w];
315 let labels = one_hot_batch(batch, k);
316 let mut rng = LcgRng::new(4);
317 let out = mixup(&images, &labels, batch, c, h, w, k, 0.4, &mut rng).expect("ok");
318 assert_eq!(out.images.len(), batch * c * h * w);
319 assert_eq!(out.labels.len(), batch * k);
320 assert_eq!(out.lambdas.len(), batch);
321 assert_eq!(out.partners.len(), batch);
322 }
323
324 #[test]
325 fn mixup_labels_sum_preserved() {
326 let batch = 6;
328 let (c, h, w, k) = (1, 4, 4, 4);
329 let images = vec![0.3_f32; batch * c * h * w];
330 let labels = one_hot_batch(batch, k);
331 let mut rng = LcgRng::new(5);
332 let out = mixup(&images, &labels, batch, c, h, w, k, 0.5, &mut rng).expect("ok");
333 for i in 0..batch {
334 let s: f32 = out.labels[i * k..(i + 1) * k].iter().sum();
335 assert!((s - 1.0).abs() < 1e-5, "row {i} label sum {s} != 1");
336 }
337 }
338
339 #[test]
340 fn mixup_constant_images_value_preserved() {
341 let batch = 3;
343 let (c, h, w, k) = (3, 4, 4, 3);
344 let images = vec![0.5_f32; batch * c * h * w];
345 let labels = one_hot_batch(batch, k);
346 let mut rng = LcgRng::new(6);
347 let out = mixup(&images, &labels, batch, c, h, w, k, 0.4, &mut rng).expect("ok");
348 assert!(out.images.iter().all(|&v| (v - 0.5).abs() < 1e-5));
349 }
350
351 #[test]
352 fn mixup_output_finite() {
353 let batch = 4;
354 let (c, h, w, k) = (3, 8, 8, 10);
355 let mut rng = LcgRng::new(7);
356 let mut images = vec![0.0_f32; batch * c * h * w];
357 rng.fill_normal(&mut images);
358 let labels = one_hot_batch(batch, k);
359 let out = mixup(&images, &labels, batch, c, h, w, k, 0.2, &mut rng).expect("ok");
360 assert!(out.images.iter().all(|v| v.is_finite()));
361 assert!(out.labels.iter().all(|v| v.is_finite()));
362 }
363
364 #[test]
365 fn mixup_deterministic_with_seed() {
366 let batch = 5;
367 let (c, h, w, k) = (3, 8, 8, 4);
368 let images = vec![0.4_f32; batch * c * h * w];
369 let labels = one_hot_batch(batch, k);
370 let mut r1 = LcgRng::new(123);
371 let mut r2 = LcgRng::new(123);
372 let o1 = mixup(&images, &labels, batch, c, h, w, k, 0.5, &mut r1).expect("ok");
373 let o2 = mixup(&images, &labels, batch, c, h, w, k, 0.5, &mut r2).expect("ok");
374 assert_eq!(o1.partners, o2.partners);
375 assert_eq!(o1.lambdas, o2.lambdas);
376 assert_eq!(o1.images, o2.images);
377 }
378
379 #[test]
380 fn mixup_empty_batch_errors() {
381 let mut rng = LcgRng::new(8);
382 let r = mixup(&[], &[], 0, 3, 8, 8, 5, 0.4, &mut rng);
383 assert!(matches!(r, Err(VisionError::EmptyInput(_))));
384 }
385
386 #[test]
387 fn mixup_label_size_mismatch_errors() {
388 let batch = 4;
389 let images = vec![0.5_f32; batch * 3 * 8 * 8];
390 let labels = vec![0.0_f32; batch * 4]; let mut rng = LcgRng::new(9);
392 let r = mixup(&images, &labels, batch, 3, 8, 8, 5, 0.4, &mut rng);
393 assert!(matches!(r, Err(VisionError::DimensionMismatch { .. })));
394 }
395
396 #[test]
397 fn cutmix_output_shapes() {
398 let batch = 4;
399 let (c, h, w, k) = (3, 16, 16, 5);
400 let images = vec![0.5_f32; batch * c * h * w];
401 let labels = one_hot_batch(batch, k);
402 let mut rng = LcgRng::new(10);
403 let out = cutmix(&images, &labels, batch, c, h, w, k, 1.0, &mut rng).expect("ok");
404 assert_eq!(out.images.len(), batch * c * h * w);
405 assert_eq!(out.labels.len(), batch * k);
406 assert_eq!(out.lambdas.len(), batch);
407 }
408
409 #[test]
410 fn cutmix_labels_sum_to_one() {
411 let batch = 6;
412 let (c, h, w, k) = (3, 16, 16, 4);
413 let images = vec![0.5_f32; batch * c * h * w];
414 let labels = one_hot_batch(batch, k);
415 let mut rng = LcgRng::new(11);
416 let out = cutmix(&images, &labels, batch, c, h, w, k, 1.0, &mut rng).expect("ok");
417 for i in 0..batch {
418 let s: f32 = out.labels[i * k..(i + 1) * k].iter().sum();
419 assert!((s - 1.0).abs() < 1e-5, "row {i} sum {s}");
420 }
421 }
422
423 #[test]
424 fn cutmix_lambda_matches_area() {
425 let batch = 4;
429 let (c, h, w, k) = (1, 16, 16, 4);
430 let images: Vec<f32> = (0..batch).flat_map(|i| vec![i as f32; c * h * w]).collect();
431 let labels = one_hot_batch(batch, k);
432 let mut rng = LcgRng::new(12);
433 let out = cutmix(&images, &labels, batch, c, h, w, k, 1.0, &mut rng).expect("ok");
434 let area = (h * w) as f32;
435 for i in 0..batch {
436 let j = out.partners[i];
437 let vi = i as f32;
438 let vj = j as f32;
439 if (vi - vj).abs() < 1e-6 {
440 continue; }
442 let base = i * c * h * w;
444 let changed = (0..h * w)
445 .filter(|&p| (out.images[base + p] - vj).abs() < 1e-5)
446 .count() as f32;
447 let observed_lambda = 1.0 - changed / area;
448 assert!(
449 (observed_lambda - out.lambdas[i]).abs() < 1e-4,
450 "sample {i}: observed λ {observed_lambda} vs reported {}",
451 out.lambdas[i]
452 );
453 }
454 }
455
456 #[test]
457 fn cutmix_lambda_in_unit_range() {
458 let batch = 5;
459 let (c, h, w, k) = (3, 16, 16, 4);
460 let images = vec![0.5_f32; batch * c * h * w];
461 let labels = one_hot_batch(batch, k);
462 let mut rng = LcgRng::new(13);
463 let out = cutmix(&images, &labels, batch, c, h, w, k, 0.5, &mut rng).expect("ok");
464 for &l in &out.lambdas {
465 assert!((0.0..=1.0).contains(&l), "cutmix λ out of range: {l}");
466 }
467 }
468
469 #[test]
470 fn cutmix_self_paste_identity_when_partner_equal() {
471 let batch = 1;
473 let (c, h, w, k) = (3, 16, 16, 3);
474 let mut rng = LcgRng::new(14);
475 let mut images = vec![0.0_f32; batch * c * h * w];
476 rng.fill_normal(&mut images);
477 let labels = one_hot_batch(batch, k);
478 let out = cutmix(&images, &labels, batch, c, h, w, k, 1.0, &mut rng).expect("ok");
479 assert_eq!(out.images, images, "self-paste must be identity");
480 }
481
482 #[test]
483 fn cutmix_output_finite() {
484 let batch = 4;
485 let (c, h, w, k) = (3, 16, 16, 10);
486 let mut rng = LcgRng::new(15);
487 let mut images = vec![0.0_f32; batch * c * h * w];
488 rng.fill_normal(&mut images);
489 let labels = one_hot_batch(batch, k);
490 let out = cutmix(&images, &labels, batch, c, h, w, k, 0.3, &mut rng).expect("ok");
491 assert!(out.images.iter().all(|v| v.is_finite()));
492 assert!(out.labels.iter().all(|v| v.is_finite()));
493 }
494
495 #[test]
496 fn cutmix_bbox_within_bounds() {
497 let mut rng = LcgRng::new(16);
498 for _ in 0..200 {
499 let (x1, y1, x2, y2) = cutmix_bbox(16, 16, 0.3, &mut rng);
500 assert!(x1 <= x2 && y1 <= y2);
501 assert!(x2 <= 16 && y2 <= 16);
502 }
503 }
504
505 #[test]
506 fn cutmix_empty_errors() {
507 let mut rng = LcgRng::new(17);
508 let r = cutmix(&[], &[], 0, 3, 8, 8, 5, 0.4, &mut rng);
509 assert!(matches!(r, Err(VisionError::EmptyInput(_))));
510 }
511}