use crate::CompilerState;
use leo_ast::{
ArrayAccess,
AssignStatement,
AstVisitor,
DefinitionPlace,
DefinitionStatement,
Expression,
Identifier,
IntegerType,
Literal,
Location,
MemberAccess,
Node as _,
Path,
Program,
Statement,
Type,
};
use leo_span::Symbol;
use indexmap::IndexMap;
pub struct WriteTransformingVisitor<'a> {
pub state: &'a mut CompilerState,
pub struct_members: IndexMap<Symbol, IndexMap<Symbol, Identifier>>,
pub array_members: IndexMap<Symbol, Vec<Identifier>>,
pub program: Symbol,
}
impl<'a> WriteTransformingVisitor<'a> {
pub fn new(state: &'a mut CompilerState, program: &Program) -> Self {
let visitor = WriteTransformingVisitor {
state,
struct_members: Default::default(),
array_members: Default::default(),
program: Symbol::intern(""),
};
let mut wtf = WriteTransformingFiller(visitor);
wtf.fill(program);
wtf.0
}
pub fn define_variable_members(&mut self, name: Identifier, accumulate: &mut Vec<Statement>) {
if let Some(members) = self.array_members.get(&name.name).cloned() {
for (i, member) in members.iter().cloned().enumerate() {
let index = Literal::integer(
IntegerType::U8,
i.to_string(),
Default::default(),
self.state.node_builder.next_id(),
);
self.state.type_table.insert(index.id(), Type::Integer(IntegerType::U32));
let access = ArrayAccess {
array: Path::from(name).into_absolute().into(),
index: index.into(),
span: Default::default(),
id: self.state.node_builder.next_id(),
};
self.state.type_table.insert(access.id(), self.state.type_table.get(&member.id()).unwrap().clone());
let def = DefinitionStatement {
place: DefinitionPlace::Single(member),
type_: None,
value: access.into(),
span: Default::default(),
id: self.state.node_builder.next_id(),
};
accumulate.push(def.into());
self.define_variable_members(member, accumulate);
}
} else if let Some(members) = self.struct_members.get(&name.name) {
for (&field_name, &member) in members.clone().iter() {
let access = MemberAccess {
inner: Path::from(name).into_absolute().into(),
name: Identifier::new(field_name, self.state.node_builder.next_id()),
span: Default::default(),
id: self.state.node_builder.next_id(),
};
self.state.type_table.insert(access.id(), self.state.type_table.get(&member.id()).unwrap().clone());
let def = DefinitionStatement {
place: DefinitionPlace::Single(member),
type_: None,
value: access.into(),
span: Default::default(),
id: self.state.node_builder.next_id(),
};
accumulate.push(def.into());
self.define_variable_members(member, accumulate);
}
}
}
pub fn reconstruct_assign_place(&mut self, input: Expression) -> Identifier {
use Expression::*;
match input {
ArrayAccess(array_access) => {
let identifier = self.reconstruct_assign_place(array_access.array);
self.get_array_member(identifier.name, &array_access.index).expect("We have visited all array writes.")
}
Path(path) => leo_ast::Identifier { name: path.identifier().name, span: path.span, id: path.id },
MemberAccess(member_access) => {
let identifier = self.reconstruct_assign_place(member_access.inner);
self.get_struct_member(identifier.name, member_access.name.name)
.expect("We have visited all struct writes.")
}
TupleAccess(_) => panic!("TupleAccess writes should have been removed by Destructuring"),
_ => panic!("Type checking should have ensured there are no other places for assignments"),
}
}
pub fn reconstruct_assign_recurse(&self, place: Identifier, value: Expression, accumulate: &mut Vec<Statement>) {
if let Some(array_members) = self.array_members.get(&place.name) {
if let Expression::Array(value_array) = value {
for (&member, rhs_element) in array_members.iter().zip(value_array.elements) {
self.reconstruct_assign_recurse(member, rhs_element, accumulate);
}
} else {
let one_assign = AssignStatement {
place: Path::from(place).into_absolute().into(),
value,
span: Default::default(),
id: self.state.node_builder.next_id(),
}
.into();
accumulate.push(one_assign);
for (i, &member) in array_members.iter().enumerate() {
let access = ArrayAccess {
array: Path::from(place).into_absolute().into(),
index: Literal::integer(
IntegerType::U32,
format!("{i}u32"),
Default::default(),
self.state.node_builder.next_id(),
)
.into(),
span: Default::default(),
id: self.state.node_builder.next_id(),
};
self.reconstruct_assign_recurse(member, access.into(), accumulate);
}
}
} else if let Some(struct_members) = self.struct_members.get(&place.name) {
if let Expression::Struct(value_struct) = value {
for initializer in value_struct.members.into_iter() {
let member_name = struct_members.get(&initializer.identifier.name).expect("Member should exist.");
let rhs_expression =
initializer.expression.expect("This should have been normalized to have a value.");
self.reconstruct_assign_recurse(*member_name, rhs_expression, accumulate);
}
} else {
let one_assign = AssignStatement {
place: Path::from(place).into_absolute().into(),
value,
span: Default::default(),
id: self.state.node_builder.next_id(),
}
.into();
accumulate.push(one_assign);
for (field, member_name) in struct_members.iter() {
let access = MemberAccess {
inner: Path::from(place).into_absolute().into(),
name: Identifier::new(*field, self.state.node_builder.next_id()),
span: Default::default(),
id: self.state.node_builder.next_id(),
};
self.reconstruct_assign_recurse(*member_name, access.into(), accumulate);
}
}
} else {
let stmt = AssignStatement {
value,
place: Path::from(place).into_absolute().into(),
id: self.state.node_builder.next_id(),
span: Default::default(),
}
.into();
accumulate.push(stmt);
}
}
}
struct WriteTransformingFiller<'a>(WriteTransformingVisitor<'a>);
impl AstVisitor for WriteTransformingFiller<'_> {
type AdditionalInput = ();
type Output = ();
fn visit_expression(&mut self, _input: &Expression, _additional: &Self::AdditionalInput) -> Self::Output {}
fn visit_assign(&mut self, input: &AssignStatement) {
self.access_recurse(&input.place);
}
}
impl WriteTransformingFiller<'_> {
fn fill(&mut self, program: &Program) {
for (_, module) in program.modules.iter() {
self.0.program = module.program_name;
for (_, function) in module.functions.iter() {
self.visit_block(&function.block);
}
}
for (_, scope) in program.program_scopes.iter() {
for (_, function) in scope.functions.iter() {
self.0.program = scope.program_id.name.name;
self.visit_block(&function.block);
}
}
}
fn access_recurse(&mut self, place: &Expression) -> Identifier {
match place {
Expression::Path(path) => Identifier { name: path.identifier().name, span: path.span, id: path.id },
Expression::ArrayAccess(array_access) => {
let array_name = self.access_recurse(&array_access.array);
let members = self.0.array_members.entry(array_name.name).or_insert_with(|| {
let ty = self.0.state.type_table.get(&array_access.array.id()).unwrap();
let Type::Array(arr) = ty else { panic!("Type checking should have prevented this.") };
(0..arr.length.as_u32().expect("length should be known at this point"))
.map(|i| {
let id = self.0.state.node_builder.next_id();
let symbol = self.0.state.assigner.unique_symbol(format_args!("{array_name}#{i}"), "$");
self.0.state.type_table.insert(id, arr.element_type().clone());
Identifier::new(symbol, id)
})
.collect()
});
let Expression::Literal(lit) = &array_access.index else {
panic!("Const propagation should have ensured this is a literal.");
};
members[lit
.as_u32()
.expect("Const propagation should have ensured this is in range, and consequently a valid u32.")
as usize]
}
Expression::MemberAccess(member_access) => {
let struct_name = self.access_recurse(&member_access.inner);
let members = self.0.struct_members.entry(struct_name.name).or_insert_with(|| {
let ty = self.0.state.type_table.get(&member_access.inner.id()).unwrap();
let Type::Composite(comp) = ty else {
panic!("Type checking should have prevented this.");
};
let struct_ = self
.0
.state
.symbol_table
.lookup_struct(&comp.path.absolute_path())
.or_else(|| {
self.0.state.symbol_table.lookup_record(&Location::new(
comp.program.unwrap_or(self.0.program),
comp.path.absolute_path(),
))
})
.unwrap();
struct_
.members
.iter()
.map(|member| {
let name = member.name();
let id = self.0.state.node_builder.next_id();
let symbol = self.0.state.assigner.unique_symbol(format_args!("{struct_name}#{name}"), "$");
self.0.state.type_table.insert(id, member.type_.clone());
(member.name(), Identifier::new(symbol, id))
})
.collect()
});
*members.get(&member_access.name.name).expect("Type checking should have ensured this is valid.")
}
Expression::TupleAccess(_) => panic!("TupleAccess writes should have been removed by Destructuring"),
_ => panic!("Type checking should have ensured there are no other places for assignments"),
}
}
}