use leo_ast::{AstReconstructor, Block, Expression, IterationStatement, Node as _, ProgramReconstructor, Statement};
use crate::CompilerState;
pub struct Replacer<'a, F>
where
F: Fn(&Expression) -> Expression,
{
state: &'a mut CompilerState,
refresh_expr_ids: bool,
replace: F,
}
impl<'a, F> Replacer<'a, F>
where
F: Fn(&Expression) -> Expression,
{
pub fn new(replace: F, refresh_expr_ids: bool, state: &'a mut CompilerState) -> Self {
Self { replace, refresh_expr_ids, state }
}
}
impl<F> AstReconstructor for Replacer<'_, F>
where
F: Fn(&Expression) -> Expression,
{
type AdditionalInput = ();
type AdditionalOutput = ();
fn reconstruct_expression(&mut self, input: Expression, _additional: &()) -> (Expression, Self::AdditionalOutput) {
let opt_old_type = self.state.type_table.get(&input.id());
let replaced_expr = (self.replace)(&input);
let (mut new_expr, additional) = if replaced_expr.id() == input.id() {
match input {
Expression::Intrinsic(intr) => self.reconstruct_intrinsic(*intr, &()),
Expression::Async(async_) => self.reconstruct_async(async_, &()),
Expression::Array(array) => self.reconstruct_array(array, &()),
Expression::ArrayAccess(access) => self.reconstruct_array_access(*access, &()),
Expression::Binary(binary) => self.reconstruct_binary(*binary, &()),
Expression::Call(call) => self.reconstruct_call(*call, &()),
Expression::Cast(cast) => self.reconstruct_cast(*cast, &()),
Expression::Composite(composite) => self.reconstruct_composite_init(composite, &()),
Expression::Err(err) => self.reconstruct_err(err, &()),
Expression::Path(path) => self.reconstruct_path(path, &()),
Expression::Literal(value) => self.reconstruct_literal(value, &()),
Expression::MemberAccess(access) => self.reconstruct_member_access(*access, &()),
Expression::Repeat(repeat) => self.reconstruct_repeat(*repeat, &()),
Expression::Ternary(ternary) => self.reconstruct_ternary(*ternary, &()),
Expression::Tuple(tuple) => self.reconstruct_tuple(tuple, &()),
Expression::TupleAccess(access) => self.reconstruct_tuple_access(*access, &()),
Expression::Unary(unary) => self.reconstruct_unary(*unary, &()),
Expression::Unit(unit) => self.reconstruct_unit(unit, &()),
}
} else {
(replaced_expr, Default::default())
};
if self.refresh_expr_ids {
new_expr.set_id(self.state.node_builder.next_id());
}
if let Some(old_type) = opt_old_type {
self.state.type_table.insert(new_expr.id(), old_type);
}
(new_expr, additional)
}
fn reconstruct_block(&mut self, input: Block) -> (Block, Self::AdditionalOutput) {
(
Block {
statements: input.statements.into_iter().map(|s| self.reconstruct_statement(s).0).collect(),
span: input.span,
id: self.state.node_builder.next_id(),
},
Default::default(),
)
}
fn reconstruct_iteration(&mut self, input: IterationStatement) -> (Statement, Self::AdditionalOutput) {
(
IterationStatement {
type_: input.type_.map(|ty| self.reconstruct_type(ty).0),
start: self.reconstruct_expression(input.start, &()).0,
stop: self.reconstruct_expression(input.stop, &()).0,
block: self.reconstruct_block(input.block).0,
id: self.state.node_builder.next_id(),
..input
}
.into(),
Default::default(),
)
}
}
impl<F> ProgramReconstructor for Replacer<'_, F> where F: Fn(&Expression) -> Expression {}