algorithm_macro/
lib.rs

1
2use proc_macro::TokenStream;
3use syn;
4use syn::{Token, parse_quote};
5use syn::spanned::Spanned;
6use syn::punctuated::Punctuated;
7use quote::quote;
8use proc_macro2;
9
10mod config;
11
12use syn::parse::Parse;
13use syn::parse::ParseStream;
14use syn::parse_macro_input;
15
16struct Attr {
17    cache_type: syn::Type,
18    cache_creation_expr: syn::Expr,
19}
20
21impl Parse for Attr {
22    fn parse(input: ParseStream) -> syn::parse::Result<Self> {
23        let cache_type: syn::Type = input.parse()?;
24        input.parse::<Token![:]>()?;
25        let cache_creation_expr: syn::Expr = input.parse()?;
26        Ok(Attr {
27            cache_type,
28            cache_creation_expr,
29        })
30    }
31}
32
33#[proc_macro_attribute]
34pub fn cache(attr: TokenStream, item: TokenStream) -> TokenStream {
35    let attr = parse_macro_input!(attr as Attr);
36
37    match algorithm_cache_impl(attr, item.clone()) {
38        Ok(tokens) => return tokens,
39        Err(e) => {
40            panic!("error = {:?}", e);
41        }
42    }
43}
44
45// The main entry point for the macro.
46fn algorithm_cache_impl(attr: Attr, item: TokenStream) -> syn::parse::Result<TokenStream> {
47    let mut original_fn: syn::ItemFn = syn::parse(item.clone())?;
48    let (macro_config, out_attributes) =
49        {
50            let attribs = &original_fn.attrs[..];
51            config::Config::parse_from_attributes(attribs)?
52        };
53    original_fn.attrs = out_attributes;
54
55    let mut new_fn = original_fn.clone();
56    let return_type = get_cache_fn_return_type(&original_fn)?;
57    let new_name = format!("__cache_auto_{}", original_fn.sig.ident.to_string());
58    original_fn.sig.ident = syn::Ident::new(&new_name[..], original_fn.sig.ident.span());
59    let (call_args, types, cache_args) = get_args_and_types(&original_fn, &macro_config)?;
60    let cloned_args = make_cloned_args_tuple(&cache_args);
61    let fn_path = path_from_ident(original_fn.sig.ident.clone());
62    let fn_call = syn::ExprCall {
63        attrs: Vec::new(),
64        paren_token: syn::token::Paren::default(),
65        args: call_args,
66        func: Box::new(fn_path)
67    };
68
69    let tuple_type = syn::TypeTuple {
70        paren_token: syn::token::Paren::default(),
71        elems: types,
72    };
73
74    let cache_type = &attr.cache_type;
75    let cache_type_with_generics: syn::Type = parse_quote! {
76        #cache_type<#tuple_type, #return_type, algorithm::DefaultHasher>
77    };
78    let lru_body = build_cache_body(&cache_type_with_generics, &attr.cache_creation_expr, &cloned_args,
79        &fn_call, &macro_config);
80
81    new_fn.block = Box::new(lru_body);
82    let out = quote! {
83        #original_fn
84        #new_fn
85    };
86    Ok(out.into())
87}
88
89// Build the body of the caching function. What is constructed depends on the config value.
90fn build_cache_body(full_cache_type: &syn::Type, cache_new: &syn::Expr,
91                    cloned_args: &syn::ExprTuple, inner_fn_call: &syn::ExprCall,
92                    config: &config::Config) -> syn::Block
93{
94    if config.use_thread {
95        build_mutex_cache_body(full_cache_type, cache_new, cloned_args, inner_fn_call)
96    } else {
97        build_tls_cache_body(full_cache_type, cache_new, cloned_args, inner_fn_call)
98    }
99}
100
101// Build the body of the caching function which puts the cache in thread-local storage.
102fn build_tls_cache_body(full_cache_type: &syn::Type, cache_new: &syn::Expr,
103                     cloned_args: &syn::ExprTuple, inner_fn_call: &syn::ExprCall) -> syn::Block
104{
105    parse_quote! {
106        {
107            use std::cell::RefCell;
108            use std::thread_local;
109            thread_local!(
110                static cache: RefCell<#full_cache_type> =
111                    RefCell::new(#cache_new);
112            );
113            cache.with(|c| {
114                let mut cache_ref = c.borrow_mut();
115                let cloned_args = #cloned_args;
116
117                let stored_result = cache_ref.get_mut(&cloned_args);
118                if let Some(stored_result) = stored_result {
119                    return stored_result.clone()
120                }
121
122                // Don't hold a mutable borrow across
123                // the recursive function call
124                drop(cache_ref);
125
126                let ret = #inner_fn_call;
127                c.borrow_mut().insert(cloned_args, ret.clone());
128                ret
129            })
130        }
131    }
132}
133
134// Build the body of the caching function which guards the static cache with a mutex.
135fn build_mutex_cache_body(full_cache_type: &syn::Type, cache_new: &syn::Expr,
136                     cloned_args: &syn::ExprTuple, inner_fn_call: &syn::ExprCall) -> syn::Block
137{
138    parse_quote! {
139        {
140            use lazy_static::lazy_static;
141            use std::sync::Mutex;
142
143            lazy_static! {
144                static ref cache: Mutex<#full_cache_type> =
145                    Mutex::new(#cache_new);
146            }
147
148            let cloned_args = #cloned_args;
149
150            let mut cache_unlocked = cache.lock().unwrap();
151            let stored_result = cache_unlocked.get_mut(&cloned_args);
152            if let Some(stored_result) = stored_result {
153                return stored_result.clone();
154            };
155
156            // must unlock here to allow potentially recursive call
157            drop(cache_unlocked);
158
159            let ret = #inner_fn_call;
160            let mut cache_unlocked = cache.lock().unwrap();
161            cache_unlocked.insert(cloned_args, ret.clone());
162            ret
163        }
164    }
165}
166
167fn get_cache_fn_return_type(original_fn: &syn::ItemFn) -> syn::Result<Box<syn::Type>> {
168    if let syn::ReturnType::Type(_, ref ty) = original_fn.sig.output {
169        Ok(ty.clone())
170    } else {
171        return Err(syn::Error::new_spanned(original_fn, "There's no point of caching the output of a function that has no output"))
172    }
173}
174
175fn path_from_ident(ident: syn::Ident) -> syn::Expr {
176    let mut segments: Punctuated<_, Token![::]> = Punctuated::new();
177    segments.push(syn::PathSegment { ident: ident, arguments: syn::PathArguments::None });
178    syn::Expr::Path(syn::ExprPath { attrs: Vec::new(), qself: None, path: syn::Path { leading_colon: None, segments: segments} })
179}
180
181fn make_cloned_args_tuple(args: &Punctuated<syn::Expr, Token![,]>) -> syn::ExprTuple {
182    let mut cloned_args = Punctuated::<_, Token![,]>::new();
183    for arg in args {
184        let call = syn::ExprMethodCall {
185            attrs: Vec::new(),
186            receiver: Box::new(arg.clone()),
187            dot_token: syn::token::Dot { spans: [arg.span(); 1] },
188            method: syn::Ident::new("clone", proc_macro2::Span::call_site()),
189            turbofish: None,
190            paren_token: syn::token::Paren::default(),
191            args: Punctuated::new(),
192        };
193        cloned_args.push(syn::Expr::MethodCall(call));
194    }
195    syn::ExprTuple {
196        attrs: Vec::new(),
197        paren_token: syn::token::Paren::default(),
198        elems: cloned_args,
199    }
200}
201
202fn get_args_and_types(f: &syn::ItemFn, config: &config::Config) ->
203        syn::Result<(Punctuated<syn::Expr, Token![,]>, Punctuated<syn::Type, Token![,]>, Punctuated<syn::Expr, Token![,]>)>
204{
205    let mut call_args = Punctuated::<_, Token![,]>::new();
206    let mut types = Punctuated::<_, Token![,]>::new();
207    let mut cache_args = Punctuated::<_, Token![,]>::new();
208
209    for input in &f.sig.inputs {
210        match input {
211            syn::FnArg::Receiver(_) => {
212                return Err(syn::Error::new(input.span(), "`self` arguments are currently unsupported by algorithm_cache"));
213
214            }
215            syn::FnArg::Typed(p) => {
216                let mut segments: syn::punctuated::Punctuated<_, Token![::]> = syn::punctuated::Punctuated::new();
217                let arg_name;
218                if let syn::Pat::Ident(ref pat_ident) = *p.pat {
219                    arg_name = pat_ident.ident.clone();
220                    segments.push(syn::PathSegment { ident: pat_ident.ident.clone(), arguments: syn::PathArguments::None });
221                } else {
222                    return Err(syn::Error::new(input.span(), "unsupported argument kind"));
223                }
224
225                let arg_path = syn::Expr::Path(syn::ExprPath { attrs: Vec::new(), qself: None, path: syn::Path { leading_colon: None, segments } });
226                if !config.ignore_args.contains(&arg_name) {
227                    // If the arg type is a reference, remove the reference because the arg will be cloned
228                    if let syn::Type::Reference(type_reference) = &*p.ty {
229                        if let Some(_) = type_reference.mutability {
230                            call_args.push(arg_path);
231                            continue;
232                            // return Err(io::Error::new(io::ErrorKind::Other, "`mut` reference arguments are not supported as this could lead to incorrect results being stored"));
233                        }
234                        types.push(type_reference.elem.as_ref().to_owned()); // as_ref -> to_owned unboxes the type
235                    } else {
236                        types.push((*p.ty).clone());
237                    }
238
239                    cache_args.push(arg_path.clone());
240                }
241                call_args.push(arg_path);
242            }
243        }
244    }
245
246    if types.len() == 1 {
247        types.push_punct(syn::token::Comma { spans: [proc_macro2::Span::call_site(); 1] })
248    }
249
250    Ok((call_args, types, cache_args))
251}