1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{parse_macro_input, ItemFn, ReturnType};
4
5#[proc_macro_attribute]
6pub fn memoize(_args: TokenStream, input: TokenStream) -> TokenStream {
7 let input_fn = parse_macro_input!(input as ItemFn);
8
9 let vis = &input_fn.vis;
10 let sig = &input_fn.sig;
11 let fn_name = &sig.ident;
12 let inputs = &sig.inputs;
13 let block = &input_fn.block;
14
15 let is_async=sig.asyncness.is_some();
16
17 let output_type = match &sig.output {
18 ReturnType::Default => {
19 return syn::Error::new_spanned(sig, "Functions must have a return type to be memoized.")
20 .to_compile_error()
21 .into();
22 }
23 ReturnType::Type(_, ty) => ty.clone(),
24 };
25
26 let mut arg_idents = Vec::new();
27 let mut arg_types = Vec::new();
28 for input in inputs.iter() {
29 if let syn::FnArg::Typed(pat_ty) = input {
30 if let syn::Pat::Ident(pat_ident) = &*pat_ty.pat {
31 arg_idents.push(pat_ident.ident.clone());
32 arg_types.push(pat_ty.ty.clone());
33 } else {
34 return syn::Error::new_spanned(&pat_ty.pat, "Argument must be a simple identifier")
35 .to_compile_error()
36 .into();
37 }
38 } else {
39 return syn::Error::new_spanned(input, "Methods with `self` are not supported by this macro.")
40 .to_compile_error()
41 .into();
42 }
43 }
44
45 if arg_idents.len() != 1 {
46 return syn::Error::new_spanned(
47 sig,
48 "The #[memoize] macro currently supports exactly one argument."
49 )
50 .to_compile_error()
51 .into();
52 }
53
54 let arg_ident = &arg_idents[0];
55 let arg_type = &arg_types[0];
56
57 let fn_name_caps=fn_name.to_string().to_uppercase();
58 let memoizer_name = syn::Ident::new(&format!("__{}_MEMOIZER", fn_name_caps), fn_name.span());
59
60 let memoizer_type = if is_async {
61 quote! { memoizee::AsyncMemoizer::<#arg_type, #output_type> }
62 } else {
63 quote! { memoizee::SyncMemoizer::<#arg_type, #output_type> }
64 };
65
66 let gen_memoizer = if is_async {
67 quote! {
68 memoizee::Lazy::new(|| {
69 #memoizer_type::new(move |key: #arg_type| {
70 Box::pin(async move {
71 let #arg_ident = key;
72 #block
73 })
74 })
75 })
76 }
77 } else {
78 quote! {
79 memoizee::Lazy::new(|| {
80 #memoizer_type::new(move |key: #arg_type| {
81 let #arg_ident = key;
82 #block
83 })
84 })
85 }
86 };
87
88 let call_memoizer = if is_async {
89 quote! {
90 #memoizer_name.of(#arg_ident).await
91 }
92 } else {
93 quote! {
94 #memoizer_name.of(#arg_ident)
95 }
96 };
97
98 let expanded = quote! {
99 static #memoizer_name: memoizee::Lazy<#memoizer_type> = #gen_memoizer;
100
101 #vis #sig {
102 #call_memoizer
103 }
104 };
105
106 println!("GENERATED\n{expanded}");
107
108 expanded.into()
109}