memoize_inner/
lib.rs

1#![crate_type = "proc-macro"]
2#![allow(unused_imports)] // Spurious complaints about a required trait import.
3use syn::{self, parse, parse_macro_input, spanned::Spanned, Expr, ExprCall, ItemFn, Path};
4
5use proc_macro::TokenStream;
6use quote::{self, ToTokens};
7
8mod kw {
9    syn::custom_keyword!(Capacity);
10    syn::custom_keyword!(TimeToLive);
11    syn::custom_keyword!(SharedCache);
12    syn::custom_keyword!(CustomHasher);
13    syn::custom_keyword!(HasherInit);
14    syn::custom_keyword!(Ignore);
15    syn::custom_punctuation!(Colon, :);
16}
17
18#[derive(Default, Clone)]
19struct CacheOptions {
20    lru_max_entries: Option<usize>,
21    time_to_live: Option<Expr>,
22    shared_cache: bool,
23    custom_hasher: Option<Path>,
24    custom_hasher_initializer: Option<ExprCall>,
25    ignore: Vec<syn::Ident>,
26}
27
28#[derive(Clone)]
29enum CacheOption {
30    LRUMaxEntries(usize),
31    TimeToLive(Expr),
32    SharedCache,
33    CustomHasher(Path),
34    HasherInit(ExprCall),
35    Ignore(syn::Ident),
36}
37
38// To extend option parsing, add functionality here.
39#[allow(unreachable_code)]
40impl parse::Parse for CacheOption {
41    fn parse(input: parse::ParseStream) -> syn::Result<Self> {
42        let la = input.lookahead1();
43        if la.peek(kw::Capacity) {
44            #[cfg(not(feature = "full"))]
45            return Err(syn::Error::new(input.span(),
46            "memoize error: Capacity specified, but the feature 'full' is not enabled! To fix this, compile with `--features=full`.",
47            ));
48
49            input.parse::<kw::Capacity>().unwrap();
50            input.parse::<kw::Colon>().unwrap();
51            let cap: syn::LitInt = input.parse().unwrap();
52
53            return Ok(CacheOption::LRUMaxEntries(cap.base10_parse()?));
54        }
55        if la.peek(kw::TimeToLive) {
56            #[cfg(not(feature = "full"))]
57            return Err(syn::Error::new(input.span(),
58            "memoize error: TimeToLive specified, but the feature 'full' is not enabled! To fix this, compile with `--features=full`.",
59            ));
60
61            input.parse::<kw::TimeToLive>().unwrap();
62            input.parse::<kw::Colon>().unwrap();
63            let cap: syn::Expr = input.parse().unwrap();
64
65            return Ok(CacheOption::TimeToLive(cap));
66        }
67        if la.peek(kw::SharedCache) {
68            input.parse::<kw::SharedCache>().unwrap();
69            return Ok(CacheOption::SharedCache);
70        }
71        if la.peek(kw::CustomHasher) {
72            input.parse::<kw::CustomHasher>().unwrap();
73            input.parse::<kw::Colon>().unwrap();
74            let cap: syn::Path = input.parse().unwrap();
75            return Ok(CacheOption::CustomHasher(cap));
76        }
77        if la.peek(kw::HasherInit) {
78            input.parse::<kw::HasherInit>().unwrap();
79            input.parse::<kw::Colon>().unwrap();
80            let cap: syn::ExprCall = input.parse().unwrap();
81            return Ok(CacheOption::HasherInit(cap));
82        }
83        if la.peek(kw::Ignore) {
84            input.parse::<kw::Ignore>().unwrap();
85            input.parse::<kw::Colon>().unwrap();
86            let ignore_ident = input.parse::<syn::Ident>().unwrap();
87            return Ok(CacheOption::Ignore(ignore_ident));
88        }
89        Err(la.error())
90    }
91}
92
93impl parse::Parse for CacheOptions {
94    fn parse(input: parse::ParseStream) -> syn::Result<Self> {
95        let f: syn::punctuated::Punctuated<CacheOption, syn::Token![,]> =
96            input.parse_terminated(CacheOption::parse)?;
97        let mut opts = Self::default();
98
99        for opt in f {
100            match opt {
101                CacheOption::LRUMaxEntries(cap) => opts.lru_max_entries = Some(cap),
102                CacheOption::TimeToLive(sec) => opts.time_to_live = Some(sec),
103                CacheOption::CustomHasher(hasher) => opts.custom_hasher = Some(hasher),
104                CacheOption::HasherInit(init) => opts.custom_hasher_initializer = Some(init),
105                CacheOption::SharedCache => opts.shared_cache = true,
106                CacheOption::Ignore(ident) => opts.ignore.push(ident),
107            }
108        }
109        Ok(opts)
110    }
111}
112
113// This implementation of the storage backend does not depend on any more crates.
114#[cfg(not(feature = "full"))]
115mod store {
116    use crate::CacheOptions;
117    use proc_macro::TokenStream;
118
119    /// Returns tokenstreams (for quoting) of the store type and an expression to initialize it.
120    pub(crate) fn construct_cache(
121        _options: &CacheOptions,
122        key_type: proc_macro2::TokenStream,
123        value_type: proc_macro2::TokenStream,
124    ) -> (proc_macro2::TokenStream, proc_macro2::TokenStream) {
125        // This is the unbounded default.
126        if let Some(hasher) = &_options.custom_hasher {
127            return (
128                quote::quote! { #hasher<#key_type, #value_type> },
129                quote::quote! { #hasher::new() },
130            );
131        } else {
132            (
133                quote::quote! { std::collections::HashMap<#key_type, #value_type> },
134                quote::quote! { std::collections::HashMap::new() },
135            )
136        }
137    }
138
139    /// Returns names of methods as TokenStreams to insert and get (respectively) elements from a
140    /// store.
141    pub(crate) fn cache_access_methods(
142        _options: &CacheOptions,
143    ) -> (proc_macro2::TokenStream, proc_macro2::TokenStream) {
144        (quote::quote! { insert }, quote::quote! { get })
145    }
146}
147
148// This implementation of the storage backend also depends on the `lru` crate.
149#[cfg(feature = "full")]
150mod store {
151    use crate::CacheOptions;
152    use proc_macro::TokenStream;
153
154    /// Returns TokenStreams to be used in quote!{} for parametrizing the memoize store variable,
155    /// and initializing it.
156    ///
157    /// First return value: Type of store ("Container<K,V>").
158    /// Second return value: Initializer syntax ("Container::<K,V>::new()").
159    pub(crate) fn construct_cache(
160        options: &CacheOptions,
161        key_type: proc_macro2::TokenStream,
162        value_type: proc_macro2::TokenStream,
163    ) -> (proc_macro2::TokenStream, proc_macro2::TokenStream) {
164        let value_type = match options.time_to_live {
165            None => quote::quote! {#value_type},
166            Some(_) => quote::quote! {(std::time::Instant, #value_type)},
167        };
168        // This is the unbounded default.
169        match options.lru_max_entries {
170            None => {
171                if let Some(hasher) = &options.custom_hasher {
172                    if let Some(hasher_init) = &options.custom_hasher_initializer {
173                        return (
174                            quote::quote! { #hasher<#key_type, #value_type> },
175                            quote::quote! { #hasher_init },
176                        );
177                    } else {
178                        return (
179                            quote::quote! { #hasher<#key_type, #value_type> },
180                            quote::quote! { #hasher::new() },
181                        );
182                    }
183                }
184                (
185                    quote::quote! { std::collections::HashMap<#key_type, #value_type> },
186                    quote::quote! { std::collections::HashMap::new() },
187                )
188            }
189            Some(cap) => {
190                if let Some(_) = &options.custom_hasher {
191                    (
192                        quote::quote! { compile_error!("Cannot use LRU cache and a custom hasher at the same time") },
193                        quote::quote! { std::collections::HashMap::new() },
194                    )
195                } else {
196                    (
197                        quote::quote! { ::memoize::lru::LruCache<#key_type, #value_type> },
198                        quote::quote! { ::memoize::lru::LruCache::new(#cap) },
199                    )
200                }
201            }
202        }
203    }
204
205    /// Returns names of methods as TokenStreams to insert and get (respectively) elements from a
206    /// store.
207    pub(crate) fn cache_access_methods(
208        options: &CacheOptions,
209    ) -> (proc_macro2::TokenStream, proc_macro2::TokenStream) {
210        // This is the unbounded default.
211        match options.lru_max_entries {
212            None => (quote::quote! { insert }, quote::quote! { get }),
213            Some(_) => (quote::quote! { put }, quote::quote! { get }),
214        }
215    }
216}
217
218/**
219 * memoize is an attribute to create a memoized version of a (simple enough) function.
220 *
221 * So far, it works on non-method functions with one or more arguments returning a [`Clone`]-able
222 * value. Arguments that are cached must be [`Clone`]-able and [`Hash`]-able as well. Several clones
223 * happen within the storage and recall layer, with the assumption being that `memoize` is used to
224 * cache such expensive functions that very few `clone()`s do not matter. `memoize` doesn't work on
225 * methods (functions with `[&/&mut/]self` receiver).
226 *
227 * Calls are memoized for the lifetime of a program, using a statically allocated, Mutex-protected
228 * HashMap.
229 *
230 * Memoizing functions is very simple: As long as the above-stated requirements are fulfilled,
231 * simply use the `#[memoize::memoize]` attribute:
232 *
233 * ```
234 * use memoize::memoize;
235 * #[memoize]
236 * fn hello(arg: String, arg2: usize) -> bool {
237 *      arg.len()%2 == arg2
238 * }
239 *
240 * // `hello` is only called once.
241 * assert!(! hello("World".to_string(), 0));
242 * assert!(! hello("World".to_string(), 0));
243 * ```
244 *
245 * If you need to use the un-memoized function, it is always available as `memoized_original_{fn}`,
246 * in this case: `memoized_original_hello()`.
247 *
248 * Parameters can be ignored by the cache using the `Ignore` parameter. `Ignore` can be specified
249 * multiple times, once per each parameter. `Ignore`d parameters do not need to implement [`Clone`]
250 * or [`Hash`]. 
251 * 
252 * See the `examples` for concrete applications.
253 *
254 * *The following descriptions need the `full` feature enabled.*
255 *
256 * The `memoize` attribute can take further parameters in order to use an LRU cache:
257 * `#[memoize(Capacity: 1234)]`. In that case, instead of a `HashMap` we use an `lru::LruCache`
258 * with the given capacity.
259 * `#[memoize(TimeToLive: Duration::from_secs(2))]`. In that case, cached value will be actual
260 * no longer than duration provided and refreshed with next request. If you prefer chrono::Duration,
261 * it can be also used: `#[memoize(TimeToLive: chrono::Duration::hours(9).to_std().unwrap()]`
262 *
263 * You can also specify a custom hasher: `#[memoize(CustomHasher: ahash::HashMap)]`, as some hashers don't use a `new()` method to initialize them, you can also specifiy a `HasherInit` parameter, like this: `#[memoize(CustomHasher: FxHashMap, HasherInit: FxHashMap::default())]`, so it will initialize your `FxHashMap` with `FxHashMap::default()` insteado of `FxHashMap::new()`
264 *
265 * This mechanism can, in principle, be extended (in the source code) to any other cache mechanism.
266 *
267 * `memoized_flush_<function name>()` allows you to clear the underlying memoization cache of a
268 * function. This function is generated with the same visibility as the memoized function.
269 *
270 */
271#[proc_macro_attribute]
272pub fn memoize(attr: TokenStream, item: TokenStream) -> TokenStream {
273    let func = parse_macro_input!(item as ItemFn);
274    let sig = &func.sig;
275
276    let fn_name = &sig.ident.to_string();
277    let renamed_name = format!("memoized_original_{}", fn_name);
278    let flush_name = syn::Ident::new(format!("memoized_flush_{}", fn_name).as_str(), sig.span());
279    let size_name = syn::Ident::new(format!("memoized_size_{}", fn_name).as_str(), sig.span());
280    let map_name = format!("memoized_mapping_{}", fn_name);
281
282    if let Some(syn::FnArg::Receiver(_)) = sig.inputs.first() {
283        return quote::quote! { compile_error!("Cannot memoize methods!"); }.into();
284    }
285
286    // Parse options from macro attributes
287    let options: CacheOptions = syn::parse(attr.clone()).unwrap();
288
289    // Extracted from the function signature.
290    let input_params = match check_signature(sig, &options) {
291        Ok(p) => p,
292        Err(e) => return e.to_compile_error().into(),
293    };
294
295    // Input types and names that are actually stored in the cache.
296    let memoized_input_types: Vec<Box<syn::Type>> = input_params
297        .iter()
298        .filter_map(|p| {
299            if p.is_memoized {
300                Some(p.arg_type.clone())
301            } else {
302                None
303            }
304        })
305        .collect();
306    let memoized_input_names: Vec<syn::Ident> = input_params
307        .iter()
308        .filter_map(|p| {
309            if p.is_memoized {
310                Some(p.arg_name.clone())
311            } else {
312                None
313            }
314        })
315        .collect();
316
317    // For each input, expression to be passe through to the original function.
318    // Cached arguments are cloned, original arguments are forwarded as-is
319    let fn_forwarded_exprs: Vec<_> = input_params
320        .iter()
321        .map(|p| {
322            let ident = p.arg_name.clone();
323            if p.is_memoized {
324                quote::quote! { #ident.clone() }
325            } else {
326                quote::quote! { #ident }
327            }
328        })
329        .collect();
330
331    let input_tuple_type = quote::quote! { (#(#memoized_input_types),*) };
332    let return_type = match &sig.output {
333        syn::ReturnType::Default => quote::quote! { () },
334        syn::ReturnType::Type(_, ty) => ty.to_token_stream(),
335    };
336
337    // Construct storage for the memoized keys and return values.
338    let store_ident = syn::Ident::new(&map_name.to_uppercase(), sig.span());
339    let (cache_type, cache_init) =
340        store::construct_cache(&options, input_tuple_type, return_type.clone());
341    let store = if options.shared_cache {
342        quote::quote! {
343            ::memoize::lazy_static::lazy_static! {
344                static ref #store_ident : std::sync::Mutex<#cache_type> =
345                    std::sync::Mutex::new(#cache_init);
346            }
347        }
348    } else {
349        quote::quote! {
350            std::thread_local! {
351                static #store_ident : std::cell::RefCell<#cache_type> =
352                    std::cell::RefCell::new(#cache_init);
353            }
354        }
355    };
356
357    // Rename original function.
358    let mut renamed_fn = func.clone();
359    renamed_fn.sig.ident = syn::Ident::new(&renamed_name, func.sig.span());
360    let memoized_id = &renamed_fn.sig.ident;
361
362    // Construct memoizer function, which calls the original function.
363    let syntax_names_tuple = quote::quote! { (#(#memoized_input_names),*) };
364    let syntax_names_tuple_cloned = quote::quote! { (#(#memoized_input_names.clone()),*) };
365    let forwarding_tuple = quote::quote! { (#(#fn_forwarded_exprs),*) };
366    let (insert_fn, get_fn) = store::cache_access_methods(&options);
367    let (read_memo, memoize) = match options.time_to_live {
368        None => (
369            quote::quote!(ATTR_MEMOIZE_HM__.#get_fn(&#syntax_names_tuple_cloned).cloned()),
370            quote::quote!(ATTR_MEMOIZE_HM__.#insert_fn(#syntax_names_tuple, ATTR_MEMOIZE_RETURN__.clone());),
371        ),
372        Some(ttl) => (
373            quote::quote! {
374                ATTR_MEMOIZE_HM__.#get_fn(&#syntax_names_tuple_cloned).and_then(|(last_updated, ATTR_MEMOIZE_RETURN__)|
375                    (last_updated.elapsed() < #ttl).then(|| ATTR_MEMOIZE_RETURN__.clone())
376                )
377            },
378            quote::quote!(ATTR_MEMOIZE_HM__.#insert_fn(#syntax_names_tuple, (std::time::Instant::now(), ATTR_MEMOIZE_RETURN__.clone()));),
379        ),
380    };
381
382    let memoizer = if options.shared_cache {
383        quote::quote! {
384            {
385                let mut ATTR_MEMOIZE_HM__ = #store_ident.lock().unwrap();
386                if let Some(ATTR_MEMOIZE_RETURN__) = #read_memo {
387                    return ATTR_MEMOIZE_RETURN__
388                }
389            }
390            let ATTR_MEMOIZE_RETURN__ = #memoized_id #forwarding_tuple;
391
392            let mut ATTR_MEMOIZE_HM__ = #store_ident.lock().unwrap();
393            #memoize
394
395            ATTR_MEMOIZE_RETURN__
396        }
397    } else {
398        quote::quote! {
399            let ATTR_MEMOIZE_RETURN__ = #store_ident.with(|ATTR_MEMOIZE_HM__| {
400                let mut ATTR_MEMOIZE_HM__ = ATTR_MEMOIZE_HM__.borrow_mut();
401                #read_memo
402            });
403            if let Some(ATTR_MEMOIZE_RETURN__) = ATTR_MEMOIZE_RETURN__ {
404                return ATTR_MEMOIZE_RETURN__;
405            }
406
407            let ATTR_MEMOIZE_RETURN__ = #memoized_id #forwarding_tuple;
408
409            #store_ident.with(|ATTR_MEMOIZE_HM__| {
410                let mut ATTR_MEMOIZE_HM__ = ATTR_MEMOIZE_HM__.borrow_mut();
411                #memoize
412            });
413
414            ATTR_MEMOIZE_RETURN__
415        }
416    };
417
418    let vis = &func.vis;
419
420    let flusher = if options.shared_cache {
421        quote::quote! {
422            #vis fn #flush_name() {
423                #store_ident.lock().unwrap().clear();
424            }
425        }
426    } else {
427        quote::quote! {
428            #vis fn #flush_name() {
429                #store_ident.with(|ATTR_MEMOIZE_HM__| ATTR_MEMOIZE_HM__.borrow_mut().clear());
430            }
431        }
432    };
433
434    let size_func = if options.shared_cache {
435        quote::quote! {
436            #vis fn #size_name() -> usize {
437                #store_ident.lock().unwrap().len()
438            }
439        }
440    } else {
441        quote::quote! {
442            #vis fn #size_name() -> usize {
443                #store_ident.with(|ATTR_MEMOIZE_HM__| ATTR_MEMOIZE_HM__.borrow().len())
444            }
445        }
446    };
447
448    quote::quote! {
449        #renamed_fn
450        #flusher
451        #size_func
452        #store
453
454        #[allow(unused_variables, unused_mut)]
455        #vis #sig {
456            #memoizer
457        }
458    }
459    .into()
460}
461
462/// An argument of the memoized function.
463struct FnArgument {
464    /// Type of the argument.
465    arg_type: Box<syn::Type>,
466
467    /// Identifier (name) of the argument.
468    arg_name: syn::Ident,
469
470    /// Whether or not this specific argument is included in the memoization.
471    is_memoized: bool,
472}
473
474fn check_signature(
475    sig: &syn::Signature,
476    options: &CacheOptions,
477) -> Result<Vec<FnArgument>, syn::Error> {
478    if sig.inputs.is_empty() {
479        return Ok(vec![]);
480    }
481
482    let mut params = vec![];
483
484    for a in &sig.inputs {
485        if let syn::FnArg::Typed(ref arg) = a {
486            let arg_type = arg.ty.clone();
487
488            if let syn::Pat::Ident(patident) = &*arg.pat {
489                let arg_name = patident.ident.clone();
490                let is_memoized = !options.ignore.contains(&arg_name);
491                params.push(FnArgument {
492                    arg_type,
493                    arg_name,
494                    is_memoized,
495                });
496            } else {
497                return Err(syn::Error::new(
498                    sig.span(),
499                    "Cannot memoize arbitrary patterns!",
500                ));
501            }
502        }
503    }
504    Ok(params)
505}
506
507#[cfg(test)]
508mod tests {}