1use super::ast::{BinaryOp, Expr, Program, Statement, UnaryOp};
10
11#[cfg(not(feature = "std"))]
12use alloc::{boxed::Box, collections::BTreeMap as HashMap, string::String, vec::Vec};
13
14#[cfg(feature = "std")]
15use std::collections::HashMap;
16
17#[derive(Debug, Clone, Copy, PartialEq, Eq)]
19pub enum OptLevel {
20 None,
22 Basic,
24 Standard,
26 Aggressive,
28}
29
30pub struct Optimizer {
32 level: OptLevel,
33}
34
35impl Default for Optimizer {
36 fn default() -> Self {
37 Self::new(OptLevel::Standard)
38 }
39}
40
41impl Optimizer {
42 pub fn new(level: OptLevel) -> Self {
44 Self { level }
45 }
46
47 pub fn optimize_program(&self, mut program: Program) -> Program {
49 if self.level == OptLevel::None {
50 return program;
51 }
52
53 program.statements = program
54 .statements
55 .into_iter()
56 .map(|stmt| self.optimize_statement(stmt))
57 .collect();
58
59 program
60 }
61
62 pub fn optimize_statement(&self, stmt: Statement) -> Statement {
64 match stmt {
65 Statement::VariableDecl { name, value } => Statement::VariableDecl {
66 name,
67 value: Box::new(self.optimize_expr(*value)),
68 },
69 Statement::FunctionDecl { name, params, body } => Statement::FunctionDecl {
70 name,
71 params,
72 body: Box::new(self.optimize_expr(*body)),
73 },
74 Statement::Return(expr) => Statement::Return(Box::new(self.optimize_expr(*expr))),
75 Statement::Expr(expr) => Statement::Expr(Box::new(self.optimize_expr(*expr))),
76 }
77 }
78
79 pub fn optimize_expr(&self, expr: Expr) -> Expr {
81 if self.level == OptLevel::None {
82 return expr;
83 }
84
85 let mut optimized = expr;
86
87 optimized = self.constant_fold(optimized);
89
90 if matches!(self.level, OptLevel::Standard | OptLevel::Aggressive) {
92 optimized = self.algebraic_simplify(optimized);
93 }
94
95 if self.level == OptLevel::Aggressive {
97 optimized = self.eliminate_common_subexpressions(optimized);
98 }
99
100 optimized
101 }
102
103 fn constant_fold(&self, expr: Expr) -> Expr {
105 match expr {
106 Expr::Binary {
107 left,
108 op,
109 right,
110 ty,
111 } => {
112 let left_opt = self.constant_fold(*left);
113 let right_opt = self.constant_fold(*right);
114
115 if let (Expr::Number(l), Expr::Number(r)) = (&left_opt, &right_opt) {
116 if let Some(result) = self.eval_const_binary(*l, op, *r) {
117 return Expr::Number(result);
118 }
119 }
120
121 Expr::Binary {
122 left: Box::new(left_opt),
123 op,
124 right: Box::new(right_opt),
125 ty,
126 }
127 }
128 Expr::Unary {
129 op,
130 expr: inner,
131 ty,
132 } => {
133 let inner_opt = self.constant_fold(*inner);
134
135 if let Expr::Number(n) = &inner_opt {
136 if let Some(result) = self.eval_const_unary(op, *n) {
137 return Expr::Number(result);
138 }
139 }
140
141 Expr::Unary {
142 op,
143 expr: Box::new(inner_opt),
144 ty,
145 }
146 }
147 Expr::Conditional {
148 condition,
149 then_expr,
150 else_expr,
151 ty,
152 } => {
153 let cond_opt = self.constant_fold(*condition);
154
155 if let Expr::Number(n) = &cond_opt {
157 if n.abs() > f64::EPSILON {
158 return self.constant_fold(*then_expr);
159 } else {
160 return self.constant_fold(*else_expr);
161 }
162 }
163
164 Expr::Conditional {
165 condition: Box::new(cond_opt),
166 then_expr: Box::new(self.constant_fold(*then_expr)),
167 else_expr: Box::new(self.constant_fold(*else_expr)),
168 ty,
169 }
170 }
171 Expr::Call { name, args, ty } => Expr::Call {
172 name,
173 args: args
174 .into_iter()
175 .map(|arg| self.constant_fold(arg))
176 .collect(),
177 ty,
178 },
179 Expr::Block {
180 statements,
181 result,
182 ty,
183 } => Expr::Block {
184 statements: statements
185 .into_iter()
186 .map(|stmt| self.optimize_statement(stmt))
187 .collect(),
188 result: result.map(|r| Box::new(self.constant_fold(*r))),
189 ty,
190 },
191 _ => expr,
192 }
193 }
194
195 fn eval_const_binary(&self, left: f64, op: BinaryOp, right: f64) -> Option<f64> {
197 let result = match op {
198 BinaryOp::Add => left + right,
199 BinaryOp::Subtract => left - right,
200 BinaryOp::Multiply => left * right,
201 BinaryOp::Divide => {
202 if right.abs() < f64::EPSILON {
203 return None;
204 }
205 left / right
206 }
207 BinaryOp::Modulo => left % right,
208 BinaryOp::Power => left.powf(right),
209 BinaryOp::Equal => {
210 if (left - right).abs() < f64::EPSILON {
211 1.0
212 } else {
213 0.0
214 }
215 }
216 BinaryOp::NotEqual => {
217 if (left - right).abs() >= f64::EPSILON {
218 1.0
219 } else {
220 0.0
221 }
222 }
223 BinaryOp::Less => {
224 if left < right {
225 1.0
226 } else {
227 0.0
228 }
229 }
230 BinaryOp::LessEqual => {
231 if left <= right {
232 1.0
233 } else {
234 0.0
235 }
236 }
237 BinaryOp::Greater => {
238 if left > right {
239 1.0
240 } else {
241 0.0
242 }
243 }
244 BinaryOp::GreaterEqual => {
245 if left >= right {
246 1.0
247 } else {
248 0.0
249 }
250 }
251 BinaryOp::And => {
252 if left != 0.0 && right != 0.0 {
253 1.0
254 } else {
255 0.0
256 }
257 }
258 BinaryOp::Or => {
259 if left != 0.0 || right != 0.0 {
260 1.0
261 } else {
262 0.0
263 }
264 }
265 };
266
267 Some(result)
268 }
269
270 fn eval_const_unary(&self, op: UnaryOp, operand: f64) -> Option<f64> {
272 let result = match op {
273 UnaryOp::Negate => -operand,
274 UnaryOp::Plus => operand,
275 UnaryOp::Not => {
276 if operand.abs() < f64::EPSILON {
277 1.0
278 } else {
279 0.0
280 }
281 }
282 };
283
284 Some(result)
285 }
286
287 fn algebraic_simplify(&self, expr: Expr) -> Expr {
289 match expr {
290 Expr::Binary {
291 left,
292 op,
293 right,
294 ty,
295 } => {
296 let left_opt = self.algebraic_simplify(*left);
297 let right_opt = self.algebraic_simplify(*right);
298
299 if op == BinaryOp::Add {
301 if let Expr::Number(n) = &right_opt {
302 if n.abs() < f64::EPSILON {
303 return left_opt;
304 }
305 }
306 if let Expr::Number(n) = &left_opt {
307 if n.abs() < f64::EPSILON {
308 return right_opt;
309 }
310 }
311 }
312
313 if op == BinaryOp::Subtract {
315 if let Expr::Number(n) = &right_opt {
316 if n.abs() < f64::EPSILON {
317 return left_opt;
318 }
319 }
320 }
321
322 if op == BinaryOp::Multiply {
324 if let Expr::Number(n) = &right_opt {
325 if n.abs() < f64::EPSILON {
326 return Expr::Number(0.0);
327 }
328 }
329 if let Expr::Number(n) = &left_opt {
330 if n.abs() < f64::EPSILON {
331 return Expr::Number(0.0);
332 }
333 }
334 }
335
336 if op == BinaryOp::Multiply {
338 if let Expr::Number(n) = &right_opt {
339 if (n - 1.0).abs() < f64::EPSILON {
340 return left_opt;
341 }
342 }
343 if let Expr::Number(n) = &left_opt {
344 if (n - 1.0).abs() < f64::EPSILON {
345 return right_opt;
346 }
347 }
348 }
349
350 if op == BinaryOp::Divide {
352 if let Expr::Number(n) = &right_opt {
353 if (n - 1.0).abs() < f64::EPSILON {
354 return left_opt;
355 }
356 }
357 }
358
359 if op == BinaryOp::Power {
361 if let Expr::Number(n) = &right_opt {
362 if n.abs() < f64::EPSILON {
363 return Expr::Number(1.0);
364 }
365 }
366 }
367
368 if op == BinaryOp::Power {
370 if let Expr::Number(n) = &right_opt {
371 if (n - 1.0).abs() < f64::EPSILON {
372 return left_opt;
373 }
374 }
375 }
376
377 Expr::Binary {
378 left: Box::new(left_opt),
379 op,
380 right: Box::new(right_opt),
381 ty,
382 }
383 }
384 Expr::Unary {
385 op,
386 expr: inner,
387 ty,
388 } => {
389 let inner_opt = self.algebraic_simplify(*inner);
390
391 if op == UnaryOp::Negate {
393 if let Expr::Unary {
394 op: UnaryOp::Negate,
395 expr: double_neg,
396 ..
397 } = &inner_opt
398 {
399 return *double_neg.clone();
400 }
401 }
402
403 if op == UnaryOp::Plus {
405 return inner_opt;
406 }
407
408 Expr::Unary {
409 op,
410 expr: Box::new(inner_opt),
411 ty,
412 }
413 }
414 Expr::Conditional {
415 condition,
416 then_expr,
417 else_expr,
418 ty,
419 } => Expr::Conditional {
420 condition: Box::new(self.algebraic_simplify(*condition)),
421 then_expr: Box::new(self.algebraic_simplify(*then_expr)),
422 else_expr: Box::new(self.algebraic_simplify(*else_expr)),
423 ty,
424 },
425 Expr::Call { name, args, ty } => Expr::Call {
426 name,
427 args: args
428 .into_iter()
429 .map(|arg| self.algebraic_simplify(arg))
430 .collect(),
431 ty,
432 },
433 Expr::Block {
434 statements,
435 result,
436 ty,
437 } => Expr::Block {
438 statements: statements
439 .into_iter()
440 .map(|stmt| self.optimize_statement(stmt))
441 .collect(),
442 result: result.map(|r| Box::new(self.algebraic_simplify(*r))),
443 ty,
444 },
445 _ => expr,
446 }
447 }
448
449 fn eliminate_common_subexpressions(&self, expr: Expr) -> Expr {
451 let mut seen: HashMap<String, usize> = HashMap::new();
452 self.cse_pass(&expr, &mut seen);
453 expr
456 }
457
458 fn cse_pass(&self, expr: &Expr, seen: &mut HashMap<String, usize>) {
459 match expr {
460 Expr::Binary { left, right, .. } => {
461 self.cse_pass(left, seen);
462 self.cse_pass(right, seen);
463 let key = format!("{:?}", expr);
464 *seen.entry(key).or_insert(0) += 1;
465 }
466 Expr::Unary { expr: inner, .. } => {
467 self.cse_pass(inner, seen);
468 }
469 Expr::Call { args, .. } => {
470 for arg in args {
471 self.cse_pass(arg, seen);
472 }
473 }
474 Expr::Conditional {
475 condition,
476 then_expr,
477 else_expr,
478 ..
479 } => {
480 self.cse_pass(condition, seen);
481 self.cse_pass(then_expr, seen);
482 self.cse_pass(else_expr, seen);
483 }
484 _ => {}
485 }
486 }
487}
488
489#[cfg(test)]
490mod tests {
491 use super::*;
492 use crate::dsl::Type;
493
494 #[test]
495 fn test_constant_fold_add() {
496 let expr = Expr::Binary {
497 left: Box::new(Expr::Number(2.0)),
498 op: BinaryOp::Add,
499 right: Box::new(Expr::Number(3.0)),
500 ty: Type::Number,
501 };
502
503 let opt = Optimizer::new(OptLevel::Basic);
504 let result = opt.optimize_expr(expr);
505
506 assert!(matches!(result, Expr::Number(n) if (n - 5.0).abs() < 1e-10));
507 }
508
509 #[test]
510 fn test_constant_fold_nested() {
511 let expr = Expr::Binary {
512 left: Box::new(Expr::Binary {
513 left: Box::new(Expr::Number(2.0)),
514 op: BinaryOp::Multiply,
515 right: Box::new(Expr::Number(3.0)),
516 ty: Type::Number,
517 }),
518 op: BinaryOp::Add,
519 right: Box::new(Expr::Number(4.0)),
520 ty: Type::Number,
521 };
522
523 let opt = Optimizer::new(OptLevel::Basic);
524 let result = opt.optimize_expr(expr);
525
526 assert!(matches!(result, Expr::Number(n) if (n - 10.0).abs() < 1e-10));
527 }
528
529 #[test]
530 fn test_algebraic_simplify_add_zero() {
531 let expr = Expr::Binary {
532 left: Box::new(Expr::Band(1)),
533 op: BinaryOp::Add,
534 right: Box::new(Expr::Number(0.0)),
535 ty: Type::Raster,
536 };
537
538 let opt = Optimizer::new(OptLevel::Standard);
539 let result = opt.optimize_expr(expr);
540
541 assert!(matches!(result, Expr::Band(1)));
542 }
543
544 #[test]
545 fn test_algebraic_simplify_mul_one() {
546 let expr = Expr::Binary {
547 left: Box::new(Expr::Band(1)),
548 op: BinaryOp::Multiply,
549 right: Box::new(Expr::Number(1.0)),
550 ty: Type::Raster,
551 };
552
553 let opt = Optimizer::new(OptLevel::Standard);
554 let result = opt.optimize_expr(expr);
555
556 assert!(matches!(result, Expr::Band(1)));
557 }
558
559 #[test]
560 fn test_algebraic_simplify_mul_zero() {
561 let expr = Expr::Binary {
562 left: Box::new(Expr::Band(1)),
563 op: BinaryOp::Multiply,
564 right: Box::new(Expr::Number(0.0)),
565 ty: Type::Raster,
566 };
567
568 let opt = Optimizer::new(OptLevel::Standard);
569 let result = opt.optimize_expr(expr);
570
571 assert!(matches!(result, Expr::Number(n) if n.abs() < 1e-10));
572 }
573
574 #[test]
575 fn test_double_negation() {
576 let expr = Expr::Unary {
577 op: UnaryOp::Negate,
578 expr: Box::new(Expr::Unary {
579 op: UnaryOp::Negate,
580 expr: Box::new(Expr::Band(1)),
581 ty: Type::Raster,
582 }),
583 ty: Type::Raster,
584 };
585
586 let opt = Optimizer::new(OptLevel::Standard);
587 let result = opt.optimize_expr(expr);
588
589 assert!(matches!(result, Expr::Band(1)));
590 }
591
592 #[test]
593 fn test_unary_plus() {
594 let expr = Expr::Unary {
595 op: UnaryOp::Plus,
596 expr: Box::new(Expr::Band(1)),
597 ty: Type::Raster,
598 };
599
600 let opt = Optimizer::new(OptLevel::Standard);
601 let result = opt.optimize_expr(expr);
602
603 assert!(matches!(result, Expr::Band(1)));
604 }
605}