use crate::ast::{
BinaryOp, Block, Declaration, Evo, Expr, Gen, Literal, MatchArm, Pattern, Rule, Statement,
Stmt, System, Trait, TypeExpr, UnaryOp,
};
pub trait Visitor {
fn visit_declaration(&mut self, decl: &Declaration) {
walk_declaration(self, decl);
}
fn visit_gene(&mut self, gene: &Gen) {
walk_gene(self, gene);
}
fn visit_trait(&mut self, tr: &Trait) {
walk_trait(self, tr);
}
fn visit_constraint(&mut self, c: &Rule) {
walk_constraint(self, c);
}
fn visit_system(&mut self, sys: &System) {
walk_system(self, sys);
}
fn visit_evolution(&mut self, _evo: &Evo) {}
fn visit_function_decl(&mut self, func: &crate::ast::FunctionDecl) {
walk_function_decl(self, func);
}
fn visit_statement(&mut self, _stmt: &Statement) {}
fn visit_stmt(&mut self, stmt: &Stmt) {
walk_stmt(self, stmt);
}
fn visit_expr(&mut self, expr: &Expr) {
walk_expr(self, expr);
}
fn visit_type_expr(&mut self, _ty: &TypeExpr) {}
fn visit_literal(&mut self, _lit: &Literal) {}
fn visit_identifier(&mut self, _name: &str) {}
fn visit_binary_op(&mut self, _op: &BinaryOp) {}
fn visit_unary_op(&mut self, _op: &UnaryOp) {}
fn visit_pattern(&mut self, pattern: &Pattern) {
walk_pattern(self, pattern);
}
fn visit_match_arm(&mut self, arm: &MatchArm) {
self.visit_pattern(&arm.pattern);
if let Some(ref guard) = arm.guard {
self.visit_expr(guard);
}
self.visit_expr(&arm.body);
}
}
pub trait MutVisitor {
fn visit_declaration(&mut self, decl: &mut Declaration) {
walk_declaration_mut(self, decl);
}
fn visit_gene(&mut self, gene: &mut Gen) {
walk_gene_mut(self, gene);
}
fn visit_trait(&mut self, tr: &mut Trait) {
walk_trait_mut(self, tr);
}
fn visit_constraint(&mut self, c: &mut Rule) {
walk_constraint_mut(self, c);
}
fn visit_system(&mut self, sys: &mut System) {
walk_system_mut(self, sys);
}
fn visit_evolution(&mut self, _evo: &mut Evo) {}
fn visit_function_decl(&mut self, func: &mut crate::ast::FunctionDecl) {
walk_function_decl_mut(self, func);
}
fn visit_statement(&mut self, _stmt: &mut Statement) {}
fn visit_stmt(&mut self, stmt: &mut Stmt) {
walk_stmt_mut(self, stmt);
}
fn visit_expr(&mut self, expr: &mut Expr) {
walk_expr_mut(self, expr);
}
fn visit_type_expr(&mut self, _ty: &mut TypeExpr) {}
fn visit_pattern(&mut self, pattern: &mut Pattern) {
walk_pattern_mut(self, pattern);
}
fn visit_match_arm(&mut self, arm: &mut MatchArm) {
self.visit_pattern(&mut arm.pattern);
if let Some(ref mut guard) = arm.guard {
self.visit_expr(guard);
}
self.visit_expr(&mut arm.body);
}
}
fn walk_declaration<V: Visitor + ?Sized>(v: &mut V, decl: &Declaration) {
match decl {
Declaration::Gene(gene) => v.visit_gene(gene),
Declaration::Trait(tr) => v.visit_trait(tr),
Declaration::Constraint(c) => v.visit_constraint(c),
Declaration::System(sys) => v.visit_system(sys),
Declaration::Evolution(evo) => v.visit_evolution(evo),
Declaration::Function(func) => v.visit_function_decl(func),
Declaration::Const(_) | Declaration::SexVar(_) => {} }
}
fn walk_function_decl<V: Visitor + ?Sized>(v: &mut V, func: &crate::ast::FunctionDecl) {
for stmt in &func.body {
v.visit_stmt(stmt);
}
}
fn walk_gene<V: Visitor + ?Sized>(v: &mut V, gene: &Gen) {
for stmt in &gene.statements {
v.visit_statement(stmt);
}
}
fn walk_trait<V: Visitor + ?Sized>(v: &mut V, tr: &Trait) {
for stmt in &tr.statements {
v.visit_statement(stmt);
}
}
fn walk_constraint<V: Visitor + ?Sized>(v: &mut V, c: &Rule) {
for stmt in &c.statements {
v.visit_statement(stmt);
}
}
fn walk_system<V: Visitor + ?Sized>(v: &mut V, sys: &System) {
for stmt in &sys.statements {
v.visit_statement(stmt);
}
}
fn walk_stmt<V: Visitor + ?Sized>(v: &mut V, stmt: &Stmt) {
match stmt {
Stmt::Let { value, .. } => {
v.visit_expr(value);
}
Stmt::Assign { target, value } => {
v.visit_expr(target);
v.visit_expr(value);
}
Stmt::Expr(expr) => v.visit_expr(expr),
Stmt::Return(Some(expr)) => v.visit_expr(expr),
Stmt::For { iterable, body, .. } => {
v.visit_expr(iterable);
for s in body {
v.visit_stmt(s);
}
}
Stmt::While { condition, body } => {
v.visit_expr(condition);
for s in body {
v.visit_stmt(s);
}
}
Stmt::Loop { body } => {
for s in body {
v.visit_stmt(s);
}
}
_ => {}
}
}
fn walk_expr<V: Visitor + ?Sized>(v: &mut V, expr: &Expr) {
match expr {
Expr::Literal(lit) => v.visit_literal(lit),
Expr::Identifier(name) => v.visit_identifier(name),
Expr::Binary { left, op, right } => {
v.visit_expr(left);
v.visit_binary_op(op);
v.visit_expr(right);
}
Expr::Unary { op, operand } => {
v.visit_unary_op(op);
v.visit_expr(operand);
}
Expr::Call { callee, args } => {
v.visit_expr(callee);
for arg in args {
v.visit_expr(arg);
}
}
Expr::Member { object, .. } => {
v.visit_expr(object);
}
Expr::Lambda { body, .. } => {
v.visit_expr(body);
}
Expr::If {
condition,
then_branch,
else_branch,
} => {
v.visit_expr(condition);
v.visit_expr(then_branch);
if let Some(else_expr) = else_branch {
v.visit_expr(else_expr);
}
}
Expr::Match { scrutinee, arms } => {
v.visit_expr(scrutinee);
for arm in arms {
v.visit_match_arm(arm);
}
}
Expr::Block(Block {
statements,
final_expr,
..
}) => {
for stmt in statements {
v.visit_stmt(stmt);
}
if let Some(expr) = final_expr {
v.visit_expr(expr);
}
}
Expr::Quote(inner) => v.visit_expr(inner),
Expr::Unquote(inner) => v.visit_expr(inner),
Expr::QuasiQuote(inner) => v.visit_expr(inner),
Expr::Eval(inner) => v.visit_expr(inner),
Expr::Reflect(ty) => v.visit_type_expr(ty),
Expr::IdiomBracket { func, args } => {
v.visit_expr(func);
for arg in args {
v.visit_expr(arg);
}
}
Expr::Forall(forall_expr) => {
v.visit_type_expr(&forall_expr.type_);
v.visit_expr(&forall_expr.body);
}
Expr::Exists(exists_expr) => {
v.visit_type_expr(&exists_expr.type_);
v.visit_expr(&exists_expr.body);
}
Expr::Implies { left, right, .. } => {
v.visit_expr(left);
v.visit_expr(right);
}
Expr::SexBlock(Block {
statements,
final_expr,
..
}) => {
for stmt in statements {
v.visit_stmt(stmt);
}
if let Some(expr) = final_expr {
v.visit_expr(expr);
}
}
Expr::List(elements) => {
for elem in elements {
v.visit_expr(elem);
}
}
Expr::Tuple(elements) => {
for elem in elements {
v.visit_expr(elem);
}
}
Expr::Cast { expr, .. } => {
v.visit_expr(expr);
}
Expr::StructLiteral { fields, .. } => {
for (_, expr) in fields {
v.visit_expr(expr);
}
}
Expr::Try(inner) => {
v.visit_expr(inner);
}
Expr::This => {}
}
}
fn walk_pattern<V: Visitor + ?Sized>(v: &mut V, pattern: &Pattern) {
match pattern {
Pattern::Literal(lit) => v.visit_literal(lit),
Pattern::Identifier(_) => {}
Pattern::Wildcard => {}
Pattern::Constructor { fields, .. } => {
for p in fields {
v.visit_pattern(p);
}
}
Pattern::Tuple(patterns) => {
for p in patterns {
v.visit_pattern(p);
}
}
Pattern::Or(patterns) => {
for p in patterns {
v.visit_pattern(p);
}
}
}
}
fn walk_declaration_mut<V: MutVisitor + ?Sized>(v: &mut V, decl: &mut Declaration) {
match decl {
Declaration::Gene(gene) => v.visit_gene(gene),
Declaration::Trait(tr) => v.visit_trait(tr),
Declaration::Constraint(c) => v.visit_constraint(c),
Declaration::System(sys) => v.visit_system(sys),
Declaration::Evolution(evo) => v.visit_evolution(evo),
Declaration::Function(func) => v.visit_function_decl(func),
Declaration::Const(_) | Declaration::SexVar(_) => {} }
}
fn walk_function_decl_mut<V: MutVisitor + ?Sized>(v: &mut V, func: &mut crate::ast::FunctionDecl) {
for stmt in &mut func.body {
v.visit_stmt(stmt);
}
}
fn walk_gene_mut<V: MutVisitor + ?Sized>(v: &mut V, gene: &mut Gen) {
for stmt in &mut gene.statements {
v.visit_statement(stmt);
}
}
fn walk_trait_mut<V: MutVisitor + ?Sized>(v: &mut V, tr: &mut Trait) {
for stmt in &mut tr.statements {
v.visit_statement(stmt);
}
}
fn walk_constraint_mut<V: MutVisitor + ?Sized>(v: &mut V, c: &mut Rule) {
for stmt in &mut c.statements {
v.visit_statement(stmt);
}
}
fn walk_system_mut<V: MutVisitor + ?Sized>(v: &mut V, sys: &mut System) {
for stmt in &mut sys.statements {
v.visit_statement(stmt);
}
}
fn walk_stmt_mut<V: MutVisitor + ?Sized>(v: &mut V, stmt: &mut Stmt) {
match stmt {
Stmt::Let { value, .. } => {
v.visit_expr(value);
}
Stmt::Assign { target, value } => {
v.visit_expr(target);
v.visit_expr(value);
}
Stmt::Expr(expr) => v.visit_expr(expr),
Stmt::Return(Some(expr)) => v.visit_expr(expr),
Stmt::For { iterable, body, .. } => {
v.visit_expr(iterable);
for s in body {
v.visit_stmt(s);
}
}
Stmt::While { condition, body } => {
v.visit_expr(condition);
for s in body {
v.visit_stmt(s);
}
}
Stmt::Loop { body } => {
for s in body {
v.visit_stmt(s);
}
}
_ => {}
}
}
fn walk_expr_mut<V: MutVisitor + ?Sized>(v: &mut V, expr: &mut Expr) {
match expr {
Expr::Binary { left, right, .. } => {
v.visit_expr(left);
v.visit_expr(right);
}
Expr::Unary { operand, .. } => {
v.visit_expr(operand);
}
Expr::Call { callee, args } => {
v.visit_expr(callee);
for arg in args {
v.visit_expr(arg);
}
}
Expr::Member { object, .. } => {
v.visit_expr(object);
}
Expr::Lambda { body, .. } => {
v.visit_expr(body);
}
Expr::If {
condition,
then_branch,
else_branch,
} => {
v.visit_expr(condition);
v.visit_expr(then_branch);
if let Some(else_expr) = else_branch {
v.visit_expr(else_expr);
}
}
Expr::Match { scrutinee, arms } => {
v.visit_expr(scrutinee);
for arm in arms {
v.visit_match_arm(arm);
}
}
Expr::Block(Block {
statements,
final_expr,
..
}) => {
for stmt in statements {
v.visit_stmt(stmt);
}
if let Some(e) = final_expr {
v.visit_expr(e);
}
}
Expr::Quote(inner) => v.visit_expr(inner),
Expr::Unquote(inner) => v.visit_expr(inner),
Expr::QuasiQuote(inner) => v.visit_expr(inner),
Expr::Eval(inner) => v.visit_expr(inner),
Expr::Reflect(ty) => v.visit_type_expr(ty),
Expr::IdiomBracket { func, args } => {
v.visit_expr(func);
for arg in args {
v.visit_expr(arg);
}
}
Expr::Forall(forall_expr) => {
v.visit_type_expr(&mut forall_expr.type_);
v.visit_expr(&mut forall_expr.body);
}
Expr::Exists(exists_expr) => {
v.visit_type_expr(&mut exists_expr.type_);
v.visit_expr(&mut exists_expr.body);
}
Expr::Implies { left, right, .. } => {
v.visit_expr(left);
v.visit_expr(right);
}
Expr::SexBlock(Block {
statements,
final_expr,
..
}) => {
for stmt in statements {
v.visit_stmt(stmt);
}
if let Some(e) = final_expr {
v.visit_expr(e);
}
}
Expr::Literal(_) | Expr::Identifier(_) => {}
Expr::List(elements) => {
for elem in elements {
v.visit_expr(elem);
}
}
Expr::Tuple(elements) => {
for elem in elements {
v.visit_expr(elem);
}
}
Expr::Cast { expr, .. } => {
v.visit_expr(expr);
}
Expr::StructLiteral { fields, .. } => {
for (_, expr) in fields {
v.visit_expr(expr);
}
}
Expr::Try(inner) => {
v.visit_expr(inner);
}
Expr::This => {}
}
}
fn walk_pattern_mut<V: MutVisitor + ?Sized>(v: &mut V, pattern: &mut Pattern) {
match pattern {
Pattern::Constructor { fields, .. } => {
for p in fields {
v.visit_pattern(p);
}
}
Pattern::Tuple(patterns) => {
for p in patterns {
v.visit_pattern(p);
}
}
Pattern::Or(patterns) => {
for p in patterns {
v.visit_pattern(p);
}
}
Pattern::Literal(_) | Pattern::Identifier(_) | Pattern::Wildcard => {}
}
}
#[cfg(test)]
mod tests {
use super::*;
struct ExprCounter {
count: usize,
}
impl Visitor for ExprCounter {
fn visit_expr(&mut self, expr: &Expr) {
self.count += 1;
walk_expr(self, expr);
}
}
#[test]
fn test_visitor_counts_expressions() {
let expr = Expr::Binary {
left: Box::new(Expr::Literal(Literal::Int(1))),
op: BinaryOp::Add,
right: Box::new(Expr::Binary {
left: Box::new(Expr::Literal(Literal::Int(2))),
op: BinaryOp::Mul,
right: Box::new(Expr::Literal(Literal::Int(3))),
}),
};
let mut counter = ExprCounter { count: 0 };
counter.visit_expr(&expr);
assert_eq!(counter.count, 5);
}
struct LiteralDoubler;
impl MutVisitor for LiteralDoubler {
fn visit_expr(&mut self, expr: &mut Expr) {
if let Expr::Literal(Literal::Int(n)) = expr {
*n *= 2;
}
walk_expr_mut(self, expr);
}
}
#[test]
fn test_mut_visitor_transforms() {
let mut expr = Expr::Binary {
left: Box::new(Expr::Literal(Literal::Int(5))),
op: BinaryOp::Add,
right: Box::new(Expr::Literal(Literal::Int(10))),
};
let mut doubler = LiteralDoubler;
doubler.visit_expr(&mut expr);
match expr {
Expr::Binary { left, right, .. } => {
assert_eq!(*left, Expr::Literal(Literal::Int(10)));
assert_eq!(*right, Expr::Literal(Literal::Int(20)));
}
_ => panic!("Expected binary expression"),
}
}
}