numeric_lut/
lib.rs

1//! # `numeric-lut`
2//!
3//! A library for generating numeric lookup functions.  Currently, it requires the use of the
4//! `proc_macro_hygiene` nightly feature.
5//!
6//! ## Examples
7//!
8//! ```
9//! #![feature(proc_macro_hygiene)]
10//! let lut = numeric_lut::lut!(|x @ 0..8, y @ 0..16| -> u32 { x as u32 + y as u32 });
11//! let x = lut(3, 10);
12//! assert_eq!(13, x);
13//! ```
14#![deny(
15    missing_docs,
16    missing_debug_implementations,
17    missing_copy_implementations,
18    trivial_casts,
19    trivial_numeric_casts,
20    unsafe_code,
21    unstable_features,
22    unused_import_braces,
23    unused_qualifications
24)]
25
26extern crate proc_macro;
27
28struct Lut {
29    #[allow(unused)]
30    or1_token: syn::Token![|],
31    inputs: syn::punctuated::Punctuated<Param, syn::Token![,]>,
32    #[allow(unused)]
33    or2_token: syn::Token![|],
34    #[allow(unused)]
35    arrow_token: syn::Token![->],
36    return_type: syn::Type,
37    body: syn::Expr,
38}
39
40struct Param {
41    ident: syn::Ident,
42    lo: usize,
43    exclusive_end: bool,
44    hi: usize,
45}
46
47/// Generates a numeric lookup function.
48///
49/// The macro is function-like and accepts an expression that looks like a closure.  Only parameters
50/// that use range patterns (like `x @ 0..1`) are accepted.  All parameters are implicitly of type
51/// `usize` since they will be used as indices for lookup tables.
52#[proc_macro]
53pub fn lut(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
54    let input = syn::parse_macro_input!(input as Lut);
55
56    let table_data = input.inputs.iter().rev().fold(input.body, |body, param| {
57        if param.exclusive_end {
58            generate_array(&param.ident, param.lo..param.hi, body)
59        } else {
60            generate_array(&param.ident, param.lo..=param.hi, body)
61        }
62    });
63
64    let lut_access = input
65        .inputs
66        .iter()
67        .fold(quote::quote!(__LUT), |expr, param| {
68            let ident = &param.ident;
69            quote::quote!(#expr[#ident])
70        });
71
72    let lut_params = input.inputs.iter().map(|param| {
73        let ident = &param.ident;
74        quote::quote!(#ident: usize)
75    });
76
77    let lut_type = input
78        .inputs
79        .iter()
80        .rev()
81        .fold(input.return_type, |ty, param| {
82            let count = if param.exclusive_end {
83                param.hi - param.lo
84            } else {
85                param.hi - param.lo + 1
86            };
87            quote::quote!([#ty; #count]).into()
88        });
89
90    let output = quote::quote!({
91        static __LUT: #lut_type = #table_data;
92        |#(#lut_params),*| #lut_access
93    });
94
95    output.into()
96}
97
98fn generate_array(
99    ident: &syn::Ident,
100    range: impl Iterator<Item = usize>,
101    body: syn::Expr,
102) -> syn::Expr {
103    let items = range.map(|n| {
104        quote::quote!({
105            #[allow(non_upper_case_globals)]
106            const #ident: usize = #n;
107            #body
108        })
109    });
110    quote::quote!([#(#items),*]).into()
111}
112
113impl Param {
114    fn from_pat(pat: syn::Pat) -> syn::Result<Self> {
115        use syn::spanned::Spanned;
116        match pat {
117            syn::Pat::Ident(pat_ident) => Self::from_pat_ident(pat_ident),
118            other => Err(syn::Error::new(
119                other.span(),
120                "this parameter must have a range pattern (e.g. `x @ 1..2` or `y @ 3..=4`)",
121            )),
122        }
123    }
124
125    fn from_pat_ident(pat_ident: syn::PatIdent) -> syn::Result<Self> {
126        use syn::spanned::Spanned;
127        match pat_ident {
128            syn::PatIdent {
129                ident,
130                subpat,
131                ..
132            } => match subpat {
133                Some((_, pat)) => {
134                    let pat_span = pat.span();
135                    match *pat {
136                        syn::Pat::Range(syn::PatRange {
137                            lo,
138                            limits,
139                            hi,
140                            ..
141                        }) => match *lo {
142                            syn::Expr::Lit(syn::ExprLit {
143                                lit: syn::Lit::Int(lo),
144                                ..
145                            }) => {
146                                let lo = lo.base10_parse()?;
147                                match *hi {
148                                    syn::Expr::Lit(syn::ExprLit {
149                                        lit: syn::Lit::Int(hi),
150                                        ..
151                                    }) => {
152                                        let hi = hi.base10_parse()?;
153                                        if hi < lo {
154                                            return Err(syn::Error::new(pat_span, format!("range lower bound {} must be less than upper bound {}", lo, hi)));
155                                        }
156                                        let exclusive_end = match limits {
157                                            syn::RangeLimits::Closed(_) => false,
158                                            syn::RangeLimits::HalfOpen(_) => true,
159                                        };
160                                        Ok(Param {
161                                            ident,
162                                            lo,
163                                            exclusive_end,
164                                            hi,
165                                        })
166                                    }
167                                    expr => Err(syn::Error::new(
168                                        expr.span(),
169                                        "must be an integer literal",
170                                    )),
171                                }
172                            }
173                            expr => Err(syn::Error::new(expr.span(), "must be an integer literal")),
174                        },
175                        pat => Err(syn::Error::new(
176                            pat.span(),
177                            "only range patterns allowed (e.g. `1..2` or `3..=4`)",
178                        )),
179                    }
180                }
181                None => Err(syn::Error::new(
182                    ident.span(),
183                    format!(
184                        "this parameter must have a specified range pattern (e.g. `{} @ 1..2`)",
185                        ident
186                    ),
187                )),
188            },
189        }
190    }
191}
192
193impl syn::parse::Parse for Lut {
194    fn parse(input: syn::parse::ParseStream) -> syn::parse::Result<Self> {
195        let or1_token: syn::Token![|] = input.parse()?;
196
197        let mut inputs = syn::punctuated::Punctuated::new();
198        loop {
199            if input.peek(syn::Token![|]) {
200                break;
201            }
202            let value = Param::from_pat(input.parse::<syn::Pat>()?)?;
203            inputs.push_value(value);
204            if input.peek(syn::Token![|]) {
205                break;
206            }
207            let punct: syn::Token![,] = input.parse()?;
208            inputs.push_punct(punct);
209        }
210
211        let or2_token: syn::Token![|] = input.parse()?;
212
213        let arrow_token: syn::Token![->] = input.parse()?;
214        let return_type: syn::Type = input.parse()?;
215        let body: syn::Block = input.parse()?;
216        let body = syn::Expr::Block(syn::ExprBlock {
217            attrs: Vec::new(),
218            label: None,
219            block: body,
220        });
221
222        Ok(Lut {
223            or1_token,
224            inputs,
225            or2_token,
226            arrow_token,
227            return_type,
228            body,
229        })
230    }
231}