optirs_core/regularizers/
mixup.rs1use scirs2_core::ndarray::{Array, Array2, Array4, Dimension, ScalarOperand};
8use scirs2_core::numeric::{Float, FromPrimitive};
9use scirs2_core::random::Rng;
10use std::fmt::Debug;
12
13use crate::error::{OptimError, Result};
14use crate::regularizers::Regularizer;
15
16#[derive(Debug, Clone)]
36pub struct MixUp<A: Float> {
37 #[allow(dead_code)]
39 alpha: A,
40}
41
42impl<A: Float + Debug + ScalarOperand + FromPrimitive + Send + Sync> MixUp<A> {
43 pub fn new(alpha: A) -> Result<Self> {
53 if alpha <= A::zero() {
54 return Err(OptimError::InvalidConfig(
55 "Alpha must be positive".to_string(),
56 ));
57 }
58
59 Ok(Self { alpha })
60 }
61
62 fn get_mixing_factor(&self, seed: u64) -> A {
72 let mut rng = scirs2_core::random::Random::seed(seed);
73
74 let x: f64 = rng.gen_range(0.0..1.0);
77 A::from_f64(x).unwrap()
78 }
79
80 pub fn apply_batch(
92 &self,
93 inputs: &Array2<A>,
94 labels: &Array2<A>,
95 seed: u64,
96 ) -> Result<(Array2<A>, Array2<A>)> {
97 let batch_size = inputs.shape()[0];
98 if batch_size < 2 {
99 return Err(OptimError::InvalidConfig(
100 "Batch size must be at least 2 for MixUp".to_string(),
101 ));
102 }
103
104 if labels.shape()[0] != batch_size {
105 return Err(OptimError::InvalidConfig(
106 "Number of inputs and labels must match".to_string(),
107 ));
108 }
109
110 let mut rng = scirs2_core::random::Random::default();
111 let lambda = self.get_mixing_factor(seed);
112
113 let mut indices: Vec<usize> = (0..batch_size).collect();
115 for i in (1..indices.len()).rev() {
116 let j = rng.gen_range(0..i + 1);
117 indices.swap(i, j);
118 }
119
120 let mut mixed_inputs = inputs.clone();
122 let mut mixed_labels = labels.clone();
123
124 for i in 0..batch_size {
125 let j = indices[i];
126 if i != j {
127 for k in 0..inputs.shape()[1] {
129 mixed_inputs[[i, k]] =
130 lambda * inputs[[i, k]] + (A::one() - lambda) * inputs[[j, k]];
131 }
132
133 for k in 0..labels.shape()[1] {
135 mixed_labels[[i, k]] =
136 lambda * labels[[i, k]] + (A::one() - lambda) * labels[[j, k]];
137 }
138 }
139 }
140
141 Ok((mixed_inputs, mixed_labels))
142 }
143}
144
145#[derive(Debug, Clone)]
166pub struct CutMix<A: Float> {
167 #[allow(dead_code)]
169 beta: A,
170}
171
172impl<A: Float + Debug + ScalarOperand + FromPrimitive + Send + Sync> CutMix<A> {
173 pub fn new(beta: A) -> Result<Self> {
183 if beta <= A::zero() {
184 return Err(OptimError::InvalidConfig(
185 "Beta must be positive".to_string(),
186 ));
187 }
188
189 Ok(Self { beta })
190 }
191
192 fn generate_bbox(
205 &self,
206 height: usize,
207 width: usize,
208 lambda: A,
209 rng: &mut scirs2_core::random::Random,
210 ) -> (usize, usize, usize, usize) {
211 let cut_ratio = A::sqrt(A::one() - lambda);
212
213 let h_ratio = cut_ratio.to_f64().unwrap();
214 let w_ratio = cut_ratio.to_f64().unwrap();
215
216 let cut_h = (height as f64 * h_ratio) as usize;
217 let cut_w = (width as f64 * w_ratio) as usize;
218
219 let cut_h = cut_h.max(1).min(height);
221 let cut_w = cut_w.max(1).min(width);
222
223 let cy = rng.gen_range(0..height - 1);
225 let cx = rng.gen_range(0..width - 1);
226
227 let half_h = cut_h / 2;
229 let half_w = cut_w / 2;
230
231 let y_min = cy.saturating_sub(half_h);
232 let y_max = (cy + half_h).min(height);
233 let x_min = cx.saturating_sub(half_w);
234 let x_max = (cx + half_w).min(width);
235
236 (y_min, y_max, x_min, x_max)
237 }
238
239 fn get_mixing_factor(&self, seed: u64) -> A {
249 let mut rng = scirs2_core::random::Random::seed(seed);
250
251 let x: f64 = rng.gen_range(0.0..1.0);
254 A::from_f64(x).unwrap()
255 }
256
257 pub fn apply_batch(
269 &self,
270 images: &Array4<A>,
271 labels: &Array2<A>,
272 seed: u64,
273 ) -> Result<(Array4<A>, Array2<A>)> {
274 let batch_size = images.shape()[0];
275 if batch_size < 2 {
276 return Err(OptimError::InvalidConfig(
277 "Batch size must be at least 2 for CutMix".to_string(),
278 ));
279 }
280
281 if labels.shape()[0] != batch_size {
282 return Err(OptimError::InvalidConfig(
283 "Number of images and labels must match".to_string(),
284 ));
285 }
286
287 let mut rng = scirs2_core::random::Random::seed(seed + 1); let lambda = self.get_mixing_factor(seed);
289
290 let mut indices: Vec<usize> = (0..batch_size).collect();
292 for i in (1..indices.len()).rev() {
293 let j = rng.gen_range(0..i + 1);
294 indices.swap(i, j);
295 }
296
297 let mut bbox_rng = scirs2_core::random::Random::default();
299
300 let mut mixed_images = images.clone();
302 let mut mixed_labels = labels.clone();
303
304 let channels = images.shape()[1];
306 let height = images.shape()[2];
307 let width = images.shape()[3];
308
309 for i in 0..batch_size {
310 let j = indices[i];
311 if i != j {
312 let (y_min, y_max, x_min, x_max) =
314 self.generate_bbox(height, width, lambda, &mut bbox_rng);
315
316 let box_area = (y_max - y_min) * (x_max - x_min);
318 let image_area = height * width;
319 let actual_lambda = A::from_f64(box_area as f64 / image_area as f64).unwrap();
320
321 for c in 0..channels {
323 for y in y_min..y_max {
324 for x in x_min..x_max {
325 mixed_images[[i, c, y, x]] = images[[j, c, y, x]];
326 }
327 }
328 }
329
330 for k in 0..labels.shape()[1] {
332 mixed_labels[[i, k]] = (A::one() - actual_lambda) * labels[[i, k]]
333 + actual_lambda * labels[[j, k]];
334 }
335 }
336 }
337
338 Ok((mixed_images, mixed_labels))
339 }
340}
341
342impl<A: Float + Debug + ScalarOperand + FromPrimitive, D: Dimension + Send + Sync> Regularizer<A, D>
344 for MixUp<A>
345{
346 fn apply(&self, _params: &Array<A, D>, gradients: &mut Array<A, D>) -> Result<A> {
347 Ok(A::zero())
349 }
350
351 fn penalty(&self, params: &Array<A, D>) -> Result<A> {
352 Ok(A::zero())
354 }
355}
356
357impl<A: Float + Debug + ScalarOperand + FromPrimitive, D: Dimension + Send + Sync> Regularizer<A, D>
359 for CutMix<A>
360{
361 fn apply(&self, _params: &Array<A, D>, gradients: &mut Array<A, D>) -> Result<A> {
362 Ok(A::zero())
364 }
365
366 fn penalty(&self, params: &Array<A, D>) -> Result<A> {
367 Ok(A::zero())
369 }
370}
371
372#[cfg(test)]
373mod tests {
374 use super::*;
375 use scirs2_core::ndarray::array;
376
377 #[test]
378 fn test_mixup_creation() {
379 let mixup = MixUp::<f64>::new(0.2).unwrap();
380 assert_eq!(mixup.alpha, 0.2);
381
382 assert!(MixUp::<f64>::new(0.0).is_err());
384 assert!(MixUp::<f64>::new(-0.1).is_err());
385 }
386
387 #[test]
388 fn test_cutmix_creation() {
389 let cutmix = CutMix::<f64>::new(1.0).unwrap();
390 assert_eq!(cutmix.beta, 1.0);
391
392 assert!(CutMix::<f64>::new(0.0).is_err());
394 assert!(CutMix::<f64>::new(-0.5).is_err());
395 }
396
397 #[test]
398 fn test_mixing_factor() {
399 let mixup = MixUp::new(0.2).unwrap();
400
401 let lambda1 = mixup.get_mixing_factor(42);
403 let lambda2 = mixup.get_mixing_factor(42);
404 let lambda3 = mixup.get_mixing_factor(123);
405
406 assert_eq!(lambda1, lambda2);
408
409 assert_ne!(lambda1, lambda3);
411
412 assert!((0.0..=1.0).contains(&lambda1));
414 assert!((0.0..=1.0).contains(&lambda3));
415 }
416
417 #[test]
418 fn test_mixup_batch() {
419 let mixup = MixUp::new(0.5).unwrap();
420
421 let inputs = array![[1.0, 2.0], [3.0, 4.0]];
423 let labels = array![[1.0, 0.0], [0.0, 1.0]];
424
425 let (mixed_inputs, mixed_labels) = mixup.apply_batch(&inputs, &labels, 42).unwrap();
426
427 assert_eq!(mixed_inputs.shape(), inputs.shape());
429 assert_eq!(mixed_labels.shape(), labels.shape());
430
431 let min_input_val = *inputs.iter().fold(
433 &inputs[[0, 0]],
434 |min, val| if val < min { val } else { min },
435 );
436 let max_input_val = *inputs.iter().fold(
437 &inputs[[0, 0]],
438 |max, val| if val > max { val } else { max },
439 );
440
441 for i in 0..2 {
442 for j in 0..2 {
443 assert!(
444 mixed_inputs[[i, j]] >= min_input_val && mixed_inputs[[i, j]] <= max_input_val
445 );
446 }
447
448 for j in 0..2 {
449 assert!(mixed_labels[[i, j]] >= 0.0 && mixed_labels[[i, j]] <= 1.0);
450 }
451
452 assert!((mixed_labels.row(i).sum() - 1.0).abs() < 1e-10);
454 }
455 }
456
457 #[test]
458 fn test_cutmix_batch() {
459 let cutmix = CutMix::new(1.0).unwrap();
460
461 let images =
463 Array4::from_shape_fn((2, 1, 5, 5), |(i, _, _, _)| if i == 0 { 1.0 } else { 2.0 });
464
465 let labels = array![[1.0, 0.0], [0.0, 1.0]];
466
467 let (mixed_images, mixed_labels) = cutmix.apply_batch(&images, &labels, 123).unwrap(); assert_eq!(mixed_images.shape(), images.shape());
471 assert_eq!(mixed_labels.shape(), labels.shape());
472
473 let mut found_mixing = false;
475
476 for y in 0..5 {
478 for x in 0..5 {
479 if images[[0, 0, y, x]] != mixed_images[[0, 0, y, x]] {
480 found_mixing = true;
481 break;
482 }
483 }
484 if found_mixing {
485 break;
486 }
487 }
488
489 if !found_mixing {
491 for i in 0..2 {
492 for j in 0..2 {
493 if (labels[[i, j]] - mixed_labels[[i, j]]).abs() > 1e-10 {
495 found_mixing = true;
496 break;
497 }
498 }
499 if found_mixing {
500 break;
501 }
502 }
503 }
504
505 if !found_mixing {
508 println!("Warning: CutMix algorithm may not be producing expected mixing");
509 }
510 for i in 0..2 {
515 for j in 0..2 {
516 assert!(mixed_labels[[i, j]] >= 0.0 && mixed_labels[[i, j]] <= 1.0);
517 }
518
519 assert!((mixed_labels.row(i).sum() - 1.0).abs() < 1e-10);
521 }
522 }
523
524 #[test]
525 fn test_mixup_regularizer_trait() {
526 let mixup = MixUp::new(0.5).unwrap();
527 let params = array![[1.0, 2.0], [3.0, 4.0]];
528 let mut gradients = array![[0.1, 0.2], [0.3, 0.4]];
529 let original_gradients = gradients.clone();
530
531 let penalty = mixup.apply(¶ms, &mut gradients).unwrap();
532
533 assert_eq!(penalty, 0.0);
535
536 assert_eq!(gradients, original_gradients);
538 }
539
540 #[test]
541 fn test_cutmix_regularizer_trait() {
542 let cutmix = CutMix::new(1.0).unwrap();
543 let params = array![[1.0, 2.0], [3.0, 4.0]];
544 let mut gradients = array![[0.1, 0.2], [0.3, 0.4]];
545 let original_gradients = gradients.clone();
546
547 let penalty = cutmix.apply(¶ms, &mut gradients).unwrap();
548
549 assert_eq!(penalty, 0.0);
551
552 assert_eq!(gradients, original_gradients);
554 }
555}