use crate::error::{DatasetsError, Result};
use scirs2_core::ndarray::{s, Array1, Array2, Array3, ArrayView2};
use scirs2_core::rand_distributions::Normal;
use scirs2_core::random::Random;
use scirs2_core::{Rng, RngExt};
use std::sync::Arc;
fn create_rng() -> Random<scirs2_core::rand_prelude::StdRng> {
use std::time::{SystemTime, UNIX_EPOCH};
let seed = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_secs())
.unwrap_or(0);
Random::seed(seed)
}
pub trait Transform: Send + Sync {
fn transform_2d(&self, data: &Array2<f64>) -> Result<Array2<f64>>;
fn transform_3d(&self, data: &Array3<f64>) -> Result<Array3<f64>> {
let (height, width, channels) = data.dim();
let mut result = Array3::zeros((height, width, channels));
for c in 0..channels {
let channel_2d = data.slice(s![.., .., c]).to_owned();
let transformed = self.transform_2d(&channel_2d)?;
result.slice_mut(s![.., .., c]).assign(&transformed);
}
Ok(result)
}
fn uses_gpu(&self) -> bool {
false
}
fn name(&self) -> &str;
}
#[derive(Clone)]
pub struct AugmentationPipeline {
transforms: Vec<Arc<dyn Transform>>,
probability: f64,
seed: Option<u64>,
}
impl AugmentationPipeline {
pub fn new() -> Self {
Self {
transforms: Vec::new(),
probability: 1.0,
seed: None,
}
}
pub fn add_transform(mut self, transform: Arc<dyn Transform>) -> Self {
self.transforms.push(transform);
self
}
pub fn with_probability(mut self, prob: f64) -> Self {
self.probability = prob.clamp(0.0, 1.0);
self
}
pub fn with_seed(mut self, seed: u64) -> Self {
self.seed = Some(seed);
self
}
pub fn apply_2d(&self, data: &Array2<f64>) -> Result<Array2<f64>> {
let mut rng = if let Some(seed) = self.seed {
Random::seed(seed)
} else {
create_rng()
};
if rng.random::<f64>() > self.probability {
return Ok(data.clone());
}
let mut result = data.clone();
for transform in &self.transforms {
result = transform.transform_2d(&result)?;
}
Ok(result)
}
pub fn apply_3d(&self, data: &Array3<f64>) -> Result<Array3<f64>> {
let mut rng = if let Some(seed) = self.seed {
Random::seed(seed)
} else {
create_rng()
};
if rng.random::<f64>() > self.probability {
return Ok(data.clone());
}
let mut result = data.clone();
for transform in &self.transforms {
result = transform.transform_3d(&result)?;
}
Ok(result)
}
pub fn uses_gpu(&self) -> bool {
self.transforms.iter().any(|t| t.uses_gpu())
}
}
impl Default for AugmentationPipeline {
fn default() -> Self {
Self::new()
}
}
pub struct HorizontalFlip {
probability: f64,
}
impl HorizontalFlip {
pub fn new(probability: f64) -> Self {
Self {
probability: probability.clamp(0.0, 1.0),
}
}
}
impl Transform for HorizontalFlip {
fn transform_2d(&self, data: &Array2<f64>) -> Result<Array2<f64>> {
let mut rng = create_rng();
if rng.random::<f64>() < self.probability {
let flipped = data.slice(s![.., ..;-1]).to_owned();
Ok(flipped)
} else {
Ok(data.clone())
}
}
fn name(&self) -> &str {
"HorizontalFlip"
}
}
pub struct VerticalFlip {
probability: f64,
}
impl VerticalFlip {
pub fn new(probability: f64) -> Self {
Self {
probability: probability.clamp(0.0, 1.0),
}
}
}
impl Transform for VerticalFlip {
fn transform_2d(&self, data: &Array2<f64>) -> Result<Array2<f64>> {
let mut rng = create_rng();
if rng.random::<f64>() < self.probability {
let flipped = data.slice(s![..;-1, ..]).to_owned();
Ok(flipped)
} else {
Ok(data.clone())
}
}
fn name(&self) -> &str {
"VerticalFlip"
}
}
pub struct RandomRotation90 {
probability: f64,
}
impl RandomRotation90 {
pub fn new(probability: f64) -> Self {
Self {
probability: probability.clamp(0.0, 1.0),
}
}
fn rotate_90(&self, data: &Array2<f64>) -> Array2<f64> {
let (rows, cols) = data.dim();
let mut result = Array2::zeros((cols, rows));
for i in 0..rows {
for j in 0..cols {
result[[j, rows - 1 - i]] = data[[i, j]];
}
}
result
}
}
impl Transform for RandomRotation90 {
fn transform_2d(&self, data: &Array2<f64>) -> Result<Array2<f64>> {
let mut rng = create_rng();
if rng.random::<f64>() < self.probability {
let rotations = (rng.random::<f64>() * 3.0).floor() as usize + 1;
let mut result = data.clone();
for _ in 0..rotations {
result = self.rotate_90(&result);
}
Ok(result)
} else {
Ok(data.clone())
}
}
fn name(&self) -> &str {
"RandomRotation90"
}
}
pub struct GaussianNoise {
mean: f64,
std: f64,
probability: f64,
}
impl GaussianNoise {
pub fn new(mean: f64, std: f64, probability: f64) -> Self {
Self {
mean,
std,
probability: probability.clamp(0.0, 1.0),
}
}
}
impl Transform for GaussianNoise {
fn transform_2d(&self, data: &Array2<f64>) -> Result<Array2<f64>> {
let mut rng = create_rng();
if rng.random::<f64>() < self.probability {
let (rows, cols) = data.dim();
let mut result = data.clone();
let normal = Normal::new(self.mean, self.std).map_err(|e| {
DatasetsError::ComputationError(format!(
"Failed to create normal distribution: {}",
e
))
})?;
for i in 0..rows {
for j in 0..cols {
let noise = rng.sample(normal);
result[[i, j]] += noise;
}
}
Ok(result)
} else {
Ok(data.clone())
}
}
fn name(&self) -> &str {
"GaussianNoise"
}
}
pub struct Brightness {
delta_range: (f64, f64),
probability: f64,
}
impl Brightness {
pub fn new(delta_range: (f64, f64), probability: f64) -> Self {
Self {
delta_range,
probability: probability.clamp(0.0, 1.0),
}
}
}
impl Transform for Brightness {
fn transform_2d(&self, data: &Array2<f64>) -> Result<Array2<f64>> {
let mut rng = create_rng();
if rng.random::<f64>() < self.probability {
let delta = self.delta_range.0
+ rng.random::<f64>() * (self.delta_range.1 - self.delta_range.0);
Ok(data + delta)
} else {
Ok(data.clone())
}
}
fn name(&self) -> &str {
"Brightness"
}
}
pub struct Contrast {
factor_range: (f64, f64),
probability: f64,
}
impl Contrast {
pub fn new(factor_range: (f64, f64), probability: f64) -> Self {
Self {
factor_range,
probability: probability.clamp(0.0, 1.0),
}
}
}
impl Transform for Contrast {
fn transform_2d(&self, data: &Array2<f64>) -> Result<Array2<f64>> {
let mut rng = create_rng();
if rng.random::<f64>() < self.probability {
let factor = self.factor_range.0
+ rng.random::<f64>() * (self.factor_range.1 - self.factor_range.0);
let mean = data.mean().unwrap_or(0.0);
Ok((data - mean) * factor + mean)
} else {
Ok(data.clone())
}
}
fn name(&self) -> &str {
"Contrast"
}
}
pub struct RandomFeatureScale {
scale_range: (f64, f64),
feature_probability: f64,
}
impl RandomFeatureScale {
pub fn new(scale_range: (f64, f64), feature_probability: f64) -> Self {
Self {
scale_range,
feature_probability: feature_probability.clamp(0.0, 1.0),
}
}
}
impl Transform for RandomFeatureScale {
fn transform_2d(&self, data: &Array2<f64>) -> Result<Array2<f64>> {
let mut rng = create_rng();
let (rows, cols) = data.dim();
let mut result = data.clone();
for j in 0..cols {
if rng.random::<f64>() < self.feature_probability {
let scale = self.scale_range.0
+ rng.random::<f64>() * (self.scale_range.1 - self.scale_range.0);
for i in 0..rows {
result[[i, j]] *= scale;
}
}
}
Ok(result)
}
fn name(&self) -> &str {
"RandomFeatureScale"
}
}
pub struct Mixup {
alpha: f64,
probability: f64,
}
impl Mixup {
pub fn new(alpha: f64, probability: f64) -> Self {
Self {
alpha,
probability: probability.clamp(0.0, 1.0),
}
}
}
impl Transform for Mixup {
fn transform_2d(&self, data: &Array2<f64>) -> Result<Array2<f64>> {
let mut rng = create_rng();
if rng.random::<f64>() < self.probability {
let (rows, cols) = data.dim();
if rows < 2 {
return Ok(data.clone());
}
let mut result = data.clone();
for i in 0..rows {
let j = (rng.random::<f64>() * rows as f64).floor() as usize % rows;
if i != j {
let lambda = rng.random::<f64>();
for k in 0..cols {
result[[i, k]] = lambda * data[[i, k]] + (1.0 - lambda) * data[[j, k]];
}
}
}
Ok(result)
} else {
Ok(data.clone())
}
}
fn name(&self) -> &str {
"Mixup"
}
}
pub fn standard_image_augmentation(probability: f64) -> AugmentationPipeline {
AugmentationPipeline::new()
.add_transform(Arc::new(HorizontalFlip::new(0.5)))
.add_transform(Arc::new(RandomRotation90::new(0.3)))
.add_transform(Arc::new(Brightness::new((-0.2, 0.2), 0.4)))
.add_transform(Arc::new(Contrast::new((0.8, 1.2), 0.4)))
.add_transform(Arc::new(GaussianNoise::new(0.0, 0.01, 0.3)))
.with_probability(probability)
}
pub fn standard_tabular_augmentation(probability: f64) -> AugmentationPipeline {
AugmentationPipeline::new()
.add_transform(Arc::new(RandomFeatureScale::new((0.9, 1.1), 0.3)))
.add_transform(Arc::new(GaussianNoise::new(0.0, 0.01, 0.2)))
.add_transform(Arc::new(Mixup::new(1.0, 0.5)))
.with_probability(probability)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_horizontal_flip() -> Result<()> {
let data = Array2::from_shape_vec(
(3, 4),
vec![
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
],
)
.map_err(|e| DatasetsError::InvalidFormat(format!("{}", e)))?;
let flip = HorizontalFlip::new(1.0); let result = flip.transform_2d(&data)?;
assert_eq!(result[[0, 0]], 4.0);
assert_eq!(result[[0, 3]], 1.0);
assert_eq!(result.nrows(), 3);
assert_eq!(result.ncols(), 4);
Ok(())
}
#[test]
fn test_gaussian_noise() -> Result<()> {
let data = Array2::zeros((10, 10));
let noise = GaussianNoise::new(0.0, 0.1, 1.0);
let result = noise.transform_2d(&data)?;
let sum = result.sum();
assert!(sum.abs() > 1e-10);
assert_eq!(result.dim(), data.dim());
Ok(())
}
#[test]
fn test_brightness() -> Result<()> {
let data = Array2::from_elem((5, 5), 0.5);
let brightness = Brightness::new((0.1, 0.1), 1.0); let result = brightness.transform_2d(&data)?;
assert!((result[[0, 0]] - 0.6).abs() < 0.01);
Ok(())
}
#[test]
fn test_augmentation_pipeline() -> Result<()> {
let data =
Array2::from_shape_vec((3, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0])
.map_err(|e| DatasetsError::InvalidFormat(format!("{}", e)))?;
let pipeline = AugmentationPipeline::new()
.add_transform(Arc::new(HorizontalFlip::new(1.0)))
.add_transform(Arc::new(Brightness::new((0.1, 0.1), 1.0)))
.with_probability(1.0);
let result = pipeline.apply_2d(&data)?;
assert_eq!(result.dim(), data.dim());
Ok(())
}
#[test]
fn test_standard_pipelines() {
let img_pipeline = standard_image_augmentation(0.8);
assert!(!img_pipeline.uses_gpu());
let tab_pipeline = standard_tabular_augmentation(0.8);
assert!(!tab_pipeline.uses_gpu());
}
}