1#![cfg_attr(feature = "cfg_attribute", feature(proc_macro_expand))]
2
3#![cfg_attr(
31 feature = "cfg_attribute",
32 doc = r#"
33## `cfg` attribute
34
35Only when using Nightly Rust, this macro supports conditional compilation with
36the `cfg` attribute. To use this feature, enable `features = ["cfg_attribute"]`
37in your `Cargo.toml`.
38
39### Example
40
41```
42use trie_match::trie_match;
43
44let x = "abd";
45
46let result = trie_match! {
47 match x {
48 #[cfg(not(feature = "foo"))]
49 "a" => 0,
50 "abc" => 1,
51 #[cfg(feature = "bar")]
52 "abd" | "bcc" => 2,
53 "bc" => 3,
54 _ => 4,
55 }
56};
57
58assert_eq!(result, 4);
59```
60"#
61)]
62mod trie;
73
74extern crate proc_macro;
75
76use std::collections::HashMap;
77
78use proc_macro2::{Span, TokenStream};
79use quote::{format_ident, quote};
80use syn::{
81 parse_macro_input, spanned::Spanned, Arm, Error, Expr, ExprLit, ExprMatch, Lit, Pat, PatIdent,
82 PatOr, PatReference, PatSlice, PatWild,
83};
84
85#[cfg(feature = "cfg_attribute")]
86use proc_macro2::Ident;
87#[cfg(feature = "cfg_attribute")]
88use syn::{Attribute, Meta};
89
90use crate::trie::Sparse;
91
92static ERROR_UNEXPECTED_PATTERN: &str =
93 "`trie_match` only supports string literals, byte string literals, and u8 slices as patterns";
94static ERROR_ATTRIBUTE_NOT_SUPPORTED: &str = "attribute not supported here";
95static ERROR_GUARD_NOT_SUPPORTED: &str = "match guard not supported";
96static ERROR_UNREACHABLE_PATTERN: &str = "unreachable pattern";
97static ERROR_PATTERN_NOT_COVERED: &str = "non-exhaustive patterns: `_` not covered";
98static ERROR_EXPECTED_U8_LITERAL: &str = "expected `u8` integer literal";
99static ERROR_VARIABLE_NOT_MATCH: &str = "variable is not bound in all patterns";
100
101#[cfg(not(feature = "cfg_attribute"))]
102static ERROR_ATTRIBUTE_NOT_SUPPORTED_CFG: &str =
103 "attribute not supported here\nnote: consider enabling the `cfg_attribute` feature: \
104 https://docs.rs/trie-match/latest/trie_match/#cfg-attribute";
105
106#[cfg(feature = "cfg_attribute")]
107static ERROR_NOT_CFG_ATTRIBUTE: &str = "only supports the cfg attribute";
108
109fn convert_literal_pattern(pat: &ExprLit) -> Result<Option<Vec<u8>>, Error> {
111 let ExprLit { attrs, lit } = pat;
112 if let Some(attr) = attrs.first() {
113 return Err(Error::new(attr.span(), ERROR_ATTRIBUTE_NOT_SUPPORTED));
114 }
115 match lit {
116 Lit::Str(s) => Ok(Some(s.value().into())),
117 Lit::ByteStr(s) => Ok(Some(s.value())),
118 _ => Err(Error::new(lit.span(), ERROR_UNEXPECTED_PATTERN)),
119 }
120}
121
122fn convert_slice_pattern(pat: &PatSlice) -> Result<Option<Vec<u8>>, Error> {
124 let PatSlice { attrs, elems, .. } = pat;
125 if let Some(attr) = attrs.first() {
126 return Err(Error::new(attr.span(), ERROR_ATTRIBUTE_NOT_SUPPORTED));
127 }
128 let mut result = vec![];
129 for elem in elems {
130 match elem {
131 Pat::Lit(ExprLit { attrs, lit }) => {
132 if let Some(attr) = attrs.first() {
133 return Err(Error::new(attr.span(), ERROR_ATTRIBUTE_NOT_SUPPORTED));
134 }
135 match lit {
136 Lit::Int(i) => {
137 let int_type = i.suffix();
138 if int_type != "u8" && !int_type.is_empty() {
139 return Err(Error::new(i.span(), ERROR_EXPECTED_U8_LITERAL));
140 }
141 result.push(i.base10_parse::<u8>()?);
142 }
143 Lit::Byte(b) => {
144 result.push(b.value());
145 }
146 _ => {
147 return Err(Error::new(elem.span(), ERROR_EXPECTED_U8_LITERAL));
148 }
149 }
150 }
151 _ => {
152 return Err(Error::new(elem.span(), ERROR_EXPECTED_U8_LITERAL));
153 }
154 }
155 }
156 Ok(Some(result))
157}
158
159fn convert_wildcard_pattern(pat: &PatWild) -> Result<Option<Vec<u8>>, Error> {
164 let PatWild { attrs, .. } = pat;
165 if let Some(attr) = attrs.first() {
166 return Err(Error::new(attr.span(), ERROR_ATTRIBUTE_NOT_SUPPORTED));
167 }
168 Ok(None)
169}
170
171fn convert_reference_pattern(pat: &PatReference) -> Result<Option<Vec<u8>>, Error> {
173 let PatReference { attrs, pat, .. } = pat;
174 if let Some(attr) = attrs.first() {
175 return Err(Error::new(attr.span(), ERROR_ATTRIBUTE_NOT_SUPPORTED));
176 }
177 match &**pat {
178 Pat::Lit(pat) => convert_literal_pattern(pat),
179 Pat::Slice(pat) => convert_slice_pattern(pat),
180 Pat::Reference(pat) => convert_reference_pattern(pat),
181 _ => Err(Error::new(pat.span(), ERROR_UNEXPECTED_PATTERN)),
182 }
183}
184
185struct PatternBytes {
186 ident: Option<PatIdent>,
188
189 bytes: Option<Vec<u8>>,
191}
192
193impl PatternBytes {
194 const fn new(ident: Option<PatIdent>, bytes: Option<Vec<u8>>) -> Self {
195 Self { ident, bytes }
196 }
197}
198
199fn retrieve_match_patterns(
203 pat: &Pat,
204 ident: Option<PatIdent>,
205 pat_bytes_set: &mut Vec<PatternBytes>,
206 pat_set: &mut Vec<Pat>,
207) -> Result<(), Error> {
208 match pat {
209 Pat::Lit(lit) => {
210 pat_set.push(pat.clone());
211 pat_bytes_set.push(PatternBytes::new(ident, convert_literal_pattern(lit)?));
212 }
213 Pat::Slice(slice) => {
214 pat_set.push(pat.clone());
215 pat_bytes_set.push(PatternBytes::new(ident, convert_slice_pattern(slice)?));
216 }
217 Pat::Wild(pat) => {
218 pat_bytes_set.push(PatternBytes::new(ident, convert_wildcard_pattern(pat)?));
219 }
220 Pat::Reference(reference) => {
221 pat_set.push(pat.clone());
222 pat_bytes_set.push(PatternBytes::new(
223 ident,
224 convert_reference_pattern(reference)?,
225 ));
226 }
227 Pat::Ident(pat) => {
228 if let Some(attr) = pat.attrs.first() {
229 return Err(Error::new(attr.span(), ERROR_ATTRIBUTE_NOT_SUPPORTED));
230 }
231 let mut pat = pat.clone();
232 if let Some((_, subpat)) = pat.subpat.take() {
233 retrieve_match_patterns(&subpat, Some(pat), pat_bytes_set, pat_set)?;
234 } else {
235 pat_bytes_set.push(PatternBytes::new(Some(pat), None));
236 }
237 }
238 Pat::Paren(pat) => {
239 retrieve_match_patterns(&pat.pat, ident, pat_bytes_set, pat_set)?;
240 }
241 Pat::Or(PatOr {
242 attrs,
243 leading_vert: None,
244 cases,
245 }) => {
246 if let Some(attr) = attrs.first() {
247 return Err(Error::new(attr.span(), ERROR_ATTRIBUTE_NOT_SUPPORTED));
248 }
249 for pat in cases {
250 retrieve_match_patterns(pat, ident.clone(), pat_bytes_set, pat_set)?;
251 }
252 }
253 _ => {
254 return Err(Error::new(pat.span(), ERROR_UNEXPECTED_PATTERN));
255 }
256 }
257 Ok(())
258}
259
260#[cfg(feature = "cfg_attribute")]
261fn evaluate_cfg_attribute(attrs: &[Attribute]) -> Result<bool, Error> {
262 for attr in attrs {
263 let ident = attr.path().get_ident().map(Ident::to_string);
264 if ident.as_deref() == Some("cfg") {
265 if let Meta::List(list) = &attr.meta {
266 let tokens = &list.tokens;
267 let cfg_macro: proc_macro::TokenStream = quote! { cfg!(#tokens) }.into();
268 let expr = cfg_macro
269 .expand_expr()
270 .map_err(|e| Error::new(tokens.span(), e.to_string()))?;
271 if expr.to_string() == "false" {
272 return Ok(false);
273 }
274 continue;
275 }
276 }
277 return Err(Error::new(attr.span(), ERROR_NOT_CFG_ATTRIBUTE));
278 }
279 Ok(true)
280}
281
282struct MatchInfo {
283 bodies: Vec<Expr>,
284 pattern_map: HashMap<Vec<u8>, usize>,
285 wildcard_idx: usize,
286 bound_vals: Vec<Option<PatIdent>>,
287 pat_set: Vec<Pat>,
288}
289
290fn parse_match_arms(arms: Vec<Arm>) -> Result<MatchInfo, Error> {
291 let mut pattern_map = HashMap::new();
292 let mut wildcard_idx = None;
293 let mut bound_vals = vec![];
294 let mut bodies = vec![];
295 let mut pat_set = vec![];
296 let mut i = 0;
297 #[allow(clippy::explicit_counter_loop)]
298 for Arm {
299 attrs,
300 pat,
301 guard,
302 body,
303 ..
304 } in arms
305 {
306 #[cfg(feature = "cfg_attribute")]
307 if !evaluate_cfg_attribute(&attrs)? {
308 continue;
309 }
310 #[cfg(not(feature = "cfg_attribute"))]
311 if let Some(attr) = attrs.first() {
312 return Err(Error::new(attr.span(), ERROR_ATTRIBUTE_NOT_SUPPORTED_CFG));
313 }
314
315 if let Some((if_token, _)) = guard {
316 return Err(Error::new(if_token.span(), ERROR_GUARD_NOT_SUPPORTED));
317 }
318 let mut pat_bytes_set = vec![];
319 retrieve_match_patterns(&pat, None, &mut pat_bytes_set, &mut pat_set)?;
320 let bound_val = pat_bytes_set[0].ident.clone();
321 for PatternBytes { ident, bytes } in pat_bytes_set {
322 if ident != bound_val {
323 return Err(Error::new(
324 ident.or(bound_val).unwrap().span(),
325 ERROR_VARIABLE_NOT_MATCH,
326 ));
327 }
328 if let Some(bytes) = bytes {
329 if pattern_map.contains_key(&bytes) {
330 return Err(Error::new(pat.span(), ERROR_UNREACHABLE_PATTERN));
331 }
332 pattern_map.insert(bytes, i);
333 } else {
334 if wildcard_idx.is_some() {
335 return Err(Error::new(pat.span(), ERROR_UNREACHABLE_PATTERN));
336 }
337 wildcard_idx.replace(i);
338 }
339 }
340 bound_vals.push(bound_val);
341 bodies.push(*body);
342 i += 1;
343 }
344 let Some(wildcard_idx) = wildcard_idx else {
345 return Err(Error::new(Span::call_site(), ERROR_PATTERN_NOT_COVERED));
346 };
347 Ok(MatchInfo {
348 bodies,
349 pattern_map,
350 wildcard_idx,
351 bound_vals,
352 pat_set,
353 })
354}
355
356fn trie_match_inner(input: ExprMatch) -> Result<TokenStream, Error> {
357 let ExprMatch {
358 attrs, expr, arms, ..
359 } = input;
360 let MatchInfo {
361 bodies,
362 pattern_map,
363 wildcard_idx,
364 bound_vals,
365 pat_set,
366 } = parse_match_arms(arms)?;
367 let mut trie = Sparse::new();
368 for (k, v) in pattern_map {
369 if v == wildcard_idx {
370 continue;
371 }
372 trie.add(k, v);
373 }
374 let (bases, checks, outs) = trie.build_double_array_trie(wildcard_idx);
375
376 let out_check = outs.iter().zip(checks).map(|(out, check)| {
377 let out = format_ident!("V{out}");
378 quote! { (__TrieMatchValue::#out, #check) }
379 });
380 let arm = bodies
381 .iter()
382 .zip(bound_vals)
383 .enumerate()
384 .map(|(i, (body, bound_val))| {
385 let i = format_ident!("V{i}");
386 let bound_val = bound_val.map_or_else(|| quote! { _ }, |val| quote! { #val });
387 quote! { (__TrieMatchValue::#i, #bound_val ) => #body }
388 });
389 let enumvalue = (0..bodies.len()).map(|i| format_ident!("V{i}"));
390 let wildcard_ident = format_ident!("V{wildcard_idx}");
391 Ok(quote! {
392 {
393 #[derive(Clone, Copy)]
394 enum __TrieMatchValue {
395 #( #enumvalue, )*
396 }
397 #( #attrs )*
398 match #expr {
399 query @ ( #( #pat_set | )* _) => {
401 match (|query| unsafe {
402 let query_ref = ::core::convert::AsRef::<[u8]>::as_ref(&query);
403 let bases: &'static [i32] = &[ #( #bases, )* ];
404 let out_checks: &'static [(__TrieMatchValue, u8)] = &[ #( #out_check, )* ];
405 let mut pos = 0;
406 let mut base = bases[0];
407 for &b in query_ref {
408 pos = base.wrapping_add(i32::from(b)) as usize;
409 if let Some((_, check)) = out_checks.get(pos) {
410 if *check == b {
411 base = *bases.get_unchecked(pos);
412 continue;
413 }
414 }
415 return (__TrieMatchValue::#wildcard_ident, query);
416 }
417 (out_checks.get_unchecked(pos).0, query)
418 })(query) {
419 #( #arm, )*
420 }
421 }
422 }
423 }
424 })
425}
426
427#[proc_macro]
449pub fn trie_match(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
450 let input = parse_macro_input!(input as ExprMatch);
451 trie_match_inner(input)
452 .unwrap_or_else(Error::into_compile_error)
453 .into()
454}