Skip to main content

oxigdal_ml/
augmentation.rs

1//! Data augmentation for geospatial imagery
2//!
3//! This module provides comprehensive data augmentation techniques specifically
4//! designed for geospatial and satellite imagery.
5
6use crate::error::{PreprocessingError, Result};
7use oxigdal_core::buffer::RasterBuffer;
8// use oxigdal_core::types::RasterDataType;
9use scirs2_core::random::prelude::{StdRng, seeded_rng};
10use tracing::debug;
11
12/// Generates a Gaussian random number using Box-Muller transform
13fn gaussian_random(rng: &mut StdRng, mean: f64, std_dev: f64) -> f64 {
14    use std::f64::consts::PI;
15
16    // Box-Muller transform
17    let u1: f64 = rng.random();
18    let u2: f64 = rng.random();
19
20    let z0: f64 = (-2.0_f64 * u1.ln()).sqrt() * (2.0_f64 * PI * u2).cos();
21    mean + std_dev * z0
22}
23
24/// Augmentation configuration
25#[derive(Debug, Clone, Default)]
26pub struct AugmentationConfig {
27    /// Enable horizontal flip
28    pub horizontal_flip: bool,
29    /// Enable vertical flip
30    pub vertical_flip: bool,
31    /// Rotation angles (degrees)
32    pub rotation_angles: Vec<f32>,
33    /// Enable random crops
34    pub random_crop: bool,
35    /// Crop size (if random_crop enabled)
36    pub crop_size: Option<(u64, u64)>,
37    /// Brightness adjustment range
38    pub brightness_range: Option<(f32, f32)>,
39    /// Contrast adjustment range
40    pub contrast_range: Option<(f32, f32)>,
41    /// Saturation adjustment range (for RGB)
42    pub saturation_range: Option<(f32, f32)>,
43    /// Add Gaussian noise
44    pub gaussian_noise: Option<f32>,
45    /// Add salt and pepper noise
46    pub salt_pepper_noise: Option<f32>,
47    /// Gaussian blur kernel size
48    pub blur_kernel: Option<usize>,
49}
50
51impl AugmentationConfig {
52    /// Creates a builder for augmentation configuration
53    #[must_use]
54    pub fn builder() -> AugmentationConfigBuilder {
55        AugmentationConfigBuilder::default()
56    }
57
58    /// Creates a standard augmentation configuration
59    #[must_use]
60    pub fn standard() -> Self {
61        Self {
62            horizontal_flip: true,
63            vertical_flip: true,
64            rotation_angles: vec![90.0, 180.0, 270.0],
65            random_crop: false,
66            crop_size: None,
67            brightness_range: Some((-0.2, 0.2)),
68            contrast_range: Some((0.8, 1.2)),
69            saturation_range: Some((0.8, 1.2)),
70            gaussian_noise: Some(0.01),
71            salt_pepper_noise: None,
72            blur_kernel: None,
73        }
74    }
75
76    /// Creates an aggressive augmentation configuration
77    #[must_use]
78    pub fn aggressive() -> Self {
79        Self {
80            horizontal_flip: true,
81            vertical_flip: true,
82            rotation_angles: vec![45.0, 90.0, 135.0, 180.0, 225.0, 270.0, 315.0],
83            random_crop: true,
84            crop_size: Some((256, 256)),
85            brightness_range: Some((-0.3, 0.3)),
86            contrast_range: Some((0.7, 1.3)),
87            saturation_range: Some((0.7, 1.3)),
88            gaussian_noise: Some(0.02),
89            salt_pepper_noise: Some(0.01),
90            blur_kernel: Some(3),
91        }
92    }
93}
94
95/// Builder for augmentation configuration
96#[derive(Debug, Default)]
97pub struct AugmentationConfigBuilder {
98    horizontal_flip: bool,
99    vertical_flip: bool,
100    rotation_angles: Vec<f32>,
101    random_crop: bool,
102    crop_size: Option<(u64, u64)>,
103    brightness_range: Option<(f32, f32)>,
104    contrast_range: Option<(f32, f32)>,
105    saturation_range: Option<(f32, f32)>,
106    gaussian_noise: Option<f32>,
107    salt_pepper_noise: Option<f32>,
108    blur_kernel: Option<usize>,
109}
110
111impl AugmentationConfigBuilder {
112    /// Enables horizontal flip
113    #[must_use]
114    pub fn horizontal_flip(mut self, enable: bool) -> Self {
115        self.horizontal_flip = enable;
116        self
117    }
118
119    /// Enables vertical flip
120    #[must_use]
121    pub fn vertical_flip(mut self, enable: bool) -> Self {
122        self.vertical_flip = enable;
123        self
124    }
125
126    /// Sets rotation angles
127    #[must_use]
128    pub fn rotation_angles(mut self, angles: Vec<f32>) -> Self {
129        self.rotation_angles = angles;
130        self
131    }
132
133    /// Enables random cropping
134    #[must_use]
135    pub fn random_crop(mut self, enable: bool, size: Option<(u64, u64)>) -> Self {
136        self.random_crop = enable;
137        self.crop_size = size;
138        self
139    }
140
141    /// Sets brightness adjustment range
142    #[must_use]
143    pub fn brightness_range(mut self, min: f32, max: f32) -> Self {
144        self.brightness_range = Some((min, max));
145        self
146    }
147
148    /// Sets contrast adjustment range
149    #[must_use]
150    pub fn contrast_range(mut self, min: f32, max: f32) -> Self {
151        self.contrast_range = Some((min, max));
152        self
153    }
154
155    /// Sets saturation adjustment range
156    #[must_use]
157    pub fn saturation_range(mut self, min: f32, max: f32) -> Self {
158        self.saturation_range = Some((min, max));
159        self
160    }
161
162    /// Sets Gaussian noise standard deviation
163    #[must_use]
164    pub fn gaussian_noise(mut self, std_dev: f32) -> Self {
165        self.gaussian_noise = Some(std_dev);
166        self
167    }
168
169    /// Sets salt and pepper noise probability
170    #[must_use]
171    pub fn salt_pepper_noise(mut self, prob: f32) -> Self {
172        self.salt_pepper_noise = Some(prob);
173        self
174    }
175
176    /// Sets blur kernel size
177    #[must_use]
178    pub fn blur_kernel(mut self, size: usize) -> Self {
179        self.blur_kernel = Some(size);
180        self
181    }
182
183    /// Builds the configuration
184    #[must_use]
185    pub fn build(self) -> AugmentationConfig {
186        AugmentationConfig {
187            horizontal_flip: self.horizontal_flip,
188            vertical_flip: self.vertical_flip,
189            rotation_angles: self.rotation_angles,
190            random_crop: self.random_crop,
191            crop_size: self.crop_size,
192            brightness_range: self.brightness_range,
193            contrast_range: self.contrast_range,
194            saturation_range: self.saturation_range,
195            gaussian_noise: self.gaussian_noise,
196            salt_pepper_noise: self.salt_pepper_noise,
197            blur_kernel: self.blur_kernel,
198        }
199    }
200}
201
202/// Applies horizontal flip augmentation
203///
204/// # Errors
205/// Returns an error if the operation fails
206pub fn horizontal_flip(input: &RasterBuffer) -> Result<RasterBuffer> {
207    debug!("Applying horizontal flip");
208    let mut output = RasterBuffer::zeros(input.width(), input.height(), input.data_type());
209
210    for y in 0..input.height() {
211        for x in 0..input.width() {
212            let flipped_x = input.width() - 1 - x;
213            let value =
214                input
215                    .get_pixel(x, y)
216                    .map_err(|e| PreprocessingError::AugmentationFailed {
217                        reason: format!("Failed to read pixel: {}", e),
218                    })?;
219            output.set_pixel(flipped_x, y, value).map_err(|e| {
220                PreprocessingError::AugmentationFailed {
221                    reason: format!("Failed to write pixel: {}", e),
222                }
223            })?;
224        }
225    }
226
227    Ok(output)
228}
229
230/// Applies vertical flip augmentation
231///
232/// # Errors
233/// Returns an error if the operation fails
234pub fn vertical_flip(input: &RasterBuffer) -> Result<RasterBuffer> {
235    debug!("Applying vertical flip");
236    let mut output = RasterBuffer::zeros(input.width(), input.height(), input.data_type());
237
238    for y in 0..input.height() {
239        for x in 0..input.width() {
240            let flipped_y = input.height() - 1 - y;
241            let value =
242                input
243                    .get_pixel(x, y)
244                    .map_err(|e| PreprocessingError::AugmentationFailed {
245                        reason: format!("Failed to read pixel: {}", e),
246                    })?;
247            output.set_pixel(x, flipped_y, value).map_err(|e| {
248                PreprocessingError::AugmentationFailed {
249                    reason: format!("Failed to write pixel: {}", e),
250                }
251            })?;
252        }
253    }
254
255    Ok(output)
256}
257
258/// Applies rotation augmentation
259///
260/// # Errors
261/// Returns an error if the operation fails
262pub fn rotate(input: &RasterBuffer, angle_degrees: f32) -> Result<RasterBuffer> {
263    debug!("Applying rotation: {} degrees", angle_degrees);
264
265    // Simple rotation for 90-degree multiples
266    if (angle_degrees % 90.0).abs() < 0.1 {
267        let times = ((angle_degrees / 90.0).round() as i32).rem_euclid(4);
268        return rotate_90_times(input, times as usize);
269    }
270
271    // General rotation with bilinear interpolation
272    rotate_general(input, angle_degrees)
273}
274
275/// Applies general rotation with bilinear interpolation
276///
277/// Uses affine transformation and bilinear interpolation for arbitrary angles.
278/// Out-of-bounds pixels are filled with zeros.
279///
280/// # Errors
281/// Returns an error if pixel access fails
282fn rotate_general(input: &RasterBuffer, angle_degrees: f32) -> Result<RasterBuffer> {
283    let width = input.width() as f64;
284    let height = input.height() as f64;
285
286    // Convert angle to radians
287    let angle = angle_degrees as f64 * std::f64::consts::PI / 180.0;
288    let cos_angle = angle.cos();
289    let sin_angle = angle.sin();
290
291    // Calculate output dimensions to fit the entire rotated image
292    let corners = [(0.0, 0.0), (width, 0.0), (0.0, height), (width, height)];
293
294    let mut min_x = f64::INFINITY;
295    let mut max_x = f64::NEG_INFINITY;
296    let mut min_y = f64::INFINITY;
297    let mut max_y = f64::NEG_INFINITY;
298
299    // Find bounding box of rotated image
300    for (x, y) in &corners {
301        let rotated_x = x * cos_angle - y * sin_angle;
302        let rotated_y = x * sin_angle + y * cos_angle;
303        min_x = min_x.min(rotated_x);
304        max_x = max_x.max(rotated_x);
305        min_y = min_y.min(rotated_y);
306        max_y = max_y.max(rotated_y);
307    }
308
309    let out_width = (max_x - min_x).ceil() as u64;
310    let out_height = (max_y - min_y).ceil() as u64;
311
312    let mut output = RasterBuffer::zeros(out_width, out_height, input.data_type());
313
314    // Center of input image
315    let cx = width / 2.0;
316    let cy = height / 2.0;
317
318    // Center of output image
319    let out_cx = (max_x - min_x) / 2.0;
320    let out_cy = (max_y - min_y) / 2.0;
321
322    // For each output pixel, find corresponding input pixel
323    for out_y in 0..out_height {
324        for out_x in 0..out_width {
325            // Translate to center
326            let x = out_x as f64 - out_cx;
327            let y = out_y as f64 - out_cy;
328
329            // Inverse rotation
330            let src_x = x * cos_angle + y * sin_angle + cx;
331            let src_y = -x * sin_angle + y * cos_angle + cy;
332
333            // Apply bilinear interpolation
334            let value = bilinear_interpolate(input, src_x, src_y)?;
335
336            output.set_pixel(out_x, out_y, value).map_err(|e| {
337                PreprocessingError::AugmentationFailed {
338                    reason: format!("Failed to write pixel: {}", e),
339                }
340            })?;
341        }
342    }
343
344    Ok(output)
345}
346
347/// Applies bilinear interpolation at the given coordinates
348///
349/// Returns 0.0 for out-of-bounds coordinates
350///
351/// # Errors
352/// Returns an error if pixel access fails
353fn bilinear_interpolate(input: &RasterBuffer, x: f64, y: f64) -> Result<f64> {
354    // Check bounds
355    if x < 0.0 || y < 0.0 || x >= input.width() as f64 - 1.0 || y >= input.height() as f64 - 1.0 {
356        return Ok(0.0); // Fill with zeros for out-of-bounds
357    }
358
359    // Get the four surrounding pixels
360    let x0 = x.floor() as u64;
361    let y0 = y.floor() as u64;
362    let x1 = x0 + 1;
363    let y1 = y0 + 1;
364
365    // Fractional parts
366    let dx = x - x0 as f64;
367    let dy = y - y0 as f64;
368
369    // Get pixel values
370    let p00 = input
371        .get_pixel(x0, y0)
372        .map_err(|e| PreprocessingError::AugmentationFailed {
373            reason: format!("Failed to read pixel: {}", e),
374        })?;
375    let p10 = input
376        .get_pixel(x1, y0)
377        .map_err(|e| PreprocessingError::AugmentationFailed {
378            reason: format!("Failed to read pixel: {}", e),
379        })?;
380    let p01 = input
381        .get_pixel(x0, y1)
382        .map_err(|e| PreprocessingError::AugmentationFailed {
383            reason: format!("Failed to read pixel: {}", e),
384        })?;
385    let p11 = input
386        .get_pixel(x1, y1)
387        .map_err(|e| PreprocessingError::AugmentationFailed {
388            reason: format!("Failed to read pixel: {}", e),
389        })?;
390
391    // Bilinear interpolation formula:
392    // value = (1-dx)(1-dy)·p00 + dx(1-dy)·p10 + (1-dx)dy·p01 + dx·dy·p11
393    let value = (1.0 - dx) * (1.0 - dy) * p00
394        + dx * (1.0 - dy) * p10
395        + (1.0 - dx) * dy * p01
396        + dx * dy * p11;
397
398    Ok(value)
399}
400
401/// Rotates image by 90 degrees multiple times
402fn rotate_90_times(input: &RasterBuffer, times: usize) -> Result<RasterBuffer> {
403    let mut current = input.clone();
404
405    for _ in 0..times {
406        current = rotate_90(&current)?;
407    }
408
409    Ok(current)
410}
411
412/// Rotates image by 90 degrees clockwise
413fn rotate_90(input: &RasterBuffer) -> Result<RasterBuffer> {
414    let mut output = RasterBuffer::zeros(input.height(), input.width(), input.data_type());
415
416    for y in 0..input.height() {
417        for x in 0..input.width() {
418            let new_x = y;
419            let new_y = input.width() - 1 - x;
420            let value =
421                input
422                    .get_pixel(x, y)
423                    .map_err(|e| PreprocessingError::AugmentationFailed {
424                        reason: format!("Failed to read pixel: {}", e),
425                    })?;
426            output.set_pixel(new_x, new_y, value).map_err(|e| {
427                PreprocessingError::AugmentationFailed {
428                    reason: format!("Failed to write pixel: {}", e),
429                }
430            })?;
431        }
432    }
433
434    Ok(output)
435}
436
437/// Applies random crop augmentation
438///
439/// # Errors
440/// Returns an error if the operation fails
441pub fn random_crop(
442    input: &RasterBuffer,
443    crop_width: u64,
444    crop_height: u64,
445) -> Result<RasterBuffer> {
446    debug!("Applying random crop: {}x{}", crop_width, crop_height);
447
448    if crop_width > input.width() || crop_height > input.height() {
449        return Err(PreprocessingError::AugmentationFailed {
450            reason: format!(
451                "Crop size ({}x{}) larger than input ({}x{})",
452                crop_width,
453                crop_height,
454                input.width(),
455                input.height()
456            ),
457        }
458        .into());
459    }
460
461    // Use SciRS2-Core RNG for random offset
462    let mut rng = seeded_rng(
463        std::time::SystemTime::now()
464            .duration_since(std::time::UNIX_EPOCH)
465            .map(|d| d.as_secs())
466            .unwrap_or(0),
467    );
468    let max_x_offset = input.width() - crop_width;
469    let max_y_offset = input.height() - crop_height;
470
471    let x_offset = if max_x_offset > 0 {
472        let random_val: f64 = rng.random();
473        (random_val * (max_x_offset + 1) as f64) as u64
474    } else {
475        0
476    };
477    let y_offset = if max_y_offset > 0 {
478        let random_val: f64 = rng.random();
479        (random_val * (max_y_offset + 1) as f64) as u64
480    } else {
481        0
482    };
483
484    let mut output = RasterBuffer::zeros(crop_width, crop_height, input.data_type());
485
486    for y in 0..crop_height {
487        for x in 0..crop_width {
488            let value = input.get_pixel(x + x_offset, y + y_offset).map_err(|e| {
489                PreprocessingError::AugmentationFailed {
490                    reason: format!("Failed to read pixel: {}", e),
491                }
492            })?;
493            output
494                .set_pixel(x, y, value)
495                .map_err(|e| PreprocessingError::AugmentationFailed {
496                    reason: format!("Failed to write pixel: {}", e),
497                })?;
498        }
499    }
500
501    Ok(output)
502}
503
504/// Adjusts brightness
505///
506/// # Errors
507/// Returns an error if the operation fails
508pub fn adjust_brightness(input: &RasterBuffer, delta: f32) -> Result<RasterBuffer> {
509    debug!("Adjusting brightness by {}", delta);
510    let mut output = input.clone();
511
512    for y in 0..input.height() {
513        for x in 0..input.width() {
514            let value =
515                input
516                    .get_pixel(x, y)
517                    .map_err(|e| PreprocessingError::AugmentationFailed {
518                        reason: format!("Failed to read pixel: {}", e),
519                    })?;
520            let adjusted = (value as f32 + delta).clamp(0.0, 1.0) as f64;
521            output.set_pixel(x, y, adjusted).map_err(|e| {
522                PreprocessingError::AugmentationFailed {
523                    reason: format!("Failed to write pixel: {}", e),
524                }
525            })?;
526        }
527    }
528
529    Ok(output)
530}
531
532/// Adjusts contrast
533///
534/// # Errors
535/// Returns an error if the operation fails
536pub fn adjust_contrast(input: &RasterBuffer, factor: f32) -> Result<RasterBuffer> {
537    debug!("Adjusting contrast by factor {}", factor);
538    let mut output = input.clone();
539
540    // Calculate mean
541    let mut sum = 0.0;
542    let mut count = 0u64;
543    for y in 0..input.height() {
544        for x in 0..input.width() {
545            let value =
546                input
547                    .get_pixel(x, y)
548                    .map_err(|e| PreprocessingError::AugmentationFailed {
549                        reason: format!("Failed to read pixel: {}", e),
550                    })?;
551            sum += value;
552            count += 1;
553        }
554    }
555    let mean = sum / count as f64;
556
557    // Adjust contrast
558    for y in 0..input.height() {
559        for x in 0..input.width() {
560            let value =
561                input
562                    .get_pixel(x, y)
563                    .map_err(|e| PreprocessingError::AugmentationFailed {
564                        reason: format!("Failed to read pixel: {}", e),
565                    })?;
566            let adjusted = (mean + (value - mean) * factor as f64).clamp(0.0, 1.0);
567            output.set_pixel(x, y, adjusted).map_err(|e| {
568                PreprocessingError::AugmentationFailed {
569                    reason: format!("Failed to write pixel: {}", e),
570                }
571            })?;
572        }
573    }
574
575    Ok(output)
576}
577
578/// Adds Gaussian noise
579///
580/// # Errors
581/// Returns an error if the operation fails
582pub fn add_gaussian_noise(input: &RasterBuffer, std_dev: f32) -> Result<RasterBuffer> {
583    debug!("Adding Gaussian noise with std_dev={}", std_dev);
584    let mut output = input.clone();
585
586    // Use SciRS2-Core RNG with Box-Muller transform for Gaussian noise
587    let mut rng = seeded_rng(
588        std::time::SystemTime::now()
589            .duration_since(std::time::UNIX_EPOCH)
590            .map(|d| d.as_secs())
591            .unwrap_or(0),
592    );
593
594    for y in 0..input.height() {
595        for x in 0..input.width() {
596            let value =
597                input
598                    .get_pixel(x, y)
599                    .map_err(|e| PreprocessingError::AugmentationFailed {
600                        reason: format!("Failed to read pixel: {}", e),
601                    })?;
602
603            // Generate Gaussian noise using Box-Muller transform
604            let noise = gaussian_random(&mut rng, 0.0, std_dev as f64);
605            let noisy = (value + noise).clamp(0.0, 1.0);
606
607            output
608                .set_pixel(x, y, noisy)
609                .map_err(|e| PreprocessingError::AugmentationFailed {
610                    reason: format!("Failed to write pixel: {}", e),
611                })?;
612        }
613    }
614
615    Ok(output)
616}
617
618/// Applies Gaussian blur
619///
620/// Uses separable Gaussian convolution for efficiency (O(n·k) instead of O(n·k²)).
621/// Applies horizontal blur pass followed by vertical blur pass.
622/// Edges are handled with mirror padding.
623///
624/// # Errors
625/// Returns an error if the operation fails
626pub fn gaussian_blur(input: &RasterBuffer, kernel_size: usize) -> Result<RasterBuffer> {
627    debug!("Applying Gaussian blur with kernel size {}", kernel_size);
628
629    if kernel_size % 2 == 0 {
630        return Err(PreprocessingError::AugmentationFailed {
631            reason: "Kernel size must be odd".to_string(),
632        }
633        .into());
634    }
635
636    if kernel_size < 3 {
637        return Err(PreprocessingError::AugmentationFailed {
638            reason: "Kernel size must be at least 3".to_string(),
639        }
640        .into());
641    }
642
643    // Calculate sigma from kernel size using standard formula
644    // sigma = 0.3 * ((kernel_size - 1) * 0.5 - 1) + 0.8
645    let sigma = 0.3 * ((kernel_size as f64 - 1.0) * 0.5 - 1.0) + 0.8;
646
647    // Generate 1D Gaussian kernel
648    let kernel = generate_gaussian_kernel(kernel_size, sigma)?;
649
650    // Apply horizontal blur
651    let horizontal = apply_horizontal_blur(input, &kernel)?;
652
653    // Apply vertical blur
654    apply_vertical_blur(&horizontal, &kernel)
655}
656
657/// Generates a 1D Gaussian kernel
658///
659/// Uses the formula: G(x) = exp(-x²/(2σ²)) / √(2πσ²)
660/// Normalizes the kernel to sum to 1.0
661///
662/// # Errors
663/// Returns an error if sigma is invalid
664fn generate_gaussian_kernel(size: usize, sigma: f64) -> Result<Vec<f64>> {
665    if sigma <= 0.0 {
666        return Err(PreprocessingError::AugmentationFailed {
667            reason: "Sigma must be positive".to_string(),
668        }
669        .into());
670    }
671
672    let radius = (size / 2) as i32;
673    let mut kernel = Vec::with_capacity(size);
674    let mut sum = 0.0;
675
676    // Generate Gaussian values
677    for i in -radius..=radius {
678        let x = i as f64;
679        let value = (-x * x / (2.0 * sigma * sigma)).exp();
680        kernel.push(value);
681        sum += value;
682    }
683
684    // Normalize to sum to 1.0
685    for value in &mut kernel {
686        *value /= sum;
687    }
688
689    Ok(kernel)
690}
691
692/// Applies horizontal blur using the given kernel
693///
694/// Uses mirror padding at image edges
695///
696/// # Errors
697/// Returns an error if pixel access fails
698fn apply_horizontal_blur(input: &RasterBuffer, kernel: &[f64]) -> Result<RasterBuffer> {
699    let width = input.width();
700    let height = input.height();
701    let radius = (kernel.len() / 2) as i64;
702
703    let mut output = RasterBuffer::zeros(width, height, input.data_type());
704
705    for y in 0..height {
706        for x in 0..width {
707            let mut sum = 0.0;
708
709            // Apply convolution
710            for (k_idx, &k_val) in kernel.iter().enumerate() {
711                let offset = k_idx as i64 - radius;
712                let src_x = x as i64 + offset;
713
714                // Mirror padding for edges
715                let safe_x = if src_x < 0 {
716                    (-src_x) as u64
717                } else if src_x >= width as i64 {
718                    (2 * width as i64 - src_x - 2) as u64
719                } else {
720                    src_x as u64
721                };
722
723                // Clamp to valid range
724                let clamped_x = safe_x.min(width - 1);
725
726                let pixel = input.get_pixel(clamped_x, y).map_err(|e| {
727                    PreprocessingError::AugmentationFailed {
728                        reason: format!("Failed to read pixel: {}", e),
729                    }
730                })?;
731
732                sum += pixel * k_val;
733            }
734
735            output
736                .set_pixel(x, y, sum)
737                .map_err(|e| PreprocessingError::AugmentationFailed {
738                    reason: format!("Failed to write pixel: {}", e),
739                })?;
740        }
741    }
742
743    Ok(output)
744}
745
746/// Applies vertical blur using the given kernel
747///
748/// Uses mirror padding at image edges
749///
750/// # Errors
751/// Returns an error if pixel access fails
752fn apply_vertical_blur(input: &RasterBuffer, kernel: &[f64]) -> Result<RasterBuffer> {
753    let width = input.width();
754    let height = input.height();
755    let radius = (kernel.len() / 2) as i64;
756
757    let mut output = RasterBuffer::zeros(width, height, input.data_type());
758
759    for y in 0..height {
760        for x in 0..width {
761            let mut sum = 0.0;
762
763            // Apply convolution
764            for (k_idx, &k_val) in kernel.iter().enumerate() {
765                let offset = k_idx as i64 - radius;
766                let src_y = y as i64 + offset;
767
768                // Mirror padding for edges
769                let safe_y = if src_y < 0 {
770                    (-src_y) as u64
771                } else if src_y >= height as i64 {
772                    (2 * height as i64 - src_y - 2) as u64
773                } else {
774                    src_y as u64
775                };
776
777                // Clamp to valid range
778                let clamped_y = safe_y.min(height - 1);
779
780                let pixel = input.get_pixel(x, clamped_y).map_err(|e| {
781                    PreprocessingError::AugmentationFailed {
782                        reason: format!("Failed to read pixel: {}", e),
783                    }
784                })?;
785
786                sum += pixel * k_val;
787            }
788
789            output
790                .set_pixel(x, y, sum)
791                .map_err(|e| PreprocessingError::AugmentationFailed {
792                    reason: format!("Failed to write pixel: {}", e),
793                })?;
794        }
795    }
796
797    Ok(output)
798}
799
800/// Applies a sequence of augmentations according to configuration
801///
802/// # Errors
803/// Returns an error if any augmentation fails
804pub fn apply_augmentation(
805    input: &RasterBuffer,
806    config: &AugmentationConfig,
807) -> Result<Vec<RasterBuffer>> {
808    let mut augmented = vec![input.clone()]; // Original
809
810    // Horizontal flip
811    if config.horizontal_flip {
812        augmented.push(horizontal_flip(input)?);
813    }
814
815    // Vertical flip
816    if config.vertical_flip {
817        augmented.push(vertical_flip(input)?);
818    }
819
820    // Rotations
821    for angle in &config.rotation_angles {
822        augmented.push(rotate(input, *angle)?);
823    }
824
825    // Brightness adjustments
826    if let Some((min, max)) = config.brightness_range {
827        augmented.push(adjust_brightness(input, min)?);
828        augmented.push(adjust_brightness(input, max)?);
829    }
830
831    // Contrast adjustments
832    if let Some((min, max)) = config.contrast_range {
833        augmented.push(adjust_contrast(input, min)?);
834        augmented.push(adjust_contrast(input, max)?);
835    }
836
837    // Noise
838    if let Some(std_dev) = config.gaussian_noise {
839        augmented.push(add_gaussian_noise(input, std_dev)?);
840    }
841
842    // Blur
843    if let Some(kernel_size) = config.blur_kernel {
844        augmented.push(gaussian_blur(input, kernel_size)?);
845    }
846
847    Ok(augmented)
848}
849
850#[cfg(test)]
851mod tests {
852    use super::*;
853    use oxigdal_core::types::RasterDataType;
854
855    #[test]
856    fn test_augmentation_config_builder() {
857        let config = AugmentationConfig::builder()
858            .horizontal_flip(true)
859            .vertical_flip(true)
860            .rotation_angles(vec![90.0, 180.0])
861            .brightness_range(-0.2, 0.2)
862            .build();
863
864        assert!(config.horizontal_flip);
865        assert!(config.vertical_flip);
866        assert_eq!(config.rotation_angles.len(), 2);
867    }
868
869    #[test]
870    fn test_standard_config() {
871        let config = AugmentationConfig::standard();
872        assert!(config.horizontal_flip);
873        assert!(config.brightness_range.is_some());
874    }
875
876    #[test]
877    fn test_aggressive_config() {
878        let config = AugmentationConfig::aggressive();
879        assert!(config.random_crop);
880        assert!(config.salt_pepper_noise.is_some());
881        assert!(config.rotation_angles.len() > 3);
882    }
883
884    #[test]
885    fn test_horizontal_flip() {
886        let input = RasterBuffer::zeros(4, 4, RasterDataType::Float32);
887        let result = horizontal_flip(&input);
888        assert!(result.is_ok());
889
890        let output = result.expect("Should succeed");
891        assert_eq!(output.width(), input.width());
892        assert_eq!(output.height(), input.height());
893    }
894
895    #[test]
896    fn test_vertical_flip() {
897        let input = RasterBuffer::zeros(4, 4, RasterDataType::Float32);
898        let result = vertical_flip(&input);
899        assert!(result.is_ok());
900    }
901
902    #[test]
903    fn test_rotate_90() {
904        let input = RasterBuffer::zeros(4, 4, RasterDataType::Float32);
905        let result = rotate(&input, 90.0);
906        assert!(result.is_ok());
907
908        let output = result.expect("Should succeed");
909        assert_eq!(output.width(), input.height());
910        assert_eq!(output.height(), input.width());
911    }
912
913    #[test]
914    fn test_random_crop() {
915        let input = RasterBuffer::zeros(100, 100, RasterDataType::Float32);
916        let result = random_crop(&input, 64, 64);
917        assert!(result.is_ok());
918
919        let output = result.expect("Should succeed");
920        assert_eq!(output.width(), 64);
921        assert_eq!(output.height(), 64);
922    }
923
924    #[test]
925    fn test_random_crop_invalid_size() {
926        let input = RasterBuffer::zeros(50, 50, RasterDataType::Float32);
927        let result = random_crop(&input, 100, 100);
928        assert!(result.is_err());
929    }
930
931    #[test]
932    fn test_rotate_45_degrees() {
933        // Test rotation at 45 degrees
934        let mut input = RasterBuffer::zeros(10, 10, RasterDataType::Float32);
935
936        // Set a specific pixel to test rotation
937        let _ = input.set_pixel(5, 5, 1.0);
938
939        let result = rotate(&input, 45.0);
940        assert!(result.is_ok());
941
942        let output = result.expect("Should succeed");
943        // Output dimensions should be larger to accommodate rotation
944        assert!(output.width() > input.width() || output.height() > input.height());
945    }
946
947    #[test]
948    fn test_rotate_180_degrees() {
949        // Test rotation at 180 degrees (should use optimized path)
950        let mut input = RasterBuffer::zeros(8, 8, RasterDataType::Float32);
951
952        // Set corner pixels
953        let _ = input.set_pixel(0, 0, 1.0);
954        let _ = input.set_pixel(7, 7, 0.5);
955
956        let result = rotate(&input, 180.0);
957        assert!(result.is_ok());
958
959        let output = result.expect("Should succeed");
960        assert_eq!(output.width(), input.width());
961        assert_eq!(output.height(), input.height());
962
963        // Check that corners are swapped
964        let top_left = output.get_pixel(0, 0).expect("Should succeed");
965        let bottom_right = output.get_pixel(7, 7).expect("Should succeed");
966
967        assert!((top_left - 0.5).abs() < 0.01, "Top-left should be ~0.5");
968        assert!(
969            (bottom_right - 1.0).abs() < 0.01,
970            "Bottom-right should be ~1.0"
971        );
972    }
973
974    #[test]
975    fn test_rotate_arbitrary_angle() {
976        // Test rotation at arbitrary angle
977        let input = RasterBuffer::zeros(16, 16, RasterDataType::Float32);
978
979        let result = rotate(&input, 30.0);
980        assert!(result.is_ok());
981
982        let output = result.expect("Should succeed");
983        // Verify output is created (dimensions may vary)
984        assert!(output.width() > 0);
985        assert!(output.height() > 0);
986    }
987
988    #[test]
989    fn test_bilinear_interpolation() {
990        // Test bilinear interpolation function
991        let mut input = RasterBuffer::zeros(4, 4, RasterDataType::Float32);
992
993        // Set up a gradient
994        for y in 0..4 {
995            for x in 0..4 {
996                let value = (x + y) as f64 * 0.1;
997                let _ = input.set_pixel(x, y, value);
998            }
999        }
1000
1001        // Test interpolation at fractional coordinates
1002        let result = bilinear_interpolate(&input, 1.5, 1.5);
1003        assert!(result.is_ok());
1004
1005        let value = result.expect("Should succeed");
1006        // Value should be between surrounding pixels
1007        assert!(
1008            (0.2..=0.4).contains(&value),
1009            "Value {} not in expected range",
1010            value
1011        );
1012    }
1013
1014    #[test]
1015    fn test_bilinear_interpolation_out_of_bounds() {
1016        let input = RasterBuffer::zeros(4, 4, RasterDataType::Float32);
1017
1018        // Test out-of-bounds coordinates
1019        let result = bilinear_interpolate(&input, -1.0, 2.0);
1020        assert!(result.is_ok());
1021
1022        let value = result.expect("Should succeed");
1023        assert_eq!(value, 0.0, "Out-of-bounds should return 0.0");
1024    }
1025
1026    #[test]
1027    fn test_gaussian_blur_basic() {
1028        // Test basic Gaussian blur
1029        let input = RasterBuffer::zeros(16, 16, RasterDataType::Float32);
1030
1031        let result = gaussian_blur(&input, 3);
1032        assert!(result.is_ok());
1033
1034        let output = result.expect("Should succeed");
1035        assert_eq!(output.width(), input.width());
1036        assert_eq!(output.height(), input.height());
1037    }
1038
1039    #[test]
1040    fn test_gaussian_blur_larger_kernel() {
1041        // Test with larger kernel
1042        let input = RasterBuffer::zeros(20, 20, RasterDataType::Float32);
1043
1044        let result = gaussian_blur(&input, 5);
1045        assert!(result.is_ok());
1046
1047        let output = result.expect("Should succeed");
1048        assert_eq!(output.width(), input.width());
1049        assert_eq!(output.height(), input.height());
1050    }
1051
1052    #[test]
1053    fn test_gaussian_blur_even_kernel_fails() {
1054        // Test that even kernel size fails
1055        let input = RasterBuffer::zeros(10, 10, RasterDataType::Float32);
1056
1057        let result = gaussian_blur(&input, 4);
1058        assert!(result.is_err());
1059    }
1060
1061    #[test]
1062    fn test_gaussian_blur_too_small_kernel_fails() {
1063        // Test that kernel size < 3 fails
1064        let input = RasterBuffer::zeros(10, 10, RasterDataType::Float32);
1065
1066        let result = gaussian_blur(&input, 1);
1067        assert!(result.is_err());
1068    }
1069
1070    #[test]
1071    fn test_gaussian_blur_smoothing() {
1072        // Test that blur actually smooths the image
1073        let mut input = RasterBuffer::zeros(10, 10, RasterDataType::Float32);
1074
1075        // Create a sharp edge
1076        for y in 0..10 {
1077            for x in 0..5 {
1078                let _ = input.set_pixel(x, y, 1.0);
1079            }
1080        }
1081
1082        let result = gaussian_blur(&input, 3);
1083        assert!(result.is_ok());
1084
1085        let output = result.expect("Should succeed");
1086
1087        // Check that the edge is smoothed (pixel at boundary should be between 0 and 1)
1088        let edge_value = output.get_pixel(4, 5).expect("Should succeed");
1089        assert!(
1090            edge_value > 0.4 && edge_value < 0.85,
1091            "Edge value {} should be smoothed",
1092            edge_value
1093        );
1094
1095        // Check that the transition area has smoothing
1096        let inside_value = output.get_pixel(2, 5).expect("Should succeed");
1097        let outside_value = output.get_pixel(7, 5).expect("Should succeed");
1098
1099        // Inside should be close to 1, outside should be close to 0
1100        assert!(
1101            inside_value > 0.9,
1102            "Inside value {} should be high",
1103            inside_value
1104        );
1105        assert!(
1106            outside_value < 0.15,
1107            "Outside value {} should be low",
1108            outside_value
1109        );
1110    }
1111
1112    #[test]
1113    fn test_generate_gaussian_kernel() {
1114        // Test kernel generation
1115        let result = generate_gaussian_kernel(5, 1.0);
1116        assert!(result.is_ok());
1117
1118        let kernel = result.expect("Should succeed");
1119        assert_eq!(kernel.len(), 5);
1120
1121        // Kernel should sum to approximately 1.0
1122        let sum: f64 = kernel.iter().sum();
1123        assert!(
1124            (sum - 1.0).abs() < 1e-10,
1125            "Kernel sum {} should be ~1.0",
1126            sum
1127        );
1128
1129        // Kernel should be symmetric
1130        assert!((kernel[0] - kernel[4]).abs() < 1e-10);
1131        assert!((kernel[1] - kernel[3]).abs() < 1e-10);
1132
1133        // Center should be largest value
1134        assert!(kernel[2] > kernel[1]);
1135        assert!(kernel[2] > kernel[0]);
1136    }
1137
1138    #[test]
1139    fn test_generate_gaussian_kernel_invalid_sigma() {
1140        // Test that invalid sigma fails
1141        let result = generate_gaussian_kernel(5, 0.0);
1142        assert!(result.is_err());
1143
1144        let result = generate_gaussian_kernel(5, -1.0);
1145        assert!(result.is_err());
1146    }
1147
1148    #[test]
1149    fn test_horizontal_blur() {
1150        // Test horizontal blur function
1151        let mut input = RasterBuffer::zeros(10, 10, RasterDataType::Float32);
1152
1153        // Set middle row to 1.0
1154        for x in 0..10 {
1155            let _ = input.set_pixel(x, 5, 1.0);
1156        }
1157
1158        let kernel = vec![0.25, 0.5, 0.25]; // Simple box-like kernel
1159        let result = apply_horizontal_blur(&input, &kernel);
1160        assert!(result.is_ok());
1161
1162        let output = result.expect("Should succeed");
1163
1164        // Middle row should still be high, but slightly smoothed
1165        let value = output.get_pixel(5, 5).expect("Should succeed");
1166        assert!(value > 0.4, "Value {} should be high", value);
1167    }
1168
1169    #[test]
1170    fn test_vertical_blur() {
1171        // Test vertical blur function
1172        let mut input = RasterBuffer::zeros(10, 10, RasterDataType::Float32);
1173
1174        // Set middle column to 1.0
1175        for y in 0..10 {
1176            let _ = input.set_pixel(5, y, 1.0);
1177        }
1178
1179        let kernel = vec![0.25, 0.5, 0.25]; // Simple box-like kernel
1180        let result = apply_vertical_blur(&input, &kernel);
1181        assert!(result.is_ok());
1182
1183        let output = result.expect("Should succeed");
1184
1185        // Middle column should still be high, but slightly smoothed
1186        let value = output.get_pixel(5, 5).expect("Should succeed");
1187        assert!(value > 0.4, "Value {} should be high", value);
1188    }
1189
1190    #[test]
1191    fn test_blur_edge_handling() {
1192        // Test that edge handling works correctly
1193        let mut input = RasterBuffer::zeros(5, 5, RasterDataType::Float32);
1194
1195        // Set corner pixel
1196        let _ = input.set_pixel(0, 0, 1.0);
1197
1198        let result = gaussian_blur(&input, 3);
1199        assert!(result.is_ok());
1200
1201        let output = result.expect("Should succeed");
1202
1203        // Corner should be smoothed but not zero (due to mirror padding)
1204        let corner = output.get_pixel(0, 0).expect("Should succeed");
1205        assert!(
1206            corner > 0.0,
1207            "Corner should be non-zero due to mirror padding"
1208        );
1209    }
1210}