const_regex/
lib.rs

1//! Proc macro to match regexes in const fns. The regex must be a string literal, but the bytes
2//! matched can be any value.
3//!
4//! The macro expects an `&[u8]`, but you can easily use `str::as_bytes`.
5//!
6//! ```
7//! const fn this_crate(bytes: &[u8]) -> bool {
8//!     const_regex::match_regex!("^(meta-)*regex matching", bytes)
9//! }
10//!
11//! assert!(this_crate(b"meta-meta-regex matching"));
12//! assert!(!this_crate(b"a good idea"));
13//! ```
14
15use proc_macro2::TokenStream;
16use quote::quote;
17use regex_automata::{dense, DFA};
18use std::collections::{BTreeSet, HashMap};
19use std::ops::RangeInclusive;
20use syn::{parse::*, *};
21
22type RegexDfa = dense::Standard<Vec<usize>, usize>;
23
24#[derive(Clone, PartialEq)]
25enum State {
26    Match,
27    Dead,
28    Transitions(HashMap<usize, BTreeSet<u8>>),
29}
30
31fn range_to_tokens(range: RangeInclusive<u8>) -> TokenStream {
32    let (start, end) = range.into_inner();
33    if start == end {
34        quote!(#start)
35    } else {
36        quote!(#start..=#end)
37    }
38}
39
40impl State {
41    fn from_regex(regex: &RegexDfa, state: usize) -> Self {
42        if regex.is_match_state(state) {
43            Self::Match
44        } else if regex.is_dead_state(state) {
45            Self::Dead
46        } else {
47            let mut transitions = HashMap::new();
48
49            for byte in 0..=255 {
50                let next = regex.next_state(state, byte);
51                transitions
52                    .entry(next)
53                    .or_insert_with(BTreeSet::new)
54                    .insert(byte);
55            }
56
57            Self::Transitions(transitions)
58        }
59    }
60
61    fn handle(&self, byte: &Ident, states: &HashMap<usize, State>) -> Expr {
62        match self {
63            Self::Match => parse_quote!(return true),
64            Self::Dead => parse_quote!(return false),
65            Self::Transitions(transitions) => {
66                let branches = transitions.iter().map(|(target, bytes)| {
67                    let mut ranges = vec![];
68                    let mut range: Option<RangeInclusive<u8>> = None;
69                    for &byte in bytes {
70                        if let Some(range) = &mut range {
71                            if *range.end() == byte - 1 {
72                                *range = *range.start()..=byte;
73                                continue;
74                            } else {
75                                ranges.push(range_to_tokens(range.clone()));
76                            }
77                        }
78                        range = Some(byte..=byte);
79                    }
80
81                    if let Some(range) = range {
82                        ranges.push(range_to_tokens(range));
83                    }
84
85                    let handler = match states[target] {
86                        Self::Match => quote!(return true),
87                        Self::Dead => quote!(return false),
88                        _ => quote!(#target),
89                    };
90
91                    quote!(#(#ranges)|* => #handler)
92                });
93
94                parse_quote! {
95                    match #byte {
96                        #(#branches),*
97                    }
98                }
99            }
100        }
101    }
102}
103
104struct Dfa {
105    start: usize,
106    states: HashMap<usize, State>,
107}
108
109impl Dfa {
110    fn add_states(&mut self, regex: &RegexDfa, id: usize) {
111        let state = State::from_regex(regex, id);
112
113        self.states.insert(id, state.clone());
114
115        if let State::Transitions(transitions) = &state {
116            for target in transitions.keys() {
117                if !self.states.contains_key(target) {
118                    self.add_states(regex, *target);
119                }
120            }
121        }
122    }
123
124    fn from_regex(regex: &RegexDfa) -> Self {
125        let start = regex.start_state();
126        let mut dfa = Self {
127            start,
128            states: HashMap::new(),
129        };
130
131        dfa.add_states(regex, start);
132
133        dfa
134    }
135
136    fn handle(&self, input: &Ident) -> Expr {
137        let byte = parse_quote!(byte);
138        let start = self.start;
139
140        let branches = self.states.iter().map(|(id, state)| {
141            let body = state.handle(&byte, &self.states);
142            quote!(#id => #body)
143        });
144
145        parse_quote! {{
146            let mut i = 0;
147            let mut state = #start;
148
149            while i < #input.len() {
150                let #byte = #input[i];
151
152                state = match state {
153                    #(#branches,)*
154                    #[allow(unconditional_panic)]
155                    _ => [][0],
156                };
157
158                i += 1;
159            }
160
161            return false;
162        }}
163    }
164}
165
166fn build_dfa(regex: &str) -> RegexDfa {
167    let (regex, anchored) = if let Some(regex) = regex.strip_prefix('^') {
168        (regex, true)
169    } else {
170        (regex, false)
171    };
172
173    let dfa = dense::Builder::new()
174        .byte_classes(false)
175        .premultiply(false)
176        .minimize(true)
177        .anchored(anchored)
178        .build(regex)
179        .unwrap();
180
181    if let dense::DenseDFA::Standard(dfa) = dfa {
182        dfa
183    } else {
184        unreachable!()
185    }
186}
187
188struct Args {
189    regex: String,
190    expr: Expr,
191}
192
193impl Parse for Args {
194    fn parse(input: ParseStream) -> Result<Self> {
195        let regex_lit: LitStr = input.parse()?;
196        let _comma_token: Token![,] = input.parse()?;
197        let expr = input.parse()?;
198
199        Ok(Self {
200            regex: regex_lit.value(),
201            expr,
202        })
203    }
204}
205
206/// See crate documentation.
207#[proc_macro]
208pub fn match_regex(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
209    let args = parse_macro_input!(input as Args);
210    let regex = build_dfa(&args.regex);
211    let dfa = Dfa::from_regex(&regex);
212    let input_token = parse_quote!(input);
213    let block = dfa.handle(&input_token);
214    let input_expr = args.expr;
215
216    let tokens = quote! {{
217        const fn match_regex(#input_token: &[u8]) -> bool {
218            #block
219        }
220
221        match_regex(#input_expr)
222    }};
223
224    tokens.into()
225}