cache_macro/
lib.rs

1//! cache-macro
2//! ================
3//!
4//! A procedural macro to automatically cache the result of a function given a set of inputs.
5//!
6//! # Example:
7//!
8//! ```rust
9//! use cache_macro::cache;
10//! use lru_cache::LruCache;
11//!
12//! #[cache(LruCache : LruCache::new(20))]
13//! fn fib(x: u32) -> u64 {
14//!     println!("{:?}", x);
15//!     if x <= 1 {
16//!         1
17//!     } else {
18//!         fib(x - 1) + fib(x - 2)
19//!     }
20//! }
21//!
22//! assert_eq!(fib(19), 6765);
23//! ```
24//!
25//! The above example only calls `fib` twenty times, with the values from 0 to 19. All intermediate
26//! results because of the recursion hit the cache.
27//!
28//! # Usage:
29//!
30//! Simply place `#[cache(CacheType : constructor)]` above your function. The function must obey a few properties
31//! to use lru_cache:
32//!
33//! * All arguments and return values must implement `Clone`.
34//! * The function may not take `self` in any form.
35//!
36//! The `LruCache` type used must accept two generic parameters `<Args, Return>` and must support methods
37//! `get_mut(&K) -> Option<&mut V>` and `insert(K, V)`. The `lru-cache` (for LRU caching)
38//! and `expiring_map` (for time-to-live caching) crates currently meet these requirements.
39//!
40//! Currently, this crate only works on nightly rust. However, once the 2018 edition stabilizes as well as the
41//! procedural macro diagnostic interface, it should be able to run on stable.
42//!
43//! # Configuration:
44//!
45//! The lru_cache macro can be configured by adding additional attributes under `#[cache(...)]`.
46//!
47//! All configuration attributes take the form `#[cache_cfg(...)]`. The available attributes are:
48//!
49//! * `#[cache_cfg(ignore_args = ...)]`
50//!
51//! This allows certain arguments to be ignored for the purposes of caching. That means they are not part of the
52//! hash table key and thus should never influence the output of the function. It can be useful for diagnostic settings,
53//! returning the number of times executed, or other introspection purposes.
54//!
55//! `ignore_args` takes a comma-separated list of variable identifiers to ignore.
56//!
57//! ### Example:
58//! ```rust
59//! use cache_macro::cache;
60//! use lru_cache::LruCache;
61//! #[cache(LruCache : LruCache::new(20))]
62//! #[cache_cfg(ignore_args = call_count)]
63//! fn fib(x: u64, call_count: &mut u32) -> u64 {
64//!     *call_count += 1;
65//!     if x <= 1 {
66//!         1
67//!     } else {
68//!         fib(x - 1, call_count) + fib(x - 2, call_count)
69//!     }
70//! }
71//!
72//! let mut call_count = 0;
73//! assert_eq!(fib(39, &mut call_count), 102_334_155);
74//! assert_eq!(call_count, 40);
75//! ```
76//!
77//! The `call_count` argument can vary, caching is only done based on `x`.
78//!
79//! * `#[cache_cfg(thread_local)]`
80//!
81//! Store the cache in thread-local storage instead of global static storage. This avoids the overhead of Mutex locking,
82//! but each thread will be given its own cache, and all caching will not affect any other thread.
83//!
84//! Expanding on the first example:
85//!
86//! ```rust
87//! use cache_macro::cache;
88//! use lru_cache::LruCache;
89//!
90//! #[cache(LruCache : LruCache::new(20))]
91//! #[cache_cfg(thread_local)]
92//! fn fib(x: u32) -> u64 {
93//!     println!("{:?}", x);
94//!     if x <= 1 {
95//!         1
96//!     } else {
97//!         fib(x - 1) + fib(x - 2)
98//!     }
99//! }
100//!
101//! assert_eq!(fib(19), 6765);
102//! ```
103//!
104//! # Details
105//! The created cache is stored as a static variable protected by a mutex unless the `#[cache_cfg(thread_local)]`
106//! configuration is added.
107//!
108//! With the default settings, the fibonacci example will generate the following code:
109//!
110//! ```rust
111//! fn __lru_base_fib(x: u32) -> u64 {
112//!     if x <= 1 { 1 } else { fib(x - 1) + fib(x - 2) }
113//! }
114//! fn fib(x: u32) -> u64 {
115//!     use lazy_static::lazy_static;
116//!     use std::sync::Mutex;
117//!
118//!     lazy_static! {
119//!         static ref cache: Mutex<::lru_cache::LruCache<(u32,), u64>> =
120//!             Mutex::new(::lru_cache::LruCache::new(20usize));
121//!     }
122//!
123//!     let cloned_args = (x.clone(),);
124//!     let mut cache_unlocked = cache.lock().unwrap();
125//!     let stored_result = cache_unlocked.get_mut(&cloned_args);
126//!     if let Some(stored_result) = stored_result {
127//!         return stored_result.clone();
128//!     };
129//!     drop(cache_unlocked);
130//!     let ret = __lru_base_fib(x);
131//!     let mut cache_unlocked = cache.lock().unwrap();
132//!     cache_unlocked.insert(cloned_args, ret.clone());
133//!     ret
134//! }
135//!
136//! ```
137//!
138//! Whereas, if you use the `#[lru_config(thread_local)]` the generated code will look like:
139//!
140//!
141//! ```rust
142//! fn __lru_base_fib(x: u32) -> u64 {
143//!     if x <= 1 { 1 } else { fib(x - 1) + fib(x - 2) }
144//! }
145//! fn fib(x: u32) -> u64 {
146//!     use std::cell::UnsafeCell;
147//!     use std::thread_local;
148//!
149//!     thread_local!(
150//!          static cache: UnsafeCell<::lru_cache::LruCache<(u32,), u64>> =
151//!              UnsafeCell::new(::lru_cache::LruCache::new(20usize));
152//!     );
153//!
154//!     cache.with(|c|
155//!         {
156//!             let mut cache_ref = unsafe { &mut *c.get() };
157//!             let cloned_args = (x.clone(),);
158//!             let stored_result = cache_ref.get_mut(&cloned_args);
159//!             if let Some(stored_result) = stored_result {
160//!                 stored_result.clone()
161//!             } else {
162//!                 let ret = __lru_base_fib(x);
163//!                 cache_ref.insert(cloned_args, ret.clone());
164//!                 ret
165//!             }
166//!         })
167//! }
168//! ```
169//!
170#![feature(extern_crate_item_prelude)]
171#![feature(proc_macro_diagnostic)]
172#![recursion_limit="128"]
173extern crate proc_macro;
174
175use proc_macro::TokenStream;
176use syn;
177use syn::{Token, parse_quote};
178use syn::spanned::Spanned;
179use syn::punctuated::Punctuated;
180use quote::quote;
181use proc_macro2;
182
183mod config;
184mod error;
185
186use self::error::{DiagnosticError, Result};
187use syn::parse::Parse;
188use syn::parse::ParseStream;
189use syn::parse_macro_input;
190
191struct Attr {
192    cache_type: syn::Type,
193    cache_creation_expr: syn::Expr,
194}
195
196impl Parse for Attr {
197    fn parse(input: ParseStream) -> syn::parse::Result<Self> {
198        let cache_type: syn::Type = input.parse()?;
199        input.parse::<Token![:]>()?;
200        let cache_creation_expr: syn::Expr = input.parse()?;
201        Ok(Attr {
202            cache_type,
203            cache_creation_expr,
204        })
205    }
206}
207
208// Function shim to allow us to use `Result` and the `?` operator.
209#[proc_macro_attribute]
210pub fn cache(attr: TokenStream, item: TokenStream) -> TokenStream {
211    let attr = parse_macro_input!(attr as Attr);
212
213    match lru_cache_impl(attr, item.clone()) {
214        Ok(tokens) => return tokens,
215        Err(e) => {
216            e.emit();
217            return item;
218        }
219    }
220}
221
222// The main entry point for the macro.
223fn lru_cache_impl(attr: Attr, item: TokenStream) -> Result<TokenStream> {
224    let mut original_fn: syn::ItemFn = match syn::parse(item.clone()) {
225        Ok(ast) => ast,
226        Err(e) => {
227            let diag = proc_macro2::Span::call_site().unstable()
228                .error("lru_cache may only be used on functions");
229            return Err(DiagnosticError::new_with_syn_error(diag, e));
230        }
231    };
232
233    let (macro_config, out_attributes) =
234        {
235            let attribs = &original_fn.attrs[..];
236            config::Config::parse_from_attributes(attribs)?
237        };
238    original_fn.attrs = out_attributes;
239
240    let mut new_fn = original_fn.clone();
241
242    let return_type = get_cache_fn_return_type(&original_fn)?;
243
244    let new_name = format!("__lru_base_{}", original_fn.ident.to_string());
245    original_fn.ident = syn::Ident::new(&new_name[..], original_fn.ident.span());
246
247    let (call_args, types, cache_args) = get_args_and_types(&original_fn, &macro_config)?;
248    let cloned_args = make_cloned_args_tuple(&cache_args);
249    let fn_path = path_from_ident(original_fn.ident.clone());
250
251    let fn_call = syn::ExprCall {
252        attrs: Vec::new(),
253        paren_token: syn::token::Paren { span: proc_macro2::Span::call_site() },
254        args: call_args.clone(),
255        func: Box::new(fn_path)
256    };
257
258    let tuple_type = syn::TypeTuple {
259        paren_token: syn::token::Paren { span: proc_macro2::Span::call_site() },
260        elems: types,
261    };
262
263    let cache_type = &attr.cache_type;
264    let cache_type_with_generics: syn::Type = parse_quote! {
265        #cache_type<#tuple_type, #return_type>
266    };
267
268    let lru_body = build_cache_body(&cache_type_with_generics, &attr.cache_creation_expr, &cloned_args,
269        &fn_call, &macro_config);
270
271
272    new_fn.block = Box::new(lru_body);
273
274    let out = quote! {
275        #original_fn
276
277        #new_fn
278    };
279    Ok(out.into())
280}
281
282// Build the body of the caching function. What is constructed depends on the config value.
283fn build_cache_body(full_cache_type: &syn::Type, cache_new: &syn::Expr,
284                    cloned_args: &syn::ExprTuple, inner_fn_call: &syn::ExprCall,
285                    config: &config::Config) -> syn::Block
286{
287    if config.use_tls {
288        build_tls_cache_body(full_cache_type, cache_new, cloned_args, inner_fn_call)
289    } else {
290        build_mutex_cache_body(full_cache_type, cache_new, cloned_args, inner_fn_call)
291    }
292}
293
294// Build the body of the caching function which puts the cache in thread-local storage.
295fn build_tls_cache_body(full_cache_type: &syn::Type, cache_new: &syn::Expr,
296                     cloned_args: &syn::ExprTuple, inner_fn_call: &syn::ExprCall) -> syn::Block
297{
298    parse_quote! {
299        {
300            use std::cell::RefCell;
301            use std::thread_local;
302            thread_local!(
303                static cache: RefCell<#full_cache_type> =
304                    RefCell::new(#cache_new);
305            );
306            cache.with(|c| {
307                let mut cache_ref = c.borrow_mut();
308                let cloned_args = #cloned_args;
309
310                let stored_result = cache_ref.get_mut(&cloned_args);
311                if let Some(stored_result) = stored_result {
312                    return stored_result.clone()
313                }
314
315                // Don't hold a mutable borrow across
316                // the recursive function call
317                drop(cache_ref);
318
319                let ret = #inner_fn_call;
320                c.borrow_mut().insert(cloned_args, ret.clone());
321                ret
322            })
323        }
324    }
325}
326
327// Build the body of the caching function which guards the static cache with a mutex.
328fn build_mutex_cache_body(full_cache_type: &syn::Type, cache_new: &syn::Expr,
329                     cloned_args: &syn::ExprTuple, inner_fn_call: &syn::ExprCall) -> syn::Block
330{
331    parse_quote! {
332        {
333            use lazy_static::lazy_static;
334            use std::sync::Mutex;
335
336            lazy_static! {
337                static ref cache: Mutex<#full_cache_type> =
338                    Mutex::new(#cache_new);
339            }
340
341            let cloned_args = #cloned_args;
342
343            let mut cache_unlocked = cache.lock().unwrap();
344            let stored_result = cache_unlocked.get_mut(&cloned_args);
345            if let Some(stored_result) = stored_result {
346                return stored_result.clone();
347            };
348
349            // must unlock here to allow potentially recursive call
350            drop(cache_unlocked);
351
352            let ret = #inner_fn_call;
353            let mut cache_unlocked = cache.lock().unwrap();
354            cache_unlocked.insert(cloned_args, ret.clone());
355            ret
356        }
357    }
358}
359
360fn get_cache_fn_return_type(original_fn: &syn::ItemFn) -> Result<Box<syn::Type>> {
361    if let syn::ReturnType::Type(_, ref ty) = original_fn.decl.output {
362        Ok(ty.clone())
363    } else {
364        let diag = original_fn.ident.span().unstable()
365            .error("There's no point of caching the output of a function that has no output");
366        return Err(DiagnosticError::new(diag));
367    }
368}
369
370fn path_from_ident(ident: syn::Ident) -> syn::Expr {
371    let mut segments: Punctuated<_, Token![::]> = Punctuated::new();
372    segments.push(syn::PathSegment { ident: ident, arguments: syn::PathArguments::None });
373    syn::Expr::Path(syn::ExprPath { attrs: Vec::new(), qself: None, path: syn::Path { leading_colon: None, segments: segments} })
374}
375
376fn make_cloned_args_tuple(args: &Punctuated<syn::Expr, Token![,]>) -> syn::ExprTuple {
377    let mut cloned_args = Punctuated::<_, Token![,]>::new();
378    for arg in args {
379        let call = syn::ExprMethodCall {
380            attrs: Vec::new(),
381            receiver: Box::new(arg.clone()),
382            dot_token: syn::token::Dot { spans: [arg.span(); 1] },
383            method: syn::Ident::new("clone", proc_macro2::Span::call_site()),
384            turbofish: None,
385            paren_token: syn::token::Paren { span: proc_macro2::Span::call_site() },
386            args: Punctuated::new(),
387        };
388        cloned_args.push(syn::Expr::MethodCall(call));
389    }
390    syn::ExprTuple {
391        attrs: Vec::new(),
392        paren_token: syn::token::Paren { span: proc_macro2::Span::call_site() },
393        elems: cloned_args,
394    }
395}
396
397fn get_args_and_types(f: &syn::ItemFn, config: &config::Config) ->
398        Result<(Punctuated<syn::Expr, Token![,]>, Punctuated<syn::Type, Token![,]>, Punctuated<syn::Expr, Token![,]>)>
399{
400    let mut call_args = Punctuated::<_, Token![,]>::new();
401    let mut types = Punctuated::<_, Token![,]>::new();
402    let mut cache_args = Punctuated::<_, Token![,]>::new();
403
404    for input in &f.decl.inputs {
405        match input {
406            syn::FnArg::SelfValue(p) => {
407                let diag = p.span().unstable()
408                    .error("`self` arguments are currently unsupported by lru_cache");
409                return Err(DiagnosticError::new(diag));
410            }
411            syn::FnArg::SelfRef(p) => {
412                let diag = p.span().unstable()
413                    .error("`&self` arguments are currently unsupported by lru_cache");
414                return Err(DiagnosticError::new(diag));
415            }
416            syn::FnArg::Captured(arg_captured) => {
417                let mut segments: syn::punctuated::Punctuated<_, Token![::]> = syn::punctuated::Punctuated::new();
418                let arg_name;
419                if let syn::Pat::Ident(ref pat_ident) = arg_captured.pat {
420                    arg_name = pat_ident.ident.clone();
421                    if let Some(m) = pat_ident.mutability {
422                        if !config.ignore_args.contains(&arg_name) {
423                            let diag = m.span.unstable()
424                                .error("`mut` arguments are not supported with lru_cache as this could lead to incorrect results being stored");
425                            return Err(DiagnosticError::new(diag));
426                        }
427                    }
428                    segments.push(syn::PathSegment { ident: pat_ident.ident.clone(), arguments: syn::PathArguments::None });
429                } else {
430                    let diag = arg_captured.span().unstable()
431                        .error("unsupported argument kind");
432                    return Err(DiagnosticError::new(diag));
433                }
434
435                let arg_path = syn::Expr::Path(syn::ExprPath { attrs: Vec::new(), qself: None, path: syn::Path { leading_colon: None, segments } });
436
437                if !config.ignore_args.contains(&arg_name) {
438
439                    // If the arg type is a reference, remove the reference because the arg will be cloned
440                    if let syn::Type::Reference(type_reference) = &arg_captured.ty {
441                        types.push(type_reference.elem.as_ref().to_owned()); // as_ref -> to_owned unboxes the type
442                    } else {
443                        types.push(arg_captured.ty.clone());
444                    }
445
446                    cache_args.push(arg_path.clone());
447                }
448
449
450                call_args.push(arg_path);
451            },
452            syn::FnArg::Inferred(p) => {
453                let diag = p.span().unstable()
454                    .error("inferred arguments are currently unsupported by lru_cache");
455                return Err(DiagnosticError::new(diag));
456            }
457            syn::FnArg::Ignored(p) => {
458                let diag = p.span().unstable()
459                    .error("ignored arguments are currently unsupported by lru_cache");
460                return Err(DiagnosticError::new(diag));
461            }
462        }
463    }
464
465    if types.len() == 1 {
466        types.push_punct(syn::token::Comma { spans: [proc_macro2::Span::call_site(); 1] })
467    }
468
469    Ok((call_args, types, cache_args))
470}