1use ghostflow_core::Tensor;
6use rand::prelude::*;
7
8pub struct ImageAugmentation {
10 pub horizontal_flip: bool,
11 pub vertical_flip: bool,
12 pub rotation_range: f32,
13 pub width_shift_range: f32,
14 pub height_shift_range: f32,
15 pub zoom_range: (f32, f32),
16 pub brightness_range: (f32, f32),
17 pub random_seed: Option<u64>,
18}
19
20impl ImageAugmentation {
21 pub fn new() -> Self {
22 ImageAugmentation {
23 horizontal_flip: false,
24 vertical_flip: false,
25 rotation_range: 0.0,
26 width_shift_range: 0.0,
27 height_shift_range: 0.0,
28 zoom_range: (1.0, 1.0),
29 brightness_range: (1.0, 1.0),
30 random_seed: None,
31 }
32 }
33
34 pub fn horizontal_flip(mut self, flip: bool) -> Self {
35 self.horizontal_flip = flip;
36 self
37 }
38
39 pub fn vertical_flip(mut self, flip: bool) -> Self {
40 self.vertical_flip = flip;
41 self
42 }
43
44 pub fn rotation_range(mut self, degrees: f32) -> Self {
45 self.rotation_range = degrees;
46 self
47 }
48
49 pub fn shift_range(mut self, width: f32, height: f32) -> Self {
50 self.width_shift_range = width;
51 self.height_shift_range = height;
52 self
53 }
54
55 pub fn zoom_range(mut self, min: f32, max: f32) -> Self {
56 self.zoom_range = (min, max);
57 self
58 }
59
60 pub fn brightness_range(mut self, min: f32, max: f32) -> Self {
61 self.brightness_range = (min, max);
62 self
63 }
64
65 pub fn augment(&self, image: &Tensor) -> Tensor {
66 let mut rng = match self.random_seed {
67 Some(seed) => StdRng::seed_from_u64(seed),
68 None => StdRng::from_entropy(),
69 };
70
71 let dims = image.dims();
72 let mut data = image.data_f32().to_vec();
73
74 if self.horizontal_flip && rng.gen::<f32>() > 0.5 {
76 data = self.flip_horizontal(&data, dims);
77 }
78
79 if self.vertical_flip && rng.gen::<f32>() > 0.5 {
81 data = self.flip_vertical(&data, dims);
82 }
83
84 if self.brightness_range.0 != 1.0 || self.brightness_range.1 != 1.0 {
86 let factor = rng.gen::<f32>() * (self.brightness_range.1 - self.brightness_range.0)
87 + self.brightness_range.0;
88 for pixel in &mut data {
89 *pixel *= factor;
90 *pixel = pixel.clamp(0.0, 1.0);
91 }
92 }
93
94 Tensor::from_slice(&data, dims).unwrap()
95 }
96
97 fn flip_horizontal(&self, data: &[f32], dims: &[usize]) -> Vec<f32> {
98 let (channels, height, width) = if dims.len() == 3 {
99 (dims[0], dims[1], dims[2])
100 } else {
101 (1, dims[0], dims[1])
102 };
103
104 let mut flipped = vec![0.0f32; data.len()];
105
106 for c in 0..channels {
107 for h in 0..height {
108 for w in 0..width {
109 let src_idx = c * height * width + h * width + w;
110 let dst_idx = c * height * width + h * width + (width - 1 - w);
111 flipped[dst_idx] = data[src_idx];
112 }
113 }
114 }
115
116 flipped
117 }
118
119 fn flip_vertical(&self, data: &[f32], dims: &[usize]) -> Vec<f32> {
120 let (channels, height, width) = if dims.len() == 3 {
121 (dims[0], dims[1], dims[2])
122 } else {
123 (1, dims[0], dims[1])
124 };
125
126 let mut flipped = vec![0.0f32; data.len()];
127
128 for c in 0..channels {
129 for h in 0..height {
130 for w in 0..width {
131 let src_idx = c * height * width + h * width + w;
132 let dst_idx = c * height * width + (height - 1 - h) * width + w;
133 flipped[dst_idx] = data[src_idx];
134 }
135 }
136 }
137
138 flipped
139 }
140}
141
142impl Default for ImageAugmentation {
143 fn default() -> Self { Self::new() }
144}
145
146pub struct ImageNormalization {
148 pub mean: Vec<f32>,
149 pub std: Vec<f32>,
150}
151
152impl ImageNormalization {
153 pub fn new(mean: Vec<f32>, std: Vec<f32>) -> Self {
154 ImageNormalization { mean, std }
155 }
156
157 pub fn imagenet() -> Self {
159 ImageNormalization {
160 mean: vec![0.485, 0.456, 0.406],
161 std: vec![0.229, 0.224, 0.225],
162 }
163 }
164
165 pub fn normalize(&self, image: &Tensor) -> Tensor {
166 let dims = image.dims();
167 let data = image.data_f32();
168 let _channels = if dims.len() == 3 { dims[0] } else { 1 };
169
170 let normalized: Vec<f32> = data.iter()
171 .enumerate()
172 .map(|(i, &pixel)| {
173 let c = if dims.len() == 3 {
174 i / (dims[1] * dims[2])
175 } else {
176 0
177 };
178 (pixel - self.mean[c % self.mean.len()]) / self.std[c % self.std.len()]
179 })
180 .collect();
181
182 Tensor::from_slice(&normalized, dims).unwrap()
183 }
184
185 pub fn denormalize(&self, image: &Tensor) -> Tensor {
186 let dims = image.dims();
187 let data = image.data_f32();
188 let _channels = if dims.len() == 3 { dims[0] } else { 1 };
189
190 let denormalized: Vec<f32> = data.iter()
191 .enumerate()
192 .map(|(i, &pixel)| {
193 let c = if dims.len() == 3 {
194 i / (dims[1] * dims[2])
195 } else {
196 0
197 };
198 pixel * self.std[c % self.std.len()] + self.mean[c % self.mean.len()]
199 })
200 .collect();
201
202 Tensor::from_slice(&denormalized, dims).unwrap()
203 }
204}
205
206pub struct ImageResize {
208 pub target_size: (usize, usize),
209 pub interpolation: Interpolation,
210}
211
212#[derive(Clone, Copy)]
213pub enum Interpolation {
214 Nearest,
215 Bilinear,
216}
217
218impl ImageResize {
219 pub fn new(width: usize, height: usize) -> Self {
220 ImageResize {
221 target_size: (width, height),
222 interpolation: Interpolation::Bilinear,
223 }
224 }
225
226 pub fn interpolation(mut self, interp: Interpolation) -> Self {
227 self.interpolation = interp;
228 self
229 }
230
231 pub fn resize(&self, image: &Tensor) -> Tensor {
232 let dims = image.dims();
233 let data = image.data_f32();
234
235 let (channels, src_height, src_width) = if dims.len() == 3 {
236 (dims[0], dims[1], dims[2])
237 } else {
238 (1, dims[0], dims[1])
239 };
240
241 let (dst_width, dst_height) = self.target_size;
242
243 match self.interpolation {
244 Interpolation::Nearest => {
245 self.resize_nearest(&data, channels, src_height, src_width, dst_height, dst_width)
246 }
247 Interpolation::Bilinear => {
248 self.resize_bilinear(&data, channels, src_height, src_width, dst_height, dst_width)
249 }
250 }
251 }
252
253 fn resize_nearest(
254 &self,
255 data: &[f32],
256 channels: usize,
257 src_h: usize,
258 src_w: usize,
259 dst_h: usize,
260 dst_w: usize,
261 ) -> Tensor {
262 let mut resized = vec![0.0f32; channels * dst_h * dst_w];
263
264 let scale_h = src_h as f32 / dst_h as f32;
265 let scale_w = src_w as f32 / dst_w as f32;
266
267 for c in 0..channels {
268 for h in 0..dst_h {
269 for w in 0..dst_w {
270 let src_h_idx = (h as f32 * scale_h) as usize;
271 let src_w_idx = (w as f32 * scale_w) as usize;
272
273 let src_idx = c * src_h * src_w + src_h_idx * src_w + src_w_idx;
274 let dst_idx = c * dst_h * dst_w + h * dst_w + w;
275
276 resized[dst_idx] = data[src_idx];
277 }
278 }
279 }
280
281 let dims = if channels == 1 {
282 vec![dst_h, dst_w]
283 } else {
284 vec![channels, dst_h, dst_w]
285 };
286
287 Tensor::from_slice(&resized, &dims).unwrap()
288 }
289
290 fn resize_bilinear(
291 &self,
292 data: &[f32],
293 channels: usize,
294 src_h: usize,
295 src_w: usize,
296 dst_h: usize,
297 dst_w: usize,
298 ) -> Tensor {
299 let mut resized = vec![0.0f32; channels * dst_h * dst_w];
300
301 let scale_h = src_h as f32 / dst_h as f32;
302 let scale_w = src_w as f32 / dst_w as f32;
303
304 for c in 0..channels {
305 for h in 0..dst_h {
306 for w in 0..dst_w {
307 let src_h_f = h as f32 * scale_h;
308 let src_w_f = w as f32 * scale_w;
309
310 let h0 = src_h_f.floor() as usize;
311 let w0 = src_w_f.floor() as usize;
312 let h1 = (h0 + 1).min(src_h - 1);
313 let w1 = (w0 + 1).min(src_w - 1);
314
315 let dh = src_h_f - h0 as f32;
316 let dw = src_w_f - w0 as f32;
317
318 let idx00 = c * src_h * src_w + h0 * src_w + w0;
319 let idx01 = c * src_h * src_w + h0 * src_w + w1;
320 let idx10 = c * src_h * src_w + h1 * src_w + w0;
321 let idx11 = c * src_h * src_w + h1 * src_w + w1;
322
323 let val = (1.0 - dh) * (1.0 - dw) * data[idx00]
324 + (1.0 - dh) * dw * data[idx01]
325 + dh * (1.0 - dw) * data[idx10]
326 + dh * dw * data[idx11];
327
328 let dst_idx = c * dst_h * dst_w + h * dst_w + w;
329 resized[dst_idx] = val;
330 }
331 }
332 }
333
334 let dims = if channels == 1 {
335 vec![dst_h, dst_w]
336 } else {
337 vec![channels, dst_h, dst_w]
338 };
339
340 Tensor::from_slice(&resized, &dims).unwrap()
341 }
342}
343
344pub struct ImageCrop {
346 pub top: usize,
347 pub left: usize,
348 pub height: usize,
349 pub width: usize,
350}
351
352impl ImageCrop {
353 pub fn new(top: usize, left: usize, height: usize, width: usize) -> Self {
354 ImageCrop { top, left, height, width }
355 }
356
357 pub fn center_crop(image_height: usize, image_width: usize, crop_size: usize) -> Self {
358 let top = (image_height - crop_size) / 2;
359 let left = (image_width - crop_size) / 2;
360 ImageCrop {
361 top,
362 left,
363 height: crop_size,
364 width: crop_size,
365 }
366 }
367
368 pub fn crop(&self, image: &Tensor) -> Tensor {
369 let dims = image.dims();
370 let data = image.data_f32();
371
372 let (channels, src_height, src_width) = if dims.len() == 3 {
373 (dims[0], dims[1], dims[2])
374 } else {
375 (1, dims[0], dims[1])
376 };
377
378 let mut cropped = vec![0.0f32; channels * self.height * self.width];
379
380 for c in 0..channels {
381 for h in 0..self.height {
382 for w in 0..self.width {
383 let src_h = self.top + h;
384 let src_w = self.left + w;
385
386 if src_h < src_height && src_w < src_width {
387 let src_idx = c * src_height * src_width + src_h * src_width + src_w;
388 let dst_idx = c * self.height * self.width + h * self.width + w;
389 cropped[dst_idx] = data[src_idx];
390 }
391 }
392 }
393 }
394
395 let dims = if channels == 1 {
396 vec![self.height, self.width]
397 } else {
398 vec![channels, self.height, self.width]
399 };
400
401 Tensor::from_slice(&cropped, &dims).unwrap()
402 }
403}
404
405pub struct RandomCrop {
407 pub height: usize,
408 pub width: usize,
409 pub random_seed: Option<u64>,
410}
411
412impl RandomCrop {
413 pub fn new(height: usize, width: usize) -> Self {
414 RandomCrop {
415 height,
416 width,
417 random_seed: None,
418 }
419 }
420
421 pub fn crop(&self, image: &Tensor) -> Tensor {
422 let dims = image.dims();
423 let (_, src_height, src_width) = if dims.len() == 3 {
424 (dims[0], dims[1], dims[2])
425 } else {
426 (1, dims[0], dims[1])
427 };
428
429 let mut rng = match self.random_seed {
430 Some(seed) => StdRng::seed_from_u64(seed),
431 None => StdRng::from_entropy(),
432 };
433
434 let max_top = src_height.saturating_sub(self.height);
435 let max_left = src_width.saturating_sub(self.width);
436
437 let top = if max_top > 0 { rng.gen_range(0..=max_top) } else { 0 };
438 let left = if max_left > 0 { rng.gen_range(0..=max_left) } else { 0 };
439
440 let crop = ImageCrop::new(top, left, self.height, self.width);
441 crop.crop(image)
442 }
443}
444
445pub struct ColorJitter {
447 pub brightness: f32,
448 pub contrast: f32,
449 pub saturation: f32,
450 pub hue: f32,
451}
452
453impl ColorJitter {
454 pub fn new() -> Self {
455 ColorJitter {
456 brightness: 0.0,
457 contrast: 0.0,
458 saturation: 0.0,
459 hue: 0.0,
460 }
461 }
462
463 pub fn brightness(mut self, factor: f32) -> Self {
464 self.brightness = factor;
465 self
466 }
467
468 pub fn contrast(mut self, factor: f32) -> Self {
469 self.contrast = factor;
470 self
471 }
472
473 pub fn apply(&self, image: &Tensor) -> Tensor {
474 let mut rng = thread_rng();
475 let data = image.data_f32();
476 let mut jittered = data.to_vec();
477
478 if self.brightness > 0.0 {
480 let factor = 1.0 + (rng.gen::<f32>() - 0.5) * 2.0 * self.brightness;
481 for pixel in &mut jittered {
482 *pixel = (*pixel * factor).clamp(0.0, 1.0);
483 }
484 }
485
486 if self.contrast > 0.0 {
488 let mean = jittered.iter().sum::<f32>() / jittered.len() as f32;
489 let factor = 1.0 + (rng.gen::<f32>() - 0.5) * 2.0 * self.contrast;
490 for pixel in &mut jittered {
491 *pixel = ((*pixel - mean) * factor + mean).clamp(0.0, 1.0);
492 }
493 }
494
495 Tensor::from_slice(&jittered, image.dims()).unwrap()
496 }
497}
498
499impl Default for ColorJitter {
500 fn default() -> Self { Self::new() }
501}
502
503#[cfg(test)]
504mod tests {
505 use super::*;
506
507 #[test]
508 fn test_image_augmentation() {
509 let image = Tensor::from_slice(&vec![0.5f32; 3 * 32 * 32], &[3, 32, 32]).unwrap();
510
511 let aug = ImageAugmentation::new()
512 .horizontal_flip(true)
513 .brightness_range(0.8, 1.2);
514
515 let augmented = aug.augment(&image);
516 assert_eq!(augmented.dims(), image.dims());
517 }
518
519 #[test]
520 fn test_image_normalization() {
521 let image = Tensor::from_slice(&vec![0.5f32; 3 * 32 * 32], &[3, 32, 32]).unwrap();
522 let norm = ImageNormalization::imagenet();
523
524 let normalized = norm.normalize(&image);
525 let denormalized = norm.denormalize(&normalized);
526
527 assert_eq!(normalized.dims(), image.dims());
528 assert_eq!(denormalized.dims(), image.dims());
529 }
530
531 #[test]
532 fn test_image_resize() {
533 let image = Tensor::from_slice(&vec![0.5f32; 3 * 64 * 64], &[3, 64, 64]).unwrap();
534 let resize = ImageResize::new(32, 32);
535
536 let resized = resize.resize(&image);
537 assert_eq!(resized.dims(), &[3, 32, 32]);
538 }
539
540 #[test]
541 fn test_image_crop() {
542 let image = Tensor::from_slice(&vec![0.5f32; 3 * 64 * 64], &[3, 64, 64]).unwrap();
543 let crop = ImageCrop::center_crop(64, 64, 32);
544
545 let cropped = crop.crop(&image);
546 assert_eq!(cropped.dims(), &[3, 32, 32]);
547 }
548
549 #[test]
550 fn test_color_jitter() {
551 let image = Tensor::from_slice(&vec![0.5f32; 3 * 32 * 32], &[3, 32, 32]).unwrap();
552 let jitter = ColorJitter::new().brightness(0.2).contrast(0.2);
553
554 let jittered = jitter.apply(&image);
555 assert_eq!(jittered.dims(), image.dims());
556 }
557}
558
559