use torsh_core::{Result as TorshResult, TorshError};
use torsh_tensor::{
creation::{ones, rand},
Tensor,
};
pub fn mixup(
x1: &Tensor,
x2: &Tensor,
y1: &Tensor,
y2: &Tensor,
lambda: f32,
) -> TorshResult<(Tensor, Tensor)> {
let lambda = lambda.clamp(0.0, 1.0);
let mixed_x = x1.mul_scalar(lambda)?.add(&x2.mul_scalar(1.0 - lambda)?)?;
let mixed_y = y1.mul_scalar(lambda)?.add(&y2.mul_scalar(1.0 - lambda)?)?;
Ok((mixed_x, mixed_y))
}
pub fn cutmix(
x1: &Tensor,
x2: &Tensor,
y1: &Tensor,
y2: &Tensor,
_alpha: f32,
) -> TorshResult<(Tensor, Tensor, f32)> {
let shape_binding = x1.shape();
let shape = shape_binding.dims();
if shape.len() != 4 {
return Err(TorshError::invalid_argument_with_context(
"Input tensors must be 4D [B, C, H, W]",
"cutmix",
));
}
let (h, w) = (shape[2], shape[3]);
let lambda_data = rand(&[1])?.data()?;
let lambda = *lambda_data.get(0).unwrap_or(&0.5);
let cut_ratio = (1.0_f32 - lambda).sqrt();
let cut_w = (w as f32 * cut_ratio) as usize;
let cut_h = (h as f32 * cut_ratio) as usize;
let cx_data = rand(&[1])?.data()?;
let cx = (*cx_data.get(0).unwrap_or(&0.5) * w as f32) as usize;
let cy_data = rand(&[1])?.data()?;
let cy = (*cy_data.get(0).unwrap_or(&0.5) * h as f32) as usize;
let x_start = cx.saturating_sub(cut_w / 2).min(w);
let x_end = (cx + cut_w / 2).min(w);
let y_start = cy.saturating_sub(cut_h / 2).min(h);
let y_end = (cy + cut_h / 2).min(h);
let _mask: Tensor = ones(&shape)?;
let actual_lambda = ((x_end - x_start) * (y_end - y_start)) as f32 / (h * w) as f32;
let mixed_x = x1
.mul_scalar(1.0 - actual_lambda)?
.add(&x2.mul_scalar(actual_lambda)?)?;
let mixed_y = y1
.mul_scalar(1.0 - actual_lambda)?
.add(&y2.mul_scalar(actual_lambda)?)?;
Ok((mixed_x, mixed_y, actual_lambda))
}
pub fn differentiable_augment(input: &Tensor, probability: f32) -> TorshResult<Tensor> {
let prob_data = rand(&[1])?.data()?;
let apply_aug = *prob_data.get(0).unwrap_or(&0.5) < probability;
if !apply_aug {
return Ok(input.clone());
}
let noise_scale = 0.05f32; let noise_data: Tensor<f32> = rand(input.shape().dims())?;
let noise_tensor = Tensor::from_data(
noise_data
.data()?
.iter()
.map(|&x| (x - 0.5f32) * 2.0f32 * noise_scale)
.collect(),
input.shape().dims().to_vec(),
input.device(),
)?;
input.add(&noise_tensor)
}
#[cfg(test)]
mod tests {
use super::*;
use torsh_tensor::creation::randn;
#[test]
fn test_mixup_basic() -> TorshResult<()> {
let x1 = randn(&[2, 3, 4, 4])?;
let x2 = randn(&[2, 3, 4, 4])?;
let y1 = randn(&[2, 10])?;
let y2 = randn(&[2, 10])?;
let (mixed_x, mixed_y) = mixup(&x1, &x2, &y1, &y2, 0.5)?;
assert_eq!(x1.shape().dims(), mixed_x.shape().dims());
assert_eq!(y1.shape().dims(), mixed_y.shape().dims());
Ok(())
}
#[test]
fn test_cutmix_basic() -> TorshResult<()> {
let x1 = randn(&[2, 3, 8, 8])?;
let x2 = randn(&[2, 3, 8, 8])?;
let y1 = randn(&[2, 10])?;
let y2 = randn(&[2, 10])?;
let (mixed_x, mixed_y, lambda) = cutmix(&x1, &x2, &y1, &y2, 1.0)?;
assert_eq!(x1.shape().dims(), mixed_x.shape().dims());
assert_eq!(y1.shape().dims(), mixed_y.shape().dims());
assert!(lambda >= 0.0 && lambda <= 1.0);
Ok(())
}
#[test]
fn test_differentiable_augment() -> TorshResult<()> {
let input = randn(&[2, 3, 4, 4])?;
let augmented = differentiable_augment(&input, 1.0)?;
assert_eq!(input.shape().dims(), augmented.shape().dims());
Ok(())
}
}