1use rand::thread_rng;
7use rand::Rng;
8
9use crate::dataset::Sample;
10use crate::transform::Transform;
11
12#[derive(Debug, Clone)]
18pub struct RandomHorizontalFlip {
19 pub p: f64,
20}
21
22impl RandomHorizontalFlip {
23 pub fn new(p: f64) -> Self {
24 Self { p }
25 }
26}
27
28impl Transform for RandomHorizontalFlip {
29 fn apply(&self, mut sample: Sample) -> Sample {
30 let mut rng = thread_rng();
31 if rng.gen::<f64>() >= self.p {
32 return sample;
33 }
34 let shape = &sample.feature_shape;
35 if shape.len() != 3 {
36 return sample;
37 }
38 let (c, h, w) = (shape[0], shape[1], shape[2]);
39 let mut flipped = vec![0.0; c * h * w];
40 for ch in 0..c {
41 for row in 0..h {
42 for col in 0..w {
43 let src = ch * h * w + row * w + col;
44 let dst = ch * h * w + row * w + (w - 1 - col);
45 flipped[dst] = sample.features[src];
46 }
47 }
48 }
49 sample.features = flipped;
50 sample
51 }
52}
53
54#[derive(Debug, Clone)]
60pub struct RandomVerticalFlip {
61 pub p: f64,
62}
63
64impl RandomVerticalFlip {
65 pub fn new(p: f64) -> Self {
66 Self { p }
67 }
68}
69
70impl Transform for RandomVerticalFlip {
71 fn apply(&self, mut sample: Sample) -> Sample {
72 let mut rng = thread_rng();
73 if rng.gen::<f64>() >= self.p {
74 return sample;
75 }
76 let shape = &sample.feature_shape;
77 if shape.len() != 3 {
78 return sample;
79 }
80 let (c, h, w) = (shape[0], shape[1], shape[2]);
81 let mut flipped = vec![0.0; c * h * w];
82 for ch in 0..c {
83 for row in 0..h {
84 for col in 0..w {
85 let src = ch * h * w + row * w + col;
86 let dst = ch * h * w + (h - 1 - row) * w + col;
87 flipped[dst] = sample.features[src];
88 }
89 }
90 }
91 sample.features = flipped;
92 sample
93 }
94}
95
96#[derive(Debug, Clone)]
103pub struct RandomCrop {
104 pub crop_h: usize,
105 pub crop_w: usize,
106 pub padding: usize,
107}
108
109impl RandomCrop {
110 pub fn new(crop_h: usize, crop_w: usize, padding: usize) -> Self {
111 Self {
112 crop_h,
113 crop_w,
114 padding,
115 }
116 }
117}
118
119impl Transform for RandomCrop {
120 fn apply(&self, mut sample: Sample) -> Sample {
121 let shape = &sample.feature_shape;
122 if shape.len() != 3 {
123 return sample;
124 }
125 let (c, h, w) = (shape[0], shape[1], shape[2]);
126 let pad = self.padding;
127 let padded_h = h + 2 * pad;
128 let padded_w = w + 2 * pad;
129
130 let mut padded = vec![0.0; c * padded_h * padded_w];
132 for ch in 0..c {
133 for row in 0..h {
134 for col in 0..w {
135 let src = ch * h * w + row * w + col;
136 let dst = ch * padded_h * padded_w + (row + pad) * padded_w + (col + pad);
137 padded[dst] = sample.features[src];
138 }
139 }
140 }
141
142 let mut rng = thread_rng();
144 let max_y = padded_h.saturating_sub(self.crop_h);
145 let max_x = padded_w.saturating_sub(self.crop_w);
146 let y0 = if max_y > 0 {
147 rng.gen_range(0..=max_y)
148 } else {
149 0
150 };
151 let x0 = if max_x > 0 {
152 rng.gen_range(0..=max_x)
153 } else {
154 0
155 };
156
157 let mut cropped = vec![0.0; c * self.crop_h * self.crop_w];
158 for ch in 0..c {
159 for row in 0..self.crop_h {
160 for col in 0..self.crop_w {
161 let src = ch * padded_h * padded_w + (y0 + row) * padded_w + (x0 + col);
162 let dst = ch * self.crop_h * self.crop_w + row * self.crop_w + col;
163 cropped[dst] = padded[src];
164 }
165 }
166 }
167
168 sample.features = cropped;
169 sample.feature_shape = vec![c, self.crop_h, self.crop_w];
170 sample
171 }
172}
173
174#[derive(Debug, Clone)]
178pub struct RandomNoise {
179 pub std_dev: f64,
180}
181
182impl RandomNoise {
183 pub fn new(std_dev: f64) -> Self {
184 Self { std_dev }
185 }
186}
187
188impl Transform for RandomNoise {
189 fn apply(&self, mut sample: Sample) -> Sample {
190 use rand_distr::{Distribution, Normal};
191 let normal = Normal::new(0.0, self.std_dev).unwrap();
192 let mut rng = thread_rng();
193 for v in &mut sample.features {
194 *v += normal.sample(&mut rng);
195 }
196 sample
197 }
198}
199
200#[derive(Debug, Clone)]
208pub struct RandomErasing {
209 pub p: f64,
210 pub fill_value: f64,
211 pub min_area_ratio: f64,
212 pub max_area_ratio: f64,
213}
214
215impl RandomErasing {
216 pub fn new(p: f64) -> Self {
217 Self {
218 p,
219 fill_value: 0.0,
220 min_area_ratio: 0.02,
221 max_area_ratio: 0.33,
222 }
223 }
224}
225
226impl Transform for RandomErasing {
227 fn apply(&self, mut sample: Sample) -> Sample {
228 let mut rng = thread_rng();
229 if rng.gen::<f64>() >= self.p {
230 return sample;
231 }
232 let shape = &sample.feature_shape;
233 if shape.len() != 3 {
234 return sample;
235 }
236 let (c, h, w) = (shape[0], shape[1], shape[2]);
237 let area = (h * w) as f64;
238
239 let target_area = area * rng.gen_range(self.min_area_ratio..self.max_area_ratio);
241 let aspect = rng.gen_range(0.3f64..3.3f64);
242 let erase_h = (target_area * aspect).sqrt().round() as usize;
243 let erase_w = (target_area / aspect).sqrt().round() as usize;
244
245 if erase_h >= h || erase_w >= w {
246 return sample;
247 }
248
249 let y0 = rng.gen_range(0..h - erase_h);
250 let x0 = rng.gen_range(0..w - erase_w);
251
252 for ch in 0..c {
253 for row in y0..y0 + erase_h {
254 for col in x0..x0 + erase_w {
255 sample.features[ch * h * w + row * w + col] = self.fill_value;
256 }
257 }
258 }
259 sample
260 }
261}
262
263#[derive(Debug, Clone)]
270pub struct ColorJitter {
271 pub brightness: f64,
272 pub contrast: f64,
273}
274
275impl ColorJitter {
276 pub fn new(brightness: f64, contrast: f64) -> Self {
277 Self {
278 brightness,
279 contrast,
280 }
281 }
282}
283
284impl Transform for ColorJitter {
285 fn apply(&self, mut sample: Sample) -> Sample {
286 let mut rng = thread_rng();
287
288 if self.brightness > 0.0 {
290 let delta = rng.gen_range(-self.brightness..self.brightness);
291 for v in &mut sample.features {
292 *v += delta;
293 }
294 }
295
296 if self.contrast > 0.0 {
298 let factor = rng.gen_range(1.0 - self.contrast..1.0 + self.contrast);
299 let mean: f64 = sample.features.iter().sum::<f64>() / sample.features.len() as f64;
300 for v in &mut sample.features {
301 *v = mean + (*v - mean) * factor;
302 }
303 }
304
305 sample
306 }
307}
308
309#[cfg(test)]
312mod tests {
313 use super::*;
314
315 fn make_image_sample(c: usize, h: usize, w: usize) -> Sample {
316 let n = c * h * w;
317 Sample {
318 features: (0..n).map(|i| i as f64).collect(),
319 feature_shape: vec![c, h, w],
320 target: vec![0.0],
321 target_shape: vec![1],
322 }
323 }
324
325 #[test]
326 fn horizontal_flip_deterministic() {
327 let flip = RandomHorizontalFlip::new(1.0);
329 let sample = make_image_sample(1, 2, 3);
330 let result = flip.apply(sample);
332 assert_eq!(result.features, vec![2.0, 1.0, 0.0, 5.0, 4.0, 3.0]);
334 }
335
336 #[test]
337 fn vertical_flip_deterministic() {
338 let flip = RandomVerticalFlip::new(1.0);
339 let sample = make_image_sample(1, 2, 3);
340 let result = flip.apply(sample);
342 assert_eq!(result.features, vec![3.0, 4.0, 5.0, 0.0, 1.0, 2.0]);
344 }
345
346 #[test]
347 fn random_crop_no_padding_same_size() {
348 let crop = RandomCrop::new(2, 3, 0);
349 let sample = make_image_sample(1, 2, 3);
350 let result = crop.apply(sample);
351 assert_eq!(result.feature_shape, vec![1, 2, 3]);
352 assert_eq!(result.features.len(), 6);
353 }
354
355 #[test]
356 fn random_crop_with_padding() {
357 let crop = RandomCrop::new(4, 4, 1);
358 let sample = make_image_sample(1, 4, 4);
359 let result = crop.apply(sample);
360 assert_eq!(result.feature_shape, vec![1, 4, 4]);
361 assert_eq!(result.features.len(), 16);
362 }
363
364 #[test]
365 fn random_noise_changes_values() {
366 let noise = RandomNoise::new(1.0);
367 let sample = make_image_sample(1, 2, 2);
368 let result = noise.apply(sample.clone());
369 let changed = result
371 .features
372 .iter()
373 .zip(sample.features.iter())
374 .any(|(a, b)| (a - b).abs() > 1e-10);
375 assert!(changed);
376 }
377
378 #[test]
379 fn random_erasing_p1() {
380 let erasing = RandomErasing::new(1.0);
381 let sample = make_image_sample(1, 8, 8);
382 let result = erasing.apply(sample);
383 let num_zeros = result.features.iter().filter(|&&v| v == 0.0).count();
385 assert!(num_zeros >= 2, "Expected erased zeros, got {}", num_zeros);
387 }
388
389 #[test]
390 fn color_jitter() {
391 let jitter = ColorJitter::new(0.1, 0.1);
392 let sample = make_image_sample(1, 2, 2);
393 let result = jitter.apply(sample.clone());
394 let changed = result
395 .features
396 .iter()
397 .zip(sample.features.iter())
398 .any(|(a, b)| (a - b).abs() > 1e-10);
399 assert!(changed);
400 }
401}