use syn::{
Block, Local, ExprIf, ExprWhile,
ExprForLoop, ExprMethodCall, ExprBlock,
ItemFn, Expr, ExprCall, Stmt, ExprClosure};
use proc_macro2::TokenStream as ProcTokenStream;
use std::collections::HashSet;
use proc_macro::TokenStream;
use quote::quote;
pub fn assert_call_impl(whitelist: &[String], function: &ItemFn) -> ProcTokenStream {
let mut errors = Vec::new();
let block: &Box<Block> = &function.block;
let mut called_functions = HashSet::new();
check_block_for_calls(block, whitelist, &mut errors, &mut called_functions);
let whitelist_set: HashSet<String> = whitelist.iter().cloned().collect();
let missed_calls: Vec<_> = whitelist_set.difference(&called_functions).collect();
if !missed_calls.is_empty() {
for missed in missed_calls {
errors.push(Error::new(format!("Function `{}` not called", missed)));
}
}
if !errors.is_empty() {
let mut error_message = String::from("Function missing required calls:\n");
for error in &errors {
error_message.push_str(&format!(" - {}\n", error.message));
}
return TokenStream::from(quote! {
compile_error!(#error_message);
}).into();
}
TokenStream::from(quote! { #function }).into()
}
#[derive(Debug)]
struct Error {
message: String,
}
impl Error {
fn new(message: String) -> Self {
Error { message }
}
}
fn check_whitelist(
name: &str,
whitelist: &[String],
_errors: &mut Vec<Error>,
called_functions: &mut HashSet<String>
) {
if whitelist.contains(&name.to_string()) {
called_functions.insert(name.to_string());
}
}
fn _print_ast<T>(item: &T, label: &str)
where
T: quote::ToTokens,
{
let tokens: ProcTokenStream = quote! { #item };
let item_string = tokens.to_string();
println!("{}: {}", label, item_string);
}
fn check_block_for_calls(
block: &Block,
whitelist: &[String],
errors: &mut Vec<Error>,
called_functions: &mut HashSet<String>
) {
for stmt in &block.stmts {
match stmt {
Stmt::Expr(expr, _) => {
check_expr_for_calls(expr, whitelist, errors, called_functions);
}
Stmt::Local(Local { init, .. }) => {
if let Some(init) = init {
check_expr_for_calls(&init.expr, whitelist, errors, called_functions);
}
}
_ => {}
}
}
}
fn check_expr_for_calls(
expr: &Expr,
whitelist: &[String],
errors: &mut Vec<Error>,
called_functions: &mut HashSet<String>
) {
match expr {
Expr::Call(ExprCall { func, .. }) => {
if let Expr::Path(path) = &**func {
let func_name = path.path.segments.last()
.map(|seg| seg.ident.to_string());
if let Some(func_name) = func_name {
check_whitelist(
&func_name,
whitelist,
errors,
called_functions
);
}
}
}
Expr::MethodCall(ExprMethodCall { method, .. }) => {
let method_name = method.to_string();
check_whitelist(
&method_name,
whitelist,
errors,
called_functions
);
}
Expr::Block(ExprBlock { block, .. }) => {
check_block_for_calls(block, whitelist, errors, called_functions);
}
Expr::If(ExprIf { then_branch, else_branch, .. }) => {
check_block_for_calls(&then_branch, whitelist, errors, called_functions);
if let Some((_, else_expr)) = else_branch {
match &**else_expr {
Expr::Block(ExprBlock { block, .. }) => {
check_block_for_calls(&block, whitelist, errors, called_functions);
},
_ => check_expr_for_calls(expr, whitelist, errors, called_functions),
}
}
}
Expr::While(ExprWhile { body, .. }) => {
check_block_for_calls(&body, whitelist, errors, called_functions);
}
Expr::ForLoop(ExprForLoop { body, .. }) => {
check_block_for_calls(&body, whitelist, errors, called_functions);
}
Expr::Closure(ExprClosure { body, .. }) => {
if let Expr::Block(ExprBlock { block, .. }) = &**body {
check_block_for_calls(block, whitelist, errors, called_functions);
} else {
check_expr_for_calls(body, whitelist, errors, called_functions);
}
}
_ => {} }
}