fncache_macros/
lib.rs

1use proc_macro::TokenStream;
2use quote::quote;
3use syn::punctuated::Punctuated;
4use syn::{parse::Parse, parse::ParseStream, parse_macro_input, ItemFn, Lit, Token};
5use syn::{Error, Result};
6
7/// Enum to represent different key derivation strategies
8enum KeyDerivation {
9    Runtime,
10    CompileTime,
11}
12
13/// Parse the attributes passed to the fncache macro
14struct FncacheArgs {
15    ttl: Option<u64>,
16    key_derivation: KeyDerivation,
17}
18
19impl Parse for FncacheArgs {
20    fn parse(input: ParseStream) -> Result<Self> {
21        let vars = Punctuated::<syn::MetaNameValue, Token![,]>::parse_terminated(input)?;
22
23        let mut ttl = None;
24        let mut key_derivation = KeyDerivation::Runtime;
25
26        for var in vars {
27            let ident = var
28                .path
29                .get_ident()
30                .ok_or_else(|| Error::new_spanned(&var.path, "Expected identifier"))?;
31
32            if ident == "ttl" {
33                match &var.lit {
34                    Lit::Int(lit) => {
35                        ttl = Some(lit.base10_parse()?);
36                    }
37                    _ => return Err(Error::new_spanned(&var.lit, "ttl must be an integer")),
38                }
39            } else if ident == "key_derivation" {
40                match &var.lit {
41                    Lit::Str(lit_str) => {
42                        let value = lit_str.value();
43                        if value == "runtime" {
44                            key_derivation = KeyDerivation::Runtime;
45                        } else if value == "compile_time" {
46                            key_derivation = KeyDerivation::CompileTime;
47                        } else {
48                            return Err(Error::new_spanned(
49                                &var.lit,
50                                "key_derivation must be either 'runtime' or 'compile_time'",
51                            ));
52                        }
53                    }
54                    _ => {
55                        return Err(Error::new_spanned(
56                            &var.lit,
57                            "key_derivation must be a string literal",
58                        ))
59                    }
60                }
61            }
62        }
63
64        Ok(FncacheArgs {
65            ttl,
66            key_derivation,
67        })
68    }
69}
70
71#[proc_macro_attribute]
72pub fn fncache(attr: TokenStream, item: TokenStream) -> TokenStream {
73    let args = syn::parse_macro_input::parse::<FncacheArgs>(attr.clone()).unwrap_or_else(|_| {
74        FncacheArgs {
75            ttl: None,
76            key_derivation: KeyDerivation::Runtime,
77        }
78    });
79
80    let use_compile_time_keys = match args.key_derivation {
81        KeyDerivation::CompileTime => true,
82        KeyDerivation::Runtime => false,
83    };
84
85    let ttl_seconds = args.ttl.unwrap_or(60);
86
87    let input_fn = parse_macro_input!(item as ItemFn);
88
89    let vis = &input_fn.vis;
90    let sig = &input_fn.sig;
91    let block = &input_fn.block;
92    let attrs = &input_fn.attrs;
93
94    let fn_name = &sig.ident;
95    let asyncness = &sig.asyncness;
96    let _generics = &sig.generics;
97    let inputs = &sig.inputs;
98    let _output = &sig.output;
99
100    let is_async = asyncness.is_some();
101
102    let arg_names = inputs.iter().map(|arg| match arg {
103        syn::FnArg::Receiver(_) => quote! { self },
104        syn::FnArg::Typed(pat_type) => {
105            if let syn::Pat::Ident(pat_ident) = &*pat_type.pat {
106                let ident = &pat_ident.ident;
107                quote! { #ident }
108            } else {
109                quote! { _ }
110            }
111        }
112    });
113
114    let arg_names1: Vec<_> = arg_names.clone().collect();
115    let _arg_names2: Vec<_> = arg_names.collect();
116
117    let expanded = if is_async {
118        quote! {
119            #(#attrs)*
120            #vis #sig {
121                use fncache::backends::CacheBackend;
122                use std::time::Duration;
123                use futures::TryFutureExt;
124
125                let key = if #use_compile_time_keys {
126                    format!("{}-ct-{}", module_path!(), stringify!(#fn_name))
127                } else {
128                    format!("{}-{:?}", stringify!(#fn_name), (#(&(#arg_names1)),*))
129                };
130
131                if let Ok(cache_guard) = fncache::global_cache().lock() {
132                    if let Ok(Some(cached)) = cache_guard.get(&key).await {
133                        if let Ok(deserialized) = bincode::deserialize::<_>(&cached) {
134                            return deserialized;
135                        }
136                    }
137                }
138
139                let result = #block;
140
141                if let Ok(serialized) = bincode::serialize(&result) {
142                    if let Ok(mut cache_guard) = fncache::global_cache().lock() {
143                        let _ = cache_guard.set(
144                            key,
145                            serialized,
146                            Some(Duration::from_secs(#ttl_seconds))
147                        ).await;
148                    }
149                }
150
151                result
152            }
153        }
154    } else {
155        quote! {
156            #(#attrs)*
157            #vis #sig {
158                use fncache::backends::CacheBackend;
159                use std::time::Duration;
160                use futures::executor;
161
162                let key = if #use_compile_time_keys {
163                    format!("{}-ct-{}", module_path!(), stringify!(#fn_name))
164                } else {
165                    format!("{}-{:?}", stringify!(#fn_name), (#(&(#arg_names1)),*))
166                };
167
168                if let Ok(cache_guard) = fncache::global_cache().lock() {
169                    if let Ok(Some(cached)) = executor::block_on(cache_guard.get(&key)) {
170                        if let Ok(deserialized) = bincode::deserialize::<_>(&cached) {
171                            return deserialized;
172                        }
173                    }
174                }
175
176                let result = #block;
177
178                if let Ok(serialized) = bincode::serialize(&result) {
179                    if let Ok(mut cache_guard) = fncache::global_cache().lock() {
180                        let _ = executor::block_on(cache_guard.set(
181                            key,
182                            serialized,
183                            Some(Duration::from_secs(#ttl_seconds))
184                        ));
185                    }
186                }
187
188                result
189            }
190        }
191    };
192
193    expanded.into()
194}