aether/
optimizer.rs

1// src/optimizer.rs
2//! 代码优化器 - 包含尾递归优化、常量折叠等
3
4use crate::ast::{BinOp, Expr, Program, Stmt, UnaryOp};
5
6/// 代码优化器
7pub struct Optimizer {
8    /// 是否启用尾递归优化
9    pub tail_recursion: bool,
10    /// 是否启用常量折叠
11    pub constant_folding: bool,
12    /// 是否启用死代码消除
13    pub dead_code_elimination: bool,
14}
15
16impl Optimizer {
17    /// 创建新的优化器,所有优化默认启用
18    pub fn new() -> Self {
19        Optimizer {
20            tail_recursion: true,
21            constant_folding: true,
22            dead_code_elimination: true,
23        }
24    }
25
26    /// 优化整个程序
27    pub fn optimize_program(&self, program: &Program) -> Program {
28        let mut optimized = program.clone();
29
30        // 常量折叠
31        if self.constant_folding {
32            optimized = self.fold_constants(optimized);
33        }
34
35        // 死代码消除
36        if self.dead_code_elimination {
37            optimized = self.eliminate_dead_code(optimized);
38        }
39
40        // 尾递归优化
41        if self.tail_recursion {
42            optimized = self.optimize_tail_recursion(optimized);
43        }
44
45        optimized
46    }
47
48    /// 常量折叠优化
49    fn fold_constants(&self, program: Program) -> Program {
50        program
51            .into_iter()
52            .map(|stmt| self.fold_stmt(stmt))
53            .collect()
54    }
55
56    /// 折叠语句中的常量
57    fn fold_stmt(&self, stmt: Stmt) -> Stmt {
58        match stmt {
59            Stmt::Set { name, value } => Stmt::Set {
60                name,
61                value: self.fold_expr(value),
62            },
63            Stmt::FuncDef { name, params, body } => Stmt::FuncDef {
64                name,
65                params,
66                body: body.into_iter().map(|s| self.fold_stmt(s)).collect(),
67            },
68            Stmt::GeneratorDef { name, params, body } => Stmt::GeneratorDef {
69                name,
70                params,
71                body: body.into_iter().map(|s| self.fold_stmt(s)).collect(),
72            },
73            Stmt::Return(expr) => Stmt::Return(self.fold_expr(expr)),
74            Stmt::Yield(expr) => Stmt::Yield(self.fold_expr(expr)),
75            Stmt::While { condition, body } => Stmt::While {
76                condition: self.fold_expr(condition),
77                body: body.into_iter().map(|s| self.fold_stmt(s)).collect(),
78            },
79            Stmt::For {
80                var,
81                iterable,
82                body,
83            } => Stmt::For {
84                var,
85                iterable: self.fold_expr(iterable),
86                body: body.into_iter().map(|s| self.fold_stmt(s)).collect(),
87            },
88            Stmt::ForIndexed {
89                index_var,
90                value_var,
91                iterable,
92                body,
93            } => Stmt::ForIndexed {
94                index_var,
95                value_var,
96                iterable: self.fold_expr(iterable),
97                body: body.into_iter().map(|s| self.fold_stmt(s)).collect(),
98            },
99            Stmt::Expression(expr) => Stmt::Expression(self.fold_expr(expr)),
100            other => other,
101        }
102    }
103
104    /// 折叠表达式中的常量
105    fn fold_expr(&self, expr: Expr) -> Expr {
106        match expr {
107            // 二元运算常量折叠
108            Expr::Binary { left, op, right } => {
109                let left = self.fold_expr(*left);
110                let right = self.fold_expr(*right);
111
112                // 如果两边都是常量,直接计算结果
113                if let (Expr::Number(l), Expr::Number(r)) = (&left, &right) {
114                    if let Some(result) = Self::eval_const_binary(*l, &op, *r) {
115                        return Expr::Number(result);
116                    }
117                }
118
119                Expr::Binary {
120                    left: Box::new(left),
121                    op,
122                    right: Box::new(right),
123                }
124            }
125
126            // 一元运算常量折叠
127            Expr::Unary { op, expr } => {
128                let expr = self.fold_expr(*expr);
129
130                if let Expr::Number(n) = expr {
131                    match op {
132                        UnaryOp::Minus => return Expr::Number(-n),
133                        UnaryOp::Not => return Expr::Boolean(n == 0.0),
134                    }
135                }
136
137                if let (UnaryOp::Not, Expr::Boolean(b)) = (&op, &expr) {
138                    return Expr::Boolean(!b);
139                }
140
141                Expr::Unary {
142                    op,
143                    expr: Box::new(expr),
144                }
145            }
146
147            // 递归处理其他表达式
148            Expr::Call { func, args } => Expr::Call {
149                func: Box::new(self.fold_expr(*func)),
150                args: args.into_iter().map(|e| self.fold_expr(e)).collect(),
151            },
152
153            Expr::Array(elements) => {
154                Expr::Array(elements.into_iter().map(|e| self.fold_expr(e)).collect())
155            }
156
157            Expr::Index { object, index } => Expr::Index {
158                object: Box::new(self.fold_expr(*object)),
159                index: Box::new(self.fold_expr(*index)),
160            },
161
162            other => other,
163        }
164    }
165
166    /// 计算常量二元运算
167    fn eval_const_binary(left: f64, op: &BinOp, right: f64) -> Option<f64> {
168        match op {
169            BinOp::Add => Some(left + right),
170            BinOp::Subtract => Some(left - right),
171            BinOp::Multiply => Some(left * right),
172            BinOp::Divide if right != 0.0 => Some(left / right),
173            BinOp::Modulo if right != 0.0 => Some(left % right),
174            _ => None,
175        }
176    }
177
178    /// 死代码消除
179    fn eliminate_dead_code(&self, program: Program) -> Program {
180        program
181            .into_iter()
182            .filter_map(|stmt| self.eliminate_dead_stmt(stmt))
183            .collect()
184    }
185
186    /// 消除死语句
187    fn eliminate_dead_stmt(&self, stmt: Stmt) -> Option<Stmt> {
188        match stmt {
189            // While循环的常量条件
190            Stmt::While { condition, body } => {
191                if let Expr::Boolean(false) = condition {
192                    // 永远不执行的循环可以删除
193                    return None;
194                }
195
196                Some(Stmt::While {
197                    condition,
198                    body: body
199                        .into_iter()
200                        .filter_map(|s| self.eliminate_dead_stmt(s))
201                        .collect(),
202                })
203            }
204
205            // 函数定义递归处理
206            Stmt::FuncDef { name, params, body } => Some(Stmt::FuncDef {
207                name,
208                params,
209                body: body
210                    .into_iter()
211                    .filter_map(|s| self.eliminate_dead_stmt(s))
212                    .collect(),
213            }),
214
215            Stmt::GeneratorDef { name, params, body } => Some(Stmt::GeneratorDef {
216                name,
217                params,
218                body: body
219                    .into_iter()
220                    .filter_map(|s| self.eliminate_dead_stmt(s))
221                    .collect(),
222            }),
223
224            // 表达式语句中可能包含If表达式
225            Stmt::Expression(expr) => Some(Stmt::Expression(self.eliminate_dead_expr(expr))),
226
227            other => Some(other),
228        }
229    }
230
231    /// 消除表达式中的死代码
232    fn eliminate_dead_expr(&self, expr: Expr) -> Expr {
233        match expr {
234            Expr::If {
235                condition,
236                then_branch,
237                elif_branches,
238                else_branch,
239            } => {
240                if let Expr::Boolean(true) = *condition {
241                    // 条件永远为真,简化为then分支
242                    return Expr::If {
243                        condition: Box::new(Expr::Boolean(true)),
244                        then_branch,
245                        elif_branches: vec![],
246                        else_branch: None,
247                    };
248                }
249
250                if let Expr::Boolean(false) = *condition {
251                    // 条件永远为假,检查elif或else
252                    if let Some(else_body) = else_branch {
253                        // 简化为else块
254                        return Expr::If {
255                            condition: Box::new(Expr::Boolean(true)),
256                            then_branch: else_body,
257                            elif_branches: vec![],
258                            else_branch: None,
259                        };
260                    }
261                    // 没有else,返回null
262                    return Expr::Null;
263                }
264
265                // 递归处理分支
266                Expr::If {
267                    condition,
268                    then_branch: then_branch
269                        .into_iter()
270                        .filter_map(|s| self.eliminate_dead_stmt(s))
271                        .collect(),
272                    elif_branches: elif_branches
273                        .into_iter()
274                        .map(|(c, b)| {
275                            (
276                                self.eliminate_dead_expr(c),
277                                b.into_iter()
278                                    .filter_map(|s| self.eliminate_dead_stmt(s))
279                                    .collect(),
280                            )
281                        })
282                        .collect(),
283                    else_branch: else_branch.map(|b| {
284                        b.into_iter()
285                            .filter_map(|s| self.eliminate_dead_stmt(s))
286                            .collect()
287                    }),
288                }
289            }
290            other => other,
291        }
292    }
293
294    /// 尾递归优化
295    fn optimize_tail_recursion(&self, program: Program) -> Program {
296        program
297            .into_iter()
298            .map(|stmt| self.optimize_tail_recursive_stmt(stmt))
299            .collect()
300    }
301
302    /// 优化尾递归语句
303    fn optimize_tail_recursive_stmt(&self, stmt: Stmt) -> Stmt {
304        match stmt {
305            Stmt::FuncDef { name, params, body } => {
306                // 检查函数体是否包含尾递归
307                if self.is_tail_recursive(&name, &body) {
308                    // 转换为迭代形式
309                    Stmt::FuncDef {
310                        name: name.clone(),
311                        params: params.clone(),
312                        body: self.convert_tail_recursion_to_loop(&name, &params, body),
313                    }
314                } else {
315                    Stmt::FuncDef { name, params, body }
316                }
317            }
318            other => other,
319        }
320    }
321
322    /// 检查是否为尾递归
323    fn is_tail_recursive(&self, func_name: &str, body: &[Stmt]) -> bool {
324        if body.is_empty() {
325            return false;
326        }
327
328        // 递归检查所有可能的返回路径
329        self.has_tail_recursion_in_body(func_name, body)
330    }
331
332    /// 检查函数体中是否包含尾递归
333    fn has_tail_recursion_in_body(&self, func_name: &str, body: &[Stmt]) -> bool {
334        // 至少需要有一条return语句包含尾递归调用
335        body.iter()
336            .any(|stmt| self.stmt_has_tail_recursion(func_name, stmt))
337    }
338
339    /// 检查语句是否包含尾递归
340    fn stmt_has_tail_recursion(&self, func_name: &str, stmt: &Stmt) -> bool {
341        match stmt {
342            Stmt::Return(expr) => self.is_tail_call(func_name, expr),
343            Stmt::Expression(expr) => self.expr_has_tail_recursion(func_name, expr),
344            Stmt::While { body, .. } => self.has_tail_recursion_in_body(func_name, body),
345            Stmt::For { body, .. } => self.has_tail_recursion_in_body(func_name, body),
346            Stmt::ForIndexed { body, .. } => self.has_tail_recursion_in_body(func_name, body),
347            _ => false,
348        }
349    }
350
351    /// 检查表达式是否包含尾递归
352    fn expr_has_tail_recursion(&self, func_name: &str, expr: &Expr) -> bool {
353        match expr {
354            Expr::If {
355                then_branch,
356                elif_branches,
357                else_branch,
358                ..
359            } => {
360                // 检查所有分支
361                let then_tail = self.has_tail_recursion_in_body(func_name, then_branch);
362                let elif_tail = elif_branches
363                    .iter()
364                    .any(|(_, body)| self.has_tail_recursion_in_body(func_name, body));
365                let else_tail = else_branch
366                    .as_ref()
367                    .map(|body| self.has_tail_recursion_in_body(func_name, body))
368                    .unwrap_or(false);
369
370                then_tail || elif_tail || else_tail
371            }
372            _ => false,
373        }
374    }
375
376    /// 检查表达式是否为尾调用(增强版)
377    fn is_tail_call(&self, func_name: &str, expr: &Expr) -> bool {
378        match expr {
379            // 直接的递归调用
380            Expr::Call { func, .. } => {
381                if let Expr::Identifier(name) = &**func {
382                    name == func_name
383                } else {
384                    false
385                }
386            }
387            // 条件表达式中的尾调用
388            Expr::If {
389                then_branch,
390                elif_branches,
391                else_branch,
392                ..
393            } => {
394                // 所有分支都必须是尾调用或没有返回值
395                let then_is_tail = self.branch_ends_with_tail_call(func_name, then_branch);
396
397                let elif_all_tail = elif_branches
398                    .iter()
399                    .all(|(_, body)| self.branch_ends_with_tail_call(func_name, body));
400
401                let else_is_tail = else_branch
402                    .as_ref()
403                    .map(|body| self.branch_ends_with_tail_call(func_name, body))
404                    .unwrap_or(true);
405
406                then_is_tail && elif_all_tail && else_is_tail
407            }
408            _ => false,
409        }
410    }
411
412    /// 检查分支是否以尾调用结束
413    fn branch_ends_with_tail_call(&self, func_name: &str, branch: &[Stmt]) -> bool {
414        if let Some(last_stmt) = branch.last() {
415            match last_stmt {
416                Stmt::Return(expr) => self.is_tail_call(func_name, expr),
417                Stmt::Expression(expr) => {
418                    // 表达式可能是If表达式
419                    self.is_tail_call(func_name, expr)
420                }
421                _ => false,
422            }
423        } else {
424            false
425        }
426    }
427
428    /// 将尾递归转换为循环 (完整实现)
429    fn convert_tail_recursion_to_loop(
430        &self,
431        func_name: &str,
432        params: &[String],
433        body: Vec<Stmt>,
434    ) -> Vec<Stmt> {
435        // 步骤1: 为每个参数创建临时变量
436        let mut new_body = Vec::new();
437
438        // 初始化临时变量
439        for param in params {
440            new_body.push(Stmt::Set {
441                name: format!("_loop_{}", param),
442                value: Expr::Identifier(param.clone()),
443            });
444        }
445
446        // 步骤2: 创建循环标志
447        new_body.push(Stmt::Set {
448            name: "_loop_continue".to_string(),
449            value: Expr::Boolean(true),
450        });
451
452        // 步骤3: 转换函数体为while循环
453        let loop_body = self.transform_body_to_loop(func_name, params, body);
454
455        // 步骤4: 创建while循环
456        new_body.push(Stmt::While {
457            condition: Expr::Identifier("_loop_continue".to_string()),
458            body: loop_body,
459        });
460
461        new_body
462    }
463
464    /// 转换函数体为循环体
465    fn transform_body_to_loop(
466        &self,
467        func_name: &str,
468        params: &[String],
469        body: Vec<Stmt>,
470    ) -> Vec<Stmt> {
471        let mut loop_body = Vec::new();
472
473        for stmt in body {
474            match stmt {
475                Stmt::Return(expr) => {
476                    // 检查是否为尾递归调用
477                    if let Some(new_args) = self.extract_tail_call_args(func_name, &expr) {
478                        // 这是尾递归调用,转换为参数更新
479                        for (i, param) in params.iter().enumerate() {
480                            if let Some(arg) = new_args.get(i) {
481                                loop_body.push(Stmt::Set {
482                                    name: format!("_loop_{}", param),
483                                    value: arg.clone(),
484                                });
485                            }
486                        }
487
488                        // 更新参数值
489                        for param in params {
490                            loop_body.push(Stmt::Set {
491                                name: param.clone(),
492                                value: Expr::Identifier(format!("_loop_{}", param)),
493                            });
494                        }
495
496                        // 继续循环
497                    } else {
498                        // 这不是尾递归调用,正常返回
499                        loop_body.push(Stmt::Set {
500                            name: "_loop_continue".to_string(),
501                            value: Expr::Boolean(false),
502                        });
503                        loop_body.push(Stmt::Return(expr));
504                    }
505                }
506                _ => {
507                    // 其他语句递归转换
508                    loop_body.push(self.transform_stmt_for_loop(func_name, params, stmt));
509                }
510            }
511        }
512
513        loop_body
514    }
515
516    /// 提取尾调用的参数
517    fn extract_tail_call_args(&self, func_name: &str, expr: &Expr) -> Option<Vec<Expr>> {
518        match expr {
519            Expr::Call { func, args } => {
520                if let Expr::Identifier(name) = &**func {
521                    if name == func_name {
522                        return Some(args.clone());
523                    }
524                }
525                None
526            }
527            _ => None,
528        }
529    }
530
531    /// 转换语句以适应循环结构
532    fn transform_stmt_for_loop(&self, func_name: &str, params: &[String], stmt: Stmt) -> Stmt {
533        match stmt {
534            Stmt::Expression(expr) => {
535                // 处理If表达式
536                Stmt::Expression(self.transform_expr_for_loop(func_name, params, expr))
537            }
538            Stmt::While { condition, body } => Stmt::While {
539                condition,
540                body: self.transform_body_to_loop(func_name, params, body),
541            },
542            Stmt::For {
543                var,
544                iterable,
545                body,
546            } => Stmt::For {
547                var,
548                iterable,
549                body: self.transform_body_to_loop(func_name, params, body),
550            },
551            Stmt::ForIndexed {
552                index_var,
553                value_var,
554                iterable,
555                body,
556            } => Stmt::ForIndexed {
557                index_var,
558                value_var,
559                iterable,
560                body: self.transform_body_to_loop(func_name, params, body),
561            },
562            other => other,
563        }
564    }
565
566    /// 转换表达式以适应循环结构
567    fn transform_expr_for_loop(&self, func_name: &str, params: &[String], expr: Expr) -> Expr {
568        match expr {
569            Expr::If {
570                condition,
571                then_branch,
572                elif_branches,
573                else_branch,
574            } => Expr::If {
575                condition,
576                then_branch: self.transform_body_to_loop(func_name, params, then_branch),
577                elif_branches: elif_branches
578                    .into_iter()
579                    .map(|(cond, body)| {
580                        (cond, self.transform_body_to_loop(func_name, params, body))
581                    })
582                    .collect(),
583                else_branch: else_branch
584                    .map(|body| self.transform_body_to_loop(func_name, params, body)),
585            },
586            other => other,
587        }
588    }
589}
590
591impl Default for Optimizer {
592    fn default() -> Self {
593        Self::new()
594    }
595}
596
597#[cfg(test)]
598mod tests {
599    use super::*;
600
601    #[test]
602    fn test_constant_folding() {
603        let optimizer = Optimizer::new();
604
605        // 测试: 2 + 3 应该折叠为 5
606        let expr = Expr::Binary {
607            left: Box::new(Expr::Number(2.0)),
608            op: BinOp::Add,
609            right: Box::new(Expr::Number(3.0)),
610        };
611
612        let folded = optimizer.fold_expr(expr);
613        assert_eq!(folded, Expr::Number(5.0));
614    }
615
616    #[test]
617    fn test_dead_code_elimination() {
618        let optimizer = Optimizer::new();
619
620        // While False 应该被删除
621        let stmt = Stmt::While {
622            condition: Expr::Boolean(false),
623            body: vec![Stmt::Set {
624                name: "x".to_string(),
625                value: Expr::Number(10.0),
626            }],
627        };
628
629        let result = optimizer.eliminate_dead_stmt(stmt);
630        assert!(result.is_none());
631    }
632
633    #[test]
634    fn test_optimize_program() {
635        let optimizer = Optimizer::new();
636
637        let program: Program = vec![Stmt::Set {
638            name: "x".to_string(),
639            value: Expr::Binary {
640                left: Box::new(Expr::Number(2.0)),
641                op: BinOp::Add,
642                right: Box::new(Expr::Number(3.0)),
643            },
644        }];
645
646        let optimized = optimizer.optimize_program(&program);
647
648        // 验证常量折叠
649        if let Some(Stmt::Set { value, .. }) = optimized.first() {
650            assert_eq!(*value, Expr::Number(5.0));
651        }
652    }
653
654    #[test]
655    fn test_tail_recursion_detection() {
656        let optimizer = Optimizer::new();
657
658        // 测试简单的尾递归
659        let body = vec![Stmt::Return(Expr::Call {
660            func: Box::new(Expr::Identifier("factorial".to_string())),
661            args: vec![
662                Expr::Binary {
663                    left: Box::new(Expr::Identifier("n".to_string())),
664                    op: BinOp::Subtract,
665                    right: Box::new(Expr::Number(1.0)),
666                },
667                Expr::Binary {
668                    left: Box::new(Expr::Identifier("acc".to_string())),
669                    op: BinOp::Multiply,
670                    right: Box::new(Expr::Identifier("n".to_string())),
671                },
672            ],
673        })];
674
675        assert!(optimizer.is_tail_recursive("factorial", &body));
676    }
677
678    #[test]
679    fn test_non_tail_recursion_detection() {
680        let optimizer = Optimizer::new();
681
682        // 测试非尾递归(递归调用后还有操作)
683        let body = vec![Stmt::Return(Expr::Binary {
684            left: Box::new(Expr::Identifier("n".to_string())),
685            op: BinOp::Multiply,
686            right: Box::new(Expr::Call {
687                func: Box::new(Expr::Identifier("factorial".to_string())),
688                args: vec![Expr::Binary {
689                    left: Box::new(Expr::Identifier("n".to_string())),
690                    op: BinOp::Subtract,
691                    right: Box::new(Expr::Number(1.0)),
692                }],
693            }),
694        })];
695
696        assert!(!optimizer.is_tail_recursive("factorial", &body));
697    }
698
699    #[test]
700    fn test_tail_recursion_in_if() {
701        let optimizer = Optimizer::new();
702
703        // 测试If表达式中的尾递归
704        // 实际上Aether中Return语句后面跟的是表达式,而If是表达式
705        // 所以我们需要Return一个If表达式
706        let body = vec![Stmt::Expression(Expr::If {
707            condition: Box::new(Expr::Binary {
708                left: Box::new(Expr::Identifier("n".to_string())),
709                op: BinOp::LessEqual,
710                right: Box::new(Expr::Number(0.0)),
711            }),
712            then_branch: vec![Stmt::Return(Expr::Identifier("acc".to_string()))],
713            elif_branches: vec![],
714            else_branch: Some(vec![Stmt::Return(Expr::Call {
715                func: Box::new(Expr::Identifier("sum".to_string())),
716                args: vec![
717                    Expr::Binary {
718                        left: Box::new(Expr::Identifier("n".to_string())),
719                        op: BinOp::Subtract,
720                        right: Box::new(Expr::Number(1.0)),
721                    },
722                    Expr::Binary {
723                        left: Box::new(Expr::Identifier("acc".to_string())),
724                        op: BinOp::Add,
725                        right: Box::new(Expr::Identifier("n".to_string())),
726                    },
727                ],
728            })]),
729        })];
730
731        assert!(optimizer.is_tail_recursive("sum", &body));
732    }
733
734    #[test]
735    fn test_tail_recursion_optimization_transform() {
736        let optimizer = Optimizer::new();
737
738        // 创建一个简单的尾递归函数
739        let func_def = Stmt::FuncDef {
740            name: "factorial".to_string(),
741            params: vec!["n".to_string(), "acc".to_string()],
742            body: vec![Stmt::Return(Expr::Call {
743                func: Box::new(Expr::Identifier("factorial".to_string())),
744                args: vec![
745                    Expr::Binary {
746                        left: Box::new(Expr::Identifier("n".to_string())),
747                        op: BinOp::Subtract,
748                        right: Box::new(Expr::Number(1.0)),
749                    },
750                    Expr::Binary {
751                        left: Box::new(Expr::Identifier("acc".to_string())),
752                        op: BinOp::Multiply,
753                        right: Box::new(Expr::Identifier("n".to_string())),
754                    },
755                ],
756            })],
757        };
758
759        let optimized = optimizer.optimize_tail_recursive_stmt(func_def);
760
761        // 验证转换后包含While循环
762        if let Stmt::FuncDef { body, .. } = optimized {
763            // 应该包含临时变量初始化、循环标志和while循环
764            // 2个参数 = 2个临时变量 + 1个循环标志 + 1个while循环 = 4个语句
765            assert!(
766                body.len() >= 3,
767                "Expected at least 3 statements, got {}",
768                body.len()
769            );
770
771            // 最后一个语句应该是While循环
772            if let Some(Stmt::While { .. }) = body.last() {
773                // 成功转换为循环
774            } else {
775                panic!("Expected While loop at the end of optimized function body");
776            }
777        } else {
778            panic!("Expected FuncDef");
779        }
780    }
781}