1use proc_macro2::TokenStream;
16use quote::quote;
17use regex_automata::{dense, DFA};
18use std::collections::{BTreeSet, HashMap};
19use std::ops::RangeInclusive;
20use syn::{parse::*, *};
21
22type RegexDfa = dense::Standard<Vec<usize>, usize>;
23
24#[derive(Clone, PartialEq)]
25enum State {
26 Match,
27 Dead,
28 Transitions(HashMap<usize, BTreeSet<u8>>),
29}
30
31fn range_to_tokens(range: RangeInclusive<u8>) -> TokenStream {
32 let (start, end) = range.into_inner();
33 if start == end {
34 quote!(#start)
35 } else {
36 quote!(#start..=#end)
37 }
38}
39
40impl State {
41 fn from_regex(regex: &RegexDfa, state: usize) -> Self {
42 if regex.is_match_state(state) {
43 Self::Match
44 } else if regex.is_dead_state(state) {
45 Self::Dead
46 } else {
47 let mut transitions = HashMap::new();
48
49 for byte in 0..=255 {
50 let next = regex.next_state(state, byte);
51 transitions
52 .entry(next)
53 .or_insert_with(BTreeSet::new)
54 .insert(byte);
55 }
56
57 Self::Transitions(transitions)
58 }
59 }
60
61 fn handle(&self, byte: &Ident, states: &HashMap<usize, State>) -> Expr {
62 match self {
63 Self::Match => parse_quote!(return true),
64 Self::Dead => parse_quote!(return false),
65 Self::Transitions(transitions) => {
66 let branches = transitions.iter().map(|(target, bytes)| {
67 let mut ranges = vec![];
68 let mut range: Option<RangeInclusive<u8>> = None;
69 for &byte in bytes {
70 if let Some(range) = &mut range {
71 if *range.end() == byte - 1 {
72 *range = *range.start()..=byte;
73 continue;
74 } else {
75 ranges.push(range_to_tokens(range.clone()));
76 }
77 }
78 range = Some(byte..=byte);
79 }
80
81 if let Some(range) = range {
82 ranges.push(range_to_tokens(range));
83 }
84
85 let handler = match states[target] {
86 Self::Match => quote!(return true),
87 Self::Dead => quote!(return false),
88 _ => quote!(#target),
89 };
90
91 quote!(#(#ranges)|* => #handler)
92 });
93
94 parse_quote! {
95 match #byte {
96 #(#branches),*
97 }
98 }
99 }
100 }
101 }
102}
103
104struct Dfa {
105 start: usize,
106 states: HashMap<usize, State>,
107}
108
109impl Dfa {
110 fn add_states(&mut self, regex: &RegexDfa, id: usize) {
111 let state = State::from_regex(regex, id);
112
113 self.states.insert(id, state.clone());
114
115 if let State::Transitions(transitions) = &state {
116 for target in transitions.keys() {
117 if !self.states.contains_key(target) {
118 self.add_states(regex, *target);
119 }
120 }
121 }
122 }
123
124 fn from_regex(regex: &RegexDfa) -> Self {
125 let start = regex.start_state();
126 let mut dfa = Self {
127 start,
128 states: HashMap::new(),
129 };
130
131 dfa.add_states(regex, start);
132
133 dfa
134 }
135
136 fn handle(&self, input: &Ident) -> Expr {
137 let byte = parse_quote!(byte);
138 let start = self.start;
139
140 let branches = self.states.iter().map(|(id, state)| {
141 let body = state.handle(&byte, &self.states);
142 quote!(#id => #body)
143 });
144
145 parse_quote! {{
146 let mut i = 0;
147 let mut state = #start;
148
149 while i < #input.len() {
150 let #byte = #input[i];
151
152 state = match state {
153 #(#branches,)*
154 #[allow(unconditional_panic)]
155 _ => [][0],
156 };
157
158 i += 1;
159 }
160
161 return false;
162 }}
163 }
164}
165
166fn build_dfa(regex: &str) -> RegexDfa {
167 let (regex, anchored) = if let Some(regex) = regex.strip_prefix('^') {
168 (regex, true)
169 } else {
170 (regex, false)
171 };
172
173 let dfa = dense::Builder::new()
174 .byte_classes(false)
175 .premultiply(false)
176 .minimize(true)
177 .anchored(anchored)
178 .build(regex)
179 .unwrap();
180
181 if let dense::DenseDFA::Standard(dfa) = dfa {
182 dfa
183 } else {
184 unreachable!()
185 }
186}
187
188struct Args {
189 regex: String,
190 expr: Expr,
191}
192
193impl Parse for Args {
194 fn parse(input: ParseStream) -> Result<Self> {
195 let regex_lit: LitStr = input.parse()?;
196 let _comma_token: Token![,] = input.parse()?;
197 let expr = input.parse()?;
198
199 Ok(Self {
200 regex: regex_lit.value(),
201 expr,
202 })
203 }
204}
205
206#[proc_macro]
208pub fn match_regex(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
209 let args = parse_macro_input!(input as Args);
210 let regex = build_dfa(&args.regex);
211 let dfa = Dfa::from_regex(®ex);
212 let input_token = parse_quote!(input);
213 let block = dfa.handle(&input_token);
214 let input_expr = args.expr;
215
216 let tokens = quote! {{
217 const fn match_regex(#input_token: &[u8]) -> bool {
218 #block
219 }
220
221 match_regex(#input_expr)
222 }};
223
224 tokens.into()
225}