1use proc_macro::TokenStream;
2use quote::quote;
3use syn::punctuated::Punctuated;
4use syn::{parse::Parse, parse::ParseStream, parse_macro_input, ItemFn, Lit, Token};
5use syn::{Error, Result};
6
7enum KeyDerivation {
9 Runtime,
10 CompileTime,
11}
12
13struct FncacheArgs {
15 ttl: Option<u64>,
16 key_derivation: KeyDerivation,
17}
18
19impl Parse for FncacheArgs {
20 fn parse(input: ParseStream) -> Result<Self> {
21 let vars = Punctuated::<syn::MetaNameValue, Token![,]>::parse_terminated(input)?;
22
23 let mut ttl = None;
24 let mut key_derivation = KeyDerivation::Runtime;
25
26 for var in vars {
27 let ident = var
28 .path
29 .get_ident()
30 .ok_or_else(|| Error::new_spanned(&var.path, "Expected identifier"))?;
31
32 if ident == "ttl" {
33 match &var.lit {
34 Lit::Int(lit) => {
35 ttl = Some(lit.base10_parse()?);
36 }
37 _ => return Err(Error::new_spanned(&var.lit, "ttl must be an integer")),
38 }
39 } else if ident == "key_derivation" {
40 match &var.lit {
41 Lit::Str(lit_str) => {
42 let value = lit_str.value();
43 if value == "runtime" {
44 key_derivation = KeyDerivation::Runtime;
45 } else if value == "compile_time" {
46 key_derivation = KeyDerivation::CompileTime;
47 } else {
48 return Err(Error::new_spanned(
49 &var.lit,
50 "key_derivation must be either 'runtime' or 'compile_time'",
51 ));
52 }
53 }
54 _ => {
55 return Err(Error::new_spanned(
56 &var.lit,
57 "key_derivation must be a string literal",
58 ))
59 }
60 }
61 }
62 }
63
64 Ok(FncacheArgs {
65 ttl,
66 key_derivation,
67 })
68 }
69}
70
71#[proc_macro_attribute]
72pub fn fncache(attr: TokenStream, item: TokenStream) -> TokenStream {
73 let args = syn::parse_macro_input::parse::<FncacheArgs>(attr.clone()).unwrap_or_else(|_| {
74 FncacheArgs {
75 ttl: None,
76 key_derivation: KeyDerivation::Runtime,
77 }
78 });
79
80 let use_compile_time_keys = match args.key_derivation {
81 KeyDerivation::CompileTime => true,
82 KeyDerivation::Runtime => false,
83 };
84
85 let ttl_seconds = args.ttl.unwrap_or(60);
86
87 let input_fn = parse_macro_input!(item as ItemFn);
88
89 let vis = &input_fn.vis;
90 let sig = &input_fn.sig;
91 let block = &input_fn.block;
92 let attrs = &input_fn.attrs;
93
94 let fn_name = &sig.ident;
95 let asyncness = &sig.asyncness;
96 let _generics = &sig.generics;
97 let inputs = &sig.inputs;
98 let _output = &sig.output;
99
100 let is_async = asyncness.is_some();
101
102 let arg_names = inputs.iter().map(|arg| match arg {
103 syn::FnArg::Receiver(_) => quote! { self },
104 syn::FnArg::Typed(pat_type) => {
105 if let syn::Pat::Ident(pat_ident) = &*pat_type.pat {
106 let ident = &pat_ident.ident;
107 quote! { #ident }
108 } else {
109 quote! { _ }
110 }
111 }
112 });
113
114 let arg_names1: Vec<_> = arg_names.clone().collect();
115 let _arg_names2: Vec<_> = arg_names.collect();
116
117 let expanded = if is_async {
118 quote! {
119 #(#attrs)*
120 #vis #sig {
121 use fncache::backends::CacheBackend;
122 use std::time::Duration;
123 use futures::TryFutureExt;
124
125 let key = if #use_compile_time_keys {
126 format!("{}-ct-{}", module_path!(), stringify!(#fn_name))
127 } else {
128 format!("{}-{:?}", stringify!(#fn_name), (#(&(#arg_names1)),*))
129 };
130
131 if let Ok(cache_guard) = fncache::global_cache().lock() {
132 if let Ok(Some(cached)) = cache_guard.get(&key).await {
133 if let Ok(deserialized) = bincode::deserialize::<_>(&cached) {
134 return deserialized;
135 }
136 }
137 }
138
139 let result = #block;
140
141 if let Ok(serialized) = bincode::serialize(&result) {
142 if let Ok(mut cache_guard) = fncache::global_cache().lock() {
143 let _ = cache_guard.set(
144 key,
145 serialized,
146 Some(Duration::from_secs(#ttl_seconds))
147 ).await;
148 }
149 }
150
151 result
152 }
153 }
154 } else {
155 quote! {
156 #(#attrs)*
157 #vis #sig {
158 use fncache::backends::CacheBackend;
159 use std::time::Duration;
160 use futures::executor;
161
162 let key = if #use_compile_time_keys {
163 format!("{}-ct-{}", module_path!(), stringify!(#fn_name))
164 } else {
165 format!("{}-{:?}", stringify!(#fn_name), (#(&(#arg_names1)),*))
166 };
167
168 if let Ok(cache_guard) = fncache::global_cache().lock() {
169 if let Ok(Some(cached)) = executor::block_on(cache_guard.get(&key)) {
170 if let Ok(deserialized) = bincode::deserialize::<_>(&cached) {
171 return deserialized;
172 }
173 }
174 }
175
176 let result = #block;
177
178 if let Ok(serialized) = bincode::serialize(&result) {
179 if let Ok(mut cache_guard) = fncache::global_cache().lock() {
180 let _ = executor::block_on(cache_guard.set(
181 key,
182 serialized,
183 Some(Duration::from_secs(#ttl_seconds))
184 ));
185 }
186 }
187
188 result
189 }
190 }
191 };
192
193 expanded.into()
194}