use crate::sema::ast::{DestructureField, Expression, Namespace, Statement};
use crate::sema::Recurse;
use indexmap::IndexSet;
use solang_parser::pt;
#[derive(Default)]
struct CallList {
pub solidity: IndexSet<usize>,
pub yul: IndexSet<usize>,
}
pub fn add_external_functions(contract_no: usize, ns: &mut Namespace) {
let mut call_list = CallList::default();
for var in &ns.contracts[contract_no].variables {
if let Some(init) = &var.initializer {
init.recurse(&mut call_list, check_expression);
}
}
for function_no in ns.contracts[contract_no].all_functions.keys() {
let func = &ns.functions[*function_no];
for stmt in &func.body {
stmt.recurse(&mut call_list, check_statement);
}
}
while !call_list.solidity.is_empty() || !call_list.yul.is_empty() {
let mut new_call_list = CallList::default();
for function_no in &call_list.solidity {
let func = &ns.functions[*function_no];
for stmt in &func.body {
stmt.recurse(&mut new_call_list, check_statement);
}
}
for function_no in &call_list.solidity {
if ns.functions[*function_no].loc != pt::Loc::Builtin {
ns.contracts[contract_no]
.all_functions
.insert(*function_no, usize::MAX);
}
}
for yul_function_no in &call_list.yul {
ns.contracts[contract_no]
.yul_functions
.push(*yul_function_no);
}
call_list.solidity.clear();
call_list.yul.clear();
for function_no in &new_call_list.solidity {
if !ns.contracts[contract_no]
.all_functions
.contains_key(function_no)
{
call_list.solidity.insert(*function_no);
}
}
for yul_func_no in &new_call_list.yul {
ns.contracts[contract_no].yul_functions.push(*yul_func_no);
}
}
let mut emits_events = Vec::new();
for function_no in ns.contracts[contract_no].all_functions.keys() {
let func = &ns.functions[*function_no];
for event_no in &func.emits_events {
if !emits_events.contains(event_no) {
emits_events.push(*event_no);
}
}
}
ns.contracts[contract_no].emits_events = emits_events;
}
fn check_expression(expr: &Expression, call_list: &mut CallList) -> bool {
if let Expression::InternalFunction { function_no, .. } = expr {
call_list.solidity.insert(*function_no);
}
true
}
fn check_statement(stmt: &Statement, call_list: &mut CallList) -> bool {
match stmt {
Statement::VariableDecl(_, _, _, Some(expr)) => {
expr.recurse(call_list, check_expression);
}
Statement::VariableDecl(_, _, _, None) => (),
Statement::If(_, _, cond, _, _) => {
cond.recurse(call_list, check_expression);
}
Statement::For {
cond: Some(cond), ..
} => {
cond.recurse(call_list, check_expression);
}
Statement::For { cond: None, .. } => (),
Statement::DoWhile(_, _, _, cond) | Statement::While(_, _, cond, _) => {
cond.recurse(call_list, check_expression);
}
Statement::Expression(_, _, expr) => {
expr.recurse(call_list, check_expression);
}
Statement::Delete(_, _, expr) => {
expr.recurse(call_list, check_expression);
}
Statement::Destructure(_, fields, expr) => {
expr.recurse(call_list, check_expression);
for field in fields {
if let DestructureField::Expression(expr) = field {
expr.recurse(call_list, check_expression);
}
}
}
Statement::Return(_, exprs) => {
for e in exprs {
e.recurse(call_list, check_expression);
}
}
Statement::TryCatch(_, _, try_catch) => {
try_catch.expr.recurse(call_list, check_expression);
}
Statement::Emit { args, .. } => {
for e in args {
e.recurse(call_list, check_expression);
}
}
Statement::Block { .. }
| Statement::Break(_)
| Statement::Continue(_)
| Statement::Underscore(_) => (),
Statement::Assembly(inline_assembly, _) => {
for func_no in inline_assembly.functions.start..inline_assembly.functions.end {
call_list.yul.insert(func_no);
}
}
}
true
}