use crate::{CompilerState, Pass};
use leo_ast::*;
use leo_errors::Result;
pub struct RemoveUnreachableOutput {
pub changed: bool,
}
pub struct RemoveUnreachable;
impl Pass for RemoveUnreachable {
type Input = ();
type Output = RemoveUnreachableOutput;
const NAME: &str = "RemoveUnreachable";
fn do_pass(_input: Self::Input, state: &mut crate::CompilerState) -> Result<Self::Output> {
let mut ast = std::mem::take(&mut state.ast);
let mut visitor = RemoveUnreachableVisitor { changed: false, state, has_return: false };
ast.ast = visitor.reconstruct_program(ast.ast);
visitor.state.ast = ast;
Ok(RemoveUnreachableOutput { changed: visitor.changed })
}
}
pub struct RemoveUnreachableVisitor<'state> {
pub state: &'state mut CompilerState,
pub changed: bool,
pub has_return: bool,
}
impl ProgramReconstructor for RemoveUnreachableVisitor<'_> {
fn reconstruct_function(&mut self, input: Function) -> Function {
self.has_return = false;
let res = Function {
annotations: input.annotations,
variant: input.variant,
identifier: input.identifier,
const_parameters: input
.const_parameters
.iter()
.map(|param| ConstParameter { type_: self.reconstruct_type(param.type_.clone()).0, ..param.clone() })
.collect(),
input: input
.input
.iter()
.map(|input| Input { type_: self.reconstruct_type(input.type_.clone()).0, ..input.clone() })
.collect(),
output: input
.output
.iter()
.map(|output| Output { type_: self.reconstruct_type(output.type_.clone()).0, ..output.clone() })
.collect(),
output_type: self.reconstruct_type(input.output_type).0,
block: self.reconstruct_block(input.block).0,
span: input.span,
id: input.id,
};
self.has_return = false;
res
}
fn reconstruct_constructor(&mut self, input: Constructor) -> Constructor {
self.has_return = false;
let res = Constructor {
annotations: input.annotations,
block: self.reconstruct_block(input.block).0,
span: input.span,
id: input.id,
};
self.has_return = false;
res
}
}
impl AstReconstructor for RemoveUnreachableVisitor<'_> {
type AdditionalInput = ();
type AdditionalOutput = ();
fn reconstruct_block(&mut self, input: Block) -> (Block, Self::AdditionalOutput) {
let statements_with_first_return_only = input
.statements
.into_iter()
.scan(false, |return_seen, s| {
let stmt = self.reconstruct_statement(s).0;
let res = (!*return_seen).then_some(stmt);
*return_seen |= self.has_return;
res
})
.collect();
(Block { statements: statements_with_first_return_only, span: input.span, id: input.id }, Default::default())
}
fn reconstruct_conditional(&mut self, input: ConditionalStatement) -> (Statement, Self::AdditionalOutput) {
let mut then_block_has_return = false;
let mut otherwise_block_has_return = false;
let previous_has_return = core::mem::replace(&mut self.has_return, then_block_has_return);
let then = self.reconstruct_block(input.then).0;
then_block_has_return = self.has_return;
let otherwise = input.otherwise.map(|otherwise| {
self.has_return = otherwise_block_has_return;
let res = Box::new(self.reconstruct_statement(*otherwise).0);
otherwise_block_has_return = self.has_return;
res
});
self.has_return = previous_has_return || (then_block_has_return && otherwise_block_has_return);
(
ConditionalStatement {
condition: self.reconstruct_expression(input.condition, &Default::default()).0,
then,
otherwise,
..input
}
.into(),
Default::default(),
)
}
fn reconstruct_iteration(&mut self, input: IterationStatement) -> (Statement, Self::AdditionalOutput) {
let prior_has_return = core::mem::take(&mut self.has_return);
let block = self.reconstruct_block(input.block).0;
self.has_return = prior_has_return;
(
IterationStatement {
type_: input.type_.map(|ty| self.reconstruct_type(ty).0),
start: self.reconstruct_expression(input.start, &Default::default()).0,
stop: self.reconstruct_expression(input.stop, &Default::default()).0,
block,
..input
}
.into(),
Default::default(),
)
}
fn reconstruct_return(&mut self, input: ReturnStatement) -> (Statement, Self::AdditionalOutput) {
self.has_return = true;
(
ReturnStatement {
expression: self.reconstruct_expression(input.expression, &Default::default()).0,
..input
}
.into(),
Default::default(),
)
}
}