1#![crate_type = "proc-macro"]
2#![allow(unused_imports)] use syn::{self, parse, parse_macro_input, spanned::Spanned, Expr, ExprCall, ItemFn, Path};
4
5use proc_macro::TokenStream;
6use quote::{self, ToTokens};
7
8mod kw {
9 syn::custom_keyword!(Capacity);
10 syn::custom_keyword!(TimeToLive);
11 syn::custom_keyword!(SharedCache);
12 syn::custom_keyword!(CustomHasher);
13 syn::custom_keyword!(HasherInit);
14 syn::custom_keyword!(Ignore);
15 syn::custom_punctuation!(Colon, :);
16}
17
18#[derive(Default, Clone)]
19struct CacheOptions {
20 lru_max_entries: Option<usize>,
21 time_to_live: Option<Expr>,
22 shared_cache: bool,
23 custom_hasher: Option<Path>,
24 custom_hasher_initializer: Option<ExprCall>,
25 ignore: Vec<syn::Ident>,
26}
27
28#[derive(Clone)]
29enum CacheOption {
30 LRUMaxEntries(usize),
31 TimeToLive(Expr),
32 SharedCache,
33 CustomHasher(Path),
34 HasherInit(ExprCall),
35 Ignore(syn::Ident),
36}
37
38#[allow(unreachable_code)]
40impl parse::Parse for CacheOption {
41 fn parse(input: parse::ParseStream) -> syn::Result<Self> {
42 let la = input.lookahead1();
43 if la.peek(kw::Capacity) {
44 #[cfg(not(feature = "full"))]
45 return Err(syn::Error::new(input.span(),
46 "memoize error: Capacity specified, but the feature 'full' is not enabled! To fix this, compile with `--features=full`.",
47 ));
48
49 input.parse::<kw::Capacity>().unwrap();
50 input.parse::<kw::Colon>().unwrap();
51 let cap: syn::LitInt = input.parse().unwrap();
52
53 return Ok(CacheOption::LRUMaxEntries(cap.base10_parse()?));
54 }
55 if la.peek(kw::TimeToLive) {
56 #[cfg(not(feature = "full"))]
57 return Err(syn::Error::new(input.span(),
58 "memoize error: TimeToLive specified, but the feature 'full' is not enabled! To fix this, compile with `--features=full`.",
59 ));
60
61 input.parse::<kw::TimeToLive>().unwrap();
62 input.parse::<kw::Colon>().unwrap();
63 let cap: syn::Expr = input.parse().unwrap();
64
65 return Ok(CacheOption::TimeToLive(cap));
66 }
67 if la.peek(kw::SharedCache) {
68 input.parse::<kw::SharedCache>().unwrap();
69 return Ok(CacheOption::SharedCache);
70 }
71 if la.peek(kw::CustomHasher) {
72 input.parse::<kw::CustomHasher>().unwrap();
73 input.parse::<kw::Colon>().unwrap();
74 let cap: syn::Path = input.parse().unwrap();
75 return Ok(CacheOption::CustomHasher(cap));
76 }
77 if la.peek(kw::HasherInit) {
78 input.parse::<kw::HasherInit>().unwrap();
79 input.parse::<kw::Colon>().unwrap();
80 let cap: syn::ExprCall = input.parse().unwrap();
81 return Ok(CacheOption::HasherInit(cap));
82 }
83 if la.peek(kw::Ignore) {
84 input.parse::<kw::Ignore>().unwrap();
85 input.parse::<kw::Colon>().unwrap();
86 let ignore_ident = input.parse::<syn::Ident>().unwrap();
87 return Ok(CacheOption::Ignore(ignore_ident));
88 }
89 Err(la.error())
90 }
91}
92
93impl parse::Parse for CacheOptions {
94 fn parse(input: parse::ParseStream) -> syn::Result<Self> {
95 let f: syn::punctuated::Punctuated<CacheOption, syn::Token![,]> =
96 input.parse_terminated(CacheOption::parse)?;
97 let mut opts = Self::default();
98
99 for opt in f {
100 match opt {
101 CacheOption::LRUMaxEntries(cap) => opts.lru_max_entries = Some(cap),
102 CacheOption::TimeToLive(sec) => opts.time_to_live = Some(sec),
103 CacheOption::CustomHasher(hasher) => opts.custom_hasher = Some(hasher),
104 CacheOption::HasherInit(init) => opts.custom_hasher_initializer = Some(init),
105 CacheOption::SharedCache => opts.shared_cache = true,
106 CacheOption::Ignore(ident) => opts.ignore.push(ident),
107 }
108 }
109 Ok(opts)
110 }
111}
112
113#[cfg(not(feature = "full"))]
115mod store {
116 use crate::CacheOptions;
117 use proc_macro::TokenStream;
118
119 pub(crate) fn construct_cache(
121 _options: &CacheOptions,
122 key_type: proc_macro2::TokenStream,
123 value_type: proc_macro2::TokenStream,
124 ) -> (proc_macro2::TokenStream, proc_macro2::TokenStream) {
125 if let Some(hasher) = &_options.custom_hasher {
127 return (
128 quote::quote! { #hasher<#key_type, #value_type> },
129 quote::quote! { #hasher::new() },
130 );
131 } else {
132 (
133 quote::quote! { std::collections::HashMap<#key_type, #value_type> },
134 quote::quote! { std::collections::HashMap::new() },
135 )
136 }
137 }
138
139 pub(crate) fn cache_access_methods(
142 _options: &CacheOptions,
143 ) -> (proc_macro2::TokenStream, proc_macro2::TokenStream) {
144 (quote::quote! { insert }, quote::quote! { get })
145 }
146}
147
148#[cfg(feature = "full")]
150mod store {
151 use crate::CacheOptions;
152 use proc_macro::TokenStream;
153
154 pub(crate) fn construct_cache(
160 options: &CacheOptions,
161 key_type: proc_macro2::TokenStream,
162 value_type: proc_macro2::TokenStream,
163 ) -> (proc_macro2::TokenStream, proc_macro2::TokenStream) {
164 let value_type = match options.time_to_live {
165 None => quote::quote! {#value_type},
166 Some(_) => quote::quote! {(std::time::Instant, #value_type)},
167 };
168 match options.lru_max_entries {
170 None => {
171 if let Some(hasher) = &options.custom_hasher {
172 if let Some(hasher_init) = &options.custom_hasher_initializer {
173 return (
174 quote::quote! { #hasher<#key_type, #value_type> },
175 quote::quote! { #hasher_init },
176 );
177 } else {
178 return (
179 quote::quote! { #hasher<#key_type, #value_type> },
180 quote::quote! { #hasher::new() },
181 );
182 }
183 }
184 (
185 quote::quote! { std::collections::HashMap<#key_type, #value_type> },
186 quote::quote! { std::collections::HashMap::new() },
187 )
188 }
189 Some(cap) => {
190 if let Some(_) = &options.custom_hasher {
191 (
192 quote::quote! { compile_error!("Cannot use LRU cache and a custom hasher at the same time") },
193 quote::quote! { std::collections::HashMap::new() },
194 )
195 } else {
196 (
197 quote::quote! { ::memoize::lru::LruCache<#key_type, #value_type> },
198 quote::quote! { ::memoize::lru::LruCache::new(#cap) },
199 )
200 }
201 }
202 }
203 }
204
205 pub(crate) fn cache_access_methods(
208 options: &CacheOptions,
209 ) -> (proc_macro2::TokenStream, proc_macro2::TokenStream) {
210 match options.lru_max_entries {
212 None => (quote::quote! { insert }, quote::quote! { get }),
213 Some(_) => (quote::quote! { put }, quote::quote! { get }),
214 }
215 }
216}
217
218#[proc_macro_attribute]
272pub fn memoize(attr: TokenStream, item: TokenStream) -> TokenStream {
273 let func = parse_macro_input!(item as ItemFn);
274 let sig = &func.sig;
275
276 let fn_name = &sig.ident.to_string();
277 let renamed_name = format!("memoized_original_{}", fn_name);
278 let flush_name = syn::Ident::new(format!("memoized_flush_{}", fn_name).as_str(), sig.span());
279 let size_name = syn::Ident::new(format!("memoized_size_{}", fn_name).as_str(), sig.span());
280 let map_name = format!("memoized_mapping_{}", fn_name);
281
282 if let Some(syn::FnArg::Receiver(_)) = sig.inputs.first() {
283 return quote::quote! { compile_error!("Cannot memoize methods!"); }.into();
284 }
285
286 let options: CacheOptions = syn::parse(attr.clone()).unwrap();
288
289 let input_params = match check_signature(sig, &options) {
291 Ok(p) => p,
292 Err(e) => return e.to_compile_error().into(),
293 };
294
295 let memoized_input_types: Vec<Box<syn::Type>> = input_params
297 .iter()
298 .filter_map(|p| {
299 if p.is_memoized {
300 Some(p.arg_type.clone())
301 } else {
302 None
303 }
304 })
305 .collect();
306 let memoized_input_names: Vec<syn::Ident> = input_params
307 .iter()
308 .filter_map(|p| {
309 if p.is_memoized {
310 Some(p.arg_name.clone())
311 } else {
312 None
313 }
314 })
315 .collect();
316
317 let fn_forwarded_exprs: Vec<_> = input_params
320 .iter()
321 .map(|p| {
322 let ident = p.arg_name.clone();
323 if p.is_memoized {
324 quote::quote! { #ident.clone() }
325 } else {
326 quote::quote! { #ident }
327 }
328 })
329 .collect();
330
331 let input_tuple_type = quote::quote! { (#(#memoized_input_types),*) };
332 let return_type = match &sig.output {
333 syn::ReturnType::Default => quote::quote! { () },
334 syn::ReturnType::Type(_, ty) => ty.to_token_stream(),
335 };
336
337 let store_ident = syn::Ident::new(&map_name.to_uppercase(), sig.span());
339 let (cache_type, cache_init) =
340 store::construct_cache(&options, input_tuple_type, return_type.clone());
341 let store = if options.shared_cache {
342 quote::quote! {
343 ::memoize::lazy_static::lazy_static! {
344 static ref #store_ident : std::sync::Mutex<#cache_type> =
345 std::sync::Mutex::new(#cache_init);
346 }
347 }
348 } else {
349 quote::quote! {
350 std::thread_local! {
351 static #store_ident : std::cell::RefCell<#cache_type> =
352 std::cell::RefCell::new(#cache_init);
353 }
354 }
355 };
356
357 let mut renamed_fn = func.clone();
359 renamed_fn.sig.ident = syn::Ident::new(&renamed_name, func.sig.span());
360 let memoized_id = &renamed_fn.sig.ident;
361
362 let syntax_names_tuple = quote::quote! { (#(#memoized_input_names),*) };
364 let syntax_names_tuple_cloned = quote::quote! { (#(#memoized_input_names.clone()),*) };
365 let forwarding_tuple = quote::quote! { (#(#fn_forwarded_exprs),*) };
366 let (insert_fn, get_fn) = store::cache_access_methods(&options);
367 let (read_memo, memoize) = match options.time_to_live {
368 None => (
369 quote::quote!(ATTR_MEMOIZE_HM__.#get_fn(&#syntax_names_tuple_cloned).cloned()),
370 quote::quote!(ATTR_MEMOIZE_HM__.#insert_fn(#syntax_names_tuple, ATTR_MEMOIZE_RETURN__.clone());),
371 ),
372 Some(ttl) => (
373 quote::quote! {
374 ATTR_MEMOIZE_HM__.#get_fn(&#syntax_names_tuple_cloned).and_then(|(last_updated, ATTR_MEMOIZE_RETURN__)|
375 (last_updated.elapsed() < #ttl).then(|| ATTR_MEMOIZE_RETURN__.clone())
376 )
377 },
378 quote::quote!(ATTR_MEMOIZE_HM__.#insert_fn(#syntax_names_tuple, (std::time::Instant::now(), ATTR_MEMOIZE_RETURN__.clone()));),
379 ),
380 };
381
382 let memoizer = if options.shared_cache {
383 quote::quote! {
384 {
385 let mut ATTR_MEMOIZE_HM__ = #store_ident.lock().unwrap();
386 if let Some(ATTR_MEMOIZE_RETURN__) = #read_memo {
387 return ATTR_MEMOIZE_RETURN__
388 }
389 }
390 let ATTR_MEMOIZE_RETURN__ = #memoized_id #forwarding_tuple;
391
392 let mut ATTR_MEMOIZE_HM__ = #store_ident.lock().unwrap();
393 #memoize
394
395 ATTR_MEMOIZE_RETURN__
396 }
397 } else {
398 quote::quote! {
399 let ATTR_MEMOIZE_RETURN__ = #store_ident.with(|ATTR_MEMOIZE_HM__| {
400 let mut ATTR_MEMOIZE_HM__ = ATTR_MEMOIZE_HM__.borrow_mut();
401 #read_memo
402 });
403 if let Some(ATTR_MEMOIZE_RETURN__) = ATTR_MEMOIZE_RETURN__ {
404 return ATTR_MEMOIZE_RETURN__;
405 }
406
407 let ATTR_MEMOIZE_RETURN__ = #memoized_id #forwarding_tuple;
408
409 #store_ident.with(|ATTR_MEMOIZE_HM__| {
410 let mut ATTR_MEMOIZE_HM__ = ATTR_MEMOIZE_HM__.borrow_mut();
411 #memoize
412 });
413
414 ATTR_MEMOIZE_RETURN__
415 }
416 };
417
418 let vis = &func.vis;
419
420 let flusher = if options.shared_cache {
421 quote::quote! {
422 #vis fn #flush_name() {
423 #store_ident.lock().unwrap().clear();
424 }
425 }
426 } else {
427 quote::quote! {
428 #vis fn #flush_name() {
429 #store_ident.with(|ATTR_MEMOIZE_HM__| ATTR_MEMOIZE_HM__.borrow_mut().clear());
430 }
431 }
432 };
433
434 let size_func = if options.shared_cache {
435 quote::quote! {
436 #vis fn #size_name() -> usize {
437 #store_ident.lock().unwrap().len()
438 }
439 }
440 } else {
441 quote::quote! {
442 #vis fn #size_name() -> usize {
443 #store_ident.with(|ATTR_MEMOIZE_HM__| ATTR_MEMOIZE_HM__.borrow().len())
444 }
445 }
446 };
447
448 quote::quote! {
449 #renamed_fn
450 #flusher
451 #size_func
452 #store
453
454 #[allow(unused_variables, unused_mut)]
455 #vis #sig {
456 #memoizer
457 }
458 }
459 .into()
460}
461
462struct FnArgument {
464 arg_type: Box<syn::Type>,
466
467 arg_name: syn::Ident,
469
470 is_memoized: bool,
472}
473
474fn check_signature(
475 sig: &syn::Signature,
476 options: &CacheOptions,
477) -> Result<Vec<FnArgument>, syn::Error> {
478 if sig.inputs.is_empty() {
479 return Ok(vec![]);
480 }
481
482 let mut params = vec![];
483
484 for a in &sig.inputs {
485 if let syn::FnArg::Typed(ref arg) = a {
486 let arg_type = arg.ty.clone();
487
488 if let syn::Pat::Ident(patident) = &*arg.pat {
489 let arg_name = patident.ident.clone();
490 let is_memoized = !options.ignore.contains(&arg_name);
491 params.push(FnArgument {
492 arg_type,
493 arg_name,
494 is_memoized,
495 });
496 } else {
497 return Err(syn::Error::new(
498 sig.span(),
499 "Cannot memoize arbitrary patterns!",
500 ));
501 }
502 }
503 }
504 Ok(params)
505}
506
507#[cfg(test)]
508mod tests {}