1use ghostflow_core::tensor::Tensor;
6use rand::Rng;
7
8pub struct RandomHorizontalFlip {
10 pub probability: f32,
11}
12
13impl RandomHorizontalFlip {
14 pub fn new(probability: f32) -> Self {
15 Self { probability }
16 }
17
18 pub fn apply(&self, image: &Tensor) -> Tensor {
19 let mut rng = rand::thread_rng();
20 if rng.gen::<f32>() < self.probability {
21 self.flip_horizontal(image)
22 } else {
23 image.clone()
24 }
25 }
26
27 fn flip_horizontal(&self, image: &Tensor) -> Tensor {
28 let shape = image.shape().dims();
29 let data = image.storage().as_slice::<f32>();
30
31 let (channels, height, width) = if shape.len() == 3 {
33 (shape[0], shape[1], shape[2])
34 } else {
35 (1, shape[0], shape[1])
36 };
37
38 let mut flipped = Vec::with_capacity(data.len());
39
40 for c in 0..channels {
41 for h in 0..height {
42 for w in (0..width).rev() {
43 let idx = c * height * width + h * width + w;
44 flipped.push(data[idx]);
45 }
46 }
47 }
48
49 Tensor::from_slice(&flipped, shape).unwrap()
50 }
51}
52
53pub struct RandomVerticalFlip {
55 pub probability: f32,
56}
57
58impl RandomVerticalFlip {
59 pub fn new(probability: f32) -> Self {
60 Self { probability }
61 }
62
63 pub fn apply(&self, image: &Tensor) -> Tensor {
64 let mut rng = rand::thread_rng();
65 if rng.gen::<f32>() < self.probability {
66 self.flip_vertical(image)
67 } else {
68 image.clone()
69 }
70 }
71
72 fn flip_vertical(&self, image: &Tensor) -> Tensor {
73 let shape = image.shape().dims();
74 let data = image.storage().as_slice::<f32>();
75
76 let (channels, height, width) = if shape.len() == 3 {
77 (shape[0], shape[1], shape[2])
78 } else {
79 (1, shape[0], shape[1])
80 };
81
82 let mut flipped = Vec::with_capacity(data.len());
83
84 for c in 0..channels {
85 for h in (0..height).rev() {
86 for w in 0..width {
87 let idx = c * height * width + h * width + w;
88 flipped.push(data[idx]);
89 }
90 }
91 }
92
93 Tensor::from_slice(&flipped, shape).unwrap()
94 }
95}
96
97pub struct RandomRotation {
99 pub max_degrees: f32,
100}
101
102impl RandomRotation {
103 pub fn new(max_degrees: f32) -> Self {
104 Self { max_degrees }
105 }
106
107 pub fn apply(&self, image: &Tensor) -> Tensor {
108 let mut rng = rand::thread_rng();
109 let degrees = rng.gen_range(-self.max_degrees..=self.max_degrees);
110 self.rotate(image, degrees)
111 }
112
113 fn rotate(&self, image: &Tensor, _degrees: f32) -> Tensor {
114 image.clone()
117 }
118}
119
120pub struct RandomCrop {
122 pub size: (usize, usize),
123}
124
125impl RandomCrop {
126 pub fn new(size: (usize, usize)) -> Self {
127 Self { size }
128 }
129
130 pub fn apply(&self, image: &Tensor) -> Tensor {
131 let shape = image.shape().dims();
132 let (channels, height, width) = if shape.len() == 3 {
133 (shape[0], shape[1], shape[2])
134 } else {
135 (1, shape[0], shape[1])
136 };
137
138 let (crop_h, crop_w) = self.size;
139
140 if crop_h > height || crop_w > width {
141 return image.clone();
142 }
143
144 let mut rng = rand::thread_rng();
145 let top = rng.gen_range(0..=(height - crop_h));
146 let left = rng.gen_range(0..=(width - crop_w));
147
148 self.crop(image, top, left, crop_h, crop_w)
149 }
150
151 fn crop(&self, image: &Tensor, top: usize, left: usize, height: usize, width: usize) -> Tensor {
152 let shape = image.shape().dims();
153 let data = image.storage().as_slice::<f32>();
154
155 let (channels, img_height, img_width) = if shape.len() == 3 {
156 (shape[0], shape[1], shape[2])
157 } else {
158 (1, shape[0], shape[1])
159 };
160
161 let mut cropped = Vec::new();
162
163 for c in 0..channels {
164 for h in top..(top + height) {
165 for w in left..(left + width) {
166 let idx = c * img_height * img_width + h * img_width + w;
167 cropped.push(data[idx]);
168 }
169 }
170 }
171
172 let new_shape = if shape.len() == 3 {
173 vec![channels, height, width]
174 } else {
175 vec![height, width]
176 };
177
178 Tensor::from_slice(&cropped, &new_shape).unwrap()
179 }
180}
181
182pub struct Normalize {
184 pub mean: Vec<f32>,
185 pub std: Vec<f32>,
186}
187
188impl Normalize {
189 pub fn new(mean: Vec<f32>, std: Vec<f32>) -> Self {
190 Self { mean, std }
191 }
192
193 pub fn imagenet() -> Self {
195 Self {
196 mean: vec![0.485, 0.456, 0.406],
197 std: vec![0.229, 0.224, 0.225],
198 }
199 }
200
201 pub fn apply(&self, image: &Tensor) -> Tensor {
202 let shape = image.shape().dims();
203 let data = image.storage().as_slice::<f32>();
204
205 let channels = if shape.len() == 3 { shape[0] } else { 1 };
206 let pixels_per_channel = data.len() / channels;
207
208 let mut normalized = Vec::with_capacity(data.len());
209
210 for c in 0..channels {
211 let mean = self.mean.get(c).copied().unwrap_or(0.0);
212 let std = self.std.get(c).copied().unwrap_or(1.0);
213
214 for i in 0..pixels_per_channel {
215 let idx = c * pixels_per_channel + i;
216 normalized.push((data[idx] - mean) / std);
217 }
218 }
219
220 Tensor::from_slice(&normalized, shape).unwrap()
221 }
222}
223
224pub struct Compose {
226 transforms: Vec<Box<dyn Fn(&Tensor) -> Tensor>>,
227}
228
229impl Compose {
230 pub fn new() -> Self {
231 Self {
232 transforms: Vec::new(),
233 }
234 }
235
236 pub fn add<F>(mut self, transform: F) -> Self
237 where
238 F: Fn(&Tensor) -> Tensor + 'static,
239 {
240 self.transforms.push(Box::new(transform));
241 self
242 }
243
244 pub fn apply(&self, image: &Tensor) -> Tensor {
245 let mut result = image.clone();
246 for transform in &self.transforms {
247 result = transform(&result);
248 }
249 result
250 }
251}
252
253impl Default for Compose {
254 fn default() -> Self {
255 Self::new()
256 }
257}
258
259#[cfg(test)]
260mod tests {
261 use super::*;
262
263 #[test]
264 fn test_random_horizontal_flip() {
265 let image = Tensor::from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
266 let flip = RandomHorizontalFlip::new(1.0); let flipped = flip.apply(&image);
269 assert_eq!(flipped.shape().dims(), &[2, 2]);
270 }
271
272 #[test]
273 fn test_random_crop() {
274 let image = Tensor::from_slice(&[1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0], &[3, 3]).unwrap();
275 let crop = RandomCrop::new((2, 2));
276
277 let cropped = crop.apply(&image);
278 assert_eq!(cropped.shape().dims(), &[2, 2]);
279 }
280
281 #[test]
282 fn test_normalize() {
283 let image = Tensor::from_slice(&[0.5f32, 0.6, 0.7], &[3]).unwrap();
284 let normalize = Normalize::new(vec![0.5], vec![0.1]);
285
286 let normalized = normalize.apply(&image);
287 assert_eq!(normalized.shape().dims(), &[3]);
288
289 let data = normalized.storage().as_slice::<f32>();
290 assert!((data[0] - 0.0).abs() < 0.01); }
292
293 #[test]
294 fn test_compose() {
295 let image = Tensor::from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
296
297 let flip = RandomHorizontalFlip::new(1.0);
298 let normalize = Normalize::new(vec![0.5], vec![0.1]);
299
300 let compose = Compose::new()
301 .add(move |img| flip.apply(img))
302 .add(move |img| normalize.apply(img));
303
304 let result = compose.apply(&image);
305 assert_eq!(result.shape().dims(), &[2, 2]);
306 }
307}