use leo_ast::{
AstReconstructor,
Block,
IterationStatement,
Literal,
Node,
NodeID,
Statement,
Type,
interpreter_value::Value,
};
use leo_errors::LoopUnrollerError;
use leo_span::{Span, Symbol};
use itertools::Either;
use crate::CompilerState;
pub struct UnrollingVisitor<'a> {
pub state: &'a mut CompilerState,
pub program: Symbol,
pub loop_not_unrolled: Option<Span>,
pub loop_unrolled: bool,
}
impl UnrollingVisitor<'_> {
pub fn in_scope<T>(&mut self, id: NodeID, func: impl FnOnce(&mut Self) -> T) -> T {
self.state.symbol_table.enter_scope(Some(id));
let result = func(self);
self.state.symbol_table.enter_parent();
result
}
pub fn unroll_iteration_statement(&mut self, input: IterationStatement, start: Value, stop: Value) -> Statement {
let cast_to_number = |v: Value| -> Result<i128, Statement> {
match v.as_i128() {
Some(val_as_i128) => Ok(val_as_i128),
None => {
self.state.handler.emit_err(LoopUnrollerError::value_out_of_i128_bounds(v, input.span()));
Err(Statement::dummy())
}
}
};
let start = match cast_to_number(start) {
Ok(v) => v,
Err(s) => return s,
};
let stop = match cast_to_number(stop) {
Ok(v) => v,
Err(s) => return s,
};
let new_block_id = self.state.node_builder.next_id();
let iter = if input.inclusive { Either::Left(start..=stop) } else { Either::Right(start..stop) };
self.in_scope(new_block_id, |slf| {
Block {
span: input.span,
statements: iter.map(|iteration_count| slf.unroll_single_iteration(&input, iteration_count)).collect(),
id: new_block_id,
}
.into()
})
}
fn unroll_single_iteration(&mut self, input: &IterationStatement, iteration_count: i128) -> Statement {
let const_id = self.state.node_builder.next_id();
let iterator_type =
self.state.type_table.get(&input.variable.id()).expect("guaranteed to have a type after type checking");
self.state.type_table.insert(const_id, iterator_type.clone());
let outer_block_id = self.state.node_builder.next_id();
let Type::Integer(integer_type) = &iterator_type else {
unreachable!("Type checking enforces that the iteration variable is of integer type");
};
self.in_scope(outer_block_id, |slf| {
let value = Literal::integer(*integer_type, iteration_count.to_string(), Default::default(), const_id);
slf.state.symbol_table.insert_local_const(input.variable.name, value.into());
let duplicated_body =
super::duplicate::duplicate(input.block.clone(), &mut slf.state.symbol_table, &slf.state.node_builder);
let result = slf.reconstruct_block(duplicated_body).0.into();
Block { statements: vec![result], span: input.span(), id: outer_block_id }.into()
})
}
}