use proc_macro::TokenStream;
use quote::quote;
use syn::{
parse_macro_input, spanned::Spanned, Error, FnArg, Ident, Type, ItemFn, Pat, PatIdent, ReturnType,
parse::{Parse, ParseStream}, punctuated::Punctuated, Token
};
use std::collections::HashSet;
struct KeyArgs {
args: Vec<Ident>,
}
impl Parse for KeyArgs {
fn parse(input: ParseStream) -> syn::Result<Self> {
let args = Punctuated::<Ident, Token![,]>::parse_terminated(input)?;
Ok(Self {
args: args.into_iter().collect(),
})
}
}
#[proc_macro_attribute]
pub fn memo(attr: TokenStream, item: TokenStream) -> TokenStream {
let input_fn = parse_macro_input!(item as ItemFn);
let _key_arg_names = parse_macro_input!(attr as KeyArgs).args;
let fn_name = &input_fn.sig.ident;
let fn_vis = &input_fn.vis;
let fn_block = &input_fn.block;
let fn_inputs = &input_fn.sig.inputs;
let fn_output = &input_fn.sig.output;
let fn_sig = &input_fn.sig;
let no_cache_name = Ident::new(&format!("{}_no_cache", fn_name), fn_name.span());
let cache_name = Ident::new(&fn_name.to_string().to_uppercase(), fn_name.span());
let _global_cache_name = Ident::new(&format!("global_cache_{}", fn_name), fn_name.span());
let (args, param_types): (Vec<_>, Vec<_>) = input_fn
.sig
.inputs
.iter()
.map(|arg| match arg {
FnArg::Typed(pat_type) => {
let ident = match &*pat_type.pat {
Pat::Ident(PatIdent { ident, .. }) => ident.clone(),
_ => {
return Err(Error::new(
pat_type.span(),
"only simple identifiers are supported",
))
}
};
let ty = &*pat_type.ty;
Ok((ident, ty))
}
_ => Err(Error::new(arg.span(), "self parameters are not supported")),
})
.collect::<Result<Vec<_>, _>>()
.unwrap()
.into_iter()
.unzip();
let mut immutable_references = HashSet::<String>::new();
for arg in input_fn.sig.inputs.iter() {
if let FnArg::Typed(pat_type) = arg {
if let Type::Reference(ty_ref) = &*pat_type.ty {
if ty_ref.mutability.is_none() {
if let Pat::Ident(pat_ident) = &*pat_type.pat {
immutable_references.insert(pat_ident.ident.to_string());
}
} else {
if let Pat::Ident(pat_ident) = &*pat_type.pat {
return Error::new(ty_ref.span(), format!("memo supports only immutable references in parameters, but {} is mutable", pat_ident.ident))
.to_compile_error()
.into();
}
}
}
}
}
if args.is_empty() { return quote! {#fn_vis #fn_sig #fn_block}.into() ; }
let (key_args, key_types): (Vec<_>, Vec<_>) = args.iter()
.zip(param_types.into_iter())
.filter(|(arg, _)| !immutable_references.contains(&arg.to_string()))
.map(|(arg, ty)| (arg.clone(), ty.clone()))
.unzip();
let key_type = if key_types.len() == 1 {
quote! { #(#key_types)* }
} else {
quote! { (#(#key_types),*) }
};
let key_exprs = key_args.iter().map(|arg| quote! { #arg.clone() });
let key_tuple = quote! { (#(#key_exprs),*) };
let call_args = args.iter().map(|arg| quote! { #arg });
let return_type = match fn_output {
ReturnType::Default => quote! { () },
ReturnType::Type(_, ty) => quote! { #ty },
};
let create_cache = quote! {
static #cache_name: ::std::sync::LazyLock<::std::sync::Mutex<::std::collections::HashMap<#key_type, #return_type>>> = ::std::sync::LazyLock::new(|| {
::std::sync::Mutex::new(::std::collections::HashMap::new())
});
};
let no_cache_fn = quote! {
#fn_vis fn #no_cache_name(#fn_inputs) #fn_output #fn_block
};
let cached_fn = quote! {
#fn_vis fn #fn_name(#fn_inputs) #fn_output {
let key = #key_tuple;
{
let cache = #cache_name.lock().unwrap();
if let Some(result) = cache.get(&key) {
return result.clone();
}
}
let result = #no_cache_name(#(#call_args),*);
let mut cache = #cache_name.lock().unwrap();
cache.insert(key, result.clone());
result
}
};
let expanded = quote! {
#create_cache
#no_cache_fn
#cached_fn
};
expanded.into()
}