1use proc_macro::TokenStream;
2use quote::{quote, format_ident};
3use syn::{AttributeArgs, parse_macro_input, parse_quote, ItemFn, ReturnType};
4use darling::FromMeta;
5
6#[derive(Debug, FromMeta)]
7struct MacroArgs {
8 t: u64,
9 fmt: String,
10
11 #[darling(default)]
12 reference: Option<bool>,
13}
14
15#[proc_macro_attribute]
19pub fn timed(args: TokenStream, item: TokenStream) -> TokenStream {
20 let attr_args = parse_macro_input!(args as AttributeArgs);
21 let mut func = parse_macro_input!(item as ItemFn);
22
23 let args = match MacroArgs::from_list(&attr_args) {
24 Ok(v) => v,
25 Err(e) => { return TokenStream::from(e.write_errors()); }
26 };
27
28 let innerfn_name = format_ident!("inner_{}", func.sig.ident);
29 let time_const_var = format_ident!("__TIME_SEC_{}", func.sig.ident);
30 let time_var = format_ident!("__LAST_TIME_{}", func.sig.ident);
31 let last_val_var = format_ident!("__LAST_VAL_{}", func.sig.ident);
32
33 let ret_type = if let ReturnType::Type(_, ret_type) = &func.sig.output {
34 ret_type
35 } else {
36 panic!("mincache: function must have a return type (for now)");
37 };
38
39 let stmts = func.block.stmts.clone();
40 let inputs = &func.sig.inputs;
41
42 let function_args = func.sig.inputs
43 .iter()
44 .map(|arg| match arg {
45 syn::FnArg::Typed(arg) => &arg.pat,
46 _ => panic!("mincache: function arguments must be named")
47 })
48 .collect::<Vec<_>>();
49
50 func.block.stmts.clear();
51
52 let no_cooldown = if args.reference.unwrap_or(false) {
54 quote! {
55 {
56 #[allow(unused_unsafe)]
57 return unsafe { (*#last_val_var).as_ref().unwrap_unchecked() };
58 }
59 }
60 } else {
61 quote! {
62 {
63 #[allow(unused_unsafe)]
64 return unsafe { (*#last_val_var).clone().unwrap_unchecked() }
65 }
66 }
67 };
68
69 let initialize = if args.reference.unwrap_or(false) {
70 quote! {
71 {
72 let __ret = #innerfn_name( #(#function_args),* );
74 unsafe {
75 *#time_var.get_mut() = Some(__now);
76 *#last_val_var.get_mut() = Some(__ret);
77 }
78 return __ret;
79 }
80 }
81 } else {
82 quote! {
83 {
84 let __ret = #innerfn_name( #(#function_args),* );
86 unsafe {
87 *#time_var.get_mut() = Some(__now);
88 *#last_val_var.get_mut() = Some(__ret.clone());
89 }
90 return __ret;
91 }
92 }
93 };
94
95
96 func.block.stmts.push(parse_quote! {
97 {
98 #[inline(always)]
99 fn #innerfn_name( #inputs ) -> #ret_type {
100 #(#stmts)*
101 }
102
103 let __now = std::time::Instant::now();
104 match *#time_var {
105 Some(last_time) if __now.duration_since(last_time) < #time_const_var => #no_cooldown,
106 _ => #initialize
107 }
108 }
109 });
110
111
112 let timefn = format_ident!("from_{}", args.fmt);
113 let time = args.t;
114
115 TokenStream::from(quote! {
116 #[allow(non_upper_case_globals)]
117 const #time_const_var: core::time::Duration = core::time::Duration::#timefn(#time);
118 #[allow(non_upper_case_globals)]
119 static #time_var: mincache::SyncUnsafeCell<Option<std::time::Instant>> = mincache::SyncUnsafeCell::new(None);
120 #[allow(non_upper_case_globals)]
121 static #last_val_var: mincache::SyncUnsafeCell<Option<#ret_type>> = mincache::SyncUnsafeCell::new(None);
122
123 #func
124 })
125}