1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
//! An attribute macro to memoise a function.
//!
//! # Usage
//!
//! You can just add attribute `memoise` to normal functions
//! that you want to memoise against arguments.
//! For example:
//!
//! ```
//! use memoise::memoise;
//!
//! #[memoise(keys(n = 100))]
//! fn fib(n: usize) -> usize {
//!     if n == 0 || n == 1 {
//!         return n;
//!     }
//!     fib(n - 1) + fib(n - 2)
//! }
//! ```
//!
//! You need to specify upper-bound of arguments statically.
//! Calling memoised function by arguments on out of bounds
//! cause panic.
//!
//! You can specify multiple keys for memoise.
//!
//! ```
//! use memoise::memoise;
//!
//! #[memoise(keys(n = 100, m = 50))]
//! fn comb(n: usize, m: usize) -> usize {
//!     if m == 0 {
//!         return 1;
//!     }
//!     if n == 0 {
//!         return 0;
//!     }
//!     comb(n - 1, m - 1) + comb(n - 1, m)
//! }
//! ```
//!
//! To reuse memoised functions depend on non-key arguments,
//! you can reset memoise tables by calling automatic defined
//! function named `<function-name>_reset`. On above code,
//! the function `comb_reset` is defined, so you can call
//! that function to reset the table.
//!
//! ```ignore
//! let a = comb(10, 5); // calculation
//! comb_reset();        // reset the memoization table
//! let a = comb(10, 5); // calculation executed again
//! ```

extern crate proc_macro;

use darling::FromMeta;
use proc_macro::TokenStream;
use proc_macro2::Span;
use quote::quote;
use std::collections::HashMap;
use syn::{parse_macro_input, parse_quote, AttributeArgs, Expr, Ident, ItemFn, ReturnType, Type};

#[derive(Debug, FromMeta)]
struct MemoiseArgs {
    keys: HashMap<String, usize>,
}

#[proc_macro_attribute]
pub fn memoise(attr: TokenStream, item: TokenStream) -> TokenStream {
    let attr = parse_macro_input!(attr as AttributeArgs);
    let item_fn = parse_macro_input!(item as ItemFn);

    let args = MemoiseArgs::from_list(&attr).unwrap();
    let fn_sig = item_fn.sig;
    let fn_block = item_fn.block;

    let cache_ident = Ident::new(&fn_sig.ident.to_string().to_uppercase(), Span::call_site());
    let ret_type = if let ReturnType::Type(_, ty) = &fn_sig.output {
        ty
    } else {
        panic!("function return type is required");
    };

    let lengths = args.keys.values().collect::<Vec<_>>();

    let cache_type = lengths.iter().rev().fold(
        parse_quote! { Option<#ret_type> },
        |acc: Type, limit| parse_quote! { [#acc; #limit + 1] },
    );

    let cache_init = lengths
        .iter()
        .rev()
        .fold(parse_quote! { None }, |acc: Expr, limit| {
            parse_quote! {
                [#acc; #limit + 1]
            }
        });

    let key_vars = args
        .keys
        .keys()
        .map(|k| Ident::new(k, Span::call_site()))
        .collect::<Vec<_>>();

    let reset_expr = (0..args.keys.len()).fold(quote! { *r = None }, |acc, _| {
        quote! {
            for r in r.iter_mut() {
                #acc
            }
        }
    });

    let reset_expr: Expr = parse_quote! {
        {
            let mut r = cache.borrow_mut();
            #reset_expr;
        }
    };

    let reset_fn = Ident::new(
        &format!("{}_reset", fn_sig.ident.to_string()),
        Span::call_site(),
    );

    let gen = quote! {
        thread_local!(
            static #cache_ident: std::cell::RefCell<#cache_type> =
                std::cell::RefCell::new(#cache_init));

        fn #reset_fn() {
            #cache_ident.with(|cache| #reset_expr);
        }

        #fn_sig {
            if let Some(ret) = #cache_ident.with(|cache| cache.borrow()#([#key_vars])*) {
                return ret;
            }

            let ret: #ret_type = (|| #fn_block )();

            #cache_ident.with(|cache| {
                let mut bm = cache.borrow_mut();
                bm #([#key_vars])* = Some(ret);
            });

            ret
        }
    };

    gen.into()
}