use proc_macro::TokenStream;
use quote::{format_ident, quote};
use syn::{ItemFn, Token, WhereClause, parse::Parse, parse_quote};
mod keyword {
syn::custom_keyword!(print_each_pass);
syn::custom_keyword!(print_recursion_counter);
}
#[derive(Default)]
struct RecursionConfig{
print_each_pass: bool,
print_recursion_counter: bool
}
impl Parse for RecursionConfig{
fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
let mut config = RecursionConfig::default();
while input.peek(keyword::print_each_pass) || input.peek(keyword::print_recursion_counter) {
let token_result = input.parse::<keyword::print_each_pass>();
if token_result.is_ok(){
config.print_each_pass = true;
}else{
let token_result = input.parse::<keyword::print_recursion_counter>();
if token_result.is_ok(){
config.print_recursion_counter = true;
}
}
if input.peek(Token![,]){
let _comman_token: Token![,] = input.parse()?;
}
}
Ok(config)
}
}
#[proc_macro_attribute]
pub fn print_recursion_tree(attr: TokenStream, item: TokenStream) -> TokenStream {
let mut item_fn: ItemFn = syn::parse_macro_input!(item);
let config = syn::parse_macro_input!(attr as RecursionConfig);
let print_each_pass_ident = format_ident!("{}", config.print_each_pass);
let print_recursion_counter_ident = format_ident!("{}", config.print_recursion_counter);
let fn_gen = &mut item_fn.sig.generics;
let fn_gen_params = &fn_gen.params;
let type_params_name:Vec<_> = fn_gen_params.iter().filter(|p| match p{
syn::GenericParam::Type(_) => true,
_ => false
}).map(|p| match p{
syn::GenericParam::Type(type_param) => &type_param.ident,
syn::GenericParam::Const(const_param) => &const_param.ident,
syn::GenericParam::Lifetime(life_time) => &life_time.lifetime.ident,
}).collect();
let where_clause_originial = fn_gen.where_clause.clone();
let mut fn_generics_clone = fn_gen.clone();
let where_clause_modified = fn_generics_clone.where_clause.get_or_insert_with(|| WhereClause {
where_token: parse_quote!(where),
predicates: syn::punctuated::Punctuated::new(),
});
type_params_name.iter().for_each(|p| {
where_clause_modified.predicates.push(parse_quote!(#p: std::fmt::Debug));
});
let async_fn = item_fn.sig.asyncness.is_some();
if async_fn{
panic!("async function are not supported by print_recursion_tree");
}
let input_args = &item_fn.sig.inputs;
let fn_ident = &item_fn.sig.ident;
let fn_name = item_fn.sig.ident.to_string();
let fn_return_type = &item_fn.sig.output;
let fn_body = &item_fn.block;
let fn_visibility = &item_fn.vis;
let renamed_fn_name_string = format!("__debug_recursion__{}", &fn_name);
let fn_name_renamed = format_ident!("{}", renamed_fn_name_string);
let recursion_call_counter_ident = format_ident!("{}{}", renamed_fn_name_string.to_uppercase(), "_COUNTER");
let recursion_call_remining_counter_ident = format_ident!("{}{}", renamed_fn_name_string.to_uppercase(), "_REMAINING");
let get_tree_fn_ident = format_ident!("{}{}", "__debug_recursion__get_tree_", &fn_name);
let get_counter_fn_ident = format_ident!("{}{}", "__debug_recursion__get_counter_", &fn_name);
let tree_builder_ident = format_ident!("{}{}", renamed_fn_name_string.to_uppercase(), "_TREE");
let input_args_clone1 = input_args.clone();
let callable_args = input_args_clone1.iter().filter_map(|arg| {
if let syn::FnArg::Typed(pat_type) = arg {
if let syn::Pat::Ident(pat_ident) = &*pat_type.pat {
Some(&pat_ident.ident)
} else {
None
}
} else {
None
}
}).collect::<Vec<_>>();
let renamed_fn: proc_macro2::TokenStream = quote! {
#fn_visibility fn #fn_name_renamed #fn_gen (#input_args) #fn_return_type #where_clause_originial #fn_body
};
let debug_str = callable_args.clone().iter().map(|_| "{:?}").collect::<Vec<&str>>().join(",");
let proxy_fn: proc_macro2::TokenStream = quote! {
#fn_visibility fn #fn_ident #fn_gen (#input_args) #fn_return_type #where_clause_modified {
{
use ptree::TreeBuilder;
use ptree::print_tree;
{
let args = format!(#debug_str, #(#callable_args),*);
let mut total_call_counter = #recursion_call_counter_ident.lock().unwrap();
let mut remaining_call_counter = #recursion_call_remining_counter_ident.lock().unwrap();
let mut tree = #tree_builder_ident.lock().unwrap();
if *remaining_call_counter == 0{
*tree = TreeBuilder::new(#fn_name.to_string());
*total_call_counter = 0;
}
*total_call_counter = *total_call_counter + 1;
*remaining_call_counter = *remaining_call_counter + 1;
tree.begin_child(args);
}
let result = #fn_name_renamed(#(#callable_args),*);
let mut print_flag = false;
{
let mut tree = #tree_builder_ident.lock().unwrap();
let mut remaining_call_counter = #recursion_call_remining_counter_ident.lock().unwrap();
let mut total_call_counter = #recursion_call_counter_ident.lock().unwrap();
if *remaining_call_counter > 0{
*remaining_call_counter = *remaining_call_counter - 1;
}
tree.add_empty_child(format!("={:?}", result));
print_flag = #print_each_pass_ident || *remaining_call_counter == 0;
if print_flag {
println!("---------------");
print_tree(&tree.build()).unwrap();
println!("---------------");
}
tree.end_child();
}
if #print_recursion_counter_ident && print_flag {
let count = #get_counter_fn_ident();
println!("Total Number Of Recursions: {}", count);
}
return result;
}
}
};
let get_tree_fn: proc_macro2::TokenStream = quote! {
fn #get_tree_fn_ident() -> String{
let mut tree = #tree_builder_ident.lock().unwrap();
let mut tree_as_vec = Vec::<u8>::new();
ptree::write_tree(&(*tree).build(), &mut tree_as_vec).unwrap();
String::from_utf8(tree_as_vec).unwrap()
}
};
let get_counter_fn: proc_macro2::TokenStream = quote! {
fn #get_counter_fn_ident() -> u16{
let counter = #recursion_call_counter_ident.lock().unwrap();
if *counter > 0 {
return *counter - 1;
}
return *counter;
}
};
let lazy_static_initialization: proc_macro2::TokenStream = quote! {
lazy_static::lazy_static! {
static ref #recursion_call_counter_ident: std::sync::Mutex<u16> = std::sync::Mutex::new(0u16);
static ref #recursion_call_remining_counter_ident: std::sync::Mutex<u16> = std::sync::Mutex::new(0u16);
static ref #tree_builder_ident:std::sync::Mutex<ptree::TreeBuilder> = std::sync::Mutex::new(ptree::TreeBuilder::new("tree".to_string()));
}
};
let mut token_stream: proc_macro2::TokenStream = renamed_fn.into();
token_stream.extend(get_counter_fn.into_iter());
token_stream.extend(get_tree_fn.into_iter());
token_stream.extend(lazy_static_initialization.into_iter());
token_stream.extend(proxy_fn.into_iter());
return token_stream.into();
}