aether/
optimizer.rs

1// src/optimizer.rs
2//! 代码优化器 - 包含尾递归优化、常量折叠等
3
4use crate::ast::{BinOp, Expr, Program, Stmt, UnaryOp};
5
6/// 代码优化器
7pub struct Optimizer {
8    /// 是否启用尾递归优化
9    pub tail_recursion: bool,
10    /// 是否启用常量折叠
11    pub constant_folding: bool,
12    /// 是否启用死代码消除
13    pub dead_code_elimination: bool,
14}
15
16impl Optimizer {
17    /// 创建新的优化器,所有优化默认启用
18    pub fn new() -> Self {
19        Optimizer {
20            tail_recursion: true,
21            constant_folding: true,
22            dead_code_elimination: true,
23        }
24    }
25
26    /// 优化整个程序
27    pub fn optimize_program(&self, program: &Program) -> Program {
28        let mut optimized = program.clone();
29
30        // 常量折叠
31        if self.constant_folding {
32            optimized = self.fold_constants(optimized);
33        }
34
35        // 死代码消除
36        if self.dead_code_elimination {
37            optimized = self.eliminate_dead_code(optimized);
38        }
39
40        // 尾递归优化
41        if self.tail_recursion {
42            optimized = self.optimize_tail_recursion(optimized);
43        }
44
45        optimized
46    }
47
48    /// 常量折叠优化
49    fn fold_constants(&self, program: Program) -> Program {
50        program
51            .into_iter()
52            .map(|stmt| self.fold_stmt(stmt))
53            .collect()
54    }
55
56    /// 折叠语句中的常量
57    fn fold_stmt(&self, stmt: Stmt) -> Stmt {
58        match stmt {
59            Stmt::Set { name, value } => Stmt::Set {
60                name,
61                value: self.fold_expr(value),
62            },
63            Stmt::FuncDef { name, params, body } => Stmt::FuncDef {
64                name,
65                params,
66                body: body.into_iter().map(|s| self.fold_stmt(s)).collect(),
67            },
68            Stmt::GeneratorDef { name, params, body } => Stmt::GeneratorDef {
69                name,
70                params,
71                body: body.into_iter().map(|s| self.fold_stmt(s)).collect(),
72            },
73            Stmt::Return(expr) => Stmt::Return(self.fold_expr(expr)),
74            Stmt::Yield(expr) => Stmt::Yield(self.fold_expr(expr)),
75            Stmt::While { condition, body } => Stmt::While {
76                condition: self.fold_expr(condition),
77                body: body.into_iter().map(|s| self.fold_stmt(s)).collect(),
78            },
79            Stmt::For {
80                var,
81                iterable,
82                body,
83            } => Stmt::For {
84                var,
85                iterable: self.fold_expr(iterable),
86                body: body.into_iter().map(|s| self.fold_stmt(s)).collect(),
87            },
88            Stmt::ForIndexed {
89                index_var,
90                value_var,
91                iterable,
92                body,
93            } => Stmt::ForIndexed {
94                index_var,
95                value_var,
96                iterable: self.fold_expr(iterable),
97                body: body.into_iter().map(|s| self.fold_stmt(s)).collect(),
98            },
99            Stmt::Expression(expr) => Stmt::Expression(self.fold_expr(expr)),
100            other => other,
101        }
102    }
103
104    /// 折叠表达式中的常量
105    #[allow(clippy::only_used_in_recursion)]
106    fn fold_expr(&self, expr: Expr) -> Expr {
107        match expr {
108            // 二元运算常量折叠
109            Expr::Binary { left, op, right } => {
110                let left = self.fold_expr(*left);
111                let right = self.fold_expr(*right);
112
113                // 如果两边都是常量,直接计算结果
114                if let (Expr::Number(l), Expr::Number(r)) = (&left, &right)
115                    && let Some(result) = Self::eval_const_binary(*l, &op, *r)
116                {
117                    return Expr::Number(result);
118                }
119
120                Expr::Binary {
121                    left: Box::new(left),
122                    op,
123                    right: Box::new(right),
124                }
125            }
126
127            // 一元运算常量折叠
128            Expr::Unary { op, expr } => {
129                let expr = self.fold_expr(*expr);
130
131                if let Expr::Number(n) = expr {
132                    match op {
133                        UnaryOp::Minus => return Expr::Number(-n),
134                        UnaryOp::Not => return Expr::Boolean(n == 0.0),
135                    }
136                }
137
138                if let (UnaryOp::Not, Expr::Boolean(b)) = (&op, &expr) {
139                    return Expr::Boolean(!b);
140                }
141
142                Expr::Unary {
143                    op,
144                    expr: Box::new(expr),
145                }
146            }
147
148            // 递归处理其他表达式
149            Expr::Call { func, args } => Expr::Call {
150                func: Box::new(self.fold_expr(*func)),
151                args: args.into_iter().map(|e| self.fold_expr(e)).collect(),
152            },
153
154            Expr::Array(elements) => {
155                Expr::Array(elements.into_iter().map(|e| self.fold_expr(e)).collect())
156            }
157
158            Expr::Index { object, index } => Expr::Index {
159                object: Box::new(self.fold_expr(*object)),
160                index: Box::new(self.fold_expr(*index)),
161            },
162
163            other => other,
164        }
165    }
166
167    /// 计算常量二元运算
168    fn eval_const_binary(left: f64, op: &BinOp, right: f64) -> Option<f64> {
169        match op {
170            BinOp::Add => Some(left + right),
171            BinOp::Subtract => Some(left - right),
172            BinOp::Multiply => Some(left * right),
173            BinOp::Divide if right != 0.0 => Some(left / right),
174            BinOp::Modulo if right != 0.0 => Some(left % right),
175            _ => None,
176        }
177    }
178
179    /// 死代码消除
180    fn eliminate_dead_code(&self, program: Program) -> Program {
181        program
182            .into_iter()
183            .filter_map(|stmt| self.eliminate_dead_stmt(stmt))
184            .collect()
185    }
186
187    /// 消除死语句
188    fn eliminate_dead_stmt(&self, stmt: Stmt) -> Option<Stmt> {
189        match stmt {
190            // While循环的常量条件
191            Stmt::While { condition, body } => {
192                if let Expr::Boolean(false) = condition {
193                    // 永远不执行的循环可以删除
194                    return None;
195                }
196
197                Some(Stmt::While {
198                    condition,
199                    body: body
200                        .into_iter()
201                        .filter_map(|s| self.eliminate_dead_stmt(s))
202                        .collect(),
203                })
204            }
205
206            // 函数定义递归处理
207            Stmt::FuncDef { name, params, body } => Some(Stmt::FuncDef {
208                name,
209                params,
210                body: body
211                    .into_iter()
212                    .filter_map(|s| self.eliminate_dead_stmt(s))
213                    .collect(),
214            }),
215
216            Stmt::GeneratorDef { name, params, body } => Some(Stmt::GeneratorDef {
217                name,
218                params,
219                body: body
220                    .into_iter()
221                    .filter_map(|s| self.eliminate_dead_stmt(s))
222                    .collect(),
223            }),
224
225            // 表达式语句中可能包含If表达式
226            Stmt::Expression(expr) => Some(Stmt::Expression(self.eliminate_dead_expr(expr))),
227
228            other => Some(other),
229        }
230    }
231
232    /// 消除表达式中的死代码
233    fn eliminate_dead_expr(&self, expr: Expr) -> Expr {
234        match expr {
235            Expr::If {
236                condition,
237                then_branch,
238                elif_branches,
239                else_branch,
240            } => {
241                if let Expr::Boolean(true) = *condition {
242                    // 条件永远为真,简化为then分支
243                    return Expr::If {
244                        condition: Box::new(Expr::Boolean(true)),
245                        then_branch,
246                        elif_branches: vec![],
247                        else_branch: None,
248                    };
249                }
250
251                if let Expr::Boolean(false) = *condition {
252                    // 条件永远为假,检查elif或else
253                    if let Some(else_body) = else_branch {
254                        // 简化为else块
255                        return Expr::If {
256                            condition: Box::new(Expr::Boolean(true)),
257                            then_branch: else_body,
258                            elif_branches: vec![],
259                            else_branch: None,
260                        };
261                    }
262                    // 没有else,返回null
263                    return Expr::Null;
264                }
265
266                // 递归处理分支
267                Expr::If {
268                    condition,
269                    then_branch: then_branch
270                        .into_iter()
271                        .filter_map(|s| self.eliminate_dead_stmt(s))
272                        .collect(),
273                    elif_branches: elif_branches
274                        .into_iter()
275                        .map(|(c, b)| {
276                            (
277                                self.eliminate_dead_expr(c),
278                                b.into_iter()
279                                    .filter_map(|s| self.eliminate_dead_stmt(s))
280                                    .collect(),
281                            )
282                        })
283                        .collect(),
284                    else_branch: else_branch.map(|b| {
285                        b.into_iter()
286                            .filter_map(|s| self.eliminate_dead_stmt(s))
287                            .collect()
288                    }),
289                }
290            }
291            other => other,
292        }
293    }
294
295    /// 尾递归优化
296    fn optimize_tail_recursion(&self, program: Program) -> Program {
297        program
298            .into_iter()
299            .map(|stmt| self.optimize_tail_recursive_stmt(stmt))
300            .collect()
301    }
302
303    /// 优化尾递归语句
304    fn optimize_tail_recursive_stmt(&self, stmt: Stmt) -> Stmt {
305        match stmt {
306            Stmt::FuncDef { name, params, body } => {
307                // 检查函数体是否包含尾递归
308                if self.is_tail_recursive(&name, &body) {
309                    // 转换为迭代形式
310                    Stmt::FuncDef {
311                        name: name.clone(),
312                        params: params.clone(),
313                        body: self.convert_tail_recursion_to_loop(&name, &params, body),
314                    }
315                } else {
316                    Stmt::FuncDef { name, params, body }
317                }
318            }
319            other => other,
320        }
321    }
322
323    /// 检查是否为尾递归
324    fn is_tail_recursive(&self, func_name: &str, body: &[Stmt]) -> bool {
325        if body.is_empty() {
326            return false;
327        }
328
329        // 递归检查所有可能的返回路径
330        self.has_tail_recursion_in_body(func_name, body)
331    }
332
333    /// 检查函数体中是否包含尾递归
334    fn has_tail_recursion_in_body(&self, func_name: &str, body: &[Stmt]) -> bool {
335        // 至少需要有一条return语句包含尾递归调用
336        body.iter()
337            .any(|stmt| self.stmt_has_tail_recursion(func_name, stmt))
338    }
339
340    /// 检查语句是否包含尾递归
341    fn stmt_has_tail_recursion(&self, func_name: &str, stmt: &Stmt) -> bool {
342        match stmt {
343            Stmt::Return(expr) => self.is_tail_call(func_name, expr),
344            Stmt::Expression(expr) => self.expr_has_tail_recursion(func_name, expr),
345            Stmt::While { body, .. } => self.has_tail_recursion_in_body(func_name, body),
346            Stmt::For { body, .. } => self.has_tail_recursion_in_body(func_name, body),
347            Stmt::ForIndexed { body, .. } => self.has_tail_recursion_in_body(func_name, body),
348            _ => false,
349        }
350    }
351
352    /// 检查表达式是否包含尾递归
353    fn expr_has_tail_recursion(&self, func_name: &str, expr: &Expr) -> bool {
354        match expr {
355            Expr::If {
356                then_branch,
357                elif_branches,
358                else_branch,
359                ..
360            } => {
361                // 检查所有分支
362                let then_tail = self.has_tail_recursion_in_body(func_name, then_branch);
363                let elif_tail = elif_branches
364                    .iter()
365                    .any(|(_, body)| self.has_tail_recursion_in_body(func_name, body));
366                let else_tail = else_branch
367                    .as_ref()
368                    .map(|body| self.has_tail_recursion_in_body(func_name, body))
369                    .unwrap_or(false);
370
371                then_tail || elif_tail || else_tail
372            }
373            _ => false,
374        }
375    }
376
377    /// 检查表达式是否为尾调用(增强版)
378    fn is_tail_call(&self, func_name: &str, expr: &Expr) -> bool {
379        match expr {
380            // 直接的递归调用
381            Expr::Call { func, .. } => {
382                if let Expr::Identifier(name) = &**func {
383                    name == func_name
384                } else {
385                    false
386                }
387            }
388            // 条件表达式中的尾调用
389            Expr::If {
390                then_branch,
391                elif_branches,
392                else_branch,
393                ..
394            } => {
395                // 所有分支都必须是尾调用或没有返回值
396                let then_is_tail = self.branch_ends_with_tail_call(func_name, then_branch);
397
398                let elif_all_tail = elif_branches
399                    .iter()
400                    .all(|(_, body)| self.branch_ends_with_tail_call(func_name, body));
401
402                let else_is_tail = else_branch
403                    .as_ref()
404                    .map(|body| self.branch_ends_with_tail_call(func_name, body))
405                    .unwrap_or(true);
406
407                then_is_tail && elif_all_tail && else_is_tail
408            }
409            _ => false,
410        }
411    }
412
413    /// 检查分支是否以尾调用结束
414    fn branch_ends_with_tail_call(&self, func_name: &str, branch: &[Stmt]) -> bool {
415        if let Some(last_stmt) = branch.last() {
416            match last_stmt {
417                Stmt::Return(expr) => self.is_tail_call(func_name, expr),
418                Stmt::Expression(expr) => {
419                    // 表达式可能是If表达式
420                    self.is_tail_call(func_name, expr)
421                }
422                _ => false,
423            }
424        } else {
425            false
426        }
427    }
428
429    /// 将尾递归转换为循环 (完整实现)
430    fn convert_tail_recursion_to_loop(
431        &self,
432        func_name: &str,
433        params: &[String],
434        body: Vec<Stmt>,
435    ) -> Vec<Stmt> {
436        // 步骤1: 为每个参数创建临时变量
437        let mut new_body = Vec::new();
438
439        // 初始化临时变量
440        for param in params {
441            new_body.push(Stmt::Set {
442                name: format!("_loop_{}", param),
443                value: Expr::Identifier(param.clone()),
444            });
445        }
446
447        // 步骤2: 创建循环标志
448        new_body.push(Stmt::Set {
449            name: "_loop_continue".to_string(),
450            value: Expr::Boolean(true),
451        });
452
453        // 步骤3: 转换函数体为while循环
454        let loop_body = self.transform_body_to_loop(func_name, params, body);
455
456        // 步骤4: 创建while循环
457        new_body.push(Stmt::While {
458            condition: Expr::Identifier("_loop_continue".to_string()),
459            body: loop_body,
460        });
461
462        new_body
463    }
464
465    /// 转换函数体为循环体
466    fn transform_body_to_loop(
467        &self,
468        func_name: &str,
469        params: &[String],
470        body: Vec<Stmt>,
471    ) -> Vec<Stmt> {
472        let mut loop_body = Vec::new();
473
474        for stmt in body {
475            match stmt {
476                Stmt::Return(expr) => {
477                    // 检查是否为尾递归调用
478                    if let Some(new_args) = self.extract_tail_call_args(func_name, &expr) {
479                        // 这是尾递归调用,转换为参数更新
480                        for (i, param) in params.iter().enumerate() {
481                            if let Some(arg) = new_args.get(i) {
482                                loop_body.push(Stmt::Set {
483                                    name: format!("_loop_{}", param),
484                                    value: arg.clone(),
485                                });
486                            }
487                        }
488
489                        // 更新参数值
490                        for param in params {
491                            loop_body.push(Stmt::Set {
492                                name: param.clone(),
493                                value: Expr::Identifier(format!("_loop_{}", param)),
494                            });
495                        }
496
497                        // 继续循环
498                    } else {
499                        // 这不是尾递归调用,正常返回
500                        loop_body.push(Stmt::Set {
501                            name: "_loop_continue".to_string(),
502                            value: Expr::Boolean(false),
503                        });
504                        loop_body.push(Stmt::Return(expr));
505                    }
506                }
507                _ => {
508                    // 其他语句递归转换
509                    loop_body.push(self.transform_stmt_for_loop(func_name, params, stmt));
510                }
511            }
512        }
513
514        loop_body
515    }
516
517    /// 提取尾调用的参数
518    fn extract_tail_call_args(&self, func_name: &str, expr: &Expr) -> Option<Vec<Expr>> {
519        match expr {
520            Expr::Call { func, args } => {
521                if let Expr::Identifier(name) = &**func
522                    && name == func_name
523                {
524                    return Some(args.clone());
525                }
526                None
527            }
528            _ => None,
529        }
530    }
531
532    /// 转换语句以适应循环结构
533    fn transform_stmt_for_loop(&self, func_name: &str, params: &[String], stmt: Stmt) -> Stmt {
534        match stmt {
535            Stmt::Expression(expr) => {
536                // 处理If表达式
537                Stmt::Expression(self.transform_expr_for_loop(func_name, params, expr))
538            }
539            Stmt::While { condition, body } => Stmt::While {
540                condition,
541                body: self.transform_body_to_loop(func_name, params, body),
542            },
543            Stmt::For {
544                var,
545                iterable,
546                body,
547            } => Stmt::For {
548                var,
549                iterable,
550                body: self.transform_body_to_loop(func_name, params, body),
551            },
552            Stmt::ForIndexed {
553                index_var,
554                value_var,
555                iterable,
556                body,
557            } => Stmt::ForIndexed {
558                index_var,
559                value_var,
560                iterable,
561                body: self.transform_body_to_loop(func_name, params, body),
562            },
563            other => other,
564        }
565    }
566
567    /// 转换表达式以适应循环结构
568    fn transform_expr_for_loop(&self, func_name: &str, params: &[String], expr: Expr) -> Expr {
569        match expr {
570            Expr::If {
571                condition,
572                then_branch,
573                elif_branches,
574                else_branch,
575            } => Expr::If {
576                condition,
577                then_branch: self.transform_body_to_loop(func_name, params, then_branch),
578                elif_branches: elif_branches
579                    .into_iter()
580                    .map(|(cond, body)| {
581                        (cond, self.transform_body_to_loop(func_name, params, body))
582                    })
583                    .collect(),
584                else_branch: else_branch
585                    .map(|body| self.transform_body_to_loop(func_name, params, body)),
586            },
587            other => other,
588        }
589    }
590}
591
592impl Default for Optimizer {
593    fn default() -> Self {
594        Self::new()
595    }
596}
597
598// 此处保留之测试代码都为测试私有函数者
599#[cfg(test)]
600mod tests {
601    use super::*;
602
603    #[test]
604    fn test_constant_folding() {
605        let optimizer = Optimizer::new();
606
607        // 测试: 2 + 3 应该折叠为 5
608        let expr = Expr::Binary {
609            left: Box::new(Expr::Number(2.0)),
610            op: BinOp::Add,
611            right: Box::new(Expr::Number(3.0)),
612        };
613
614        let folded = optimizer.fold_expr(expr);
615        assert_eq!(folded, Expr::Number(5.0));
616    }
617
618    #[test]
619    fn test_dead_code_elimination() {
620        let optimizer = Optimizer::new();
621
622        // While False 应该被删除
623        let stmt = Stmt::While {
624            condition: Expr::Boolean(false),
625            body: vec![Stmt::Set {
626                name: "x".to_string(),
627                value: Expr::Number(10.0),
628            }],
629        };
630
631        let result = optimizer.eliminate_dead_stmt(stmt);
632        assert!(result.is_none());
633    }
634
635    #[test]
636    fn test_tail_recursion_detection() {
637        let optimizer = Optimizer::new();
638
639        // 测试简单的尾递归
640        let body = vec![Stmt::Return(Expr::Call {
641            func: Box::new(Expr::Identifier("factorial".to_string())),
642            args: vec![
643                Expr::Binary {
644                    left: Box::new(Expr::Identifier("n".to_string())),
645                    op: BinOp::Subtract,
646                    right: Box::new(Expr::Number(1.0)),
647                },
648                Expr::Binary {
649                    left: Box::new(Expr::Identifier("acc".to_string())),
650                    op: BinOp::Multiply,
651                    right: Box::new(Expr::Identifier("n".to_string())),
652                },
653            ],
654        })];
655
656        assert!(optimizer.is_tail_recursive("factorial", &body));
657    }
658
659    #[test]
660    fn test_non_tail_recursion_detection() {
661        let optimizer = Optimizer::new();
662
663        // 测试非尾递归(递归调用后还有操作)
664        let body = vec![Stmt::Return(Expr::Binary {
665            left: Box::new(Expr::Identifier("n".to_string())),
666            op: BinOp::Multiply,
667            right: Box::new(Expr::Call {
668                func: Box::new(Expr::Identifier("factorial".to_string())),
669                args: vec![Expr::Binary {
670                    left: Box::new(Expr::Identifier("n".to_string())),
671                    op: BinOp::Subtract,
672                    right: Box::new(Expr::Number(1.0)),
673                }],
674            }),
675        })];
676
677        assert!(!optimizer.is_tail_recursive("factorial", &body));
678    }
679
680    #[test]
681    fn test_tail_recursion_in_if() {
682        let optimizer = Optimizer::new();
683
684        // 测试If表达式中的尾递归
685        // 实际上Aether中Return语句后面跟的是表达式,而If是表达式
686        // 所以我们需要Return一个If表达式
687        let body = vec![Stmt::Expression(Expr::If {
688            condition: Box::new(Expr::Binary {
689                left: Box::new(Expr::Identifier("n".to_string())),
690                op: BinOp::LessEqual,
691                right: Box::new(Expr::Number(0.0)),
692            }),
693            then_branch: vec![Stmt::Return(Expr::Identifier("acc".to_string()))],
694            elif_branches: vec![],
695            else_branch: Some(vec![Stmt::Return(Expr::Call {
696                func: Box::new(Expr::Identifier("sum".to_string())),
697                args: vec![
698                    Expr::Binary {
699                        left: Box::new(Expr::Identifier("n".to_string())),
700                        op: BinOp::Subtract,
701                        right: Box::new(Expr::Number(1.0)),
702                    },
703                    Expr::Binary {
704                        left: Box::new(Expr::Identifier("acc".to_string())),
705                        op: BinOp::Add,
706                        right: Box::new(Expr::Identifier("n".to_string())),
707                    },
708                ],
709            })]),
710        })];
711
712        assert!(optimizer.is_tail_recursive("sum", &body));
713    }
714
715    #[test]
716    fn test_tail_recursion_optimization_transform() {
717        let optimizer = Optimizer::new();
718
719        // 创建一个简单的尾递归函数
720        let func_def = Stmt::FuncDef {
721            name: "factorial".to_string(),
722            params: vec!["n".to_string(), "acc".to_string()],
723            body: vec![Stmt::Return(Expr::Call {
724                func: Box::new(Expr::Identifier("factorial".to_string())),
725                args: vec![
726                    Expr::Binary {
727                        left: Box::new(Expr::Identifier("n".to_string())),
728                        op: BinOp::Subtract,
729                        right: Box::new(Expr::Number(1.0)),
730                    },
731                    Expr::Binary {
732                        left: Box::new(Expr::Identifier("acc".to_string())),
733                        op: BinOp::Multiply,
734                        right: Box::new(Expr::Identifier("n".to_string())),
735                    },
736                ],
737            })],
738        };
739
740        let optimized = optimizer.optimize_tail_recursive_stmt(func_def);
741
742        // 验证转换后包含While循环
743        if let Stmt::FuncDef { body, .. } = optimized {
744            // 应该包含临时变量初始化、循环标志和while循环
745            // 2个参数 = 2个临时变量 + 1个循环标志 + 1个while循环 = 4个语句
746            assert!(
747                body.len() >= 3,
748                "Expected at least 3 statements, got {}",
749                body.len()
750            );
751
752            // 最后一个语句应该是While循环
753            if let Some(Stmt::While { .. }) = body.last() {
754                // 成功转换为循环
755            } else {
756                panic!("Expected While loop at the end of optimized function body");
757            }
758        } else {
759            panic!("Expected FuncDef");
760        }
761    }
762}