use super::core::Transform;
use crate::{Result, VisionError};
use scirs2_core::random::{Random, Rng};
use scirs2_core::RngExt;
use torsh_tensor::{creation, creation::zeros_mut, Tensor};
#[derive(Debug, Clone)]
pub struct AugMix {
pub width: usize,
pub depth: usize,
pub alpha: f32,
pub mixture_weight: f32,
augmentations: Vec<AugMixOp>,
}
#[derive(Debug, Clone)]
enum AugMixOp {
AutoContrast,
Brightness(f32),
Color(f32),
Contrast(f32),
Equalize,
Posterize(u8),
Rotate(f32),
Sharpness(f32),
ShearX(f32),
ShearY(f32),
Solarize(f32),
TranslateX(f32),
TranslateY(f32),
}
impl AugMix {
pub fn new() -> Self {
Self {
width: 3,
depth: 3,
alpha: 1.0,
mixture_weight: 0.5,
augmentations: Self::default_augmentations(),
}
}
pub fn with_params(width: usize, depth: usize, alpha: f32, mixture_weight: f32) -> Self {
Self {
width,
depth,
alpha,
mixture_weight,
augmentations: Self::default_augmentations(),
}
}
fn default_augmentations() -> Vec<AugMixOp> {
vec![
AugMixOp::AutoContrast,
AugMixOp::Brightness(0.3),
AugMixOp::Color(0.3),
AugMixOp::Contrast(0.3),
AugMixOp::Equalize,
AugMixOp::Posterize(4),
AugMixOp::Rotate(30.0),
AugMixOp::Sharpness(0.3),
AugMixOp::ShearX(0.3),
AugMixOp::ShearY(0.3),
AugMixOp::Solarize(0.5),
AugMixOp::TranslateX(0.3),
AugMixOp::TranslateY(0.3),
]
}
fn apply_operation(
&self,
input: &Tensor<f32>,
op: &AugMixOp,
magnitude: f32,
) -> Result<Tensor<f32>> {
match op {
AugMixOp::AutoContrast => self.auto_contrast(input),
AugMixOp::Brightness(max_mag) => {
let factor = 1.0 + magnitude * max_mag;
self.adjust_brightness(input, factor)
}
AugMixOp::Color(max_mag) => {
let factor = 1.0 + magnitude * max_mag;
self.adjust_saturation(input, factor)
}
AugMixOp::Contrast(max_mag) => {
let factor = 1.0 + magnitude * max_mag;
self.adjust_contrast(input, factor)
}
AugMixOp::Equalize => self.equalize_histogram(input),
AugMixOp::Posterize(bits) => self.posterize(input, *bits),
AugMixOp::Rotate(max_angle) => {
let angle = magnitude * max_angle;
crate::ops::rotate(input, angle)
}
AugMixOp::Sharpness(_)
| AugMixOp::ShearX(_)
| AugMixOp::ShearY(_)
| AugMixOp::Solarize(_)
| AugMixOp::TranslateX(_)
| AugMixOp::TranslateY(_) => {
Ok(input.clone())
}
}
}
fn auto_contrast(&self, input: &Tensor<f32>) -> Result<Tensor<f32>> {
let shape = input.shape();
let (channels, height, width) = (shape.dims()[0], shape.dims()[1], shape.dims()[2]);
let mut output = input.clone();
output.make_unique()?;
for c in 0..channels {
let mut min_val = f32::INFINITY;
let mut max_val = f32::NEG_INFINITY;
for y in 0..height {
for x in 0..width {
let val = input.get(&[c, y, x])?;
min_val = min_val.min(val);
max_val = max_val.max(val);
}
}
let range = max_val - min_val;
if range > 0.0 {
for y in 0..height {
for x in 0..width {
let val = input.get(&[c, y, x])?;
let normalized = (val - min_val) / range;
output.set(&[c, y, x], normalized)?;
}
}
}
}
Ok(output)
}
fn adjust_brightness(&self, input: &Tensor<f32>, factor: f32) -> Result<Tensor<f32>> {
input
.mul_scalar(factor)
.map_err(|e| VisionError::TensorError(e))
}
fn adjust_saturation(&self, input: &Tensor<f32>, _factor: f32) -> Result<Tensor<f32>> {
Ok(input.clone()) }
fn adjust_contrast(&self, input: &Tensor<f32>, factor: f32) -> Result<Tensor<f32>> {
let mean = input.mean(None, false)?.item()?;
let centered = input.sub_scalar(mean)?;
let adjusted = centered.mul_scalar(factor)?;
adjusted
.add_scalar(mean)
.map_err(|e| VisionError::TensorError(e))
}
fn equalize_histogram(&self, input: &Tensor<f32>) -> Result<Tensor<f32>> {
Ok(input.clone()) }
fn posterize(&self, input: &Tensor<f32>, bits: u8) -> Result<Tensor<f32>> {
let mask = (1 << bits) - 1;
let shift = 8 - bits;
let shape = input.shape();
let (channels, height, width) = (shape.dims()[0], shape.dims()[1], shape.dims()[2]);
let mut output = input.clone();
output.make_unique()?;
for c in 0..channels {
for y in 0..height {
for x in 0..width {
let val = input.get(&[c, y, x])?;
let quantized =
((val.clamp(0.0, 1.0) * 255.0) as u8 & (mask << shift)) as f32 / 255.0;
output.set(&[c, y, x], quantized)?;
}
}
}
Ok(output)
}
}
impl Transform for AugMix {
fn forward(&self, input: &Tensor<f32>) -> Result<Tensor<f32>> {
let mut rng = Random::seed(42);
let mut weights: Vec<f32> = (0..self.width).map(|_| rng.random::<f32>()).collect();
let sum: f32 = weights.iter().sum();
if sum > 0.0 {
for w in &mut weights {
*w /= sum;
}
}
let mut augmented_images = Vec::new();
for _ in 0..self.width {
let mut current = input.clone();
let chain_length = rng.gen_range(1..=self.depth);
for _ in 0..chain_length {
let op_idx = rng.gen_range(0..self.augmentations.len());
let magnitude = rng.random::<f32>();
current = self.apply_operation(¤t, &self.augmentations[op_idx], magnitude)?;
}
augmented_images.push(current);
}
let mut mixed =
creation::zeros(input.shape().dims()).expect("tensor creation should succeed");
for (i, aug_img) in augmented_images.iter().enumerate() {
let weighted = aug_img
.mul_scalar(weights[i])
.map_err(|e| VisionError::TensorError(e))?;
mixed = mixed
.add(&weighted)
.map_err(|e| VisionError::TensorError(e))?;
}
let final_weight = self.mixture_weight;
let original_weighted = input
.mul_scalar(1.0 - final_weight)
.map_err(|e| VisionError::TensorError(e))?;
let mixed_weighted = mixed
.mul_scalar(final_weight)
.map_err(|e| VisionError::TensorError(e))?;
original_weighted
.add(&mixed_weighted)
.map_err(|e| VisionError::TensorError(e))
}
fn name(&self) -> &'static str {
"AugMix"
}
fn parameters(&self) -> Vec<(&'static str, String)> {
vec![
("width", format!("{}", self.width)),
("depth", format!("{}", self.depth)),
("alpha", format!("{:.2}", self.alpha)),
("mixture_weight", format!("{:.2}", self.mixture_weight)),
]
}
fn clone_transform(&self) -> Box<dyn Transform> {
Box::new(AugMix {
width: self.width,
depth: self.depth,
alpha: self.alpha,
mixture_weight: self.mixture_weight,
augmentations: self.augmentations.clone(),
})
}
}
pub struct GridMask {
pub prob: f32,
pub d_min: usize,
pub d_max: usize,
pub r_min: f32,
pub r_max: f32,
pub rotate: bool,
pub fill_value: f32,
}
impl GridMask {
pub fn new() -> Self {
Self {
prob: 0.6,
d_min: 96,
d_max: 224,
r_min: 0.5,
r_max: 0.8,
rotate: true,
fill_value: 0.0,
}
}
pub fn with_params(
prob: f32,
d_min: usize,
d_max: usize,
r_min: f32,
r_max: f32,
rotate: bool,
fill_value: f32,
) -> Self {
Self {
prob,
d_min,
d_max,
r_min,
r_max,
rotate,
fill_value,
}
}
fn generate_grid_mask(&self, height: usize, width: usize) -> Result<Tensor<f32>> {
let mut rng = Random::seed(42);
let d = rng.gen_range(self.d_min..=self.d_max);
let r = rng.gen_range(self.r_min..=self.r_max);
let l = ((d as f32 * r) as usize).max(1);
let mask = creation::ones(&[height, width]).expect("tensor creation should succeed");
let mut y_start = rng.gen_range(0..d);
while y_start < height {
let y_end = (y_start + l).min(height);
let mut x_start = rng.gen_range(0..d);
while x_start < width {
let x_end = (x_start + l).min(width);
for y in y_start..y_end {
for x in x_start..x_end {
mask.set(&[y, x], 0.0)?;
}
}
x_start += d;
}
y_start += d;
}
Ok(mask)
}
}
impl Transform for GridMask {
fn forward(&self, input: &Tensor<f32>) -> Result<Tensor<f32>> {
let mut rng = Random::seed(42);
if rng.random::<f32>() > self.prob {
return Ok(input.clone());
}
let shape = input.shape();
if shape.dims().len() != 3 {
return Err(VisionError::InvalidShape(format!(
"Expected 3D tensor, got {}D",
shape.dims().len()
)));
}
let (channels, height, width) = (shape.dims()[0], shape.dims()[1], shape.dims()[2]);
let mask = self.generate_grid_mask(height, width)?;
let mut output = input.clone();
output.make_unique()?;
for c in 0..channels {
for y in 0..height {
for x in 0..width {
let mask_val = mask.get(&[y, x])?;
if mask_val < 0.5 {
output.set(&[c, y, x], self.fill_value)?;
}
}
}
}
Ok(output)
}
fn name(&self) -> &'static str {
"GridMask"
}
fn parameters(&self) -> Vec<(&'static str, String)> {
vec![
("prob", format!("{:.2}", self.prob)),
("d_range", format!("({}, {})", self.d_min, self.d_max)),
("r_range", format!("({:.2}, {:.2})", self.r_min, self.r_max)),
]
}
fn clone_transform(&self) -> Box<dyn Transform> {
Box::new(GridMask {
prob: self.prob,
d_min: self.d_min,
d_max: self.d_max,
r_min: self.r_min,
r_max: self.r_max,
rotate: self.rotate,
fill_value: self.fill_value,
})
}
}
pub struct Mosaic {
pub prob: f32,
pub size: (usize, usize),
}
impl Mosaic {
pub fn new(size: (usize, usize)) -> Self {
Self { prob: 0.5, size }
}
pub fn with_prob(size: (usize, usize), prob: f32) -> Self {
Self { prob, size }
}
pub fn apply_batch(&self, images: &[Tensor<f32>]) -> Result<Tensor<f32>> {
if images.len() != 4 {
return Err(VisionError::InvalidArgument(
"Mosaic requires exactly 4 images".to_string(),
));
}
let (target_width, target_height) = self.size;
let mut rng = Random::seed(42);
let center_x = rng.gen_range(target_width / 4..target_width * 3 / 4);
let center_y = rng.gen_range(target_height / 4..target_height * 3 / 4);
let channels = images[0].shape().dims()[0];
let mosaic = zeros_mut(&[channels, target_height, target_width]);
let regions = [
(0, 0, center_x, center_y),
(center_x, 0, target_width, center_y),
(0, center_y, center_x, target_height),
(center_x, center_y, target_width, target_height),
];
for (i, image) in images.iter().enumerate() {
let (x_start, y_start, x_end, y_end) = regions[i];
let region_width = x_end - x_start;
let region_height = y_end - y_start;
let resized = crate::ops::resize(image, (region_width, region_height))?;
for c in 0..channels {
for y in 0..region_height {
for x in 0..region_width {
let val = resized.get(&[c, y, x])?;
mosaic.set(&[c, y_start + y, x_start + x], val)?;
}
}
}
}
Ok(mosaic)
}
}
impl Transform for Mosaic {
fn forward(&self, input: &Tensor<f32>) -> Result<Tensor<f32>> {
Ok(input.clone())
}
fn name(&self) -> &'static str {
"Mosaic"
}
fn parameters(&self) -> Vec<(&'static str, String)> {
vec![
("size", format!("({}, {})", self.size.0, self.size.1)),
("probability", format!("{:.2}", self.prob)),
]
}
fn clone_transform(&self) -> Box<dyn Transform> {
Box::new(Mosaic::with_prob(self.size, self.prob))
}
}
#[cfg(test)]
mod tests {
use super::*;
use torsh_tensor::creation;
#[test]
fn test_augmix_creation() {
let augmix = AugMix::new();
assert_eq!(augmix.width, 3);
assert_eq!(augmix.depth, 3);
assert_eq!(augmix.name(), "AugMix");
}
#[test]
fn test_gridmask_creation() {
let gridmask = GridMask::new();
assert_eq!(gridmask.prob, 0.6);
assert_eq!(gridmask.name(), "GridMask");
}
#[test]
fn test_mosaic_creation() {
let mosaic = Mosaic::new((224, 224));
assert_eq!(mosaic.size, (224, 224));
assert_eq!(mosaic.name(), "Mosaic");
}
#[test]
#[ignore] fn test_transforms_forward() {
let input = creation::ones(&[3, 32, 32]).expect("creation should succeed");
let augmix = AugMix::new();
let result = augmix.forward(&input);
assert!(result.is_ok());
let gridmask = GridMask::new();
let result = gridmask.forward(&input);
assert!(result.is_ok());
let mosaic = Mosaic::new((32, 32));
let result = mosaic.forward(&input);
assert!(result.is_ok());
}
#[test]
fn test_mosaic_apply_batch() {
let images = vec![
creation::ones(&[3, 16, 16]).expect("creation should succeed"),
creation::zeros(&[3, 16, 16]).expect("creation should succeed"),
creation::full(&[3, 16, 16], 0.5).expect("creation should succeed"),
creation::full(&[3, 16, 16], 0.25).expect("creation should succeed"),
];
let mosaic = Mosaic::new((32, 32));
let result = mosaic.apply_batch(&images);
assert!(result.is_ok());
assert_eq!(
result.expect("operation should succeed").shape().dims(),
&[3, 32, 32]
);
}
#[test]
fn test_mosaic_invalid_batch_size() {
let images = vec![creation::ones(&[3, 16, 16]).expect("creation should succeed")];
let mosaic = Mosaic::new((32, 32));
let result = mosaic.apply_batch(&images);
assert!(result.is_err());
}
}