1use std::{
2 collections::{HashMap, HashSet, VecDeque},
3 fmt::{self, Display, Formatter},
4};
5
6use proc_macro::TokenStream;
7use quote::{quote, ToTokens};
8use thiserror::Error;
9
10#[derive(Debug)]
11struct Ast(Union);
12
13#[derive(Debug)]
14enum Union {
15 Union(Box<Union>, Box<Concat>),
16 Concat(Box<Concat>),
17}
18
19#[derive(Debug)]
20enum Concat {
21 Concat(Box<Concat>, Box<Star>),
22 Star(Box<Star>),
23}
24
25#[allow(clippy::enum_variant_names)]
26#[derive(Debug)]
27enum Star {
28 Star(Box<Terminal>),
29 Optional(Box<Terminal>),
30 Terminal(Box<Terminal>),
31}
32
33#[derive(Debug)]
34enum Terminal {
35 AnyChar,
36 Char(char),
37 Group(Box<Ast>),
38}
39
40#[derive(Debug, Error)]
41enum ParseError {
42 #[error("unexpected character: {0}")]
43 UnexpectedChar(char),
44 #[error("unexpected end of input")]
45 UnexpectedEnd,
46}
47
48struct Ctx<'a>(&'a str);
49
50impl<'a> Ctx<'a> {
51 fn new(s: &'a str) -> Self {
52 Self(s)
53 }
54
55 fn peek_skip_whitespace(&self) -> Option<char> {
56 self.0.chars().find(|&c| c != ' ')
57 }
58
59 fn next_skip_whitespace(&mut self) -> Option<char> {
60 let chars = self.0.chars();
61 for (i, c) in chars.enumerate() {
62 if c != ' ' {
63 self.0 = &self.0[(i + 1)..];
64 return Some(c);
65 }
66 }
67 None
68 }
69
70 fn next_with_whitespace(&mut self) -> Option<char> {
71 let c = self.0.chars().next();
72 if c.is_some() {
73 self.0 = &self.0[1..];
74 }
75 c
76 }
77}
78
79trait Parse {
80 fn parse(chars: &mut Ctx) -> Result<Self, ParseError>
81 where
82 Self: Sized;
83}
84
85impl Parse for Ast {
86 fn parse(chars: &mut Ctx) -> Result<Self, ParseError> {
87 Union::parse(chars).map(Ast)
88 }
89}
90
91impl Parse for Union {
92 fn parse(chars: &mut Ctx) -> Result<Self, ParseError> {
93 let mut left = Union::Concat(Box::new(Concat::parse(chars)?));
94 while let Some('+') = chars.peek_skip_whitespace() {
95 chars.next_skip_whitespace();
96 let right = Concat::parse(chars)?;
97 left = Union::Union(Box::new(left), Box::new(right));
98 }
99 Ok(left)
100 }
101}
102
103impl Parse for Concat {
104 fn parse(chars: &mut Ctx) -> Result<Self, ParseError> {
105 let mut left = Concat::Star(Box::new(Star::parse(chars)?));
106 while let Some(c) = chars.peek_skip_whitespace() {
107 if c == '+' {
108 break;
109 }
110 let right = Star::parse(chars)?;
111 left = Concat::Concat(Box::new(left), Box::new(right));
112 }
113 Ok(left)
114 }
115}
116
117impl Parse for Star {
118 fn parse(chars: &mut Ctx) -> Result<Self, ParseError> {
119 let left = Terminal::parse(chars)?;
120 match chars.peek_skip_whitespace() {
121 Some('*') => {
122 chars.next_skip_whitespace();
123 Ok(Star::Star(Box::new(left)))
124 },
125 Some('?') => {
126 chars.next_skip_whitespace();
127 Ok(Star::Optional(Box::new(left)))
128 },
129 _ => Ok(Star::Terminal(Box::new(left))),
130 }
131 }
132}
133
134impl Parse for Terminal {
135 fn parse(chars: &mut Ctx) -> Result<Self, ParseError> {
136 match chars.next_skip_whitespace() {
137 Some('.') => Ok(Terminal::AnyChar),
138 Some('(') => {
139 let ast = Ast::parse(chars)?;
140 match chars.next_skip_whitespace() {
141 Some(')') => Ok(Terminal::Group(Box::new(ast))),
142 Some(c) => Err(ParseError::UnexpectedChar(c)),
143 None => Err(ParseError::UnexpectedEnd),
144 }
145 },
146 Some('\\') => {
147 let c = match chars.next_with_whitespace() {
148 Some(c) => c,
149 None => return Err(ParseError::UnexpectedEnd),
150 };
151 Ok(Terminal::Char(c))
152 },
153 Some(c) => Ok(Terminal::Char(c)),
154 None => Err(ParseError::UnexpectedEnd),
155 }
156 }
157}
158
159impl Display for Ast {
160 fn fmt(&self, f: &mut Formatter) -> fmt::Result {
161 write!(f, "{}", self.0)
162 }
163}
164
165impl Display for Union {
166 fn fmt(&self, f: &mut Formatter) -> fmt::Result {
167 match self {
168 Union::Union(left, right) => write!(f, "({}+{})", left, right),
169 Union::Concat(concat) => write!(f, "{}", concat),
170 }
171 }
172}
173
174impl Display for Concat {
175 fn fmt(&self, f: &mut Formatter) -> fmt::Result {
176 match self {
177 Concat::Concat(left, right) => write!(f, "({}{})", left, right),
178 Concat::Star(star) => write!(f, "{}", star),
179 }
180 }
181}
182
183impl Display for Star {
184 fn fmt(&self, f: &mut Formatter) -> fmt::Result {
185 match self {
186 Star::Star(optional) => write!(f, "({}*)", optional),
187 Star::Optional(optional) => write!(f, "({}?)", optional),
188 Star::Terminal(optional) => write!(f, "{}", optional),
189 }
190 }
191}
192
193impl Display for Terminal {
194 fn fmt(&self, f: &mut Formatter) -> fmt::Result {
195 match self {
196 Terminal::AnyChar => write!(f, "."),
197 Terminal::Char(c) => write!(f, "{}", c),
198 Terminal::Group(ast) => write!(f, "({})", ast),
199 }
200 }
201}
202
203#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
204enum NfaTransitions {
205 Epsilon,
206 AnyChar,
207 Char(char),
208}
209
210impl NfaTransitions {
211 fn to_dfa(self) -> DfaTransitions {
212 match self {
213 NfaTransitions::AnyChar => DfaTransitions::AnyChar,
214 NfaTransitions::Char(c) => DfaTransitions::Char(c),
215 _ => unreachable!(),
216 }
217 }
218}
219
220#[derive(Debug)]
221struct Nfa {
222 start: usize,
223 accept: usize,
224 transitions: Vec<HashSet<(NfaTransitions, usize)>>,
225}
226
227impl Nfa {
228 fn new() -> Self {
229 Self {
230 start: 1,
231 accept: 0,
232 transitions: vec![HashSet::new(), HashSet::new()],
233 }
234 }
235
236 fn new_state(&mut self) -> usize {
237 let state = self.transitions.len();
238 self.transitions.push(HashSet::new());
239 state
240 }
241
242 fn add_transition(&mut self, from: usize, to: usize, epsilon: NfaTransitions) {
243 self.transitions[from].insert((epsilon, to));
244 }
245
246 fn add_epsilon_transition(&mut self, from: usize, to: usize) {
247 self.add_transition(from, to, NfaTransitions::Epsilon);
248 }
249
250 fn epsilon_closure(&self, state: usize) -> HashSet<usize> {
251 let mut closure = HashSet::new();
252 let mut stack = VecDeque::new();
253 stack.push_back(state);
254 while let Some(state) = stack.pop_front() {
255 if closure.contains(&state) {
256 continue;
257 }
258 closure.insert(state);
259 for (transition, next) in &self.transitions[state] {
260 if *transition == NfaTransitions::Epsilon {
261 stack.push_back(*next);
262 }
263 }
264 }
265 closure
266 }
267
268 fn to_dfa(&self) -> Dfa {
269 Dfa::product_construction(self)
270 }
271}
272
273trait ToNfa {
274 fn add_to_nfa(&self, nfa: &mut Nfa, from: usize, to: usize);
275 fn as_nfa(&self) -> Nfa {
276 let mut nfa = Nfa::new();
277 let start = nfa.start;
278 let accept = nfa.accept;
279 self.add_to_nfa(&mut nfa, start, accept);
280 nfa
281 }
282}
283
284impl ToNfa for Ast {
285 fn add_to_nfa(&self, nfa: &mut Nfa, from: usize, to: usize) {
286 self.0.add_to_nfa(nfa, from, to);
287 }
288}
289
290impl ToNfa for Union {
291 fn add_to_nfa(&self, nfa: &mut Nfa, from: usize, to: usize) {
292 match self {
293 Union::Union(left, right) => {
294 left.add_to_nfa(nfa, from, to);
295 right.add_to_nfa(nfa, from, to);
296 },
297 Union::Concat(concat) => concat.add_to_nfa(nfa, from, to),
298 }
299 }
300}
301
302impl ToNfa for Concat {
303 fn add_to_nfa(&self, nfa: &mut Nfa, from: usize, to: usize) {
304 match self {
305 Concat::Concat(left, right) => {
306 let mid = nfa.new_state();
307 left.add_to_nfa(nfa, from, mid);
308 right.add_to_nfa(nfa, mid, to);
309 },
310 Concat::Star(star) => star.add_to_nfa(nfa, from, to),
311 }
312 }
313}
314
315impl ToNfa for Star {
316 fn add_to_nfa(&self, nfa: &mut Nfa, from: usize, to: usize) {
317 match self {
318 Star::Star(optional) => {
319 let mid = nfa.new_state();
320 nfa.add_epsilon_transition(from, mid);
321 nfa.add_epsilon_transition(mid, to);
322 optional.add_to_nfa(nfa, mid, mid);
323 },
324 Star::Optional(ast) => {
325 ast.add_to_nfa(nfa, from, to);
326 nfa.add_epsilon_transition(from, to);
327 },
328 Star::Terminal(optional) => optional.add_to_nfa(nfa, from, to),
329 }
330 }
331}
332
333impl ToNfa for Terminal {
334 fn add_to_nfa(&self, nfa: &mut Nfa, from: usize, to: usize) {
335 match self {
336 Terminal::AnyChar => {
337 nfa.add_transition(from, to, NfaTransitions::AnyChar);
338 },
339 Terminal::Char(c) => nfa.add_transition(from, to, NfaTransitions::Char(*c)),
340 Terminal::Group(ast) => ast.add_to_nfa(nfa, from, to),
341 }
342 }
343}
344
345#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
347enum DfaTransitions {
348 AnyChar,
349 Char(char),
350}
351
352#[derive(Debug)]
353struct Dfa {
354 start: usize,
355 accept: usize,
356 accept_states: HashSet<usize>,
357 transitions: Vec<HashMap<DfaTransitions, usize>>,
358}
359
360impl Dfa {
361 fn new() -> Self {
362 let mut accept_states = HashSet::new();
363 accept_states.insert(0);
364 Self {
365 start: 1,
366 accept: 0,
367 accept_states,
368 transitions: vec![HashMap::new(), HashMap::new()],
369 }
370 }
371
372 fn new_state(&mut self) -> usize {
373 let state = self.transitions.len();
374 self.transitions.push(HashMap::new());
375 state
376 }
377
378 fn add_transition(&mut self, from: usize, to: usize, transition: DfaTransitions) {
379 self.transitions[from].insert(transition, to);
380 }
381
382 fn product_construction(nfa: &Nfa) -> Self {
383 let mut dfa = Dfa::new();
384 let initial_states = nfa.epsilon_closure(nfa.start);
385 let mut states = HashMap::new();
386 states.insert(dfa.start, initial_states);
387 states.insert(dfa.accept, HashSet::from_iter([nfa.accept]));
388 let mut queue = VecDeque::new();
389 let mut visited = HashSet::new();
390 queue.push_back(dfa.start);
391 while let Some(state) = queue.pop_front() {
392 if visited.contains(&state) {
393 continue;
394 }
395 visited.insert(state);
396 let mut transitions = HashMap::new();
397 let s = &states[&state];
398 if s.contains(&nfa.accept) {
399 dfa.accept_states.insert(state);
400 }
401 for state in s {
402 for transition in &nfa.transitions[*state] {
403 if transition.0 == NfaTransitions::Epsilon {
404 continue;
405 }
406 let next_states = transitions.entry(transition.0).or_insert_with(HashSet::new);
407 next_states.extend(nfa.epsilon_closure(transition.1));
408 }
409 }
410 for (transition, next_states) in transitions {
411 let next_state = 'a: {
412 for (state, set) in &states {
413 if set == &next_states {
414 break 'a *state;
415 }
416 }
417 let next_state = dfa.new_state();
418 states.insert(next_state, next_states);
419 next_state
420 };
421 dfa.add_transition(state, next_state, transition.to_dfa());
422 queue.push_back(next_state);
423 }
424 }
435 dfa
436 }
437}
438
439impl ToTokens for Dfa {
440 fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
441 let start = self.start;
442 let char_transitions =
443 self.transitions
444 .iter()
445 .enumerate()
446 .flat_map(|(from, transitions)| {
447 transitions
448 .iter()
449 .filter(|i| *i.0 != DfaTransitions::AnyChar)
450 .map(move |(transition, to)| match transition {
451 DfaTransitions::Char(c) => {
452 quote! { (#from, #c) => #to, }
453 },
454 _ => unreachable!(),
455 })
456 });
457 let any_char_transitions =
458 self.transitions
459 .iter()
460 .enumerate()
461 .flat_map(|(from, transitions)| {
462 transitions
463 .iter()
464 .filter(|i| *i.0 == DfaTransitions::AnyChar)
465 .map(move |(transition, to)| match transition {
466 DfaTransitions::AnyChar => {
467 quote! { (#from, _) => #to, }
468 },
469 _ => unreachable!(),
470 })
471 });
472 let accept_states = self.accept_states.iter().collect::<Vec<_>>();
473 let accept_states = quote! { #(state == #accept_states)||* };
474 tokens.extend(quote! {
475 let mut state = #start;
476 while let Some(c) = chars.next() {
477 state = match (state, c) {
478 #(#char_transitions)*
479 #(#any_char_transitions)*
480 _ => return false,
481 };
482 }
483 #accept_states
484 });
485 }
486}
487
488struct Input {
489 name: syn::Ident,
490 value: syn::LitStr,
491}
492
493impl syn::parse::Parse for Input {
494 fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
495 let name = input.parse()?;
496 input.parse::<syn::Token![=>]>()?;
497 let value = input.parse()?;
498 Ok(Self { name, value })
499 }
500}
501
502#[proc_macro]
522pub fn regex(input: TokenStream) -> TokenStream {
523 let input = syn::parse_macro_input!(input as Input);
524 let lit = input.value.value();
525 let mut chars = Ctx::new(&lit);
526 let ast = Ast::parse(&mut chars).unwrap();
527 let nfa = ast.as_nfa();
528 let dfa = nfa.to_dfa();
529 let name = input.name;
530 quote! {
531 struct #name;
532
533 impl #name {
534 fn matches(s: &str) -> bool {
535 let mut chars = s.chars();
536 #dfa
537 }
538 }
539 }
540 .into()
541}