use proc_macro::TokenStream;
use quote::quote;
use syn::{
parse_macro_input, spanned::Spanned, Error, FnArg, Ident, ItemFn, Pat, PatIdent, ReturnType,
};
#[proc_macro_attribute]
pub fn memo(_attr: TokenStream, item: TokenStream) -> TokenStream {
let input_fn = parse_macro_input!(item as ItemFn);
let fn_name = &input_fn.sig.ident;
let fn_vis = &input_fn.vis;
let fn_block = &input_fn.block;
let fn_inputs = &input_fn.sig.inputs;
let fn_output = &input_fn.sig.output;
let no_cache_name = Ident::new(&format!("{}_no_cache", fn_name), fn_name.span());
let cache_name = Ident::new(&fn_name.to_string().to_uppercase(), fn_name.span());
let _global_cache_name = Ident::new(&format!("global_cache_{}", fn_name), fn_name.span());
let (args, param_types): (Vec<_>, Vec<_>) = input_fn
.sig
.inputs
.iter()
.map(|arg| match arg {
FnArg::Typed(pat_type) => {
let ident = match &*pat_type.pat {
Pat::Ident(PatIdent { ident, .. }) => ident.clone(),
_ => {
return Err(Error::new(
pat_type.span(),
"only simple identifiers are supported",
))
}
};
let ty = &*pat_type.ty;
Ok((ident, ty))
}
_ => Err(Error::new(arg.span(), "self parameters are not supported")),
})
.collect::<Result<Vec<_>, _>>()
.unwrap()
.into_iter()
.unzip();
let key_type = if param_types.len() == 1 {
quote! { #(#param_types)* }
} else {
quote! { (#(#param_types),*) }
};
let return_type = match fn_output {
ReturnType::Default => quote! { () },
ReturnType::Type(_, ty) => quote! { #ty },
};
let key_exprs = args.iter().map(|arg| quote! { #arg.clone() });
let key_tuple = quote! { (#(#key_exprs),*) };
let call_args = args.iter().map(|arg| quote! { #arg });
let create_cache = quote! {
static #cache_name: ::std::sync::LazyLock<::std::sync::Mutex<::std::collections::HashMap<#key_type, #return_type>>> = ::std::sync::LazyLock::new(|| {
::std::sync::Mutex::new(::std::collections::HashMap::new())
});
};
let no_cache_fn = quote! {
#fn_vis fn #no_cache_name(#fn_inputs) #fn_output #fn_block
};
let cached_fn = quote! {
#fn_vis fn #fn_name(#fn_inputs) #fn_output {
let key = #key_tuple;
{
let cache = #cache_name.lock().unwrap();
if let Some(result) = cache.get(&key) {
return result.clone();
}
}
let result = #no_cache_name(#(#call_args),*);
let mut cache = #cache_name.lock().unwrap();
cache.insert(key, result.clone());
result
}
};
let expanded = quote! {
#create_cache
#no_cache_fn
#cached_fn
};
expanded.into()
}