use super::ProcessingAsyncVisitor;
use crate::{CompilerState, Replacer};
use indexmap::{IndexMap, IndexSet};
use leo_ast::{
AstReconstructor,
AstVisitor,
AsyncExpression,
Block,
CallExpression,
Expression,
Function,
Identifier,
Input,
IterationStatement,
Location,
Node,
Path,
ProgramVisitor,
Statement,
TupleAccess,
TupleExpression,
TupleType,
Type,
Variant,
};
use leo_span::{Span, Symbol};
struct SymbolAccessCollector<'a> {
state: &'a CompilerState,
symbol_accesses: IndexSet<(Vec<Symbol>, Option<usize>)>,
}
impl AstVisitor for SymbolAccessCollector<'_> {
type AdditionalInput = ();
type Output = ();
fn visit_path(&mut self, input: &Path, _: &Self::AdditionalInput) -> Self::Output {
self.symbol_accesses.insert((input.absolute_path(), None));
}
fn visit_tuple_access(&mut self, input: &TupleAccess, _: &Self::AdditionalInput) -> Self::Output {
if let Expression::Path(path) = &input.tuple {
if let Some(Type::Future(_)) = self.state.type_table.get(&input.tuple.id()) {
self.symbol_accesses.insert((path.absolute_path(), None));
} else {
self.symbol_accesses.insert((path.absolute_path(), Some(input.index.value())));
}
} else {
self.visit_expression(&input.tuple, &());
}
}
}
impl ProgramVisitor for SymbolAccessCollector<'_> {}
impl AstReconstructor for ProcessingAsyncVisitor<'_> {
type AdditionalInput = ();
type AdditionalOutput = ();
fn reconstruct_async(&mut self, input: AsyncExpression, _additional: &()) -> (Expression, Self::AdditionalOutput) {
let finalize_fn_name = self.state.assigner.unique_symbol(self.current_function, "$");
let mut access_collector = SymbolAccessCollector { state: self.state, symbol_accesses: IndexSet::new() };
access_collector.visit_async(&input, &());
let mut replacements: IndexMap<(Symbol, Option<usize>), Expression> = IndexMap::new();
let make_identifier = |slf: &mut Self, symbol: Symbol| Identifier {
name: symbol,
span: Span::default(),
id: slf.state.node_builder.next_id(),
};
let mut make_inputs_and_arguments =
|slf: &mut Self, symbol: Symbol, var_type: &Type, index_opt: Option<usize>| -> Vec<(Input, Expression)> {
if replacements.contains_key(&(symbol, index_opt)) {
return vec![]; }
match index_opt {
Some(index) => {
let Type::Tuple(TupleType { elements }) = var_type else {
panic!("Expected tuple type when accessing tuple field: {symbol}.{index}");
};
let synthetic_name = format!("\"{symbol}.{index}\"");
let synthetic_symbol = Symbol::intern(&synthetic_name);
let identifier = make_identifier(slf, synthetic_symbol);
let input = Input {
identifier,
mode: leo_ast::Mode::None,
type_: elements[index].clone(),
span: Span::default(),
id: slf.state.node_builder.next_id(),
};
replacements.insert((symbol, Some(index)), Path::from(identifier).into_absolute().into());
vec![(
input,
TupleAccess {
tuple: Path::from(make_identifier(slf, symbol)).into_absolute().into(),
index: index.into(),
span: Span::default(),
id: slf.state.node_builder.next_id(),
}
.into(),
)]
}
None => match var_type {
Type::Tuple(TupleType { elements }) => {
let mut inputs_and_arguments = Vec::with_capacity(elements.len());
let mut tuple_elements = Vec::with_capacity(elements.len());
for (i, element_type) in elements.iter().enumerate() {
let key = (symbol, Some(i));
if let Some(existing_expr) = replacements.get(&key) {
tuple_elements.push(existing_expr.clone());
continue;
}
let synthetic_name = format!("\"{symbol}.{i}\"");
let synthetic_symbol = Symbol::intern(&synthetic_name);
let identifier = make_identifier(slf, synthetic_symbol);
let input = Input {
identifier,
mode: leo_ast::Mode::None,
type_: element_type.clone(),
span: Span::default(),
id: slf.state.node_builder.next_id(),
};
let expr: Expression = Path::from(identifier).into_absolute().into();
replacements.insert(key, expr.clone());
tuple_elements.push(expr.clone());
inputs_and_arguments.push((
input,
TupleAccess {
tuple: Path::from(make_identifier(slf, symbol)).into_absolute().into(),
index: i.into(),
span: Span::default(),
id: slf.state.node_builder.next_id(),
}
.into(),
));
}
replacements.insert(
(symbol, None),
Expression::Tuple(TupleExpression {
elements: tuple_elements,
span: Span::default(),
id: slf.state.node_builder.next_id(),
}),
);
inputs_and_arguments
}
_ => {
let identifier = make_identifier(slf, symbol);
let input = Input {
identifier,
mode: leo_ast::Mode::None,
type_: var_type.clone(),
span: Span::default(),
id: slf.state.node_builder.next_id(),
};
replacements.insert((symbol, None), Path::from(identifier).into_absolute().into());
let argument = Path::from(make_identifier(slf, symbol)).into_absolute().into();
vec![(input, argument)]
}
},
}
};
let (inputs, arguments): (Vec<_>, Vec<_>) = access_collector
.symbol_accesses
.iter()
.filter_map(|(path, index)| {
if self.state.symbol_table.lookup_global(&Location::new(self.current_program, path.to_vec())).is_some()
{
return None;
}
let local_var_name = *path.last().expect("all paths must have at least one segment.");
if self.state.symbol_table.is_local_to_or_in_child_scope(input.block.id(), local_var_name) {
return None;
}
let var = self.state.symbol_table.lookup_local(local_var_name)?;
Some(make_inputs_and_arguments(self, local_var_name, &var.type_, *index))
})
.flatten()
.unzip();
let replace_expr = |expr: &Expression| -> Expression {
match expr {
Expression::Path(path) => {
replacements.get(&(path.identifier().name, None)).cloned().unwrap_or_else(|| expr.clone())
}
Expression::TupleAccess(ta) => {
if let Expression::Path(path) = &ta.tuple {
replacements
.get(&(path.identifier().name, Some(ta.index.value())))
.cloned()
.unwrap_or_else(|| expr.clone())
} else {
expr.clone()
}
}
_ => expr.clone(),
}
};
let mut replacer = Replacer::new(replace_expr, true , self.state);
let new_block = replacer.reconstruct_block(input.block.clone()).0;
if inputs.len() > self.max_inputs {
self.state.handler.emit_err(leo_errors::StaticAnalyzerError::async_block_capturing_too_many_vars(
inputs.len(),
self.max_inputs,
input.span,
));
}
let function = Function {
annotations: vec![],
variant: Variant::AsyncFunction,
identifier: make_identifier(self, finalize_fn_name),
const_parameters: vec![],
input: inputs,
output: vec![], output_type: Type::Unit, block: new_block,
span: input.span,
id: self.state.node_builder.next_id(),
};
self.new_async_functions.push((finalize_fn_name, function));
let call_to_finalize = CallExpression {
function: Path::new(
vec![],
make_identifier(self, finalize_fn_name),
true,
Some(vec![finalize_fn_name]), Span::default(),
self.state.node_builder.next_id(),
),
const_arguments: vec![],
arguments,
program: Some(self.current_program),
span: input.span,
id: self.state.node_builder.next_id(),
};
self.modified = true;
(call_to_finalize.into(), ())
}
fn reconstruct_block(&mut self, input: Block) -> (Block, Self::AdditionalOutput) {
self.in_scope(input.id(), |slf| {
(
Block {
statements: input.statements.into_iter().map(|s| slf.reconstruct_statement(s).0).collect(),
span: input.span,
id: input.id,
},
Default::default(),
)
})
}
fn reconstruct_iteration(&mut self, input: IterationStatement) -> (Statement, Self::AdditionalOutput) {
self.in_scope(input.id(), |slf| {
(
IterationStatement {
type_: input.type_.map(|ty| slf.reconstruct_type(ty).0),
start: slf.reconstruct_expression(input.start, &()).0,
stop: slf.reconstruct_expression(input.stop, &()).0,
block: slf.reconstruct_block(input.block).0,
..input
}
.into(),
Default::default(),
)
})
}
}