optirs_core/regularizers/
mixup.rs1use scirs2_core::ndarray::{Array, Array2, Array4, Dimension, ScalarOperand};
8use scirs2_core::numeric::{Float, FromPrimitive};
9use scirs2_core::random::Rng;
10use std::fmt::Debug;
12
13use crate::error::{OptimError, Result};
14use crate::regularizers::Regularizer;
15
16#[derive(Debug, Clone)]
36pub struct MixUp<A: Float> {
37 #[allow(dead_code)]
39 alpha: A,
40}
41
42impl<A: Float + Debug + ScalarOperand + FromPrimitive + Send + Sync> MixUp<A> {
43 pub fn new(alpha: A) -> Result<Self> {
53 if alpha <= A::zero() {
54 return Err(OptimError::InvalidConfig(
55 "Alpha must be positive".to_string(),
56 ));
57 }
58
59 Ok(Self { alpha })
60 }
61
62 fn get_mixing_factor(&self, seed: u64) -> A {
72 let mut rng = scirs2_core::random::Random::seed(seed);
73
74 let x: f64 = rng.gen_range(0.0..1.0);
77 A::from_f64(x).expect("unwrap failed")
78 }
79
80 pub fn apply_batch(
92 &self,
93 inputs: &Array2<A>,
94 labels: &Array2<A>,
95 seed: u64,
96 ) -> Result<(Array2<A>, Array2<A>)> {
97 let batch_size = inputs.shape()[0];
98 if batch_size < 2 {
99 return Err(OptimError::InvalidConfig(
100 "Batch size must be at least 2 for MixUp".to_string(),
101 ));
102 }
103
104 if labels.shape()[0] != batch_size {
105 return Err(OptimError::InvalidConfig(
106 "Number of inputs and labels must match".to_string(),
107 ));
108 }
109
110 let mut rng = scirs2_core::random::Random::default();
111 let lambda = self.get_mixing_factor(seed);
112
113 let mut indices: Vec<usize> = (0..batch_size).collect();
115 for i in (1..indices.len()).rev() {
116 let j = rng.gen_range(0..i + 1);
117 indices.swap(i, j);
118 }
119
120 let mut mixed_inputs = inputs.clone();
122 let mut mixed_labels = labels.clone();
123
124 for i in 0..batch_size {
125 let j = indices[i];
126 if i != j {
127 for k in 0..inputs.shape()[1] {
129 mixed_inputs[[i, k]] =
130 lambda * inputs[[i, k]] + (A::one() - lambda) * inputs[[j, k]];
131 }
132
133 for k in 0..labels.shape()[1] {
135 mixed_labels[[i, k]] =
136 lambda * labels[[i, k]] + (A::one() - lambda) * labels[[j, k]];
137 }
138 }
139 }
140
141 Ok((mixed_inputs, mixed_labels))
142 }
143}
144
145#[derive(Debug, Clone)]
166pub struct CutMix<A: Float> {
167 #[allow(dead_code)]
169 beta: A,
170}
171
172impl<A: Float + Debug + ScalarOperand + FromPrimitive + Send + Sync> CutMix<A> {
173 pub fn new(beta: A) -> Result<Self> {
183 if beta <= A::zero() {
184 return Err(OptimError::InvalidConfig(
185 "Beta must be positive".to_string(),
186 ));
187 }
188
189 Ok(Self { beta })
190 }
191
192 fn generate_bbox(
205 &self,
206 height: usize,
207 width: usize,
208 lambda: A,
209 rng: &mut scirs2_core::random::Random,
210 ) -> (usize, usize, usize, usize) {
211 let cut_ratio = A::sqrt(A::one() - lambda);
212
213 let h_ratio = cut_ratio.to_f64().expect("unwrap failed");
214 let w_ratio = cut_ratio.to_f64().expect("unwrap failed");
215
216 let cut_h = (height as f64 * h_ratio) as usize;
217 let cut_w = (width as f64 * w_ratio) as usize;
218
219 let cut_h = cut_h.max(1).min(height);
221 let cut_w = cut_w.max(1).min(width);
222
223 let cy = rng.gen_range(0..height - 1);
225 let cx = rng.gen_range(0..width - 1);
226
227 let half_h = cut_h / 2;
229 let half_w = cut_w / 2;
230
231 let y_min = cy.saturating_sub(half_h);
232 let y_max = (cy + half_h).min(height);
233 let x_min = cx.saturating_sub(half_w);
234 let x_max = (cx + half_w).min(width);
235
236 (y_min, y_max, x_min, x_max)
237 }
238
239 fn get_mixing_factor(&self, seed: u64) -> A {
249 let mut rng = scirs2_core::random::Random::seed(seed);
250
251 let x: f64 = rng.gen_range(0.0..1.0);
254 A::from_f64(x).expect("unwrap failed")
255 }
256
257 pub fn apply_batch(
269 &self,
270 images: &Array4<A>,
271 labels: &Array2<A>,
272 seed: u64,
273 ) -> Result<(Array4<A>, Array2<A>)> {
274 let batch_size = images.shape()[0];
275 if batch_size < 2 {
276 return Err(OptimError::InvalidConfig(
277 "Batch size must be at least 2 for CutMix".to_string(),
278 ));
279 }
280
281 if labels.shape()[0] != batch_size {
282 return Err(OptimError::InvalidConfig(
283 "Number of images and labels must match".to_string(),
284 ));
285 }
286
287 let mut rng = scirs2_core::random::Random::seed(seed + 1); let lambda = self.get_mixing_factor(seed);
289
290 let mut indices: Vec<usize> = (0..batch_size).collect();
292 for i in (1..indices.len()).rev() {
293 let j = rng.gen_range(0..i + 1);
294 indices.swap(i, j);
295 }
296
297 let mut bbox_rng = scirs2_core::random::Random::default();
299
300 let mut mixed_images = images.clone();
302 let mut mixed_labels = labels.clone();
303
304 let channels = images.shape()[1];
306 let height = images.shape()[2];
307 let width = images.shape()[3];
308
309 for i in 0..batch_size {
310 let j = indices[i];
311 if i != j {
312 let (y_min, y_max, x_min, x_max) =
314 self.generate_bbox(height, width, lambda, &mut bbox_rng);
315
316 let box_area = (y_max - y_min) * (x_max - x_min);
318 let image_area = height * width;
319 let actual_lambda =
320 A::from_f64(box_area as f64 / image_area as f64).expect("unwrap failed");
321
322 for c in 0..channels {
324 for y in y_min..y_max {
325 for x in x_min..x_max {
326 mixed_images[[i, c, y, x]] = images[[j, c, y, x]];
327 }
328 }
329 }
330
331 for k in 0..labels.shape()[1] {
333 mixed_labels[[i, k]] = (A::one() - actual_lambda) * labels[[i, k]]
334 + actual_lambda * labels[[j, k]];
335 }
336 }
337 }
338
339 Ok((mixed_images, mixed_labels))
340 }
341}
342
343impl<A: Float + Debug + ScalarOperand + FromPrimitive, D: Dimension + Send + Sync> Regularizer<A, D>
345 for MixUp<A>
346{
347 fn apply(&self, _params: &Array<A, D>, gradients: &mut Array<A, D>) -> Result<A> {
348 Ok(A::zero())
350 }
351
352 fn penalty(&self, params: &Array<A, D>) -> Result<A> {
353 Ok(A::zero())
355 }
356}
357
358impl<A: Float + Debug + ScalarOperand + FromPrimitive, D: Dimension + Send + Sync> Regularizer<A, D>
360 for CutMix<A>
361{
362 fn apply(&self, _params: &Array<A, D>, gradients: &mut Array<A, D>) -> Result<A> {
363 Ok(A::zero())
365 }
366
367 fn penalty(&self, params: &Array<A, D>) -> Result<A> {
368 Ok(A::zero())
370 }
371}
372
373#[cfg(test)]
374mod tests {
375 use super::*;
376 use scirs2_core::ndarray::array;
377
378 #[test]
379 fn test_mixup_creation() {
380 let mixup = MixUp::<f64>::new(0.2).expect("unwrap failed");
381 assert_eq!(mixup.alpha, 0.2);
382
383 assert!(MixUp::<f64>::new(0.0).is_err());
385 assert!(MixUp::<f64>::new(-0.1).is_err());
386 }
387
388 #[test]
389 fn test_cutmix_creation() {
390 let cutmix = CutMix::<f64>::new(1.0).expect("unwrap failed");
391 assert_eq!(cutmix.beta, 1.0);
392
393 assert!(CutMix::<f64>::new(0.0).is_err());
395 assert!(CutMix::<f64>::new(-0.5).is_err());
396 }
397
398 #[test]
399 fn test_mixing_factor() {
400 let mixup = MixUp::new(0.2).expect("unwrap failed");
401
402 let lambda1 = mixup.get_mixing_factor(42);
404 let lambda2 = mixup.get_mixing_factor(42);
405 let lambda3 = mixup.get_mixing_factor(123);
406
407 assert_eq!(lambda1, lambda2);
409
410 assert_ne!(lambda1, lambda3);
412
413 assert!((0.0..=1.0).contains(&lambda1));
415 assert!((0.0..=1.0).contains(&lambda3));
416 }
417
418 #[test]
419 fn test_mixup_batch() {
420 let mixup = MixUp::new(0.5).expect("unwrap failed");
421
422 let inputs = array![[1.0, 2.0], [3.0, 4.0]];
424 let labels = array![[1.0, 0.0], [0.0, 1.0]];
425
426 let (mixed_inputs, mixed_labels) = mixup
427 .apply_batch(&inputs, &labels, 42)
428 .expect("unwrap failed");
429
430 assert_eq!(mixed_inputs.shape(), inputs.shape());
432 assert_eq!(mixed_labels.shape(), labels.shape());
433
434 let min_input_val = *inputs.iter().fold(
436 &inputs[[0, 0]],
437 |min, val| if val < min { val } else { min },
438 );
439 let max_input_val = *inputs.iter().fold(
440 &inputs[[0, 0]],
441 |max, val| if val > max { val } else { max },
442 );
443
444 for i in 0..2 {
445 for j in 0..2 {
446 assert!(
447 mixed_inputs[[i, j]] >= min_input_val && mixed_inputs[[i, j]] <= max_input_val
448 );
449 }
450
451 for j in 0..2 {
452 assert!(mixed_labels[[i, j]] >= 0.0 && mixed_labels[[i, j]] <= 1.0);
453 }
454
455 assert!((mixed_labels.row(i).sum() - 1.0).abs() < 1e-10);
457 }
458 }
459
460 #[test]
461 fn test_cutmix_batch() {
462 let cutmix = CutMix::new(1.0).expect("unwrap failed");
463
464 let images =
466 Array4::from_shape_fn((2, 1, 5, 5), |(i, _, _, _)| if i == 0 { 1.0 } else { 2.0 });
467
468 let labels = array![[1.0, 0.0], [0.0, 1.0]];
469
470 let (mixed_images, mixed_labels) = cutmix
471 .apply_batch(&images, &labels, 123)
472 .expect("unwrap failed"); assert_eq!(mixed_images.shape(), images.shape());
476 assert_eq!(mixed_labels.shape(), labels.shape());
477
478 let mut found_mixing = false;
480
481 for y in 0..5 {
483 for x in 0..5 {
484 if images[[0, 0, y, x]] != mixed_images[[0, 0, y, x]] {
485 found_mixing = true;
486 break;
487 }
488 }
489 if found_mixing {
490 break;
491 }
492 }
493
494 if !found_mixing {
496 for i in 0..2 {
497 for j in 0..2 {
498 if (labels[[i, j]] - mixed_labels[[i, j]]).abs() > 1e-10 {
500 found_mixing = true;
501 break;
502 }
503 }
504 if found_mixing {
505 break;
506 }
507 }
508 }
509
510 if !found_mixing {
513 println!("Warning: CutMix algorithm may not be producing expected mixing");
514 }
515 for i in 0..2 {
520 for j in 0..2 {
521 assert!(mixed_labels[[i, j]] >= 0.0 && mixed_labels[[i, j]] <= 1.0);
522 }
523
524 assert!((mixed_labels.row(i).sum() - 1.0).abs() < 1e-10);
526 }
527 }
528
529 #[test]
530 fn test_mixup_regularizer_trait() {
531 let mixup = MixUp::new(0.5).expect("unwrap failed");
532 let params = array![[1.0, 2.0], [3.0, 4.0]];
533 let mut gradients = array![[0.1, 0.2], [0.3, 0.4]];
534 let original_gradients = gradients.clone();
535
536 let penalty = mixup.apply(¶ms, &mut gradients).expect("unwrap failed");
537
538 assert_eq!(penalty, 0.0);
540
541 assert_eq!(gradients, original_gradients);
543 }
544
545 #[test]
546 fn test_cutmix_regularizer_trait() {
547 let cutmix = CutMix::new(1.0).expect("unwrap failed");
548 let params = array![[1.0, 2.0], [3.0, 4.0]];
549 let mut gradients = array![[0.1, 0.2], [0.3, 0.4]];
550 let original_gradients = gradients.clone();
551
552 let penalty = cutmix
553 .apply(¶ms, &mut gradients)
554 .expect("unwrap failed");
555
556 assert_eq!(penalty, 0.0);
558
559 assert_eq!(gradients, original_gradients);
561 }
562}