Skip to main content

oxigdal_gpu/
algebra.rs

1//! GPU-accelerated raster algebra operations.
2//!
3//! Provides element-wise band math operations (`AlgebraOp`), a structured
4//! expression tree (`BandExpression`) for composing multi-band formulas, and
5//! the top-level `GpuAlgebra` driver that dispatches to GPU (future) or CPU.
6
7use crate::error::GpuError;
8
9/// Element-wise raster algebra operation.
10#[derive(Debug, Clone, PartialEq)]
11pub enum AlgebraOp {
12    /// `A + B`
13    Add,
14    /// `A - B`
15    Subtract,
16    /// `A * B`
17    Multiply,
18    /// `A / B` (outputs nodata when `|B| < 1e-10`)
19    Divide,
20    /// `min(A, B)`
21    Min,
22    /// `max(A, B)`
23    Max,
24    /// `sqrt(max(0, A))`
25    Sqrt,
26    /// `|A|`
27    Abs,
28    /// `A ^ exp`
29    Power(f32),
30    /// `clamp(A, min, max)`
31    Clamp { min: f32, max: f32 },
32    /// Linear stretch: maps `[src_min, src_max]` → `[dst_min, dst_max]`
33    Normalize {
34        src_min: f32,
35        src_max: f32,
36        dst_min: f32,
37        dst_max: f32,
38    },
39}
40
41/// Pure-Rust raster algebra executor.
42pub struct GpuAlgebra;
43
44impl GpuAlgebra {
45    /// Execute an algebra operation pixel-by-pixel (CPU fallback).
46    ///
47    /// `band_b` is required for binary operations (`Add`, `Subtract`,
48    /// `Multiply`, `Divide`, `Min`, `Max`).  For unary operations it is
49    /// ignored.
50    ///
51    /// # Errors
52    ///
53    /// Returns [`GpuError::InvalidKernelParams`] if `band_a` is empty.
54    pub fn execute(
55        band_a: &[f32],
56        band_b: Option<&[f32]>,
57        op: AlgebraOp,
58        nodata: Option<f32>,
59    ) -> Result<Vec<f32>, GpuError> {
60        if band_a.is_empty() {
61            return Err(GpuError::invalid_kernel_params("band_a must not be empty"));
62        }
63
64        let nodata_val = nodata.unwrap_or(f32::NAN);
65        let mut output = Vec::with_capacity(band_a.len());
66
67        for (i, &a) in band_a.iter().enumerate() {
68            // Nodata check for band A.
69            if nodata.is_some() && Self::is_nodata(a, nodata_val) {
70                output.push(nodata_val);
71                continue;
72            }
73
74            let b = band_b.and_then(|bb| bb.get(i)).copied().unwrap_or(0.0_f32);
75
76            // Nodata check for band B.
77            if nodata.is_some() && band_b.is_some() && Self::is_nodata(b, nodata_val) {
78                output.push(nodata_val);
79                continue;
80            }
81
82            let result = match &op {
83                AlgebraOp::Add => a + b,
84                AlgebraOp::Subtract => a - b,
85                AlgebraOp::Multiply => a * b,
86                AlgebraOp::Divide => {
87                    if b.abs() > 1e-10 {
88                        a / b
89                    } else {
90                        nodata_val
91                    }
92                }
93                AlgebraOp::Min => a.min(b),
94                AlgebraOp::Max => a.max(b),
95                AlgebraOp::Sqrt => a.max(0.0).sqrt(),
96                AlgebraOp::Abs => a.abs(),
97                AlgebraOp::Power(exp) => a.powf(*exp),
98                AlgebraOp::Clamp { min, max } => a.clamp(*min, *max),
99                AlgebraOp::Normalize {
100                    src_min,
101                    src_max,
102                    dst_min,
103                    dst_max,
104                } => {
105                    let range = src_max - src_min;
106                    if range.abs() < 1e-10 {
107                        *dst_min
108                    } else {
109                        (a - src_min) / range * (dst_max - dst_min) + dst_min
110                    }
111                }
112            };
113
114            output.push(result);
115        }
116
117        Ok(output)
118    }
119
120    /// Evaluate a multi-band expression for every pixel.
121    ///
122    /// All bands in `bands` must have the same length.  Pixels where any
123    /// band holds the nodata value are written as nodata without evaluating
124    /// the expression.
125    ///
126    /// # Errors
127    ///
128    /// Returns [`GpuError::InvalidKernelParams`] if no bands are provided.
129    /// Propagates any error from `expression.evaluate`.
130    pub fn evaluate_expression(
131        bands: &[&[f32]],
132        expression: &BandExpression,
133        nodata: Option<f32>,
134    ) -> Result<Vec<f32>, GpuError> {
135        if bands.is_empty() {
136            return Err(GpuError::invalid_kernel_params("no bands provided"));
137        }
138
139        let len = bands[0].len();
140        let nodata_val = nodata.unwrap_or(f32::NAN);
141
142        let mut output = Vec::with_capacity(len);
143        for i in 0..len {
144            // Check nodata across all bands.
145            let has_nodata = nodata.is_some()
146                && bands.iter().any(|b| {
147                    b.get(i)
148                        .map(|v| Self::is_nodata(*v, nodata_val))
149                        .unwrap_or(false)
150                });
151
152            if has_nodata {
153                output.push(nodata_val);
154                continue;
155            }
156
157            let vals: Vec<f32> = bands
158                .iter()
159                .map(|b| b.get(i).copied().unwrap_or(0.0))
160                .collect();
161            output.push(expression.evaluate(&vals)?);
162        }
163
164        Ok(output)
165    }
166
167    #[inline]
168    fn is_nodata(value: f32, nodata: f32) -> bool {
169        (value - nodata).abs() < 1e-6
170    }
171}
172
173/// A composable expression tree for multi-band raster math.
174///
175/// Leaf nodes are either a `Band` index or a scalar `Constant`.
176/// Interior nodes are arithmetic operators.
177#[derive(Debug, Clone)]
178pub enum BandExpression {
179    /// Reference to band at the given index.
180    Band(usize),
181    /// Scalar constant.
182    Constant(f32),
183    /// Addition: `A + B`
184    Add(Box<BandExpression>, Box<BandExpression>),
185    /// Subtraction: `A - B`
186    Sub(Box<BandExpression>, Box<BandExpression>),
187    /// Multiplication: `A * B`
188    Mul(Box<BandExpression>, Box<BandExpression>),
189    /// Division: `A / B` (errors on divide-by-zero)
190    Div(Box<BandExpression>, Box<BandExpression>),
191    /// Square root: `sqrt(max(0, A))`
192    Sqrt(Box<BandExpression>),
193    /// Absolute value: `|A|`
194    Abs(Box<BandExpression>),
195    /// Negation: `-A`
196    Neg(Box<BandExpression>),
197}
198
199impl BandExpression {
200    /// Evaluate the expression for one pixel given per-band values.
201    ///
202    /// # Errors
203    ///
204    /// Returns [`GpuError::InvalidKernelParams`] when a `Band` index is out
205    /// of range or a `Div` node encounters a zero denominator.
206    pub fn evaluate(&self, bands: &[f32]) -> Result<f32, GpuError> {
207        match self {
208            BandExpression::Band(idx) => bands.get(*idx).copied().ok_or_else(|| {
209                GpuError::invalid_kernel_params(format!(
210                    "band index {} out of range (have {} bands)",
211                    idx,
212                    bands.len()
213                ))
214            }),
215            BandExpression::Constant(v) => Ok(*v),
216            BandExpression::Add(a, b) => Ok(a.evaluate(bands)? + b.evaluate(bands)?),
217            BandExpression::Sub(a, b) => Ok(a.evaluate(bands)? - b.evaluate(bands)?),
218            BandExpression::Mul(a, b) => Ok(a.evaluate(bands)? * b.evaluate(bands)?),
219            BandExpression::Div(a, b) => {
220                let denom = b.evaluate(bands)?;
221                if denom.abs() < 1e-10 {
222                    Err(GpuError::invalid_kernel_params(
223                        "division by zero in BandExpression",
224                    ))
225                } else {
226                    Ok(a.evaluate(bands)? / denom)
227                }
228            }
229            BandExpression::Sqrt(a) => Ok(a.evaluate(bands)?.max(0.0).sqrt()),
230            BandExpression::Abs(a) => Ok(a.evaluate(bands)?.abs()),
231            BandExpression::Neg(a) => Ok(-a.evaluate(bands)?),
232        }
233    }
234}
235
236#[cfg(test)]
237mod tests {
238    use super::*;
239
240    #[test]
241    fn test_execute_empty_band_a() {
242        let result = GpuAlgebra::execute(&[], None, AlgebraOp::Add, None);
243        assert!(result.is_err());
244    }
245
246    #[test]
247    fn test_execute_add() {
248        let a = vec![1.0_f32, 2.0, 3.0];
249        let b = vec![4.0_f32, 5.0, 6.0];
250        let out = GpuAlgebra::execute(&a, Some(&b), AlgebraOp::Add, None).expect("execute failed");
251        assert_eq!(out, vec![5.0, 7.0, 9.0]);
252    }
253
254    #[test]
255    fn test_expression_band_out_of_range() {
256        let expr = BandExpression::Band(5);
257        assert!(expr.evaluate(&[1.0, 2.0]).is_err());
258    }
259
260    #[test]
261    fn test_expression_div_by_zero() {
262        let expr = BandExpression::Div(
263            Box::new(BandExpression::Band(0)),
264            Box::new(BandExpression::Constant(0.0)),
265        );
266        assert!(expr.evaluate(&[1.0]).is_err());
267    }
268}