1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
use proc_macro::TokenStream;
use quote::{format_ident, quote};
use syn::{parse_macro_input, parse_quote, Expr, FnArg, ItemFn, ReturnType, Type};
#[proc_macro_attribute]
pub fn memo(_: TokenStream, input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as ItemFn);
let mut key_types = Vec::<Type>::new();
let mut keys = Vec::<Expr>::new();
input.sig.inputs.iter().for_each(|x| match x {
FnArg::Typed(pat) => {
key_types.push((*pat.ty).clone());
let p = &pat.pat;
keys.push(parse_quote!(#p));
}
_ => unimplemented!("Unimplemented function signature: {:?}", x),
});
let key_type: Type = parse_quote! {(#(#key_types),*)};
let ret_type: Type = match &input.sig.output {
ReturnType::Type(_, ty) => (**ty).clone(),
_ => panic!("required: return type"),
};
let memo_name = format_ident!("{}_MEMO", input.sig.ident.to_string().to_uppercase());
let fn_sig = input.sig;
let fn_block = input.block;
let expanded = quote! {
thread_local!(
static #memo_name: std::cell::RefCell<std::collections::HashMap<#key_type, #ret_type>> =
std::cell::RefCell::new(std::collections::HashMap::new())
);
#fn_sig {
if let Some(ret) = #memo_name.with(|memo| memo.borrow().get(&(#(#keys),*)).cloned()) {
ret
} else {
let ret: #ret_type = (||#fn_block)();
#memo_name.with(|memo| {
memo.borrow_mut().insert((#(#keys),*), ret.clone());
});
ret
}
}
};
expanded.into()
}