use crate::error::{NeuralError, Result};
use scirs2_core::ndarray::{Array, IxDyn, ScalarOperand};
use scirs2_core::numeric::{Float, NumAssign};
use scirs2_core::random::rngs::SmallRng;
use scirs2_core::random::Distribution;
use scirs2_core::random::{thread_rng, Rng, RngExt, SeedableRng};
use std::fmt::Debug;
pub trait Augmentation<F: Float + NumAssign + Debug + ScalarOperand> {
fn apply(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>>;
fn description(&self) -> String;
}
#[derive(Debug, Clone)]
pub struct GaussianNoise<F: Float + NumAssign + Debug + ScalarOperand> {
std: F,
}
impl<F: Float + NumAssign + Debug + ScalarOperand> GaussianNoise<F> {
pub fn new(std: F) -> Self {
Self { std }
}
}
impl<F: Float + NumAssign + Debug + ScalarOperand> Augmentation<F> for GaussianNoise<F> {
fn apply(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
let mut rng = SmallRng::from_rng(&mut thread_rng());
let mut result = input.clone();
for item in result.iter_mut() {
let normal = scirs2_core::random::Normal::new(0.0, self.std.to_f64().unwrap_or(0.1))
.expect("Failed to create normal distribution");
let noise = F::from(rng.sample(normal)).unwrap_or(F::zero());
*item += noise;
}
Ok(result)
}
fn description(&self) -> String {
format!(
"GaussianNoise (std: {:.3})",
self.std.to_f64().unwrap_or(0.0)
)
}
}
#[derive(Debug, Clone)]
pub struct RandomErasing<F: Float + NumAssign + Debug + ScalarOperand> {
probability: f64,
value: F,
}
impl<F: Float + NumAssign + Debug + ScalarOperand> RandomErasing<F> {
pub fn new(probability: f64, value: F) -> Self {
Self { probability, value }
}
}
impl<F: Float + NumAssign + Debug + ScalarOperand> Augmentation<F> for RandomErasing<F> {
fn apply(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
let mut rng = SmallRng::from_rng(&mut thread_rng());
let mut result = input.clone();
if rng.random::<f64>() > self.probability {
return Ok(result);
}
if result.ndim() < 3 {
return Ok(result);
}
Ok(result)
}
fn description(&self) -> String {
format!(
"RandomErasing (prob: {:.2}, value: {:.2})",
self.probability,
self.value.to_f64().unwrap_or(0.0)
)
}
}
#[derive(Debug, Clone)]
pub struct RandomHorizontalFlip<F: Float + NumAssign + Debug + ScalarOperand> {
probability: f64,
_phantom: std::marker::PhantomData<F>,
}
impl<F: Float + NumAssign + Debug + ScalarOperand> RandomHorizontalFlip<F> {
pub fn new(probability: f64) -> Self {
Self {
probability,
_phantom: std::marker::PhantomData,
}
}
}
impl<F: Float + NumAssign + Debug + ScalarOperand> Augmentation<F> for RandomHorizontalFlip<F> {
fn apply(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
let mut rng = SmallRng::from_rng(&mut thread_rng());
let result = input.clone();
if rng.random::<f64>() > self.probability {
return Ok(result);
}
Ok(result)
}
fn description(&self) -> String {
format!("RandomHorizontalFlip (prob: {:.2})", self.probability)
}
}
struct DebugAugmentationWrapper<'a, F: Float + NumAssign + Debug + ScalarOperand> {
inner: &'a dyn Augmentation<F>,
}
impl<F: Float + NumAssign + Debug + ScalarOperand> Debug for DebugAugmentationWrapper<'_, F> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Augmentation({})", self.inner.description())
}
}
pub struct ComposeAugmentation<F: Float + NumAssign + Debug + ScalarOperand> {
augmentations: Vec<Box<dyn Augmentation<F>>>,
}
impl<F: Float + NumAssign + Debug + ScalarOperand> Clone for ComposeAugmentation<F> {
fn clone(&self) -> Self {
Self {
augmentations: Vec::new(),
}
}
}
impl<F: Float + NumAssign + Debug + ScalarOperand> Debug for ComposeAugmentation<F> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let mut debug_list = f.debug_list();
for augmentation in &self.augmentations {
debug_list.entry(&DebugAugmentationWrapper {
inner: augmentation.as_ref(),
});
}
debug_list.finish()
}
}
impl<F: Float + NumAssign + Debug + ScalarOperand> ComposeAugmentation<F> {
pub fn new(augmentations: Vec<Box<dyn Augmentation<F>>>) -> Self {
Self { augmentations }
}
}
impl<F: Float + NumAssign + Debug + ScalarOperand> Augmentation<F> for ComposeAugmentation<F> {
fn apply(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
let mut data = input.clone();
for augmentation in &self.augmentations {
data = augmentation.apply(&data)?;
}
Ok(data)
}
fn description(&self) -> String {
let descriptions: Vec<String> =
self.augmentations.iter().map(|a| a.description()).collect();
format!("Compose({})", descriptions.join(", "))
}
}