use leo_ast::{
Block,
DeclarationType,
DefinitionStatement,
Expression,
IntegerType,
IterationStatement,
Literal,
Statement,
StatementReconstructor,
Type,
Value,
};
use std::cell::RefCell;
use leo_errors::emitter::Handler;
use crate::{Clusivity, LoopBound, RangeIterator, SymbolTable};
pub struct Unroller<'a> {
pub(crate) symbol_table: RefCell<SymbolTable>,
pub(crate) scope_index: usize,
pub(crate) handler: &'a Handler,
pub(crate) is_unrolling: bool,
}
impl<'a> Unroller<'a> {
pub(crate) fn new(symbol_table: SymbolTable, handler: &'a Handler) -> Self {
Self { symbol_table: RefCell::new(symbol_table), scope_index: 0, handler, is_unrolling: false }
}
pub(crate) fn current_scope_index(&mut self) -> usize {
if self.is_unrolling { self.symbol_table.borrow_mut().insert_block() } else { self.scope_index }
}
pub(crate) fn enter_scope(&mut self, index: usize) -> usize {
let previous_symbol_table = std::mem::take(&mut self.symbol_table);
self.symbol_table.swap(previous_symbol_table.borrow().lookup_scope_by_index(index).unwrap());
self.symbol_table.borrow_mut().parent = Some(Box::new(previous_symbol_table.into_inner()));
core::mem::replace(&mut self.scope_index, 0)
}
pub(crate) fn exit_scope(&mut self, index: usize) {
let prev_st = *self.symbol_table.borrow_mut().parent.take().unwrap();
self.symbol_table.swap(prev_st.lookup_scope_by_index(index).unwrap());
self.symbol_table = RefCell::new(prev_st);
self.scope_index = index + 1;
}
pub(crate) fn unroll_iteration_statement<I: LoopBound>(
&mut self,
input: IterationStatement,
start: Value,
stop: Value,
) -> Statement {
let cast_to_number = |v: Value| -> Result<I, Statement> {
match v.try_into() {
Ok(val_as_u128) => Ok(val_as_u128),
Err(err) => {
self.handler.emit_err(err);
Err(Statement::dummy(input.span))
}
}
};
let start = match cast_to_number(start) {
Ok(v) => v,
Err(s) => return s,
};
let stop = match cast_to_number(stop) {
Ok(v) => v,
Err(s) => return s,
};
let scope_index = self.current_scope_index();
let previous_scope_index = self.enter_scope(scope_index);
self.symbol_table.borrow_mut().variables.clear();
self.symbol_table.borrow_mut().scopes.clear();
self.symbol_table.borrow_mut().scope_index = 0;
let iter_blocks = Statement::Block(Block {
span: input.span,
statements: match input.inclusive {
true => {
let iter = RangeIterator::new(start, stop, Clusivity::Inclusive);
iter.map(|iteration_count| self.unroll_single_iteration(&input, iteration_count)).collect()
}
false => {
let iter = RangeIterator::new(start, stop, Clusivity::Exclusive);
iter.map(|iteration_count| self.unroll_single_iteration(&input, iteration_count)).collect()
}
},
});
self.exit_scope(previous_scope_index);
iter_blocks
}
fn unroll_single_iteration<I: LoopBound>(&mut self, input: &IterationStatement, iteration_count: I) -> Statement {
let scope_index = self.symbol_table.borrow_mut().insert_block();
let previous_scope_index = self.enter_scope(scope_index);
let prior_is_unrolling = self.is_unrolling;
self.is_unrolling = true;
let value = match input.type_ {
Type::Integer(IntegerType::I8) => {
Literal::Integer(IntegerType::I8, iteration_count.to_string(), Default::default())
}
Type::Integer(IntegerType::I16) => {
Literal::Integer(IntegerType::I16, iteration_count.to_string(), Default::default())
}
Type::Integer(IntegerType::I32) => {
Literal::Integer(IntegerType::I32, iteration_count.to_string(), Default::default())
}
Type::Integer(IntegerType::I64) => {
Literal::Integer(IntegerType::I64, iteration_count.to_string(), Default::default())
}
Type::Integer(IntegerType::I128) => {
Literal::Integer(IntegerType::I128, iteration_count.to_string(), Default::default())
}
Type::Integer(IntegerType::U8) => {
Literal::Integer(IntegerType::U8, iteration_count.to_string(), Default::default())
}
Type::Integer(IntegerType::U16) => {
Literal::Integer(IntegerType::U16, iteration_count.to_string(), Default::default())
}
Type::Integer(IntegerType::U32) => {
Literal::Integer(IntegerType::U32, iteration_count.to_string(), Default::default())
}
Type::Integer(IntegerType::U64) => {
Literal::Integer(IntegerType::U64, iteration_count.to_string(), Default::default())
}
Type::Integer(IntegerType::U128) => {
Literal::Integer(IntegerType::U128, iteration_count.to_string(), Default::default())
}
_ => unreachable!(
"The iteration variable must be an integer type. This should be enforced by type checking."
),
};
let mut statements = vec![
self.reconstruct_definition(DefinitionStatement {
declaration_type: DeclarationType::Const,
type_: input.type_.clone(),
value: Expression::Literal(value),
span: Default::default(),
place: Expression::Identifier(input.variable),
})
.0,
];
input.block.statements.clone().into_iter().for_each(|s| {
statements.push(self.reconstruct_statement(s).0);
});
let block = Statement::Block(Block { statements, span: input.block.span });
self.is_unrolling = prior_is_unrolling;
self.exit_scope(previous_scope_index);
block
}
}