trie_match/
lib.rs

1#![cfg_attr(feature = "cfg_attribute", feature(proc_macro_expand))]
2
3//! # `trie_match! {}`
4//!
5//! This macro speeds up Rust's `match` expression for comparing strings by using a compact
6//! double-array data structure.
7//!
8//! ## Usage
9//!
10//! Simply wrap the existing match expression with the `trie_match! {}` macro as
11//! follows:
12//!
13//! ```
14//! use trie_match::trie_match;
15//!
16//! let x = "abd";
17//!
18//! let result = trie_match! {
19//!     match x {
20//!         "a" => 0,
21//!         "abc" => 1,
22//!         pat @ ("abd" | "bcde") => pat.len(),
23//!         "bc" => 3,
24//!         _ => 4,
25//!     }
26//! };
27//!
28//! assert_eq!(result, 3);
29//! ```
30#![cfg_attr(
31    feature = "cfg_attribute",
32    doc = r#"
33## `cfg` attribute
34
35Only when using Nightly Rust, this macro supports conditional compilation with
36the `cfg` attribute. To use this feature, enable `features = ["cfg_attribute"]`
37in your `Cargo.toml`.
38
39### Example
40
41```
42use trie_match::trie_match;
43
44let x = "abd";
45
46let result = trie_match! {
47    match x {
48        #[cfg(not(feature = "foo"))]
49        "a" => 0,
50        "abc" => 1,
51        #[cfg(feature = "bar")]
52        "abd" | "bcc" => 2,
53        "bc" => 3,
54        _ => 4,
55    }
56};
57
58assert_eq!(result, 4);
59```
60"#
61)]
62//!
63//! ## Limitations
64//!
65//! The followings are different from the normal `match` expression:
66//!
67//! * Only supports strings, byte strings, and u8 slices as patterns.
68//! * The wildcard is evaluated last. (The normal `match` expression does not
69//!   match patterns after the wildcard.)
70//! * Guards are unavailable.
71
72mod trie;
73
74extern crate proc_macro;
75
76use std::collections::HashMap;
77
78use proc_macro2::{Span, TokenStream};
79use quote::{format_ident, quote};
80use syn::{
81    parse_macro_input, spanned::Spanned, Arm, Error, Expr, ExprLit, ExprMatch, Lit, Pat, PatIdent,
82    PatOr, PatReference, PatSlice, PatWild,
83};
84
85#[cfg(feature = "cfg_attribute")]
86use proc_macro2::Ident;
87#[cfg(feature = "cfg_attribute")]
88use syn::{Attribute, Meta};
89
90use crate::trie::Sparse;
91
92static ERROR_UNEXPECTED_PATTERN: &str =
93    "`trie_match` only supports string literals, byte string literals, and u8 slices as patterns";
94static ERROR_ATTRIBUTE_NOT_SUPPORTED: &str = "attribute not supported here";
95static ERROR_GUARD_NOT_SUPPORTED: &str = "match guard not supported";
96static ERROR_UNREACHABLE_PATTERN: &str = "unreachable pattern";
97static ERROR_PATTERN_NOT_COVERED: &str = "non-exhaustive patterns: `_` not covered";
98static ERROR_EXPECTED_U8_LITERAL: &str = "expected `u8` integer literal";
99static ERROR_VARIABLE_NOT_MATCH: &str = "variable is not bound in all patterns";
100
101#[cfg(not(feature = "cfg_attribute"))]
102static ERROR_ATTRIBUTE_NOT_SUPPORTED_CFG: &str =
103    "attribute not supported here\nnote: consider enabling the `cfg_attribute` feature: \
104    https://docs.rs/trie-match/latest/trie_match/#cfg-attribute";
105
106#[cfg(feature = "cfg_attribute")]
107static ERROR_NOT_CFG_ATTRIBUTE: &str = "only supports the cfg attribute";
108
109/// Converts a literal pattern into a byte sequence.
110fn convert_literal_pattern(pat: &ExprLit) -> Result<Option<Vec<u8>>, Error> {
111    let ExprLit { attrs, lit } = pat;
112    if let Some(attr) = attrs.first() {
113        return Err(Error::new(attr.span(), ERROR_ATTRIBUTE_NOT_SUPPORTED));
114    }
115    match lit {
116        Lit::Str(s) => Ok(Some(s.value().into())),
117        Lit::ByteStr(s) => Ok(Some(s.value())),
118        _ => Err(Error::new(lit.span(), ERROR_UNEXPECTED_PATTERN)),
119    }
120}
121
122/// Converts a slice pattern into a byte sequence.
123fn convert_slice_pattern(pat: &PatSlice) -> Result<Option<Vec<u8>>, Error> {
124    let PatSlice { attrs, elems, .. } = pat;
125    if let Some(attr) = attrs.first() {
126        return Err(Error::new(attr.span(), ERROR_ATTRIBUTE_NOT_SUPPORTED));
127    }
128    let mut result = vec![];
129    for elem in elems {
130        match elem {
131            Pat::Lit(ExprLit { attrs, lit }) => {
132                if let Some(attr) = attrs.first() {
133                    return Err(Error::new(attr.span(), ERROR_ATTRIBUTE_NOT_SUPPORTED));
134                }
135                match lit {
136                    Lit::Int(i) => {
137                        let int_type = i.suffix();
138                        if int_type != "u8" && !int_type.is_empty() {
139                            return Err(Error::new(i.span(), ERROR_EXPECTED_U8_LITERAL));
140                        }
141                        result.push(i.base10_parse::<u8>()?);
142                    }
143                    Lit::Byte(b) => {
144                        result.push(b.value());
145                    }
146                    _ => {
147                        return Err(Error::new(elem.span(), ERROR_EXPECTED_U8_LITERAL));
148                    }
149                }
150            }
151            _ => {
152                return Err(Error::new(elem.span(), ERROR_EXPECTED_U8_LITERAL));
153            }
154        }
155    }
156    Ok(Some(result))
157}
158
159/// Checks a wildcard pattern and returns `None`.
160///
161/// The reason the type is `Result<Option<Vec<u8>>, Error>` instead of `Result<(), Error>` is for
162/// consistency with other functions.
163fn convert_wildcard_pattern(pat: &PatWild) -> Result<Option<Vec<u8>>, Error> {
164    let PatWild { attrs, .. } = pat;
165    if let Some(attr) = attrs.first() {
166        return Err(Error::new(attr.span(), ERROR_ATTRIBUTE_NOT_SUPPORTED));
167    }
168    Ok(None)
169}
170
171/// Converts a reference pattern (e.g. `&[0, 1, ...]`) into a byte sequence.
172fn convert_reference_pattern(pat: &PatReference) -> Result<Option<Vec<u8>>, Error> {
173    let PatReference { attrs, pat, .. } = pat;
174    if let Some(attr) = attrs.first() {
175        return Err(Error::new(attr.span(), ERROR_ATTRIBUTE_NOT_SUPPORTED));
176    }
177    match &**pat {
178        Pat::Lit(pat) => convert_literal_pattern(pat),
179        Pat::Slice(pat) => convert_slice_pattern(pat),
180        Pat::Reference(pat) => convert_reference_pattern(pat),
181        _ => Err(Error::new(pat.span(), ERROR_UNEXPECTED_PATTERN)),
182    }
183}
184
185struct PatternBytes {
186    /// Bound variable identifier.
187    ident: Option<PatIdent>,
188
189    /// Byte sequence of this pattern. `None` is for a wildcard.
190    bytes: Option<Vec<u8>>,
191}
192
193impl PatternBytes {
194    const fn new(ident: Option<PatIdent>, bytes: Option<Vec<u8>>) -> Self {
195        Self { ident, bytes }
196    }
197}
198
199/// Retrieves pattern strings from the given token.
200///
201/// None indicates a wild card pattern (`_`).
202fn retrieve_match_patterns(
203    pat: &Pat,
204    ident: Option<PatIdent>,
205    pat_bytes_set: &mut Vec<PatternBytes>,
206    pat_set: &mut Vec<Pat>,
207) -> Result<(), Error> {
208    match pat {
209        Pat::Lit(lit) => {
210            pat_set.push(pat.clone());
211            pat_bytes_set.push(PatternBytes::new(ident, convert_literal_pattern(lit)?));
212        }
213        Pat::Slice(slice) => {
214            pat_set.push(pat.clone());
215            pat_bytes_set.push(PatternBytes::new(ident, convert_slice_pattern(slice)?));
216        }
217        Pat::Wild(pat) => {
218            pat_bytes_set.push(PatternBytes::new(ident, convert_wildcard_pattern(pat)?));
219        }
220        Pat::Reference(reference) => {
221            pat_set.push(pat.clone());
222            pat_bytes_set.push(PatternBytes::new(
223                ident,
224                convert_reference_pattern(reference)?,
225            ));
226        }
227        Pat::Ident(pat) => {
228            if let Some(attr) = pat.attrs.first() {
229                return Err(Error::new(attr.span(), ERROR_ATTRIBUTE_NOT_SUPPORTED));
230            }
231            let mut pat = pat.clone();
232            if let Some((_, subpat)) = pat.subpat.take() {
233                retrieve_match_patterns(&subpat, Some(pat), pat_bytes_set, pat_set)?;
234            } else {
235                pat_bytes_set.push(PatternBytes::new(Some(pat), None));
236            }
237        }
238        Pat::Paren(pat) => {
239            retrieve_match_patterns(&pat.pat, ident, pat_bytes_set, pat_set)?;
240        }
241        Pat::Or(PatOr {
242            attrs,
243            leading_vert: None,
244            cases,
245        }) => {
246            if let Some(attr) = attrs.first() {
247                return Err(Error::new(attr.span(), ERROR_ATTRIBUTE_NOT_SUPPORTED));
248            }
249            for pat in cases {
250                retrieve_match_patterns(pat, ident.clone(), pat_bytes_set, pat_set)?;
251            }
252        }
253        _ => {
254            return Err(Error::new(pat.span(), ERROR_UNEXPECTED_PATTERN));
255        }
256    }
257    Ok(())
258}
259
260#[cfg(feature = "cfg_attribute")]
261fn evaluate_cfg_attribute(attrs: &[Attribute]) -> Result<bool, Error> {
262    for attr in attrs {
263        let ident = attr.path().get_ident().map(Ident::to_string);
264        if ident.as_deref() == Some("cfg") {
265            if let Meta::List(list) = &attr.meta {
266                let tokens = &list.tokens;
267                let cfg_macro: proc_macro::TokenStream = quote! { cfg!(#tokens) }.into();
268                let expr = cfg_macro
269                    .expand_expr()
270                    .map_err(|e| Error::new(tokens.span(), e.to_string()))?;
271                if expr.to_string() == "false" {
272                    return Ok(false);
273                }
274                continue;
275            }
276        }
277        return Err(Error::new(attr.span(), ERROR_NOT_CFG_ATTRIBUTE));
278    }
279    Ok(true)
280}
281
282struct MatchInfo {
283    bodies: Vec<Expr>,
284    pattern_map: HashMap<Vec<u8>, usize>,
285    wildcard_idx: usize,
286    bound_vals: Vec<Option<PatIdent>>,
287    pat_set: Vec<Pat>,
288}
289
290fn parse_match_arms(arms: Vec<Arm>) -> Result<MatchInfo, Error> {
291    let mut pattern_map = HashMap::new();
292    let mut wildcard_idx = None;
293    let mut bound_vals = vec![];
294    let mut bodies = vec![];
295    let mut pat_set = vec![];
296    let mut i = 0;
297    #[allow(clippy::explicit_counter_loop)]
298    for Arm {
299        attrs,
300        pat,
301        guard,
302        body,
303        ..
304    } in arms
305    {
306        #[cfg(feature = "cfg_attribute")]
307        if !evaluate_cfg_attribute(&attrs)? {
308            continue;
309        }
310        #[cfg(not(feature = "cfg_attribute"))]
311        if let Some(attr) = attrs.first() {
312            return Err(Error::new(attr.span(), ERROR_ATTRIBUTE_NOT_SUPPORTED_CFG));
313        }
314
315        if let Some((if_token, _)) = guard {
316            return Err(Error::new(if_token.span(), ERROR_GUARD_NOT_SUPPORTED));
317        }
318        let mut pat_bytes_set = vec![];
319        retrieve_match_patterns(&pat, None, &mut pat_bytes_set, &mut pat_set)?;
320        let bound_val = pat_bytes_set[0].ident.clone();
321        for PatternBytes { ident, bytes } in pat_bytes_set {
322            if ident != bound_val {
323                return Err(Error::new(
324                    ident.or(bound_val).unwrap().span(),
325                    ERROR_VARIABLE_NOT_MATCH,
326                ));
327            }
328            if let Some(bytes) = bytes {
329                if pattern_map.contains_key(&bytes) {
330                    return Err(Error::new(pat.span(), ERROR_UNREACHABLE_PATTERN));
331                }
332                pattern_map.insert(bytes, i);
333            } else {
334                if wildcard_idx.is_some() {
335                    return Err(Error::new(pat.span(), ERROR_UNREACHABLE_PATTERN));
336                }
337                wildcard_idx.replace(i);
338            }
339        }
340        bound_vals.push(bound_val);
341        bodies.push(*body);
342        i += 1;
343    }
344    let Some(wildcard_idx) = wildcard_idx else {
345        return Err(Error::new(Span::call_site(), ERROR_PATTERN_NOT_COVERED));
346    };
347    Ok(MatchInfo {
348        bodies,
349        pattern_map,
350        wildcard_idx,
351        bound_vals,
352        pat_set,
353    })
354}
355
356fn trie_match_inner(input: ExprMatch) -> Result<TokenStream, Error> {
357    let ExprMatch {
358        attrs, expr, arms, ..
359    } = input;
360    let MatchInfo {
361        bodies,
362        pattern_map,
363        wildcard_idx,
364        bound_vals,
365        pat_set,
366    } = parse_match_arms(arms)?;
367    let mut trie = Sparse::new();
368    for (k, v) in pattern_map {
369        if v == wildcard_idx {
370            continue;
371        }
372        trie.add(k, v);
373    }
374    let (bases, checks, outs) = trie.build_double_array_trie(wildcard_idx);
375
376    let out_check = outs.iter().zip(checks).map(|(out, check)| {
377        let out = format_ident!("V{out}");
378        quote! { (__TrieMatchValue::#out, #check) }
379    });
380    let arm = bodies
381        .iter()
382        .zip(bound_vals)
383        .enumerate()
384        .map(|(i, (body, bound_val))| {
385            let i = format_ident!("V{i}");
386            let bound_val = bound_val.map_or_else(|| quote! { _ }, |val| quote! { #val });
387            quote! { (__TrieMatchValue::#i, #bound_val ) => #body }
388        });
389    let enumvalue = (0..bodies.len()).map(|i| format_ident!("V{i}"));
390    let wildcard_ident = format_ident!("V{wildcard_idx}");
391    Ok(quote! {
392        {
393            #[derive(Clone, Copy)]
394            enum __TrieMatchValue {
395                #( #enumvalue, )*
396            }
397            #( #attrs )*
398            match #expr {
399                // This is for type inference.
400                query @ ( #( #pat_set | )* _) => {
401                    match (|query| unsafe {
402                        let query_ref = ::core::convert::AsRef::<[u8]>::as_ref(&query);
403                        let bases: &'static [i32] = &[ #( #bases, )* ];
404                        let out_checks: &'static [(__TrieMatchValue, u8)] = &[ #( #out_check, )* ];
405                        let mut pos = 0;
406                        let mut base = bases[0];
407                        for &b in query_ref {
408                            pos = base.wrapping_add(i32::from(b)) as usize;
409                            if let Some((_, check)) = out_checks.get(pos) {
410                                if *check == b {
411                                    base = *bases.get_unchecked(pos);
412                                    continue;
413                                }
414                            }
415                            return (__TrieMatchValue::#wildcard_ident, query);
416                        }
417                        (out_checks.get_unchecked(pos).0, query)
418                    })(query) {
419                        #( #arm, )*
420                    }
421                }
422            }
423        }
424    })
425}
426
427/// Generates a match expression that uses a trie structure.
428///
429/// # Examples
430///
431/// ```
432/// use trie_match::trie_match;
433///
434/// let x = "abd";
435///
436/// let result = trie_match! {
437///     match x {
438///         "a" => 0,
439///         "abc" => 1,
440///         pat @ ("abd" | "bcde") => pat.len(),
441///         "bc" => 3,
442///         _ => 4,
443///     }
444/// };
445///
446/// assert_eq!(result, 3);
447/// ```
448#[proc_macro]
449pub fn trie_match(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
450    let input = parse_macro_input!(input as ExprMatch);
451    trie_match_inner(input)
452        .unwrap_or_else(Error::into_compile_error)
453        .into()
454}