use crate::CompilerState;
use leo_ast::{
AssignStatement,
Expression,
ExpressionVisitor,
Identifier,
Location,
Node as _,
Program,
StatementVisitor,
Type,
};
use leo_span::Symbol;
use indexmap::IndexMap;
pub struct WriteTransformingVisitor<'a> {
pub state: &'a mut CompilerState,
pub struct_members: IndexMap<Symbol, IndexMap<Symbol, Identifier>>,
pub array_members: IndexMap<Symbol, Vec<Identifier>>,
pub program: Symbol,
}
impl<'a> WriteTransformingVisitor<'a> {
pub fn new(state: &'a mut CompilerState, program: &Program) -> Self {
let visitor = WriteTransformingVisitor {
state,
struct_members: Default::default(),
array_members: Default::default(),
program: Symbol::intern(""),
};
let mut wtf = WriteTransformingFiller(visitor);
wtf.fill(program);
wtf.0
}
}
struct WriteTransformingFiller<'a>(WriteTransformingVisitor<'a>);
impl ExpressionVisitor for WriteTransformingFiller<'_> {
type AdditionalInput = ();
type Output = ();
fn visit_expression(&mut self, _input: &Expression, _additional: &Self::AdditionalInput) -> Self::Output {}
}
impl StatementVisitor for WriteTransformingFiller<'_> {
fn visit_assign(&mut self, input: &AssignStatement) {
self.access_recurse(&input.place);
}
}
impl WriteTransformingFiller<'_> {
fn fill(&mut self, program: &Program) {
for (_, scope) in program.program_scopes.iter() {
for (_, function) in scope.functions.iter() {
self.0.program = scope.program_id.name.name;
self.visit_block(&function.block);
}
}
}
fn access_recurse(&mut self, place: &Expression) -> Identifier {
match place {
Expression::Identifier(identifier) => *identifier,
Expression::ArrayAccess(array_access) => {
let array_name = self.access_recurse(&array_access.array);
let members = self.0.array_members.entry(array_name.name).or_insert_with(|| {
let ty = self.0.state.type_table.get(&array_access.array.id()).unwrap();
let Type::Array(arr) = ty else { panic!("Type checking should have prevented this.") };
(0..arr.length())
.map(|i| {
let id = self.0.state.node_builder.next_id();
let symbol = self.0.state.assigner.unique_symbol(format_args!("{array_name}#{i}"), "$");
self.0.state.type_table.insert(id, arr.element_type().clone());
Identifier::new(symbol, id)
})
.collect()
});
let Expression::Literal(lit) = &array_access.index else {
panic!("Const propagation should have ensured this is a literal.");
};
members[lit
.as_u32()
.expect("Const propagation should have ensured this is in range, and consequently a valid u32.")
as usize]
}
Expression::MemberAccess(member_access) => {
let struct_name = self.access_recurse(&member_access.inner);
let members = self.0.struct_members.entry(struct_name.name).or_insert_with(|| {
let ty = self.0.state.type_table.get(&member_access.inner.id()).unwrap();
let Type::Composite(comp) = ty else {
panic!("Type checking should have prevented this.");
};
let struct_ = self
.0
.state
.symbol_table
.lookup_struct(comp.id.name)
.or_else(|| {
self.0
.state
.symbol_table
.lookup_record(Location::new(comp.program.unwrap_or(self.0.program), comp.id.name))
})
.unwrap();
struct_
.members
.iter()
.map(|member| {
let name = member.name();
let id = self.0.state.node_builder.next_id();
let symbol = self.0.state.assigner.unique_symbol(format_args!("{struct_name}#{name}"), "$");
self.0.state.type_table.insert(id, member.type_.clone());
(member.name(), Identifier::new(symbol, id))
})
.collect()
});
*members.get(&member_access.name.name).expect("Type checking should have ensured this is valid.")
}
Expression::TupleAccess(_) => panic!("TupleAccess writes should have been removed by Destructuring"),
_ => panic!("Type checking should have ensured there are no other places for assignments"),
}
}
}