use super::*;
use leo_ast::{
Composite,
Constructor,
Function,
Location,
Mapping,
Member,
Mode,
NetworkName,
Program,
ProgramScope,
Type,
UpgradeVariant,
Variant,
snarkvm_admin_constructor,
snarkvm_checksum_constructor,
snarkvm_noupgrade_constructor,
};
use leo_span::{Symbol, sym};
use indexmap::IndexMap;
use itertools::Itertools;
use snarkvm::prelude::{CanaryV0, MainnetV0, TestnetV0};
use std::fmt::Write as _;
const EXPECT_STR: &str = "Failed to write code";
impl<'a> CodeGeneratingVisitor<'a> {
pub fn visit_program(&mut self, input: &'a Program) -> String {
let mut program_string = String::new();
input.stubs.iter().for_each(|(program_name, _)| {
writeln!(program_string, "import {program_name}.aleo;").expect(EXPECT_STR);
});
let program_scope: &ProgramScope = input.program_scopes.values().next().unwrap();
self.program_id = Some(program_scope.program_id);
writeln!(program_string, "program {};", program_scope.program_id).expect(EXPECT_STR);
let order = self.state.struct_graph.post_order().unwrap();
let this_program = self.program_id.unwrap().name.name;
let lookup = |name: &[Symbol]| {
self.state
.symbol_table
.lookup_struct(name)
.or_else(|| self.state.symbol_table.lookup_record(&Location::new(this_program, name.to_vec())))
};
for name in order.into_iter() {
if let Some(struct_) = lookup(&name) {
program_string.push_str(&self.visit_struct_or_record(struct_, &name));
}
}
for (_symbol, mapping) in program_scope.mappings.iter() {
program_string.push_str(&self.visit_mapping(mapping));
}
for (_symbol, function) in program_scope.functions.iter() {
if function.variant != Variant::AsyncFunction {
let mut function_string = self.visit_function(function);
if function.variant == Variant::AsyncTransition {
self.finalize_caller = Some(function.identifier.name);
let finalize = &self
.state
.symbol_table
.lookup_function(&Location::new(
self.program_id.unwrap().name.name,
vec![function.identifier.name], ))
.unwrap()
.clone()
.finalizer
.unwrap();
function_string.push_str(
&self.visit_function_with(
&program_scope
.functions
.iter()
.find(|(name, _f)| vec![*name] == finalize.location.path)
.unwrap()
.1,
&finalize.future_inputs,
),
);
}
program_string.push_str(&function_string);
}
}
if let Some(constructor) = program_scope.constructor.as_ref() {
program_string.push_str(&self.visit_constructor(constructor));
}
program_string
}
fn visit_struct_or_record(&mut self, struct_: &'a Composite, absolute_path: &[Symbol]) -> String {
if struct_.is_record {
self.visit_record(struct_, absolute_path)
} else {
self.visit_struct(struct_, absolute_path)
}
}
fn visit_struct(&mut self, struct_: &'a Composite, absolute_path: &[Symbol]) -> String {
self.composite_mapping.insert(absolute_path.to_vec(), (false, String::from("private")));
let mut output_string = format!(
"\nstruct {}:\n",
Self::legalize_path(absolute_path).unwrap_or_else(|| panic!(
"path format cannot be legalized at this point: {}",
absolute_path.iter().join("::")
))
);
for var in struct_.members.iter() {
writeln!(output_string, " {} as {};", var.identifier, Self::visit_type(&var.type_),).expect(EXPECT_STR);
}
output_string
}
fn visit_record(&mut self, record: &'a Composite, absolute_path: &[Symbol]) -> String {
self.composite_mapping.insert(absolute_path.to_vec(), (true, "record".into()));
let mut output_string = format!("\nrecord {}:\n", record.identifier);
let mut members = Vec::with_capacity(record.members.len());
let mut member_map: IndexMap<Symbol, Member> =
record.members.clone().into_iter().map(|member| (member.identifier.name, member)).collect();
members.push(member_map.shift_remove(&sym::owner).unwrap());
members.extend(member_map.into_iter().map(|(_, member)| member));
for var in members.iter() {
let mode = match var.mode {
Mode::Constant => "constant",
Mode::Public => "public",
Mode::None | Mode::Private => "private",
};
writeln!(
output_string,
" {} as {}.{mode};", var.identifier,
Self::visit_type(&var.type_)
)
.expect(EXPECT_STR);
}
output_string
}
fn visit_function_with(&mut self, function: &'a Function, futures: &[Location]) -> String {
self.next_register = 0;
self.variable_mapping = IndexMap::new();
self.variant = Some(function.variant);
self.variable_mapping.insert(sym::SelfLower, "self".to_string());
self.variable_mapping.insert(sym::block, "block".to_string());
self.variable_mapping.insert(sym::network, "network".to_string());
self.current_function = Some(function);
let mut function_string = match function.variant {
Variant::Transition | Variant::AsyncTransition => format!("\nfunction {}:\n", function.identifier),
Variant::Function => format!("\nclosure {}:\n", function.identifier),
Variant::AsyncFunction => format!("\nfinalize {}:\n", self.finalize_caller.unwrap()),
Variant::Inline => return String::new(),
Variant::Script => panic!("script should not appear in native code"),
};
let mut futures = futures.iter();
self.internal_record_inputs.clear();
for input in function.input.iter() {
let register_string = self.next_register();
if let Type::Composite(comp) = &input.type_ {
let program = comp.program.unwrap_or(self.program_id.unwrap().name.name);
if let Some(record) =
self.state.symbol_table.lookup_record(&Location::new(program, comp.path.absolute_path().to_vec()))
&& (record.external.is_none() || record.external == self.program_id.map(|id| id.name.name))
{
self.internal_record_inputs.insert(register_string.clone());
}
}
let type_string = {
self.variable_mapping.insert(input.identifier.name, register_string.clone());
let visibility = match (self.variant.unwrap(), input.mode) {
(Variant::AsyncTransition, Mode::None) | (Variant::Transition, Mode::None) => Mode::Private,
(Variant::AsyncFunction, Mode::None) => Mode::Public,
_ => input.mode,
};
if matches!(input.type_, Type::Future(_)) {
let location = futures
.next()
.expect("Type checking guarantees we have future locations for each future input");
let [future_name] = location.path.as_slice() else {
panic!("All futures must have a single segment paths since they don't belong to submodules.")
};
format!("{}.aleo/{}.future", location.program, future_name)
} else {
self.visit_type_with_visibility(&input.type_, visibility)
}
};
writeln!(function_string, " input {register_string} as {type_string};",).expect(EXPECT_STR);
}
let block_string = self.visit_block(&function.block);
if matches!(self.variant.unwrap(), Variant::Function | Variant::AsyncFunction)
&& block_string.lines().all(|line| line.starts_with(" output "))
{
function_string.push_str(" assert.eq true true;\n");
}
function_string.push_str(&block_string);
function_string
}
fn visit_function(&mut self, function: &'a Function) -> String {
self.visit_function_with(function, &[])
}
fn visit_constructor(&mut self, constructor: &'a Constructor) -> String {
self.next_register = 0;
self.variable_mapping = IndexMap::new();
self.variant = Some(Variant::AsyncFunction);
self.variable_mapping.insert(sym::SelfLower, "self".to_string());
self.variable_mapping.insert(sym::block, "block".to_string());
self.variable_mapping.insert(sym::network, "network".to_string());
let upgrade_variant = constructor
.get_upgrade_variant_with_network(self.state.network)
.expect("Type checking should have validated the upgrade variant");
let constructor = match &upgrade_variant {
UpgradeVariant::Admin { address } => snarkvm_admin_constructor(address),
UpgradeVariant::Checksum { mapping, key, .. } => {
if mapping.program
== self.program_id.expect("Program ID should be set before traversing the program").name.name
{
let [mapping_name] = &mapping.path[..] else {
panic!("Mappings are only allowed in the top level program at this stage");
};
snarkvm_checksum_constructor(mapping_name, key)
} else {
snarkvm_checksum_constructor(mapping, key)
}
}
UpgradeVariant::Custom => format!("\nconstructor:\n{}\n", self.visit_block(&constructor.block)),
UpgradeVariant::NoUpgrade => snarkvm_noupgrade_constructor(),
};
if let Err(e) = match self.state.network {
NetworkName::MainnetV0 => check_snarkvm_constructor::<MainnetV0>(&constructor),
NetworkName::TestnetV0 => check_snarkvm_constructor::<TestnetV0>(&constructor),
NetworkName::CanaryV0 => check_snarkvm_constructor::<CanaryV0>(&constructor),
} {
panic!("Compilation produced an invalid constructor: {e}");
};
constructor
}
fn visit_mapping(&mut self, mapping: &'a Mapping) -> String {
let legalized_mapping_name = Self::legalize_path(&[mapping.identifier.name]);
let mut mapping_string = format!(
"\nmapping {}:\n",
legalized_mapping_name
.clone()
.unwrap_or_else(|| panic!("path format cannot be legalized at this point: {}", mapping.identifier))
);
let create_type = |type_: &Type| {
match type_ {
Type::Mapping(_) | Type::Tuple(_) => panic!("Mappings cannot contain mappings or tuples."),
Type::Identifier(identifier) => {
let (is_record, _) = self.composite_mapping.get(&vec![identifier.name]).unwrap();
assert!(!is_record, "Type checking guarantees that mappings cannot contain records.");
self.visit_type_with_visibility(type_, Mode::Public)
}
type_ => self.visit_type_with_visibility(type_, Mode::Public),
}
};
writeln!(mapping_string, " key as {};", create_type(&mapping.key_type)).expect(EXPECT_STR);
writeln!(mapping_string, " value as {};", create_type(&mapping.value_type)).expect(EXPECT_STR);
self.global_mapping.insert(
mapping.identifier.name,
legalized_mapping_name
.unwrap_or_else(|| panic!("path format cannot be legalized at this point: {}", mapping.identifier)),
);
mapping_string
}
}