Skip to main content

oxigdal_ml/
preprocessing.rs

1//! Data preprocessing for ML workflows
2//!
3//! This module provides preprocessing operations for geospatial data
4//! before ML inference.
5
6use oxigdal_core::buffer::RasterBuffer;
7// use oxigdal_core::types::RasterDataType;
8// use rayon::prelude::*;
9use serde::{Deserialize, Serialize};
10use tracing::debug;
11
12use crate::error::{PreprocessingError, Result};
13
14/// Normalization parameters
15#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct NormalizationParams {
17    /// Per-channel mean values
18    pub mean: Vec<f64>,
19    /// Per-channel standard deviation values
20    pub std: Vec<f64>,
21}
22
23impl NormalizationParams {
24    /// Creates ImageNet normalization parameters
25    #[must_use]
26    pub fn imagenet() -> Self {
27        Self {
28            mean: vec![0.485, 0.456, 0.406],
29            std: vec![0.229, 0.224, 0.225],
30        }
31    }
32
33    /// Creates normalization parameters for a given range
34    #[must_use]
35    pub fn from_range(min: f64, max: f64) -> Self {
36        let mean = (min + max) / 2.0;
37        let std = (max - min) / 2.0;
38        Self {
39            mean: vec![mean],
40            std: vec![std],
41        }
42    }
43
44    /// Creates zero-mean unit-variance normalization
45    #[must_use]
46    pub fn zero_mean_unit_variance() -> Self {
47        Self {
48            mean: vec![0.0],
49            std: vec![1.0],
50        }
51    }
52}
53
54/// Padding strategy for tiles
55#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
56pub enum PaddingStrategy {
57    /// Zero padding
58    Zero,
59    /// Replicate edge values
60    Replicate,
61    /// Reflect values at boundaries
62    Reflect,
63    /// Wrap around to opposite edge
64    Wrap,
65}
66
67/// Tile configuration
68#[derive(Debug, Clone)]
69pub struct TileConfig {
70    /// Tile width
71    pub tile_width: usize,
72    /// Tile height
73    pub tile_height: usize,
74    /// Overlap between tiles (in pixels)
75    pub overlap: usize,
76    /// Padding strategy
77    pub padding: PaddingStrategy,
78}
79
80impl Default for TileConfig {
81    fn default() -> Self {
82        Self {
83            tile_width: 256,
84            tile_height: 256,
85            overlap: 32,
86            padding: PaddingStrategy::Replicate,
87        }
88    }
89}
90
91/// A single tile from a raster
92#[derive(Debug, Clone)]
93pub struct Tile {
94    /// The tile buffer
95    pub buffer: RasterBuffer,
96    /// X offset in the original raster
97    pub x_offset: u64,
98    /// Y offset in the original raster
99    pub y_offset: u64,
100    /// Original raster width
101    pub original_width: u64,
102    /// Original raster height
103    pub original_height: u64,
104}
105
106/// Normalizes a raster buffer using the given parameters
107///
108/// # Errors
109/// Returns an error if normalization fails
110pub fn normalize(buffer: &RasterBuffer, params: &NormalizationParams) -> Result<RasterBuffer> {
111    if params.mean.is_empty() || params.std.is_empty() {
112        return Err(PreprocessingError::InvalidNormalization {
113            message: "Mean and std must not be empty".to_string(),
114        }
115        .into());
116    }
117
118    if params.std.contains(&0.0) {
119        return Err(PreprocessingError::InvalidNormalization {
120            message: "Standard deviation cannot be zero".to_string(),
121        }
122        .into());
123    }
124
125    let mut result = buffer.clone();
126
127    // Normalize each pixel
128    for y in 0..buffer.height() {
129        for x in 0..buffer.width() {
130            let pixel =
131                buffer
132                    .get_pixel(x, y)
133                    .map_err(|e| PreprocessingError::InvalidNormalization {
134                        message: format!("Failed to get pixel: {}", e),
135                    })?;
136
137            // Use first channel params if only one set is provided
138            let channel_idx = 0;
139            let mean = params.mean[channel_idx];
140            let std = params.std[channel_idx];
141
142            let normalized = (pixel - mean) / std;
143
144            result.set_pixel(x, y, normalized).map_err(|e| {
145                PreprocessingError::InvalidNormalization {
146                    message: format!("Failed to set pixel: {}", e),
147                }
148            })?;
149        }
150    }
151
152    Ok(result)
153}
154
155/// Tiles a raster buffer into smaller tiles
156///
157/// # Errors
158/// Returns an error if tiling fails
159pub fn tile_raster(buffer: &RasterBuffer, config: &TileConfig) -> Result<Vec<Tile>> {
160    if config.tile_width == 0 || config.tile_height == 0 {
161        return Err(PreprocessingError::InvalidTileSize {
162            width: config.tile_width,
163            height: config.tile_height,
164        }
165        .into());
166    }
167
168    let width = buffer.width();
169    let height = buffer.height();
170
171    debug!(
172        "Tiling {}x{} raster into {}x{} tiles with {} overlap",
173        width, height, config.tile_width, config.tile_height, config.overlap
174    );
175
176    let mut tiles = Vec::new();
177
178    let stride_x = config.tile_width.saturating_sub(config.overlap);
179    let stride_y = config.tile_height.saturating_sub(config.overlap);
180
181    if stride_x == 0 || stride_y == 0 {
182        return Err(PreprocessingError::TilingFailed {
183            reason: "Overlap is too large for the tile size".to_string(),
184        }
185        .into());
186    }
187
188    let mut y = 0u64;
189    while y < height {
190        let mut x = 0u64;
191        while x < width {
192            let tile_width = (width - x).min(config.tile_width as u64);
193            let tile_height = (height - y).min(config.tile_height as u64);
194
195            let tile_buffer = extract_tile(buffer, x, y, tile_width, tile_height, config)?;
196
197            tiles.push(Tile {
198                buffer: tile_buffer,
199                x_offset: x,
200                y_offset: y,
201                original_width: width,
202                original_height: height,
203            });
204
205            x = x.saturating_add(stride_x as u64);
206            if x >= width {
207                break;
208            }
209        }
210
211        y = y.saturating_add(stride_y as u64);
212        if y >= height {
213            break;
214        }
215    }
216
217    debug!("Created {} tiles", tiles.len());
218
219    Ok(tiles)
220}
221
222/// Extracts a tile from a raster buffer
223fn extract_tile(
224    buffer: &RasterBuffer,
225    x: u64,
226    y: u64,
227    width: u64,
228    height: u64,
229    config: &TileConfig,
230) -> Result<RasterBuffer> {
231    let mut tile = RasterBuffer::zeros(
232        config.tile_width as u64,
233        config.tile_height as u64,
234        buffer.data_type(),
235    );
236
237    // Copy pixels from source to tile
238    for ty in 0..height {
239        for tx in 0..width {
240            let src_x = x + tx;
241            let src_y = y + ty;
242
243            let pixel =
244                buffer
245                    .get_pixel(src_x, src_y)
246                    .map_err(|e| PreprocessingError::TilingFailed {
247                        reason: format!("Failed to get pixel: {}", e),
248                    })?;
249
250            tile.set_pixel(tx, ty, pixel)
251                .map_err(|e| PreprocessingError::TilingFailed {
252                    reason: format!("Failed to set pixel: {}", e),
253                })?;
254        }
255    }
256
257    // Apply padding if tile is smaller than requested size
258    if width < config.tile_width as u64 || height < config.tile_height as u64 {
259        apply_padding(&mut tile, width, height, config.padding)?;
260    }
261
262    Ok(tile)
263}
264
265/// Applies padding to a tile
266fn apply_padding(
267    tile: &mut RasterBuffer,
268    valid_width: u64,
269    valid_height: u64,
270    strategy: PaddingStrategy,
271) -> Result<()> {
272    let tile_width = tile.width();
273    let tile_height = tile.height();
274
275    match strategy {
276        PaddingStrategy::Zero => {
277            // Zeros are already filled by RasterBuffer::zeros
278            Ok(())
279        }
280        PaddingStrategy::Replicate => {
281            // Replicate right edge
282            if valid_width < tile_width {
283                let edge_x = valid_width.saturating_sub(1);
284                for y in 0..valid_height {
285                    let edge_value = tile.get_pixel(edge_x, y).map_err(|e| {
286                        PreprocessingError::PaddingFailed {
287                            reason: format!("Failed to get edge pixel: {}", e),
288                        }
289                    })?;
290                    for x in valid_width..tile_width {
291                        tile.set_pixel(x, y, edge_value).map_err(|e| {
292                            PreprocessingError::PaddingFailed {
293                                reason: format!("Failed to set pixel: {}", e),
294                            }
295                        })?;
296                    }
297                }
298            }
299
300            // Replicate bottom edge
301            if valid_height < tile_height {
302                let edge_y = valid_height.saturating_sub(1);
303                for x in 0..tile_width {
304                    let edge_value = tile.get_pixel(x, edge_y).map_err(|e| {
305                        PreprocessingError::PaddingFailed {
306                            reason: format!("Failed to get edge pixel: {}", e),
307                        }
308                    })?;
309                    for y in valid_height..tile_height {
310                        tile.set_pixel(x, y, edge_value).map_err(|e| {
311                            PreprocessingError::PaddingFailed {
312                                reason: format!("Failed to set pixel: {}", e),
313                            }
314                        })?;
315                    }
316                }
317            }
318
319            Ok(())
320        }
321        PaddingStrategy::Reflect => {
322            // Simplified reflection padding
323            if valid_width < tile_width {
324                for y in 0..valid_height {
325                    for x in valid_width..tile_width {
326                        let reflect_x =
327                            valid_width.saturating_sub((x - valid_width + 1).min(valid_width));
328                        let value = tile.get_pixel(reflect_x, y).map_err(|e| {
329                            PreprocessingError::PaddingFailed {
330                                reason: format!("Failed to get reflected pixel: {}", e),
331                            }
332                        })?;
333                        tile.set_pixel(x, y, value).map_err(|e| {
334                            PreprocessingError::PaddingFailed {
335                                reason: format!("Failed to set pixel: {}", e),
336                            }
337                        })?;
338                    }
339                }
340            }
341
342            if valid_height < tile_height {
343                for x in 0..tile_width {
344                    for y in valid_height..tile_height {
345                        let reflect_y =
346                            valid_height.saturating_sub((y - valid_height + 1).min(valid_height));
347                        let value = tile.get_pixel(x, reflect_y).map_err(|e| {
348                            PreprocessingError::PaddingFailed {
349                                reason: format!("Failed to get reflected pixel: {}", e),
350                            }
351                        })?;
352                        tile.set_pixel(x, y, value).map_err(|e| {
353                            PreprocessingError::PaddingFailed {
354                                reason: format!("Failed to set pixel: {}", e),
355                            }
356                        })?;
357                    }
358                }
359            }
360
361            Ok(())
362        }
363        PaddingStrategy::Wrap => {
364            // Wrap around to opposite edge
365            if valid_width < tile_width && valid_width > 0 {
366                for y in 0..valid_height {
367                    for x in valid_width..tile_width {
368                        let wrap_x = (x - valid_width) % valid_width;
369                        let value = tile.get_pixel(wrap_x, y).map_err(|e| {
370                            PreprocessingError::PaddingFailed {
371                                reason: format!("Failed to get wrapped pixel: {}", e),
372                            }
373                        })?;
374                        tile.set_pixel(x, y, value).map_err(|e| {
375                            PreprocessingError::PaddingFailed {
376                                reason: format!("Failed to set pixel: {}", e),
377                            }
378                        })?;
379                    }
380                }
381            }
382
383            if valid_height < tile_height && valid_height > 0 {
384                for x in 0..tile_width {
385                    for y in valid_height..tile_height {
386                        let wrap_y = (y - valid_height) % valid_height;
387                        let value = tile.get_pixel(x, wrap_y).map_err(|e| {
388                            PreprocessingError::PaddingFailed {
389                                reason: format!("Failed to get wrapped pixel: {}", e),
390                            }
391                        })?;
392                        tile.set_pixel(x, y, value).map_err(|e| {
393                            PreprocessingError::PaddingFailed {
394                                reason: format!("Failed to set pixel: {}", e),
395                            }
396                        })?;
397                    }
398                }
399            }
400
401            Ok(())
402        }
403    }
404}
405
406/// Resizes a raster buffer using nearest neighbor interpolation
407///
408/// # Errors
409/// Returns an error if resizing fails
410pub fn resize_nearest(
411    buffer: &RasterBuffer,
412    new_width: u64,
413    new_height: u64,
414) -> Result<RasterBuffer> {
415    let mut result = RasterBuffer::zeros(new_width, new_height, buffer.data_type());
416
417    let x_ratio = buffer.width() as f64 / new_width as f64;
418    let y_ratio = buffer.height() as f64 / new_height as f64;
419
420    for y in 0..new_height {
421        for x in 0..new_width {
422            let src_x = (x as f64 * x_ratio) as u64;
423            let src_y = (y as f64 * y_ratio) as u64;
424
425            let pixel = buffer.get_pixel(src_x, src_y).map_err(|e| {
426                PreprocessingError::InvalidNormalization {
427                    message: format!("Failed to get pixel during resize: {}", e),
428                }
429            })?;
430
431            result.set_pixel(x, y, pixel).map_err(|e| {
432                PreprocessingError::InvalidNormalization {
433                    message: format!("Failed to set pixel during resize: {}", e),
434                }
435            })?;
436        }
437    }
438
439    Ok(result)
440}
441
442#[cfg(test)]
443mod tests {
444    use super::*;
445    use oxigdal_core::types::RasterDataType;
446
447    #[test]
448    fn test_normalization_params() {
449        let params = NormalizationParams::imagenet();
450        assert_eq!(params.mean.len(), 3);
451        assert_eq!(params.std.len(), 3);
452
453        let params = NormalizationParams::from_range(0.0, 255.0);
454        assert!((params.mean[0] - 127.5).abs() < f64::EPSILON);
455    }
456
457    #[test]
458    fn test_normalize() {
459        let buffer = RasterBuffer::zeros(10, 10, RasterDataType::Float32);
460        let params = NormalizationParams::zero_mean_unit_variance();
461
462        let result = normalize(&buffer, &params);
463        assert!(result.is_ok());
464    }
465
466    #[test]
467    fn test_tile_config_default() {
468        let config = TileConfig::default();
469        assert_eq!(config.tile_width, 256);
470        assert_eq!(config.tile_height, 256);
471        assert_eq!(config.overlap, 32);
472    }
473
474    #[test]
475    fn test_tiling() {
476        let buffer = RasterBuffer::zeros(512, 512, RasterDataType::Float32);
477        let config = TileConfig::default();
478
479        let tiles = tile_raster(&buffer, &config);
480        assert!(tiles.is_ok());
481        let tiles = tiles.ok().unwrap_or_default();
482        assert!(!tiles.is_empty());
483    }
484
485    #[test]
486    fn test_resize_nearest() {
487        let buffer = RasterBuffer::zeros(100, 100, RasterDataType::Float32);
488        let resized = resize_nearest(&buffer, 50, 50);
489        assert!(resized.is_ok());
490        let resized = resized
491            .ok()
492            .unwrap_or_else(|| RasterBuffer::zeros(1, 1, RasterDataType::Float32));
493        assert_eq!(resized.width(), 50);
494        assert_eq!(resized.height(), 50);
495    }
496}