cache_macro_stable_rust/
lib.rs

1//! cache-macro
2//! ================
3//!
4//! A procedural macro to automatically cache the result of a function given a set of inputs.
5//!
6//! # Example:
7//!
8//! ```rust
9//! use cache_macro::cache;
10//! use lru_cache::LruCache;
11//!
12//! #[cache(LruCache : LruCache::new(20))]
13//! fn fib(x: u32) -> u64 {
14//!     println!("{:?}", x);
15//!     if x <= 1 {
16//!         1
17//!     } else {
18//!         fib(x - 1) + fib(x - 2)
19//!     }
20//! }
21//!
22//! assert_eq!(fib(19), 6765);
23//! ```
24//!
25//! The above example only calls `fib` twenty times, with the values from 0 to 19. All intermediate
26//! results because of the recursion hit the cache.
27//!
28//! # Usage:
29//!
30//! Simply place `#[cache(CacheType : constructor)]` above your function. The function must obey a few properties
31//! to use lru_cache:
32//!
33//! * All arguments and return values must implement `Clone`.
34//! * The function may not take `self` in any form.
35//!
36//! The `LruCache` type used must accept two generic parameters `<Args, Return>` and must support methods
37//! `get_mut(&K) -> Option<&mut V>` and `insert(K, V)`. The `lru-cache` (for LRU caching)
38//! and `expiring_map` (for time-to-live caching) crates currently meet these requirements.
39//!
40//! Currently, this crate only works on nightly rust. However, once the 2018 edition stabilizes as well as the
41//! procedural macro diagnostic interface, it should be able to run on stable.
42//!
43//! # Configuration:
44//!
45//! The lru_cache macro can be configured by adding additional attributes under `#[cache(...)]`.
46//!
47//! All configuration attributes take the form `#[cache_cfg(...)]`. The available attributes are:
48//!
49//! * `#[cache_cfg(ignore_args = ...)]`
50//!
51//! This allows certain arguments to be ignored for the purposes of caching. That means they are not part of the
52//! hash table key and thus should never influence the output of the function. It can be useful for diagnostic settings,
53//! returning the number of times executed, or other introspection purposes.
54//!
55//! `ignore_args` takes a comma-separated list of variable identifiers to ignore.
56//!
57//! ### Example:
58//! ```rust
59//! use cache_macro::cache;
60//! use lru_cache::LruCache;
61//! #[cache(LruCache : LruCache::new(20))]
62//! #[cache_cfg(ignore_args = call_count)]
63//! fn fib(x: u64, call_count: &mut u32) -> u64 {
64//!     *call_count += 1;
65//!     if x <= 1 {
66//!         1
67//!     } else {
68//!         fib(x - 1, call_count) + fib(x - 2, call_count)
69//!     }
70//! }
71//!
72//! let mut call_count = 0;
73//! assert_eq!(fib(39, &mut call_count), 102_334_155);
74//! assert_eq!(call_count, 40);
75//! ```
76//!
77//! The `call_count` argument can vary, caching is only done based on `x`.
78//!
79//! * `#[cache_cfg(thread_local)]`
80//!
81//! Store the cache in thread-local storage instead of global static storage. This avoids the overhead of Mutex locking,
82//! but each thread will be given its own cache, and all caching will not affect any other thread.
83//!
84//! Expanding on the first example:
85//!
86//! ```rust
87//! use cache_macro::cache;
88//! use lru_cache::LruCache;
89//!
90//! #[cache(LruCache : LruCache::new(20))]
91//! #[cache_cfg(thread_local)]
92//! fn fib(x: u32) -> u64 {
93//!     println!("{:?}", x);
94//!     if x <= 1 {
95//!         1
96//!     } else {
97//!         fib(x - 1) + fib(x - 2)
98//!     }
99//! }
100//!
101//! assert_eq!(fib(19), 6765);
102//! ```
103//!
104//! # Details
105//! The created cache is stored as a static variable protected by a mutex unless the `#[cache_cfg(thread_local)]`
106//! configuration is added.
107//!
108//! With the default settings, the fibonacci example will generate the following code:
109//!
110//! ```rust
111//! fn __lru_base_fib(x: u32) -> u64 {
112//!     if x <= 1 { 1 } else { fib(x - 1) + fib(x - 2) }
113//! }
114//! fn fib(x: u32) -> u64 {
115//!     use lazy_static::lazy_static;
116//!     use std::sync::Mutex;
117//!
118//!     lazy_static! {
119//!         static ref cache: Mutex<::lru_cache::LruCache<(u32,), u64>> =
120//!             Mutex::new(::lru_cache::LruCache::new(20usize));
121//!     }
122//!
123//!     let cloned_args = (x.clone(),);
124//!     let mut cache_unlocked = cache.lock().unwrap();
125//!     let stored_result = cache_unlocked.get_mut(&cloned_args);
126//!     if let Some(stored_result) = stored_result {
127//!         return stored_result.clone();
128//!     };
129//!     drop(cache_unlocked);
130//!     let ret = __lru_base_fib(x);
131//!     let mut cache_unlocked = cache.lock().unwrap();
132//!     cache_unlocked.insert(cloned_args, ret.clone());
133//!     ret
134//! }
135//!
136//! ```
137//!
138//! Whereas, if you use the `#[lru_config(thread_local)]` the generated code will look like:
139//!
140//!
141//! ```rust
142//! fn __lru_base_fib(x: u32) -> u64 {
143//!     if x <= 1 { 1 } else { fib(x - 1) + fib(x - 2) }
144//! }
145//! fn fib(x: u32) -> u64 {
146//!     use std::cell::RefCell;
147//!     use std::thread_local;
148//!
149//!     thread_local!(
150//!         static cache: RefCell<::lru_cache::LruCache<(u32,), u64>> =
151//!             RefCell::new(::lru_cache::LruCache::new(20usize));
152//!     );
153//!
154//!     cache.with(|c|
155//!         {
156//!             let mut cache_ref = c.borrow_mut();
157//!             let cloned_args = (x.clone(),);
158//!             let stored_result = cache_ref.get_mut(&cloned_args);
159//!             if let Some(stored_result) = stored_result {
160//!                 return stored_result.clone()
161//!             }
162//!
163//!             // Don't hold a mutable borrow across
164//!             // the recursive function call
165//!             drop(cache_ref);
166//!
167//!             let ret = __lru_base_fib(x);
168//!             c.borrow_mut().insert(cloned_args, ret.clone());
169//!             ret
170//!         })
171//! }
172//! ```
173//!
174//#![feature(extern_crate_item_prelude)]
175//#![feature(proc_macro_diagnostic)]
176#![recursion_limit="128"]
177extern crate proc_macro;
178
179use proc_macro::TokenStream;
180use syn;
181use syn::{Token, parse_quote};
182use syn::spanned::Spanned;
183use syn::punctuated::Punctuated;
184use quote::quote;
185use proc_macro2;
186
187mod config;
188mod error;
189
190use self::error::{DiagnosticError, Result};
191use syn::parse::Parse;
192use syn::parse::ParseStream;
193use syn::parse_macro_input;
194
195struct Attr {
196    cache_type: syn::Type,
197    cache_creation_expr: syn::Expr,
198}
199
200impl Parse for Attr {
201    fn parse(input: ParseStream) -> syn::parse::Result<Self> {
202        let cache_type: syn::Type = input.parse()?;
203        input.parse::<Token![:]>()?;
204        let cache_creation_expr: syn::Expr = input.parse()?;
205        Ok(Attr {
206            cache_type,
207            cache_creation_expr,
208        })
209    }
210}
211
212// Function shim to allow us to use `Result` and the `?` operator.
213#[proc_macro_attribute]
214pub fn cache(attr: TokenStream, item: TokenStream) -> TokenStream {
215    let attr = parse_macro_input!(attr as Attr);
216
217    match lru_cache_impl(attr, item.clone()) {
218        Ok(tokens) => return tokens,
219        Err(e) => {
220            e.emit();
221            return item;
222        }
223    }
224}
225
226// The main entry point for the macro.
227fn lru_cache_impl(attr: Attr, item: TokenStream) -> Result<TokenStream> {
228    let mut original_fn: syn::ItemFn = match syn::parse(item.clone()) {
229        Ok(ast) => ast,
230        Err(e) => {
231            return Err(DiagnosticError::new_with_syn_error(String::from("lru_cache may only be used on functions"), e));
232        }
233    };
234
235    let (macro_config, out_attributes) =
236        {
237            let attribs = &original_fn.attrs[..];
238            config::Config::parse_from_attributes(attribs)?
239        };
240    original_fn.attrs = out_attributes;
241
242    let mut new_fn = original_fn.clone();
243
244    let return_type = get_cache_fn_return_type(&original_fn)?;
245
246    let new_name = format!("__lru_base_{}", original_fn.ident.to_string());
247    original_fn.ident = syn::Ident::new(&new_name[..], original_fn.ident.span());
248
249    let (call_args, types, cache_args) = get_args_and_types(&original_fn, &macro_config)?;
250    let cloned_args = make_cloned_args_tuple(&cache_args);
251    let fn_path = path_from_ident(original_fn.ident.clone());
252
253    let fn_call = syn::ExprCall {
254        attrs: Vec::new(),
255        paren_token: syn::token::Paren { span: proc_macro2::Span::call_site() },
256        args: call_args.clone(),
257        func: Box::new(fn_path)
258    };
259
260    let tuple_type = syn::TypeTuple {
261        paren_token: syn::token::Paren { span: proc_macro2::Span::call_site() },
262        elems: types,
263    };
264
265    let cache_type = &attr.cache_type;
266    let cache_type_with_generics: syn::Type = parse_quote! {
267        #cache_type<#tuple_type, #return_type>
268    };
269
270    let lru_body = build_cache_body(&cache_type_with_generics, &attr.cache_creation_expr, &cloned_args,
271        &fn_call, &macro_config);
272
273
274    new_fn.block = Box::new(lru_body);
275
276    let out = quote! {
277        #original_fn
278
279        #new_fn
280    };
281    Ok(out.into())
282}
283
284// Build the body of the caching function. What is constructed depends on the config value.
285fn build_cache_body(full_cache_type: &syn::Type, cache_new: &syn::Expr,
286                    cloned_args: &syn::ExprTuple, inner_fn_call: &syn::ExprCall,
287                    config: &config::Config) -> syn::Block
288{
289    if config.use_tls {
290        build_tls_cache_body(full_cache_type, cache_new, cloned_args, inner_fn_call)
291    } else {
292        build_mutex_cache_body(full_cache_type, cache_new, cloned_args, inner_fn_call)
293    }
294}
295
296// Build the body of the caching function which puts the cache in thread-local storage.
297fn build_tls_cache_body(full_cache_type: &syn::Type, cache_new: &syn::Expr,
298                     cloned_args: &syn::ExprTuple, inner_fn_call: &syn::ExprCall) -> syn::Block
299{
300    parse_quote! {
301        {
302            use std::cell::RefCell;
303            use std::thread_local;
304            thread_local!(
305                static cache: RefCell<#full_cache_type> =
306                    RefCell::new(#cache_new);
307            );
308            cache.with(|c| {
309                let mut cache_ref = c.borrow_mut();
310                let cloned_args = #cloned_args;
311
312                let stored_result = cache_ref.get_mut(&cloned_args);
313                if let Some(stored_result) = stored_result {
314                    return stored_result.clone()
315                }
316
317                // Don't hold a mutable borrow across
318                // the recursive function call
319                drop(cache_ref);
320
321                let ret = #inner_fn_call;
322                c.borrow_mut().insert(cloned_args, ret.clone());
323                ret
324            })
325        }
326    }
327}
328
329// Build the body of the caching function which guards the static cache with a mutex.
330fn build_mutex_cache_body(full_cache_type: &syn::Type, cache_new: &syn::Expr,
331                     cloned_args: &syn::ExprTuple, inner_fn_call: &syn::ExprCall) -> syn::Block
332{
333    parse_quote! {
334        {
335            use lazy_static::lazy_static;
336            use std::sync::Mutex;
337
338            lazy_static! {
339                static ref cache: Mutex<#full_cache_type> =
340                    Mutex::new(#cache_new);
341            }
342
343            let cloned_args = #cloned_args;
344
345            let mut cache_unlocked = cache.lock().unwrap();
346            let stored_result = cache_unlocked.get_mut(&cloned_args);
347            if let Some(stored_result) = stored_result {
348                return stored_result.clone();
349            };
350
351            // must unlock here to allow potentially recursive call
352            drop(cache_unlocked);
353
354            let ret = #inner_fn_call;
355            let mut cache_unlocked = cache.lock().unwrap();
356            cache_unlocked.insert(cloned_args, ret.clone());
357            ret
358        }
359    }
360}
361
362fn get_cache_fn_return_type(original_fn: &syn::ItemFn) -> Result<Box<syn::Type>> {
363    if let syn::ReturnType::Type(_, ref ty) = original_fn.decl.output {
364        Ok(ty.clone())
365    } else {
366        return Err(DiagnosticError::new(String::from("There's no point of caching the output of a function that has no output")));
367    }
368}
369
370fn path_from_ident(ident: syn::Ident) -> syn::Expr {
371    let mut segments: Punctuated<_, Token![::]> = Punctuated::new();
372    segments.push(syn::PathSegment { ident: ident, arguments: syn::PathArguments::None });
373    syn::Expr::Path(syn::ExprPath { attrs: Vec::new(), qself: None, path: syn::Path { leading_colon: None, segments: segments} })
374}
375
376fn make_cloned_args_tuple(args: &Punctuated<syn::Expr, Token![,]>) -> syn::ExprTuple {
377    let mut cloned_args = Punctuated::<_, Token![,]>::new();
378    for arg in args {
379        let call = syn::ExprMethodCall {
380            attrs: Vec::new(),
381            receiver: Box::new(arg.clone()),
382            dot_token: syn::token::Dot { spans: [arg.span(); 1] },
383            method: syn::Ident::new("clone", proc_macro2::Span::call_site()),
384            turbofish: None,
385            paren_token: syn::token::Paren { span: proc_macro2::Span::call_site() },
386            args: Punctuated::new(),
387        };
388        cloned_args.push(syn::Expr::MethodCall(call));
389    }
390    syn::ExprTuple {
391        attrs: Vec::new(),
392        paren_token: syn::token::Paren { span: proc_macro2::Span::call_site() },
393        elems: cloned_args,
394    }
395}
396
397fn get_args_and_types(f: &syn::ItemFn, config: &config::Config) ->
398        Result<(Punctuated<syn::Expr, Token![,]>, Punctuated<syn::Type, Token![,]>, Punctuated<syn::Expr, Token![,]>)>
399{
400    let mut call_args = Punctuated::<_, Token![,]>::new();
401    let mut types = Punctuated::<_, Token![,]>::new();
402    let mut cache_args = Punctuated::<_, Token![,]>::new();
403
404    for input in &f.decl.inputs {
405        match input {
406            syn::FnArg::SelfValue(_p) => {
407                return Err(DiagnosticError::new(String::from("`self` arguments are currently unsupported by lru_cache")));
408            }
409            syn::FnArg::SelfRef(_p) => {
410                return Err(DiagnosticError::new(String::from("`&self` arguments are currently unsupported by lru_cache")));
411            }
412            syn::FnArg::Captured(arg_captured) => {
413                let mut segments: syn::punctuated::Punctuated<_, Token![::]> = syn::punctuated::Punctuated::new();
414                let arg_name;
415                if let syn::Pat::Ident(ref pat_ident) = arg_captured.pat {
416                    arg_name = pat_ident.ident.clone();
417                    segments.push(syn::PathSegment { ident: pat_ident.ident.clone(), arguments: syn::PathArguments::None });
418                } else {
419                    return Err(DiagnosticError::new(String::from("unsupported argument kind")));
420                }
421
422                let arg_path = syn::Expr::Path(syn::ExprPath { attrs: Vec::new(), qself: None, path: syn::Path { leading_colon: None, segments } });
423
424                if !config.ignore_args.contains(&arg_name) {
425
426                    // If the arg type is a reference, remove the reference because the arg will be cloned
427                    if let syn::Type::Reference(type_reference) = &arg_captured.ty {
428                        if let Some(_m) = type_reference.mutability {
429                            return Err(DiagnosticError::new(String::from("`mut` reference arguments are not supported as this could lead to incorrect results being stored")));
430                        }
431                        types.push(type_reference.elem.as_ref().to_owned()); // as_ref -> to_owned unboxes the type
432                    } else {
433                        types.push(arg_captured.ty.clone());
434                    }
435
436                    cache_args.push(arg_path.clone());
437                }
438
439
440                call_args.push(arg_path);
441            },
442            syn::FnArg::Inferred(_p) => {
443                return Err(DiagnosticError::new(String::from("inferred arguments are currently unsupported by lru_cache")));
444            }
445            syn::FnArg::Ignored(_p) => {
446                return Err(DiagnosticError::new(String::from("ignored arguments are currently unsupported by lru_cache")));
447            }
448        }
449    }
450
451    if types.len() == 1 {
452        types.push_punct(syn::token::Comma { spans: [proc_macro2::Span::call_site(); 1] })
453    }
454
455    Ok((call_args, types, cache_args))
456}