scirs2_neural/data/
augmentation.rs1use crate::error::{NeuralError, Result};
4use scirs2_core::ndarray::{Array, IxDyn, ScalarOperand};
5use scirs2_core::numeric::{Float, NumAssign};
6use scirs2_core::random::rngs::SmallRng;
7use scirs2_core::random::Distribution;
8use scirs2_core::random::{thread_rng, Rng, RngExt, SeedableRng};
9use std::fmt::Debug;
10
11pub trait Augmentation<F: Float + NumAssign + Debug + ScalarOperand> {
13 fn apply(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>>;
15
16 fn description(&self) -> String;
18}
19
20#[derive(Debug, Clone)]
22pub struct GaussianNoise<F: Float + NumAssign + Debug + ScalarOperand> {
23 std: F,
25}
26
27impl<F: Float + NumAssign + Debug + ScalarOperand> GaussianNoise<F> {
28 pub fn new(std: F) -> Self {
30 Self { std }
31 }
32}
33
34impl<F: Float + NumAssign + Debug + ScalarOperand> Augmentation<F> for GaussianNoise<F> {
35 fn apply(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
36 let mut rng = SmallRng::from_rng(&mut thread_rng());
37 let mut result = input.clone();
38
39 for item in result.iter_mut() {
40 let normal = scirs2_core::random::Normal::new(0.0, self.std.to_f64().unwrap_or(0.1))
42 .expect("Failed to create normal distribution");
43
44 let noise = F::from(rng.sample(normal)).unwrap_or(F::zero());
46 *item += noise;
47 }
48
49 Ok(result)
50 }
51
52 fn description(&self) -> String {
53 format!(
54 "GaussianNoise (std: {:.3})",
55 self.std.to_f64().unwrap_or(0.0)
56 )
57 }
58}
59
60#[derive(Debug, Clone)]
62pub struct RandomErasing<F: Float + NumAssign + Debug + ScalarOperand> {
63 probability: f64,
65 value: F,
67}
68
69impl<F: Float + NumAssign + Debug + ScalarOperand> RandomErasing<F> {
70 pub fn new(probability: f64, value: F) -> Self {
72 Self { probability, value }
73 }
74}
75
76impl<F: Float + NumAssign + Debug + ScalarOperand> Augmentation<F> for RandomErasing<F> {
77 fn apply(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
78 let mut rng = SmallRng::from_rng(&mut thread_rng());
79 let mut result = input.clone();
80
81 if rng.random::<f64>() > self.probability {
83 return Ok(result);
84 }
85
86 if result.ndim() < 3 {
88 return Ok(result);
89 }
90
91 Ok(result)
95 }
96
97 fn description(&self) -> String {
98 format!(
99 "RandomErasing (prob: {:.2}, value: {:.2})",
100 self.probability,
101 self.value.to_f64().unwrap_or(0.0)
102 )
103 }
104}
105
106#[derive(Debug, Clone)]
108pub struct RandomHorizontalFlip<F: Float + NumAssign + Debug + ScalarOperand> {
109 probability: f64,
111 _phantom: std::marker::PhantomData<F>,
113}
114
115impl<F: Float + NumAssign + Debug + ScalarOperand> RandomHorizontalFlip<F> {
116 pub fn new(probability: f64) -> Self {
118 Self {
119 probability,
120 _phantom: std::marker::PhantomData,
121 }
122 }
123}
124
125impl<F: Float + NumAssign + Debug + ScalarOperand> Augmentation<F> for RandomHorizontalFlip<F> {
126 fn apply(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
127 let mut rng = SmallRng::from_rng(&mut thread_rng());
128 let result = input.clone();
129
130 if rng.random::<f64>() > self.probability {
132 return Ok(result);
133 }
134
135 Ok(result)
139 }
140
141 fn description(&self) -> String {
142 format!("RandomHorizontalFlip (prob: {:.2})", self.probability)
143 }
144}
145
146struct DebugAugmentationWrapper<'a, F: Float + NumAssign + Debug + ScalarOperand> {
148 inner: &'a dyn Augmentation<F>,
150}
151
152impl<F: Float + NumAssign + Debug + ScalarOperand> Debug for DebugAugmentationWrapper<'_, F> {
153 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
154 write!(f, "Augmentation({})", self.inner.description())
155 }
156}
157
158pub struct ComposeAugmentation<F: Float + NumAssign + Debug + ScalarOperand> {
160 augmentations: Vec<Box<dyn Augmentation<F>>>,
162}
163
164impl<F: Float + NumAssign + Debug + ScalarOperand> Clone for ComposeAugmentation<F> {
165 fn clone(&self) -> Self {
166 Self {
170 augmentations: Vec::new(),
171 }
172 }
173}
174
175impl<F: Float + NumAssign + Debug + ScalarOperand> Debug for ComposeAugmentation<F> {
176 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
177 let mut debug_list = f.debug_list();
178 for augmentation in &self.augmentations {
179 debug_list.entry(&DebugAugmentationWrapper {
180 inner: augmentation.as_ref(),
181 });
182 }
183 debug_list.finish()
184 }
185}
186
187impl<F: Float + NumAssign + Debug + ScalarOperand> ComposeAugmentation<F> {
188 pub fn new(augmentations: Vec<Box<dyn Augmentation<F>>>) -> Self {
190 Self { augmentations }
191 }
192}
193
194impl<F: Float + NumAssign + Debug + ScalarOperand> Augmentation<F> for ComposeAugmentation<F> {
195 fn apply(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
196 let mut data = input.clone();
197 for augmentation in &self.augmentations {
198 data = augmentation.apply(&data)?;
199 }
200 Ok(data)
201 }
202
203 fn description(&self) -> String {
204 let descriptions: Vec<String> =
205 self.augmentations.iter().map(|a| a.description()).collect();
206 format!("Compose({})", descriptions.join(", "))
207 }
208}