1use super::variables::Value;
11use crate::error::{AlgorithmError, Result};
12use crate::raster::{gaussian_blur, median_filter};
13
14#[cfg(not(feature = "std"))]
15use alloc::{boxed::Box, string::String, vec::Vec};
16
17pub type BuiltinFn = fn(&[Value]) -> Result<Value>;
19
20pub struct FunctionRegistry {
22 functions: Vec<(&'static str, BuiltinFn, usize)>, }
24
25impl Default for FunctionRegistry {
26 fn default() -> Self {
27 Self::new()
28 }
29}
30
31impl FunctionRegistry {
32 pub fn new() -> Self {
34 let mut registry = Self {
35 functions: Vec::new(),
36 };
37
38 registry.register("sqrt", fn_sqrt, 1);
40 registry.register("abs", fn_abs, 1);
41 registry.register("floor", fn_floor, 1);
42 registry.register("ceil", fn_ceil, 1);
43 registry.register("round", fn_round, 1);
44 registry.register("log", fn_log, 1);
45 registry.register("log10", fn_log10, 1);
46 registry.register("log2", fn_log2, 1);
47 registry.register("exp", fn_exp, 1);
48 registry.register("sin", fn_sin, 1);
49 registry.register("cos", fn_cos, 1);
50 registry.register("tan", fn_tan, 1);
51 registry.register("asin", fn_asin, 1);
52 registry.register("acos", fn_acos, 1);
53 registry.register("atan", fn_atan, 1);
54 registry.register("sinh", fn_sinh, 1);
55 registry.register("cosh", fn_cosh, 1);
56 registry.register("tanh", fn_tanh, 1);
57
58 registry.register("atan2", fn_atan2, 2);
60 registry.register("pow", fn_pow, 2);
61 registry.register("hypot", fn_hypot, 2);
62
63 registry.register("min", fn_min, 0);
65 registry.register("max", fn_max, 0);
66
67 registry.register("mean", fn_mean, 1);
69 registry.register("median", fn_median, 1);
70 registry.register("mode", fn_mode, 1);
71 registry.register("stddev", fn_stddev, 1);
72 registry.register("variance", fn_variance, 1);
73 registry.register("sum", fn_sum, 1);
74 registry.register("product", fn_product, 1);
75
76 registry.register("percentile", fn_percentile, 2);
78
79 registry.register("gaussian", fn_gaussian, 2);
81 registry.register("median_filter", fn_median_filt, 2);
82
83 registry.register("and", fn_and, 2);
85 registry.register("or", fn_or, 2);
86 registry.register("not", fn_not, 1);
87 registry.register("xor", fn_xor, 2);
88
89 registry.register("eq", fn_eq, 2);
91 registry.register("ne", fn_ne, 2);
92 registry.register("lt", fn_lt, 2);
93 registry.register("le", fn_le, 2);
94 registry.register("gt", fn_gt, 2);
95 registry.register("ge", fn_ge, 2);
96
97 registry.register("to_number", fn_to_number, 1);
99 registry.register("to_bool", fn_to_bool, 1);
100
101 registry.register("clamp", fn_clamp, 3);
103 registry.register("select", fn_select, 3);
104
105 registry
106 }
107
108 pub fn register(&mut self, name: &'static str, func: BuiltinFn, arity: usize) {
110 self.functions.push((name, func, arity));
111 }
112
113 pub fn lookup(&self, name: &str) -> Option<(BuiltinFn, usize)> {
115 self.functions
116 .iter()
117 .find(|(n, _, _)| *n == name)
118 .map(|(_, f, a)| (*f, *a))
119 }
120
121 pub fn exists(&self, name: &str) -> bool {
123 self.functions.iter().any(|(n, _, _)| *n == name)
124 }
125
126 pub fn function_names(&self) -> Vec<&'static str> {
128 self.functions.iter().map(|(n, _, _)| *n).collect()
129 }
130}
131
132fn apply_unary_fn<F>(value: &Value, f: F) -> Result<Value>
136where
137 F: Fn(f64) -> f64,
138{
139 match value {
140 Value::Number(x) => Ok(Value::Number(f(*x))),
141 Value::Raster(raster) => {
142 use oxigdal_core::types::RasterDataType;
143 let width = raster.width();
144 let height = raster.height();
145 let mut result =
146 oxigdal_core::buffer::RasterBuffer::zeros(width, height, RasterDataType::Float32);
147
148 for y in 0..height {
149 for x in 0..width {
150 let pixel = raster
151 .get_pixel(x, y)
152 .map_err(crate::error::AlgorithmError::Core)?;
153 let new_val = f(pixel);
154 result
155 .set_pixel(x, y, new_val)
156 .map_err(crate::error::AlgorithmError::Core)?;
157 }
158 }
159
160 Ok(Value::Raster(Box::new(result)))
161 }
162 _ => Err(AlgorithmError::InvalidParameter {
163 parameter: "value",
164 message: "Expected number or raster".to_string(),
165 }),
166 }
167}
168
169fn apply_binary_fn<F>(left: &Value, right: &Value, f: F) -> Result<Value>
171where
172 F: Fn(f64, f64) -> f64,
173{
174 match (left, right) {
175 (Value::Number(l), Value::Number(r)) => Ok(Value::Number(f(*l, *r))),
176 (Value::Raster(raster), Value::Number(scalar))
177 | (Value::Number(scalar), Value::Raster(raster)) => {
178 use oxigdal_core::types::RasterDataType;
179 let width = raster.width();
180 let height = raster.height();
181 let mut result =
182 oxigdal_core::buffer::RasterBuffer::zeros(width, height, RasterDataType::Float32);
183
184 for y in 0..height {
185 for x in 0..width {
186 let pixel = raster
187 .get_pixel(x, y)
188 .map_err(crate::error::AlgorithmError::Core)?;
189 let new_val = f(pixel, *scalar);
190 result
191 .set_pixel(x, y, new_val)
192 .map_err(crate::error::AlgorithmError::Core)?;
193 }
194 }
195
196 Ok(Value::Raster(Box::new(result)))
197 }
198 (Value::Raster(left_raster), Value::Raster(right_raster)) => {
199 use oxigdal_core::types::RasterDataType;
200 let width = left_raster.width();
201 let height = left_raster.height();
202
203 if right_raster.width() != width || right_raster.height() != height {
204 return Err(AlgorithmError::InvalidDimensions {
205 message: "Rasters must have same dimensions",
206 actual: right_raster.width() as usize,
207 expected: width as usize,
208 });
209 }
210
211 let mut result =
212 oxigdal_core::buffer::RasterBuffer::zeros(width, height, RasterDataType::Float32);
213
214 for y in 0..height {
215 for x in 0..width {
216 let left_pixel = left_raster
217 .get_pixel(x, y)
218 .map_err(crate::error::AlgorithmError::Core)?;
219 let right_pixel = right_raster
220 .get_pixel(x, y)
221 .map_err(crate::error::AlgorithmError::Core)?;
222 let new_val = f(left_pixel, right_pixel);
223 result
224 .set_pixel(x, y, new_val)
225 .map_err(crate::error::AlgorithmError::Core)?;
226 }
227 }
228
229 Ok(Value::Raster(Box::new(result)))
230 }
231 _ => Err(AlgorithmError::InvalidParameter {
232 parameter: "value",
233 message: "Expected number or raster".to_string(),
234 }),
235 }
236}
237
238fn fn_sqrt(args: &[Value]) -> Result<Value> {
239 apply_unary_fn(&args[0], |x| x.sqrt())
240}
241
242fn fn_abs(args: &[Value]) -> Result<Value> {
243 apply_unary_fn(&args[0], |x| x.abs())
244}
245
246fn fn_floor(args: &[Value]) -> Result<Value> {
247 apply_unary_fn(&args[0], |x| x.floor())
248}
249
250fn fn_ceil(args: &[Value]) -> Result<Value> {
251 apply_unary_fn(&args[0], |x| x.ceil())
252}
253
254fn fn_round(args: &[Value]) -> Result<Value> {
255 apply_unary_fn(&args[0], |x| x.round())
256}
257
258fn fn_log(args: &[Value]) -> Result<Value> {
259 apply_unary_fn(&args[0], |x| x.ln())
260}
261
262fn fn_log10(args: &[Value]) -> Result<Value> {
263 apply_unary_fn(&args[0], |x| x.log10())
264}
265
266fn fn_log2(args: &[Value]) -> Result<Value> {
267 apply_unary_fn(&args[0], |x| x.log2())
268}
269
270fn fn_exp(args: &[Value]) -> Result<Value> {
271 apply_unary_fn(&args[0], |x| x.exp())
272}
273
274fn fn_sin(args: &[Value]) -> Result<Value> {
275 apply_unary_fn(&args[0], |x| x.sin())
276}
277
278fn fn_cos(args: &[Value]) -> Result<Value> {
279 apply_unary_fn(&args[0], |x| x.cos())
280}
281
282fn fn_tan(args: &[Value]) -> Result<Value> {
283 apply_unary_fn(&args[0], |x| x.tan())
284}
285
286fn fn_asin(args: &[Value]) -> Result<Value> {
287 apply_unary_fn(&args[0], |x| x.asin())
288}
289
290fn fn_acos(args: &[Value]) -> Result<Value> {
291 apply_unary_fn(&args[0], |x| x.acos())
292}
293
294fn fn_atan(args: &[Value]) -> Result<Value> {
295 apply_unary_fn(&args[0], |x| x.atan())
296}
297
298fn fn_sinh(args: &[Value]) -> Result<Value> {
299 apply_unary_fn(&args[0], |x| x.sinh())
300}
301
302fn fn_cosh(args: &[Value]) -> Result<Value> {
303 apply_unary_fn(&args[0], |x| x.cosh())
304}
305
306fn fn_tanh(args: &[Value]) -> Result<Value> {
307 apply_unary_fn(&args[0], |x| x.tanh())
308}
309
310fn fn_atan2(args: &[Value]) -> Result<Value> {
311 apply_binary_fn(&args[0], &args[1], |y, x| y.atan2(x))
312}
313
314fn fn_pow(args: &[Value]) -> Result<Value> {
315 apply_binary_fn(&args[0], &args[1], |base, exp| base.powf(exp))
316}
317
318fn fn_hypot(args: &[Value]) -> Result<Value> {
319 apply_binary_fn(&args[0], &args[1], |x, y| x.hypot(y))
320}
321
322fn fn_min(args: &[Value]) -> Result<Value> {
323 if args.is_empty() {
324 return Err(AlgorithmError::InvalidParameter {
325 parameter: "min",
326 message: "Expected at least 1 argument".to_string(),
327 });
328 }
329
330 let mut min_val = args[0].as_number()?;
331 for arg in &args[1..] {
332 let val = arg.as_number()?;
333 if val < min_val {
334 min_val = val;
335 }
336 }
337 Ok(Value::Number(min_val))
338}
339
340fn fn_max(args: &[Value]) -> Result<Value> {
341 if args.is_empty() {
342 return Err(AlgorithmError::InvalidParameter {
343 parameter: "max",
344 message: "Expected at least 1 argument".to_string(),
345 });
346 }
347
348 let mut max_val = args[0].as_number()?;
349 for arg in &args[1..] {
350 let val = arg.as_number()?;
351 if val > max_val {
352 max_val = val;
353 }
354 }
355 Ok(Value::Number(max_val))
356}
357
358fn fn_mean(args: &[Value]) -> Result<Value> {
361 let raster = args[0].as_raster()?;
362 let mut sum = 0.0;
363 let mut count = 0u64;
364
365 for y in 0..raster.height() {
366 for x in 0..raster.width() {
367 if let Ok(val) = raster.get_pixel(x, y) {
368 if val.is_finite() {
369 sum += val;
370 count += 1;
371 }
372 }
373 }
374 }
375
376 if count == 0 {
377 return Err(AlgorithmError::EmptyInput { operation: "mean" });
378 }
379
380 Ok(Value::Number(sum / count as f64))
381}
382
383fn fn_median(args: &[Value]) -> Result<Value> {
384 let raster = args[0].as_raster()?;
385 let mut values: Vec<f64> = Vec::with_capacity((raster.width() * raster.height()) as usize);
386
387 for y in 0..raster.height() {
388 for x in 0..raster.width() {
389 if let Ok(val) = raster.get_pixel(x, y) {
390 if val.is_finite() {
391 values.push(val);
392 }
393 }
394 }
395 }
396
397 if values.is_empty() {
398 return Err(AlgorithmError::EmptyInput {
399 operation: "median",
400 });
401 }
402
403 values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
405
406 let mid = values.len() / 2;
407 let median = if values.len() % 2 == 0 {
408 (values[mid - 1] + values[mid]) / 2.0
409 } else {
410 values[mid]
411 };
412
413 Ok(Value::Number(median))
414}
415
416fn fn_mode(args: &[Value]) -> Result<Value> {
417 let raster = args[0].as_raster()?;
418
419 use std::collections::HashMap;
422 let mut freq: HashMap<u64, (f64, u64)> = HashMap::new();
423
424 for y in 0..raster.height() {
425 for x in 0..raster.width() {
426 if let Ok(val) = raster.get_pixel(x, y) {
427 if val.is_finite() {
428 let key = val.to_bits();
429 let entry = freq.entry(key).or_insert((val, 0));
430 entry.1 += 1;
431 }
432 }
433 }
434 }
435
436 if freq.is_empty() {
437 return Err(AlgorithmError::EmptyInput { operation: "mode" });
438 }
439
440 let mode = freq
442 .values()
443 .max_by(|a, b| {
444 a.1.cmp(&b.1)
445 .then_with(|| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal))
446 })
447 .map(|(val, _)| *val)
448 .ok_or(AlgorithmError::EmptyInput { operation: "mode" })?;
449
450 Ok(Value::Number(mode))
451}
452
453fn fn_stddev(args: &[Value]) -> Result<Value> {
454 let raster = args[0].as_raster()?;
455 let mut sum = 0.0;
456 let mut sum_sq = 0.0;
457 let mut count = 0u64;
458
459 for y in 0..raster.height() {
460 for x in 0..raster.width() {
461 if let Ok(val) = raster.get_pixel(x, y) {
462 if val.is_finite() {
463 sum += val;
464 sum_sq += val * val;
465 count += 1;
466 }
467 }
468 }
469 }
470
471 if count == 0 {
472 return Err(AlgorithmError::EmptyInput {
473 operation: "stddev",
474 });
475 }
476
477 let mean = sum / count as f64;
478 let variance = (sum_sq / count as f64) - (mean * mean);
479 Ok(Value::Number(variance.sqrt()))
480}
481
482fn fn_variance(args: &[Value]) -> Result<Value> {
483 let raster = args[0].as_raster()?;
484 let mut sum = 0.0;
485 let mut sum_sq = 0.0;
486 let mut count = 0u64;
487
488 for y in 0..raster.height() {
489 for x in 0..raster.width() {
490 if let Ok(val) = raster.get_pixel(x, y) {
491 if val.is_finite() {
492 sum += val;
493 sum_sq += val * val;
494 count += 1;
495 }
496 }
497 }
498 }
499
500 if count == 0 {
501 return Err(AlgorithmError::EmptyInput {
502 operation: "variance",
503 });
504 }
505
506 let mean = sum / count as f64;
507 let variance = (sum_sq / count as f64) - (mean * mean);
508 Ok(Value::Number(variance))
509}
510
511fn fn_sum(args: &[Value]) -> Result<Value> {
512 let raster = args[0].as_raster()?;
513 let mut sum = 0.0;
514
515 for y in 0..raster.height() {
516 for x in 0..raster.width() {
517 if let Ok(val) = raster.get_pixel(x, y) {
518 if val.is_finite() {
519 sum += val;
520 }
521 }
522 }
523 }
524
525 Ok(Value::Number(sum))
526}
527
528fn fn_product(args: &[Value]) -> Result<Value> {
529 let raster = args[0].as_raster()?;
530 let mut product = 1.0;
531
532 for y in 0..raster.height() {
533 for x in 0..raster.width() {
534 if let Ok(val) = raster.get_pixel(x, y) {
535 if val.is_finite() {
536 product *= val;
537 }
538 }
539 }
540 }
541
542 Ok(Value::Number(product))
543}
544
545fn fn_percentile(args: &[Value]) -> Result<Value> {
546 let raster = args[0].as_raster()?;
547 let p = args[1].as_number()?;
548
549 if !(0.0..=100.0).contains(&p) {
550 return Err(AlgorithmError::InvalidParameter {
551 parameter: "percentile",
552 message: format!("Percentile must be in [0, 100], got {p}"),
553 });
554 }
555
556 let mut values: Vec<f64> = Vec::with_capacity((raster.width() * raster.height()) as usize);
557
558 for y in 0..raster.height() {
559 for x in 0..raster.width() {
560 if let Ok(val) = raster.get_pixel(x, y) {
561 if val.is_finite() {
562 values.push(val);
563 }
564 }
565 }
566 }
567
568 if values.is_empty() {
569 return Err(AlgorithmError::EmptyInput {
570 operation: "percentile",
571 });
572 }
573
574 values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
575
576 let n = values.len();
577 if n == 1 {
578 return Ok(Value::Number(values[0]));
579 }
580
581 let rank = p / 100.0 * (n - 1) as f64;
583 let lower = rank.floor() as usize;
584 let upper = (lower + 1).min(n - 1);
585 let frac = rank - lower as f64;
586 let result = values[lower] + frac * (values[upper] - values[lower]);
587
588 Ok(Value::Number(result))
589}
590
591fn fn_gaussian(args: &[Value]) -> Result<Value> {
594 let raster = args[0].as_raster()?;
595 let sigma = args[1].as_number()?;
596
597 let result = gaussian_blur(raster, sigma, None)?;
598 Ok(Value::Raster(Box::new(result)))
599}
600
601fn fn_median_filt(args: &[Value]) -> Result<Value> {
602 let raster = args[0].as_raster()?;
603 let radius = args[1].as_number()? as usize;
604
605 let result = median_filter(raster, radius)?;
606 Ok(Value::Raster(Box::new(result)))
607}
608
609fn fn_and(args: &[Value]) -> Result<Value> {
612 let a = args[0].as_bool()?;
613 let b = args[1].as_bool()?;
614 Ok(Value::Bool(a && b))
615}
616
617fn fn_or(args: &[Value]) -> Result<Value> {
618 let a = args[0].as_bool()?;
619 let b = args[1].as_bool()?;
620 Ok(Value::Bool(a || b))
621}
622
623fn fn_not(args: &[Value]) -> Result<Value> {
624 let a = args[0].as_bool()?;
625 Ok(Value::Bool(!a))
626}
627
628fn fn_xor(args: &[Value]) -> Result<Value> {
629 let a = args[0].as_bool()?;
630 let b = args[1].as_bool()?;
631 Ok(Value::Bool(a ^ b))
632}
633
634fn fn_eq(args: &[Value]) -> Result<Value> {
637 let a = args[0].as_number()?;
638 let b = args[1].as_number()?;
639 Ok(Value::Bool((a - b).abs() < f64::EPSILON))
640}
641
642fn fn_ne(args: &[Value]) -> Result<Value> {
643 let a = args[0].as_number()?;
644 let b = args[1].as_number()?;
645 Ok(Value::Bool((a - b).abs() >= f64::EPSILON))
646}
647
648fn fn_lt(args: &[Value]) -> Result<Value> {
649 let a = args[0].as_number()?;
650 let b = args[1].as_number()?;
651 Ok(Value::Bool(a < b))
652}
653
654fn fn_le(args: &[Value]) -> Result<Value> {
655 let a = args[0].as_number()?;
656 let b = args[1].as_number()?;
657 Ok(Value::Bool(a <= b))
658}
659
660fn fn_gt(args: &[Value]) -> Result<Value> {
661 let a = args[0].as_number()?;
662 let b = args[1].as_number()?;
663 Ok(Value::Bool(a > b))
664}
665
666fn fn_ge(args: &[Value]) -> Result<Value> {
667 let a = args[0].as_number()?;
668 let b = args[1].as_number()?;
669 Ok(Value::Bool(a >= b))
670}
671
672fn fn_to_number(args: &[Value]) -> Result<Value> {
675 args[0].as_number().map(Value::Number)
676}
677
678fn fn_to_bool(args: &[Value]) -> Result<Value> {
679 args[0].as_bool().map(Value::Bool)
680}
681
682fn fn_clamp(args: &[Value]) -> Result<Value> {
685 let value = args[0].as_number()?;
686 let min = args[1].as_number()?;
687 let max = args[2].as_number()?;
688
689 let clamped = if value < min {
690 min
691 } else if value > max {
692 max
693 } else {
694 value
695 };
696
697 Ok(Value::Number(clamped))
698}
699
700fn fn_select(args: &[Value]) -> Result<Value> {
701 let cond = args[0].as_bool()?;
702 if cond {
703 Ok(args[1].clone())
704 } else {
705 Ok(args[2].clone())
706 }
707}
708
709#[cfg(test)]
710#[allow(clippy::panic)]
711mod tests {
712 use super::*;
713 use oxigdal_core::buffer::RasterBuffer;
714 use oxigdal_core::types::RasterDataType;
715
716 #[test]
717 fn test_function_registry() {
718 let registry = FunctionRegistry::new();
719 assert!(registry.exists("sqrt"));
720 assert!(registry.exists("sin"));
721 assert!(registry.exists("mean"));
722 assert!(!registry.exists("nonexistent"));
723 }
724
725 #[test]
726 fn test_math_functions() {
727 let args = vec![Value::Number(16.0)];
728 let result = fn_sqrt(&args).expect("Should work");
729 if let Value::Number(n) = result {
730 assert!((n - 4.0).abs() < 1e-10);
731 } else {
732 panic!("Expected number");
733 }
734 }
735
736 #[test]
737 fn test_min_max() {
738 let args = vec![
739 Value::Number(3.0),
740 Value::Number(1.0),
741 Value::Number(4.0),
742 Value::Number(1.0),
743 Value::Number(5.0),
744 ];
745
746 let min_result = fn_min(&args).expect("Should work");
747 if let Value::Number(n) = min_result {
748 assert!((n - 1.0).abs() < 1e-10);
749 }
750
751 let max_result = fn_max(&args).expect("Should work");
752 if let Value::Number(n) = max_result {
753 assert!((n - 5.0).abs() < 1e-10);
754 }
755 }
756
757 #[test]
758 fn test_mean() {
759 let mut raster = RasterBuffer::zeros(10, 10, RasterDataType::Float32);
760 for y in 0..10 {
761 for x in 0..10 {
762 let _ = raster.set_pixel(x, y, (x + y) as f64);
763 }
764 }
765
766 let args = vec![Value::Raster(Box::new(raster))];
767 let result = fn_mean(&args);
768 assert!(result.is_ok());
769 }
770
771 #[test]
772 fn test_logical_functions() {
773 let args_true = vec![Value::Bool(true), Value::Bool(true)];
774 let result = fn_and(&args_true).expect("Should work");
775 assert!(matches!(result, Value::Bool(true)));
776
777 let args_false = vec![Value::Bool(true), Value::Bool(false)];
778 let result = fn_and(&args_false).expect("Should work");
779 assert!(matches!(result, Value::Bool(false)));
780 }
781
782 #[test]
783 fn test_clamp() {
784 let args = vec![Value::Number(15.0), Value::Number(0.0), Value::Number(10.0)];
785 let result = fn_clamp(&args).expect("Should work");
786 if let Value::Number(n) = result {
787 assert!((n - 10.0).abs() < 1e-10);
788 }
789 }
790}