Skip to main content

oxigdal_algorithms/dsl/
ast.rs

1//! Abstract Syntax Tree for Raster Algebra DSL
2//!
3//! This module defines the AST nodes for the raster algebra DSL, providing
4//! a type-safe representation of parsed expressions.
5
6#![allow(missing_docs)]
7
8use serde::{Deserialize, Serialize};
9
10#[cfg(not(feature = "std"))]
11use alloc::{boxed::Box, string::String, vec::Vec};
12
13/// Type of a value in the DSL
14#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
15pub enum Type {
16    /// Floating point number
17    Number,
18    /// Boolean value
19    Bool,
20    /// Raster band
21    Raster,
22    /// Unknown type (to be inferred)
23    Unknown,
24}
25
26impl Type {
27    /// Checks if two types are compatible
28    pub fn is_compatible(&self, other: &Type) -> bool {
29        matches!(
30            (self, other),
31            (Type::Number, Type::Number)
32                | (Type::Bool, Type::Bool)
33                | (Type::Raster, Type::Raster)
34                | (Type::Unknown, _)
35                | (_, Type::Unknown)
36        )
37    }
38
39    /// Gets the common type for binary operations
40    pub fn common_type(&self, other: &Type) -> Option<Type> {
41        match (self, other) {
42            (Type::Number, Type::Number) => Some(Type::Number),
43            (Type::Bool, Type::Bool) => Some(Type::Bool),
44            (Type::Raster, Type::Raster) => Some(Type::Raster),
45            (Type::Raster, Type::Number) | (Type::Number, Type::Raster) => Some(Type::Raster),
46            (Type::Unknown, t) | (t, Type::Unknown) => Some(*t),
47            _ => None,
48        }
49    }
50}
51
52/// Program AST - top level
53#[derive(Debug, Clone, Serialize, Deserialize)]
54pub struct Program {
55    pub statements: Vec<Statement>,
56}
57
58/// Statement in the DSL
59#[derive(Debug, Clone, Serialize, Deserialize)]
60pub enum Statement {
61    /// Variable declaration: let x = expr;
62    VariableDecl { name: String, value: Box<Expr> },
63    /// Function declaration: fn name(params) = expr;
64    FunctionDecl {
65        name: String,
66        params: Vec<String>,
67        body: Box<Expr>,
68    },
69    /// Return statement: return expr;
70    Return(Box<Expr>),
71    /// Expression statement: expr;
72    Expr(Box<Expr>),
73}
74
75/// Expression node
76#[derive(Debug, Clone, Serialize, Deserialize)]
77pub enum Expr {
78    /// Number literal
79    Number(f64),
80
81    /// Band reference (e.g., B1, B2)
82    Band(usize),
83
84    /// Variable reference
85    Variable(String),
86
87    /// Binary operation
88    Binary {
89        left: Box<Expr>,
90        op: BinaryOp,
91        right: Box<Expr>,
92        ty: Type,
93    },
94
95    /// Unary operation
96    Unary {
97        op: UnaryOp,
98        expr: Box<Expr>,
99        ty: Type,
100    },
101
102    /// Function call
103    Call {
104        name: String,
105        args: Vec<Expr>,
106        ty: Type,
107    },
108
109    /// Conditional expression: if cond then expr1 else expr2
110    Conditional {
111        condition: Box<Expr>,
112        then_expr: Box<Expr>,
113        else_expr: Box<Expr>,
114        ty: Type,
115    },
116
117    /// Block expression: { stmts; expr }
118    Block {
119        statements: Vec<Statement>,
120        result: Option<Box<Expr>>,
121        ty: Type,
122    },
123
124    /// For loop (for optimization/unrolling)
125    ForLoop {
126        var: String,
127        start: Box<Expr>,
128        end: Box<Expr>,
129        body: Box<Expr>,
130        ty: Type,
131    },
132}
133
134impl Expr {
135    /// Gets the type of this expression
136    pub fn get_type(&self) -> Type {
137        match self {
138            Expr::Number(_) => Type::Number,
139            Expr::Band(_) => Type::Raster,
140            Expr::Variable(_) => Type::Unknown,
141            Expr::Binary { ty, .. }
142            | Expr::Unary { ty, .. }
143            | Expr::Call { ty, .. }
144            | Expr::Conditional { ty, .. }
145            | Expr::Block { ty, .. }
146            | Expr::ForLoop { ty, .. } => *ty,
147        }
148    }
149
150    /// Sets the type of this expression
151    pub fn set_type(&mut self, new_type: Type) {
152        match self {
153            Expr::Binary { ty, .. }
154            | Expr::Unary { ty, .. }
155            | Expr::Call { ty, .. }
156            | Expr::Conditional { ty, .. }
157            | Expr::Block { ty, .. }
158            | Expr::ForLoop { ty, .. } => *ty = new_type,
159            _ => {}
160        }
161    }
162
163    /// Checks if this expression is constant
164    pub fn is_constant(&self) -> bool {
165        matches!(self, Expr::Number(_))
166    }
167
168    /// Checks if this expression is pure (has no side effects)
169    pub fn is_pure(&self) -> bool {
170        match self {
171            Expr::Number(_) | Expr::Band(_) | Expr::Variable(_) => true,
172            Expr::Binary { left, right, .. } => left.is_pure() && right.is_pure(),
173            Expr::Unary { expr, .. } => expr.is_pure(),
174            Expr::Call { args, .. } => args.iter().all(|a| a.is_pure()),
175            Expr::Conditional {
176                condition,
177                then_expr,
178                else_expr,
179                ..
180            } => condition.is_pure() && then_expr.is_pure() && else_expr.is_pure(),
181            Expr::Block { .. } | Expr::ForLoop { .. } => false,
182        }
183    }
184}
185
186/// Binary operators
187#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
188pub enum BinaryOp {
189    // Arithmetic
190    Add,
191    Subtract,
192    Multiply,
193    Divide,
194    Modulo,
195    Power,
196
197    // Comparison
198    Equal,
199    NotEqual,
200    Less,
201    LessEqual,
202    Greater,
203    GreaterEqual,
204
205    // Logical
206    And,
207    Or,
208}
209
210impl BinaryOp {
211    /// Gets the precedence of this operator (higher = tighter binding)
212    pub fn precedence(&self) -> u8 {
213        match self {
214            BinaryOp::Or => 1,
215            BinaryOp::And => 2,
216            BinaryOp::Equal
217            | BinaryOp::NotEqual
218            | BinaryOp::Less
219            | BinaryOp::LessEqual
220            | BinaryOp::Greater
221            | BinaryOp::GreaterEqual => 3,
222            BinaryOp::Add | BinaryOp::Subtract => 4,
223            BinaryOp::Multiply | BinaryOp::Divide | BinaryOp::Modulo => 5,
224            BinaryOp::Power => 6,
225        }
226    }
227
228    /// Checks if this operator is associative
229    pub fn is_associative(&self) -> bool {
230        matches!(
231            self,
232            BinaryOp::Add | BinaryOp::Multiply | BinaryOp::And | BinaryOp::Or
233        )
234    }
235
236    /// Checks if this operator is commutative
237    pub fn is_commutative(&self) -> bool {
238        matches!(
239            self,
240            BinaryOp::Add
241                | BinaryOp::Multiply
242                | BinaryOp::Equal
243                | BinaryOp::NotEqual
244                | BinaryOp::And
245                | BinaryOp::Or
246        )
247    }
248
249    /// Gets the result type for this operation
250    pub fn result_type(&self, left: Type, right: Type) -> Option<Type> {
251        match self {
252            BinaryOp::Add
253            | BinaryOp::Subtract
254            | BinaryOp::Multiply
255            | BinaryOp::Divide
256            | BinaryOp::Modulo
257            | BinaryOp::Power => left.common_type(&right),
258            BinaryOp::Equal
259            | BinaryOp::NotEqual
260            | BinaryOp::Less
261            | BinaryOp::LessEqual
262            | BinaryOp::Greater
263            | BinaryOp::GreaterEqual => {
264                if left.is_compatible(&right) {
265                    Some(Type::Bool)
266                } else {
267                    None
268                }
269            }
270            BinaryOp::And | BinaryOp::Or => {
271                if left == Type::Bool && right == Type::Bool {
272                    Some(Type::Bool)
273                } else {
274                    None
275                }
276            }
277        }
278    }
279}
280
281/// Unary operators
282#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
283pub enum UnaryOp {
284    /// Negation (-)
285    Negate,
286    /// Logical not (!)
287    Not,
288    /// Unary plus (+)
289    Plus,
290}
291
292impl UnaryOp {
293    /// Gets the result type for this operation
294    pub fn result_type(&self, operand: Type) -> Option<Type> {
295        match self {
296            UnaryOp::Negate | UnaryOp::Plus => {
297                if matches!(operand, Type::Number | Type::Raster) {
298                    Some(operand)
299                } else {
300                    None
301                }
302            }
303            UnaryOp::Not => {
304                if operand == Type::Bool {
305                    Some(Type::Bool)
306                } else {
307                    None
308                }
309            }
310        }
311    }
312}
313
314#[cfg(test)]
315mod tests {
316    use super::*;
317
318    #[test]
319    fn test_type_compatibility() {
320        assert!(Type::Number.is_compatible(&Type::Number));
321        assert!(Type::Unknown.is_compatible(&Type::Number));
322        assert!(!Type::Number.is_compatible(&Type::Bool));
323    }
324
325    #[test]
326    fn test_common_type() {
327        assert_eq!(Type::Number.common_type(&Type::Number), Some(Type::Number));
328        assert_eq!(Type::Raster.common_type(&Type::Number), Some(Type::Raster));
329        assert_eq!(Type::Number.common_type(&Type::Bool), None);
330    }
331
332    #[test]
333    fn test_expr_constant() {
334        let expr = Expr::Number(42.0);
335        assert!(expr.is_constant());
336
337        let expr = Expr::Band(1);
338        assert!(!expr.is_constant());
339    }
340
341    #[test]
342    fn test_binary_op_precedence() {
343        assert!(BinaryOp::Multiply.precedence() > BinaryOp::Add.precedence());
344        assert!(BinaryOp::Power.precedence() > BinaryOp::Multiply.precedence());
345    }
346
347    #[test]
348    fn test_binary_op_properties() {
349        assert!(BinaryOp::Add.is_commutative());
350        assert!(BinaryOp::Add.is_associative());
351        assert!(!BinaryOp::Subtract.is_commutative());
352        assert!(!BinaryOp::Divide.is_associative());
353    }
354}