Skip to main content

oxigdal_algorithms/dsl/
functions.rs

1//! Built-in functions for the Raster Algebra DSL
2//!
3//! This module provides a comprehensive library of built-in functions including:
4//! - Mathematical functions (sin, cos, sqrt, etc.)
5//! - Statistical functions (mean, median, percentile, etc.)
6//! - Spatial functions (focal operations)
7//! - Logical functions
8//! - Type conversion functions
9
10use super::variables::Value;
11use crate::error::{AlgorithmError, Result};
12use crate::raster::{gaussian_blur, median_filter};
13
14#[cfg(not(feature = "std"))]
15use alloc::{boxed::Box, string::String, vec::Vec};
16
17/// Built-in function type
18pub type BuiltinFn = fn(&[Value]) -> Result<Value>;
19
20/// Registry of built-in functions
21pub struct FunctionRegistry {
22    functions: Vec<(&'static str, BuiltinFn, usize)>, // (name, function, arity)
23}
24
25impl Default for FunctionRegistry {
26    fn default() -> Self {
27        Self::new()
28    }
29}
30
31impl FunctionRegistry {
32    /// Creates a new function registry with all built-in functions
33    pub fn new() -> Self {
34        let mut registry = Self {
35            functions: Vec::new(),
36        };
37
38        // Mathematical functions (1 arg)
39        registry.register("sqrt", fn_sqrt, 1);
40        registry.register("abs", fn_abs, 1);
41        registry.register("floor", fn_floor, 1);
42        registry.register("ceil", fn_ceil, 1);
43        registry.register("round", fn_round, 1);
44        registry.register("log", fn_log, 1);
45        registry.register("log10", fn_log10, 1);
46        registry.register("log2", fn_log2, 1);
47        registry.register("exp", fn_exp, 1);
48        registry.register("sin", fn_sin, 1);
49        registry.register("cos", fn_cos, 1);
50        registry.register("tan", fn_tan, 1);
51        registry.register("asin", fn_asin, 1);
52        registry.register("acos", fn_acos, 1);
53        registry.register("atan", fn_atan, 1);
54        registry.register("sinh", fn_sinh, 1);
55        registry.register("cosh", fn_cosh, 1);
56        registry.register("tanh", fn_tanh, 1);
57
58        // Mathematical functions (2 args)
59        registry.register("atan2", fn_atan2, 2);
60        registry.register("pow", fn_pow, 2);
61        registry.register("hypot", fn_hypot, 2);
62
63        // Min/Max (variable args)
64        registry.register("min", fn_min, 0);
65        registry.register("max", fn_max, 0);
66
67        // Statistical functions (1 arg - raster)
68        registry.register("mean", fn_mean, 1);
69        registry.register("median", fn_median, 1);
70        registry.register("mode", fn_mode, 1);
71        registry.register("stddev", fn_stddev, 1);
72        registry.register("variance", fn_variance, 1);
73        registry.register("sum", fn_sum, 1);
74        registry.register("product", fn_product, 1);
75
76        // Percentile functions
77        registry.register("percentile", fn_percentile, 2);
78
79        // Spatial filters
80        registry.register("gaussian", fn_gaussian, 2);
81        registry.register("median_filter", fn_median_filt, 2);
82
83        // Logical functions
84        registry.register("and", fn_and, 2);
85        registry.register("or", fn_or, 2);
86        registry.register("not", fn_not, 1);
87        registry.register("xor", fn_xor, 2);
88
89        // Comparison functions
90        registry.register("eq", fn_eq, 2);
91        registry.register("ne", fn_ne, 2);
92        registry.register("lt", fn_lt, 2);
93        registry.register("le", fn_le, 2);
94        registry.register("gt", fn_gt, 2);
95        registry.register("ge", fn_ge, 2);
96
97        // Type conversion
98        registry.register("to_number", fn_to_number, 1);
99        registry.register("to_bool", fn_to_bool, 1);
100
101        // Utility functions
102        registry.register("clamp", fn_clamp, 3);
103        registry.register("select", fn_select, 3);
104
105        registry
106    }
107
108    /// Registers a function
109    pub fn register(&mut self, name: &'static str, func: BuiltinFn, arity: usize) {
110        self.functions.push((name, func, arity));
111    }
112
113    /// Looks up a function by name
114    pub fn lookup(&self, name: &str) -> Option<(BuiltinFn, usize)> {
115        self.functions
116            .iter()
117            .find(|(n, _, _)| *n == name)
118            .map(|(_, f, a)| (*f, *a))
119    }
120
121    /// Checks if a function exists
122    pub fn exists(&self, name: &str) -> bool {
123        self.functions.iter().any(|(n, _, _)| *n == name)
124    }
125
126    /// Gets all function names
127    pub fn function_names(&self) -> Vec<&'static str> {
128        self.functions.iter().map(|(n, _, _)| *n).collect()
129    }
130}
131
132// Mathematical functions
133
134/// Helper to apply a unary function to either a scalar or raster
135fn apply_unary_fn<F>(value: &Value, f: F) -> Result<Value>
136where
137    F: Fn(f64) -> f64,
138{
139    match value {
140        Value::Number(x) => Ok(Value::Number(f(*x))),
141        Value::Raster(raster) => {
142            use oxigdal_core::types::RasterDataType;
143            let width = raster.width();
144            let height = raster.height();
145            let mut result =
146                oxigdal_core::buffer::RasterBuffer::zeros(width, height, RasterDataType::Float32);
147
148            for y in 0..height {
149                for x in 0..width {
150                    let pixel = raster
151                        .get_pixel(x, y)
152                        .map_err(crate::error::AlgorithmError::Core)?;
153                    let new_val = f(pixel);
154                    result
155                        .set_pixel(x, y, new_val)
156                        .map_err(crate::error::AlgorithmError::Core)?;
157                }
158            }
159
160            Ok(Value::Raster(Box::new(result)))
161        }
162        _ => Err(AlgorithmError::InvalidParameter {
163            parameter: "value",
164            message: "Expected number or raster".to_string(),
165        }),
166    }
167}
168
169/// Helper to apply a binary function to scalars or rasters
170fn apply_binary_fn<F>(left: &Value, right: &Value, f: F) -> Result<Value>
171where
172    F: Fn(f64, f64) -> f64,
173{
174    match (left, right) {
175        (Value::Number(l), Value::Number(r)) => Ok(Value::Number(f(*l, *r))),
176        (Value::Raster(raster), Value::Number(scalar))
177        | (Value::Number(scalar), Value::Raster(raster)) => {
178            use oxigdal_core::types::RasterDataType;
179            let width = raster.width();
180            let height = raster.height();
181            let mut result =
182                oxigdal_core::buffer::RasterBuffer::zeros(width, height, RasterDataType::Float32);
183
184            for y in 0..height {
185                for x in 0..width {
186                    let pixel = raster
187                        .get_pixel(x, y)
188                        .map_err(crate::error::AlgorithmError::Core)?;
189                    let new_val = f(pixel, *scalar);
190                    result
191                        .set_pixel(x, y, new_val)
192                        .map_err(crate::error::AlgorithmError::Core)?;
193                }
194            }
195
196            Ok(Value::Raster(Box::new(result)))
197        }
198        (Value::Raster(left_raster), Value::Raster(right_raster)) => {
199            use oxigdal_core::types::RasterDataType;
200            let width = left_raster.width();
201            let height = left_raster.height();
202
203            if right_raster.width() != width || right_raster.height() != height {
204                return Err(AlgorithmError::InvalidDimensions {
205                    message: "Rasters must have same dimensions",
206                    actual: right_raster.width() as usize,
207                    expected: width as usize,
208                });
209            }
210
211            let mut result =
212                oxigdal_core::buffer::RasterBuffer::zeros(width, height, RasterDataType::Float32);
213
214            for y in 0..height {
215                for x in 0..width {
216                    let left_pixel = left_raster
217                        .get_pixel(x, y)
218                        .map_err(crate::error::AlgorithmError::Core)?;
219                    let right_pixel = right_raster
220                        .get_pixel(x, y)
221                        .map_err(crate::error::AlgorithmError::Core)?;
222                    let new_val = f(left_pixel, right_pixel);
223                    result
224                        .set_pixel(x, y, new_val)
225                        .map_err(crate::error::AlgorithmError::Core)?;
226                }
227            }
228
229            Ok(Value::Raster(Box::new(result)))
230        }
231        _ => Err(AlgorithmError::InvalidParameter {
232            parameter: "value",
233            message: "Expected number or raster".to_string(),
234        }),
235    }
236}
237
238fn fn_sqrt(args: &[Value]) -> Result<Value> {
239    apply_unary_fn(&args[0], |x| x.sqrt())
240}
241
242fn fn_abs(args: &[Value]) -> Result<Value> {
243    apply_unary_fn(&args[0], |x| x.abs())
244}
245
246fn fn_floor(args: &[Value]) -> Result<Value> {
247    apply_unary_fn(&args[0], |x| x.floor())
248}
249
250fn fn_ceil(args: &[Value]) -> Result<Value> {
251    apply_unary_fn(&args[0], |x| x.ceil())
252}
253
254fn fn_round(args: &[Value]) -> Result<Value> {
255    apply_unary_fn(&args[0], |x| x.round())
256}
257
258fn fn_log(args: &[Value]) -> Result<Value> {
259    apply_unary_fn(&args[0], |x| x.ln())
260}
261
262fn fn_log10(args: &[Value]) -> Result<Value> {
263    apply_unary_fn(&args[0], |x| x.log10())
264}
265
266fn fn_log2(args: &[Value]) -> Result<Value> {
267    apply_unary_fn(&args[0], |x| x.log2())
268}
269
270fn fn_exp(args: &[Value]) -> Result<Value> {
271    apply_unary_fn(&args[0], |x| x.exp())
272}
273
274fn fn_sin(args: &[Value]) -> Result<Value> {
275    apply_unary_fn(&args[0], |x| x.sin())
276}
277
278fn fn_cos(args: &[Value]) -> Result<Value> {
279    apply_unary_fn(&args[0], |x| x.cos())
280}
281
282fn fn_tan(args: &[Value]) -> Result<Value> {
283    apply_unary_fn(&args[0], |x| x.tan())
284}
285
286fn fn_asin(args: &[Value]) -> Result<Value> {
287    apply_unary_fn(&args[0], |x| x.asin())
288}
289
290fn fn_acos(args: &[Value]) -> Result<Value> {
291    apply_unary_fn(&args[0], |x| x.acos())
292}
293
294fn fn_atan(args: &[Value]) -> Result<Value> {
295    apply_unary_fn(&args[0], |x| x.atan())
296}
297
298fn fn_sinh(args: &[Value]) -> Result<Value> {
299    apply_unary_fn(&args[0], |x| x.sinh())
300}
301
302fn fn_cosh(args: &[Value]) -> Result<Value> {
303    apply_unary_fn(&args[0], |x| x.cosh())
304}
305
306fn fn_tanh(args: &[Value]) -> Result<Value> {
307    apply_unary_fn(&args[0], |x| x.tanh())
308}
309
310fn fn_atan2(args: &[Value]) -> Result<Value> {
311    apply_binary_fn(&args[0], &args[1], |y, x| y.atan2(x))
312}
313
314fn fn_pow(args: &[Value]) -> Result<Value> {
315    apply_binary_fn(&args[0], &args[1], |base, exp| base.powf(exp))
316}
317
318fn fn_hypot(args: &[Value]) -> Result<Value> {
319    apply_binary_fn(&args[0], &args[1], |x, y| x.hypot(y))
320}
321
322fn fn_min(args: &[Value]) -> Result<Value> {
323    if args.is_empty() {
324        return Err(AlgorithmError::InvalidParameter {
325            parameter: "min",
326            message: "Expected at least 1 argument".to_string(),
327        });
328    }
329
330    let mut min_val = args[0].as_number()?;
331    for arg in &args[1..] {
332        let val = arg.as_number()?;
333        if val < min_val {
334            min_val = val;
335        }
336    }
337    Ok(Value::Number(min_val))
338}
339
340fn fn_max(args: &[Value]) -> Result<Value> {
341    if args.is_empty() {
342        return Err(AlgorithmError::InvalidParameter {
343            parameter: "max",
344            message: "Expected at least 1 argument".to_string(),
345        });
346    }
347
348    let mut max_val = args[0].as_number()?;
349    for arg in &args[1..] {
350        let val = arg.as_number()?;
351        if val > max_val {
352            max_val = val;
353        }
354    }
355    Ok(Value::Number(max_val))
356}
357
358// Statistical functions
359
360fn fn_mean(args: &[Value]) -> Result<Value> {
361    let raster = args[0].as_raster()?;
362    let mut sum = 0.0;
363    let mut count = 0u64;
364
365    for y in 0..raster.height() {
366        for x in 0..raster.width() {
367            if let Ok(val) = raster.get_pixel(x, y) {
368                if val.is_finite() {
369                    sum += val;
370                    count += 1;
371                }
372            }
373        }
374    }
375
376    if count == 0 {
377        return Err(AlgorithmError::EmptyInput { operation: "mean" });
378    }
379
380    Ok(Value::Number(sum / count as f64))
381}
382
383fn fn_median(_args: &[Value]) -> Result<Value> {
384    // Simplified version - full implementation would collect all values and sort
385    Err(AlgorithmError::InvalidParameter {
386        parameter: "median",
387        message: "Not yet implemented".to_string(),
388    })
389}
390
391fn fn_mode(_args: &[Value]) -> Result<Value> {
392    // Simplified version
393    Err(AlgorithmError::InvalidParameter {
394        parameter: "mode",
395        message: "Not yet implemented".to_string(),
396    })
397}
398
399fn fn_stddev(args: &[Value]) -> Result<Value> {
400    let raster = args[0].as_raster()?;
401    let mut sum = 0.0;
402    let mut sum_sq = 0.0;
403    let mut count = 0u64;
404
405    for y in 0..raster.height() {
406        for x in 0..raster.width() {
407            if let Ok(val) = raster.get_pixel(x, y) {
408                if val.is_finite() {
409                    sum += val;
410                    sum_sq += val * val;
411                    count += 1;
412                }
413            }
414        }
415    }
416
417    if count == 0 {
418        return Err(AlgorithmError::EmptyInput {
419            operation: "stddev",
420        });
421    }
422
423    let mean = sum / count as f64;
424    let variance = (sum_sq / count as f64) - (mean * mean);
425    Ok(Value::Number(variance.sqrt()))
426}
427
428fn fn_variance(args: &[Value]) -> Result<Value> {
429    let raster = args[0].as_raster()?;
430    let mut sum = 0.0;
431    let mut sum_sq = 0.0;
432    let mut count = 0u64;
433
434    for y in 0..raster.height() {
435        for x in 0..raster.width() {
436            if let Ok(val) = raster.get_pixel(x, y) {
437                if val.is_finite() {
438                    sum += val;
439                    sum_sq += val * val;
440                    count += 1;
441                }
442            }
443        }
444    }
445
446    if count == 0 {
447        return Err(AlgorithmError::EmptyInput {
448            operation: "variance",
449        });
450    }
451
452    let mean = sum / count as f64;
453    let variance = (sum_sq / count as f64) - (mean * mean);
454    Ok(Value::Number(variance))
455}
456
457fn fn_sum(args: &[Value]) -> Result<Value> {
458    let raster = args[0].as_raster()?;
459    let mut sum = 0.0;
460
461    for y in 0..raster.height() {
462        for x in 0..raster.width() {
463            if let Ok(val) = raster.get_pixel(x, y) {
464                if val.is_finite() {
465                    sum += val;
466                }
467            }
468        }
469    }
470
471    Ok(Value::Number(sum))
472}
473
474fn fn_product(args: &[Value]) -> Result<Value> {
475    let raster = args[0].as_raster()?;
476    let mut product = 1.0;
477
478    for y in 0..raster.height() {
479        for x in 0..raster.width() {
480            if let Ok(val) = raster.get_pixel(x, y) {
481                if val.is_finite() {
482                    product *= val;
483                }
484            }
485        }
486    }
487
488    Ok(Value::Number(product))
489}
490
491fn fn_percentile(_args: &[Value]) -> Result<Value> {
492    // Simplified version
493    Err(AlgorithmError::InvalidParameter {
494        parameter: "percentile",
495        message: "Not yet implemented".to_string(),
496    })
497}
498
499// Spatial filters
500
501fn fn_gaussian(args: &[Value]) -> Result<Value> {
502    let raster = args[0].as_raster()?;
503    let sigma = args[1].as_number()?;
504
505    let result = gaussian_blur(raster, sigma, None)?;
506    Ok(Value::Raster(Box::new(result)))
507}
508
509fn fn_median_filt(args: &[Value]) -> Result<Value> {
510    let raster = args[0].as_raster()?;
511    let radius = args[1].as_number()? as usize;
512
513    let result = median_filter(raster, radius)?;
514    Ok(Value::Raster(Box::new(result)))
515}
516
517// Logical functions
518
519fn fn_and(args: &[Value]) -> Result<Value> {
520    let a = args[0].as_bool()?;
521    let b = args[1].as_bool()?;
522    Ok(Value::Bool(a && b))
523}
524
525fn fn_or(args: &[Value]) -> Result<Value> {
526    let a = args[0].as_bool()?;
527    let b = args[1].as_bool()?;
528    Ok(Value::Bool(a || b))
529}
530
531fn fn_not(args: &[Value]) -> Result<Value> {
532    let a = args[0].as_bool()?;
533    Ok(Value::Bool(!a))
534}
535
536fn fn_xor(args: &[Value]) -> Result<Value> {
537    let a = args[0].as_bool()?;
538    let b = args[1].as_bool()?;
539    Ok(Value::Bool(a ^ b))
540}
541
542// Comparison functions
543
544fn fn_eq(args: &[Value]) -> Result<Value> {
545    let a = args[0].as_number()?;
546    let b = args[1].as_number()?;
547    Ok(Value::Bool((a - b).abs() < f64::EPSILON))
548}
549
550fn fn_ne(args: &[Value]) -> Result<Value> {
551    let a = args[0].as_number()?;
552    let b = args[1].as_number()?;
553    Ok(Value::Bool((a - b).abs() >= f64::EPSILON))
554}
555
556fn fn_lt(args: &[Value]) -> Result<Value> {
557    let a = args[0].as_number()?;
558    let b = args[1].as_number()?;
559    Ok(Value::Bool(a < b))
560}
561
562fn fn_le(args: &[Value]) -> Result<Value> {
563    let a = args[0].as_number()?;
564    let b = args[1].as_number()?;
565    Ok(Value::Bool(a <= b))
566}
567
568fn fn_gt(args: &[Value]) -> Result<Value> {
569    let a = args[0].as_number()?;
570    let b = args[1].as_number()?;
571    Ok(Value::Bool(a > b))
572}
573
574fn fn_ge(args: &[Value]) -> Result<Value> {
575    let a = args[0].as_number()?;
576    let b = args[1].as_number()?;
577    Ok(Value::Bool(a >= b))
578}
579
580// Type conversion
581
582fn fn_to_number(args: &[Value]) -> Result<Value> {
583    args[0].as_number().map(Value::Number)
584}
585
586fn fn_to_bool(args: &[Value]) -> Result<Value> {
587    args[0].as_bool().map(Value::Bool)
588}
589
590// Utility functions
591
592fn fn_clamp(args: &[Value]) -> Result<Value> {
593    let value = args[0].as_number()?;
594    let min = args[1].as_number()?;
595    let max = args[2].as_number()?;
596
597    let clamped = if value < min {
598        min
599    } else if value > max {
600        max
601    } else {
602        value
603    };
604
605    Ok(Value::Number(clamped))
606}
607
608fn fn_select(args: &[Value]) -> Result<Value> {
609    let cond = args[0].as_bool()?;
610    if cond {
611        Ok(args[1].clone())
612    } else {
613        Ok(args[2].clone())
614    }
615}
616
617#[cfg(test)]
618#[allow(clippy::panic)]
619mod tests {
620    use super::*;
621    use oxigdal_core::buffer::RasterBuffer;
622    use oxigdal_core::types::RasterDataType;
623
624    #[test]
625    fn test_function_registry() {
626        let registry = FunctionRegistry::new();
627        assert!(registry.exists("sqrt"));
628        assert!(registry.exists("sin"));
629        assert!(registry.exists("mean"));
630        assert!(!registry.exists("nonexistent"));
631    }
632
633    #[test]
634    fn test_math_functions() {
635        let args = vec![Value::Number(16.0)];
636        let result = fn_sqrt(&args).expect("Should work");
637        if let Value::Number(n) = result {
638            assert!((n - 4.0).abs() < 1e-10);
639        } else {
640            panic!("Expected number");
641        }
642    }
643
644    #[test]
645    fn test_min_max() {
646        let args = vec![
647            Value::Number(3.0),
648            Value::Number(1.0),
649            Value::Number(4.0),
650            Value::Number(1.0),
651            Value::Number(5.0),
652        ];
653
654        let min_result = fn_min(&args).expect("Should work");
655        if let Value::Number(n) = min_result {
656            assert!((n - 1.0).abs() < 1e-10);
657        }
658
659        let max_result = fn_max(&args).expect("Should work");
660        if let Value::Number(n) = max_result {
661            assert!((n - 5.0).abs() < 1e-10);
662        }
663    }
664
665    #[test]
666    fn test_mean() {
667        let mut raster = RasterBuffer::zeros(10, 10, RasterDataType::Float32);
668        for y in 0..10 {
669            for x in 0..10 {
670                let _ = raster.set_pixel(x, y, (x + y) as f64);
671            }
672        }
673
674        let args = vec![Value::Raster(Box::new(raster))];
675        let result = fn_mean(&args);
676        assert!(result.is_ok());
677    }
678
679    #[test]
680    fn test_logical_functions() {
681        let args_true = vec![Value::Bool(true), Value::Bool(true)];
682        let result = fn_and(&args_true).expect("Should work");
683        assert!(matches!(result, Value::Bool(true)));
684
685        let args_false = vec![Value::Bool(true), Value::Bool(false)];
686        let result = fn_and(&args_false).expect("Should work");
687        assert!(matches!(result, Value::Bool(false)));
688    }
689
690    #[test]
691    fn test_clamp() {
692        let args = vec![Value::Number(15.0), Value::Number(0.0), Value::Number(10.0)];
693        let result = fn_clamp(&args).expect("Should work");
694        if let Value::Number(n) = result {
695            assert!((n - 10.0).abs() < 1e-10);
696        }
697    }
698}