1use crate::error::GpuError;
8
9#[derive(Debug, Clone, PartialEq)]
11pub enum AlgebraOp {
12 Add,
14 Subtract,
16 Multiply,
18 Divide,
20 Min,
22 Max,
24 Sqrt,
26 Abs,
28 Power(f32),
30 Clamp { min: f32, max: f32 },
32 Normalize {
34 src_min: f32,
35 src_max: f32,
36 dst_min: f32,
37 dst_max: f32,
38 },
39}
40
41pub struct GpuAlgebra;
43
44impl GpuAlgebra {
45 pub fn execute(
55 band_a: &[f32],
56 band_b: Option<&[f32]>,
57 op: AlgebraOp,
58 nodata: Option<f32>,
59 ) -> Result<Vec<f32>, GpuError> {
60 if band_a.is_empty() {
61 return Err(GpuError::invalid_kernel_params("band_a must not be empty"));
62 }
63
64 let nodata_val = nodata.unwrap_or(f32::NAN);
65 let mut output = Vec::with_capacity(band_a.len());
66
67 for (i, &a) in band_a.iter().enumerate() {
68 if nodata.is_some() && Self::is_nodata(a, nodata_val) {
70 output.push(nodata_val);
71 continue;
72 }
73
74 let b = band_b.and_then(|bb| bb.get(i)).copied().unwrap_or(0.0_f32);
75
76 if nodata.is_some() && band_b.is_some() && Self::is_nodata(b, nodata_val) {
78 output.push(nodata_val);
79 continue;
80 }
81
82 let result = match &op {
83 AlgebraOp::Add => a + b,
84 AlgebraOp::Subtract => a - b,
85 AlgebraOp::Multiply => a * b,
86 AlgebraOp::Divide => {
87 if b.abs() > 1e-10 {
88 a / b
89 } else {
90 nodata_val
91 }
92 }
93 AlgebraOp::Min => a.min(b),
94 AlgebraOp::Max => a.max(b),
95 AlgebraOp::Sqrt => a.max(0.0).sqrt(),
96 AlgebraOp::Abs => a.abs(),
97 AlgebraOp::Power(exp) => a.powf(*exp),
98 AlgebraOp::Clamp { min, max } => a.clamp(*min, *max),
99 AlgebraOp::Normalize {
100 src_min,
101 src_max,
102 dst_min,
103 dst_max,
104 } => {
105 let range = src_max - src_min;
106 if range.abs() < 1e-10 {
107 *dst_min
108 } else {
109 (a - src_min) / range * (dst_max - dst_min) + dst_min
110 }
111 }
112 };
113
114 output.push(result);
115 }
116
117 Ok(output)
118 }
119
120 pub fn evaluate_expression(
131 bands: &[&[f32]],
132 expression: &BandExpression,
133 nodata: Option<f32>,
134 ) -> Result<Vec<f32>, GpuError> {
135 if bands.is_empty() {
136 return Err(GpuError::invalid_kernel_params("no bands provided"));
137 }
138
139 let len = bands[0].len();
140 let nodata_val = nodata.unwrap_or(f32::NAN);
141
142 let mut output = Vec::with_capacity(len);
143 for i in 0..len {
144 let has_nodata = nodata.is_some()
146 && bands.iter().any(|b| {
147 b.get(i)
148 .map(|v| Self::is_nodata(*v, nodata_val))
149 .unwrap_or(false)
150 });
151
152 if has_nodata {
153 output.push(nodata_val);
154 continue;
155 }
156
157 let vals: Vec<f32> = bands
158 .iter()
159 .map(|b| b.get(i).copied().unwrap_or(0.0))
160 .collect();
161 output.push(expression.evaluate(&vals)?);
162 }
163
164 Ok(output)
165 }
166
167 #[inline]
168 fn is_nodata(value: f32, nodata: f32) -> bool {
169 (value - nodata).abs() < 1e-6
170 }
171}
172
173#[derive(Debug, Clone)]
178pub enum BandExpression {
179 Band(usize),
181 Constant(f32),
183 Add(Box<BandExpression>, Box<BandExpression>),
185 Sub(Box<BandExpression>, Box<BandExpression>),
187 Mul(Box<BandExpression>, Box<BandExpression>),
189 Div(Box<BandExpression>, Box<BandExpression>),
191 Sqrt(Box<BandExpression>),
193 Abs(Box<BandExpression>),
195 Neg(Box<BandExpression>),
197}
198
199impl BandExpression {
200 pub fn evaluate(&self, bands: &[f32]) -> Result<f32, GpuError> {
207 match self {
208 BandExpression::Band(idx) => bands.get(*idx).copied().ok_or_else(|| {
209 GpuError::invalid_kernel_params(format!(
210 "band index {} out of range (have {} bands)",
211 idx,
212 bands.len()
213 ))
214 }),
215 BandExpression::Constant(v) => Ok(*v),
216 BandExpression::Add(a, b) => Ok(a.evaluate(bands)? + b.evaluate(bands)?),
217 BandExpression::Sub(a, b) => Ok(a.evaluate(bands)? - b.evaluate(bands)?),
218 BandExpression::Mul(a, b) => Ok(a.evaluate(bands)? * b.evaluate(bands)?),
219 BandExpression::Div(a, b) => {
220 let denom = b.evaluate(bands)?;
221 if denom.abs() < 1e-10 {
222 Err(GpuError::invalid_kernel_params(
223 "division by zero in BandExpression",
224 ))
225 } else {
226 Ok(a.evaluate(bands)? / denom)
227 }
228 }
229 BandExpression::Sqrt(a) => Ok(a.evaluate(bands)?.max(0.0).sqrt()),
230 BandExpression::Abs(a) => Ok(a.evaluate(bands)?.abs()),
231 BandExpression::Neg(a) => Ok(-a.evaluate(bands)?),
232 }
233 }
234}
235
236#[cfg(test)]
237mod tests {
238 use super::*;
239
240 #[test]
241 fn test_execute_empty_band_a() {
242 let result = GpuAlgebra::execute(&[], None, AlgebraOp::Add, None);
243 assert!(result.is_err());
244 }
245
246 #[test]
247 fn test_execute_add() {
248 let a = vec![1.0_f32, 2.0, 3.0];
249 let b = vec![4.0_f32, 5.0, 6.0];
250 let out = GpuAlgebra::execute(&a, Some(&b), AlgebraOp::Add, None).expect("execute failed");
251 assert_eq!(out, vec![5.0, 7.0, 9.0]);
252 }
253
254 #[test]
255 fn test_expression_band_out_of_range() {
256 let expr = BandExpression::Band(5);
257 assert!(expr.evaluate(&[1.0, 2.0]).is_err());
258 }
259
260 #[test]
261 fn test_expression_div_by_zero() {
262 let expr = BandExpression::Div(
263 Box::new(BandExpression::Band(0)),
264 Box::new(BandExpression::Constant(0.0)),
265 );
266 assert!(expr.evaluate(&[1.0]).is_err());
267 }
268}