dfa_regex/
lib.rs

1use std::{
2    collections::{HashMap, HashSet, VecDeque},
3    fmt::{self, Display, Formatter},
4};
5
6use proc_macro::TokenStream;
7use quote::{quote, ToTokens};
8use thiserror::Error;
9
10#[derive(Debug)]
11struct Ast(Union);
12
13#[derive(Debug)]
14enum Union {
15    Union(Box<Union>, Box<Concat>),
16    Concat(Box<Concat>),
17}
18
19#[derive(Debug)]
20enum Concat {
21    Concat(Box<Concat>, Box<Star>),
22    Star(Box<Star>),
23}
24
25#[allow(clippy::enum_variant_names)]
26#[derive(Debug)]
27enum Star {
28    Star(Box<Terminal>),
29    Optional(Box<Terminal>),
30    Terminal(Box<Terminal>),
31}
32
33#[derive(Debug)]
34enum Terminal {
35    AnyChar,
36    Char(char),
37    Group(Box<Ast>),
38}
39
40#[derive(Debug, Error)]
41enum ParseError {
42    #[error("unexpected character: {0}")]
43    UnexpectedChar(char),
44    #[error("unexpected end of input")]
45    UnexpectedEnd,
46}
47
48struct Ctx<'a>(&'a str);
49
50impl<'a> Ctx<'a> {
51    fn new(s: &'a str) -> Self {
52        Self(s)
53    }
54
55    fn peek_skip_whitespace(&self) -> Option<char> {
56        self.0.chars().find(|&c| c != ' ')
57    }
58
59    fn next_skip_whitespace(&mut self) -> Option<char> {
60        let chars = self.0.chars();
61        for (i, c) in chars.enumerate() {
62            if c != ' ' {
63                self.0 = &self.0[(i + 1)..];
64                return Some(c);
65            }
66        }
67        None
68    }
69
70    fn next_with_whitespace(&mut self) -> Option<char> {
71        let c = self.0.chars().next();
72        if c.is_some() {
73            self.0 = &self.0[1..];
74        }
75        c
76    }
77}
78
79trait Parse {
80    fn parse(chars: &mut Ctx) -> Result<Self, ParseError>
81    where
82        Self: Sized;
83}
84
85impl Parse for Ast {
86    fn parse(chars: &mut Ctx) -> Result<Self, ParseError> {
87        Union::parse(chars).map(Ast)
88    }
89}
90
91impl Parse for Union {
92    fn parse(chars: &mut Ctx) -> Result<Self, ParseError> {
93        let mut left = Union::Concat(Box::new(Concat::parse(chars)?));
94        while let Some('+') = chars.peek_skip_whitespace() {
95            chars.next_skip_whitespace();
96            let right = Concat::parse(chars)?;
97            left = Union::Union(Box::new(left), Box::new(right));
98        }
99        Ok(left)
100    }
101}
102
103impl Parse for Concat {
104    fn parse(chars: &mut Ctx) -> Result<Self, ParseError> {
105        let mut left = Concat::Star(Box::new(Star::parse(chars)?));
106        while let Some(c) = chars.peek_skip_whitespace() {
107            if c == '+' {
108                break;
109            }
110            let right = Star::parse(chars)?;
111            left = Concat::Concat(Box::new(left), Box::new(right));
112        }
113        Ok(left)
114    }
115}
116
117impl Parse for Star {
118    fn parse(chars: &mut Ctx) -> Result<Self, ParseError> {
119        let left = Terminal::parse(chars)?;
120        match chars.peek_skip_whitespace() {
121            Some('*') => {
122                chars.next_skip_whitespace();
123                Ok(Star::Star(Box::new(left)))
124            },
125            Some('?') => {
126                chars.next_skip_whitespace();
127                Ok(Star::Optional(Box::new(left)))
128            },
129            _ => Ok(Star::Terminal(Box::new(left))),
130        }
131    }
132}
133
134impl Parse for Terminal {
135    fn parse(chars: &mut Ctx) -> Result<Self, ParseError> {
136        match chars.next_skip_whitespace() {
137            Some('.') => Ok(Terminal::AnyChar),
138            Some('(') => {
139                let ast = Ast::parse(chars)?;
140                match chars.next_skip_whitespace() {
141                    Some(')') => Ok(Terminal::Group(Box::new(ast))),
142                    Some(c) => Err(ParseError::UnexpectedChar(c)),
143                    None => Err(ParseError::UnexpectedEnd),
144                }
145            },
146            Some('\\') => {
147                let c = match chars.next_with_whitespace() {
148                    Some(c) => c,
149                    None => return Err(ParseError::UnexpectedEnd),
150                };
151                Ok(Terminal::Char(c))
152            },
153            Some(c) => Ok(Terminal::Char(c)),
154            None => Err(ParseError::UnexpectedEnd),
155        }
156    }
157}
158
159impl Display for Ast {
160    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
161        write!(f, "{}", self.0)
162    }
163}
164
165impl Display for Union {
166    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
167        match self {
168            Union::Union(left, right) => write!(f, "({}+{})", left, right),
169            Union::Concat(concat) => write!(f, "{}", concat),
170        }
171    }
172}
173
174impl Display for Concat {
175    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
176        match self {
177            Concat::Concat(left, right) => write!(f, "({}{})", left, right),
178            Concat::Star(star) => write!(f, "{}", star),
179        }
180    }
181}
182
183impl Display for Star {
184    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
185        match self {
186            Star::Star(optional) => write!(f, "({}*)", optional),
187            Star::Optional(optional) => write!(f, "({}?)", optional),
188            Star::Terminal(optional) => write!(f, "{}", optional),
189        }
190    }
191}
192
193impl Display for Terminal {
194    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
195        match self {
196            Terminal::AnyChar => write!(f, "."),
197            Terminal::Char(c) => write!(f, "{}", c),
198            Terminal::Group(ast) => write!(f, "({})", ast),
199        }
200    }
201}
202
203#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
204enum NfaTransitions {
205    Epsilon,
206    AnyChar,
207    Char(char),
208}
209
210impl NfaTransitions {
211    fn to_dfa(self) -> DfaTransitions {
212        match self {
213            NfaTransitions::AnyChar => DfaTransitions::AnyChar,
214            NfaTransitions::Char(c) => DfaTransitions::Char(c),
215            _ => unreachable!(),
216        }
217    }
218}
219
220#[derive(Debug)]
221struct Nfa {
222    start: usize,
223    accept: usize,
224    transitions: Vec<HashSet<(NfaTransitions, usize)>>,
225}
226
227impl Nfa {
228    fn new() -> Self {
229        Self {
230            start: 1,
231            accept: 0,
232            transitions: vec![HashSet::new(), HashSet::new()],
233        }
234    }
235
236    fn new_state(&mut self) -> usize {
237        let state = self.transitions.len();
238        self.transitions.push(HashSet::new());
239        state
240    }
241
242    fn add_transition(&mut self, from: usize, to: usize, epsilon: NfaTransitions) {
243        self.transitions[from].insert((epsilon, to));
244    }
245
246    fn add_epsilon_transition(&mut self, from: usize, to: usize) {
247        self.add_transition(from, to, NfaTransitions::Epsilon);
248    }
249
250    fn epsilon_closure(&self, state: usize) -> HashSet<usize> {
251        let mut closure = HashSet::new();
252        let mut stack = VecDeque::new();
253        stack.push_back(state);
254        while let Some(state) = stack.pop_front() {
255            if closure.contains(&state) {
256                continue;
257            }
258            closure.insert(state);
259            for (transition, next) in &self.transitions[state] {
260                if *transition == NfaTransitions::Epsilon {
261                    stack.push_back(*next);
262                }
263            }
264        }
265        closure
266    }
267
268    fn to_dfa(&self) -> Dfa {
269        Dfa::product_construction(self)
270    }
271}
272
273trait ToNfa {
274    fn add_to_nfa(&self, nfa: &mut Nfa, from: usize, to: usize);
275    fn as_nfa(&self) -> Nfa {
276        let mut nfa = Nfa::new();
277        let start = nfa.start;
278        let accept = nfa.accept;
279        self.add_to_nfa(&mut nfa, start, accept);
280        nfa
281    }
282}
283
284impl ToNfa for Ast {
285    fn add_to_nfa(&self, nfa: &mut Nfa, from: usize, to: usize) {
286        self.0.add_to_nfa(nfa, from, to);
287    }
288}
289
290impl ToNfa for Union {
291    fn add_to_nfa(&self, nfa: &mut Nfa, from: usize, to: usize) {
292        match self {
293            Union::Union(left, right) => {
294                left.add_to_nfa(nfa, from, to);
295                right.add_to_nfa(nfa, from, to);
296            },
297            Union::Concat(concat) => concat.add_to_nfa(nfa, from, to),
298        }
299    }
300}
301
302impl ToNfa for Concat {
303    fn add_to_nfa(&self, nfa: &mut Nfa, from: usize, to: usize) {
304        match self {
305            Concat::Concat(left, right) => {
306                let mid = nfa.new_state();
307                left.add_to_nfa(nfa, from, mid);
308                right.add_to_nfa(nfa, mid, to);
309            },
310            Concat::Star(star) => star.add_to_nfa(nfa, from, to),
311        }
312    }
313}
314
315impl ToNfa for Star {
316    fn add_to_nfa(&self, nfa: &mut Nfa, from: usize, to: usize) {
317        match self {
318            Star::Star(optional) => {
319                let mid = nfa.new_state();
320                nfa.add_epsilon_transition(from, mid);
321                nfa.add_epsilon_transition(mid, to);
322                optional.add_to_nfa(nfa, mid, mid);
323            },
324            Star::Optional(ast) => {
325                ast.add_to_nfa(nfa, from, to);
326                nfa.add_epsilon_transition(from, to);
327            },
328            Star::Terminal(optional) => optional.add_to_nfa(nfa, from, to),
329        }
330    }
331}
332
333impl ToNfa for Terminal {
334    fn add_to_nfa(&self, nfa: &mut Nfa, from: usize, to: usize) {
335        match self {
336            Terminal::AnyChar => {
337                nfa.add_transition(from, to, NfaTransitions::AnyChar);
338            },
339            Terminal::Char(c) => nfa.add_transition(from, to, NfaTransitions::Char(*c)),
340            Terminal::Group(ast) => ast.add_to_nfa(nfa, from, to),
341        }
342    }
343}
344
345// what if there's a AnyChar transition and a Char transition from a state?
346#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
347enum DfaTransitions {
348    AnyChar,
349    Char(char),
350}
351
352#[derive(Debug)]
353struct Dfa {
354    start: usize,
355    accept: usize,
356    accept_states: HashSet<usize>,
357    transitions: Vec<HashMap<DfaTransitions, usize>>,
358}
359
360impl Dfa {
361    fn new() -> Self {
362        let mut accept_states = HashSet::new();
363        accept_states.insert(0);
364        Self {
365            start: 1,
366            accept: 0,
367            accept_states,
368            transitions: vec![HashMap::new(), HashMap::new()],
369        }
370    }
371
372    fn new_state(&mut self) -> usize {
373        let state = self.transitions.len();
374        self.transitions.push(HashMap::new());
375        state
376    }
377
378    fn add_transition(&mut self, from: usize, to: usize, transition: DfaTransitions) {
379        self.transitions[from].insert(transition, to);
380    }
381
382    fn product_construction(nfa: &Nfa) -> Self {
383        let mut dfa = Dfa::new();
384        let initial_states = nfa.epsilon_closure(nfa.start);
385        let mut states = HashMap::new();
386        states.insert(dfa.start, initial_states);
387        states.insert(dfa.accept, HashSet::from_iter([nfa.accept]));
388        let mut queue = VecDeque::new();
389        let mut visited = HashSet::new();
390        queue.push_back(dfa.start);
391        while let Some(state) = queue.pop_front() {
392            if visited.contains(&state) {
393                continue;
394            }
395            visited.insert(state);
396            let mut transitions = HashMap::new();
397            let s = &states[&state];
398            if s.contains(&nfa.accept) {
399                dfa.accept_states.insert(state);
400            }
401            for state in s {
402                for transition in &nfa.transitions[*state] {
403                    if transition.0 == NfaTransitions::Epsilon {
404                        continue;
405                    }
406                    let next_states = transitions.entry(transition.0).or_insert_with(HashSet::new);
407                    next_states.extend(nfa.epsilon_closure(transition.1));
408                }
409            }
410            for (transition, next_states) in transitions {
411                let next_state = 'a: {
412                    for (state, set) in &states {
413                        if set == &next_states {
414                            break 'a *state;
415                        }
416                    }
417                    let next_state = dfa.new_state();
418                    states.insert(next_state, next_states);
419                    next_state
420                };
421                dfa.add_transition(state, next_state, transition.to_dfa());
422                queue.push_back(next_state);
423            }
424            // for each state, there are multiple possible transitions
425            // we need to find the set of states that can be reached by each
426            // transition for epsilon moves, we compute the epsilon
427            // closure for non-epsilon moves, we compute the set of
428            // states that can be reached by the move
429            //
430            // first set is {start}, we compute a -> {...} for each transition a
431            // from {start, epsilon_closure(start)}
432            // each set of states is then a state in the dfa
433            // we then move to the next set of states
434        }
435        dfa
436    }
437}
438
439impl ToTokens for Dfa {
440    fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
441        let start = self.start;
442        let char_transitions =
443            self.transitions
444                .iter()
445                .enumerate()
446                .flat_map(|(from, transitions)| {
447                    transitions
448                        .iter()
449                        .filter(|i| *i.0 != DfaTransitions::AnyChar)
450                        .map(move |(transition, to)| match transition {
451                            DfaTransitions::Char(c) => {
452                                quote! { (#from, #c) => #to, }
453                            },
454                            _ => unreachable!(),
455                        })
456                });
457        let any_char_transitions =
458            self.transitions
459                .iter()
460                .enumerate()
461                .flat_map(|(from, transitions)| {
462                    transitions
463                        .iter()
464                        .filter(|i| *i.0 == DfaTransitions::AnyChar)
465                        .map(move |(transition, to)| match transition {
466                            DfaTransitions::AnyChar => {
467                                quote! { (#from, _) => #to, }
468                            },
469                            _ => unreachable!(),
470                        })
471                });
472        let accept_states = self.accept_states.iter().collect::<Vec<_>>();
473        let accept_states = quote! { #(state == #accept_states)||* };
474        tokens.extend(quote! {
475            let mut state = #start;
476            while let Some(c) = chars.next() {
477                state = match (state, c) {
478                    #(#char_transitions)*
479                    #(#any_char_transitions)*
480                    _ => return false,
481                };
482            }
483            #accept_states
484        });
485    }
486}
487
488struct Input {
489    name: syn::Ident,
490    value: syn::LitStr,
491}
492
493impl syn::parse::Parse for Input {
494    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
495        let name = input.parse()?;
496        input.parse::<syn::Token![=>]>()?;
497        let value = input.parse()?;
498        Ok(Self { name, value })
499    }
500}
501
502/// Create a regex matcher
503///
504/// # Regex Syntax
505/// - `.`: matches any character
506/// - `a`: matches the character `a`
507/// - `a*`: matches zero or more `a`
508/// - `a?`: matches zero or one `a`
509/// - `a+b`: matches `a` or `b`
510/// - `(a+b)*`: matches zero or more `a` or `b`
511///
512/// # Example
513/// ```rust
514/// use dfa_regex::regex;
515///
516/// regex!(Foo => "a*");
517/// assert!(Foo::matches("aaaa"));
518/// assert!(Foo::matches(""));
519/// assert!(!Foo::matches("b"));
520/// ```
521#[proc_macro]
522pub fn regex(input: TokenStream) -> TokenStream {
523    let input = syn::parse_macro_input!(input as Input);
524    let lit = input.value.value();
525    let mut chars = Ctx::new(&lit);
526    let ast = Ast::parse(&mut chars).unwrap();
527    let nfa = ast.as_nfa();
528    let dfa = nfa.to_dfa();
529    let name = input.name;
530    quote! {
531        struct #name;
532
533        impl #name {
534            fn matches(s: &str) -> bool {
535                let mut chars = s.chars();
536                #dfa
537            }
538        }
539    }
540    .into()
541}