use crate::error::{NeuralError, Result};
use scirs2_core::ndarray::{Array, ArrayD, Axis};
use scirs2_core::numeric::Float;
use scirs2_core::random::{Rng, RngExt};
use scirs2_core::ndarray::ArrayStatCompat;
use std::collections::HashMap;
use std::fmt::Debug;
use statrs::statistics::Statistics;
#[derive(Debug, Clone, PartialEq)]
pub enum ImageAugmentation {
RandomHorizontalFlip {
probability: f64,
},
RandomVerticalFlip {
RandomRotation {
min_angle: f64,
max_angle: f64,
fill_mode: FillMode,
RandomScale {
min_scale: f64,
max_scale: f64,
preserve_aspect_ratio: bool,
RandomCrop {
crop_height: usize,
crop_width: usize,
padding: Option<usize>,
ColorJitter {
brightness: Option<f64>,
contrast: Option<f64>,
saturation: Option<f64>,
hue: Option<f64>,
GaussianNoise {
mean: f64,
std: f64,
RandomErasing {
area_ratio_range: (f64, f64),
aspect_ratio_range: (f64, f64),
fill_value: f64,
ElasticDeformation {
alpha: f64,
sigma: f64,
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum FillMode {
Constant(f64),
Reflect,
Wrap,
Nearest,
pub enum TextAugmentation {
SynonymReplacement {
num_replacements: usize,
RandomInsertion {
num_insertions: usize,
RandomDeletion {
RandomSwap {
num_swaps: usize,
BackTranslation {
intermediate_language: String,
Paraphrasing {
model_type: String,
pub enum AudioAugmentation {
TimeStretch {
stretch_factor_range: (f64, f64),
PitchShift {
semitone_range: (f64, f64),
AddNoise {
noise_factor: f64,
VolumeAdjust {
gain_range: (f64, f64),
FrequencyMask {
num_masks: usize,
mask_width_range: (usize, usize),
TimeMask {
mask_length_range: (usize, usize),
pub enum MixAugmentation {
MixUp {
CutMix {
cut_ratio_range: (f64, f64),
AugMix {
severity: usize,
width: usize,
depth: usize,
ManifoldMix {
layer_mix_probability: f64,
pub struct AugmentationManager<
F: Float + Debug + 'static + scirs2_core::ndarray::ScalarOperand + scirs2_core::numeric::FromPrimitive,
> {
image_transforms: Vec<ImageAugmentation>,
text_transforms: Vec<TextAugmentation>,
audio_transforms: Vec<AudioAugmentation>,
mix_strategies: Vec<MixAugmentation>,
rng_seed: Option<u64>,
stats: AugmentationStatistics<F>,
#[derive(Debug, Clone)]
pub struct AugmentationStatistics<
pub samples_processed: usize,
pub avg_intensity: F,
pub transform_counts: HashMap<String, usize>,
pub processing_time_ms: f64,
impl<F: Float + Debug + 'static + scirs2_core::ndarray::ScalarOperand + scirs2_core::numeric::FromPrimitive>
AugmentationManager<F>
{
pub fn new(_rngseed: Option<u64>) -> Self {
Self {
image_transforms: Vec::new(),
text_transforms: Vec::new(),
audio_transforms: Vec::new(),
mix_strategies: Vec::new(),
rng_seed,
stats: AugmentationStatistics {
samples_processed: 0,
avg_intensity: F::zero(),
transform_counts: HashMap::new(),
processing_time_ms: 0.0,
},
}
}
pub fn add_image_transform(&mut self, transform: ImageAugmentation) {
self.image_transforms.push(transform);
pub fn addtext_transform(&mut self, transform: TextAugmentation) {
self.text_transforms.push(transform);
pub fn add_audio_transform(&mut self, transform: AudioAugmentation) {
self.audio_transforms.push(transform);
pub fn add_mix_strategy(&mut self, strategy: MixAugmentation) {
self.mix_strategies.push(strategy);
pub fn augment_images(&mut self, images: &ArrayD<F>) -> Result<ArrayD<F>> {
let start_time = std::time::Instant::now();
let mut augmented = images.clone();
for transform in &self.image_transforms {
augmented = self.apply_image_transform(&augmented, transform)?;
let transform_name = format!("{transform:?}")
.split(' ')
.next()
.unwrap_or("unknown")
.to_string();
*self
.stats
.transform_counts
.entry(transform_name)
.or_insert(0) += 1;
self.stats.samples_processed += images.shape()[0];
self.stats.processing_time_ms += start_time.elapsed().as_secs_f64() * 1000.0;
Ok(augmented)
fn apply_image_transform(
&self,
images: &ArrayD<F>,
transform: &ImageAugmentation,
) -> Result<ArrayD<F>> {
match transform {
ImageAugmentation::RandomHorizontalFlip { probability } => {
self.random_horizontal_flip(images, *probability)
}
ImageAugmentation::RandomVerticalFlip { probability } => {
self.random_vertical_flip(images, *probability)
ImageAugmentation::RandomRotation {
min_angle,
max_angle,
fill_mode,
} => self.random_rotation(images, *min_angle, *max_angle, *fill_mode),
ImageAugmentation::RandomScale {
min_scale,
max_scale,
preserve_aspect_ratio,
} => self.random_scale(images, *min_scale, *max_scale, *preserve_aspect_ratio),
ImageAugmentation::RandomCrop {
crop_height,
crop_width,
padding,
} => self.random_crop(images, *crop_height, *crop_width, *padding),
ImageAugmentation::ColorJitter {
brightness,
contrast,
saturation,
hue,
} => self.color_jitter(images, *brightness, *contrast, *saturation, *hue),
ImageAugmentation::GaussianNoise {
mean,
std,
probability,
} => self.gaussian_noise(images, *mean, *std, *probability),
ImageAugmentation::RandomErasing {
area_ratio_range,
aspect_ratio_range,
fill_value,
} => self.random_erasing(
images,
*probability,
*area_ratio_range,
*aspect_ratio_range,
*fill_value,
),
ImageAugmentation::ElasticDeformation {
alpha,
sigma,
} => self.elastic_deformation(images, *alpha, *sigma, *probability),
fn random_horizontal_flip(&self, images: &ArrayD<F>, probability: f64) -> Result<ArrayD<F>> {
let mut result = images.clone();
let batch_size = images.shape()[0];
for i in 0..batch_size {
if scirs2_core::random::random::<f64>() < probability {
if images.ndim() >= 4 {
let width_dim = images.ndim() - 1;
let mut sample = result.slice_mut(scirs2_core::ndarray::s![i, .., .., ..]);
sample.invert_axis(Axis(width_dim - 1)); }
Ok(result)
fn random_vertical_flip(&self, images: &ArrayD<F>, probability: f64) -> Result<ArrayD<F>> {
let height_dim = images.ndim() - 2;
sample.invert_axis(Axis(height_dim - 1)); fn random_rotation(
_fill_mode: FillMode,
let result = images.clone();
for _i in 0..batch_size {
let _angle = rng().random_range(min_angle..=max_angle);
fn random_scale(
_preserve_aspect_ratio: bool..// Simplified scaling implementation
// In practice, this would involve proper image scaling algorithms
let _scale = rng().random_range(min_scale..=max_scale);
fn random_crop(
_padding: Option<usize>..if images.ndim() < 4 {
return Err(NeuralError::InvalidArchitecture(
"Random crop requires 4D input (NCHW)".to_string(),
));
let channels = images.shape()[1];
let height = images.shape()[2];
let width = images.shape()[3];
if crop_height > height || crop_width > width {
"Crop size cannot be larger than image size".to_string(),
let mut result = Array::zeros((batch_size, channels, crop_height, crop_width));
let start_h = rng().random_range(0..=(height - crop_height));
let start_w = rng().random_range(0..=(width - crop_width));
let crop = images.slice(scirs2_core::ndarray::s![
i....,
start_h..start_h + crop_height,
start_w..start_w + crop_width
]);
result.slice_mut(scirs2_core::ndarray::s![i, .., .., ..]).assign(&crop);
Ok(result.into_dyn())
fn color_jitter(
_saturation: Option<f64>, _hue: Option<f64>,
if let Some(bright_factor) = brightness {
let factor =
F::from(1.0 + rng().random_range(-bright_factor..=bright_factor)).expect("Operation failed");
result = result * factor;
if let Some(contrast_factor) = contrast {
F::from(1.0 + rng().random_range(-contrast_factor..=contrast_factor))
.expect("Operation failed");
let mean = result.mean_or(F::zero());
result = (result - mean) * factor + mean;
result = result.mapv(|x| x.max(F::zero()).min(F::one()));
fn gaussian_noise(
if scirs2_core::random::random::<f64>() < probability {
let noise = images.mapv(|_| {
let noise_val = rng().random_range(-3.0 * std..=3.0 * std) + mean;
F::from(noise_val).unwrap_or(F::zero())
});
result = result + noise;
fn random_erasing(
"Random erasing requires 4D input (NCHW)".to_string()..let fill_val = F::from(fill_value).unwrap_or(F::zero());
let area_ratio = rng().random_range(area_ratio_range.0..=area_ratio_range.1);
let aspect_ratio =
rng().random_range(aspect_ratio_range.0..=aspect_ratio_range.1);
let target_area = (height * width) as f64 * area_ratio;
let mask_height = ((target_area * aspect_ratio).sqrt() as usize).min(height);
let mask_width = ((target_area / aspect_ratio).sqrt() as usize).min(width);
if mask_height > 0 && mask_width > 0 {
let start_h = rng().random_range(0..=(height - mask_height));
let start_w = rng().random_range(0..=(width - mask_width));
result
.slice_mut(scirs2_core::ndarray::s![
i....,
start_h..start_h + mask_height,
start_w..start_w + mask_width
])
.fill(fill_val);
fn elastic_deformation(
_alpha: f64, sigma: f64,
let noise_factor = F::from(0.01).expect("Failed to convert constant to float");
let noise_val = rng().random_range(-0.05..=0.05);
result = result + noise * noise_factor;
pub fn apply_mixup(
&mut self..labels: &ArrayD<F>,) -> Result<(ArrayD<F>, ArrayD<F>)> {
if batch_size < 2 {
return Ok((images.clone(), labels.clone()));
let lambda = self.sample_beta_distribution(alpha)?;
let lambda_f = F::from(lambda).unwrap_or(F::from(0.5).expect("Failed to convert constant to float"));
let mut indices: Vec<usize> = (0..batch_size).collect();
let j = rng().random_range(0..batch_size);
indices.swap(i, j);
let mut mixed_images = images.clone();
let mut mixed_labels = labels.clone();
for (i, &j) in indices.iter().enumerate().take(batch_size) {
let x_i = images.index_axis(scirs2_core::ndarray::Axis(0), i);
let x_j = images.index_axis(scirs2_core::ndarray::Axis(0), j);
let mixed = &x_i * lambda_f + &x_j * (F::one() - lambda_f);
mixed_images
.index_axis_mut(scirs2_core::ndarray::Axis(0), i)
.assign(&mixed);
let y_i = labels.index_axis(scirs2_core::ndarray::Axis(0), i);
let y_j = labels.index_axis(scirs2_core::ndarray::Axis(0), j);
let mixed_label = &y_i * lambda_f + &y_j * (F::one() - lambda_f);
mixed_labels
.assign(&mixed_label);
self.stats.samples_processed += batch_size;
*self
.stats
.transform_counts
.entry("MixUp".to_string())
.or_insert(0) += 1;
Ok((mixed_images, mixed_labels))
pub fn apply_cutmix(
"CutMix requires 4D input (NCHW)".to_string(),
let _lambda = self.sample_beta_distribution(alpha)?;
let cut_ratio = rng().random_range(cut_ratio_range.0..=cut_ratio_range.1);
let cut_height = ((height as f64 * cut_ratio).sqrt() as usize).min(height);
let cut_width = ((width as f64 * cut_ratio).sqrt() as usize).min(width);
let j = indices[i];
let start_h = rng().random_range(0..=(height - cut_height));
let start_w = rng().random_range(0..=(width - cut_width));
let patch = images.slice(scirs2_core::ndarray::s![
j..start_h..start_h + cut_height,
start_w..start_w + cut_width
.slice_mut(scirs2_core::ndarray::s![
i,
..,
start_h..start_h + cut_height,
start_w..start_w + cut_width
])
.assign(&patch);
let actual_lambda = (cut_height * cut_width) as f64 / (height * width) as f64;
let lambda_f = F::from(1.0 - actual_lambda).unwrap_or(F::from(0.5).expect("Failed to convert constant to float"));
let y_i = labels.slice(scirs2_core::ndarray::s![i, ..]);
let y_j = labels.slice(scirs2_core::ndarray::s![j, ..]);
.slice_mut(scirs2_core::ndarray::s![i, ..])
.entry("CutMix".to_string())
fn sample_beta_distribution(&self, alpha: f64) -> Result<f64> {
if alpha <= 0.0 {
return Ok(0.5);
Ok(scirs2_core::random::random::<f64>())
pub fn get_statistics(&self) -> &AugmentationStatistics<F> {
&self.stats
pub fn reset_statistics(&mut self) {
self.stats = AugmentationStatistics {
samples_processed: 0,
avg_intensity: F::zero(),
transform_counts: HashMap::new(),
processing_time_ms: 0.0,
};
pub fn create_standard_image_pipeline() -> Vec<ImageAugmentation> {
vec![
ImageAugmentation::RandomHorizontalFlip { probability: 0.5 },
brightness: Some(0.2),
contrast: Some(0.2),
saturation: Some(0.2),
hue: Some(0.1),
mean: 0.0,
std: 0.01,
probability: 0.3,
probability: 0.25,
area_ratio_range: (0.02, 0.33),
aspect_ratio_range: (0.3, 3.3),
fill_value: 0.0,
]
pub fn create_strong_image_pipeline() -> Vec<ImageAugmentation> {
ImageAugmentation::RandomVerticalFlip { probability: 0.2 },
min_angle: -30.0,
max_angle: 30.0,
fill_mode: FillMode::Constant(0.0),
min_scale: 0.8,
max_scale: 1.2,
preserve_aspect_ratio: true,
brightness: Some(0.4),
contrast: Some(0.4),
saturation: Some(0.4),
hue: Some(0.2),
std: 0.02,
probability: 0.5,
area_ratio_range: (0.02, 0.4),
alpha: 1.0,
sigma: 0.1,
impl<F: Float + Debug + 'static + scirs2_core::ndarray::ScalarOperand + scirs2_core::numeric::FromPrimitive> Default
for AugmentationManager<F>
fn default() -> Self {
Self::new(None)
/// Augmentation pipeline builder for easy configuration
pub struct AugmentationPipelineBuilder<
manager: AugmentationManager<F>,
AugmentationPipelineBuilder<F>
/// Create a new pipeline builder
pub fn new() -> Self {
manager: AugmentationManager::new(None),
/// Set random seed
pub fn with_seed(mut self, seed: u64) -> Self {
self.manager.rng_seed = Some(seed);
self
/// Add standard image augmentations
pub fn with_standard_image_augmentations(mut self) -> Self {
for transform in AugmentationManager::<F>::create_standard_image_pipeline() {
self.manager.add_image_transform(transform);
/// Add strong image augmentations
pub fn with_strong_image_augmentations(mut self) -> Self {
for transform in AugmentationManager::<F>::create_strong_image_pipeline() {
/// Add MixUp augmentation
pub fn with_mixup(mut self, alpha: f64) -> Self {
self.manager
.add_mix_strategy(MixAugmentation::MixUp { alpha });
pub fn with_cutmix(mut self, alpha: f64, cut_ratiorange: (f64, f64)) -> Self {
self.manager.add_mix_strategy(MixAugmentation::CutMix {
alpha,
cut_ratio_range,
});
pub fn build(self) -> AugmentationManager<F> {
for AugmentationPipelineBuilder<F>
Self::new()
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::{Array2, Array4};
#[test]
fn test_augmentation_manager_creation() {
let manager = AugmentationManager::<f64>::new(Some(42));
assert_eq!(manager.rng_seed, Some(42));
assert_eq!(manager.image_transforms.len(), 0);
fn test_random_horizontal_flip() {
let mut manager = AugmentationManager::<f64>::new(Some(42));
manager.add_image_transform(ImageAugmentation::RandomHorizontalFlip { probability: 1.0 });
let input =
Array4::<f64>::from_shape_fn((2, 3, 4, 4), |(____)| scirs2_core::random::random()).into_dyn();
let result = manager.augment_images(&input).expect("Operation failed");
assert_eq!(result.shape(), input.shape());
assert!(manager.stats.samples_processed > 0);
fn test_random_crop() {
let manager = AugmentationManager::<f64>::new(None);
let input = Array4::<f64>::ones((2, 3, 8, 8)).into_dyn();
let result = manager.random_crop(&input, 4, 4, None).expect("Operation failed");
assert_eq!(result.shape(), &[2, 3, 4, 4]);
fn test_color_jitter() {
let input = Array4::<f64>::from_elem((1, 3, 4, 4), 0.5).into_dyn();
let result = manager
.color_jitter(&input, Some(0.2), Some(0.2), None, None)
.expect("Operation failed");
fn test_gaussian_noise() {
let input = Array4::<f64>::zeros((2, 3, 4, 4)).into_dyn();
let result = manager.gaussian_noise(&input, 0.0, 0.1, 1.0).expect("Operation failed");
fn test_random_erasing() {
.random_erasing(&input, 1.0, (0.1, 0.3), (0.5, 2.0), 0.0)
fn test_mixup() {
let images = Array4::<f64>::ones((4, 3, 8, 8)).into_dyn();
let labels = Array2::<f64>::from_elem((4, 10), 1.0).into_dyn();
let (mixed_images, mixed_labels) = manager.apply_mixup(&images, &labels, 1.0).expect("Operation failed");
assert_eq!(mixed_images.shape(), images.shape());
assert_eq!(mixed_labels.shape(), labels.shape());
assert!(manager.stats.transform_counts.contains_key("MixUp"));
fn test_cutmix() {
let (mixed_images, mixed_labels) = manager
.apply_cutmix(&images, &labels, 1.0, (0.1, 0.5))
assert!(manager.stats.transform_counts.contains_key("CutMix"));
fn test_standard_pipeline() {
let pipeline = AugmentationManager::<f64>::create_standard_image_pipeline();
assert!(!pipeline.is_empty());
assert!(pipeline.len() >= 3);
fn test_strong_pipeline() {
let pipeline = AugmentationManager::<f64>::create_strong_image_pipeline();
assert!(
pipeline.len() > AugmentationManager::<f64>::create_standard_image_pipeline().len()
);
fn test_pipeline_builder() {
let manager = AugmentationPipelineBuilder::<f64>::new()
.with_seed(42)
.with_standard_image_augmentations()
.with_mixup(1.0)
.build();
assert!(!manager.image_transforms.is_empty());
assert!(!manager.mix_strategies.is_empty());
fn test_augmentation_statistics() {
let mut manager = AugmentationManager::<f64>::new(None);
manager.add_image_transform(ImageAugmentation::RandomHorizontalFlip { probability: 0.5 });
let input = Array4::<f64>::ones((2, 3, 4, 4)).into_dyn();
let _ = manager.augment_images(&input).expect("Operation failed");
let stats = manager.get_statistics();
assert_eq!(stats.samples_processed, 2);
assert!(stats.processing_time_ms >= 0.0);