mincache_impl/
lib.rs

1use proc_macro::TokenStream;
2use quote::{quote, format_ident};
3use syn::{AttributeArgs, parse_macro_input, parse_quote, ItemFn, ReturnType};
4use darling::FromMeta;
5
6#[derive(Debug, FromMeta)]
7struct MacroArgs {
8	t: u64,
9	fmt: String,
10
11	#[darling(default)]
12	reference: Option<bool>,
13}
14
15/// A timed cache.
16/// The function will be called once, and then wait the specified amount of time until it is able to be called again.
17/// In the mean time it returns the cached value, being cloned each time as an owned value.
18#[proc_macro_attribute]
19pub fn timed(args: TokenStream, item: TokenStream) -> TokenStream {
20	let attr_args = parse_macro_input!(args as AttributeArgs);
21	let mut func = parse_macro_input!(item as ItemFn);
22
23	let args = match MacroArgs::from_list(&attr_args) {
24		Ok(v) => v,
25		Err(e) => { return TokenStream::from(e.write_errors()); }
26	};
27
28	let innerfn_name = format_ident!("inner_{}", func.sig.ident);
29	let time_const_var = format_ident!("__TIME_SEC_{}", func.sig.ident);
30	let time_var = format_ident!("__LAST_TIME_{}", func.sig.ident);
31	let last_val_var = format_ident!("__LAST_VAL_{}", func.sig.ident);
32
33	let ret_type = if let ReturnType::Type(_, ret_type) = &func.sig.output {
34		ret_type
35	} else {
36		panic!("mincache: function must have a return type (for now)");
37	};
38
39	let stmts = func.block.stmts.clone();
40	let inputs = &func.sig.inputs;
41
42	let function_args = func.sig.inputs
43		.iter()
44		.map(|arg| match arg {
45			syn::FnArg::Typed(arg) => &arg.pat,
46			_ => panic!("mincache: function arguments must be named")
47		})
48		.collect::<Vec<_>>();
49
50	func.block.stmts.clear();
51
52	// Cooldown hasn't passed. Return last value.
53	let no_cooldown = if args.reference.unwrap_or(false) {
54		quote! {
55			{
56				#[allow(unused_unsafe)]
57				return unsafe { (*#last_val_var).as_ref().unwrap_unchecked() };
58			}
59		}
60	} else {
61		quote! {
62			{
63				#[allow(unused_unsafe)]
64				return unsafe { (*#last_val_var).clone().unwrap_unchecked() }
65			}
66		}
67	};
68
69	let initialize = if args.reference.unwrap_or(false) {
70		quote! {
71			{
72				// First initialization OR time has passed
73				let __ret = #innerfn_name( #(#function_args),* );
74				unsafe {
75					*#time_var.get_mut() = Some(__now);
76					*#last_val_var.get_mut() = Some(__ret);
77				}
78				return __ret;
79			}
80		}
81	} else {
82		quote! {
83			{
84				// First initialization OR time has passed
85				let __ret = #innerfn_name( #(#function_args),* );
86				unsafe {
87					*#time_var.get_mut() = Some(__now);
88					*#last_val_var.get_mut() = Some(__ret.clone());
89				}
90				return __ret;
91			}
92		}
93	};
94
95
96	func.block.stmts.push(parse_quote! {
97		{
98			#[inline(always)]
99			fn #innerfn_name( #inputs ) -> #ret_type {
100				#(#stmts)*
101			}
102
103			let __now = std::time::Instant::now();
104			match *#time_var {
105				Some(last_time) if __now.duration_since(last_time) < #time_const_var  => #no_cooldown,
106				_ => #initialize
107			}
108		}
109	});
110
111
112	let timefn = format_ident!("from_{}", args.fmt);
113	let time = args.t;
114
115	TokenStream::from(quote! {
116		#[allow(non_upper_case_globals)]
117		const #time_const_var: core::time::Duration = core::time::Duration::#timefn(#time);
118		#[allow(non_upper_case_globals)]
119		static #time_var: mincache::SyncUnsafeCell<Option<std::time::Instant>> = mincache::SyncUnsafeCell::new(None);
120		#[allow(non_upper_case_globals)]
121		static #last_val_var: mincache::SyncUnsafeCell<Option<#ret_type>> = mincache::SyncUnsafeCell::new(None);
122
123		#func
124	})
125}