use std::collections::{HashMap, HashSet};
use std::iter::FromIterator;
use crate::{
error::*, parse_tree::Scrutinee, parse_tree::*, span::Span, type_engine::IntegerBits, AstNode,
AstNodeContent, CodeBlock, Declaration, Expression, ReturnStatement, TypeInfo, WhileLoop,
};
pub(crate) fn order_ast_nodes_by_dependency<'sc>(
nodes: Vec<AstNode<'sc>>,
) -> CompileResult<'sc, Vec<AstNode<'sc>>> {
let decl_dependencies =
DependencyMap::from_iter(nodes.iter().filter_map(Dependencies::gather_from_decl_node));
let mut errors = find_recursive_calls(&decl_dependencies);
if !errors.is_empty() {
errors.sort_by(|lhs, rhs| lhs.span().0.cmp(&rhs.span().0));
err(Vec::new(), errors)
} else {
ok(
nodes
.into_iter()
.fold(Vec::<AstNode<'sc>>::new(), |ordered, node| {
insert_into_ordered_nodes(&decl_dependencies, ordered, node)
}),
Vec::new(),
Vec::new(),
)
}
}
fn find_recursive_calls<'sc>(decl_dependencies: &DependencyMap<'sc>) -> Vec<CompileError<'sc>> {
decl_dependencies
.iter()
.filter_map(|(dep_sym, _)| find_recursive_call(decl_dependencies, dep_sym))
.collect()
}
fn find_recursive_call<'sc>(
decl_dependencies: &DependencyMap<'sc>,
fn_sym: &DependentSymbol<'sc>,
) -> Option<CompileError<'sc>> {
if let DependentSymbol::Fn(_, Some(fn_span)) = fn_sym {
let mut chain = Vec::new();
find_recursive_call_chain(decl_dependencies, fn_sym, fn_span, &mut chain)
} else {
None
}
}
fn find_recursive_call_chain<'sc>(
decl_dependencies: &DependencyMap<'sc>,
fn_sym: &DependentSymbol<'sc>,
fn_span: &Span<'sc>,
chain: &mut Vec<&'sc str>,
) -> Option<CompileError<'sc>> {
if let DependentSymbol::Fn(fn_sym_str, _) = fn_sym {
if chain.iter().any(|seen_sym| seen_sym == fn_sym_str) {
return if &chain[0] != fn_sym_str {
None
} else {
Some(build_recursion_error(
fn_sym_str,
fn_span.clone(),
&chain[1..],
))
};
}
decl_dependencies
.get(fn_sym)
.map(|deps_set| {
chain.push(fn_sym_str);
let result = deps_set.deps.iter().find_map(|dep_sym| {
find_recursive_call_chain(decl_dependencies, dep_sym, fn_span, chain)
});
chain.pop();
result
})
.flatten()
} else {
None
}
}
fn build_recursion_error<'sc>(
fn_sym: &'sc str,
span: Span<'sc>,
chain: &[&'sc str],
) -> CompileError<'sc> {
match chain.len() {
0 => CompileError::RecursiveCall {
fn_name: fn_sym,
span,
},
1 => CompileError::RecursiveCallChain {
fn_name: fn_sym,
call_chain: chain[0].to_owned(),
span,
},
n => {
let msg = chain[0..(n - 1)].join(", ");
CompileError::RecursiveCallChain {
fn_name: fn_sym,
call_chain: msg + " and " + chain[n - 1],
span,
}
}
}
}
type DependencyMap<'sc> = HashMap<DependentSymbol<'sc>, Dependencies<'sc>>;
fn insert_into_ordered_nodes<'sc>(
decl_dependencies: &DependencyMap<'sc>,
mut ordered_nodes: Vec<AstNode<'sc>>,
node: AstNode<'sc>,
) -> Vec<AstNode<'sc>> {
for idx in 0..ordered_nodes.len() {
if depends_on(decl_dependencies, &ordered_nodes[idx], &node) {
ordered_nodes.insert(idx, node);
return ordered_nodes;
}
}
ordered_nodes.push(node);
ordered_nodes
}
fn depends_on<'sc>(
decl_dependencies: &DependencyMap<'sc>,
dependant_node: &AstNode<'sc>,
dependee_node: &AstNode<'sc>,
) -> bool {
match (&dependant_node.content, &dependee_node.content) {
(AstNodeContent::IncludeStatement(_), AstNodeContent::IncludeStatement(_)) => false,
(_, AstNodeContent::IncludeStatement(_)) => true,
(AstNodeContent::IncludeStatement(_), AstNodeContent::UseStatement(_)) => false,
(AstNodeContent::UseStatement(_), AstNodeContent::UseStatement(_)) => false,
(_, AstNodeContent::UseStatement(_)) => true,
(AstNodeContent::IncludeStatement(_), AstNodeContent::Declaration(_)) => false,
(AstNodeContent::UseStatement(_), AstNodeContent::Declaration(_)) => false,
(AstNodeContent::Declaration(dependant), AstNodeContent::Declaration(dependee)) => {
match (decl_name(dependant), decl_name(dependee)) {
(Some(dependant_name), Some(dependee_name)) => decl_dependencies
.get(&dependant_name)
.map(|deps_set| deps_set.deps.contains(&dependee_name))
.unwrap_or(false),
_ => false,
}
}
(_, AstNodeContent::Declaration(_)) => true,
_ => false,
}
}
#[derive(Debug)]
struct Dependencies<'sc> {
deps: HashSet<DependentSymbol<'sc>>,
}
impl<'sc> Dependencies<'sc> {
fn gather_from_decl_node(
node: &AstNode<'sc>,
) -> Option<(DependentSymbol<'sc>, Dependencies<'sc>)> {
match &node.content {
AstNodeContent::Declaration(decl) => decl_name(decl).map(|name| {
(
name,
Dependencies {
deps: HashSet::new(),
}
.gather_from_decl(decl),
)
}),
_ => None,
}
}
fn gather_from_decl(self, decl: &Declaration<'sc>) -> Self {
match decl {
Declaration::VariableDeclaration(VariableDeclaration {
type_ascription,
body,
..
}) => self
.gather_from_typeinfo(type_ascription)
.gather_from_expr(body),
Declaration::ConstantDeclaration(ConstantDeclaration {
type_ascription,
value,
..
}) => self
.gather_from_typeinfo(type_ascription)
.gather_from_expr(value),
Declaration::FunctionDeclaration(fn_decl) => self.gather_from_fn_decl(fn_decl),
Declaration::StructDeclaration(StructDeclaration {
fields,
type_parameters,
..
}) => self
.gather_from_iter(fields.iter(), |deps, field| {
deps.gather_from_typeinfo(&field.r#type)
})
.gather_from_traits(type_parameters),
Declaration::EnumDeclaration(EnumDeclaration {
variants,
type_parameters,
..
}) => self
.gather_from_iter(variants.iter(), |deps, variant| {
deps.gather_from_typeinfo(&variant.r#type)
})
.gather_from_traits(type_parameters),
Declaration::Reassignment(decl) => self.gather_from_expr(&decl.rhs),
Declaration::TraitDeclaration(TraitDeclaration {
interface_surface,
methods,
type_parameters,
..
}) => self
.gather_from_iter(interface_surface.iter(), |deps, sig| {
deps.gather_from_iter(sig.parameters.iter(), |deps, param| {
deps.gather_from_typeinfo(¶m.r#type)
})
.gather_from_typeinfo(&sig.return_type)
})
.gather_from_iter(methods.iter(), |deps, fn_decl| {
deps.gather_from_fn_decl(fn_decl)
})
.gather_from_traits(type_parameters),
Declaration::ImplTrait(ImplTrait {
trait_name,
type_implementing_for,
type_arguments,
functions,
..
}) => self
.gather_from_call_path(trait_name, false, false)
.gather_from_typeinfo(type_implementing_for)
.gather_from_traits(type_arguments)
.gather_from_iter(functions.iter(), |deps, fn_decl| {
deps.gather_from_fn_decl(fn_decl)
}),
Declaration::ImplSelf(ImplSelf {
type_implementing_for,
type_arguments,
functions,
..
}) => self
.gather_from_typeinfo(type_implementing_for)
.gather_from_traits(type_arguments)
.gather_from_iter(functions.iter(), |deps, fn_decl| {
deps.gather_from_fn_decl(fn_decl)
}),
Declaration::AbiDeclaration(AbiDeclaration {
interface_surface,
methods,
..
}) => self
.gather_from_iter(interface_surface.iter(), |deps, sig| {
deps.gather_from_iter(sig.parameters.iter(), |deps, param| {
deps.gather_from_typeinfo(¶m.r#type)
})
.gather_from_typeinfo(&sig.return_type)
})
.gather_from_iter(methods.iter(), |deps, fn_decl| {
deps.gather_from_fn_decl(fn_decl)
}),
Declaration::StorageDeclaration(StorageDeclaration { fields, .. }) => self
.gather_from_iter(
fields.iter(),
|deps,
StorageField {
r#type,
initializer,
..
}| {
deps.gather_from_typeinfo(r#type)
.gather_from_expr(initializer)
},
),
}
}
fn gather_from_fn_decl(self, fn_decl: &FunctionDeclaration<'sc>) -> Self {
let FunctionDeclaration {
parameters,
return_type,
body,
type_parameters,
..
} = fn_decl;
self.gather_from_iter(parameters.iter(), |deps, param| {
deps.gather_from_typeinfo(¶m.r#type)
})
.gather_from_typeinfo(return_type)
.gather_from_block(body)
.gather_from_traits(type_parameters)
}
fn gather_from_expr(mut self, expr: &Expression<'sc>) -> Self {
match expr {
Expression::VariableExpression { .. } => self,
Expression::FunctionApplication {
name, arguments, ..
} => self
.gather_from_call_path(name, false, true)
.gather_from_iter(arguments.iter(), |deps, arg| deps.gather_from_expr(arg)),
Expression::LazyOperator { lhs, rhs, .. } => {
self.gather_from_expr(lhs).gather_from_expr(rhs)
}
Expression::IfExp {
condition,
then,
r#else,
..
} => if let Some(else_expr) = r#else {
self.gather_from_expr(else_expr)
} else {
self
}
.gather_from_expr(condition)
.gather_from_expr(then),
Expression::CodeBlock { contents, .. } => self.gather_from_block(contents),
Expression::Array { contents, .. } => {
self.gather_from_iter(contents.iter(), |deps, expr| deps.gather_from_expr(expr))
}
Expression::ArrayIndex { prefix, index, .. } => {
self.gather_from_expr(prefix).gather_from_expr(index)
}
Expression::StructExpression {
struct_name,
fields,
..
} => {
self.deps.insert(DependentSymbol::Symbol(
struct_name.primary_name.to_string(),
));
self.gather_from_iter(fields.iter(), |deps, field| {
deps.gather_from_expr(&field.value)
})
}
Expression::SubfieldExpression { prefix, .. } => self.gather_from_expr(prefix),
Expression::DelineatedPath { call_path, .. } => {
self.gather_from_call_path(call_path, true, false)
}
Expression::MethodApplication { arguments, .. } => {
self.gather_from_iter(arguments.iter(), |deps, arg| deps.gather_from_expr(arg))
}
Expression::AsmExpression { asm, .. } => self
.gather_from_iter(asm.registers.iter(), |deps, register| {
deps.gather_from_opt_expr(®ister.initializer)
})
.gather_from_typeinfo(&asm.return_type),
Expression::MatchExpression {
primary_expression,
branches,
..
} => self.gather_from_expr(primary_expression).gather_from_iter(
branches.iter(),
|deps, branch| {
match &branch.condition {
MatchCondition::CatchAll(_) => deps,
MatchCondition::Scrutinee(scrutinee) => {
deps.gather_from_scrutinee(scrutinee)
}
}
.gather_from_expr(&branch.result)
},
),
Expression::AbiCast { .. } => self,
Expression::Literal { .. } => self,
Expression::Unit { .. } => self,
Expression::DelayedMatchTypeResolution { .. } => self,
}
}
fn gather_from_scrutinee(self, scrutinee: &Scrutinee<'sc>) -> Self {
match scrutinee {
Scrutinee::Unit { .. } => self,
Scrutinee::Literal { .. } => self,
Scrutinee::Variable { .. } => self,
Scrutinee::StructScrutinee { .. } => self,
Scrutinee::EnumScrutinee { .. } => self,
}
}
fn gather_from_opt_expr(self, opt_expr: &Option<Expression<'sc>>) -> Self {
match opt_expr {
None => self,
Some(expr) => self.gather_from_expr(expr),
}
}
fn gather_from_block(self, block: &CodeBlock<'sc>) -> Self {
self.gather_from_iter(block.contents.iter(), |deps, node| {
deps.gather_from_node(node)
})
}
fn gather_from_node(self, node: &AstNode<'sc>) -> Self {
match &node.content {
AstNodeContent::ReturnStatement(ReturnStatement { expr }) => {
self.gather_from_expr(expr)
}
AstNodeContent::Expression(expr) => self.gather_from_expr(expr),
AstNodeContent::ImplicitReturnExpression(expr) => self.gather_from_expr(expr),
AstNodeContent::Declaration(decl) => self.gather_from_decl(decl),
AstNodeContent::WhileLoop(WhileLoop { condition, body }) => {
self.gather_from_expr(condition).gather_from_block(body)
}
AstNodeContent::UseStatement(_) => self,
AstNodeContent::IncludeStatement(_) => self,
}
}
fn gather_from_call_path(
mut self,
call_path: &CallPath<'sc>,
use_prefix: bool,
is_fn_app: bool,
) -> Self {
if call_path.prefixes.is_empty() {
self.deps.insert(if is_fn_app {
DependentSymbol::Fn(call_path.suffix.primary_name, None)
} else {
DependentSymbol::Symbol(call_path.suffix.primary_name.to_string())
});
} else if use_prefix && call_path.prefixes.len() == 1 {
self.deps.insert(DependentSymbol::Symbol(
call_path.prefixes[0].primary_name.to_string(),
));
}
self
}
fn gather_from_traits(mut self, type_parameters: &[TypeParameter<'sc>]) -> Self {
for type_param in type_parameters {
for constraint in &type_param.trait_constraints {
self.deps.insert(DependentSymbol::Symbol(
constraint.name.primary_name.to_string(),
));
}
}
self
}
fn gather_from_typeinfo(mut self, type_info: &TypeInfo) -> Self {
if let TypeInfo::Custom { name } = type_info {
self.deps.insert(DependentSymbol::Symbol(name.clone()));
}
self
}
fn gather_from_iter<I: Iterator, F: FnMut(Self, I::Item) -> Self>(self, iter: I, f: F) -> Self {
iter.fold(self, f)
}
}
#[derive(Debug, Eq)]
enum DependentSymbol<'sc> {
Symbol(String),
Fn(&'sc str, Option<Span<'sc>>),
Impl(&'sc str, String), }
impl<'sc> PartialEq for DependentSymbol<'sc> {
fn eq(&self, rhs: &Self) -> bool {
match (self, rhs) {
(DependentSymbol::Symbol(l), DependentSymbol::Symbol(r)) => l.eq(r),
(DependentSymbol::Fn(l, _), DependentSymbol::Fn(r, _)) => l.eq(r),
(DependentSymbol::Impl(lt, ls), DependentSymbol::Impl(rt, rs)) => {
lt.eq(rt) && ls.eq(rs)
}
_ => false,
}
}
}
use std::hash::{Hash, Hasher};
impl<'sc> Hash for DependentSymbol<'sc> {
fn hash<H: Hasher>(&self, state: &mut H) {
match self {
DependentSymbol::Symbol(s) => s.hash(state),
DependentSymbol::Fn(s, _) => s.hash(state),
DependentSymbol::Impl(t, s) => {
t.hash(state);
s.hash(state)
}
}
}
}
fn decl_name<'sc>(decl: &Declaration<'sc>) -> Option<DependentSymbol<'sc>> {
let dep_sym = |name| Some(DependentSymbol::Symbol(name));
let impl_sym = |trait_name, type_info: &TypeInfo| {
Some(DependentSymbol::Impl(trait_name, type_info_name(type_info)))
};
match decl {
Declaration::FunctionDeclaration(decl) => Some(DependentSymbol::Fn(
decl.name.primary_name,
Some(decl.span.clone()),
)),
Declaration::ConstantDeclaration(decl) => dep_sym(decl.name.primary_name.to_string()),
Declaration::StructDeclaration(decl) => dep_sym(decl.name.primary_name.to_string()),
Declaration::EnumDeclaration(decl) => dep_sym(decl.name.primary_name.to_string()),
Declaration::TraitDeclaration(decl) => dep_sym(decl.name.primary_name.to_string()),
Declaration::AbiDeclaration(decl) => dep_sym(decl.name.primary_name.to_string()),
Declaration::ImplSelf(decl) => impl_sym("self", &decl.type_implementing_for),
Declaration::ImplTrait(decl) => {
if decl.trait_name.prefixes.is_empty() {
impl_sym(
decl.trait_name.suffix.primary_name,
&decl.type_implementing_for,
)
} else {
None
}
}
Declaration::VariableDeclaration(_) => None,
Declaration::Reassignment(_) => None,
Declaration::StorageDeclaration(_) => None,
}
}
fn type_info_name(type_info: &TypeInfo) -> String {
match type_info {
TypeInfo::Str(_) => "str",
TypeInfo::UnsignedInteger(n) => match n {
IntegerBits::Eight => "uint8",
IntegerBits::Sixteen => "uint16",
IntegerBits::ThirtyTwo => "uint32",
IntegerBits::SixtyFour => "uint64",
},
TypeInfo::Boolean => "bool",
TypeInfo::Custom { name } => name,
TypeInfo::Unit => "unit",
TypeInfo::SelfType => "self",
TypeInfo::Byte => "byte",
TypeInfo::B256 => "b256",
TypeInfo::Numeric => "numeric",
TypeInfo::Contract => "contract",
TypeInfo::ErrorRecovery => "err_recov",
TypeInfo::Ref(x) => return format!("T{}", x),
TypeInfo::Unknown => "unknown",
TypeInfo::UnknownGeneric { name } => return format!("generic {}", name),
TypeInfo::ContractCaller { .. } => "contract caller",
TypeInfo::Struct { .. } => "struct",
TypeInfo::Enum { .. } => "enum",
TypeInfo::Array(..) => "array",
}
.to_string()
}