use super::core::Transform;
use crate::{Result, VisionError};
use scirs2_core::random::Random;
use torsh_tensor::{creation, creation::zeros_mut, Tensor};
#[derive(Debug, Clone)]
pub struct MixUp {
alpha: f32,
}
impl MixUp {
pub fn new(alpha: f32) -> Self {
assert!(alpha >= 0.0, "Alpha must be non-negative");
Self { alpha }
}
pub fn alpha(&self) -> f32 {
self.alpha
}
pub fn apply_pair(
&self,
input1: &Tensor<f32>,
input2: &Tensor<f32>,
label1: usize,
label2: usize,
num_classes: usize,
) -> Result<(Tensor<f32>, Tensor<f32>)> {
if label1 >= num_classes || label2 >= num_classes {
return Err(VisionError::InvalidArgument(format!(
"Labels ({}, {}) must be less than num_classes ({})",
label1, label2, num_classes
)));
}
let mut rng = Random::seed(42);
let lambda = if self.alpha > 0.0 {
rng.gen_range(0.0..=1.0)
} else {
0.5
};
let mixed_image = input1
.mul_scalar(lambda)?
.add(&input2.mul_scalar(1.0 - lambda)?)?;
let mixed_labels = zeros_mut(&[num_classes]);
if label1 == label2 {
mixed_labels.set(&[label1], 1.0)?;
} else {
mixed_labels.set(&[label1], lambda)?;
mixed_labels.set(&[label2], 1.0 - lambda)?;
}
Ok((mixed_image, mixed_labels))
}
pub fn apply_pair_with_lambda(
&self,
input1: &Tensor<f32>,
input2: &Tensor<f32>,
label1: usize,
label2: usize,
num_classes: usize,
lambda: f32,
) -> Result<(Tensor<f32>, Tensor<f32>)> {
assert!(
(0.0..=1.0).contains(&lambda),
"Lambda must be between 0.0 and 1.0"
);
if label1 >= num_classes || label2 >= num_classes {
return Err(VisionError::InvalidArgument(format!(
"Labels ({}, {}) must be less than num_classes ({})",
label1, label2, num_classes
)));
}
let mixed_image = input1
.mul_scalar(lambda)?
.add(&input2.mul_scalar(1.0 - lambda)?)?;
let mixed_labels = zeros_mut(&[num_classes]);
if label1 == label2 {
mixed_labels.set(&[label1], 1.0)?;
} else {
mixed_labels.set(&[label1], lambda)?;
mixed_labels.set(&[label2], 1.0 - lambda)?;
}
Ok((mixed_image, mixed_labels))
}
}
impl Transform for MixUp {
fn forward(&self, input: &Tensor<f32>) -> Result<Tensor<f32>> {
Ok(input.clone())
}
fn name(&self) -> &'static str {
"MixUp"
}
fn parameters(&self) -> Vec<(&'static str, String)> {
vec![("alpha", format!("{:.2}", self.alpha))]
}
fn clone_transform(&self) -> Box<dyn Transform> {
Box::new(MixUp::new(self.alpha))
}
}
#[derive(Debug, Clone)]
pub struct CutMix {
alpha: f32,
}
impl CutMix {
pub fn new(alpha: f32) -> Self {
assert!(alpha >= 0.0, "Alpha must be non-negative");
Self { alpha }
}
pub fn alpha(&self) -> f32 {
self.alpha
}
pub fn apply_pair(
&self,
input1: &Tensor<f32>,
input2: &Tensor<f32>,
label1: usize,
label2: usize,
num_classes: usize,
) -> Result<(Tensor<f32>, Tensor<f32>)> {
if label1 >= num_classes || label2 >= num_classes {
return Err(VisionError::InvalidArgument(format!(
"Labels ({}, {}) must be less than num_classes ({})",
label1, label2, num_classes
)));
}
let mut rng = Random::seed(42);
let shape = input1.shape();
if shape.dims().len() != 3 {
return Err(VisionError::InvalidShape(format!(
"Expected 3D tensor (C, H, W), got {}D",
shape.dims().len()
)));
}
let (channels, height, width) = (shape.dims()[0], shape.dims()[1], shape.dims()[2]);
let lambda = if self.alpha > 0.0 {
rng.gen_range(0.0..=1.0)
} else {
0.5
};
let cut_ratio = (1.0f32 - lambda).sqrt();
let cut_w = (width as f32 * cut_ratio) as usize;
let cut_h = (height as f32 * cut_ratio) as usize;
let cx = rng.gen_range(0..width);
let cy = rng.gen_range(0..height);
let x1 = (cx as i32 - cut_w as i32 / 2).max(0) as usize;
let y1 = (cy as i32 - cut_h as i32 / 2).max(0) as usize;
let x2 = (x1 + cut_w).min(width);
let y2 = (y1 + cut_h).min(height);
let mixed_image = input1.clone();
for c in 0..channels {
for y in y1..y2 {
for x in x1..x2 {
let pixel_val = input2.get(&[c, y, x])?;
mixed_image.set(&[c, y, x], pixel_val)?;
}
}
}
let cut_area = (x2 - x1) * (y2 - y1);
let total_area = width * height;
let actual_lambda = 1.0 - (cut_area as f32 / total_area as f32);
let mixed_labels = zeros_mut(&[num_classes]);
mixed_labels.set(&[label1], actual_lambda)?;
mixed_labels.set(&[label2], 1.0 - actual_lambda)?;
Ok((mixed_image, mixed_labels))
}
pub fn apply_pair_with_bbox(
&self,
input1: &Tensor<f32>,
input2: &Tensor<f32>,
label1: usize,
label2: usize,
num_classes: usize,
x1: usize,
y1: usize,
x2: usize,
y2: usize,
) -> Result<(Tensor<f32>, Tensor<f32>)> {
if label1 >= num_classes || label2 >= num_classes {
return Err(VisionError::InvalidArgument(format!(
"Labels ({}, {}) must be less than num_classes ({})",
label1, label2, num_classes
)));
}
let shape = input1.shape();
if shape.dims().len() != 3 {
return Err(VisionError::InvalidShape(format!(
"Expected 3D tensor (C, H, W), got {}D",
shape.dims().len()
)));
}
let (channels, height, width) = (shape.dims()[0], shape.dims()[1], shape.dims()[2]);
if x2 <= x1 || y2 <= y1 || x2 > width || y2 > height {
return Err(VisionError::InvalidArgument(format!(
"Invalid bounding box: ({}, {}, {}, {}) for image size {}x{}",
x1, y1, x2, y2, width, height
)));
}
let mixed_image = input1.clone();
for c in 0..channels {
for y in y1..y2 {
for x in x1..x2 {
let pixel_val = input2.get(&[c, y, x])?;
mixed_image.set(&[c, y, x], pixel_val)?;
}
}
}
let cut_area = (x2 - x1) * (y2 - y1);
let total_area = width * height;
let lambda = 1.0 - (cut_area as f32 / total_area as f32);
let mixed_labels = zeros_mut(&[num_classes]);
mixed_labels.set(&[label1], lambda)?;
mixed_labels.set(&[label2], 1.0 - lambda)?;
Ok((mixed_image, mixed_labels))
}
}
impl Transform for CutMix {
fn forward(&self, input: &Tensor<f32>) -> Result<Tensor<f32>> {
Ok(input.clone())
}
fn name(&self) -> &'static str {
"CutMix"
}
fn parameters(&self) -> Vec<(&'static str, String)> {
vec![("alpha", format!("{:.2}", self.alpha))]
}
fn clone_transform(&self) -> Box<dyn Transform> {
Box::new(CutMix::new(self.alpha))
}
}
#[cfg(test)]
mod tests {
use super::*;
use torsh_tensor::creation;
#[test]
fn test_mixup_creation() {
let mixup = MixUp::new(1.0);
assert_eq!(mixup.alpha(), 1.0);
assert_eq!(mixup.name(), "MixUp");
let params = mixup.parameters();
assert_eq!(params.len(), 1);
assert_eq!(params[0].0, "alpha");
assert_eq!(params[0].1, "1.00");
}
#[test]
#[should_panic(expected = "Alpha must be non-negative")]
fn test_mixup_negative_alpha() {
MixUp::new(-0.1);
}
#[test]
fn test_mixup_apply_pair() {
let mixup = MixUp::new(0.0); let input1 = creation::ones(&[3, 4, 4]).expect("creation should succeed");
let input2 = creation::zeros(&[3, 4, 4]).expect("creation should succeed");
let result = mixup.apply_pair(&input1, &input2, 0, 1, 5);
assert!(result.is_ok());
let (mixed_image, mixed_labels) = result.expect("operation should succeed");
assert_eq!(mixed_image.shape().dims(), &[3, 4, 4]);
assert_eq!(mixed_labels.shape().dims(), &[5]);
assert!(
(mixed_labels
.get(&[0])
.expect("element retrieval should succeed for valid index")
- 0.5)
.abs()
< 1e-6
);
assert!(
(mixed_labels
.get(&[1])
.expect("element retrieval should succeed for valid index")
- 0.5)
.abs()
< 1e-6
);
}
#[test]
fn test_mixup_apply_pair_with_lambda() {
let mixup = MixUp::new(1.0);
let input1 = creation::ones(&[3, 4, 4]).expect("creation should succeed");
let input2 = creation::zeros(&[3, 4, 4]).expect("creation should succeed");
let result = mixup.apply_pair_with_lambda(&input1, &input2, 0, 2, 5, 0.3);
assert!(result.is_ok());
let (_mixed_image, mixed_labels) = result.expect("operation should succeed");
assert!(
(mixed_labels
.get(&[0])
.expect("element retrieval should succeed for valid index")
- 0.3)
.abs()
< 1e-6
);
assert!(
(mixed_labels
.get(&[2])
.expect("element retrieval should succeed for valid index")
- 0.7)
.abs()
< 1e-6
);
}
#[test]
#[should_panic(expected = "Lambda must be between 0.0 and 1.0")]
fn test_mixup_invalid_lambda() {
let mixup = MixUp::new(1.0);
let input1 = creation::ones(&[3, 4, 4]).expect("creation should succeed");
let input2 = creation::zeros(&[3, 4, 4]).expect("creation should succeed");
mixup
.apply_pair_with_lambda(&input1, &input2, 0, 1, 5, 1.5)
.expect("operation should succeed");
}
#[test]
fn test_mixup_invalid_labels() {
let mixup = MixUp::new(1.0);
let input1 = creation::ones(&[3, 4, 4]).expect("creation should succeed");
let input2 = creation::zeros(&[3, 4, 4]).expect("creation should succeed");
let result = mixup.apply_pair(&input1, &input2, 5, 1, 5); assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
VisionError::InvalidArgument(_)
));
}
#[test]
fn test_cutmix_creation() {
let cutmix = CutMix::new(1.0);
assert_eq!(cutmix.alpha(), 1.0);
assert_eq!(cutmix.name(), "CutMix");
let params = cutmix.parameters();
assert_eq!(params.len(), 1);
assert_eq!(params[0].0, "alpha");
assert_eq!(params[0].1, "1.00");
}
#[test]
#[should_panic(expected = "Alpha must be non-negative")]
fn test_cutmix_negative_alpha() {
CutMix::new(-0.5);
}
#[test]
fn test_cutmix_apply_pair() {
let cutmix = CutMix::new(1.0);
let input1 = creation::ones(&[3, 8, 8]).expect("creation should succeed");
let input2 = creation::zeros(&[3, 8, 8]).expect("creation should succeed");
let result = cutmix.apply_pair(&input1, &input2, 0, 1, 5);
assert!(result.is_ok());
let (mixed_image, mixed_labels) = result.expect("operation should succeed");
assert_eq!(mixed_image.shape().dims(), &[3, 8, 8]);
assert_eq!(mixed_labels.shape().dims(), &[5]);
let label_sum = mixed_labels
.get(&[0])
.expect("element retrieval should succeed for valid index")
+ mixed_labels
.get(&[1])
.expect("element retrieval should succeed for valid index");
assert!((label_sum - 1.0).abs() < 1e-6);
}
#[test]
fn test_cutmix_apply_pair_with_bbox() {
let cutmix = CutMix::new(1.0);
let input1 = creation::ones(&[3, 8, 8]).expect("creation should succeed");
let input2 = creation::zeros(&[3, 8, 8]).expect("creation should succeed");
let result = cutmix.apply_pair_with_bbox(&input1, &input2, 0, 1, 5, 2, 2, 4, 4);
assert!(result.is_ok());
let (mixed_image, mixed_labels) = result.expect("operation should succeed");
assert_eq!(
mixed_image
.get(&[0, 2, 2])
.expect("element retrieval should succeed for valid index"),
0.0
); assert_eq!(
mixed_image
.get(&[0, 0, 0])
.expect("element retrieval should succeed for valid index"),
1.0
);
let expected_lambda = 1.0 - (4.0 / 64.0);
assert!(
(mixed_labels
.get(&[0])
.expect("element retrieval should succeed for valid index")
- expected_lambda)
.abs()
< 1e-6
);
assert!(
(mixed_labels
.get(&[1])
.expect("element retrieval should succeed for valid index")
- (1.0 - expected_lambda))
.abs()
< 1e-6
);
}
#[test]
fn test_cutmix_invalid_bbox() {
let cutmix = CutMix::new(1.0);
let input1 = creation::ones(&[3, 8, 8]).expect("creation should succeed");
let input2 = creation::zeros(&[3, 8, 8]).expect("creation should succeed");
let result = cutmix.apply_pair_with_bbox(&input1, &input2, 0, 1, 5, 4, 2, 4, 4);
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
VisionError::InvalidArgument(_)
));
let result = cutmix.apply_pair_with_bbox(&input1, &input2, 0, 1, 5, 6, 2, 10, 4);
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
VisionError::InvalidArgument(_)
));
}
#[test]
fn test_cutmix_invalid_shape() {
let cutmix = CutMix::new(1.0);
let input1 = creation::ones(&[8, 8]).expect("creation should succeed"); let input2 = creation::zeros(&[8, 8]).expect("creation should succeed");
let result = cutmix.apply_pair(&input1, &input2, 0, 1, 5);
assert!(result.is_err());
assert!(matches!(result.unwrap_err(), VisionError::InvalidShape(_)));
}
#[test]
fn test_cutmix_invalid_labels() {
let cutmix = CutMix::new(1.0);
let input1 = creation::ones(&[3, 4, 4]).expect("creation should succeed");
let input2 = creation::zeros(&[3, 4, 4]).expect("creation should succeed");
let result = cutmix.apply_pair(&input1, &input2, 0, 5, 5); assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
VisionError::InvalidArgument(_)
));
}
#[test]
fn test_transforms_forward() {
let mixup = MixUp::new(1.0);
let cutmix = CutMix::new(1.0);
let input = creation::ones(&[3, 8, 8]).expect("creation should succeed");
let mixup_result = mixup.forward(&input).expect("forward pass should succeed");
let cutmix_result = cutmix.forward(&input).expect("forward pass should succeed");
assert_eq!(
mixup_result
.get(&[0, 0, 0])
.expect("element retrieval should succeed for valid index"),
1.0
);
assert_eq!(
cutmix_result
.get(&[0, 0, 0])
.expect("element retrieval should succeed for valid index"),
1.0
);
}
#[test]
fn test_clone_transforms() {
let mixup = MixUp::new(0.8);
let cloned = mixup.clone_transform();
assert_eq!(cloned.name(), "MixUp");
let cutmix = CutMix::new(1.2);
let cloned = cutmix.clone_transform();
assert_eq!(cloned.name(), "CutMix");
}
#[test]
fn test_edge_cases() {
let mixup = MixUp::new(0.0);
let cutmix = CutMix::new(0.0);
assert_eq!(mixup.alpha(), 0.0);
assert_eq!(cutmix.alpha(), 0.0);
let input1 = creation::ones(&[1, 1, 1]).expect("creation should succeed");
let input2 = creation::zeros(&[1, 1, 1]).expect("creation should succeed");
let mixup_result = mixup.apply_pair(&input1, &input2, 0, 1, 2);
assert!(mixup_result.is_ok());
let cutmix_result = cutmix.apply_pair(&input1, &input2, 0, 1, 2);
assert!(cutmix_result.is_ok());
}
#[test]
fn test_same_labels() {
let mixup = MixUp::new(1.0);
let input1 = creation::ones(&[3, 4, 4]).expect("creation should succeed");
let input2 = creation::zeros(&[3, 4, 4]).expect("creation should succeed");
let result = mixup.apply_pair(&input1, &input2, 2, 2, 5);
assert!(result.is_ok());
let (_, mixed_labels) = result.expect("operation should succeed");
assert!(
(mixed_labels
.get(&[2])
.expect("element retrieval should succeed for valid index")
- 1.0)
.abs()
< 1e-6
);
for i in 0..5 {
if i != 2 {
assert!(
(mixed_labels
.get(&[i])
.expect("element retrieval should succeed for valid index"))
.abs()
< 1e-6
);
}
}
}
}