Skip to main content

oxigdal_gpu/
reprojection.rs

1//! GPU-accelerated raster reprojection using wgpu compute shaders.
2//!
3//! This module provides both a GPU-backed reprojection pipeline and a CPU
4//! fallback implementation for environments where GPU is unavailable.
5
6use crate::error::GpuError;
7
8/// Resampling method for reprojection.
9#[derive(Debug, Clone, PartialEq)]
10pub enum ResampleMethod {
11    /// Nearest-neighbor sampling (fastest, blocky).
12    NearestNeighbor,
13    /// Bilinear interpolation (smoother, moderate cost).
14    Bilinear,
15}
16
17/// Configuration for a reprojection operation.
18#[derive(Debug, Clone)]
19pub struct ReprojectionConfig {
20    /// Source raster width in pixels.
21    pub src_width: u32,
22    /// Source raster height in pixels.
23    pub src_height: u32,
24    /// Destination raster width in pixels.
25    pub dst_width: u32,
26    /// Destination raster height in pixels.
27    pub dst_height: u32,
28    /// Source geotransform \[a, b, c, d, e, f\] where:
29    /// `x_geo = c + col * a + row * b`
30    /// `y_geo = f + col * d + row * e`
31    pub src_geotransform: [f32; 6],
32    /// Destination inverse geotransform (maps geo → pixel).
33    pub dst_inv_geotransform: [f32; 6],
34    /// Pixel resampling strategy.
35    pub resample_method: ResampleMethod,
36    /// Optional nodata sentinel value.
37    pub nodata: Option<f32>,
38}
39
40impl ReprojectionConfig {
41    /// Validate that the configuration is internally consistent.
42    ///
43    /// # Errors
44    ///
45    /// Returns [`GpuError::InvalidKernelParams`] if dimensions are zero.
46    pub fn validate(&self) -> Result<(), GpuError> {
47        if self.src_width == 0 || self.src_height == 0 {
48            return Err(GpuError::invalid_kernel_params(
49                "source dimensions must be greater than zero",
50            ));
51        }
52        if self.dst_width == 0 || self.dst_height == 0 {
53            return Err(GpuError::invalid_kernel_params(
54                "destination dimensions must be greater than zero",
55            ));
56        }
57        Ok(())
58    }
59}
60
61/// GPU-based raster reprojector.
62///
63/// On platforms with GPU support this will eventually dispatch to a wgpu
64/// compute shader.  Until that path is fully wired up, [`reproject_cpu`]
65/// provides a correct CPU-based fallback.
66///
67/// [`reproject_cpu`]: GpuReprojector::reproject_cpu
68pub struct GpuReprojector {
69    config: ReprojectionConfig,
70}
71
72impl GpuReprojector {
73    /// Construct a new reprojector from the given configuration.
74    pub fn new(config: ReprojectionConfig) -> Self {
75        Self { config }
76    }
77
78    /// Return a reference to the reprojection configuration.
79    pub fn config(&self) -> &ReprojectionConfig {
80        &self.config
81    }
82
83    /// Reproject `src_data` to the destination grid using a pure-CPU path.
84    ///
85    /// The implementation maps each destination pixel back to source
86    /// coordinates via the supplied geotransforms and samples the source
87    /// raster.  Out-of-bounds source pixels are filled with the nodata
88    /// value (or `0.0` when nodata is not configured).
89    ///
90    /// # Errors
91    ///
92    /// Returns [`GpuError::InvalidKernelParams`] if the configuration is
93    /// invalid or the source data length does not match the declared
94    /// source dimensions.
95    pub fn reproject_cpu(&self, src_data: &[f32]) -> Result<Vec<f32>, GpuError> {
96        self.config.validate()?;
97
98        let expected_src = (self.config.src_width as usize) * (self.config.src_height as usize);
99        if src_data.len() != expected_src {
100            return Err(GpuError::invalid_kernel_params(format!(
101                "src_data length {} does not match declared source dimensions {}x{} ({})",
102                src_data.len(),
103                self.config.src_width,
104                self.config.src_height,
105                expected_src
106            )));
107        }
108
109        let nodata_fill = self.config.nodata.unwrap_or(0.0);
110        let dst_size = (self.config.dst_width as usize) * (self.config.dst_height as usize);
111        let mut dst = vec![nodata_fill; dst_size];
112
113        let gt = &self.config.src_geotransform;
114        let inv_gt = &self.config.dst_inv_geotransform;
115
116        // Determinant of the source geotransform's 2×2 linear part
117        // used to invert the forward transform: pixel → geo → src pixel.
118        let det = gt[0] * gt[4] - gt[1] * gt[3];
119        let src_gt_invertible = det.abs() > f32::EPSILON;
120
121        for row in 0..self.config.dst_height {
122            for col in 0..self.config.dst_width {
123                // Centre of destination pixel in pixel space.
124                let dst_x = col as f32 + 0.5_f32;
125                let dst_y = row as f32 + 0.5_f32;
126
127                // Destination pixel → destination geo coordinates.
128                let geo_x = inv_gt[0] + dst_x * inv_gt[1] + dst_y * inv_gt[2];
129                let geo_y = inv_gt[3] + dst_x * inv_gt[4] + dst_y * inv_gt[5];
130
131                // Destination geo → source pixel coordinates.
132                let (src_col_f, src_row_f) = if src_gt_invertible {
133                    let dx = geo_x - gt[2];
134                    let dy = geo_y - gt[5];
135                    let sc = (gt[4] * dx - gt[1] * dy) / det;
136                    let sr = (gt[0] * dy - gt[3] * dx) / det;
137                    (sc, sr)
138                } else {
139                    // Fallback: treat inv_gt as direct pixel scaling.
140                    (
141                        col as f32 * self.config.src_width as f32 / self.config.dst_width as f32,
142                        row as f32 * self.config.src_height as f32 / self.config.dst_height as f32,
143                    )
144                };
145
146                let dst_idx = row as usize * self.config.dst_width as usize + col as usize;
147
148                match self.config.resample_method {
149                    ResampleMethod::NearestNeighbor => {
150                        let src_c = src_col_f as i64;
151                        let src_r = src_row_f as i64;
152
153                        if src_c < 0
154                            || src_r < 0
155                            || src_c >= self.config.src_width as i64
156                            || src_r >= self.config.src_height as i64
157                        {
158                            continue;
159                        }
160
161                        let src_idx =
162                            src_r as usize * self.config.src_width as usize + src_c as usize;
163                        if src_idx < src_data.len() {
164                            dst[dst_idx] = src_data[src_idx];
165                        }
166                    }
167                    ResampleMethod::Bilinear => {
168                        let x0 = src_col_f.floor() as i64;
169                        let y0 = src_row_f.floor() as i64;
170                        let x1 = x0 + 1;
171                        let y1 = y0 + 1;
172
173                        let tx = src_col_f - src_col_f.floor();
174                        let ty = src_row_f - src_row_f.floor();
175
176                        let w = self.config.src_width as i64;
177                        let h = self.config.src_height as i64;
178
179                        let sample = |c: i64, r: i64| -> f32 {
180                            if c < 0 || r < 0 || c >= w || r >= h {
181                                return nodata_fill;
182                            }
183                            let idx = r as usize * self.config.src_width as usize + c as usize;
184                            src_data.get(idx).copied().unwrap_or(nodata_fill)
185                        };
186
187                        let v00 = sample(x0, y0);
188                        let v10 = sample(x1, y0);
189                        let v01 = sample(x0, y1);
190                        let v11 = sample(x1, y1);
191
192                        let v0 = v00 + (v10 - v00) * tx;
193                        let v1 = v01 + (v11 - v01) * tx;
194                        dst[dst_idx] = v0 + (v1 - v0) * ty;
195                    }
196                }
197            }
198        }
199
200        Ok(dst)
201    }
202}
203
204#[cfg(test)]
205mod tests {
206    use super::*;
207
208    fn identity_config(size: u32) -> ReprojectionConfig {
209        // src_gt: origin (0,0), pixel size 1x1
210        // dst_inv_gt: maps dst pixel → geo coord with 1:1 scale
211        ReprojectionConfig {
212            src_width: size,
213            src_height: size,
214            dst_width: size,
215            dst_height: size,
216            src_geotransform: [1.0, 0.0, 0.0, 0.0, 1.0, 0.0],
217            dst_inv_geotransform: [0.0, 1.0, 0.0, 0.0, 0.0, 1.0],
218            resample_method: ResampleMethod::NearestNeighbor,
219            nodata: None,
220        }
221    }
222
223    #[test]
224    fn test_new_and_config() {
225        let cfg = identity_config(4);
226        let r = GpuReprojector::new(cfg.clone());
227        assert_eq!(r.config().src_width, 4);
228        assert_eq!(r.config().dst_width, 4);
229    }
230
231    #[test]
232    fn test_validate_zero_src_dims() {
233        let mut cfg = identity_config(4);
234        cfg.src_width = 0;
235        assert!(cfg.validate().is_err());
236    }
237
238    #[test]
239    fn test_validate_zero_dst_dims() {
240        let mut cfg = identity_config(4);
241        cfg.dst_width = 0;
242        assert!(cfg.validate().is_err());
243    }
244
245    #[test]
246    fn test_reproject_cpu_wrong_len() {
247        let cfg = identity_config(4);
248        let r = GpuReprojector::new(cfg);
249        let result = r.reproject_cpu(&[1.0, 2.0]);
250        assert!(result.is_err());
251    }
252
253    #[test]
254    fn test_reproject_cpu_identity() {
255        let size = 4u32;
256        let src: Vec<f32> = (0..(size * size)).map(|i| i as f32).collect();
257        let r = GpuReprojector::new(identity_config(size));
258        let dst = r.reproject_cpu(&src).expect("reproject_cpu failed");
259        assert_eq!(dst.len(), (size * size) as usize);
260    }
261}