1#![deny(
15 missing_docs,
16 missing_debug_implementations,
17 missing_copy_implementations,
18 trivial_casts,
19 trivial_numeric_casts,
20 unsafe_code,
21 unstable_features,
22 unused_import_braces,
23 unused_qualifications
24)]
25
26extern crate proc_macro;
27
28struct Lut {
29 #[allow(unused)]
30 or1_token: syn::Token![|],
31 inputs: syn::punctuated::Punctuated<Param, syn::Token![,]>,
32 #[allow(unused)]
33 or2_token: syn::Token![|],
34 #[allow(unused)]
35 arrow_token: syn::Token![->],
36 return_type: syn::Type,
37 body: syn::Expr,
38}
39
40struct Param {
41 ident: syn::Ident,
42 lo: usize,
43 exclusive_end: bool,
44 hi: usize,
45}
46
47#[proc_macro]
53pub fn lut(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
54 let input = syn::parse_macro_input!(input as Lut);
55
56 let table_data = input.inputs.iter().rev().fold(input.body, |body, param| {
57 if param.exclusive_end {
58 generate_array(¶m.ident, param.lo..param.hi, body)
59 } else {
60 generate_array(¶m.ident, param.lo..=param.hi, body)
61 }
62 });
63
64 let lut_access = input
65 .inputs
66 .iter()
67 .fold(quote::quote!(__LUT), |expr, param| {
68 let ident = ¶m.ident;
69 quote::quote!(#expr[#ident])
70 });
71
72 let lut_params = input.inputs.iter().map(|param| {
73 let ident = ¶m.ident;
74 quote::quote!(#ident: usize)
75 });
76
77 let lut_type = input
78 .inputs
79 .iter()
80 .rev()
81 .fold(input.return_type, |ty, param| {
82 let count = if param.exclusive_end {
83 param.hi - param.lo
84 } else {
85 param.hi - param.lo + 1
86 };
87 quote::quote!([#ty; #count]).into()
88 });
89
90 let output = quote::quote!({
91 static __LUT: #lut_type = #table_data;
92 |#(#lut_params),*| #lut_access
93 });
94
95 output.into()
96}
97
98fn generate_array(
99 ident: &syn::Ident,
100 range: impl Iterator<Item = usize>,
101 body: syn::Expr,
102) -> syn::Expr {
103 let items = range.map(|n| {
104 quote::quote!({
105 #[allow(non_upper_case_globals)]
106 const #ident: usize = #n;
107 #body
108 })
109 });
110 quote::quote!([#(#items),*]).into()
111}
112
113impl Param {
114 fn from_pat(pat: syn::Pat) -> syn::Result<Self> {
115 use syn::spanned::Spanned;
116 match pat {
117 syn::Pat::Ident(pat_ident) => Self::from_pat_ident(pat_ident),
118 other => Err(syn::Error::new(
119 other.span(),
120 "this parameter must have a range pattern (e.g. `x @ 1..2` or `y @ 3..=4`)",
121 )),
122 }
123 }
124
125 fn from_pat_ident(pat_ident: syn::PatIdent) -> syn::Result<Self> {
126 use syn::spanned::Spanned;
127 match pat_ident {
128 syn::PatIdent {
129 ident,
130 subpat,
131 ..
132 } => match subpat {
133 Some((_, pat)) => {
134 let pat_span = pat.span();
135 match *pat {
136 syn::Pat::Range(syn::PatRange {
137 lo,
138 limits,
139 hi,
140 ..
141 }) => match *lo {
142 syn::Expr::Lit(syn::ExprLit {
143 lit: syn::Lit::Int(lo),
144 ..
145 }) => {
146 let lo = lo.base10_parse()?;
147 match *hi {
148 syn::Expr::Lit(syn::ExprLit {
149 lit: syn::Lit::Int(hi),
150 ..
151 }) => {
152 let hi = hi.base10_parse()?;
153 if hi < lo {
154 return Err(syn::Error::new(pat_span, format!("range lower bound {} must be less than upper bound {}", lo, hi)));
155 }
156 let exclusive_end = match limits {
157 syn::RangeLimits::Closed(_) => false,
158 syn::RangeLimits::HalfOpen(_) => true,
159 };
160 Ok(Param {
161 ident,
162 lo,
163 exclusive_end,
164 hi,
165 })
166 }
167 expr => Err(syn::Error::new(
168 expr.span(),
169 "must be an integer literal",
170 )),
171 }
172 }
173 expr => Err(syn::Error::new(expr.span(), "must be an integer literal")),
174 },
175 pat => Err(syn::Error::new(
176 pat.span(),
177 "only range patterns allowed (e.g. `1..2` or `3..=4`)",
178 )),
179 }
180 }
181 None => Err(syn::Error::new(
182 ident.span(),
183 format!(
184 "this parameter must have a specified range pattern (e.g. `{} @ 1..2`)",
185 ident
186 ),
187 )),
188 },
189 }
190 }
191}
192
193impl syn::parse::Parse for Lut {
194 fn parse(input: syn::parse::ParseStream) -> syn::parse::Result<Self> {
195 let or1_token: syn::Token![|] = input.parse()?;
196
197 let mut inputs = syn::punctuated::Punctuated::new();
198 loop {
199 if input.peek(syn::Token![|]) {
200 break;
201 }
202 let value = Param::from_pat(input.parse::<syn::Pat>()?)?;
203 inputs.push_value(value);
204 if input.peek(syn::Token![|]) {
205 break;
206 }
207 let punct: syn::Token![,] = input.parse()?;
208 inputs.push_punct(punct);
209 }
210
211 let or2_token: syn::Token![|] = input.parse()?;
212
213 let arrow_token: syn::Token![->] = input.parse()?;
214 let return_type: syn::Type = input.parse()?;
215 let body: syn::Block = input.parse()?;
216 let body = syn::Expr::Block(syn::ExprBlock {
217 attrs: Vec::new(),
218 label: None,
219 block: body,
220 });
221
222 Ok(Lut {
223 or1_token,
224 inputs,
225 or2_token,
226 arrow_token,
227 return_type,
228 body,
229 })
230 }
231}