amazon_interview_redemption_regex_parser/
regex.rs

1use std::{iter::Peekable, str::Chars};
2
3use corrida::Corrida;
4use gerber::{nfa::*, nfa_state_creator};
5
6type RState = State<2, char>;
7/// Parses a regex string into an NFA. Returns an error if the regex is invalid.
8pub fn parse_regex<'a>(regex_string: &str, arena: &'a Corrida) -> Result<Nfa<'a ,RState>, &'static str> {
9    nfa_state_creator!(($), new_state, arena, char, 2);
10    let create_state = |is_final| new_state!(is_final);
11    
12    fn parse_base<'a>(cur: &'a mut RState, chars: &mut Peekable<Chars>, create_state: &impl Fn(bool) -> &'a mut State<2, char>) -> Result<(&'a mut RState, &'a mut RState), &'static str> {
13        let (base_start, base_end) = match chars.next() {
14            Some('(') => {
15                let (start_node, end_state) = parse_group::<false>(chars, create_state)?;
16                cur.push_transition(None, Some(start_node));
17                (cur, end_state)
18            },
19            Some(c) => {
20                if c == '+' || c == '*' || c == '?' {
21                    return Err("Got an operator (+, *, ?) when there was no base to skip/repeat");
22                }
23
24                let new_state = create_state(false);
25                cur.push_transition(Some(c), Some(new_state));
26
27                (cur, new_state)
28            },
29            None => {
30                panic!("How did we get here.")
31            }
32        };
33
34        // Should be on operators, if not end base
35        let mut add_skip = false;
36        let mut add_cycle = false;
37
38        while let Some(c) = chars.peek() {
39            match c {
40                '+' => {
41                    add_cycle = true;
42                },
43                '*' => {
44                    add_cycle = true;
45                    add_skip = true;
46                },
47                '?' => {
48                    add_skip = true;
49                },
50                _ => {
51                    break;
52                }
53            }
54
55            chars.next(); //eat
56        }
57
58        if add_cycle {
59            base_end.push_transition(None, Some(base_start));
60        }
61        if add_skip {
62            base_start.push_transition(None, Some(base_end));
63        }
64
65        Ok((base_start, base_end))
66    }
67
68    fn parse_concat<'a>(chars: &mut Peekable<Chars>,create_state: &impl Fn(bool) -> &'a mut State<2, char>) -> Result<(&'a mut RState, Option<&'a mut RState>), &'static str> {
69        let mut cur = create_state(false);
70        let mut pattern_start = None;
71
72        while chars.peek().is_some() && chars.peek() != Some(&')') && chars.peek() != Some(&'|') {
73            let (base_start, base_end) = parse_base(cur, chars, create_state)?;
74
75            if pattern_start.is_none() {
76                pattern_start = Some(base_start);
77            }
78
79            cur = base_end;
80        }
81
82        Ok(match pattern_start {
83            Some(start) => (start, Some(cur)),
84            None => (cur, None)
85        })
86    }
87
88    fn parse_group<'a, const OUTERMOST: bool>(chars: &mut Peekable<Chars>, create_state: &impl Fn(bool) -> &'a mut State<2, char>) -> Result<(&'a mut RState, &'a mut RState), &'static str> {
89        fn add_to_union(union_start: &mut RState, union_end: &mut RState, concat_start: &mut RState, concat_end: Option<&mut RState>) {
90            union_start.push_transition(None, Some(concat_start));
91            let concat_end = concat_end.unwrap_or(concat_start);
92            concat_end.push_transition(None, Some(union_end));
93        }
94
95        let (concat_start, concat_end_opt) = parse_concat(chars, create_state)?;
96
97        let (group_start, group_end) = if let Some(&'|') = chars.peek() {
98            let (union_start, union_end) = (create_state(false), create_state(false));
99            add_to_union(union_start, union_end, concat_start, concat_end_opt);
100            
101            loop {
102                chars.next(); // eat '|'
103                let (concat_start, concat_end_opt) = parse_concat(chars, create_state)?;
104                add_to_union(union_start, union_end, concat_start, concat_end_opt);
105                if chars.peek() != Some(&'|') { break; }
106            }
107
108            (union_start, union_end)
109        } else {
110            let concat_end = concat_end_opt.unwrap_or_else(|| {
111                let end = create_state(true);
112                concat_start.push_transition(None, Some(end));
113                end
114            });
115            (concat_start, concat_end)
116        };
117
118        match chars.next() {
119            Some(')') if OUTERMOST => {
120                return Err("Attempted to close a group in the outermost context ( no matching '(' )");
121            },
122            None if !OUTERMOST => {
123                return Err("EOF when not all groups were closed, '(' without matching ')'");
124            },
125            _ => {}
126        }
127
128        // SAFETY, in both arms the non null comes from an exclusive reference, so all good.
129        Ok((group_start, group_end))
130    }
131
132    let mut chars = regex_string.chars().peekable();
133    let (start_node, end_node) = parse_group::<true>(&mut chars, &create_state)?;
134    end_node.set_accept(true);
135
136    Ok(Nfa::new(start_node))
137}
138
139#[cfg(test)]
140mod tests {
141    use std::time::Instant;
142
143    use super::*;
144
145    #[test]
146    pub fn test_basics() {
147        let arena = Corrida::new(None);
148        let nfa = parse_regex("ab*(c|)", &arena).unwrap();
149        
150        assert_eq!(nfa.simulate_iter("".chars()), false);
151        assert_eq!(nfa.simulate_iter("a".chars()), true);
152        assert_eq!(nfa.simulate_iter("ab".chars()), true);
153        assert_eq!(nfa.simulate_iter("ac".chars()), true);
154        assert_eq!(nfa.simulate_iter("abb".chars()), true);
155        assert_eq!(nfa.simulate_iter("abbcc".chars()), false);
156        assert_eq!(nfa.simulate_iter("abbbac".chars()), false);
157        assert_eq!(nfa.simulate_iter("abaa".chars()), false);
158        assert_eq!(nfa.simulate_iter("abbbbbbbc".chars()), true);
159    }    
160
161    #[test]
162    pub fn test_unfriendly() {
163        let arena = Corrida::new(None);
164        let nfa = parse_regex("a*b*a*b*a*b*a*b*a*b*(|)?a", &arena).unwrap();
165
166        let mut test = vec!['b'; 100_000];
167        let start = Instant::now();
168        assert_eq!(nfa.simulate_slice(&test), false);
169        test.push('a');
170        assert_eq!(nfa.simulate_slice(&test), true);
171        let a = start.elapsed();
172
173        test.pop();
174        let dfa = nfa.as_dfa(&arena);
175
176        let start = Instant::now();
177        assert_eq!(dfa.simulate_slice(&test), false);
178        test.push('a');
179        assert_eq!(dfa.simulate_slice(&test), true);
180        let b = start.elapsed();
181
182        println!("Unfriendly -- NFA {:?}, DFA {:?}", a, b);
183    }
184
185    #[test]
186    pub fn complicated() {
187        let arena = Corrida::new(None);
188        let nfa = parse_regex("(((a|b)+c?(a|b)*)?(c(a|b)+|a?b?c+)((a|b|c)*)(a(a)+)?)+", &arena).unwrap();
189
190        let testa = vec!['a'; 100_000];
191        let testb = vec!['b'; 100_000];
192        let testc = vec!['c'; 100_000];
193
194        let start = Instant::now();
195        assert!(!nfa.simulate_slice(&testa));
196        assert!(!nfa.simulate_slice(&testb));
197        assert!(nfa.simulate_slice(&testc));
198        let a = start.elapsed();
199
200        let dfa = nfa.as_dfa(&arena);
201        let start = Instant::now();
202        assert!(!dfa.simulate_slice(&testa));
203        assert!(!dfa.simulate_slice(&testb));
204        assert!(dfa.simulate_slice(&testc));
205        let b = start.elapsed();
206        
207        println!("Complicated -- NFA {:?}, DFA {:?}", a, b);
208    }
209
210    #[test]
211    pub fn test_as() {
212        let arena = Corrida::new(None);
213        const N: usize = 100;
214        let nfa = parse_regex(&("a?".repeat(N) + &"a".repeat(N)), &arena).unwrap();
215
216        let test = "a".repeat(N);
217
218        let start = Instant::now();
219        assert!(nfa.simulate_iter(test.chars()));
220        let a = start.elapsed();
221
222        // Friendly simulation explodes on this regex.        
223        // let start = Instant::now();
224        // assert!(nfa.simulate_iter_friendly(test.chars()));
225        // let b = start.elapsed();
226        
227        let dfa = nfa.as_dfa(&arena);
228        let start = Instant::now();
229        assert!(dfa.simulate_iter(test.chars()));
230        let c = start.elapsed();
231
232        println!("a?^na^n -- NFA {:?}, NFA Friendly {:?}, DFA {:?}", a, "N/A", c);
233    }
234}