1use super::variables::Value;
11use crate::error::{AlgorithmError, Result};
12use crate::raster::{gaussian_blur, median_filter};
13
14#[cfg(not(feature = "std"))]
15use alloc::{boxed::Box, string::String, vec::Vec};
16
17pub type BuiltinFn = fn(&[Value]) -> Result<Value>;
19
20pub struct FunctionRegistry {
22 functions: Vec<(&'static str, BuiltinFn, usize)>, }
24
25impl Default for FunctionRegistry {
26 fn default() -> Self {
27 Self::new()
28 }
29}
30
31impl FunctionRegistry {
32 pub fn new() -> Self {
34 let mut registry = Self {
35 functions: Vec::new(),
36 };
37
38 registry.register("sqrt", fn_sqrt, 1);
40 registry.register("abs", fn_abs, 1);
41 registry.register("floor", fn_floor, 1);
42 registry.register("ceil", fn_ceil, 1);
43 registry.register("round", fn_round, 1);
44 registry.register("log", fn_log, 1);
45 registry.register("log10", fn_log10, 1);
46 registry.register("log2", fn_log2, 1);
47 registry.register("exp", fn_exp, 1);
48 registry.register("sin", fn_sin, 1);
49 registry.register("cos", fn_cos, 1);
50 registry.register("tan", fn_tan, 1);
51 registry.register("asin", fn_asin, 1);
52 registry.register("acos", fn_acos, 1);
53 registry.register("atan", fn_atan, 1);
54 registry.register("sinh", fn_sinh, 1);
55 registry.register("cosh", fn_cosh, 1);
56 registry.register("tanh", fn_tanh, 1);
57
58 registry.register("atan2", fn_atan2, 2);
60 registry.register("pow", fn_pow, 2);
61 registry.register("hypot", fn_hypot, 2);
62
63 registry.register("min", fn_min, 0);
65 registry.register("max", fn_max, 0);
66
67 registry.register("mean", fn_mean, 1);
69 registry.register("median", fn_median, 1);
70 registry.register("mode", fn_mode, 1);
71 registry.register("stddev", fn_stddev, 1);
72 registry.register("variance", fn_variance, 1);
73 registry.register("sum", fn_sum, 1);
74 registry.register("product", fn_product, 1);
75
76 registry.register("percentile", fn_percentile, 2);
78
79 registry.register("gaussian", fn_gaussian, 2);
81 registry.register("median_filter", fn_median_filt, 2);
82
83 registry.register("and", fn_and, 2);
85 registry.register("or", fn_or, 2);
86 registry.register("not", fn_not, 1);
87 registry.register("xor", fn_xor, 2);
88
89 registry.register("eq", fn_eq, 2);
91 registry.register("ne", fn_ne, 2);
92 registry.register("lt", fn_lt, 2);
93 registry.register("le", fn_le, 2);
94 registry.register("gt", fn_gt, 2);
95 registry.register("ge", fn_ge, 2);
96
97 registry.register("to_number", fn_to_number, 1);
99 registry.register("to_bool", fn_to_bool, 1);
100
101 registry.register("clamp", fn_clamp, 3);
103 registry.register("select", fn_select, 3);
104
105 registry
106 }
107
108 pub fn register(&mut self, name: &'static str, func: BuiltinFn, arity: usize) {
110 self.functions.push((name, func, arity));
111 }
112
113 pub fn lookup(&self, name: &str) -> Option<(BuiltinFn, usize)> {
115 self.functions
116 .iter()
117 .find(|(n, _, _)| *n == name)
118 .map(|(_, f, a)| (*f, *a))
119 }
120
121 pub fn exists(&self, name: &str) -> bool {
123 self.functions.iter().any(|(n, _, _)| *n == name)
124 }
125
126 pub fn function_names(&self) -> Vec<&'static str> {
128 self.functions.iter().map(|(n, _, _)| *n).collect()
129 }
130}
131
132fn apply_unary_fn<F>(value: &Value, f: F) -> Result<Value>
136where
137 F: Fn(f64) -> f64,
138{
139 match value {
140 Value::Number(x) => Ok(Value::Number(f(*x))),
141 Value::Raster(raster) => {
142 use oxigdal_core::types::RasterDataType;
143 let width = raster.width();
144 let height = raster.height();
145 let mut result =
146 oxigdal_core::buffer::RasterBuffer::zeros(width, height, RasterDataType::Float32);
147
148 for y in 0..height {
149 for x in 0..width {
150 let pixel = raster
151 .get_pixel(x, y)
152 .map_err(crate::error::AlgorithmError::Core)?;
153 let new_val = f(pixel);
154 result
155 .set_pixel(x, y, new_val)
156 .map_err(crate::error::AlgorithmError::Core)?;
157 }
158 }
159
160 Ok(Value::Raster(Box::new(result)))
161 }
162 _ => Err(AlgorithmError::InvalidParameter {
163 parameter: "value",
164 message: "Expected number or raster".to_string(),
165 }),
166 }
167}
168
169fn apply_binary_fn<F>(left: &Value, right: &Value, f: F) -> Result<Value>
171where
172 F: Fn(f64, f64) -> f64,
173{
174 match (left, right) {
175 (Value::Number(l), Value::Number(r)) => Ok(Value::Number(f(*l, *r))),
176 (Value::Raster(raster), Value::Number(scalar))
177 | (Value::Number(scalar), Value::Raster(raster)) => {
178 use oxigdal_core::types::RasterDataType;
179 let width = raster.width();
180 let height = raster.height();
181 let mut result =
182 oxigdal_core::buffer::RasterBuffer::zeros(width, height, RasterDataType::Float32);
183
184 for y in 0..height {
185 for x in 0..width {
186 let pixel = raster
187 .get_pixel(x, y)
188 .map_err(crate::error::AlgorithmError::Core)?;
189 let new_val = f(pixel, *scalar);
190 result
191 .set_pixel(x, y, new_val)
192 .map_err(crate::error::AlgorithmError::Core)?;
193 }
194 }
195
196 Ok(Value::Raster(Box::new(result)))
197 }
198 (Value::Raster(left_raster), Value::Raster(right_raster)) => {
199 use oxigdal_core::types::RasterDataType;
200 let width = left_raster.width();
201 let height = left_raster.height();
202
203 if right_raster.width() != width || right_raster.height() != height {
204 return Err(AlgorithmError::InvalidDimensions {
205 message: "Rasters must have same dimensions",
206 actual: right_raster.width() as usize,
207 expected: width as usize,
208 });
209 }
210
211 let mut result =
212 oxigdal_core::buffer::RasterBuffer::zeros(width, height, RasterDataType::Float32);
213
214 for y in 0..height {
215 for x in 0..width {
216 let left_pixel = left_raster
217 .get_pixel(x, y)
218 .map_err(crate::error::AlgorithmError::Core)?;
219 let right_pixel = right_raster
220 .get_pixel(x, y)
221 .map_err(crate::error::AlgorithmError::Core)?;
222 let new_val = f(left_pixel, right_pixel);
223 result
224 .set_pixel(x, y, new_val)
225 .map_err(crate::error::AlgorithmError::Core)?;
226 }
227 }
228
229 Ok(Value::Raster(Box::new(result)))
230 }
231 _ => Err(AlgorithmError::InvalidParameter {
232 parameter: "value",
233 message: "Expected number or raster".to_string(),
234 }),
235 }
236}
237
238fn fn_sqrt(args: &[Value]) -> Result<Value> {
239 apply_unary_fn(&args[0], |x| x.sqrt())
240}
241
242fn fn_abs(args: &[Value]) -> Result<Value> {
243 apply_unary_fn(&args[0], |x| x.abs())
244}
245
246fn fn_floor(args: &[Value]) -> Result<Value> {
247 apply_unary_fn(&args[0], |x| x.floor())
248}
249
250fn fn_ceil(args: &[Value]) -> Result<Value> {
251 apply_unary_fn(&args[0], |x| x.ceil())
252}
253
254fn fn_round(args: &[Value]) -> Result<Value> {
255 apply_unary_fn(&args[0], |x| x.round())
256}
257
258fn fn_log(args: &[Value]) -> Result<Value> {
259 apply_unary_fn(&args[0], |x| x.ln())
260}
261
262fn fn_log10(args: &[Value]) -> Result<Value> {
263 apply_unary_fn(&args[0], |x| x.log10())
264}
265
266fn fn_log2(args: &[Value]) -> Result<Value> {
267 apply_unary_fn(&args[0], |x| x.log2())
268}
269
270fn fn_exp(args: &[Value]) -> Result<Value> {
271 apply_unary_fn(&args[0], |x| x.exp())
272}
273
274fn fn_sin(args: &[Value]) -> Result<Value> {
275 apply_unary_fn(&args[0], |x| x.sin())
276}
277
278fn fn_cos(args: &[Value]) -> Result<Value> {
279 apply_unary_fn(&args[0], |x| x.cos())
280}
281
282fn fn_tan(args: &[Value]) -> Result<Value> {
283 apply_unary_fn(&args[0], |x| x.tan())
284}
285
286fn fn_asin(args: &[Value]) -> Result<Value> {
287 apply_unary_fn(&args[0], |x| x.asin())
288}
289
290fn fn_acos(args: &[Value]) -> Result<Value> {
291 apply_unary_fn(&args[0], |x| x.acos())
292}
293
294fn fn_atan(args: &[Value]) -> Result<Value> {
295 apply_unary_fn(&args[0], |x| x.atan())
296}
297
298fn fn_sinh(args: &[Value]) -> Result<Value> {
299 apply_unary_fn(&args[0], |x| x.sinh())
300}
301
302fn fn_cosh(args: &[Value]) -> Result<Value> {
303 apply_unary_fn(&args[0], |x| x.cosh())
304}
305
306fn fn_tanh(args: &[Value]) -> Result<Value> {
307 apply_unary_fn(&args[0], |x| x.tanh())
308}
309
310fn fn_atan2(args: &[Value]) -> Result<Value> {
311 apply_binary_fn(&args[0], &args[1], |y, x| y.atan2(x))
312}
313
314fn fn_pow(args: &[Value]) -> Result<Value> {
315 apply_binary_fn(&args[0], &args[1], |base, exp| base.powf(exp))
316}
317
318fn fn_hypot(args: &[Value]) -> Result<Value> {
319 apply_binary_fn(&args[0], &args[1], |x, y| x.hypot(y))
320}
321
322fn fn_min(args: &[Value]) -> Result<Value> {
323 if args.is_empty() {
324 return Err(AlgorithmError::InvalidParameter {
325 parameter: "min",
326 message: "Expected at least 1 argument".to_string(),
327 });
328 }
329
330 let mut min_val = args[0].as_number()?;
331 for arg in &args[1..] {
332 let val = arg.as_number()?;
333 if val < min_val {
334 min_val = val;
335 }
336 }
337 Ok(Value::Number(min_val))
338}
339
340fn fn_max(args: &[Value]) -> Result<Value> {
341 if args.is_empty() {
342 return Err(AlgorithmError::InvalidParameter {
343 parameter: "max",
344 message: "Expected at least 1 argument".to_string(),
345 });
346 }
347
348 let mut max_val = args[0].as_number()?;
349 for arg in &args[1..] {
350 let val = arg.as_number()?;
351 if val > max_val {
352 max_val = val;
353 }
354 }
355 Ok(Value::Number(max_val))
356}
357
358fn fn_mean(args: &[Value]) -> Result<Value> {
361 let raster = args[0].as_raster()?;
362 let mut sum = 0.0;
363 let mut count = 0u64;
364
365 for y in 0..raster.height() {
366 for x in 0..raster.width() {
367 if let Ok(val) = raster.get_pixel(x, y) {
368 if val.is_finite() {
369 sum += val;
370 count += 1;
371 }
372 }
373 }
374 }
375
376 if count == 0 {
377 return Err(AlgorithmError::EmptyInput { operation: "mean" });
378 }
379
380 Ok(Value::Number(sum / count as f64))
381}
382
383fn fn_median(_args: &[Value]) -> Result<Value> {
384 Err(AlgorithmError::InvalidParameter {
386 parameter: "median",
387 message: "Not yet implemented".to_string(),
388 })
389}
390
391fn fn_mode(_args: &[Value]) -> Result<Value> {
392 Err(AlgorithmError::InvalidParameter {
394 parameter: "mode",
395 message: "Not yet implemented".to_string(),
396 })
397}
398
399fn fn_stddev(args: &[Value]) -> Result<Value> {
400 let raster = args[0].as_raster()?;
401 let mut sum = 0.0;
402 let mut sum_sq = 0.0;
403 let mut count = 0u64;
404
405 for y in 0..raster.height() {
406 for x in 0..raster.width() {
407 if let Ok(val) = raster.get_pixel(x, y) {
408 if val.is_finite() {
409 sum += val;
410 sum_sq += val * val;
411 count += 1;
412 }
413 }
414 }
415 }
416
417 if count == 0 {
418 return Err(AlgorithmError::EmptyInput {
419 operation: "stddev",
420 });
421 }
422
423 let mean = sum / count as f64;
424 let variance = (sum_sq / count as f64) - (mean * mean);
425 Ok(Value::Number(variance.sqrt()))
426}
427
428fn fn_variance(args: &[Value]) -> Result<Value> {
429 let raster = args[0].as_raster()?;
430 let mut sum = 0.0;
431 let mut sum_sq = 0.0;
432 let mut count = 0u64;
433
434 for y in 0..raster.height() {
435 for x in 0..raster.width() {
436 if let Ok(val) = raster.get_pixel(x, y) {
437 if val.is_finite() {
438 sum += val;
439 sum_sq += val * val;
440 count += 1;
441 }
442 }
443 }
444 }
445
446 if count == 0 {
447 return Err(AlgorithmError::EmptyInput {
448 operation: "variance",
449 });
450 }
451
452 let mean = sum / count as f64;
453 let variance = (sum_sq / count as f64) - (mean * mean);
454 Ok(Value::Number(variance))
455}
456
457fn fn_sum(args: &[Value]) -> Result<Value> {
458 let raster = args[0].as_raster()?;
459 let mut sum = 0.0;
460
461 for y in 0..raster.height() {
462 for x in 0..raster.width() {
463 if let Ok(val) = raster.get_pixel(x, y) {
464 if val.is_finite() {
465 sum += val;
466 }
467 }
468 }
469 }
470
471 Ok(Value::Number(sum))
472}
473
474fn fn_product(args: &[Value]) -> Result<Value> {
475 let raster = args[0].as_raster()?;
476 let mut product = 1.0;
477
478 for y in 0..raster.height() {
479 for x in 0..raster.width() {
480 if let Ok(val) = raster.get_pixel(x, y) {
481 if val.is_finite() {
482 product *= val;
483 }
484 }
485 }
486 }
487
488 Ok(Value::Number(product))
489}
490
491fn fn_percentile(_args: &[Value]) -> Result<Value> {
492 Err(AlgorithmError::InvalidParameter {
494 parameter: "percentile",
495 message: "Not yet implemented".to_string(),
496 })
497}
498
499fn fn_gaussian(args: &[Value]) -> Result<Value> {
502 let raster = args[0].as_raster()?;
503 let sigma = args[1].as_number()?;
504
505 let result = gaussian_blur(raster, sigma, None)?;
506 Ok(Value::Raster(Box::new(result)))
507}
508
509fn fn_median_filt(args: &[Value]) -> Result<Value> {
510 let raster = args[0].as_raster()?;
511 let radius = args[1].as_number()? as usize;
512
513 let result = median_filter(raster, radius)?;
514 Ok(Value::Raster(Box::new(result)))
515}
516
517fn fn_and(args: &[Value]) -> Result<Value> {
520 let a = args[0].as_bool()?;
521 let b = args[1].as_bool()?;
522 Ok(Value::Bool(a && b))
523}
524
525fn fn_or(args: &[Value]) -> Result<Value> {
526 let a = args[0].as_bool()?;
527 let b = args[1].as_bool()?;
528 Ok(Value::Bool(a || b))
529}
530
531fn fn_not(args: &[Value]) -> Result<Value> {
532 let a = args[0].as_bool()?;
533 Ok(Value::Bool(!a))
534}
535
536fn fn_xor(args: &[Value]) -> Result<Value> {
537 let a = args[0].as_bool()?;
538 let b = args[1].as_bool()?;
539 Ok(Value::Bool(a ^ b))
540}
541
542fn fn_eq(args: &[Value]) -> Result<Value> {
545 let a = args[0].as_number()?;
546 let b = args[1].as_number()?;
547 Ok(Value::Bool((a - b).abs() < f64::EPSILON))
548}
549
550fn fn_ne(args: &[Value]) -> Result<Value> {
551 let a = args[0].as_number()?;
552 let b = args[1].as_number()?;
553 Ok(Value::Bool((a - b).abs() >= f64::EPSILON))
554}
555
556fn fn_lt(args: &[Value]) -> Result<Value> {
557 let a = args[0].as_number()?;
558 let b = args[1].as_number()?;
559 Ok(Value::Bool(a < b))
560}
561
562fn fn_le(args: &[Value]) -> Result<Value> {
563 let a = args[0].as_number()?;
564 let b = args[1].as_number()?;
565 Ok(Value::Bool(a <= b))
566}
567
568fn fn_gt(args: &[Value]) -> Result<Value> {
569 let a = args[0].as_number()?;
570 let b = args[1].as_number()?;
571 Ok(Value::Bool(a > b))
572}
573
574fn fn_ge(args: &[Value]) -> Result<Value> {
575 let a = args[0].as_number()?;
576 let b = args[1].as_number()?;
577 Ok(Value::Bool(a >= b))
578}
579
580fn fn_to_number(args: &[Value]) -> Result<Value> {
583 args[0].as_number().map(Value::Number)
584}
585
586fn fn_to_bool(args: &[Value]) -> Result<Value> {
587 args[0].as_bool().map(Value::Bool)
588}
589
590fn fn_clamp(args: &[Value]) -> Result<Value> {
593 let value = args[0].as_number()?;
594 let min = args[1].as_number()?;
595 let max = args[2].as_number()?;
596
597 let clamped = if value < min {
598 min
599 } else if value > max {
600 max
601 } else {
602 value
603 };
604
605 Ok(Value::Number(clamped))
606}
607
608fn fn_select(args: &[Value]) -> Result<Value> {
609 let cond = args[0].as_bool()?;
610 if cond {
611 Ok(args[1].clone())
612 } else {
613 Ok(args[2].clone())
614 }
615}
616
617#[cfg(test)]
618#[allow(clippy::panic)]
619mod tests {
620 use super::*;
621 use oxigdal_core::buffer::RasterBuffer;
622 use oxigdal_core::types::RasterDataType;
623
624 #[test]
625 fn test_function_registry() {
626 let registry = FunctionRegistry::new();
627 assert!(registry.exists("sqrt"));
628 assert!(registry.exists("sin"));
629 assert!(registry.exists("mean"));
630 assert!(!registry.exists("nonexistent"));
631 }
632
633 #[test]
634 fn test_math_functions() {
635 let args = vec![Value::Number(16.0)];
636 let result = fn_sqrt(&args).expect("Should work");
637 if let Value::Number(n) = result {
638 assert!((n - 4.0).abs() < 1e-10);
639 } else {
640 panic!("Expected number");
641 }
642 }
643
644 #[test]
645 fn test_min_max() {
646 let args = vec![
647 Value::Number(3.0),
648 Value::Number(1.0),
649 Value::Number(4.0),
650 Value::Number(1.0),
651 Value::Number(5.0),
652 ];
653
654 let min_result = fn_min(&args).expect("Should work");
655 if let Value::Number(n) = min_result {
656 assert!((n - 1.0).abs() < 1e-10);
657 }
658
659 let max_result = fn_max(&args).expect("Should work");
660 if let Value::Number(n) = max_result {
661 assert!((n - 5.0).abs() < 1e-10);
662 }
663 }
664
665 #[test]
666 fn test_mean() {
667 let mut raster = RasterBuffer::zeros(10, 10, RasterDataType::Float32);
668 for y in 0..10 {
669 for x in 0..10 {
670 let _ = raster.set_pixel(x, y, (x + y) as f64);
671 }
672 }
673
674 let args = vec![Value::Raster(Box::new(raster))];
675 let result = fn_mean(&args);
676 assert!(result.is_ok());
677 }
678
679 #[test]
680 fn test_logical_functions() {
681 let args_true = vec![Value::Bool(true), Value::Bool(true)];
682 let result = fn_and(&args_true).expect("Should work");
683 assert!(matches!(result, Value::Bool(true)));
684
685 let args_false = vec![Value::Bool(true), Value::Bool(false)];
686 let result = fn_and(&args_false).expect("Should work");
687 assert!(matches!(result, Value::Bool(false)));
688 }
689
690 #[test]
691 fn test_clamp() {
692 let args = vec![Value::Number(15.0), Value::Number(0.0), Value::Number(10.0)];
693 let result = fn_clamp(&args).expect("Should work");
694 if let Value::Number(n) = result {
695 assert!((n - 10.0).abs() < 1e-10);
696 }
697 }
698}