use crate::{CompilerState, Replacer, SymbolAccessCollector};
use leo_ast::{
AstReconstructor,
AstVisitor,
Block,
CallExpression,
Expression,
Function,
Identifier,
Input,
Location,
Node,
Path,
TupleAccess,
TupleExpression,
TupleType,
Type,
Variant,
};
use leo_span::{Span, Symbol};
use indexmap::IndexMap;
pub struct BlockToFunctionRewriter<'a> {
state: &'a mut CompilerState,
current_program: Symbol,
}
impl<'a> BlockToFunctionRewriter<'a> {
pub fn new(state: &'a mut CompilerState, current_program: Symbol) -> Self {
Self { state, current_program }
}
}
impl BlockToFunctionRewriter<'_> {
pub fn rewrite_block(
&mut self,
input: &Block,
function_name: Symbol,
function_variant: Variant,
) -> (Function, Expression) {
let mut access_collector = SymbolAccessCollector::new(self.state);
access_collector.visit_block(input);
let mut replacements: IndexMap<(Symbol, Option<usize>), Expression> = IndexMap::new();
let make_identifier = |slf: &mut Self, symbol: Symbol| Identifier {
name: symbol,
span: Span::default(),
id: slf.state.node_builder.next_id(),
};
let mut make_inputs_and_arguments =
|slf: &mut Self, symbol: Symbol, var_type: &Type, index_opt: Option<usize>| -> Vec<(Input, Expression)> {
if replacements.contains_key(&(symbol, index_opt)) {
return vec![]; }
match index_opt {
Some(index) => {
let Type::Tuple(TupleType { elements }) = var_type else {
return vec![];
};
if index >= elements.len() {
return vec![];
}
let synthetic_name = format!("\"{symbol}.{index}\"");
let synthetic_symbol = Symbol::intern(&synthetic_name);
let identifier = make_identifier(slf, synthetic_symbol);
let input = Input {
identifier,
mode: leo_ast::Mode::None,
type_: elements[index].clone(),
span: Span::default(),
id: slf.state.node_builder.next_id(),
};
replacements.insert((symbol, Some(index)), Path::from(identifier).to_local().into());
vec![(
input,
TupleAccess {
tuple: Path::from(make_identifier(slf, symbol)).to_local().into(),
index: index.into(),
span: Span::default(),
id: slf.state.node_builder.next_id(),
}
.into(),
)]
}
None => match var_type {
Type::Tuple(TupleType { elements }) => {
let mut inputs_and_arguments = Vec::with_capacity(elements.len());
let mut tuple_elements = Vec::with_capacity(elements.len());
for (i, element_type) in elements.iter().enumerate() {
let key = (symbol, Some(i));
if let Some(existing_expr) = replacements.get(&key) {
tuple_elements.push(existing_expr.clone());
continue;
}
let synthetic_name = format!("\"{symbol}.{i}\"");
let synthetic_symbol = Symbol::intern(&synthetic_name);
let identifier = make_identifier(slf, synthetic_symbol);
let input = Input {
identifier,
mode: leo_ast::Mode::None,
type_: element_type.clone(),
span: Span::default(),
id: slf.state.node_builder.next_id(),
};
let expr: Expression = Path::from(identifier).to_local().into();
replacements.insert(key, expr.clone());
tuple_elements.push(expr.clone());
inputs_and_arguments.push((
input,
TupleAccess {
tuple: Path::from(make_identifier(slf, symbol)).to_local().into(),
index: i.into(),
span: Span::default(),
id: slf.state.node_builder.next_id(),
}
.into(),
));
}
replacements.insert(
(symbol, None),
Expression::Tuple(TupleExpression {
elements: tuple_elements,
span: Span::default(),
id: slf.state.node_builder.next_id(),
}),
);
inputs_and_arguments
}
_ => {
let identifier = make_identifier(slf, symbol);
let input = Input {
identifier,
mode: leo_ast::Mode::None,
type_: var_type.clone(),
span: Span::default(),
id: slf.state.node_builder.next_id(),
};
replacements.insert((symbol, None), Path::from(identifier).to_local().into());
let argument = Path::from(make_identifier(slf, symbol)).to_local().into();
vec![(input, argument)]
}
},
}
};
let (inputs, arguments): (Vec<_>, Vec<_>) = access_collector
.symbol_accesses
.iter()
.filter_map(|(path, index)| {
if path.is_global() {
return None;
}
let local_var_name = path.expect_local_symbol(); if self.state.symbol_table.is_local_to_or_in_child_scope(input.id(), local_var_name) {
return None;
}
let var = self.state.symbol_table.lookup_local(local_var_name)?;
Some(make_inputs_and_arguments(self, local_var_name, &var.type_.expect("must be known by now"), *index))
})
.flatten()
.unzip();
let replace_expr = |expr: &Expression| -> Expression {
match expr {
Expression::Path(path) => {
replacements.get(&(path.identifier().name, None)).cloned().unwrap_or_else(|| expr.clone())
}
Expression::TupleAccess(ta) => {
if let Expression::Path(path) = &ta.tuple {
replacements
.get(&(path.identifier().name, Some(ta.index.value())))
.cloned()
.unwrap_or_else(|| expr.clone())
} else {
expr.clone()
}
}
_ => expr.clone(),
}
};
let mut replacer = Replacer::new(replace_expr, true , self.state);
let new_block = replacer.reconstruct_block(input.clone()).0;
let function = Function {
annotations: vec![],
variant: function_variant,
identifier: make_identifier(self, function_name),
const_parameters: vec![],
input: inputs,
output: vec![], output_type: Type::Unit, block: new_block,
span: input.span,
id: self.state.node_builder.next_id(),
};
let call_to_function = CallExpression {
function: Path::from(make_identifier(self, function_name))
.to_global(Location::new(self.current_program, vec![function_name])),
const_arguments: vec![],
arguments,
span: input.span,
id: self.state.node_builder.next_id(),
};
(function, call_to_function.into())
}
}