1use oxigdal_core::buffer::RasterBuffer;
7use serde::{Deserialize, Serialize};
10use tracing::debug;
11
12use crate::error::{PreprocessingError, Result};
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct NormalizationParams {
17 pub mean: Vec<f64>,
19 pub std: Vec<f64>,
21}
22
23impl NormalizationParams {
24 #[must_use]
26 pub fn imagenet() -> Self {
27 Self {
28 mean: vec![0.485, 0.456, 0.406],
29 std: vec![0.229, 0.224, 0.225],
30 }
31 }
32
33 #[must_use]
35 pub fn from_range(min: f64, max: f64) -> Self {
36 let mean = (min + max) / 2.0;
37 let std = (max - min) / 2.0;
38 Self {
39 mean: vec![mean],
40 std: vec![std],
41 }
42 }
43
44 #[must_use]
46 pub fn zero_mean_unit_variance() -> Self {
47 Self {
48 mean: vec![0.0],
49 std: vec![1.0],
50 }
51 }
52}
53
54#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
56pub enum PaddingStrategy {
57 Zero,
59 Replicate,
61 Reflect,
63 Wrap,
65}
66
67#[derive(Debug, Clone)]
69pub struct TileConfig {
70 pub tile_width: usize,
72 pub tile_height: usize,
74 pub overlap: usize,
76 pub padding: PaddingStrategy,
78}
79
80impl Default for TileConfig {
81 fn default() -> Self {
82 Self {
83 tile_width: 256,
84 tile_height: 256,
85 overlap: 32,
86 padding: PaddingStrategy::Replicate,
87 }
88 }
89}
90
91#[derive(Debug, Clone)]
93pub struct Tile {
94 pub buffer: RasterBuffer,
96 pub x_offset: u64,
98 pub y_offset: u64,
100 pub original_width: u64,
102 pub original_height: u64,
104}
105
106pub fn normalize(buffer: &RasterBuffer, params: &NormalizationParams) -> Result<RasterBuffer> {
111 if params.mean.is_empty() || params.std.is_empty() {
112 return Err(PreprocessingError::InvalidNormalization {
113 message: "Mean and std must not be empty".to_string(),
114 }
115 .into());
116 }
117
118 if params.std.contains(&0.0) {
119 return Err(PreprocessingError::InvalidNormalization {
120 message: "Standard deviation cannot be zero".to_string(),
121 }
122 .into());
123 }
124
125 let mut result = buffer.clone();
126
127 for y in 0..buffer.height() {
129 for x in 0..buffer.width() {
130 let pixel =
131 buffer
132 .get_pixel(x, y)
133 .map_err(|e| PreprocessingError::InvalidNormalization {
134 message: format!("Failed to get pixel: {}", e),
135 })?;
136
137 let channel_idx = 0;
139 let mean = params.mean[channel_idx];
140 let std = params.std[channel_idx];
141
142 let normalized = (pixel - mean) / std;
143
144 result.set_pixel(x, y, normalized).map_err(|e| {
145 PreprocessingError::InvalidNormalization {
146 message: format!("Failed to set pixel: {}", e),
147 }
148 })?;
149 }
150 }
151
152 Ok(result)
153}
154
155pub fn tile_raster(buffer: &RasterBuffer, config: &TileConfig) -> Result<Vec<Tile>> {
160 if config.tile_width == 0 || config.tile_height == 0 {
161 return Err(PreprocessingError::InvalidTileSize {
162 width: config.tile_width,
163 height: config.tile_height,
164 }
165 .into());
166 }
167
168 let width = buffer.width();
169 let height = buffer.height();
170
171 debug!(
172 "Tiling {}x{} raster into {}x{} tiles with {} overlap",
173 width, height, config.tile_width, config.tile_height, config.overlap
174 );
175
176 let mut tiles = Vec::new();
177
178 let stride_x = config.tile_width.saturating_sub(config.overlap);
179 let stride_y = config.tile_height.saturating_sub(config.overlap);
180
181 if stride_x == 0 || stride_y == 0 {
182 return Err(PreprocessingError::TilingFailed {
183 reason: "Overlap is too large for the tile size".to_string(),
184 }
185 .into());
186 }
187
188 let mut y = 0u64;
189 while y < height {
190 let mut x = 0u64;
191 while x < width {
192 let tile_width = (width - x).min(config.tile_width as u64);
193 let tile_height = (height - y).min(config.tile_height as u64);
194
195 let tile_buffer = extract_tile(buffer, x, y, tile_width, tile_height, config)?;
196
197 tiles.push(Tile {
198 buffer: tile_buffer,
199 x_offset: x,
200 y_offset: y,
201 original_width: width,
202 original_height: height,
203 });
204
205 x = x.saturating_add(stride_x as u64);
206 if x >= width {
207 break;
208 }
209 }
210
211 y = y.saturating_add(stride_y as u64);
212 if y >= height {
213 break;
214 }
215 }
216
217 debug!("Created {} tiles", tiles.len());
218
219 Ok(tiles)
220}
221
222fn extract_tile(
224 buffer: &RasterBuffer,
225 x: u64,
226 y: u64,
227 width: u64,
228 height: u64,
229 config: &TileConfig,
230) -> Result<RasterBuffer> {
231 let mut tile = RasterBuffer::zeros(
232 config.tile_width as u64,
233 config.tile_height as u64,
234 buffer.data_type(),
235 );
236
237 for ty in 0..height {
239 for tx in 0..width {
240 let src_x = x + tx;
241 let src_y = y + ty;
242
243 let pixel =
244 buffer
245 .get_pixel(src_x, src_y)
246 .map_err(|e| PreprocessingError::TilingFailed {
247 reason: format!("Failed to get pixel: {}", e),
248 })?;
249
250 tile.set_pixel(tx, ty, pixel)
251 .map_err(|e| PreprocessingError::TilingFailed {
252 reason: format!("Failed to set pixel: {}", e),
253 })?;
254 }
255 }
256
257 if width < config.tile_width as u64 || height < config.tile_height as u64 {
259 apply_padding(&mut tile, width, height, config.padding)?;
260 }
261
262 Ok(tile)
263}
264
265fn apply_padding(
267 tile: &mut RasterBuffer,
268 valid_width: u64,
269 valid_height: u64,
270 strategy: PaddingStrategy,
271) -> Result<()> {
272 let tile_width = tile.width();
273 let tile_height = tile.height();
274
275 match strategy {
276 PaddingStrategy::Zero => {
277 Ok(())
279 }
280 PaddingStrategy::Replicate => {
281 if valid_width < tile_width {
283 let edge_x = valid_width.saturating_sub(1);
284 for y in 0..valid_height {
285 let edge_value = tile.get_pixel(edge_x, y).map_err(|e| {
286 PreprocessingError::PaddingFailed {
287 reason: format!("Failed to get edge pixel: {}", e),
288 }
289 })?;
290 for x in valid_width..tile_width {
291 tile.set_pixel(x, y, edge_value).map_err(|e| {
292 PreprocessingError::PaddingFailed {
293 reason: format!("Failed to set pixel: {}", e),
294 }
295 })?;
296 }
297 }
298 }
299
300 if valid_height < tile_height {
302 let edge_y = valid_height.saturating_sub(1);
303 for x in 0..tile_width {
304 let edge_value = tile.get_pixel(x, edge_y).map_err(|e| {
305 PreprocessingError::PaddingFailed {
306 reason: format!("Failed to get edge pixel: {}", e),
307 }
308 })?;
309 for y in valid_height..tile_height {
310 tile.set_pixel(x, y, edge_value).map_err(|e| {
311 PreprocessingError::PaddingFailed {
312 reason: format!("Failed to set pixel: {}", e),
313 }
314 })?;
315 }
316 }
317 }
318
319 Ok(())
320 }
321 PaddingStrategy::Reflect => {
322 if valid_width < tile_width {
324 for y in 0..valid_height {
325 for x in valid_width..tile_width {
326 let reflect_x =
327 valid_width.saturating_sub((x - valid_width + 1).min(valid_width));
328 let value = tile.get_pixel(reflect_x, y).map_err(|e| {
329 PreprocessingError::PaddingFailed {
330 reason: format!("Failed to get reflected pixel: {}", e),
331 }
332 })?;
333 tile.set_pixel(x, y, value).map_err(|e| {
334 PreprocessingError::PaddingFailed {
335 reason: format!("Failed to set pixel: {}", e),
336 }
337 })?;
338 }
339 }
340 }
341
342 if valid_height < tile_height {
343 for x in 0..tile_width {
344 for y in valid_height..tile_height {
345 let reflect_y =
346 valid_height.saturating_sub((y - valid_height + 1).min(valid_height));
347 let value = tile.get_pixel(x, reflect_y).map_err(|e| {
348 PreprocessingError::PaddingFailed {
349 reason: format!("Failed to get reflected pixel: {}", e),
350 }
351 })?;
352 tile.set_pixel(x, y, value).map_err(|e| {
353 PreprocessingError::PaddingFailed {
354 reason: format!("Failed to set pixel: {}", e),
355 }
356 })?;
357 }
358 }
359 }
360
361 Ok(())
362 }
363 PaddingStrategy::Wrap => {
364 if valid_width < tile_width && valid_width > 0 {
366 for y in 0..valid_height {
367 for x in valid_width..tile_width {
368 let wrap_x = (x - valid_width) % valid_width;
369 let value = tile.get_pixel(wrap_x, y).map_err(|e| {
370 PreprocessingError::PaddingFailed {
371 reason: format!("Failed to get wrapped pixel: {}", e),
372 }
373 })?;
374 tile.set_pixel(x, y, value).map_err(|e| {
375 PreprocessingError::PaddingFailed {
376 reason: format!("Failed to set pixel: {}", e),
377 }
378 })?;
379 }
380 }
381 }
382
383 if valid_height < tile_height && valid_height > 0 {
384 for x in 0..tile_width {
385 for y in valid_height..tile_height {
386 let wrap_y = (y - valid_height) % valid_height;
387 let value = tile.get_pixel(x, wrap_y).map_err(|e| {
388 PreprocessingError::PaddingFailed {
389 reason: format!("Failed to get wrapped pixel: {}", e),
390 }
391 })?;
392 tile.set_pixel(x, y, value).map_err(|e| {
393 PreprocessingError::PaddingFailed {
394 reason: format!("Failed to set pixel: {}", e),
395 }
396 })?;
397 }
398 }
399 }
400
401 Ok(())
402 }
403 }
404}
405
406pub fn resize_nearest(
411 buffer: &RasterBuffer,
412 new_width: u64,
413 new_height: u64,
414) -> Result<RasterBuffer> {
415 let mut result = RasterBuffer::zeros(new_width, new_height, buffer.data_type());
416
417 let x_ratio = buffer.width() as f64 / new_width as f64;
418 let y_ratio = buffer.height() as f64 / new_height as f64;
419
420 for y in 0..new_height {
421 for x in 0..new_width {
422 let src_x = (x as f64 * x_ratio) as u64;
423 let src_y = (y as f64 * y_ratio) as u64;
424
425 let pixel = buffer.get_pixel(src_x, src_y).map_err(|e| {
426 PreprocessingError::InvalidNormalization {
427 message: format!("Failed to get pixel during resize: {}", e),
428 }
429 })?;
430
431 result.set_pixel(x, y, pixel).map_err(|e| {
432 PreprocessingError::InvalidNormalization {
433 message: format!("Failed to set pixel during resize: {}", e),
434 }
435 })?;
436 }
437 }
438
439 Ok(result)
440}
441
442#[cfg(test)]
443mod tests {
444 use super::*;
445 use oxigdal_core::types::RasterDataType;
446
447 #[test]
448 fn test_normalization_params() {
449 let params = NormalizationParams::imagenet();
450 assert_eq!(params.mean.len(), 3);
451 assert_eq!(params.std.len(), 3);
452
453 let params = NormalizationParams::from_range(0.0, 255.0);
454 assert!((params.mean[0] - 127.5).abs() < f64::EPSILON);
455 }
456
457 #[test]
458 fn test_normalize() {
459 let buffer = RasterBuffer::zeros(10, 10, RasterDataType::Float32);
460 let params = NormalizationParams::zero_mean_unit_variance();
461
462 let result = normalize(&buffer, ¶ms);
463 assert!(result.is_ok());
464 }
465
466 #[test]
467 fn test_tile_config_default() {
468 let config = TileConfig::default();
469 assert_eq!(config.tile_width, 256);
470 assert_eq!(config.tile_height, 256);
471 assert_eq!(config.overlap, 32);
472 }
473
474 #[test]
475 fn test_tiling() {
476 let buffer = RasterBuffer::zeros(512, 512, RasterDataType::Float32);
477 let config = TileConfig::default();
478
479 let tiles = tile_raster(&buffer, &config);
480 assert!(tiles.is_ok());
481 let tiles = tiles.ok().unwrap_or_default();
482 assert!(!tiles.is_empty());
483 }
484
485 #[test]
486 fn test_resize_nearest() {
487 let buffer = RasterBuffer::zeros(100, 100, RasterDataType::Float32);
488 let resized = resize_nearest(&buffer, 50, 50);
489 assert!(resized.is_ok());
490 let resized = resized
491 .ok()
492 .unwrap_or_else(|| RasterBuffer::zeros(1, 1, RasterDataType::Float32));
493 assert_eq!(resized.width(), 50);
494 assert_eq!(resized.height(), 50);
495 }
496}