1use super::ast::{BinaryOp, Expr, Program, Statement, UnaryOp};
7use super::functions::FunctionRegistry;
8use super::variables::{BandContext, Environment, Value};
9use crate::error::{AlgorithmError, Result};
10use oxigdal_core::buffer::RasterBuffer;
11use oxigdal_core::types::RasterDataType;
12
13#[cfg(not(feature = "std"))]
14use alloc::{boxed::Box, string::String, vec::Vec};
15
16pub struct CompiledProgram {
18 program: Program,
19 func_registry: FunctionRegistry,
20}
21
22impl CompiledProgram {
23 pub fn new(program: Program) -> Self {
25 Self {
26 program,
27 func_registry: FunctionRegistry::new(),
28 }
29 }
30
31 pub fn execute(&self, bands: &[RasterBuffer]) -> Result<RasterBuffer> {
33 if bands.is_empty() {
34 return Err(AlgorithmError::EmptyInput {
35 operation: "execute",
36 });
37 }
38
39 let width = bands[0].width();
40 let height = bands[0].height();
41
42 for band in bands.iter().skip(1) {
44 if band.width() != width || band.height() != height {
45 return Err(AlgorithmError::InvalidDimensions {
46 message: "All bands must have same dimensions",
47 actual: band.width() as usize,
48 expected: width as usize,
49 });
50 }
51 }
52
53 let mut env = Environment::new();
54 let band_ctx = BandContext::new(bands);
55 let mut executor = Executor::new(&self.func_registry);
56
57 for stmt in &self.program.statements {
59 executor.execute_statement(stmt, &mut env, &band_ctx)?;
60 }
61
62 if let Some(Statement::Expr(expr)) = self.program.statements.last() {
64 let result = executor.evaluate_expr(expr, &env, &band_ctx)?;
65
66 match result {
67 Value::Raster(r) => Ok(*r),
68 Value::Number(n) => {
69 let mut raster = RasterBuffer::zeros(width, height, RasterDataType::Float32);
70 for y in 0..height {
71 for x in 0..width {
72 raster.set_pixel(x, y, n).map_err(AlgorithmError::Core)?;
73 }
74 }
75 Ok(raster)
76 }
77 Value::Bool(b) => {
78 let val = if b { 1.0 } else { 0.0 };
79 let mut raster = RasterBuffer::zeros(width, height, RasterDataType::Float32);
80 for y in 0..height {
81 for x in 0..width {
82 raster.set_pixel(x, y, val).map_err(AlgorithmError::Core)?;
83 }
84 }
85 Ok(raster)
86 }
87 _ => Err(AlgorithmError::InvalidParameter {
88 parameter: "result",
89 message: "Program must return a raster or scalar".to_string(),
90 }),
91 }
92 } else {
93 Err(AlgorithmError::InvalidParameter {
94 parameter: "program",
95 message: "Program has no expression to evaluate".to_string(),
96 })
97 }
98 }
99
100 pub fn execute_expr(&self, expr: &Expr, bands: &[RasterBuffer]) -> Result<RasterBuffer> {
102 if bands.is_empty() {
103 return Err(AlgorithmError::EmptyInput {
104 operation: "execute_expr",
105 });
106 }
107
108 let width = bands[0].width();
109 let height = bands[0].height();
110
111 let env = Environment::new();
112 let band_ctx = BandContext::new(bands);
113 let mut executor = Executor::new(&self.func_registry);
114
115 let result = executor.evaluate_expr(expr, &env, &band_ctx)?;
116
117 match result {
118 Value::Raster(r) => Ok(*r),
119 Value::Number(n) => {
120 let mut raster = RasterBuffer::zeros(width, height, RasterDataType::Float32);
121 for y in 0..height {
122 for x in 0..width {
123 raster.set_pixel(x, y, n).map_err(AlgorithmError::Core)?;
124 }
125 }
126 Ok(raster)
127 }
128 Value::Bool(b) => {
129 let val = if b { 1.0 } else { 0.0 };
130 let mut raster = RasterBuffer::zeros(width, height, RasterDataType::Float32);
131 for y in 0..height {
132 for x in 0..width {
133 raster.set_pixel(x, y, val).map_err(AlgorithmError::Core)?;
134 }
135 }
136 Ok(raster)
137 }
138 _ => Err(AlgorithmError::InvalidParameter {
139 parameter: "result",
140 message: "Expression must return a raster or scalar".to_string(),
141 }),
142 }
143 }
144}
145
146struct Executor<'a> {
148 func_registry: &'a FunctionRegistry,
149}
150
151impl<'a> Executor<'a> {
152 fn new(func_registry: &'a FunctionRegistry) -> Self {
153 Self { func_registry }
154 }
155
156 fn execute_statement(
157 &mut self,
158 stmt: &Statement,
159 env: &mut Environment,
160 band_ctx: &BandContext,
161 ) -> Result<()> {
162 match stmt {
163 Statement::VariableDecl { name, value } => {
164 let val = self.evaluate_expr(value, env, band_ctx)?;
165 env.define(name.clone(), val);
166 Ok(())
167 }
168 Statement::FunctionDecl { name, params, body } => {
169 let func_val = Value::Function {
170 params: params.clone(),
171 body: body.clone(),
172 env: env.clone(),
173 };
174 env.define(name.clone(), func_val);
175 Ok(())
176 }
177 Statement::Return(_) => Err(AlgorithmError::InvalidParameter {
178 parameter: "return",
179 message: "Return statements not supported in top-level".to_string(),
180 }),
181 Statement::Expr(expr) => {
182 let _ = self.evaluate_expr(expr, env, band_ctx)?;
183 Ok(())
184 }
185 }
186 }
187
188 fn evaluate_expr(
189 &mut self,
190 expr: &Expr,
191 env: &Environment,
192 band_ctx: &BandContext,
193 ) -> Result<Value> {
194 match expr {
195 Expr::Number(n) => Ok(Value::Number(*n)),
196 Expr::Band(b) => {
197 let band = band_ctx.get_band(*b)?;
198 Ok(Value::Raster(Box::new(band.clone())))
199 }
200 Expr::Variable(name) => env.lookup(name).cloned(),
201 Expr::Binary {
202 left, op, right, ..
203 } => self.evaluate_binary(left, *op, right, env, band_ctx),
204 Expr::Unary {
205 op, expr: inner, ..
206 } => self.evaluate_unary(*op, inner, env, band_ctx),
207 Expr::Call { name, args, .. } => self.evaluate_call(name, args, env, band_ctx),
208 Expr::Conditional {
209 condition,
210 then_expr,
211 else_expr,
212 ..
213 } => self.evaluate_conditional(condition, then_expr, else_expr, env, band_ctx),
214 Expr::Block {
215 statements, result, ..
216 } => self.evaluate_block(statements, result.as_deref(), env, band_ctx),
217 Expr::ForLoop { .. } => Err(AlgorithmError::InvalidParameter {
218 parameter: "for",
219 message: "For loops not yet implemented".to_string(),
220 }),
221 }
222 }
223
224 fn evaluate_binary(
225 &mut self,
226 left: &Expr,
227 op: BinaryOp,
228 right: &Expr,
229 env: &Environment,
230 band_ctx: &BandContext,
231 ) -> Result<Value> {
232 let left_val = self.evaluate_expr(left, env, band_ctx)?;
233 let right_val = self.evaluate_expr(right, env, band_ctx)?;
234
235 match (left_val, right_val) {
236 (Value::Number(l), Value::Number(r)) => {
237 let result = match op {
238 BinaryOp::Add => l + r,
239 BinaryOp::Subtract => l - r,
240 BinaryOp::Multiply => l * r,
241 BinaryOp::Divide => {
242 if r.abs() < f64::EPSILON {
243 f64::NAN
244 } else {
245 l / r
246 }
247 }
248 BinaryOp::Modulo => l % r,
249 BinaryOp::Power => l.powf(r),
250 BinaryOp::Equal => return Ok(Value::Bool((l - r).abs() < f64::EPSILON)),
251 BinaryOp::NotEqual => return Ok(Value::Bool((l - r).abs() >= f64::EPSILON)),
252 BinaryOp::Less => return Ok(Value::Bool(l < r)),
253 BinaryOp::LessEqual => return Ok(Value::Bool(l <= r)),
254 BinaryOp::Greater => return Ok(Value::Bool(l > r)),
255 BinaryOp::GreaterEqual => return Ok(Value::Bool(l >= r)),
256 BinaryOp::And | BinaryOp::Or => {
257 return Err(AlgorithmError::InvalidParameter {
258 parameter: "operator",
259 message: "Logical operators require boolean operands".to_string(),
260 });
261 }
262 };
263 Ok(Value::Number(result))
264 }
265 (Value::Bool(l), Value::Bool(r)) => {
266 let result = match op {
267 BinaryOp::And => l && r,
268 BinaryOp::Or => l || r,
269 BinaryOp::Equal => l == r,
270 BinaryOp::NotEqual => l != r,
271 _ => {
272 return Err(AlgorithmError::InvalidParameter {
273 parameter: "operator",
274 message: format!("Operator {:?} not supported for booleans", op),
275 });
276 }
277 };
278 Ok(Value::Bool(result))
279 }
280 (Value::Raster(l), Value::Raster(r)) => self.evaluate_raster_binary(&l, op, &r),
281 (Value::Raster(l), Value::Number(r)) => {
282 self.evaluate_raster_scalar_binary(&l, op, r, false)
283 }
284 (Value::Number(l), Value::Raster(r)) => {
285 self.evaluate_raster_scalar_binary(&r, op, l, true)
286 }
287 _ => Err(AlgorithmError::InvalidParameter {
288 parameter: "operands",
289 message: "Incompatible operand types".to_string(),
290 }),
291 }
292 }
293
294 fn evaluate_raster_binary(
295 &self,
296 left: &RasterBuffer,
297 op: BinaryOp,
298 right: &RasterBuffer,
299 ) -> Result<Value> {
300 if left.width() != right.width() || left.height() != right.height() {
301 return Err(AlgorithmError::InvalidDimensions {
302 message: "Rasters must have same dimensions",
303 actual: right.width() as usize,
304 expected: left.width() as usize,
305 });
306 }
307
308 let mut result = RasterBuffer::zeros(left.width(), left.height(), left.data_type());
309
310 for y in 0..left.height() {
311 for x in 0..left.width() {
312 let l = left.get_pixel(x, y).map_err(AlgorithmError::Core)?;
313 let r = right.get_pixel(x, y).map_err(AlgorithmError::Core)?;
314
315 let val = match op {
316 BinaryOp::Add => l + r,
317 BinaryOp::Subtract => l - r,
318 BinaryOp::Multiply => l * r,
319 BinaryOp::Divide => {
320 if r.abs() < f64::EPSILON {
321 f64::NAN
322 } else {
323 l / r
324 }
325 }
326 BinaryOp::Modulo => l % r,
327 BinaryOp::Power => l.powf(r),
328 BinaryOp::Equal => {
329 if (l - r).abs() < f64::EPSILON {
330 1.0
331 } else {
332 0.0
333 }
334 }
335 BinaryOp::NotEqual => {
336 if (l - r).abs() >= f64::EPSILON {
337 1.0
338 } else {
339 0.0
340 }
341 }
342 BinaryOp::Less => {
343 if l < r {
344 1.0
345 } else {
346 0.0
347 }
348 }
349 BinaryOp::LessEqual => {
350 if l <= r {
351 1.0
352 } else {
353 0.0
354 }
355 }
356 BinaryOp::Greater => {
357 if l > r {
358 1.0
359 } else {
360 0.0
361 }
362 }
363 BinaryOp::GreaterEqual => {
364 if l >= r {
365 1.0
366 } else {
367 0.0
368 }
369 }
370 BinaryOp::And => {
371 let l_bool = l.abs() > f64::EPSILON;
373 let r_bool = r.abs() > f64::EPSILON;
374 if l_bool && r_bool { 1.0 } else { 0.0 }
375 }
376 BinaryOp::Or => {
377 let l_bool = l.abs() > f64::EPSILON;
379 let r_bool = r.abs() > f64::EPSILON;
380 if l_bool || r_bool { 1.0 } else { 0.0 }
381 }
382 };
383
384 result.set_pixel(x, y, val).map_err(AlgorithmError::Core)?;
385 }
386 }
387
388 Ok(Value::Raster(Box::new(result)))
389 }
390
391 fn evaluate_raster_scalar_binary(
392 &self,
393 raster: &RasterBuffer,
394 op: BinaryOp,
395 scalar: f64,
396 scalar_left: bool,
397 ) -> Result<Value> {
398 let mut result = RasterBuffer::zeros(raster.width(), raster.height(), raster.data_type());
399
400 for y in 0..raster.height() {
401 for x in 0..raster.width() {
402 let r = raster.get_pixel(x, y).map_err(AlgorithmError::Core)?;
403
404 let val = if scalar_left {
405 match op {
406 BinaryOp::Add => scalar + r,
407 BinaryOp::Subtract => scalar - r,
408 BinaryOp::Multiply => scalar * r,
409 BinaryOp::Divide => {
410 if r.abs() < f64::EPSILON {
411 f64::NAN
412 } else {
413 scalar / r
414 }
415 }
416 BinaryOp::Modulo => scalar % r,
417 BinaryOp::Power => scalar.powf(r),
418 BinaryOp::Equal => {
419 if (scalar - r).abs() < f64::EPSILON {
420 1.0
421 } else {
422 0.0
423 }
424 }
425 BinaryOp::NotEqual => {
426 if (scalar - r).abs() >= f64::EPSILON {
427 1.0
428 } else {
429 0.0
430 }
431 }
432 BinaryOp::Less => {
433 if scalar < r {
434 1.0
435 } else {
436 0.0
437 }
438 }
439 BinaryOp::LessEqual => {
440 if scalar <= r {
441 1.0
442 } else {
443 0.0
444 }
445 }
446 BinaryOp::Greater => {
447 if scalar > r {
448 1.0
449 } else {
450 0.0
451 }
452 }
453 BinaryOp::GreaterEqual => {
454 if scalar >= r {
455 1.0
456 } else {
457 0.0
458 }
459 }
460 BinaryOp::And | BinaryOp::Or => {
461 return Err(AlgorithmError::InvalidParameter {
462 parameter: "operator",
463 message: "Logical operators require boolean operands".to_string(),
464 });
465 }
466 }
467 } else {
468 match op {
469 BinaryOp::Add => r + scalar,
470 BinaryOp::Subtract => r - scalar,
471 BinaryOp::Multiply => r * scalar,
472 BinaryOp::Divide => {
473 if scalar.abs() < f64::EPSILON {
474 f64::NAN
475 } else {
476 r / scalar
477 }
478 }
479 BinaryOp::Modulo => r % scalar,
480 BinaryOp::Power => r.powf(scalar),
481 BinaryOp::Equal => {
482 if (r - scalar).abs() < f64::EPSILON {
483 1.0
484 } else {
485 0.0
486 }
487 }
488 BinaryOp::NotEqual => {
489 if (r - scalar).abs() >= f64::EPSILON {
490 1.0
491 } else {
492 0.0
493 }
494 }
495 BinaryOp::Less => {
496 if r < scalar {
497 1.0
498 } else {
499 0.0
500 }
501 }
502 BinaryOp::LessEqual => {
503 if r <= scalar {
504 1.0
505 } else {
506 0.0
507 }
508 }
509 BinaryOp::Greater => {
510 if r > scalar {
511 1.0
512 } else {
513 0.0
514 }
515 }
516 BinaryOp::GreaterEqual => {
517 if r >= scalar {
518 1.0
519 } else {
520 0.0
521 }
522 }
523 BinaryOp::And | BinaryOp::Or => {
524 return Err(AlgorithmError::InvalidParameter {
525 parameter: "operator",
526 message: "Logical operators require boolean operands".to_string(),
527 });
528 }
529 }
530 };
531
532 result.set_pixel(x, y, val).map_err(AlgorithmError::Core)?;
533 }
534 }
535
536 Ok(Value::Raster(Box::new(result)))
537 }
538
539 fn evaluate_unary(
540 &mut self,
541 op: UnaryOp,
542 expr: &Expr,
543 env: &Environment,
544 band_ctx: &BandContext,
545 ) -> Result<Value> {
546 let val = self.evaluate_expr(expr, env, band_ctx)?;
547
548 match val {
549 Value::Number(n) => {
550 let result = match op {
551 UnaryOp::Negate => -n,
552 UnaryOp::Plus => n,
553 UnaryOp::Not => {
554 return Err(AlgorithmError::InvalidParameter {
555 parameter: "operator",
556 message: "Not operator requires boolean".to_string(),
557 });
558 }
559 };
560 Ok(Value::Number(result))
561 }
562 Value::Bool(b) => match op {
563 UnaryOp::Not => Ok(Value::Bool(!b)),
564 _ => Err(AlgorithmError::InvalidParameter {
565 parameter: "operator",
566 message: "Operator not supported for booleans".to_string(),
567 }),
568 },
569 Value::Raster(raster) => {
570 let mut result =
571 RasterBuffer::zeros(raster.width(), raster.height(), raster.data_type());
572
573 for y in 0..raster.height() {
574 for x in 0..raster.width() {
575 let val = raster.get_pixel(x, y).map_err(AlgorithmError::Core)?;
576 let new_val = match op {
577 UnaryOp::Negate => -val,
578 UnaryOp::Plus => val,
579 UnaryOp::Not => {
580 return Err(AlgorithmError::InvalidParameter {
581 parameter: "operator",
582 message: "Not operator requires boolean operands".to_string(),
583 });
584 }
585 };
586 result
587 .set_pixel(x, y, new_val)
588 .map_err(AlgorithmError::Core)?;
589 }
590 }
591
592 Ok(Value::Raster(Box::new(result)))
593 }
594 _ => Err(AlgorithmError::InvalidParameter {
595 parameter: "operand",
596 message: "Incompatible operand type for unary operator".to_string(),
597 }),
598 }
599 }
600
601 fn evaluate_call(
602 &mut self,
603 name: &str,
604 args: &[Expr],
605 env: &Environment,
606 band_ctx: &BandContext,
607 ) -> Result<Value> {
608 if let Ok(func_val) = env.lookup(name) {
610 if let Value::Function {
611 params,
612 body,
613 env: func_env,
614 } = func_val
615 {
616 if params.len() != args.len() {
617 return Err(AlgorithmError::InvalidParameter {
618 parameter: "arguments",
619 message: format!("Expected {} arguments, got {}", params.len(), args.len()),
620 });
621 }
622
623 let mut new_env = Environment::with_parent(func_env.clone());
625 for (param, arg) in params.iter().zip(args.iter()) {
626 let arg_val = self.evaluate_expr(arg, env, band_ctx)?;
627 new_env.define(param.clone(), arg_val);
628 }
629
630 return self.evaluate_expr(body, &new_env, band_ctx);
631 }
632 }
633
634 if let Some((func, arity)) = self.func_registry.lookup(name) {
636 if arity > 0 && args.len() != arity {
637 return Err(AlgorithmError::InvalidParameter {
638 parameter: "arguments",
639 message: format!("Expected {arity} arguments, got {}", args.len()),
640 });
641 }
642
643 let arg_vals: Result<Vec<Value>> = args
644 .iter()
645 .map(|arg| self.evaluate_expr(arg, env, band_ctx))
646 .collect();
647
648 func(&arg_vals?)
649 } else {
650 Err(AlgorithmError::InvalidParameter {
651 parameter: "function",
652 message: format!("Unknown function: {name}"),
653 })
654 }
655 }
656
657 fn evaluate_conditional(
658 &mut self,
659 condition: &Expr,
660 then_expr: &Expr,
661 else_expr: &Expr,
662 env: &Environment,
663 band_ctx: &BandContext,
664 ) -> Result<Value> {
665 let cond_val = self.evaluate_expr(condition, env, band_ctx)?;
666
667 match cond_val {
668 Value::Bool(b) => {
669 if b {
670 self.evaluate_expr(then_expr, env, band_ctx)
671 } else {
672 self.evaluate_expr(else_expr, env, band_ctx)
673 }
674 }
675 Value::Number(n) => {
676 if n.abs() > f64::EPSILON {
677 self.evaluate_expr(then_expr, env, band_ctx)
678 } else {
679 self.evaluate_expr(else_expr, env, band_ctx)
680 }
681 }
682 Value::Raster(cond_raster) => {
683 let then_val = self.evaluate_expr(then_expr, env, band_ctx)?;
685 let else_val = self.evaluate_expr(else_expr, env, band_ctx)?;
686
687 let width = cond_raster.width();
688 let height = cond_raster.height();
689 let mut result = RasterBuffer::zeros(width, height, cond_raster.data_type());
690
691 for y in 0..height {
692 for x in 0..width {
693 let cond = cond_raster.get_pixel(x, y).map_err(AlgorithmError::Core)?;
694 let is_true = cond.abs() > f64::EPSILON;
695
696 let val = if is_true {
697 match &then_val {
698 Value::Raster(r) => {
699 r.get_pixel(x, y).map_err(AlgorithmError::Core)?
700 }
701 Value::Number(n) => *n,
702 Value::Bool(b) => {
703 if *b {
704 1.0
705 } else {
706 0.0
707 }
708 }
709 _ => {
710 return Err(AlgorithmError::InvalidParameter {
711 parameter: "then_expr",
712 message: "Then expression must be raster or scalar"
713 .to_string(),
714 });
715 }
716 }
717 } else {
718 match &else_val {
719 Value::Raster(r) => {
720 r.get_pixel(x, y).map_err(AlgorithmError::Core)?
721 }
722 Value::Number(n) => *n,
723 Value::Bool(b) => {
724 if *b {
725 1.0
726 } else {
727 0.0
728 }
729 }
730 _ => {
731 return Err(AlgorithmError::InvalidParameter {
732 parameter: "else_expr",
733 message: "Else expression must be raster or scalar"
734 .to_string(),
735 });
736 }
737 }
738 };
739
740 result.set_pixel(x, y, val).map_err(AlgorithmError::Core)?;
741 }
742 }
743
744 Ok(Value::Raster(Box::new(result)))
745 }
746 _ => Err(AlgorithmError::InvalidParameter {
747 parameter: "condition",
748 message: "Condition must be boolean, number, or raster".to_string(),
749 }),
750 }
751 }
752
753 fn evaluate_block(
754 &mut self,
755 statements: &[Statement],
756 result: Option<&Expr>,
757 env: &Environment,
758 band_ctx: &BandContext,
759 ) -> Result<Value> {
760 let mut block_env = Environment::with_parent(env.clone());
761
762 for stmt in statements {
763 self.execute_statement(stmt, &mut block_env, band_ctx)?;
764 }
765
766 if let Some(expr) = result {
767 self.evaluate_expr(expr, &block_env, band_ctx)
768 } else {
769 Ok(Value::Number(0.0))
770 }
771 }
772}
773
774#[cfg(test)]
775mod tests {
776 use super::*;
777 use crate::dsl::parser::parse_expression;
778 use oxigdal_core::types::RasterDataType;
779
780 #[test]
781 fn test_compile_number() {
782 let expr = parse_expression("42").expect("Should parse");
783 let program = Program {
784 statements: vec![Statement::Expr(Box::new(expr))],
785 };
786 let compiled = CompiledProgram::new(program);
787
788 let bands = vec![RasterBuffer::zeros(10, 10, RasterDataType::Float32)];
789 let result = compiled.execute(&bands);
790 assert!(result.is_ok());
791 }
792
793 #[test]
794 fn test_compile_band() {
795 let expr = parse_expression("B1").expect("Should parse");
796 let program = Program {
797 statements: vec![Statement::Expr(Box::new(expr))],
798 };
799 let compiled = CompiledProgram::new(program);
800
801 let bands = vec![RasterBuffer::zeros(10, 10, RasterDataType::Float32)];
802 let result = compiled.execute(&bands);
803 assert!(result.is_ok());
804 }
805
806 #[test]
807 fn test_compile_arithmetic() {
808 let expr = parse_expression("B1 + B2").expect("Should parse");
809 let program = Program {
810 statements: vec![Statement::Expr(Box::new(expr))],
811 };
812 let compiled = CompiledProgram::new(program);
813
814 let bands = vec![
815 RasterBuffer::zeros(10, 10, RasterDataType::Float32),
816 RasterBuffer::zeros(10, 10, RasterDataType::Float32),
817 ];
818 let result = compiled.execute(&bands);
819 assert!(result.is_ok());
820 }
821}