1
2use proc_macro::TokenStream;
3use syn;
4use syn::{Token, parse_quote};
5use syn::spanned::Spanned;
6use syn::punctuated::Punctuated;
7use quote::quote;
8use proc_macro2;
9
10mod config;
11
12use syn::parse::Parse;
13use syn::parse::ParseStream;
14use syn::parse_macro_input;
15
16struct Attr {
17 cache_type: syn::Type,
18 cache_creation_expr: syn::Expr,
19}
20
21impl Parse for Attr {
22 fn parse(input: ParseStream) -> syn::parse::Result<Self> {
23 let cache_type: syn::Type = input.parse()?;
24 input.parse::<Token![:]>()?;
25 let cache_creation_expr: syn::Expr = input.parse()?;
26 Ok(Attr {
27 cache_type,
28 cache_creation_expr,
29 })
30 }
31}
32
33#[proc_macro_attribute]
34pub fn cache(attr: TokenStream, item: TokenStream) -> TokenStream {
35 let attr = parse_macro_input!(attr as Attr);
36
37 match algorithm_cache_impl(attr, item.clone()) {
38 Ok(tokens) => return tokens,
39 Err(e) => {
40 panic!("error = {:?}", e);
41 }
42 }
43}
44
45fn algorithm_cache_impl(attr: Attr, item: TokenStream) -> syn::parse::Result<TokenStream> {
47 let mut original_fn: syn::ItemFn = syn::parse(item.clone())?;
48 let (macro_config, out_attributes) =
49 {
50 let attribs = &original_fn.attrs[..];
51 config::Config::parse_from_attributes(attribs)?
52 };
53 original_fn.attrs = out_attributes;
54
55 let mut new_fn = original_fn.clone();
56 let return_type = get_cache_fn_return_type(&original_fn)?;
57 let new_name = format!("__cache_auto_{}", original_fn.sig.ident.to_string());
58 original_fn.sig.ident = syn::Ident::new(&new_name[..], original_fn.sig.ident.span());
59 let (call_args, types, cache_args) = get_args_and_types(&original_fn, ¯o_config)?;
60 let cloned_args = make_cloned_args_tuple(&cache_args);
61 let fn_path = path_from_ident(original_fn.sig.ident.clone());
62 let fn_call = syn::ExprCall {
63 attrs: Vec::new(),
64 paren_token: syn::token::Paren::default(),
65 args: call_args,
66 func: Box::new(fn_path)
67 };
68
69 let tuple_type = syn::TypeTuple {
70 paren_token: syn::token::Paren::default(),
71 elems: types,
72 };
73
74 let cache_type = &attr.cache_type;
75 let cache_type_with_generics: syn::Type = parse_quote! {
76 #cache_type<#tuple_type, #return_type, algorithm::DefaultHasher>
77 };
78 let lru_body = build_cache_body(&cache_type_with_generics, &attr.cache_creation_expr, &cloned_args,
79 &fn_call, ¯o_config);
80
81 new_fn.block = Box::new(lru_body);
82 let out = quote! {
83 #original_fn
84 #new_fn
85 };
86 Ok(out.into())
87}
88
89fn build_cache_body(full_cache_type: &syn::Type, cache_new: &syn::Expr,
91 cloned_args: &syn::ExprTuple, inner_fn_call: &syn::ExprCall,
92 config: &config::Config) -> syn::Block
93{
94 if config.use_thread {
95 build_mutex_cache_body(full_cache_type, cache_new, cloned_args, inner_fn_call)
96 } else {
97 build_tls_cache_body(full_cache_type, cache_new, cloned_args, inner_fn_call)
98 }
99}
100
101fn build_tls_cache_body(full_cache_type: &syn::Type, cache_new: &syn::Expr,
103 cloned_args: &syn::ExprTuple, inner_fn_call: &syn::ExprCall) -> syn::Block
104{
105 parse_quote! {
106 {
107 use std::cell::RefCell;
108 use std::thread_local;
109 thread_local!(
110 static cache: RefCell<#full_cache_type> =
111 RefCell::new(#cache_new);
112 );
113 cache.with(|c| {
114 let mut cache_ref = c.borrow_mut();
115 let cloned_args = #cloned_args;
116
117 let stored_result = cache_ref.get_mut(&cloned_args);
118 if let Some(stored_result) = stored_result {
119 return stored_result.clone()
120 }
121
122 drop(cache_ref);
125
126 let ret = #inner_fn_call;
127 c.borrow_mut().insert(cloned_args, ret.clone());
128 ret
129 })
130 }
131 }
132}
133
134fn build_mutex_cache_body(full_cache_type: &syn::Type, cache_new: &syn::Expr,
136 cloned_args: &syn::ExprTuple, inner_fn_call: &syn::ExprCall) -> syn::Block
137{
138 parse_quote! {
139 {
140 use lazy_static::lazy_static;
141 use std::sync::Mutex;
142
143 lazy_static! {
144 static ref cache: Mutex<#full_cache_type> =
145 Mutex::new(#cache_new);
146 }
147
148 let cloned_args = #cloned_args;
149
150 let mut cache_unlocked = cache.lock().unwrap();
151 let stored_result = cache_unlocked.get_mut(&cloned_args);
152 if let Some(stored_result) = stored_result {
153 return stored_result.clone();
154 };
155
156 drop(cache_unlocked);
158
159 let ret = #inner_fn_call;
160 let mut cache_unlocked = cache.lock().unwrap();
161 cache_unlocked.insert(cloned_args, ret.clone());
162 ret
163 }
164 }
165}
166
167fn get_cache_fn_return_type(original_fn: &syn::ItemFn) -> syn::Result<Box<syn::Type>> {
168 if let syn::ReturnType::Type(_, ref ty) = original_fn.sig.output {
169 Ok(ty.clone())
170 } else {
171 return Err(syn::Error::new_spanned(original_fn, "There's no point of caching the output of a function that has no output"))
172 }
173}
174
175fn path_from_ident(ident: syn::Ident) -> syn::Expr {
176 let mut segments: Punctuated<_, Token![::]> = Punctuated::new();
177 segments.push(syn::PathSegment { ident: ident, arguments: syn::PathArguments::None });
178 syn::Expr::Path(syn::ExprPath { attrs: Vec::new(), qself: None, path: syn::Path { leading_colon: None, segments: segments} })
179}
180
181fn make_cloned_args_tuple(args: &Punctuated<syn::Expr, Token![,]>) -> syn::ExprTuple {
182 let mut cloned_args = Punctuated::<_, Token![,]>::new();
183 for arg in args {
184 let call = syn::ExprMethodCall {
185 attrs: Vec::new(),
186 receiver: Box::new(arg.clone()),
187 dot_token: syn::token::Dot { spans: [arg.span(); 1] },
188 method: syn::Ident::new("clone", proc_macro2::Span::call_site()),
189 turbofish: None,
190 paren_token: syn::token::Paren::default(),
191 args: Punctuated::new(),
192 };
193 cloned_args.push(syn::Expr::MethodCall(call));
194 }
195 syn::ExprTuple {
196 attrs: Vec::new(),
197 paren_token: syn::token::Paren::default(),
198 elems: cloned_args,
199 }
200}
201
202fn get_args_and_types(f: &syn::ItemFn, config: &config::Config) ->
203 syn::Result<(Punctuated<syn::Expr, Token![,]>, Punctuated<syn::Type, Token![,]>, Punctuated<syn::Expr, Token![,]>)>
204{
205 let mut call_args = Punctuated::<_, Token![,]>::new();
206 let mut types = Punctuated::<_, Token![,]>::new();
207 let mut cache_args = Punctuated::<_, Token![,]>::new();
208
209 for input in &f.sig.inputs {
210 match input {
211 syn::FnArg::Receiver(_) => {
212 return Err(syn::Error::new(input.span(), "`self` arguments are currently unsupported by algorithm_cache"));
213
214 }
215 syn::FnArg::Typed(p) => {
216 let mut segments: syn::punctuated::Punctuated<_, Token![::]> = syn::punctuated::Punctuated::new();
217 let arg_name;
218 if let syn::Pat::Ident(ref pat_ident) = *p.pat {
219 arg_name = pat_ident.ident.clone();
220 segments.push(syn::PathSegment { ident: pat_ident.ident.clone(), arguments: syn::PathArguments::None });
221 } else {
222 return Err(syn::Error::new(input.span(), "unsupported argument kind"));
223 }
224
225 let arg_path = syn::Expr::Path(syn::ExprPath { attrs: Vec::new(), qself: None, path: syn::Path { leading_colon: None, segments } });
226 if !config.ignore_args.contains(&arg_name) {
227 if let syn::Type::Reference(type_reference) = &*p.ty {
229 if let Some(_) = type_reference.mutability {
230 call_args.push(arg_path);
231 continue;
232 }
234 types.push(type_reference.elem.as_ref().to_owned()); } else {
236 types.push((*p.ty).clone());
237 }
238
239 cache_args.push(arg_path.clone());
240 }
241 call_args.push(arg_path);
242 }
243 }
244 }
245
246 if types.len() == 1 {
247 types.push_punct(syn::token::Comma { spans: [proc_macro2::Span::call_site(); 1] })
248 }
249
250 Ok((call_args, types, cache_args))
251}