Skip to main content

oxigdal_algorithms/dsl/
optimizer.rs

1//! Expression optimizer for the DSL
2//!
3//! This module provides various optimization passes:
4//! - Constant folding
5//! - Common subexpression elimination
6//! - Dead code elimination
7//! - Algebraic simplifications
8
9use super::ast::{BinaryOp, Expr, Program, Statement, UnaryOp};
10
11#[cfg(not(feature = "std"))]
12use alloc::{boxed::Box, collections::BTreeMap as HashMap, string::String, vec::Vec};
13
14#[cfg(feature = "std")]
15use std::collections::HashMap;
16
17/// Optimization level
18#[derive(Debug, Clone, Copy, PartialEq, Eq)]
19pub enum OptLevel {
20    /// No optimization
21    None,
22    /// Basic optimizations (constant folding)
23    Basic,
24    /// Standard optimizations (basic + algebraic simplifications)
25    Standard,
26    /// Aggressive optimizations (standard + CSE + DCE)
27    Aggressive,
28}
29
30/// Optimizer for DSL programs
31pub struct Optimizer {
32    level: OptLevel,
33}
34
35impl Default for Optimizer {
36    fn default() -> Self {
37        Self::new(OptLevel::Standard)
38    }
39}
40
41impl Optimizer {
42    /// Creates a new optimizer with the given optimization level
43    pub fn new(level: OptLevel) -> Self {
44        Self { level }
45    }
46
47    /// Optimizes a program
48    pub fn optimize_program(&self, mut program: Program) -> Program {
49        if self.level == OptLevel::None {
50            return program;
51        }
52
53        program.statements = program
54            .statements
55            .into_iter()
56            .map(|stmt| self.optimize_statement(stmt))
57            .collect();
58
59        program
60    }
61
62    /// Optimizes a single statement
63    pub fn optimize_statement(&self, stmt: Statement) -> Statement {
64        match stmt {
65            Statement::VariableDecl { name, value } => Statement::VariableDecl {
66                name,
67                value: Box::new(self.optimize_expr(*value)),
68            },
69            Statement::FunctionDecl { name, params, body } => Statement::FunctionDecl {
70                name,
71                params,
72                body: Box::new(self.optimize_expr(*body)),
73            },
74            Statement::Return(expr) => Statement::Return(Box::new(self.optimize_expr(*expr))),
75            Statement::Expr(expr) => Statement::Expr(Box::new(self.optimize_expr(*expr))),
76        }
77    }
78
79    /// Optimizes an expression
80    pub fn optimize_expr(&self, expr: Expr) -> Expr {
81        if self.level == OptLevel::None {
82            return expr;
83        }
84
85        let mut optimized = expr;
86
87        // Apply constant folding
88        optimized = self.constant_fold(optimized);
89
90        // Apply algebraic simplifications
91        if matches!(self.level, OptLevel::Standard | OptLevel::Aggressive) {
92            optimized = self.algebraic_simplify(optimized);
93        }
94
95        // Apply common subexpression elimination
96        if self.level == OptLevel::Aggressive {
97            optimized = self.eliminate_common_subexpressions(optimized);
98        }
99
100        optimized
101    }
102
103    /// Performs constant folding
104    fn constant_fold(&self, expr: Expr) -> Expr {
105        match expr {
106            Expr::Binary {
107                left,
108                op,
109                right,
110                ty,
111            } => {
112                let left_opt = self.constant_fold(*left);
113                let right_opt = self.constant_fold(*right);
114
115                if let (Expr::Number(l), Expr::Number(r)) = (&left_opt, &right_opt) {
116                    if let Some(result) = self.eval_const_binary(*l, op, *r) {
117                        return Expr::Number(result);
118                    }
119                }
120
121                Expr::Binary {
122                    left: Box::new(left_opt),
123                    op,
124                    right: Box::new(right_opt),
125                    ty,
126                }
127            }
128            Expr::Unary {
129                op,
130                expr: inner,
131                ty,
132            } => {
133                let inner_opt = self.constant_fold(*inner);
134
135                if let Expr::Number(n) = &inner_opt {
136                    if let Some(result) = self.eval_const_unary(op, *n) {
137                        return Expr::Number(result);
138                    }
139                }
140
141                Expr::Unary {
142                    op,
143                    expr: Box::new(inner_opt),
144                    ty,
145                }
146            }
147            Expr::Conditional {
148                condition,
149                then_expr,
150                else_expr,
151                ty,
152            } => {
153                let cond_opt = self.constant_fold(*condition);
154
155                // If condition is constant, return only the taken branch
156                if let Expr::Number(n) = &cond_opt {
157                    if n.abs() > f64::EPSILON {
158                        return self.constant_fold(*then_expr);
159                    } else {
160                        return self.constant_fold(*else_expr);
161                    }
162                }
163
164                Expr::Conditional {
165                    condition: Box::new(cond_opt),
166                    then_expr: Box::new(self.constant_fold(*then_expr)),
167                    else_expr: Box::new(self.constant_fold(*else_expr)),
168                    ty,
169                }
170            }
171            Expr::Call { name, args, ty } => Expr::Call {
172                name,
173                args: args
174                    .into_iter()
175                    .map(|arg| self.constant_fold(arg))
176                    .collect(),
177                ty,
178            },
179            Expr::Block {
180                statements,
181                result,
182                ty,
183            } => Expr::Block {
184                statements: statements
185                    .into_iter()
186                    .map(|stmt| self.optimize_statement(stmt))
187                    .collect(),
188                result: result.map(|r| Box::new(self.constant_fold(*r))),
189                ty,
190            },
191            _ => expr,
192        }
193    }
194
195    /// Evaluates a constant binary operation
196    fn eval_const_binary(&self, left: f64, op: BinaryOp, right: f64) -> Option<f64> {
197        let result = match op {
198            BinaryOp::Add => left + right,
199            BinaryOp::Subtract => left - right,
200            BinaryOp::Multiply => left * right,
201            BinaryOp::Divide => {
202                if right.abs() < f64::EPSILON {
203                    return None;
204                }
205                left / right
206            }
207            BinaryOp::Modulo => left % right,
208            BinaryOp::Power => left.powf(right),
209            BinaryOp::Equal => {
210                if (left - right).abs() < f64::EPSILON {
211                    1.0
212                } else {
213                    0.0
214                }
215            }
216            BinaryOp::NotEqual => {
217                if (left - right).abs() >= f64::EPSILON {
218                    1.0
219                } else {
220                    0.0
221                }
222            }
223            BinaryOp::Less => {
224                if left < right {
225                    1.0
226                } else {
227                    0.0
228                }
229            }
230            BinaryOp::LessEqual => {
231                if left <= right {
232                    1.0
233                } else {
234                    0.0
235                }
236            }
237            BinaryOp::Greater => {
238                if left > right {
239                    1.0
240                } else {
241                    0.0
242                }
243            }
244            BinaryOp::GreaterEqual => {
245                if left >= right {
246                    1.0
247                } else {
248                    0.0
249                }
250            }
251            BinaryOp::And => {
252                if left != 0.0 && right != 0.0 {
253                    1.0
254                } else {
255                    0.0
256                }
257            }
258            BinaryOp::Or => {
259                if left != 0.0 || right != 0.0 {
260                    1.0
261                } else {
262                    0.0
263                }
264            }
265        };
266
267        Some(result)
268    }
269
270    /// Evaluates a constant unary operation
271    fn eval_const_unary(&self, op: UnaryOp, operand: f64) -> Option<f64> {
272        let result = match op {
273            UnaryOp::Negate => -operand,
274            UnaryOp::Plus => operand,
275            UnaryOp::Not => {
276                if operand.abs() < f64::EPSILON {
277                    1.0
278                } else {
279                    0.0
280                }
281            }
282        };
283
284        Some(result)
285    }
286
287    /// Performs algebraic simplifications
288    fn algebraic_simplify(&self, expr: Expr) -> Expr {
289        match expr {
290            Expr::Binary {
291                left,
292                op,
293                right,
294                ty,
295            } => {
296                let left_opt = self.algebraic_simplify(*left);
297                let right_opt = self.algebraic_simplify(*right);
298
299                // x + 0 = x
300                if op == BinaryOp::Add {
301                    if let Expr::Number(n) = &right_opt {
302                        if n.abs() < f64::EPSILON {
303                            return left_opt;
304                        }
305                    }
306                    if let Expr::Number(n) = &left_opt {
307                        if n.abs() < f64::EPSILON {
308                            return right_opt;
309                        }
310                    }
311                }
312
313                // x - 0 = x
314                if op == BinaryOp::Subtract {
315                    if let Expr::Number(n) = &right_opt {
316                        if n.abs() < f64::EPSILON {
317                            return left_opt;
318                        }
319                    }
320                }
321
322                // x * 0 = 0
323                if op == BinaryOp::Multiply {
324                    if let Expr::Number(n) = &right_opt {
325                        if n.abs() < f64::EPSILON {
326                            return Expr::Number(0.0);
327                        }
328                    }
329                    if let Expr::Number(n) = &left_opt {
330                        if n.abs() < f64::EPSILON {
331                            return Expr::Number(0.0);
332                        }
333                    }
334                }
335
336                // x * 1 = x
337                if op == BinaryOp::Multiply {
338                    if let Expr::Number(n) = &right_opt {
339                        if (n - 1.0).abs() < f64::EPSILON {
340                            return left_opt;
341                        }
342                    }
343                    if let Expr::Number(n) = &left_opt {
344                        if (n - 1.0).abs() < f64::EPSILON {
345                            return right_opt;
346                        }
347                    }
348                }
349
350                // x / 1 = x
351                if op == BinaryOp::Divide {
352                    if let Expr::Number(n) = &right_opt {
353                        if (n - 1.0).abs() < f64::EPSILON {
354                            return left_opt;
355                        }
356                    }
357                }
358
359                // x ^ 0 = 1
360                if op == BinaryOp::Power {
361                    if let Expr::Number(n) = &right_opt {
362                        if n.abs() < f64::EPSILON {
363                            return Expr::Number(1.0);
364                        }
365                    }
366                }
367
368                // x ^ 1 = x
369                if op == BinaryOp::Power {
370                    if let Expr::Number(n) = &right_opt {
371                        if (n - 1.0).abs() < f64::EPSILON {
372                            return left_opt;
373                        }
374                    }
375                }
376
377                Expr::Binary {
378                    left: Box::new(left_opt),
379                    op,
380                    right: Box::new(right_opt),
381                    ty,
382                }
383            }
384            Expr::Unary {
385                op,
386                expr: inner,
387                ty,
388            } => {
389                let inner_opt = self.algebraic_simplify(*inner);
390
391                // --x = x
392                if op == UnaryOp::Negate {
393                    if let Expr::Unary {
394                        op: UnaryOp::Negate,
395                        expr: double_neg,
396                        ..
397                    } = &inner_opt
398                    {
399                        return *double_neg.clone();
400                    }
401                }
402
403                // +x = x
404                if op == UnaryOp::Plus {
405                    return inner_opt;
406                }
407
408                Expr::Unary {
409                    op,
410                    expr: Box::new(inner_opt),
411                    ty,
412                }
413            }
414            Expr::Conditional {
415                condition,
416                then_expr,
417                else_expr,
418                ty,
419            } => Expr::Conditional {
420                condition: Box::new(self.algebraic_simplify(*condition)),
421                then_expr: Box::new(self.algebraic_simplify(*then_expr)),
422                else_expr: Box::new(self.algebraic_simplify(*else_expr)),
423                ty,
424            },
425            Expr::Call { name, args, ty } => Expr::Call {
426                name,
427                args: args
428                    .into_iter()
429                    .map(|arg| self.algebraic_simplify(arg))
430                    .collect(),
431                ty,
432            },
433            Expr::Block {
434                statements,
435                result,
436                ty,
437            } => Expr::Block {
438                statements: statements
439                    .into_iter()
440                    .map(|stmt| self.optimize_statement(stmt))
441                    .collect(),
442                result: result.map(|r| Box::new(self.algebraic_simplify(*r))),
443                ty,
444            },
445            _ => expr,
446        }
447    }
448
449    /// Eliminates common subexpressions
450    fn eliminate_common_subexpressions(&self, expr: Expr) -> Expr {
451        let mut seen: HashMap<String, usize> = HashMap::new();
452        self.cse_pass(&expr, &mut seen);
453        // Note: Full CSE implementation would require more complex analysis
454        // This is a simplified version that just counts occurrences
455        expr
456    }
457
458    fn cse_pass(&self, expr: &Expr, seen: &mut HashMap<String, usize>) {
459        match expr {
460            Expr::Binary { left, right, .. } => {
461                self.cse_pass(left, seen);
462                self.cse_pass(right, seen);
463                let key = format!("{:?}", expr);
464                *seen.entry(key).or_insert(0) += 1;
465            }
466            Expr::Unary { expr: inner, .. } => {
467                self.cse_pass(inner, seen);
468            }
469            Expr::Call { args, .. } => {
470                for arg in args {
471                    self.cse_pass(arg, seen);
472                }
473            }
474            Expr::Conditional {
475                condition,
476                then_expr,
477                else_expr,
478                ..
479            } => {
480                self.cse_pass(condition, seen);
481                self.cse_pass(then_expr, seen);
482                self.cse_pass(else_expr, seen);
483            }
484            _ => {}
485        }
486    }
487}
488
489#[cfg(test)]
490mod tests {
491    use super::*;
492    use crate::dsl::Type;
493
494    #[test]
495    fn test_constant_fold_add() {
496        let expr = Expr::Binary {
497            left: Box::new(Expr::Number(2.0)),
498            op: BinaryOp::Add,
499            right: Box::new(Expr::Number(3.0)),
500            ty: Type::Number,
501        };
502
503        let opt = Optimizer::new(OptLevel::Basic);
504        let result = opt.optimize_expr(expr);
505
506        assert!(matches!(result, Expr::Number(n) if (n - 5.0).abs() < 1e-10));
507    }
508
509    #[test]
510    fn test_constant_fold_nested() {
511        let expr = Expr::Binary {
512            left: Box::new(Expr::Binary {
513                left: Box::new(Expr::Number(2.0)),
514                op: BinaryOp::Multiply,
515                right: Box::new(Expr::Number(3.0)),
516                ty: Type::Number,
517            }),
518            op: BinaryOp::Add,
519            right: Box::new(Expr::Number(4.0)),
520            ty: Type::Number,
521        };
522
523        let opt = Optimizer::new(OptLevel::Basic);
524        let result = opt.optimize_expr(expr);
525
526        assert!(matches!(result, Expr::Number(n) if (n - 10.0).abs() < 1e-10));
527    }
528
529    #[test]
530    fn test_algebraic_simplify_add_zero() {
531        let expr = Expr::Binary {
532            left: Box::new(Expr::Band(1)),
533            op: BinaryOp::Add,
534            right: Box::new(Expr::Number(0.0)),
535            ty: Type::Raster,
536        };
537
538        let opt = Optimizer::new(OptLevel::Standard);
539        let result = opt.optimize_expr(expr);
540
541        assert!(matches!(result, Expr::Band(1)));
542    }
543
544    #[test]
545    fn test_algebraic_simplify_mul_one() {
546        let expr = Expr::Binary {
547            left: Box::new(Expr::Band(1)),
548            op: BinaryOp::Multiply,
549            right: Box::new(Expr::Number(1.0)),
550            ty: Type::Raster,
551        };
552
553        let opt = Optimizer::new(OptLevel::Standard);
554        let result = opt.optimize_expr(expr);
555
556        assert!(matches!(result, Expr::Band(1)));
557    }
558
559    #[test]
560    fn test_algebraic_simplify_mul_zero() {
561        let expr = Expr::Binary {
562            left: Box::new(Expr::Band(1)),
563            op: BinaryOp::Multiply,
564            right: Box::new(Expr::Number(0.0)),
565            ty: Type::Raster,
566        };
567
568        let opt = Optimizer::new(OptLevel::Standard);
569        let result = opt.optimize_expr(expr);
570
571        assert!(matches!(result, Expr::Number(n) if n.abs() < 1e-10));
572    }
573
574    #[test]
575    fn test_double_negation() {
576        let expr = Expr::Unary {
577            op: UnaryOp::Negate,
578            expr: Box::new(Expr::Unary {
579                op: UnaryOp::Negate,
580                expr: Box::new(Expr::Band(1)),
581                ty: Type::Raster,
582            }),
583            ty: Type::Raster,
584        };
585
586        let opt = Optimizer::new(OptLevel::Standard);
587        let result = opt.optimize_expr(expr);
588
589        assert!(matches!(result, Expr::Band(1)));
590    }
591
592    #[test]
593    fn test_unary_plus() {
594        let expr = Expr::Unary {
595            op: UnaryOp::Plus,
596            expr: Box::new(Expr::Band(1)),
597            ty: Type::Raster,
598        };
599
600        let opt = Optimizer::new(OptLevel::Standard);
601        let result = opt.optimize_expr(expr);
602
603        assert!(matches!(result, Expr::Band(1)));
604    }
605}