ghostflow_data/
transforms.rs

1//! Data transforms and augmentations
2
3use ghostflow_core::Tensor;
4use rand::Rng;
5
6/// Trait for data transforms
7pub trait Transform: Send + Sync {
8    fn apply(&self, tensor: &Tensor) -> Tensor;
9}
10
11/// Normalize tensor with mean and std
12pub struct Normalize {
13    mean: Vec<f32>,
14    std: Vec<f32>,
15}
16
17impl Normalize {
18    pub fn new(mean: Vec<f32>, std: Vec<f32>) -> Self {
19        Normalize { mean, std }
20    }
21
22    /// ImageNet normalization
23    pub fn imagenet() -> Self {
24        Normalize {
25            mean: vec![0.485, 0.456, 0.406],
26            std: vec![0.229, 0.224, 0.225],
27        }
28    }
29}
30
31impl Transform for Normalize {
32    fn apply(&self, tensor: &Tensor) -> Tensor {
33        let dims = tensor.dims();
34        let data = tensor.data_f32();
35        
36        // Assume tensor is [C, H, W] or [N, C, H, W]
37        let channels = if dims.len() == 3 { dims[0] } else { dims[1] };
38        let spatial_size: usize = dims[dims.len()-2..].iter().product();
39        
40        let mut result = data.clone();
41        
42        if dims.len() == 3 {
43            // [C, H, W]
44            for c in 0..channels {
45                let start = c * spatial_size;
46                let end = start + spatial_size;
47                for item in result.iter_mut().take(end).skip(start) {
48                    *item = (*item - self.mean[c]) / self.std[c];
49                }
50            }
51        } else {
52            // [N, C, H, W]
53            let batch_size = dims[0];
54            let batch_stride = channels * spatial_size;
55            
56            for b in 0..batch_size {
57                for c in 0..channels {
58                    let start = b * batch_stride + c * spatial_size;
59                    let end = start + spatial_size;
60                    for item in result.iter_mut().take(end).skip(start) {
61                        *item = (*item - self.mean[c]) / self.std[c];
62                    }
63                }
64            }
65        }
66        
67        Tensor::from_slice(&result, dims).unwrap()
68    }
69}
70
71/// Random horizontal flip
72pub struct RandomHorizontalFlip {
73    p: f32,
74}
75
76impl RandomHorizontalFlip {
77    pub fn new(p: f32) -> Self {
78        RandomHorizontalFlip { p }
79    }
80}
81
82impl Default for RandomHorizontalFlip {
83    fn default() -> Self {
84        Self::new(0.5)
85    }
86}
87
88impl Transform for RandomHorizontalFlip {
89    fn apply(&self, tensor: &Tensor) -> Tensor {
90        if rand::thread_rng().gen::<f32>() > self.p {
91            return tensor.clone();
92        }
93
94        let dims = tensor.dims();
95        let data = tensor.data_f32();
96        
97        // Assume [C, H, W] or [N, C, H, W]
98        let (height, width) = if dims.len() == 3 {
99            (dims[1], dims[2])
100        } else {
101            (dims[2], dims[3])
102        };
103        
104        let mut result = data.clone();
105        
106        // Flip along width dimension
107        if dims.len() == 3 {
108            let channels = dims[0];
109            for c in 0..channels {
110                for h in 0..height {
111                    for w in 0..width / 2 {
112                        let idx1 = c * height * width + h * width + w;
113                        let idx2 = c * height * width + h * width + (width - 1 - w);
114                        result.swap(idx1, idx2);
115                    }
116                }
117            }
118        }
119        
120        Tensor::from_slice(&result, dims).unwrap()
121    }
122}
123
124/// Random crop
125pub struct RandomCrop {
126    size: (usize, usize),
127}
128
129impl RandomCrop {
130    pub fn new(height: usize, width: usize) -> Self {
131        RandomCrop { size: (height, width) }
132    }
133}
134
135impl Transform for RandomCrop {
136    fn apply(&self, tensor: &Tensor) -> Tensor {
137        let dims = tensor.dims();
138        let data = tensor.data_f32();
139        
140        let (channels, in_h, in_w) = if dims.len() == 3 {
141            (dims[0], dims[1], dims[2])
142        } else {
143            panic!("RandomCrop expects 3D tensor [C, H, W]");
144        };
145        
146        let (out_h, out_w) = self.size;
147        
148        if in_h < out_h || in_w < out_w {
149            panic!("Crop size larger than input");
150        }
151        
152        let mut rng = rand::thread_rng();
153        let top = rng.gen_range(0..=in_h - out_h);
154        let left = rng.gen_range(0..=in_w - out_w);
155        
156        let mut result = Vec::with_capacity(channels * out_h * out_w);
157        
158        for c in 0..channels {
159            for h in 0..out_h {
160                for w in 0..out_w {
161                    let src_idx = c * in_h * in_w + (top + h) * in_w + (left + w);
162                    result.push(data[src_idx]);
163                }
164            }
165        }
166        
167        Tensor::from_slice(&result, &[channels, out_h, out_w]).unwrap()
168    }
169}
170
171/// Compose multiple transforms
172pub struct Compose {
173    transforms: Vec<Box<dyn Transform>>,
174}
175
176impl Compose {
177    pub fn new(transforms: Vec<Box<dyn Transform>>) -> Self {
178        Compose { transforms }
179    }
180}
181
182impl Transform for Compose {
183    fn apply(&self, tensor: &Tensor) -> Tensor {
184        let mut result = tensor.clone();
185        for t in &self.transforms {
186            result = t.apply(&result);
187        }
188        result
189    }
190}
191
192/// Convert to tensor (identity for already-tensor data)
193pub struct ToTensor;
194
195impl Transform for ToTensor {
196    fn apply(&self, tensor: &Tensor) -> Tensor {
197        tensor.clone()
198    }
199}
200
201/// Random erasing augmentation
202pub struct RandomErasing {
203    p: f32,
204    scale: (f32, f32),
205    ratio: (f32, f32),
206    value: f32,
207}
208
209impl RandomErasing {
210    pub fn new(p: f32) -> Self {
211        RandomErasing {
212            p,
213            scale: (0.02, 0.33),
214            ratio: (0.3, 3.3),
215            value: 0.0,
216        }
217    }
218}
219
220impl Default for RandomErasing {
221    fn default() -> Self {
222        Self::new(0.5)
223    }
224}
225
226impl Transform for RandomErasing {
227    fn apply(&self, tensor: &Tensor) -> Tensor {
228        if rand::thread_rng().gen::<f32>() > self.p {
229            return tensor.clone();
230        }
231
232        let dims = tensor.dims();
233        let mut data = tensor.data_f32();
234        
235        let (channels, height, width) = if dims.len() == 3 {
236            (dims[0], dims[1], dims[2])
237        } else {
238            return tensor.clone();
239        };
240        
241        let area = (height * width) as f32;
242        let mut rng = rand::thread_rng();
243        
244        for _ in 0..10 {
245            let target_area = rng.gen_range(self.scale.0..self.scale.1) * area;
246            let aspect_ratio = rng.gen_range(self.ratio.0..self.ratio.1);
247            
248            let h = (target_area * aspect_ratio).sqrt() as usize;
249            let w = (target_area / aspect_ratio).sqrt() as usize;
250            
251            if h < height && w < width {
252                let top = rng.gen_range(0..height - h);
253                let left = rng.gen_range(0..width - w);
254                
255                for c in 0..channels {
256                    for y in top..top + h {
257                        for x in left..left + w {
258                            let idx = c * height * width + y * width + x;
259                            data[idx] = self.value;
260                        }
261                    }
262                }
263                break;
264            }
265        }
266        
267        Tensor::from_slice(&data, dims).unwrap()
268    }
269}