Skip to main content

recuerdame_macros/
lib.rs

1extern crate proc_macro;
2
3use std::collections::{HashMap, HashSet};
4
5use proc_macro::TokenStream;
6use quote::{ToTokens, format_ident, quote};
7use syn::{FnArg, ItemFn, Meta, Pat, Token, Visibility, parse_macro_input, punctuated::Punctuated};
8
9/// Precalculate all possible values for const function at compile time.
10///
11/// This macro builds a look-up table at compile time to avoid
12/// having to run complicated arithmentic at runtime.
13///
14/// This macro supports three operating modes:
15///  - **fallback** (Default): The fallback operating mode never panic (unless the implementation panics). It will use the look up table for the specified ranges and use the original implementation if outside of the range.
16///  - **option**: The option operating mode will change the function to return an [Option]. [Some] if the input is in range, [None] if not.
17///  - **panic**: If the input is outside of the range specified in the macro the function will panic.
18///
19/// The option and keep modes will require additional bounds checks which may come at a cost.
20///
21/// Please benchmark the functions to decide if it's worth using a look-up table.
22///
23/// Examples:
24/// ```rust
25/// use recuerdame::precalculate;
26///
27/// #[precalculate(a = 0..=10, b = 0..=4)]
28/// pub const fn add(a: i32, b: i32) -> i32 {
29///     a + b
30/// }
31///
32/// #[precalculate(a = 0..=10, b = 0..=4, option)]
33/// pub const fn add_opt(a: i32, b: i32) -> i32 {
34///     a + b
35/// }
36///
37/// #[precalculate(a = 0..=10, b = 0..=4, panic)]
38/// pub const fn add_panic(a: i32, b: i32) -> i32 {
39///     a + b
40/// }
41///
42/// #[test]
43/// fn it_works() {
44///     assert_eq!(add(8, 2), 10);
45///     assert_eq!(add(0, 0), 0);
46///     assert_eq!(add_keep(5, 4), 9);
47///     assert_eq!(add_keep(25, 0), 25);
48/// }
49///
50/// #[test]
51/// fn it_works_opt() {
52///     assert_eq!(add_opt(5, 4), Some(9));
53///     assert_eq!(add_opt(25, 0), None);
54/// }
55///
56/// #[test]
57/// #[should_panic]
58/// fn outside_bounds_panics() {
59///     add_panic(25, 9);
60/// }
61/// ```
62#[proc_macro_attribute]
63pub fn precalculate(attr: TokenStream, item: TokenStream) -> TokenStream {
64    let metas: Punctuated<Meta, Token![,]> =
65        parse_macro_input!(attr with Punctuated::parse_terminated);
66
67    #[derive(Debug, Hash, PartialEq, Eq, Clone, Copy)]
68    enum Options {
69        Fallback,
70        Option,
71        Panic,
72    }
73
74    let mut mode = Vec::new();
75    let mut range_map = HashMap::<String, proc_macro2::TokenStream>::new();
76    for meta in metas {
77        match meta {
78            Meta::NameValue(mnv) => {
79                let ident = mnv
80                    .path
81                    .get_ident()
82                    .expect("Attribute key must be an identifier")
83                    .to_string();
84                let value_expr = mnv.value.into_token_stream();
85                if range_map.insert(ident.clone(), value_expr).is_some() {
86                    panic!("Duplicated key: {ident}");
87                }
88            }
89            Meta::Path(opt) => {
90                match opt.to_token_stream().to_string().trim() {
91                    "option" => mode.push(Options::Option),
92                    "panic" => mode.push(Options::Panic),
93                    "fallback" => mode.push(Options::Fallback),
94                    opt => panic!("Unknown option: {opt}"),
95                };
96            }
97            _ => (),
98        }
99    }
100
101    let mode = match mode.len() {
102        0 => Options::Fallback,
103        1 => mode[0],
104        _ => {
105            panic!(
106                "precalculate macro may only take one operating mode at a time, found: {:?}.",
107                mode
108            )
109        }
110    };
111
112    let mut func = parse_macro_input!(item as ItemFn);
113    let visibility = func.vis.clone();
114    let func_ident = func.sig.ident.clone();
115    let new_func_ident = format_ident!("_{func_ident}_original");
116    func.vis = Visibility::Public(syn::token::Pub::default());
117    func.sig.ident = new_func_ident.clone();
118    let func_return_type = &func.sig.output;
119    let mut return_ty = match func_return_type {
120        syn::ReturnType::Default => panic!("Function must have a return type."),
121        syn::ReturnType::Type(_, ty) => ty.clone(),
122    };
123
124    let mut arg_info = Vec::new();
125    for arg in &func.sig.inputs {
126        if let FnArg::Typed(pat_type) = arg
127            && let Pat::Ident(pat_ident) = &*pat_type.pat
128        {
129            let arg_name = pat_ident.ident.to_string();
130            let arg_type = &pat_type.ty;
131            if let Some(range_expr) = range_map.get(&arg_name) {
132                arg_info.push((
133                    pat_ident.ident.clone(),
134                    arg_type.clone(),
135                    range_expr.clone(),
136                ));
137            } else {
138                panic!("Argument '{arg_name}' does not have a specified range.");
139            }
140        }
141    }
142
143    let const_defs = arg_info.iter().map(|(ident, ty, range_expr)| {
144        let upper_ident = ident.to_string().to_uppercase();
145        let range_ident = format_ident!("{}_RANGE", upper_ident);
146        let min_ident = format_ident!("{}_MIN", upper_ident);
147        let max_ident = format_ident!("{}_MAX", upper_ident);
148        let size_ident = format_ident!("{}_SIZE", upper_ident);
149
150        quote! {
151            const #range_ident: std::ops::RangeInclusive<#ty> = #range_expr;
152            const #min_ident: #ty = *#range_ident.start();
153            const #max_ident: #ty = *#range_ident.end();
154            const #size_ident: usize = (#max_ident as isize - #min_ident as isize + 1) as usize;
155        }
156    });
157
158    let table_type = arg_info
159        .iter()
160        .rev()
161        .fold(quote! { #return_ty }, |inner, (ident, _, _)| {
162            let size_ident = format_ident!("{}_SIZE", ident.to_string().to_uppercase());
163            quote! { [#inner; #size_ident] }
164        });
165
166    let func_args = arg_info.iter().map(|(ident, _, _)| ident);
167
168    let generate_table_fn = {
169        let table_init_value = quote! { recuerdame::PrecalcConst::DEFAULT };
170        let table_init_expr =
171            arg_info
172                .iter()
173                .rev()
174                .fold(table_init_value, |inner, (ident, _, _)| {
175                    let size_ident = format_ident!("{}_SIZE", ident.to_string().to_uppercase());
176                    quote! { [#inner; #size_ident] }
177                });
178
179        let mut nested_loops = {
180            let value_calcs = arg_info.iter().map(|(ident, ty, _)| {
181                let min_ident = format_ident!("{}_MIN", ident.to_string().to_uppercase());
182                let loop_var = format_ident!("{}_idx", ident);
183                quote! { let #ident = #min_ident + #loop_var as #ty; }
184            });
185            let table_access = arg_info
186                .iter()
187                .fold(quote! { table }, |acc, (ident, _, _)| {
188                    let loop_var = format_ident!("{}_idx", ident);
189                    quote! { #acc[#loop_var] }
190                });
191
192            let func_args = func_args.clone();
193
194            quote! {
195                #(#value_calcs)*
196                #table_access = #new_func_ident(#(#func_args),*);
197            }
198        };
199
200        for (ident, _, _) in arg_info.iter().rev() {
201            let loop_var = format_ident!("{}_idx", ident);
202            let size_ident = format_ident!("{}_SIZE", ident.to_string().to_uppercase());
203            nested_loops = quote! {
204                let mut #loop_var: usize = 0;
205                while #loop_var < #size_ident {
206                    #nested_loops
207                    #loop_var += 1;
208                }
209            };
210        }
211
212        quote! {
213            const fn generate_table() -> #table_type {
214                let mut table = #table_init_expr;
215                #nested_loops
216                table
217            }
218        }
219    };
220
221    let mod_name = format_ident!("_mod_precalc_{}", func_ident);
222
223    let precalc_fn = {
224        let lookup_table_ident =
225            format_ident!("LOOKUP_TABLE_{}", func_ident.to_string().to_uppercase());
226
227        let fn_params = arg_info.iter().map(|(ident, ty, _)| quote! { #ident: #ty });
228        let index_calcs = arg_info.iter().map(|(ident, _ty, _)| {
229            let min_ident = format_ident!("{}_MIN", ident.to_string().to_uppercase());
230            let index_var = format_ident!("{}_idx", ident);
231            quote! { let #index_var = (#ident - #min_ident) as usize; }
232        });
233
234        let bounds_check_expr = {
235            let per_ident_check = arg_info.iter().map(|(ident, _ty, _)| {
236                let min_ident = format_ident!("{}_MIN", ident.to_string().to_uppercase());
237                let max_ident = format_ident!("{}_MAX", ident.to_string().to_uppercase());
238                quote! { #min_ident <= #ident && #ident <= #max_ident }
239            });
240
241            quote! { #(#per_ident_check &&)* true }
242        };
243
244        let mut table_access =
245            arg_info
246                .iter()
247                .fold(quote! { #lookup_table_ident }, |acc, (ident, _, _)| {
248                    let index_var = format_ident!("{}_idx", ident);
249                    quote! { #acc[#index_var] }
250                });
251
252        let mode_check = match mode {
253            Options::Panic => None,
254            Options::Fallback => Some(quote! {
255                if !(#bounds_check_expr) {
256                    return #new_func_ident(#(#func_args),*);
257                }
258            }),
259            Options::Option => {
260                // Change signature to return option
261                *return_ty.as_mut() = syn::Type::Verbatim(quote! { Option<#return_ty> });
262                // Change the table access expression to return Some
263                table_access = quote! { Some(#table_access)};
264                Some(quote! {
265                    if !(#bounds_check_expr) {
266                        return None;
267                    }
268                })
269            }
270        };
271
272        quote! {
273            pub const fn #func_ident(#(#fn_params),*) -> #return_ty {
274                #mode_check
275                #(#index_calcs)*
276                #table_access
277            }
278        }
279    };
280
281    let lookup_table_ident =
282        format_ident!("LOOKUP_TABLE_{}", func_ident.to_string().to_uppercase());
283    let expanded = quote! {
284
285        mod #mod_name {
286
287            use super::*;
288
289            #func
290
291            #(#const_defs)*
292
293            #generate_table_fn
294
295            pub const #lookup_table_ident: &'static #table_type = &generate_table();
296
297            #precalc_fn
298        }
299
300        #[allow(unused_imports)]
301        #visibility use #mod_name::#func_ident;
302    };
303
304    expanded.into()
305}