1extern crate proc_macro;
2
3use std::collections::{HashMap, HashSet};
4
5use proc_macro::TokenStream;
6use quote::{ToTokens, format_ident, quote};
7use syn::{FnArg, ItemFn, Meta, Pat, Token, Visibility, parse_macro_input, punctuated::Punctuated};
8
9#[proc_macro_attribute]
63pub fn precalculate(attr: TokenStream, item: TokenStream) -> TokenStream {
64 let metas: Punctuated<Meta, Token![,]> =
65 parse_macro_input!(attr with Punctuated::parse_terminated);
66
67 #[derive(Debug, Hash, PartialEq, Eq, Clone, Copy)]
68 enum Options {
69 Fallback,
70 Option,
71 Panic,
72 }
73
74 let mut mode = Vec::new();
75 let mut range_map = HashMap::<String, proc_macro2::TokenStream>::new();
76 for meta in metas {
77 match meta {
78 Meta::NameValue(mnv) => {
79 let ident = mnv
80 .path
81 .get_ident()
82 .expect("Attribute key must be an identifier")
83 .to_string();
84 let value_expr = mnv.value.into_token_stream();
85 if range_map.insert(ident.clone(), value_expr).is_some() {
86 panic!("Duplicated key: {ident}");
87 }
88 }
89 Meta::Path(opt) => {
90 match opt.to_token_stream().to_string().trim() {
91 "option" => mode.push(Options::Option),
92 "panic" => mode.push(Options::Panic),
93 "fallback" => mode.push(Options::Fallback),
94 opt => panic!("Unknown option: {opt}"),
95 };
96 }
97 _ => (),
98 }
99 }
100
101 let mode = match mode.len() {
102 0 => Options::Fallback,
103 1 => mode[0],
104 _ => {
105 panic!(
106 "precalculate macro may only take one operating mode at a time, found: {:?}.",
107 mode
108 )
109 }
110 };
111
112 let mut func = parse_macro_input!(item as ItemFn);
113 let visibility = func.vis.clone();
114 let func_ident = func.sig.ident.clone();
115 let new_func_ident = format_ident!("_{func_ident}_original");
116 func.vis = Visibility::Public(syn::token::Pub::default());
117 func.sig.ident = new_func_ident.clone();
118 let func_return_type = &func.sig.output;
119 let mut return_ty = match func_return_type {
120 syn::ReturnType::Default => panic!("Function must have a return type."),
121 syn::ReturnType::Type(_, ty) => ty.clone(),
122 };
123
124 let mut arg_info = Vec::new();
125 for arg in &func.sig.inputs {
126 if let FnArg::Typed(pat_type) = arg
127 && let Pat::Ident(pat_ident) = &*pat_type.pat
128 {
129 let arg_name = pat_ident.ident.to_string();
130 let arg_type = &pat_type.ty;
131 if let Some(range_expr) = range_map.get(&arg_name) {
132 arg_info.push((
133 pat_ident.ident.clone(),
134 arg_type.clone(),
135 range_expr.clone(),
136 ));
137 } else {
138 panic!("Argument '{arg_name}' does not have a specified range.");
139 }
140 }
141 }
142
143 let const_defs = arg_info.iter().map(|(ident, ty, range_expr)| {
144 let upper_ident = ident.to_string().to_uppercase();
145 let range_ident = format_ident!("{}_RANGE", upper_ident);
146 let min_ident = format_ident!("{}_MIN", upper_ident);
147 let max_ident = format_ident!("{}_MAX", upper_ident);
148 let size_ident = format_ident!("{}_SIZE", upper_ident);
149
150 quote! {
151 const #range_ident: std::ops::RangeInclusive<#ty> = #range_expr;
152 const #min_ident: #ty = *#range_ident.start();
153 const #max_ident: #ty = *#range_ident.end();
154 const #size_ident: usize = (#max_ident as isize - #min_ident as isize + 1) as usize;
155 }
156 });
157
158 let table_type = arg_info
159 .iter()
160 .rev()
161 .fold(quote! { #return_ty }, |inner, (ident, _, _)| {
162 let size_ident = format_ident!("{}_SIZE", ident.to_string().to_uppercase());
163 quote! { [#inner; #size_ident] }
164 });
165
166 let func_args = arg_info.iter().map(|(ident, _, _)| ident);
167
168 let generate_table_fn = {
169 let table_init_value = quote! { recuerdame::PrecalcConst::DEFAULT };
170 let table_init_expr =
171 arg_info
172 .iter()
173 .rev()
174 .fold(table_init_value, |inner, (ident, _, _)| {
175 let size_ident = format_ident!("{}_SIZE", ident.to_string().to_uppercase());
176 quote! { [#inner; #size_ident] }
177 });
178
179 let mut nested_loops = {
180 let value_calcs = arg_info.iter().map(|(ident, ty, _)| {
181 let min_ident = format_ident!("{}_MIN", ident.to_string().to_uppercase());
182 let loop_var = format_ident!("{}_idx", ident);
183 quote! { let #ident = #min_ident + #loop_var as #ty; }
184 });
185 let table_access = arg_info
186 .iter()
187 .fold(quote! { table }, |acc, (ident, _, _)| {
188 let loop_var = format_ident!("{}_idx", ident);
189 quote! { #acc[#loop_var] }
190 });
191
192 let func_args = func_args.clone();
193
194 quote! {
195 #(#value_calcs)*
196 #table_access = #new_func_ident(#(#func_args),*);
197 }
198 };
199
200 for (ident, _, _) in arg_info.iter().rev() {
201 let loop_var = format_ident!("{}_idx", ident);
202 let size_ident = format_ident!("{}_SIZE", ident.to_string().to_uppercase());
203 nested_loops = quote! {
204 let mut #loop_var: usize = 0;
205 while #loop_var < #size_ident {
206 #nested_loops
207 #loop_var += 1;
208 }
209 };
210 }
211
212 quote! {
213 const fn generate_table() -> #table_type {
214 let mut table = #table_init_expr;
215 #nested_loops
216 table
217 }
218 }
219 };
220
221 let mod_name = format_ident!("_mod_precalc_{}", func_ident);
222
223 let precalc_fn = {
224 let lookup_table_ident =
225 format_ident!("LOOKUP_TABLE_{}", func_ident.to_string().to_uppercase());
226
227 let fn_params = arg_info.iter().map(|(ident, ty, _)| quote! { #ident: #ty });
228 let index_calcs = arg_info.iter().map(|(ident, _ty, _)| {
229 let min_ident = format_ident!("{}_MIN", ident.to_string().to_uppercase());
230 let index_var = format_ident!("{}_idx", ident);
231 quote! { let #index_var = (#ident - #min_ident) as usize; }
232 });
233
234 let bounds_check_expr = {
235 let per_ident_check = arg_info.iter().map(|(ident, _ty, _)| {
236 let min_ident = format_ident!("{}_MIN", ident.to_string().to_uppercase());
237 let max_ident = format_ident!("{}_MAX", ident.to_string().to_uppercase());
238 quote! { #min_ident <= #ident && #ident <= #max_ident }
239 });
240
241 quote! { #(#per_ident_check &&)* true }
242 };
243
244 let mut table_access =
245 arg_info
246 .iter()
247 .fold(quote! { #lookup_table_ident }, |acc, (ident, _, _)| {
248 let index_var = format_ident!("{}_idx", ident);
249 quote! { #acc[#index_var] }
250 });
251
252 let mode_check = match mode {
253 Options::Panic => None,
254 Options::Fallback => Some(quote! {
255 if !(#bounds_check_expr) {
256 return #new_func_ident(#(#func_args),*);
257 }
258 }),
259 Options::Option => {
260 *return_ty.as_mut() = syn::Type::Verbatim(quote! { Option<#return_ty> });
262 table_access = quote! { Some(#table_access)};
264 Some(quote! {
265 if !(#bounds_check_expr) {
266 return None;
267 }
268 })
269 }
270 };
271
272 quote! {
273 pub const fn #func_ident(#(#fn_params),*) -> #return_ty {
274 #mode_check
275 #(#index_calcs)*
276 #table_access
277 }
278 }
279 };
280
281 let lookup_table_ident =
282 format_ident!("LOOKUP_TABLE_{}", func_ident.to_string().to_uppercase());
283 let expanded = quote! {
284
285 mod #mod_name {
286
287 use super::*;
288
289 #func
290
291 #(#const_defs)*
292
293 #generate_table_fn
294
295 pub const #lookup_table_ident: &'static #table_type = &generate_table();
296
297 #precalc_fn
298 }
299
300 #[allow(unused_imports)]
301 #visibility use #mod_name::#func_ident;
302 };
303
304 expanded.into()
305}