ere_core/
working_nfa.rs

1//! Implements the primary compile-time intermediate [`WorkingNFA`] structure for optimization.
2
3use crate::parse_tree::Atom;
4use crate::simplified_tree::SimplifiedTreeNode;
5use quote::{quote, ToTokens};
6
7#[derive(Debug, Clone, Copy, PartialEq, Eq)]
8pub enum EpsilonType {
9    None,
10    StartAnchor,
11    EndAnchor,
12    StartCapture(usize),
13    EndCapture(usize),
14}
15impl ToTokens for EpsilonType {
16    fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
17        match self {
18            EpsilonType::None => {
19                tokens.extend(quote! { ::ere_core::working_nfa::EpsilonType::None })
20            }
21            EpsilonType::StartAnchor => {
22                tokens.extend(quote! { ::ere_core::working_nfa::EpsilonType::StartAnchor })
23            }
24            EpsilonType::EndAnchor => {
25                tokens.extend(quote! { ::ere_core::working_nfa::EpsilonType::EndAnchor })
26            }
27            EpsilonType::StartCapture(group_num) => tokens.extend(quote! {
28                ::ere_core::working_nfa::EpsilonType::StartCapture(#group_num)
29            }),
30            EpsilonType::EndCapture(group_num) => tokens.extend(quote! {
31                ::ere_core::working_nfa::EpsilonType::EndCapture(#group_num)
32            }),
33        };
34    }
35}
36
37/// An epsilon transition for the [`WorkingNFA`]
38#[derive(Debug, Clone, Copy)]
39pub struct EpsilonTransition {
40    pub(crate) to: usize,
41    pub(crate) special: EpsilonType,
42}
43impl EpsilonTransition {
44    pub(crate) const fn new(to: usize) -> EpsilonTransition {
45        return EpsilonTransition {
46            to,
47            special: EpsilonType::None,
48        };
49    }
50    pub(crate) const fn with_offset(self, offset: usize) -> EpsilonTransition {
51        return EpsilonTransition {
52            to: self.to + offset,
53            special: self.special,
54        };
55    }
56    pub(crate) fn inplace_offset(&mut self, offset: usize) {
57        self.to += offset;
58    }
59    pub(crate) const fn add_offset(&self, offset: usize) -> EpsilonTransition {
60        return EpsilonTransition {
61            to: self.to + offset,
62            special: self.special,
63        };
64    }
65    /// Only intended for internal use by macros.
66    pub const fn __load(to: usize, special: EpsilonType) -> EpsilonTransition {
67        return EpsilonTransition { to, special };
68    }
69}
70impl std::fmt::Display for EpsilonTransition {
71    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
72        return write!(f, "-> {}", self.to);
73    }
74}
75impl ToTokens for EpsilonTransition {
76    fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
77        let EpsilonTransition { to, special } = self;
78        tokens.extend(quote! {
79            ere_core::working_nfa::EpsilonTransition::__load(
80                #to,
81                #special,
82            )
83        });
84    }
85}
86
87#[derive(Debug, Clone)]
88pub(crate) struct WorkingTransition {
89    pub(crate) to: usize,
90    pub(crate) symbol: Atom,
91}
92impl WorkingTransition {
93    pub fn new(to: usize, symbol: Atom) -> WorkingTransition {
94        return WorkingTransition { to, symbol };
95    }
96    pub fn with_offset(mut self, offset: usize) -> WorkingTransition {
97        self.inplace_offset(offset);
98        return self;
99    }
100    pub fn inplace_offset(&mut self, offset: usize) {
101        self.to += offset;
102    }
103    pub fn add_offset(&self, offset: usize) -> WorkingTransition {
104        return WorkingTransition {
105            to: self.to + offset,
106            symbol: self.symbol.clone(),
107        };
108    }
109}
110impl std::fmt::Display for WorkingTransition {
111    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
112        return write!(f, "-({})> {}", self.symbol, self.to);
113    }
114}
115
116#[derive(Debug, Clone)]
117pub struct WorkingState {
118    pub(crate) transitions: Vec<WorkingTransition>,
119    pub(crate) epsilons: Vec<EpsilonTransition>,
120}
121impl WorkingState {
122    pub const fn new() -> WorkingState {
123        return WorkingState {
124            transitions: Vec::new(),
125            epsilons: Vec::new(),
126        };
127    }
128    pub fn with_transition(mut self, to: usize, symbol: Atom) -> WorkingState {
129        self.transitions.push(WorkingTransition::new(to, symbol));
130        return self;
131    }
132    pub fn with_epsilon(mut self, to: usize) -> WorkingState {
133        self.epsilons.push(EpsilonTransition::new(to));
134        return self;
135    }
136    pub fn with_epsilon_special(mut self, to: usize, special: EpsilonType) -> WorkingState {
137        self.epsilons.push(EpsilonTransition { to, special });
138        return self;
139    }
140    pub fn with_offset(mut self, offset: usize) -> WorkingState {
141        self.inplace_offset(offset);
142        return self;
143    }
144    pub fn inplace_offset(&mut self, offset: usize) {
145        for t in &mut self.transitions {
146            t.inplace_offset(offset);
147        }
148        for e in &mut self.epsilons {
149            e.inplace_offset(offset);
150        }
151    }
152    pub fn add_offset(&self, offset: usize) -> WorkingState {
153        return WorkingState {
154            transitions: self
155                .transitions
156                .iter()
157                .map(|t| t.add_offset(offset))
158                .collect(),
159            epsilons: self.epsilons.iter().map(|e| e.add_offset(offset)).collect(),
160        };
161    }
162}
163impl std::fmt::Display for WorkingState {
164    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
165        for t in &self.transitions {
166            writeln!(f, "  {t}")?;
167        }
168        for e in &self.epsilons {
169            writeln!(f, "  {e}")?;
170        }
171        return Ok(());
172    }
173}
174
175/// Each NFA has one start state (`0`) and one accept state (`states.len() - 1`)
176#[derive(Debug, Clone)]
177pub struct WorkingNFA {
178    pub(crate) states: Vec<WorkingState>,
179}
180impl WorkingNFA {
181    /// Makes an NFA that matches with zero length.
182    fn nfa_empty() -> WorkingNFA {
183        let states = vec![WorkingState::new()];
184        return WorkingNFA { states };
185    }
186    /// Makes an NFA matching a some symbol.
187    fn nfa_symbol(c: &Atom) -> WorkingNFA {
188        let states = vec![
189            WorkingState::new().with_transition(1, c.clone()),
190            WorkingState::new(),
191        ];
192        return WorkingNFA { states };
193    }
194    /// Makes a union of NFAs.
195    fn nfa_union(nodes: &[WorkingNFA]) -> WorkingNFA {
196        let states_count = 2 + nodes.iter().map(|n| n.states.len()).sum::<usize>();
197        let mut states = vec![WorkingState::new()];
198        for nfa in nodes {
199            let sub_nfa_start = states.len();
200            states[0]
201                .epsilons
202                .push(EpsilonTransition::new(sub_nfa_start));
203            states.extend(
204                nfa.states
205                    .iter()
206                    .map(|state| state.add_offset(sub_nfa_start)),
207            );
208            states
209                .last_mut()
210                .unwrap()
211                .epsilons
212                .push(EpsilonTransition::new(states_count - 1));
213        }
214        states.push(WorkingState::new());
215        assert_eq!(states_count, states.len());
216
217        return WorkingNFA { states };
218    }
219    fn build_union(nodes: &[SimplifiedTreeNode]) -> WorkingNFA {
220        let sub_nfas: Vec<WorkingNFA> = nodes.iter().map(WorkingNFA::build).collect();
221        return WorkingNFA::nfa_union(&sub_nfas);
222    }
223    /// Wraps an NFA part in a capture group.
224    fn nfa_capture(nfa: &WorkingNFA, group_num: usize) -> WorkingNFA {
225        let states_count = 2 + nfa.states.len();
226        let mut states: Vec<WorkingState> = std::iter::once(
227            WorkingState::new().with_epsilon_special(1, EpsilonType::StartCapture(group_num)),
228        )
229        .chain(nfa.states.iter().map(|state| state.add_offset(1)))
230        .chain(std::iter::once(WorkingState::new()))
231        .collect();
232        assert_eq!(states_count, states.len());
233        states[states_count - 2].epsilons.push(EpsilonTransition {
234            to: states_count - 1,
235            special: EpsilonType::EndCapture(group_num),
236        });
237
238        return WorkingNFA { states };
239    }
240    fn build_capture(tree: &SimplifiedTreeNode, group_num: usize) -> WorkingNFA {
241        let nfa = WorkingNFA::build(tree);
242        return WorkingNFA::nfa_capture(&nfa, group_num);
243    }
244    /// Makes an NFA that matches a concatenation of NFAs.
245    fn nfa_concat<T: IntoIterator<Item = WorkingNFA>>(nodes: T) -> WorkingNFA {
246        let mut states = vec![WorkingState::new().with_epsilon(1)];
247
248        for nfa in nodes {
249            let states_count = states.len();
250            states.extend(
251                nfa.states
252                    .into_iter()
253                    .map(|state| state.with_offset(states_count)),
254            );
255            let states_count = states.len();
256            states
257                .last_mut()
258                .unwrap()
259                .epsilons
260                .push(EpsilonTransition::new(states_count));
261        }
262
263        states.push(WorkingState::new());
264        return WorkingNFA { states };
265    }
266    fn build_concat<'a, T: IntoIterator<Item = &'a SimplifiedTreeNode>>(nodes: T) -> WorkingNFA {
267        return WorkingNFA::nfa_concat(nodes.into_iter().map(WorkingNFA::build));
268    }
269    /// Makes an NFA that matches some NFA concatenated with itself multiple times.
270    fn nfa_repeat(nfa: &WorkingNFA, times: usize) -> WorkingNFA {
271        return WorkingNFA::nfa_concat(std::iter::repeat(nfa).cloned().take(times));
272    }
273    fn build_repeat(tree: &SimplifiedTreeNode, times: usize) -> WorkingNFA {
274        let nfa = WorkingNFA::build(tree);
275        return WorkingNFA::nfa_repeat(&nfa, times);
276    }
277    /// Makes an NFA that matches some NFA concatenated with itself up to some number of times.
278    fn nfa_upto(nfa: &WorkingNFA, times: usize, longest: bool) -> WorkingNFA {
279        let end_state_idx = 1 + (nfa.states.len() + 1) * times;
280
281        let mut states = vec![WorkingState::new()
282            .with_epsilon(1)
283            .with_epsilon(end_state_idx - 1)];
284        for i in 0..times {
285            let states_count = states.len();
286            states.extend(
287                nfa.states
288                    .iter()
289                    .map(|state| state.add_offset(states_count)),
290            );
291            let transition_state_idx = states.len();
292            states
293                .last_mut()
294                .unwrap()
295                .epsilons
296                .push(EpsilonTransition::new(transition_state_idx));
297            let mut transition_state = WorkingState::new();
298            if i + 1 != times {
299                if longest {
300                    transition_state
301                        .epsilons
302                        .push(EpsilonTransition::new(states.len() + 1));
303                }
304
305                transition_state
306                    .epsilons
307                    .push(EpsilonTransition::new(end_state_idx - 1));
308                if !longest {
309                    transition_state
310                        .epsilons
311                        .push(EpsilonTransition::new(states.len() + 1));
312                }
313            }
314            states.push(transition_state);
315        }
316
317        return WorkingNFA { states };
318    }
319    fn build_upto(tree: &SimplifiedTreeNode, times: usize, longest: bool) -> WorkingNFA {
320        let nfa = WorkingNFA::build(tree);
321        return WorkingNFA::nfa_upto(&nfa, times, longest);
322    }
323    /// Makes an NFA that matches some NFA concatenated with itself any number of times.
324    fn nfa_star(nfa: WorkingNFA, longest: bool) -> WorkingNFA {
325        let end_state_idx = 1 + nfa.states.len();
326        let mut start_state = WorkingState::new();
327        if !longest {
328            start_state
329                .epsilons
330                .push(EpsilonTransition::new(end_state_idx));
331        }
332        start_state.epsilons.push(EpsilonTransition::new(1));
333        if longest {
334            start_state
335                .epsilons
336                .push(EpsilonTransition::new(end_state_idx));
337        }
338        let mut states: Vec<WorkingState> = std::iter::once(start_state)
339            .chain(nfa.states.into_iter().map(|state| state.with_offset(1)))
340            .chain(std::iter::once(WorkingState::new()))
341            .collect();
342        states[end_state_idx - 1]
343            .epsilons
344            .push(EpsilonTransition::new(0));
345        return WorkingNFA { states };
346    }
347    fn build_star(tree: &SimplifiedTreeNode, longest: bool) -> WorkingNFA {
348        let nfa = WorkingNFA::build(tree);
349        return WorkingNFA::nfa_star(nfa, longest);
350    }
351    /// Makes an NFA that matches zero length but only at the text start
352    fn nfa_start() -> WorkingNFA {
353        let states = vec![
354            WorkingState::new().with_epsilon_special(1, EpsilonType::StartAnchor),
355            WorkingState::new(),
356        ];
357        return WorkingNFA { states };
358    }
359    /// Makes an NFA that matches zero length but only at the text end
360    fn nfa_end() -> WorkingNFA {
361        let states = vec![
362            WorkingState::new().with_epsilon_special(1, EpsilonType::EndAnchor),
363            WorkingState::new(),
364        ];
365        return WorkingNFA { states };
366    }
367    /// Makes an NFA that never matches.
368    fn nfa_never() -> WorkingNFA {
369        let states = vec![WorkingState::new(), WorkingState::new()];
370        return WorkingNFA { states };
371    }
372    /// Recursively builds an inefficient but valid NFA based loosely on Thompson's Algorithm.
373    ///
374    /// Should be optimized using [`WorkingNFA::optimize_pass`]
375    fn build(tree: &SimplifiedTreeNode) -> WorkingNFA {
376        return match tree {
377            SimplifiedTreeNode::Empty => WorkingNFA::nfa_empty(),
378            SimplifiedTreeNode::Symbol(c) => WorkingNFA::nfa_symbol(c),
379            SimplifiedTreeNode::Union(nodes) => WorkingNFA::build_union(nodes),
380            SimplifiedTreeNode::Capture(tree, group_num) => {
381                WorkingNFA::build_capture(&tree, *group_num)
382            }
383            SimplifiedTreeNode::Concat(nodes) => WorkingNFA::build_concat(nodes),
384            SimplifiedTreeNode::Repeat(tree, times) => WorkingNFA::build_repeat(tree, *times),
385            SimplifiedTreeNode::UpTo(tree, times, longest) => {
386                WorkingNFA::build_upto(tree, *times, *longest)
387            }
388            SimplifiedTreeNode::Star(tree, longest) => WorkingNFA::build_star(tree, *longest),
389            SimplifiedTreeNode::Start => WorkingNFA::nfa_start(),
390            SimplifiedTreeNode::End => WorkingNFA::nfa_end(),
391            SimplifiedTreeNode::Never => WorkingNFA::nfa_never(),
392        };
393    }
394    pub fn new(tree: &SimplifiedTreeNode) -> WorkingNFA {
395        let mut nfa = WorkingNFA::build(tree);
396
397        nfa.clean_start_anchors();
398        nfa.clean_end_anchors();
399
400        // add loops at start and end in case we lack anchors
401        nfa = WorkingNFA::nfa_concat([
402            WorkingNFA::nfa_star(
403                WorkingNFA::nfa_symbol(&Atom::NonmatchingList(Vec::new())),
404                false,
405            ),
406            nfa,
407            WorkingNFA::nfa_star(
408                WorkingNFA::nfa_symbol(&Atom::NonmatchingList(Vec::new())),
409                false,
410            ),
411        ]);
412
413        // Then remove redundant transitions from nodes before/after anchors
414        // May include the loops we just added
415        let zero_symbol_states: Vec<bool> =
416            std::iter::zip(nfa.nodes_after_end(), nfa.nodes_before_start())
417                .map(|(a, b)| a || b)
418                .collect();
419        for (from, state) in nfa.states.iter_mut().enumerate() {
420            if zero_symbol_states[from] {
421                state.transitions = Vec::new();
422            }
423        }
424
425        // Finally, do normal optimization passes
426        while nfa.optimize_pass() {}
427        nfa.remove_unreachable();
428        return nfa;
429    }
430
431    /// Removes start anchors that will never be satisfied
432    /// (basically turning them into a `Never` to allow further optimization)
433    fn clean_start_anchors(&mut self) {
434        let mut zero_len_reachable = vec![false; self.states.len()];
435        zero_len_reachable[0] = true;
436        let mut stack = vec![0];
437        while let Some(state) = stack.pop() {
438            for e in &self.states[state].epsilons {
439                if !zero_len_reachable[e.to] {
440                    stack.push(e.to);
441                }
442                zero_len_reachable[e.to] = true;
443            }
444        }
445
446        for (i, state) in self.states.iter_mut().enumerate() {
447            state
448                .epsilons
449                .retain(|e| e.special != EpsilonType::StartAnchor || zero_len_reachable[i]);
450        }
451    }
452
453    /// Removes end anchors that will never be satisfied
454    /// (basically turning them into a `Never` to allow further optimization)    
455    fn clean_end_anchors(&mut self) {
456        let mut zero_len_reachable = vec![false; self.states.len()];
457        zero_len_reachable[self.states.len() - 1] = true;
458
459        let mut reverse_epsilons = vec![Vec::new(); self.states.len()];
460        for (i, state) in self.states.iter().enumerate() {
461            for e in &state.epsilons {
462                reverse_epsilons[e.to].push(i);
463            }
464        }
465
466        let mut stack = vec![self.states.len() - 1];
467        while let Some(state) = stack.pop() {
468            for src in &reverse_epsilons[state] {
469                if !zero_len_reachable[*src] {
470                    stack.push(*src);
471                }
472                zero_len_reachable[*src] = true;
473            }
474        }
475
476        for state in self.states.iter_mut() {
477            state
478                .epsilons
479                .retain(|e| e.special != EpsilonType::EndAnchor || zero_len_reachable[e.to]);
480        }
481    }
482    /// Finds all nodes that are only ever visited after a `$`.
483    fn nodes_after_end(&self) -> Vec<bool> {
484        let mut nodes = vec![true; self.states.len()];
485        nodes[0] = false;
486
487        let mut stack = vec![0];
488        while let Some(from) = stack.pop() {
489            for e in self.states[from].epsilons.iter() {
490                if nodes[e.to] && e.special != EpsilonType::EndAnchor {
491                    nodes[e.to] = false;
492                    stack.push(e.to);
493                }
494            }
495            for t in self.states[from].transitions.iter() {
496                if nodes[t.to] {
497                    nodes[t.to] = false;
498                    stack.push(t.to);
499                }
500            }
501        }
502        return nodes;
503    }
504    /// Finds all nodes that are only ever visited before a `^`.
505    fn nodes_before_start(&self) -> Vec<bool> {
506        let mut reverse = vec![Vec::new(); self.states.len()];
507        for (i, state) in self.states.iter().enumerate() {
508            for e in &state.epsilons {
509                if e.special != EpsilonType::StartAnchor {
510                    reverse[e.to].push(i);
511                }
512            }
513            for t in &state.transitions {
514                reverse[t.to].push(i);
515            }
516        }
517
518        let mut nodes = vec![true; self.states.len()];
519        nodes[self.states.len() - 1] = false;
520
521        let mut stack = vec![self.states.len() - 1];
522        while let Some(to) = stack.pop() {
523            for from in &reverse[to] {
524                if nodes[*from] {
525                    nodes[*from] = false;
526                    stack.push(*from);
527                }
528            }
529        }
530        return nodes;
531    }
532
533    /// Helper function for removing a set of states.
534    ///
535    /// These states should have no incoming transitions.
536    fn remove_dead_states<T: IntoIterator<Item = bool>>(&mut self, dead_states: T) {
537        let state_map: Vec<usize> = dead_states
538            .into_iter()
539            .scan(0, |s, dead| {
540                if dead {
541                    return Some(usize::MAX);
542                } else {
543                    let out = *s;
544                    *s += 1;
545                    return Some(out);
546                }
547            })
548            .collect();
549        self.states = self
550            .states
551            .iter()
552            .enumerate()
553            .filter(|(i, _)| state_map[*i] != usize::MAX)
554            .map(|(_, state)| state)
555            .cloned()
556            .collect();
557
558        for state in &mut self.states {
559            for t in &mut state.transitions {
560                t.to = state_map[t.to];
561            }
562            for t in &mut state.epsilons {
563                t.to = state_map[t.to];
564            }
565        }
566    }
567
568    /// Optimizes the NFA graph.
569    ///
570    /// Returns `true` if changes were made (meaning another pass should be tried).
571    fn optimize_pass(&mut self) -> bool {
572        let mut changed = false;
573        let state_count = self.states.len();
574
575        let mut dead_states = vec![false; self.states.len()];
576
577        // Skip redundant states
578        // Special transitions (anchors + capture groups) are treated similar to non-epsilon transitions
579        for state_idx in 1..state_count - 1 {
580            let incoming: Vec<(usize, usize)> = self
581                .states
582                .iter()
583                .enumerate()
584                .flat_map(|(s_i, s)| s.transitions.iter().enumerate().map(move |(t, _)| (s_i, t)))
585                .filter(|(s, t)| self.states[*s].transitions[*t].to == state_idx)
586                .collect();
587            let incoming_eps: Vec<(usize, usize)> = self
588                .states
589                .iter()
590                .enumerate()
591                .flat_map(|(s_i, s)| s.epsilons.iter().enumerate().map(move |(e, _)| (s_i, e)))
592                .filter(|(s, e)| self.states[*s].epsilons[*e].to == state_idx)
593                .collect();
594
595            match (
596                incoming.as_slice(),
597                incoming_eps.as_slice(),
598                self.states[state_idx].transitions.len(),
599                self.states[state_idx].epsilons.len(),
600            ) {
601                // `as -xes> b -e> c` can become `as -xes> c` (assuming no other transitions)
602                (incoming, incoming_eps, 0, 1)
603                    if self.states[state_idx].epsilons[0].special == EpsilonType::None =>
604                {
605                    let to = self.states[state_idx].epsilons[0].to;
606                    for (s, t) in incoming {
607                        self.states[*s].transitions[*t].to = to;
608                    }
609                    for (s, e) in incoming_eps {
610                        self.states[*s].epsilons[*e].to = to;
611                    }
612                    dead_states[state_idx] = true;
613                    self.states[state_idx].epsilons = Vec::new();
614                    changed = true;
615                    continue;
616                }
617                // `a -e> b -es> cs` can become `a -es> cs` (assuming no other transitions)
618                (&[], &[(incoming_state, incoming_eps)], 0, _)
619                    if self.states[incoming_state].epsilons[incoming_eps].special
620                        == EpsilonType::None =>
621                {
622                    let outgoing_eps = std::mem::take(&mut self.states[state_idx].epsilons);
623                    let after = self.states[incoming_state]
624                        .epsilons
625                        .split_off(incoming_eps + 1);
626                    self.states[incoming_state].epsilons.pop();
627                    self.states[incoming_state]
628                        .epsilons
629                        .extend_from_slice(&outgoing_eps);
630                    self.states[incoming_state]
631                        .epsilons
632                        .extend_from_slice(&after);
633
634                    dead_states[state_idx] = true;
635                    changed = true;
636                    continue;
637                }
638                _ => {}
639            }
640
641            // TODO:
642            // `a -e> b -xes> cs` can become `a -xes> cs` (assuming no other transitions)
643            // `a -e> b -e> a` can combine `a` and `b` (including other transitions)
644            // TODO: might cause additional overhead in some cases, should we do
645            // ??? `a -x> b -es> cs` can become `a -xs> cs`
646            // ??? `as -es> b -x> c` can become `as -xs> c`
647        }
648        if !changed {
649            return changed;
650        }
651
652        self.remove_dead_states(dead_states);
653
654        return changed;
655    }
656
657    /// Finds the states that can be reached from the start via any path
658    fn states_reachable_start(&self) -> Vec<bool> {
659        let mut reachable = vec![false; self.states.len()];
660        reachable[0] = true;
661        let mut stack = vec![0];
662
663        while let Some(state) = stack.pop() {
664            for src in &self.states[state].epsilons {
665                if !reachable[src.to] {
666                    stack.push(src.to);
667                }
668                reachable[src.to] = true;
669            }
670            for src in &self.states[state].transitions {
671                if !reachable[src.to] {
672                    stack.push(src.to);
673                }
674                reachable[src.to] = true;
675            }
676        }
677
678        return reachable;
679    }
680    /// Finds the states that can reach the end via any path
681    fn states_reachable_end(&self) -> Vec<bool> {
682        let mut reverse = vec![Vec::new(); self.states.len()];
683        for (i, state) in self.states.iter().enumerate() {
684            for e in &state.epsilons {
685                reverse[e.to].push(i);
686            }
687            for t in &state.transitions {
688                reverse[t.to].push(i);
689            }
690        }
691
692        let mut reachable = vec![false; self.states.len()];
693        reachable[self.states.len() - 1] = true;
694        let mut stack = vec![self.states.len() - 1];
695
696        while let Some(state) = stack.pop() {
697            for src in &reverse[state] {
698                if !reachable[*src] {
699                    stack.push(*src);
700                }
701                reachable[*src] = true;
702            }
703        }
704
705        return reachable;
706    }
707
708    /// Removes all nodes that cannot be reached or cannot reach the end.
709    ///
710    /// Ignores special epsilon types (so should be called after they have been resolved)
711    fn remove_unreachable(&mut self) {
712        let reach_start = self.states_reachable_start();
713        let reach_end = self.states_reachable_end();
714
715        // Remove transitions that involve redundant states
716        for state in &mut self.states {
717            state
718                .epsilons
719                .retain(|e| reach_start[e.to] && reach_end[e.to]);
720            state
721                .transitions
722                .retain(|t| reach_start[t.to] && reach_end[t.to]);
723        }
724
725        // Then remove the states
726        self.remove_dead_states(
727            std::iter::zip(reach_start.into_iter(), reach_end.into_iter()).map(|(a, b)| !a || !b),
728        );
729    }
730
731    /// Finds the number of capture groups in this NFA
732    pub fn num_capture_groups(&self) -> usize {
733        return self
734            .states
735            .iter()
736            .flat_map(|state| &state.epsilons)
737            .map(|eps| match eps.special {
738                EpsilonType::StartCapture(n) => n,
739                _ => 0,
740            })
741            .max()
742            .unwrap_or(0)
743            + 1;
744    }
745
746    /// Writes a LaTeX TikZ representation to visualize the graph.
747    ///
748    /// If `include_doc` is `true`, will include the headers.
749    /// Otherwise, you should include `\usepackage{tikz}` and `\usetikzlibrary{automata, positioning}`.
750    pub fn to_tikz(&self, include_doc: bool) -> String {
751        let map_state =
752            |(i, state): (usize, &WorkingState)| -> crate::visualization::LatexGraphState {
753                let transitions =
754                    state
755                        .transitions
756                        .iter()
757                        .map(|t| crate::visualization::LatexGraphTransition {
758                            label: crate::visualization::escape_latex(t.symbol.to_string()),
759                            to: t.to,
760                        });
761                let epsilons = state.epsilons.iter().map(|e| {
762                    let label = match e.special {
763                        EpsilonType::None => r"$\epsilon$".to_string(),
764                        EpsilonType::StartAnchor => r"{\textasciicircum}".to_string(),
765                        EpsilonType::EndAnchor => r"\$".to_string(),
766                        EpsilonType::StartCapture(group) => format!("{group}("),
767                        EpsilonType::EndCapture(group) => format!("){group}"),
768                    };
769                    return crate::visualization::LatexGraphTransition { label, to: e.to };
770                });
771                let transitions = transitions.chain(epsilons).collect();
772                return crate::visualization::LatexGraphState {
773                    label: format!("q{i}"),
774                    transitions,
775                    initial: i == 0,
776                    accept: i + 1 == self.states.len(),
777                };
778            };
779
780        let graph = crate::visualization::LatexGraph {
781            states: self.states.iter().enumerate().map(map_state).collect(),
782        };
783        return graph.to_tikz(include_doc);
784    }
785
786    /// Using the classical NFA algorithm to do a simple boolean test on a string.
787    pub fn test(&self, text: &str) -> bool {
788        let mut list = vec![false; self.states.len()];
789        let mut new_list = vec![false; self.states.len()];
790        list[0] = true;
791
792        // Adds all states reachable by epsilon transitions
793        let propogate_epsilon = |list: &mut Vec<bool>, idx: usize| {
794            let mut stack: Vec<usize> = list
795                .iter()
796                .enumerate()
797                .filter_map(|(i, set)| set.then_some(i))
798                .collect();
799
800            while let Some(from) = stack.pop() {
801                for EpsilonTransition { to, special } in &self.states[from].epsilons {
802                    if list[from]
803                        && !list[*to]
804                        && (match special {
805                            EpsilonType::StartAnchor => idx == 0,
806                            EpsilonType::EndAnchor => idx == text.len(),
807                            _ => true,
808                        })
809                    {
810                        stack.push(*to);
811                        list[*to] = true;
812                    }
813                }
814            }
815        };
816
817        for (i, c) in text.char_indices() {
818            propogate_epsilon(&mut list, i);
819            for (from, state) in self.states.iter().enumerate() {
820                if !list[from] {
821                    continue;
822                }
823
824                for WorkingTransition { to, symbol } in &state.transitions {
825                    if symbol.check(c) {
826                        new_list[*to] = true;
827                    }
828                }
829            }
830            let tmp = list;
831            list = new_list;
832            new_list = tmp;
833            new_list.fill(false);
834        }
835        propogate_epsilon(&mut list, text.len());
836        return *list.last().unwrap_or(&false);
837    }
838}
839impl std::fmt::Display for WorkingNFA {
840    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
841        for (i, state) in self.states.iter().enumerate() {
842            writeln!(f, "State {i}:")?;
843            for e in &state.epsilons {
844                writeln!(f, "  {e}")?;
845            }
846            for t in &state.transitions {
847                writeln!(f, "  {t}")?;
848            }
849        }
850        return Ok(());
851    }
852}
853
854#[cfg(test)]
855mod tests {
856    use super::*;
857    use crate::{config::Config, parse_tree::ERE};
858
859    #[test]
860    fn abbc_raw() {
861        let nfa = WorkingNFA {
862            states: vec![
863                WorkingState::new().with_transition(1, 'a'.into()),
864                WorkingState::new().with_transition(2, 'b'.into()),
865                WorkingState::new()
866                    .with_transition(3, 'c'.into())
867                    .with_epsilon(1),
868                WorkingState::new(),
869            ],
870        };
871        println!("{}", nfa.to_tikz(true));
872
873        assert!(nfa.test("abc"));
874        assert!(nfa.test("abbc"));
875        assert!(nfa.test("abbbc"));
876        assert!(nfa.test("abbbbc"));
877
878        assert!(!nfa.test("ac"));
879        assert!(!nfa.test("abcc"));
880        assert!(!nfa.test("bac"));
881        assert!(!nfa.test("acb"));
882    }
883
884    #[test]
885    fn phone_number() {
886        let ere = ERE::parse_str(r"^(\+1 )?[0-9]{3}-[0-9]{3}-[0-9]{4}$").unwrap();
887        let (tree, capture_groups) = SimplifiedTreeNode::from_ere(&ere, &Config::default());
888        assert_eq!(capture_groups, 2);
889        let nfa = WorkingNFA::new(&tree);
890        println!("{}", nfa.to_tikz(true));
891
892        assert!(nfa.test("012-345-6789"));
893        assert!(nfa.test("987-654-3210"));
894        assert!(nfa.test("+1 555-555-5555"));
895        assert!(nfa.test("123-555-9876"));
896
897        assert!(!nfa.test("abcd"));
898        assert!(!nfa.test("0123456789"));
899        assert!(!nfa.test("012--345-6789"));
900        assert!(!nfa.test("(555) 555-5555"));
901        assert!(!nfa.test("1 555-555-5555"));
902    }
903
904    #[test]
905    fn double_loop() {
906        let ere = ERE::parse_str(r"^.*(.*)*$").unwrap();
907        let (tree, capture_groups) = SimplifiedTreeNode::from_ere(&ere, &Config::default());
908        assert_eq!(capture_groups, 2);
909        let nfa = WorkingNFA::new(&tree);
910        // println!("{}", nfa.to_tikz(true));
911
912        assert!(nfa.test(""));
913        assert!(nfa.test("asdf"));
914        assert!(nfa.test("1234567"));
915        assert!(nfa.test("0"));
916
917        assert!(!nfa.test("\0"));
918    }
919
920    #[test]
921    fn good_anchored_start() {
922        let ere = ERE::parse_str(r"^a|b*^c|d^|n").unwrap();
923        let (tree, capture_groups) = SimplifiedTreeNode::from_ere(&ere, &Config::default());
924        assert_eq!(capture_groups, 1);
925        let nfa = WorkingNFA::new(&tree);
926        // println!("{}", nfa.to_tikz(true));
927
928        assert!(nfa.test("a"));
929        assert!(nfa.test("c"));
930        assert!(nfa.test("cq"));
931        assert!(nfa.test("wwwnwww"));
932
933        assert!(!nfa.test(""));
934        assert!(!nfa.test("qb"));
935        assert!(!nfa.test("qc"));
936        assert!(!nfa.test("b"));
937        assert!(!nfa.test("bc"));
938        assert!(!nfa.test("bbbbbbc"));
939        assert!(!nfa.test("d"));
940    }
941
942    #[test]
943    fn good_anchored_end() {
944        let ere = ERE::parse_str(r"a$|b$c*|$d|n").unwrap();
945        let (tree, capture_groups) = SimplifiedTreeNode::from_ere(&ere, &Config::default());
946        assert_eq!(capture_groups, 1);
947        let nfa = WorkingNFA::new(&tree);
948        println!("{}", nfa.to_tikz(true));
949
950        assert!(nfa.test("a"));
951        assert!(nfa.test("b"));
952        assert!(nfa.test("qb"));
953        assert!(nfa.test("wwwnwww"));
954
955        assert!(!nfa.test(""));
956        assert!(!nfa.test("bq"));
957        assert!(!nfa.test("qc"));
958        assert!(!nfa.test("c"));
959        assert!(!nfa.test("bc"));
960        assert!(!nfa.test("bcccccc"));
961        assert!(!nfa.test("d"));
962    }
963
964    #[test]
965    fn range_digit() {
966        let ere = ERE::parse_str(r"^[[:digit:].]$").unwrap();
967        let (tree, capture_groups) = SimplifiedTreeNode::from_ere(&ere, &Config::default());
968        assert_eq!(capture_groups, 1);
969        let nfa = WorkingNFA::new(&tree);
970        // println!("{}", nfa.to_tikz(true));
971
972        assert!(nfa.test("0"));
973        assert!(nfa.test("1"));
974        assert!(nfa.test("9"));
975        assert!(nfa.test("."));
976
977        assert!(!nfa.test(""));
978        assert!(!nfa.test("a"));
979        assert!(!nfa.test("11"));
980        assert!(!nfa.test("1."));
981        assert!(!nfa.test(".2"));
982        assert!(!nfa.test("09"));
983        assert!(!nfa.test("d"));
984    }
985}