mod circuit;
pub use circuit::*;
mod function;
pub use function::*;
use crate::{ArenaNode, AsgContext, AsgConvertError, ImportResolver, Input, Scope};
use leo_ast::{Identifier, PackageAccess, PackageOrPackages, Span};
use indexmap::IndexMap;
use std::cell::{Cell, RefCell};
#[derive(Clone)]
pub struct Program<'a> {
pub context: AsgContext<'a>,
pub id: u32,
pub name: String,
pub imported_modules: IndexMap<String, Program<'a>>,
pub functions: IndexMap<String, &'a Function<'a>>,
pub circuits: IndexMap<String, &'a Circuit<'a>>,
pub scope: &'a Scope<'a>,
}
#[derive(Clone)]
enum ImportSymbol {
Direct(String),
Alias(String, String),
All,
}
fn resolve_import_package(
output: &mut Vec<(Vec<String>, ImportSymbol, Span)>,
mut package_segments: Vec<String>,
package_or_packages: &PackageOrPackages,
) {
match package_or_packages {
PackageOrPackages::Package(package) => {
package_segments.push(package.name.name.to_string());
resolve_import_package_access(output, package_segments, &package.access);
}
PackageOrPackages::Packages(packages) => {
package_segments.push(packages.name.name.to_string());
for access in packages.accesses.clone() {
resolve_import_package_access(output, package_segments.clone(), &access);
}
}
}
}
fn resolve_import_package_access(
output: &mut Vec<(Vec<String>, ImportSymbol, Span)>,
mut package_segments: Vec<String>,
package: &PackageAccess,
) {
match package {
PackageAccess::Star(span) => {
output.push((package_segments, ImportSymbol::All, span.clone()));
}
PackageAccess::SubPackage(subpackage) => {
resolve_import_package(
output,
package_segments,
&PackageOrPackages::Package(*(*subpackage).clone()),
);
}
PackageAccess::Symbol(symbol) => {
let span = symbol.symbol.span.clone();
let symbol = if let Some(alias) = symbol.alias.as_ref() {
ImportSymbol::Alias(symbol.symbol.name.to_string(), alias.name.to_string())
} else {
ImportSymbol::Direct(symbol.symbol.name.to_string())
};
output.push((package_segments, symbol, span));
}
PackageAccess::Multiple(packages) => {
package_segments.push(packages.name.name.to_string());
for subaccess in packages.accesses.iter() {
resolve_import_package_access(output, package_segments.clone(), &subaccess);
}
}
}
}
impl<'a> Program<'a> {
pub fn new<T: ImportResolver<'a>>(
context: AsgContext<'a>,
program: &leo_ast::Program,
import_resolver: &mut T,
) -> Result<Program<'a>, AsgConvertError> {
let mut imported_symbols: Vec<(Vec<String>, ImportSymbol, Span)> = vec![];
for import in program.imports.iter() {
resolve_import_package(&mut imported_symbols, vec![], &import.package_or_packages);
}
let mut deduplicated_imports: IndexMap<Vec<String>, Span> = IndexMap::new();
for (package, _symbol, span) in imported_symbols.iter() {
deduplicated_imports.insert(package.clone(), span.clone());
}
let mut wrapped_resolver = crate::CoreImportResolver::new(import_resolver);
let mut resolved_packages: IndexMap<Vec<String>, Program> = IndexMap::new();
for (package, span) in deduplicated_imports.iter() {
let pretty_package = package.join(".");
let resolved_package = match wrapped_resolver.resolve_package(
context,
&package.iter().map(|x| &**x).collect::<Vec<_>>()[..],
span,
)? {
Some(x) => x,
None => return Err(AsgConvertError::unresolved_import(&*pretty_package, &Span::default())),
};
resolved_packages.insert(package.clone(), resolved_package);
}
let mut imported_functions: IndexMap<String, &'a Function<'a>> = IndexMap::new();
let mut imported_circuits: IndexMap<String, &'a Circuit<'a>> = IndexMap::new();
for (package, symbol, span) in imported_symbols.into_iter() {
let pretty_package = package.join(".");
let resolved_package = resolved_packages
.get(&package)
.expect("could not find preloaded package");
match symbol {
ImportSymbol::All => {
imported_functions.extend(resolved_package.functions.clone().into_iter());
imported_circuits.extend(resolved_package.circuits.clone().into_iter());
}
ImportSymbol::Direct(name) => {
if let Some(function) = resolved_package.functions.get(&name) {
imported_functions.insert(name.clone(), *function);
} else if let Some(circuit) = resolved_package.circuits.get(&name) {
imported_circuits.insert(name.clone(), *circuit);
} else {
return Err(AsgConvertError::unresolved_import(
&*format!("{}.{}", pretty_package, name),
&span,
));
}
}
ImportSymbol::Alias(name, alias) => {
if let Some(function) = resolved_package.functions.get(&name) {
imported_functions.insert(alias.clone(), *function);
} else if let Some(circuit) = resolved_package.circuits.get(&name) {
imported_circuits.insert(alias.clone(), *circuit);
} else {
return Err(AsgConvertError::unresolved_import(
&*format!("{}.{}", pretty_package, name),
&span,
));
}
}
}
}
let import_scope = match context.arena.alloc(ArenaNode::Scope(Scope {
context,
id: context.get_id(),
parent_scope: Cell::new(None),
circuit_self: Cell::new(None),
variables: RefCell::new(IndexMap::new()),
functions: RefCell::new(imported_functions),
circuits: RefCell::new(imported_circuits),
function: Cell::new(None),
input: Cell::new(None),
})) {
ArenaNode::Scope(c) => c,
_ => unimplemented!(),
};
let scope = import_scope.context.alloc_scope(Scope {
context,
input: Cell::new(Some(Input::new(import_scope))),
id: context.get_id(),
parent_scope: Cell::new(Some(import_scope)),
circuit_self: Cell::new(None),
variables: RefCell::new(IndexMap::new()),
functions: RefCell::new(IndexMap::new()),
circuits: RefCell::new(IndexMap::new()),
function: Cell::new(None),
});
for (name, circuit) in program.circuits.iter() {
assert_eq!(name.name, circuit.circuit_name.name);
let asg_circuit = Circuit::init(scope, circuit)?;
scope.circuits.borrow_mut().insert(name.name.to_string(), asg_circuit);
}
for (name, circuit) in program.circuits.iter() {
assert_eq!(name.name, circuit.circuit_name.name);
let asg_circuit = Circuit::init_member(scope, circuit)?;
scope.circuits.borrow_mut().insert(name.name.to_string(), asg_circuit);
}
for (name, function) in program.functions.iter() {
assert_eq!(name.name, function.identifier.name);
let function = Function::init(scope, function)?;
scope.functions.borrow_mut().insert(name.name.to_string(), function);
}
let mut functions = IndexMap::new();
for (name, function) in program.functions.iter() {
assert_eq!(name.name, function.identifier.name);
let asg_function = *scope.functions.borrow().get(name.name.as_ref()).unwrap();
asg_function.fill_from_ast(function)?;
let name = name.name.to_string();
if functions.contains_key(&name) {
return Err(AsgConvertError::duplicate_function_definition(&name, &function.span));
}
functions.insert(name, asg_function);
}
let mut circuits = IndexMap::new();
for (name, circuit) in program.circuits.iter() {
assert_eq!(name.name, circuit.circuit_name.name);
let asg_circuit = *scope.circuits.borrow().get(name.name.as_ref()).unwrap();
asg_circuit.fill_from_ast(circuit)?;
circuits.insert(name.name.to_string(), asg_circuit);
}
Ok(Program {
context,
id: context.get_id(),
name: program.name.clone(),
functions,
circuits,
imported_modules: resolved_packages
.into_iter()
.map(|(package, program)| (package.join("."), program))
.collect(),
scope,
})
}
pub(crate) fn set_core_mapping(&self, mapping: &str) {
for (_, circuit) in self.circuits.iter() {
circuit.core_mapping.replace(Some(mapping.to_string()));
}
}
}
struct InternalIdentifierGenerator {
next: usize,
}
impl Iterator for InternalIdentifierGenerator {
type Item = String;
fn next(&mut self) -> Option<String> {
let out = format!("$_{}_", self.next);
self.next += 1;
Some(out)
}
}
pub fn reform_ast<'a>(program: &Program<'a>) -> leo_ast::Program {
let mut all_programs: IndexMap<String, Program> = IndexMap::new();
let mut program_stack = program.imported_modules.clone();
while let Some((module, program)) = program_stack.pop() {
if all_programs.contains_key(&module) {
continue;
}
all_programs.insert(module, program.clone());
program_stack.extend(program.imported_modules.clone());
}
all_programs.insert("".to_string(), program.clone());
let core_programs: Vec<_> = all_programs
.iter()
.filter(|(module, _)| module.starts_with("core."))
.map(|(module, program)| (module.clone(), program.clone()))
.collect();
all_programs.retain(|module, _| !module.starts_with("core."));
let mut all_circuits: IndexMap<String, &'a Circuit<'a>> = IndexMap::new();
let mut all_functions: IndexMap<String, &'a Function<'a>> = IndexMap::new();
let mut identifiers = InternalIdentifierGenerator { next: 0 };
for (_, program) in all_programs.into_iter() {
for (name, circuit) in program.circuits.iter() {
let identifier = format!("{}{}", identifiers.next().unwrap(), name);
circuit.name.borrow_mut().name = identifier.clone().into();
all_circuits.insert(identifier, *circuit);
}
for (name, function) in program.functions.iter() {
let identifier = if name == "main" {
"main".to_string()
} else {
format!("{}{}", identifiers.next().unwrap(), name)
};
function.name.borrow_mut().name = identifier.clone().into();
all_functions.insert(identifier, *function);
}
}
leo_ast::Program {
name: "ast_aggregate".to_string(),
imports: core_programs
.iter()
.map(|(module, _)| leo_ast::ImportStatement {
package_or_packages: leo_ast::PackageOrPackages::Package(leo_ast::Package {
name: Identifier::new(module.clone().into()),
access: leo_ast::PackageAccess::Star(Span::default()),
span: Default::default(),
}),
span: Span::default(),
})
.collect(),
expected_input: vec![],
functions: all_functions
.into_iter()
.map(|(_, function)| (function.name.borrow().clone(), function.into()))
.collect(),
circuits: all_circuits
.into_iter()
.map(|(_, circuit)| (circuit.name.borrow().clone(), circuit.into()))
.collect(),
}
}
impl<'a> Into<leo_ast::Program> for &Program<'a> {
fn into(self) -> leo_ast::Program {
leo_ast::Program {
name: self.name.clone(),
imports: vec![],
expected_input: vec![],
circuits: self
.circuits
.iter()
.map(|(_, circuit)| (circuit.name.borrow().clone(), (*circuit).into()))
.collect(),
functions: self
.functions
.iter()
.map(|(_, function)| (function.name.borrow().clone(), (*function).into()))
.collect(),
}
}
}