ghostflow_data/
augmentation.rs

1//! Data Augmentation
2//!
3//! Common data augmentation techniques for training.
4
5use ghostflow_core::tensor::Tensor;
6use rand::Rng;
7
8/// Random horizontal flip
9pub struct RandomHorizontalFlip {
10    pub probability: f32,
11}
12
13impl RandomHorizontalFlip {
14    pub fn new(probability: f32) -> Self {
15        Self { probability }
16    }
17
18    pub fn apply(&self, image: &Tensor) -> Tensor {
19        let mut rng = rand::thread_rng();
20        if rng.gen::<f32>() < self.probability {
21            self.flip_horizontal(image)
22        } else {
23            image.clone()
24        }
25    }
26
27    fn flip_horizontal(&self, image: &Tensor) -> Tensor {
28        let shape = image.shape().dims();
29        let data = image.storage().as_slice::<f32>();
30        
31        // Assume image is [C, H, W] or [H, W]
32        let (channels, height, width) = if shape.len() == 3 {
33            (shape[0], shape[1], shape[2])
34        } else {
35            (1, shape[0], shape[1])
36        };
37
38        let mut flipped = Vec::with_capacity(data.len());
39
40        for c in 0..channels {
41            for h in 0..height {
42                for w in (0..width).rev() {
43                    let idx = c * height * width + h * width + w;
44                    flipped.push(data[idx]);
45                }
46            }
47        }
48
49        Tensor::from_slice(&flipped, shape).unwrap()
50    }
51}
52
53/// Random vertical flip
54pub struct RandomVerticalFlip {
55    pub probability: f32,
56}
57
58impl RandomVerticalFlip {
59    pub fn new(probability: f32) -> Self {
60        Self { probability }
61    }
62
63    pub fn apply(&self, image: &Tensor) -> Tensor {
64        let mut rng = rand::thread_rng();
65        if rng.gen::<f32>() < self.probability {
66            self.flip_vertical(image)
67        } else {
68            image.clone()
69        }
70    }
71
72    fn flip_vertical(&self, image: &Tensor) -> Tensor {
73        let shape = image.shape().dims();
74        let data = image.storage().as_slice::<f32>();
75        
76        let (channels, height, width) = if shape.len() == 3 {
77            (shape[0], shape[1], shape[2])
78        } else {
79            (1, shape[0], shape[1])
80        };
81
82        let mut flipped = Vec::with_capacity(data.len());
83
84        for c in 0..channels {
85            for h in (0..height).rev() {
86                for w in 0..width {
87                    let idx = c * height * width + h * width + w;
88                    flipped.push(data[idx]);
89                }
90            }
91        }
92
93        Tensor::from_slice(&flipped, shape).unwrap()
94    }
95}
96
97/// Random rotation
98pub struct RandomRotation {
99    pub max_degrees: f32,
100}
101
102impl RandomRotation {
103    pub fn new(max_degrees: f32) -> Self {
104        Self { max_degrees }
105    }
106
107    pub fn apply(&self, image: &Tensor) -> Tensor {
108        let mut rng = rand::thread_rng();
109        let degrees = rng.gen_range(-self.max_degrees..=self.max_degrees);
110        self.rotate(image, degrees)
111    }
112
113    fn rotate(&self, image: &Tensor, _degrees: f32) -> Tensor {
114        // Simplified rotation - in practice would use proper image rotation
115        // For now, just return the image
116        image.clone()
117    }
118}
119
120/// Random crop
121pub struct RandomCrop {
122    pub size: (usize, usize),
123}
124
125impl RandomCrop {
126    pub fn new(size: (usize, usize)) -> Self {
127        Self { size }
128    }
129
130    pub fn apply(&self, image: &Tensor) -> Tensor {
131        let shape = image.shape().dims();
132        let (channels, height, width) = if shape.len() == 3 {
133            (shape[0], shape[1], shape[2])
134        } else {
135            (1, shape[0], shape[1])
136        };
137
138        let (crop_h, crop_w) = self.size;
139        
140        if crop_h > height || crop_w > width {
141            return image.clone();
142        }
143
144        let mut rng = rand::thread_rng();
145        let top = rng.gen_range(0..=(height - crop_h));
146        let left = rng.gen_range(0..=(width - crop_w));
147
148        self.crop(image, top, left, crop_h, crop_w)
149    }
150
151    fn crop(&self, image: &Tensor, top: usize, left: usize, height: usize, width: usize) -> Tensor {
152        let shape = image.shape().dims();
153        let data = image.storage().as_slice::<f32>();
154        
155        let (channels, img_height, img_width) = if shape.len() == 3 {
156            (shape[0], shape[1], shape[2])
157        } else {
158            (1, shape[0], shape[1])
159        };
160
161        let mut cropped = Vec::new();
162
163        for c in 0..channels {
164            for h in top..(top + height) {
165                for w in left..(left + width) {
166                    let idx = c * img_height * img_width + h * img_width + w;
167                    cropped.push(data[idx]);
168                }
169            }
170        }
171
172        let new_shape = if shape.len() == 3 {
173            vec![channels, height, width]
174        } else {
175            vec![height, width]
176        };
177
178        Tensor::from_slice(&cropped, &new_shape).unwrap()
179    }
180}
181
182/// Normalize
183pub struct Normalize {
184    pub mean: Vec<f32>,
185    pub std: Vec<f32>,
186}
187
188impl Normalize {
189    pub fn new(mean: Vec<f32>, std: Vec<f32>) -> Self {
190        Self { mean, std }
191    }
192
193    /// ImageNet normalization
194    pub fn imagenet() -> Self {
195        Self {
196            mean: vec![0.485, 0.456, 0.406],
197            std: vec![0.229, 0.224, 0.225],
198        }
199    }
200
201    pub fn apply(&self, image: &Tensor) -> Tensor {
202        let shape = image.shape().dims();
203        let data = image.storage().as_slice::<f32>();
204        
205        let channels = if shape.len() == 3 { shape[0] } else { 1 };
206        let pixels_per_channel = data.len() / channels;
207
208        let mut normalized = Vec::with_capacity(data.len());
209
210        for c in 0..channels {
211            let mean = self.mean.get(c).copied().unwrap_or(0.0);
212            let std = self.std.get(c).copied().unwrap_or(1.0);
213            
214            for i in 0..pixels_per_channel {
215                let idx = c * pixels_per_channel + i;
216                normalized.push((data[idx] - mean) / std);
217            }
218        }
219
220        Tensor::from_slice(&normalized, shape).unwrap()
221    }
222}
223
224/// Compose multiple augmentations
225pub struct Compose {
226    transforms: Vec<Box<dyn Fn(&Tensor) -> Tensor>>,
227}
228
229impl Compose {
230    pub fn new() -> Self {
231        Self {
232            transforms: Vec::new(),
233        }
234    }
235
236    pub fn add<F>(mut self, transform: F) -> Self
237    where
238        F: Fn(&Tensor) -> Tensor + 'static,
239    {
240        self.transforms.push(Box::new(transform));
241        self
242    }
243
244    pub fn apply(&self, image: &Tensor) -> Tensor {
245        let mut result = image.clone();
246        for transform in &self.transforms {
247            result = transform(&result);
248        }
249        result
250    }
251}
252
253impl Default for Compose {
254    fn default() -> Self {
255        Self::new()
256    }
257}
258
259#[cfg(test)]
260mod tests {
261    use super::*;
262
263    #[test]
264    fn test_random_horizontal_flip() {
265        let image = Tensor::from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
266        let flip = RandomHorizontalFlip::new(1.0); // Always flip
267        
268        let flipped = flip.apply(&image);
269        assert_eq!(flipped.shape().dims(), &[2, 2]);
270    }
271
272    #[test]
273    fn test_random_crop() {
274        let image = Tensor::from_slice(&[1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0], &[3, 3]).unwrap();
275        let crop = RandomCrop::new((2, 2));
276        
277        let cropped = crop.apply(&image);
278        assert_eq!(cropped.shape().dims(), &[2, 2]);
279    }
280
281    #[test]
282    fn test_normalize() {
283        let image = Tensor::from_slice(&[0.5f32, 0.6, 0.7], &[3]).unwrap();
284        let normalize = Normalize::new(vec![0.5], vec![0.1]);
285        
286        let normalized = normalize.apply(&image);
287        assert_eq!(normalized.shape().dims(), &[3]);
288        
289        let data = normalized.storage().as_slice::<f32>();
290        assert!((data[0] - 0.0).abs() < 0.01); // (0.5 - 0.5) / 0.1 = 0
291    }
292
293    #[test]
294    fn test_compose() {
295        let image = Tensor::from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
296        
297        let flip = RandomHorizontalFlip::new(1.0);
298        let normalize = Normalize::new(vec![0.5], vec![0.1]);
299        
300        let compose = Compose::new()
301            .add(move |img| flip.apply(img))
302            .add(move |img| normalize.apply(img));
303        
304        let result = compose.apply(&image);
305        assert_eq!(result.shape().dims(), &[2, 2]);
306    }
307}