use scirs2_core::ndarray::{Array, Array2, Array4, Dimension, ScalarOperand};
use scirs2_core::numeric::{Float, FromPrimitive};
use scirs2_core::random::Rng;
use std::fmt::Debug;
use crate::error::{OptimError, Result};
use crate::regularizers::Regularizer;
#[derive(Debug, Clone)]
pub struct MixUp<A: Float> {
#[allow(dead_code)]
alpha: A,
}
impl<A: Float + Debug + ScalarOperand + FromPrimitive + Send + Sync> MixUp<A> {
pub fn new(alpha: A) -> Result<Self> {
if alpha <= A::zero() {
return Err(OptimError::InvalidConfig(
"Alpha must be positive".to_string(),
));
}
Ok(Self { alpha })
}
fn get_mixing_factor(&self, seed: u64) -> A {
let mut rng = scirs2_core::random::Random::seed(seed);
let x: f64 = rng.gen_range(0.0..1.0);
A::from_f64(x).expect("unwrap failed")
}
pub fn apply_batch(
&self,
inputs: &Array2<A>,
labels: &Array2<A>,
seed: u64,
) -> Result<(Array2<A>, Array2<A>)> {
let batch_size = inputs.shape()[0];
if batch_size < 2 {
return Err(OptimError::InvalidConfig(
"Batch size must be at least 2 for MixUp".to_string(),
));
}
if labels.shape()[0] != batch_size {
return Err(OptimError::InvalidConfig(
"Number of inputs and labels must match".to_string(),
));
}
let mut rng = scirs2_core::random::Random::default();
let lambda = self.get_mixing_factor(seed);
let mut indices: Vec<usize> = (0..batch_size).collect();
for i in (1..indices.len()).rev() {
let j = rng.gen_range(0..i + 1);
indices.swap(i, j);
}
let mut mixed_inputs = inputs.clone();
let mut mixed_labels = labels.clone();
for i in 0..batch_size {
let j = indices[i];
if i != j {
for k in 0..inputs.shape()[1] {
mixed_inputs[[i, k]] =
lambda * inputs[[i, k]] + (A::one() - lambda) * inputs[[j, k]];
}
for k in 0..labels.shape()[1] {
mixed_labels[[i, k]] =
lambda * labels[[i, k]] + (A::one() - lambda) * labels[[j, k]];
}
}
}
Ok((mixed_inputs, mixed_labels))
}
}
#[derive(Debug, Clone)]
pub struct CutMix<A: Float> {
#[allow(dead_code)]
beta: A,
}
impl<A: Float + Debug + ScalarOperand + FromPrimitive + Send + Sync> CutMix<A> {
pub fn new(beta: A) -> Result<Self> {
if beta <= A::zero() {
return Err(OptimError::InvalidConfig(
"Beta must be positive".to_string(),
));
}
Ok(Self { beta })
}
fn generate_bbox(
&self,
height: usize,
width: usize,
lambda: A,
rng: &mut scirs2_core::random::Random,
) -> (usize, usize, usize, usize) {
let cut_ratio = A::sqrt(A::one() - lambda);
let h_ratio = cut_ratio.to_f64().expect("unwrap failed");
let w_ratio = cut_ratio.to_f64().expect("unwrap failed");
let cut_h = (height as f64 * h_ratio) as usize;
let cut_w = (width as f64 * w_ratio) as usize;
let cut_h = cut_h.max(1).min(height);
let cut_w = cut_w.max(1).min(width);
let cy = rng.gen_range(0..height - 1);
let cx = rng.gen_range(0..width - 1);
let half_h = cut_h / 2;
let half_w = cut_w / 2;
let y_min = cy.saturating_sub(half_h);
let y_max = (cy + half_h).min(height);
let x_min = cx.saturating_sub(half_w);
let x_max = (cx + half_w).min(width);
(y_min, y_max, x_min, x_max)
}
fn get_mixing_factor(&self, seed: u64) -> A {
let mut rng = scirs2_core::random::Random::seed(seed);
let x: f64 = rng.gen_range(0.0..1.0);
A::from_f64(x).expect("unwrap failed")
}
pub fn apply_batch(
&self,
images: &Array4<A>,
labels: &Array2<A>,
seed: u64,
) -> Result<(Array4<A>, Array2<A>)> {
let batch_size = images.shape()[0];
if batch_size < 2 {
return Err(OptimError::InvalidConfig(
"Batch size must be at least 2 for CutMix".to_string(),
));
}
if labels.shape()[0] != batch_size {
return Err(OptimError::InvalidConfig(
"Number of images and labels must match".to_string(),
));
}
let mut rng = scirs2_core::random::Random::seed(seed + 1); let lambda = self.get_mixing_factor(seed);
let mut indices: Vec<usize> = (0..batch_size).collect();
for i in (1..indices.len()).rev() {
let j = rng.gen_range(0..i + 1);
indices.swap(i, j);
}
let mut bbox_rng = scirs2_core::random::Random::default();
let mut mixed_images = images.clone();
let mut mixed_labels = labels.clone();
let channels = images.shape()[1];
let height = images.shape()[2];
let width = images.shape()[3];
for i in 0..batch_size {
let j = indices[i];
if i != j {
let (y_min, y_max, x_min, x_max) =
self.generate_bbox(height, width, lambda, &mut bbox_rng);
let box_area = (y_max - y_min) * (x_max - x_min);
let image_area = height * width;
let actual_lambda =
A::from_f64(box_area as f64 / image_area as f64).expect("unwrap failed");
for c in 0..channels {
for y in y_min..y_max {
for x in x_min..x_max {
mixed_images[[i, c, y, x]] = images[[j, c, y, x]];
}
}
}
for k in 0..labels.shape()[1] {
mixed_labels[[i, k]] = (A::one() - actual_lambda) * labels[[i, k]]
+ actual_lambda * labels[[j, k]];
}
}
}
Ok((mixed_images, mixed_labels))
}
}
impl<A: Float + Debug + ScalarOperand + FromPrimitive, D: Dimension + Send + Sync> Regularizer<A, D>
for MixUp<A>
{
fn apply(&self, _params: &Array<A, D>, gradients: &mut Array<A, D>) -> Result<A> {
Ok(A::zero())
}
fn penalty(&self, params: &Array<A, D>) -> Result<A> {
Ok(A::zero())
}
}
impl<A: Float + Debug + ScalarOperand + FromPrimitive, D: Dimension + Send + Sync> Regularizer<A, D>
for CutMix<A>
{
fn apply(&self, _params: &Array<A, D>, gradients: &mut Array<A, D>) -> Result<A> {
Ok(A::zero())
}
fn penalty(&self, params: &Array<A, D>) -> Result<A> {
Ok(A::zero())
}
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::array;
#[test]
fn test_mixup_creation() {
let mixup = MixUp::<f64>::new(0.2).expect("unwrap failed");
assert_eq!(mixup.alpha, 0.2);
assert!(MixUp::<f64>::new(0.0).is_err());
assert!(MixUp::<f64>::new(-0.1).is_err());
}
#[test]
fn test_cutmix_creation() {
let cutmix = CutMix::<f64>::new(1.0).expect("unwrap failed");
assert_eq!(cutmix.beta, 1.0);
assert!(CutMix::<f64>::new(0.0).is_err());
assert!(CutMix::<f64>::new(-0.5).is_err());
}
#[test]
fn test_mixing_factor() {
let mixup = MixUp::new(0.2).expect("unwrap failed");
let lambda1 = mixup.get_mixing_factor(42);
let lambda2 = mixup.get_mixing_factor(42);
let lambda3 = mixup.get_mixing_factor(123);
assert_eq!(lambda1, lambda2);
assert_ne!(lambda1, lambda3);
assert!((0.0..=1.0).contains(&lambda1));
assert!((0.0..=1.0).contains(&lambda3));
}
#[test]
fn test_mixup_batch() {
let mixup = MixUp::new(0.5).expect("unwrap failed");
let inputs = array![[1.0, 2.0], [3.0, 4.0]];
let labels = array![[1.0, 0.0], [0.0, 1.0]];
let (mixed_inputs, mixed_labels) = mixup
.apply_batch(&inputs, &labels, 42)
.expect("unwrap failed");
assert_eq!(mixed_inputs.shape(), inputs.shape());
assert_eq!(mixed_labels.shape(), labels.shape());
let min_input_val = *inputs.iter().fold(
&inputs[[0, 0]],
|min, val| if val < min { val } else { min },
);
let max_input_val = *inputs.iter().fold(
&inputs[[0, 0]],
|max, val| if val > max { val } else { max },
);
for i in 0..2 {
for j in 0..2 {
assert!(
mixed_inputs[[i, j]] >= min_input_val && mixed_inputs[[i, j]] <= max_input_val
);
}
for j in 0..2 {
assert!(mixed_labels[[i, j]] >= 0.0 && mixed_labels[[i, j]] <= 1.0);
}
assert!((mixed_labels.row(i).sum() - 1.0).abs() < 1e-10);
}
}
#[test]
fn test_cutmix_batch() {
let cutmix = CutMix::new(1.0).expect("unwrap failed");
let images =
Array4::from_shape_fn((2, 1, 5, 5), |(i, _, _, _)| if i == 0 { 1.0 } else { 2.0 });
let labels = array![[1.0, 0.0], [0.0, 1.0]];
let (mixed_images, mixed_labels) = cutmix
.apply_batch(&images, &labels, 123)
.expect("unwrap failed");
assert_eq!(mixed_images.shape(), images.shape());
assert_eq!(mixed_labels.shape(), labels.shape());
let mut found_mixing = false;
for y in 0..5 {
for x in 0..5 {
if images[[0, 0, y, x]] != mixed_images[[0, 0, y, x]] {
found_mixing = true;
break;
}
}
if found_mixing {
break;
}
}
if !found_mixing {
for i in 0..2 {
for j in 0..2 {
if (labels[[i, j]] - mixed_labels[[i, j]]).abs() > 1e-10 {
found_mixing = true;
break;
}
}
if found_mixing {
break;
}
}
}
if !found_mixing {
println!("Warning: CutMix algorithm may not be producing expected mixing");
}
for i in 0..2 {
for j in 0..2 {
assert!(mixed_labels[[i, j]] >= 0.0 && mixed_labels[[i, j]] <= 1.0);
}
assert!((mixed_labels.row(i).sum() - 1.0).abs() < 1e-10);
}
}
#[test]
fn test_mixup_regularizer_trait() {
let mixup = MixUp::new(0.5).expect("unwrap failed");
let params = array![[1.0, 2.0], [3.0, 4.0]];
let mut gradients = array![[0.1, 0.2], [0.3, 0.4]];
let original_gradients = gradients.clone();
let penalty = mixup.apply(¶ms, &mut gradients).expect("unwrap failed");
assert_eq!(penalty, 0.0);
assert_eq!(gradients, original_gradients);
}
#[test]
fn test_cutmix_regularizer_trait() {
let cutmix = CutMix::new(1.0).expect("unwrap failed");
let params = array![[1.0, 2.0], [3.0, 4.0]];
let mut gradients = array![[0.1, 0.2], [0.3, 0.4]];
let original_gradients = gradients.clone();
let penalty = cutmix
.apply(¶ms, &mut gradients)
.expect("unwrap failed");
assert_eq!(penalty, 0.0);
assert_eq!(gradients, original_gradients);
}
}