use super::restricted::{Expr, Function, RestrictedAst, Stmt};
pub trait Visitor<T> {
fn visit_ast(&mut self, ast: &RestrictedAst) -> T;
fn visit_function(&mut self, function: &Function) -> T;
fn visit_stmt(&mut self, stmt: &Stmt) -> T;
fn visit_expr(&mut self, expr: &Expr) -> T;
}
pub trait VisitorMut<T> {
fn visit_ast_mut(&mut self, ast: &mut RestrictedAst) -> T;
fn visit_function_mut(&mut self, function: &mut Function) -> T;
fn visit_stmt_mut(&mut self, stmt: &mut Stmt) -> T;
fn visit_expr_mut(&mut self, expr: &mut Expr) -> T;
}
pub fn walk_ast<V, T>(visitor: &mut V, ast: &RestrictedAst) -> T
where
V: Visitor<T>,
T: Default,
{
visitor.visit_ast(ast)
}
pub fn transform_exprs<F>(ast: &mut RestrictedAst, mut transform: F)
where
F: FnMut(&mut Expr),
{
for function in &mut ast.functions {
for stmt in &mut function.body {
transform_stmt_exprs(stmt, &mut transform);
}
}
}
fn transform_stmt_exprs<F>(stmt: &mut Stmt, transform: &mut F)
where
F: FnMut(&mut Expr),
{
match stmt {
Stmt::Let { value, .. } => transform_expr(value, transform),
Stmt::Expr(expr) => transform_expr(expr, transform),
Stmt::Return(Some(expr)) => transform_expr(expr, transform),
Stmt::Return(None) => {}
Stmt::If {
condition,
then_block,
else_block,
} => {
transform_expr(condition, transform);
for stmt in then_block {
transform_stmt_exprs(stmt, transform);
}
if let Some(else_stmts) = else_block {
for stmt in else_stmts {
transform_stmt_exprs(stmt, transform);
}
}
}
_ => {} }
}
fn transform_expr<F>(expr: &mut Expr, transform: &mut F)
where
F: FnMut(&mut Expr),
{
match expr {
Expr::FunctionCall { args, .. } => {
for arg in args {
transform_expr(arg, transform);
}
}
Expr::Binary { left, right, .. } => {
transform_expr(left, transform);
transform_expr(right, transform);
}
Expr::Unary { operand, .. } => {
transform_expr(operand, transform);
}
Expr::MethodCall { receiver, args, .. } => {
transform_expr(receiver, transform);
for arg in args {
transform_expr(arg, transform);
}
}
_ => {}
}
transform(expr);
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ast::restricted::{BinaryOp, Literal, Type, UnaryOp};
fn create_test_ast() -> RestrictedAst {
RestrictedAst {
entry_point: "test_fn".to_string(),
functions: vec![Function {
name: "test_fn".to_string(),
params: vec![],
return_type: Type::Void,
body: vec![
Stmt::Let {
name: "x".to_string(),
value: Expr::Literal(Literal::Str("hello".to_string())),
declaration: true,
},
Stmt::Expr(Expr::Variable("x".to_string())),
],
}],
}
}
struct CountingVisitor {
count: usize,
}
impl Visitor<()> for CountingVisitor {
fn visit_ast(&mut self, ast: &RestrictedAst) {
self.count += 1;
for func in &ast.functions {
self.visit_function(func);
}
}
fn visit_function(&mut self, function: &Function) {
self.count += 1;
for stmt in &function.body {
self.visit_stmt(stmt);
}
}
fn visit_stmt(&mut self, stmt: &Stmt) {
self.count += 1;
match stmt {
Stmt::Let { value, .. } => self.visit_expr(value),
Stmt::Expr(expr) => self.visit_expr(expr),
Stmt::Return(Some(expr)) => self.visit_expr(expr),
_ => (),
}
}
fn visit_expr(&mut self, _expr: &Expr) {
self.count += 1;
}
}
#[test]
fn test_walk_ast() {
let ast = create_test_ast();
let mut visitor = CountingVisitor { count: 0 };
walk_ast(&mut visitor, &ast);
assert_eq!(visitor.count, 6);
}
#[test]
fn test_transform_exprs_let() {
let mut ast = create_test_ast();
let mut transform_count = 0;
transform_exprs(&mut ast, |_expr| {
transform_count += 1;
});
assert_eq!(transform_count, 2);
}
#[test]
fn test_transform_exprs_empty_ast() {
let mut ast = RestrictedAst {
entry_point: "main".to_string(),
functions: vec![],
};
let mut transform_count = 0;
transform_exprs(&mut ast, |_expr| {
transform_count += 1;
});
assert_eq!(transform_count, 0);
}
#[test]
fn test_transform_exprs_with_if() {
let mut ast = RestrictedAst {
entry_point: "test".to_string(),
functions: vec![Function {
name: "test".to_string(),
params: vec![],
return_type: Type::Void,
body: vec![Stmt::If {
condition: Expr::Variable("cond".to_string()),
then_block: vec![Stmt::Expr(Expr::Literal(Literal::Str("then".to_string())))],
else_block: Some(vec![Stmt::Expr(Expr::Literal(Literal::Str(
"else".to_string(),
)))]),
}],
}],
};
let mut transform_count = 0;
transform_exprs(&mut ast, |_expr| {
transform_count += 1;
});
assert_eq!(transform_count, 3);
}
#[test]
fn test_transform_exprs_with_return() {
let mut ast = RestrictedAst {
entry_point: "test".to_string(),
functions: vec![Function {
name: "test".to_string(),
params: vec![],
return_type: Type::Void,
body: vec![
Stmt::Return(Some(Expr::Literal(Literal::Str("value".to_string())))),
Stmt::Return(None),
],
}],
};
let mut transform_count = 0;
transform_exprs(&mut ast, |_expr| {
transform_count += 1;
});
assert_eq!(transform_count, 1);
}
#[test]
fn test_transform_expr_function_call() {
let mut ast = RestrictedAst {
entry_point: "test".to_string(),
functions: vec![Function {
name: "test".to_string(),
params: vec![],
return_type: Type::Void,
body: vec![Stmt::Expr(Expr::FunctionCall {
name: "func".to_string(),
args: vec![
Expr::Literal(Literal::Str("arg1".to_string())),
Expr::Literal(Literal::Str("arg2".to_string())),
],
})],
}],
};
let mut transform_count = 0;
transform_exprs(&mut ast, |_expr| {
transform_count += 1;
});
assert_eq!(transform_count, 3);
}
#[test]
fn test_transform_expr_binary() {
let mut ast = RestrictedAst {
entry_point: "test".to_string(),
functions: vec![Function {
name: "test".to_string(),
params: vec![],
return_type: Type::Void,
body: vec![Stmt::Expr(Expr::Binary {
op: BinaryOp::Add,
left: Box::new(Expr::Literal(Literal::U32(1))),
right: Box::new(Expr::Literal(Literal::U32(2))),
})],
}],
};
let mut transform_count = 0;
transform_exprs(&mut ast, |_expr| {
transform_count += 1;
});
assert_eq!(transform_count, 3);
}
#[test]
fn test_transform_expr_unary() {
let mut ast = RestrictedAst {
entry_point: "test".to_string(),
functions: vec![Function {
name: "test".to_string(),
params: vec![],
return_type: Type::Void,
body: vec![Stmt::Expr(Expr::Unary {
op: UnaryOp::Neg,
operand: Box::new(Expr::Literal(Literal::U32(5))),
})],
}],
};
let mut transform_count = 0;
transform_exprs(&mut ast, |_expr| {
transform_count += 1;
});
assert_eq!(transform_count, 2);
}
#[test]
fn test_transform_expr_method_call() {
let mut ast = RestrictedAst {
entry_point: "test".to_string(),
functions: vec![Function {
name: "test".to_string(),
params: vec![],
return_type: Type::Void,
body: vec![Stmt::Expr(Expr::MethodCall {
receiver: Box::new(Expr::Variable("obj".to_string())),
method: "method".to_string(),
args: vec![Expr::Literal(Literal::Str("arg".to_string()))],
})],
}],
};
let mut transform_count = 0;
transform_exprs(&mut ast, |_expr| {
transform_count += 1;
});
assert_eq!(transform_count, 3);
}
}