1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
//! An attribute macro to memoize function calls
//!
//! # Usage
//!
//! Just add `#[memo]` to your function.
//!
//! ```
//! use memor::memo;
//! #[memo]
//! fn fib(n: i64) -> i64 {
//!     if n == 0 || n == 1 {
//!         n
//!     } else {
//!         fib(n - 1) + fib(n - 2)
//!     }
//! }
//!
//! assert_eq!(12586269025, fib(50));
//! ```
//!
//! Various functions can be memoized.
//! Because the arguments are saved into keys of `std::collections::HashMap` internally,
//! this macro can be applied to functions all of whose arguments implments `Hash` and `Eq`.
//!
//! ```
//! use memor::memo;
//! #[derive(Hash, Eq, PartialEq)]
//! struct Foo {
//!     a: usize,
//!     b: usize,
//! }
//!
//! #[memo]
//! fn foo(Foo { a, b }: Foo, c: usize) -> usize {
//!     if a == 0 || b == 0 || c == 0 {
//!         1
//!     } else {
//!         foo(Foo { a, b: b - 1 }, c)
//!             .wrapping_add(foo(Foo { a: a - 1, b }, c))
//!             .wrapping_add(foo(Foo { a, b }, c - 1))
//!     }
//! }
//!
//! assert_eq!(foo(Foo { a: 50, b: 50 }, 50), 6753084261833197057);
//! ```
use proc_macro::TokenStream;
use quote::{format_ident, quote};
use syn::{parse_macro_input, parse_quote, Expr, FnArg, ItemFn, ReturnType, Type};

#[proc_macro_attribute]
pub fn memo(_: TokenStream, input: TokenStream) -> TokenStream {
    let input = parse_macro_input!(input as ItemFn);

    let mut key_types = Vec::<Type>::new();
    let mut keys = Vec::<Expr>::new();
    input.sig.inputs.iter().for_each(|x| match x {
        FnArg::Typed(pat) => {
            key_types.push((*pat.ty).clone());
            let p = &pat.pat;
            keys.push(parse_quote!(#p));
        }
        _ => unimplemented!("Unimplemented function signature: {:?}", x),
    });

    let key_type: Type = parse_quote! {(#(#key_types),*)};
    let ret_type: Type = match &input.sig.output {
        ReturnType::Type(_, ty) => (**ty).clone(),
        _ => panic!("required: return type"),
    };
    let memo_name = format_ident!("{}_MEMO", input.sig.ident.to_string().to_uppercase());
    let fn_sig = input.sig;
    let fn_block = input.block;

    let expanded = quote! {
            thread_local!(
                static #memo_name: std::cell::RefCell<std::collections::HashMap<#key_type, #ret_type>> =
                    std::cell::RefCell::new(std::collections::HashMap::new())
            );

            #fn_sig {
                if let Some(ret) = #memo_name.with(|memo| memo.borrow().get(&(#(#keys),*)).cloned()) {
                    ret
                } else {
                    let ret: #ret_type = (||#fn_block)();
                    #memo_name.with(|memo| {
                        memo.borrow_mut().insert((#(#keys),*), ret.clone());
                    });
                    ret
                }
            }
    };

    expanded.into()
}