ghostflow_ml/
vision.rs

1//! Computer Vision - Image Processing and Augmentation
2//!
3//! This module provides image processing utilities for computer vision tasks.
4
5use ghostflow_core::Tensor;
6use rand::prelude::*;
7
8/// Image data augmentation
9pub struct ImageAugmentation {
10    pub horizontal_flip: bool,
11    pub vertical_flip: bool,
12    pub rotation_range: f32,
13    pub width_shift_range: f32,
14    pub height_shift_range: f32,
15    pub zoom_range: (f32, f32),
16    pub brightness_range: (f32, f32),
17    pub random_seed: Option<u64>,
18}
19
20impl ImageAugmentation {
21    pub fn new() -> Self {
22        ImageAugmentation {
23            horizontal_flip: false,
24            vertical_flip: false,
25            rotation_range: 0.0,
26            width_shift_range: 0.0,
27            height_shift_range: 0.0,
28            zoom_range: (1.0, 1.0),
29            brightness_range: (1.0, 1.0),
30            random_seed: None,
31        }
32    }
33
34    pub fn horizontal_flip(mut self, flip: bool) -> Self {
35        self.horizontal_flip = flip;
36        self
37    }
38
39    pub fn vertical_flip(mut self, flip: bool) -> Self {
40        self.vertical_flip = flip;
41        self
42    }
43
44    pub fn rotation_range(mut self, degrees: f32) -> Self {
45        self.rotation_range = degrees;
46        self
47    }
48
49    pub fn shift_range(mut self, width: f32, height: f32) -> Self {
50        self.width_shift_range = width;
51        self.height_shift_range = height;
52        self
53    }
54
55    pub fn zoom_range(mut self, min: f32, max: f32) -> Self {
56        self.zoom_range = (min, max);
57        self
58    }
59
60    pub fn brightness_range(mut self, min: f32, max: f32) -> Self {
61        self.brightness_range = (min, max);
62        self
63    }
64
65    pub fn augment(&self, image: &Tensor) -> Tensor {
66        let mut rng = match self.random_seed {
67            Some(seed) => StdRng::seed_from_u64(seed),
68            None => StdRng::from_entropy(),
69        };
70
71        let dims = image.dims();
72        let mut data = image.data_f32().to_vec();
73
74        // Horizontal flip
75        if self.horizontal_flip && rng.gen::<f32>() > 0.5 {
76            data = self.flip_horizontal(&data, dims);
77        }
78
79        // Vertical flip
80        if self.vertical_flip && rng.gen::<f32>() > 0.5 {
81            data = self.flip_vertical(&data, dims);
82        }
83
84        // Brightness adjustment
85        if self.brightness_range.0 != 1.0 || self.brightness_range.1 != 1.0 {
86            let factor = rng.gen::<f32>() * (self.brightness_range.1 - self.brightness_range.0) 
87                + self.brightness_range.0;
88            for pixel in &mut data {
89                *pixel *= factor;
90                *pixel = pixel.clamp(0.0, 1.0);
91            }
92        }
93
94        Tensor::from_slice(&data, dims).unwrap()
95    }
96
97    fn flip_horizontal(&self, data: &[f32], dims: &[usize]) -> Vec<f32> {
98        let (channels, height, width) = if dims.len() == 3 {
99            (dims[0], dims[1], dims[2])
100        } else {
101            (1, dims[0], dims[1])
102        };
103
104        let mut flipped = vec![0.0f32; data.len()];
105
106        for c in 0..channels {
107            for h in 0..height {
108                for w in 0..width {
109                    let src_idx = c * height * width + h * width + w;
110                    let dst_idx = c * height * width + h * width + (width - 1 - w);
111                    flipped[dst_idx] = data[src_idx];
112                }
113            }
114        }
115
116        flipped
117    }
118
119    fn flip_vertical(&self, data: &[f32], dims: &[usize]) -> Vec<f32> {
120        let (channels, height, width) = if dims.len() == 3 {
121            (dims[0], dims[1], dims[2])
122        } else {
123            (1, dims[0], dims[1])
124        };
125
126        let mut flipped = vec![0.0f32; data.len()];
127
128        for c in 0..channels {
129            for h in 0..height {
130                for w in 0..width {
131                    let src_idx = c * height * width + h * width + w;
132                    let dst_idx = c * height * width + (height - 1 - h) * width + w;
133                    flipped[dst_idx] = data[src_idx];
134                }
135            }
136        }
137
138        flipped
139    }
140}
141
142impl Default for ImageAugmentation {
143    fn default() -> Self { Self::new() }
144}
145
146/// Image normalization
147pub struct ImageNormalization {
148    pub mean: Vec<f32>,
149    pub std: Vec<f32>,
150}
151
152impl ImageNormalization {
153    pub fn new(mean: Vec<f32>, std: Vec<f32>) -> Self {
154        ImageNormalization { mean, std }
155    }
156
157    /// ImageNet normalization
158    pub fn imagenet() -> Self {
159        ImageNormalization {
160            mean: vec![0.485, 0.456, 0.406],
161            std: vec![0.229, 0.224, 0.225],
162        }
163    }
164
165    pub fn normalize(&self, image: &Tensor) -> Tensor {
166        let dims = image.dims();
167        let data = image.data_f32();
168        let _channels = if dims.len() == 3 { dims[0] } else { 1 };
169
170        let normalized: Vec<f32> = data.iter()
171            .enumerate()
172            .map(|(i, &pixel)| {
173                let c = if dims.len() == 3 {
174                    i / (dims[1] * dims[2])
175                } else {
176                    0
177                };
178                (pixel - self.mean[c % self.mean.len()]) / self.std[c % self.std.len()]
179            })
180            .collect();
181
182        Tensor::from_slice(&normalized, dims).unwrap()
183    }
184
185    pub fn denormalize(&self, image: &Tensor) -> Tensor {
186        let dims = image.dims();
187        let data = image.data_f32();
188        let _channels = if dims.len() == 3 { dims[0] } else { 1 };
189
190        let denormalized: Vec<f32> = data.iter()
191            .enumerate()
192            .map(|(i, &pixel)| {
193                let c = if dims.len() == 3 {
194                    i / (dims[1] * dims[2])
195                } else {
196                    0
197                };
198                pixel * self.std[c % self.std.len()] + self.mean[c % self.mean.len()]
199            })
200            .collect();
201
202        Tensor::from_slice(&denormalized, dims).unwrap()
203    }
204}
205
206/// Image resizing
207pub struct ImageResize {
208    pub target_size: (usize, usize),
209    pub interpolation: Interpolation,
210}
211
212#[derive(Clone, Copy)]
213pub enum Interpolation {
214    Nearest,
215    Bilinear,
216}
217
218impl ImageResize {
219    pub fn new(width: usize, height: usize) -> Self {
220        ImageResize {
221            target_size: (width, height),
222            interpolation: Interpolation::Bilinear,
223        }
224    }
225
226    pub fn interpolation(mut self, interp: Interpolation) -> Self {
227        self.interpolation = interp;
228        self
229    }
230
231    pub fn resize(&self, image: &Tensor) -> Tensor {
232        let dims = image.dims();
233        let data = image.data_f32();
234
235        let (channels, src_height, src_width) = if dims.len() == 3 {
236            (dims[0], dims[1], dims[2])
237        } else {
238            (1, dims[0], dims[1])
239        };
240
241        let (dst_width, dst_height) = self.target_size;
242
243        match self.interpolation {
244            Interpolation::Nearest => {
245                self.resize_nearest(&data, channels, src_height, src_width, dst_height, dst_width)
246            }
247            Interpolation::Bilinear => {
248                self.resize_bilinear(&data, channels, src_height, src_width, dst_height, dst_width)
249            }
250        }
251    }
252
253    fn resize_nearest(
254        &self,
255        data: &[f32],
256        channels: usize,
257        src_h: usize,
258        src_w: usize,
259        dst_h: usize,
260        dst_w: usize,
261    ) -> Tensor {
262        let mut resized = vec![0.0f32; channels * dst_h * dst_w];
263
264        let scale_h = src_h as f32 / dst_h as f32;
265        let scale_w = src_w as f32 / dst_w as f32;
266
267        for c in 0..channels {
268            for h in 0..dst_h {
269                for w in 0..dst_w {
270                    let src_h_idx = (h as f32 * scale_h) as usize;
271                    let src_w_idx = (w as f32 * scale_w) as usize;
272
273                    let src_idx = c * src_h * src_w + src_h_idx * src_w + src_w_idx;
274                    let dst_idx = c * dst_h * dst_w + h * dst_w + w;
275
276                    resized[dst_idx] = data[src_idx];
277                }
278            }
279        }
280
281        let dims = if channels == 1 {
282            vec![dst_h, dst_w]
283        } else {
284            vec![channels, dst_h, dst_w]
285        };
286
287        Tensor::from_slice(&resized, &dims).unwrap()
288    }
289
290    fn resize_bilinear(
291        &self,
292        data: &[f32],
293        channels: usize,
294        src_h: usize,
295        src_w: usize,
296        dst_h: usize,
297        dst_w: usize,
298    ) -> Tensor {
299        let mut resized = vec![0.0f32; channels * dst_h * dst_w];
300
301        let scale_h = src_h as f32 / dst_h as f32;
302        let scale_w = src_w as f32 / dst_w as f32;
303
304        for c in 0..channels {
305            for h in 0..dst_h {
306                for w in 0..dst_w {
307                    let src_h_f = h as f32 * scale_h;
308                    let src_w_f = w as f32 * scale_w;
309
310                    let h0 = src_h_f.floor() as usize;
311                    let w0 = src_w_f.floor() as usize;
312                    let h1 = (h0 + 1).min(src_h - 1);
313                    let w1 = (w0 + 1).min(src_w - 1);
314
315                    let dh = src_h_f - h0 as f32;
316                    let dw = src_w_f - w0 as f32;
317
318                    let idx00 = c * src_h * src_w + h0 * src_w + w0;
319                    let idx01 = c * src_h * src_w + h0 * src_w + w1;
320                    let idx10 = c * src_h * src_w + h1 * src_w + w0;
321                    let idx11 = c * src_h * src_w + h1 * src_w + w1;
322
323                    let val = (1.0 - dh) * (1.0 - dw) * data[idx00]
324                        + (1.0 - dh) * dw * data[idx01]
325                        + dh * (1.0 - dw) * data[idx10]
326                        + dh * dw * data[idx11];
327
328                    let dst_idx = c * dst_h * dst_w + h * dst_w + w;
329                    resized[dst_idx] = val;
330                }
331            }
332        }
333
334        let dims = if channels == 1 {
335            vec![dst_h, dst_w]
336        } else {
337            vec![channels, dst_h, dst_w]
338        };
339
340        Tensor::from_slice(&resized, &dims).unwrap()
341    }
342}
343
344/// Image cropping
345pub struct ImageCrop {
346    pub top: usize,
347    pub left: usize,
348    pub height: usize,
349    pub width: usize,
350}
351
352impl ImageCrop {
353    pub fn new(top: usize, left: usize, height: usize, width: usize) -> Self {
354        ImageCrop { top, left, height, width }
355    }
356
357    pub fn center_crop(image_height: usize, image_width: usize, crop_size: usize) -> Self {
358        let top = (image_height - crop_size) / 2;
359        let left = (image_width - crop_size) / 2;
360        ImageCrop {
361            top,
362            left,
363            height: crop_size,
364            width: crop_size,
365        }
366    }
367
368    pub fn crop(&self, image: &Tensor) -> Tensor {
369        let dims = image.dims();
370        let data = image.data_f32();
371
372        let (channels, src_height, src_width) = if dims.len() == 3 {
373            (dims[0], dims[1], dims[2])
374        } else {
375            (1, dims[0], dims[1])
376        };
377
378        let mut cropped = vec![0.0f32; channels * self.height * self.width];
379
380        for c in 0..channels {
381            for h in 0..self.height {
382                for w in 0..self.width {
383                    let src_h = self.top + h;
384                    let src_w = self.left + w;
385
386                    if src_h < src_height && src_w < src_width {
387                        let src_idx = c * src_height * src_width + src_h * src_width + src_w;
388                        let dst_idx = c * self.height * self.width + h * self.width + w;
389                        cropped[dst_idx] = data[src_idx];
390                    }
391                }
392            }
393        }
394
395        let dims = if channels == 1 {
396            vec![self.height, self.width]
397        } else {
398            vec![channels, self.height, self.width]
399        };
400
401        Tensor::from_slice(&cropped, &dims).unwrap()
402    }
403}
404
405/// Random crop
406pub struct RandomCrop {
407    pub height: usize,
408    pub width: usize,
409    pub random_seed: Option<u64>,
410}
411
412impl RandomCrop {
413    pub fn new(height: usize, width: usize) -> Self {
414        RandomCrop {
415            height,
416            width,
417            random_seed: None,
418        }
419    }
420
421    pub fn crop(&self, image: &Tensor) -> Tensor {
422        let dims = image.dims();
423        let (_, src_height, src_width) = if dims.len() == 3 {
424            (dims[0], dims[1], dims[2])
425        } else {
426            (1, dims[0], dims[1])
427        };
428
429        let mut rng = match self.random_seed {
430            Some(seed) => StdRng::seed_from_u64(seed),
431            None => StdRng::from_entropy(),
432        };
433
434        let max_top = src_height.saturating_sub(self.height);
435        let max_left = src_width.saturating_sub(self.width);
436
437        let top = if max_top > 0 { rng.gen_range(0..=max_top) } else { 0 };
438        let left = if max_left > 0 { rng.gen_range(0..=max_left) } else { 0 };
439
440        let crop = ImageCrop::new(top, left, self.height, self.width);
441        crop.crop(image)
442    }
443}
444
445/// Color jitter
446pub struct ColorJitter {
447    pub brightness: f32,
448    pub contrast: f32,
449    pub saturation: f32,
450    pub hue: f32,
451}
452
453impl ColorJitter {
454    pub fn new() -> Self {
455        ColorJitter {
456            brightness: 0.0,
457            contrast: 0.0,
458            saturation: 0.0,
459            hue: 0.0,
460        }
461    }
462
463    pub fn brightness(mut self, factor: f32) -> Self {
464        self.brightness = factor;
465        self
466    }
467
468    pub fn contrast(mut self, factor: f32) -> Self {
469        self.contrast = factor;
470        self
471    }
472
473    pub fn apply(&self, image: &Tensor) -> Tensor {
474        let mut rng = thread_rng();
475        let data = image.data_f32();
476        let mut jittered = data.to_vec();
477
478        // Brightness
479        if self.brightness > 0.0 {
480            let factor = 1.0 + (rng.gen::<f32>() - 0.5) * 2.0 * self.brightness;
481            for pixel in &mut jittered {
482                *pixel = (*pixel * factor).clamp(0.0, 1.0);
483            }
484        }
485
486        // Contrast
487        if self.contrast > 0.0 {
488            let mean = jittered.iter().sum::<f32>() / jittered.len() as f32;
489            let factor = 1.0 + (rng.gen::<f32>() - 0.5) * 2.0 * self.contrast;
490            for pixel in &mut jittered {
491                *pixel = ((*pixel - mean) * factor + mean).clamp(0.0, 1.0);
492            }
493        }
494
495        Tensor::from_slice(&jittered, image.dims()).unwrap()
496    }
497}
498
499impl Default for ColorJitter {
500    fn default() -> Self { Self::new() }
501}
502
503#[cfg(test)]
504mod tests {
505    use super::*;
506
507    #[test]
508    fn test_image_augmentation() {
509        let image = Tensor::from_slice(&vec![0.5f32; 3 * 32 * 32], &[3, 32, 32]).unwrap();
510        
511        let aug = ImageAugmentation::new()
512            .horizontal_flip(true)
513            .brightness_range(0.8, 1.2);
514
515        let augmented = aug.augment(&image);
516        assert_eq!(augmented.dims(), image.dims());
517    }
518
519    #[test]
520    fn test_image_normalization() {
521        let image = Tensor::from_slice(&vec![0.5f32; 3 * 32 * 32], &[3, 32, 32]).unwrap();
522        let norm = ImageNormalization::imagenet();
523
524        let normalized = norm.normalize(&image);
525        let denormalized = norm.denormalize(&normalized);
526
527        assert_eq!(normalized.dims(), image.dims());
528        assert_eq!(denormalized.dims(), image.dims());
529    }
530
531    #[test]
532    fn test_image_resize() {
533        let image = Tensor::from_slice(&vec![0.5f32; 3 * 64 * 64], &[3, 64, 64]).unwrap();
534        let resize = ImageResize::new(32, 32);
535
536        let resized = resize.resize(&image);
537        assert_eq!(resized.dims(), &[3, 32, 32]);
538    }
539
540    #[test]
541    fn test_image_crop() {
542        let image = Tensor::from_slice(&vec![0.5f32; 3 * 64 * 64], &[3, 64, 64]).unwrap();
543        let crop = ImageCrop::center_crop(64, 64, 32);
544
545        let cropped = crop.crop(&image);
546        assert_eq!(cropped.dims(), &[3, 32, 32]);
547    }
548
549    #[test]
550    fn test_color_jitter() {
551        let image = Tensor::from_slice(&vec![0.5f32; 3 * 32 * 32], &[3, 32, 32]).unwrap();
552        let jitter = ColorJitter::new().brightness(0.2).contrast(0.2);
553
554        let jittered = jitter.apply(&image);
555        assert_eq!(jittered.dims(), image.dims());
556    }
557}
558
559