ere_core/
working_nfa.rs

1//! Implements the primary compile-time intermediate [`WorkingNFA`] structure for optimization.
2
3use crate::parse_tree::Atom;
4use crate::simplified_tree::SimplifiedTreeNode;
5use quote::{quote, ToTokens};
6
7#[derive(Debug, Clone, Copy, PartialEq, Eq)]
8pub enum EpsilonType {
9    None,
10    StartAnchor,
11    EndAnchor,
12    StartCapture(usize),
13    EndCapture(usize),
14}
15impl ToTokens for EpsilonType {
16    fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
17        match self {
18            EpsilonType::None => tokens.extend(quote! { ::ere::working_nfa::EpsilonType::None }),
19            EpsilonType::StartAnchor => {
20                tokens.extend(quote! { ::ere::working_nfa::EpsilonType::StartAnchor })
21            }
22            EpsilonType::EndAnchor => {
23                tokens.extend(quote! { ::ere::working_nfa::EpsilonType::EndAnchor })
24            }
25            EpsilonType::StartCapture(group_num) => tokens.extend(quote! {
26                ::ere::working_nfa::EpsilonType::StartCapture(#group_num)
27            }),
28            EpsilonType::EndCapture(group_num) => tokens.extend(quote! {
29                ::ere::working_nfa::EpsilonType::EndCapture(#group_num)
30            }),
31        };
32    }
33}
34
35/// An epsilon transition for the [`WorkingNFA`]
36#[derive(Debug, Clone, Copy, PartialEq, Eq)]
37pub struct EpsilonTransition {
38    pub(crate) to: usize,
39    pub(crate) special: EpsilonType,
40}
41impl EpsilonTransition {
42    pub(crate) const fn new(to: usize) -> EpsilonTransition {
43        return EpsilonTransition {
44            to,
45            special: EpsilonType::None,
46        };
47    }
48    pub(crate) const fn with_offset(self, offset: usize) -> EpsilonTransition {
49        return EpsilonTransition {
50            to: self.to + offset,
51            special: self.special,
52        };
53    }
54    pub(crate) fn inplace_offset(&mut self, offset: usize) {
55        self.to += offset;
56    }
57    pub(crate) const fn add_offset(&self, offset: usize) -> EpsilonTransition {
58        return EpsilonTransition {
59            to: self.to + offset,
60            special: self.special,
61        };
62    }
63    /// Only intended for internal use by macros.
64    pub const fn __load(to: usize, special: EpsilonType) -> EpsilonTransition {
65        return EpsilonTransition { to, special };
66    }
67}
68impl std::fmt::Display for EpsilonTransition {
69    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
70        return write!(f, "-> {}", self.to);
71    }
72}
73impl ToTokens for EpsilonTransition {
74    fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
75        let EpsilonTransition { to, special } = self;
76        tokens.extend(quote! {
77            ere_core::working_nfa::EpsilonTransition::__load(
78                #to,
79                #special,
80            )
81        });
82    }
83}
84
85#[derive(Debug, Clone, PartialEq, Eq)]
86pub(crate) struct WorkingTransition {
87    pub(crate) to: usize,
88    pub(crate) symbol: Atom,
89}
90impl WorkingTransition {
91    pub fn new(to: usize, symbol: Atom) -> WorkingTransition {
92        return WorkingTransition { to, symbol };
93    }
94    pub fn with_offset(mut self, offset: usize) -> WorkingTransition {
95        self.inplace_offset(offset);
96        return self;
97    }
98    pub fn inplace_offset(&mut self, offset: usize) {
99        self.to += offset;
100    }
101    pub fn add_offset(&self, offset: usize) -> WorkingTransition {
102        return WorkingTransition {
103            to: self.to + offset,
104            symbol: self.symbol.clone(),
105        };
106    }
107}
108impl std::fmt::Display for WorkingTransition {
109    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
110        return write!(f, "-({})> {}", self.symbol, self.to);
111    }
112}
113
114/// The symbol transitions are assumed to have priority before epsilon transitions,
115/// meaning that of all the propogated next symbol transitions available,
116/// those going out from the previous transition's destination will come first.
117///
118/// So, if the symbol transitions should have a lower priority than those found
119/// via an epsilon transition, they should be in a new state with a lower priority
120/// epsilon transition to it.
121/// Due to the way we construct NFAs, this should initially be the case--
122/// No state has both incoming and outgoing symbol transitions
123/// (they are always separated by at least one epsilon transition) and priority is maintained.
124/// We then can optimize where there are no competing transitions.
125#[derive(Debug, Clone)]
126pub struct WorkingState {
127    pub(crate) transitions: Vec<WorkingTransition>,
128    pub(crate) epsilons: Vec<EpsilonTransition>,
129}
130impl WorkingState {
131    pub const fn new() -> WorkingState {
132        return WorkingState {
133            transitions: Vec::new(),
134            epsilons: Vec::new(),
135        };
136    }
137    pub fn with_transition(mut self, to: usize, symbol: Atom) -> WorkingState {
138        self.transitions.push(WorkingTransition::new(to, symbol));
139        return self;
140    }
141    pub fn with_epsilon(mut self, to: usize) -> WorkingState {
142        self.epsilons.push(EpsilonTransition::new(to));
143        return self;
144    }
145    pub fn with_epsilon_special(mut self, to: usize, special: EpsilonType) -> WorkingState {
146        self.epsilons.push(EpsilonTransition { to, special });
147        return self;
148    }
149    pub fn with_offset(mut self, offset: usize) -> WorkingState {
150        self.inplace_offset(offset);
151        return self;
152    }
153    pub fn inplace_offset(&mut self, offset: usize) {
154        for t in &mut self.transitions {
155            t.inplace_offset(offset);
156        }
157        for e in &mut self.epsilons {
158            e.inplace_offset(offset);
159        }
160    }
161    pub fn add_offset(&self, offset: usize) -> WorkingState {
162        return WorkingState {
163            transitions: self
164                .transitions
165                .iter()
166                .map(|t| t.add_offset(offset))
167                .collect(),
168            epsilons: self.epsilons.iter().map(|e| e.add_offset(offset)).collect(),
169        };
170    }
171}
172impl std::fmt::Display for WorkingState {
173    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
174        for t in &self.transitions {
175            writeln!(f, "  {t}")?;
176        }
177        for e in &self.epsilons {
178            writeln!(f, "  {e}")?;
179        }
180        return Ok(());
181    }
182}
183
184/// Each NFA has one start state (`0`) and one accept state (`states.len() - 1`)
185#[derive(Debug, Clone)]
186pub struct WorkingNFA {
187    pub(crate) states: Vec<WorkingState>,
188}
189impl WorkingNFA {
190    /// Makes an NFA that matches with zero length.
191    fn nfa_empty() -> WorkingNFA {
192        let states = vec![WorkingState::new()];
193        return WorkingNFA { states };
194    }
195    /// Makes an NFA matching a some symbol.
196    fn nfa_symbol(c: &Atom) -> WorkingNFA {
197        let states = vec![
198            WorkingState::new().with_transition(1, c.clone()),
199            WorkingState::new(),
200        ];
201        return WorkingNFA { states };
202    }
203    /// Makes a union of NFAs.
204    fn nfa_union(nodes: &[WorkingNFA]) -> WorkingNFA {
205        let states_count = 2 + nodes.iter().map(|n| n.states.len()).sum::<usize>();
206        let mut states = vec![WorkingState::new()];
207        for nfa in nodes {
208            let sub_nfa_start = states.len();
209            states[0]
210                .epsilons
211                .push(EpsilonTransition::new(sub_nfa_start));
212            states.extend(
213                nfa.states
214                    .iter()
215                    .map(|state| state.add_offset(sub_nfa_start)),
216            );
217            states
218                .last_mut()
219                .unwrap()
220                .epsilons
221                .push(EpsilonTransition::new(states_count - 1));
222        }
223        states.push(WorkingState::new());
224        assert_eq!(states_count, states.len());
225
226        return WorkingNFA { states };
227    }
228    fn build_union(nodes: &[SimplifiedTreeNode]) -> WorkingNFA {
229        let sub_nfas: Vec<WorkingNFA> = nodes.iter().map(WorkingNFA::build).collect();
230        return WorkingNFA::nfa_union(&sub_nfas);
231    }
232    /// Wraps an NFA part in a capture group.
233    fn nfa_capture(nfa: &WorkingNFA, group_num: usize) -> WorkingNFA {
234        let states_count = 2 + nfa.states.len();
235        let mut states: Vec<WorkingState> = std::iter::once(
236            WorkingState::new().with_epsilon_special(1, EpsilonType::StartCapture(group_num)),
237        )
238        .chain(nfa.states.iter().map(|state| state.add_offset(1)))
239        .chain(std::iter::once(WorkingState::new()))
240        .collect();
241        assert_eq!(states_count, states.len());
242        states[states_count - 2].epsilons.push(EpsilonTransition {
243            to: states_count - 1,
244            special: EpsilonType::EndCapture(group_num),
245        });
246
247        return WorkingNFA { states };
248    }
249    fn build_capture(tree: &SimplifiedTreeNode, group_num: usize) -> WorkingNFA {
250        let nfa = WorkingNFA::build(tree);
251        return WorkingNFA::nfa_capture(&nfa, group_num);
252    }
253    /// Makes an NFA that matches a concatenation of NFAs.
254    fn nfa_concat<T: IntoIterator<Item = WorkingNFA>>(nodes: T) -> WorkingNFA {
255        let mut states = vec![WorkingState::new().with_epsilon(1)];
256
257        for nfa in nodes {
258            let states_count = states.len();
259            states.extend(
260                nfa.states
261                    .into_iter()
262                    .map(|state| state.with_offset(states_count)),
263            );
264            let states_count = states.len();
265            states
266                .last_mut()
267                .unwrap()
268                .epsilons
269                .push(EpsilonTransition::new(states_count));
270        }
271
272        states.push(WorkingState::new());
273        return WorkingNFA { states };
274    }
275    fn build_concat<'a, T: IntoIterator<Item = &'a SimplifiedTreeNode>>(nodes: T) -> WorkingNFA {
276        return WorkingNFA::nfa_concat(nodes.into_iter().map(WorkingNFA::build));
277    }
278    /// Makes an NFA that matches some NFA concatenated with itself multiple times.
279    fn nfa_repeat(nfa: &WorkingNFA, times: usize) -> WorkingNFA {
280        return WorkingNFA::nfa_concat(std::iter::repeat(nfa).cloned().take(times));
281    }
282    fn build_repeat(tree: &SimplifiedTreeNode, times: usize) -> WorkingNFA {
283        let nfa = WorkingNFA::build(tree);
284        return WorkingNFA::nfa_repeat(&nfa, times);
285    }
286    /// Makes an NFA that matches some NFA concatenated with itself up to some number of times.
287    fn nfa_upto(nfa: &WorkingNFA, times: usize, longest: bool) -> WorkingNFA {
288        let end_state_idx = 1 + (nfa.states.len() + 1) * times;
289
290        let state0 = if longest {
291            WorkingState::new()
292                .with_epsilon(1)
293                .with_epsilon(end_state_idx - 1)
294        } else {
295            WorkingState::new()
296                .with_epsilon(end_state_idx - 1)
297                .with_epsilon(1)
298        };
299        let mut states = vec![state0];
300        for i in 0..times {
301            let states_count = states.len();
302            states.extend(
303                nfa.states
304                    .iter()
305                    .map(|state| state.add_offset(states_count)),
306            );
307            let transition_state_idx = states.len();
308            states
309                .last_mut()
310                .unwrap()
311                .epsilons
312                .push(EpsilonTransition::new(transition_state_idx));
313            let mut transition_state = WorkingState::new();
314            if i + 1 != times {
315                if longest {
316                    transition_state
317                        .epsilons
318                        .push(EpsilonTransition::new(states.len() + 1));
319                }
320
321                transition_state
322                    .epsilons
323                    .push(EpsilonTransition::new(end_state_idx - 1));
324                if !longest {
325                    transition_state
326                        .epsilons
327                        .push(EpsilonTransition::new(states.len() + 1));
328                }
329            }
330            states.push(transition_state);
331        }
332
333        return WorkingNFA { states };
334    }
335    fn build_upto(tree: &SimplifiedTreeNode, times: usize, longest: bool) -> WorkingNFA {
336        let nfa = WorkingNFA::build(tree);
337        return WorkingNFA::nfa_upto(&nfa, times, longest);
338    }
339    /// Makes an NFA that matches some NFA concatenated with itself any number of times.
340    fn nfa_star(nfa: WorkingNFA, longest: bool) -> WorkingNFA {
341        let end_state_idx = 1 + nfa.states.len();
342        let mut start_state = WorkingState::new();
343        if !longest {
344            start_state
345                .epsilons
346                .push(EpsilonTransition::new(end_state_idx));
347        }
348        start_state.epsilons.push(EpsilonTransition::new(1));
349        if longest {
350            start_state
351                .epsilons
352                .push(EpsilonTransition::new(end_state_idx));
353        }
354        let mut states: Vec<WorkingState> = std::iter::once(start_state)
355            .chain(nfa.states.into_iter().map(|state| state.with_offset(1)))
356            .chain(std::iter::once(WorkingState::new()))
357            .collect();
358        states[end_state_idx - 1]
359            .epsilons
360            .push(EpsilonTransition::new(0));
361        return WorkingNFA { states };
362    }
363    fn build_star(tree: &SimplifiedTreeNode, longest: bool) -> WorkingNFA {
364        let nfa = WorkingNFA::build(tree);
365        return WorkingNFA::nfa_star(nfa, longest);
366    }
367    /// Makes an NFA that matches zero length but only at the text start
368    fn nfa_start() -> WorkingNFA {
369        let states = vec![
370            WorkingState::new().with_epsilon_special(1, EpsilonType::StartAnchor),
371            WorkingState::new(),
372        ];
373        return WorkingNFA { states };
374    }
375    /// Makes an NFA that matches zero length but only at the text end
376    fn nfa_end() -> WorkingNFA {
377        let states = vec![
378            WorkingState::new().with_epsilon_special(1, EpsilonType::EndAnchor),
379            WorkingState::new(),
380        ];
381        return WorkingNFA { states };
382    }
383    /// Makes an NFA that never matches.
384    fn nfa_never() -> WorkingNFA {
385        let states = vec![WorkingState::new(), WorkingState::new()];
386        return WorkingNFA { states };
387    }
388    /// Recursively builds an inefficient but valid NFA based loosely on Thompson's Algorithm.
389    ///
390    /// Should be optimized using [`WorkingNFA::optimize_pass`]
391    pub fn build(tree: &SimplifiedTreeNode) -> WorkingNFA {
392        return match tree {
393            SimplifiedTreeNode::Empty => WorkingNFA::nfa_empty(),
394            SimplifiedTreeNode::Symbol(c) => WorkingNFA::nfa_symbol(c),
395            SimplifiedTreeNode::Union(nodes) => WorkingNFA::build_union(nodes),
396            SimplifiedTreeNode::Capture(tree, group_num) => {
397                WorkingNFA::build_capture(&tree, *group_num)
398            }
399            SimplifiedTreeNode::Concat(nodes) => WorkingNFA::build_concat(nodes),
400            SimplifiedTreeNode::Repeat(tree, times) => WorkingNFA::build_repeat(tree, times.get()),
401            SimplifiedTreeNode::UpTo(tree, times, longest) => {
402                WorkingNFA::build_upto(tree, times.get(), *longest)
403            }
404            SimplifiedTreeNode::Star(tree, longest) => WorkingNFA::build_star(tree, *longest),
405            SimplifiedTreeNode::Start => WorkingNFA::nfa_start(),
406            SimplifiedTreeNode::End => WorkingNFA::nfa_end(),
407            SimplifiedTreeNode::Never => WorkingNFA::nfa_never(),
408        };
409    }
410
411    /// Creates an NFA with the default `.*?` loops at the start and end (though they may be optimized away if not needed).
412    pub fn new(tree: &SimplifiedTreeNode) -> WorkingNFA {
413        return Self::new_loop_opt(tree, true, true);
414    }
415    /// Creates an NFA but allowing specification of whether to include the `.*?` loops at the start and end.
416    pub fn new_loop_opt(tree: &SimplifiedTreeNode, start_loop: bool, end_loop: bool) -> WorkingNFA {
417        let mut nfa = WorkingNFA::build(tree);
418
419        nfa.clean_start_anchors();
420        nfa.clean_end_anchors();
421
422        // add loops at start and end in case we lack anchors
423        if start_loop {
424            nfa = WorkingNFA::nfa_concat([
425                WorkingNFA::nfa_star(
426                    WorkingNFA::nfa_symbol(&Atom::NonmatchingList(Vec::new())),
427                    false,
428                ),
429                nfa,
430            ]);
431        }
432        if end_loop {
433            nfa = WorkingNFA::nfa_concat([
434                nfa,
435                WorkingNFA::nfa_star(
436                    WorkingNFA::nfa_symbol(&Atom::NonmatchingList(Vec::new())),
437                    false,
438                ),
439            ]);
440        }
441
442        let zero_symbol_states: Vec<bool> =
443            std::iter::zip(nfa.nodes_after_end(), nfa.nodes_before_start())
444                .map(|(a, b)| a || b)
445                .collect();
446        for (from, state) in nfa.states.iter_mut().enumerate() {
447            if zero_symbol_states[from] {
448                state.transitions = Vec::new();
449            }
450        }
451
452        // nfa.remove_unreachable();
453        // Finally, do normal optimization passes
454        // println!("{}", nfa.to_tikz(true));
455        while nfa.optimize_pass() {
456            // println!("{}", nfa.to_tikz(true));
457        }
458        nfa.remove_unreachable();
459        return nfa;
460    }
461
462    /// Removes start anchors that will never be satisfied
463    /// (basically turning them into a `Never` to allow further optimization)
464    fn clean_start_anchors(&mut self) {
465        let mut zero_len_reachable = vec![false; self.states.len()];
466        zero_len_reachable[0] = true;
467        let mut stack = vec![0];
468        while let Some(state) = stack.pop() {
469            for e in &self.states[state].epsilons {
470                if !zero_len_reachable[e.to] {
471                    stack.push(e.to);
472                }
473                zero_len_reachable[e.to] = true;
474            }
475        }
476
477        for (i, state) in self.states.iter_mut().enumerate() {
478            state
479                .epsilons
480                .retain(|e| e.special != EpsilonType::StartAnchor || zero_len_reachable[i]);
481        }
482    }
483
484    /// Removes end anchors that will never be satisfied
485    /// (basically turning them into a `Never` to allow further optimization)    
486    fn clean_end_anchors(&mut self) {
487        let mut zero_len_reachable = vec![false; self.states.len()];
488        zero_len_reachable[self.states.len() - 1] = true;
489
490        let mut reverse_epsilons = vec![Vec::new(); self.states.len()];
491        for (i, state) in self.states.iter().enumerate() {
492            for e in &state.epsilons {
493                reverse_epsilons[e.to].push(i);
494            }
495        }
496
497        let mut stack = vec![self.states.len() - 1];
498        while let Some(state) = stack.pop() {
499            for src in &reverse_epsilons[state] {
500                if !zero_len_reachable[*src] {
501                    stack.push(*src);
502                }
503                zero_len_reachable[*src] = true;
504            }
505        }
506
507        for state in self.states.iter_mut() {
508            state
509                .epsilons
510                .retain(|e| e.special != EpsilonType::EndAnchor || zero_len_reachable[e.to]);
511        }
512    }
513    /// Finds all nodes that are only ever visited after a `$`.
514    fn nodes_after_end(&self) -> Vec<bool> {
515        let mut nodes = vec![true; self.states.len()];
516        nodes[0] = false;
517
518        let mut stack = vec![0];
519        while let Some(from) = stack.pop() {
520            for e in self.states[from].epsilons.iter() {
521                if nodes[e.to] && e.special != EpsilonType::EndAnchor {
522                    nodes[e.to] = false;
523                    stack.push(e.to);
524                }
525            }
526            for t in self.states[from].transitions.iter() {
527                if nodes[t.to] {
528                    nodes[t.to] = false;
529                    stack.push(t.to);
530                }
531            }
532        }
533        return nodes;
534    }
535    /// Finds all nodes that are only ever visited before a `^`.
536    fn nodes_before_start(&self) -> Vec<bool> {
537        let mut reverse = vec![Vec::new(); self.states.len()];
538        for (i, state) in self.states.iter().enumerate() {
539            for e in &state.epsilons {
540                if e.special != EpsilonType::StartAnchor {
541                    reverse[e.to].push(i);
542                }
543            }
544            for t in &state.transitions {
545                reverse[t.to].push(i);
546            }
547        }
548
549        let mut nodes = vec![true; self.states.len()];
550        nodes[self.states.len() - 1] = false;
551
552        let mut stack = vec![self.states.len() - 1];
553        while let Some(to) = stack.pop() {
554            for from in &reverse[to] {
555                if nodes[*from] {
556                    nodes[*from] = false;
557                    stack.push(*from);
558                }
559            }
560        }
561        return nodes;
562    }
563
564    /// Helper function for removing a set of states.
565    ///
566    /// These states should have no incoming transitions.
567    fn remove_dead_states<T: IntoIterator<Item = bool>>(&mut self, dead_states: T) {
568        let state_map: Vec<usize> = dead_states
569            .into_iter()
570            .scan(0, |s, dead| {
571                if dead {
572                    return Some(usize::MAX);
573                } else {
574                    let out = *s;
575                    *s += 1;
576                    return Some(out);
577                }
578            })
579            .collect();
580        self.states = self
581            .states
582            .iter()
583            .enumerate()
584            .filter(|(i, _)| state_map[*i] != usize::MAX)
585            .map(|(_, state)| state)
586            .cloned()
587            .collect();
588
589        for state in &mut self.states {
590            for t in &mut state.transitions {
591                t.to = state_map[t.to];
592            }
593            for t in &mut state.epsilons {
594                t.to = state_map[t.to];
595            }
596        }
597    }
598
599    /// De-duplicates identical transitions
600    /// (`a -e> b`, `a -e> b`) -> (`a -e> b`)
601    ///
602    /// Returns `true` if changes were made.
603    /// The highest-priority transition will be kept.
604    ///
605    /// ---
606    ///
607    /// Typically these are caused by optimizations that merge paths.
608    fn dedupe_transitions(&mut self) -> bool {
609        let mut changed = false;
610
611        for state in &mut self.states {
612            // state transitions
613            let keep: Vec<bool> = state
614                .transitions
615                .iter()
616                .enumerate()
617                .map(|(i, e)| state.transitions[..=i].contains(e))
618                .collect();
619            let prev_len = state.transitions.len();
620            let mut i = 0;
621            state.transitions.retain(|_| {
622                let idx = i;
623                i += 1;
624                return keep[idx];
625            });
626            if state.transitions.len() != prev_len {
627                changed = true;
628            }
629
630            // epsilon transitions
631            let keep: Vec<bool> = state
632                .epsilons
633                .iter()
634                .enumerate()
635                .map(|(i, e)| !state.epsilons[..i].contains(e))
636                .collect();
637            let prev_len = state.epsilons.len();
638            let mut i = 0;
639            state.epsilons.retain(|_| {
640                let idx = i;
641                i += 1;
642                return keep[idx];
643            });
644            if state.epsilons.len() != prev_len {
645                changed = true;
646            }
647        }
648
649        return changed;
650    }
651
652    /// Various operations to optimize the NFA graph.
653    ///
654    /// Returns `true` if changes were made (meaning another pass should be tried).
655    fn optimize_pass(&mut self) -> bool {
656        let mut changed = false;
657        let state_count = self.states.len();
658        debug_assert!(state_count >= 2);
659
660        let mut dead_states = vec![false; self.states.len()];
661
662        // Skip redundant states
663        // Special transitions (anchors + capture groups) are treated similar to non-epsilon transitions
664        'state_loop: for state_idx in 1..state_count - 1 {
665            // merge states with same outgoing
666            for other_idx in 0..state_count - 1 {
667                if self.states[state_idx].epsilons == self.states[other_idx].epsilons
668                    && self.states[state_idx].transitions == self.states[other_idx].transitions
669                    && state_idx != other_idx
670                    && (!self.states[state_idx].epsilons.is_empty()
671                        || !self.states[state_idx].transitions.is_empty())
672                {
673                    // TODO: if the two states have self-loops, they currently are not counted
674                    // as equivalent even if they should be.
675
676                    // I think symbol transition order matters here because it may have been created by previous
677                    // optimizations, which originated from epsilon transitions where it was important.
678                    dead_states[state_idx] = true;
679                    changed = true;
680                    self.states[state_idx].epsilons = Vec::new();
681                    self.states[state_idx].transitions = Vec::new();
682                    // divert other states to other
683                    for s in &mut self.states {
684                        for ep in &mut s.epsilons {
685                            if ep.to == state_idx {
686                                ep.to = other_idx;
687                            }
688                        }
689                        for tr in &mut s.transitions {
690                            if tr.to == state_idx {
691                                tr.to = other_idx;
692                            }
693                        }
694                    }
695                    continue 'state_loop;
696                }
697            }
698
699            // dedupe transitions
700            changed |= self.dedupe_transitions();
701
702            // skip redundant
703            let incoming: Vec<(usize, usize)> = self
704                .states
705                .iter()
706                .enumerate()
707                .flat_map(|(s_i, s)| s.transitions.iter().enumerate().map(move |(t, _)| (s_i, t)))
708                .filter(|(s, t)| self.states[*s].transitions[*t].to == state_idx)
709                .collect();
710            let incoming_eps: Vec<(usize, usize)> = self
711                .states
712                .iter()
713                .enumerate()
714                .flat_map(|(s_i, s)| s.epsilons.iter().enumerate().map(move |(e, _)| (s_i, e)))
715                .filter(|(s, e)| self.states[*s].epsilons[*e].to == state_idx)
716                .collect();
717
718            match (
719                incoming.as_slice(),
720                incoming_eps.as_slice(),
721                self.states[state_idx].transitions.len(),
722                self.states[state_idx].epsilons.len(),
723            ) {
724                // `as -xes> b -e> c` can become `as -xes> c` (assuming no other transitions)
725                (incoming, incoming_eps, 0, 1)
726                    if self.states[state_idx].epsilons[0].special == EpsilonType::None =>
727                {
728                    let to = self.states[state_idx].epsilons[0].to;
729                    for (s, t) in incoming {
730                        self.states[*s].transitions[*t].to = to;
731                    }
732                    for (s, e) in incoming_eps {
733                        self.states[*s].epsilons[*e].to = to;
734                    }
735                    dead_states[state_idx] = true;
736                    self.states[state_idx].epsilons = Vec::new();
737                    changed = true;
738                    continue;
739                }
740                // `a -e> b -es> cs` can become `a -es> cs` (assuming no other transitions)
741                (&[], &[(incoming_state, incoming_eps)], 0, _)
742                    if self.states[incoming_state].epsilons[incoming_eps].special
743                        == EpsilonType::None =>
744                {
745                    let outgoing_eps = std::mem::take(&mut self.states[state_idx].epsilons);
746                    let after = self.states[incoming_state]
747                        .epsilons
748                        .split_off(incoming_eps + 1);
749                    self.states[incoming_state].epsilons.pop();
750                    self.states[incoming_state]
751                        .epsilons
752                        .extend_from_slice(&outgoing_eps);
753                    self.states[incoming_state]
754                        .epsilons
755                        .extend_from_slice(&after);
756
757                    dead_states[state_idx] = true;
758                    changed = true;
759                    continue;
760                }
761                _ => {}
762            }
763
764            // TODO:
765            // `a -e> b -xes> cs` can become `a -xes> cs` (assuming no other transitions)
766            // `a -e> b -e> a` can combine `a` and `b` (including other transitions)
767            // TODO: might cause additional overhead in some cases, should we do
768            // ??? `a -x> b -es> cs` can become `a -xs> cs`
769            // ??? `as -es> b -x> c` can become `as -xs> c`
770        }
771
772        if changed {
773            self.remove_dead_states(dead_states);
774            return true;
775        }
776        return false;
777    }
778
779    /// Finds the states that can be reached from the start via any path
780    fn states_reachable_start(&self) -> Vec<bool> {
781        let mut reachable = vec![false; self.states.len()];
782        reachable[0] = true;
783        let mut stack = vec![0];
784
785        while let Some(state) = stack.pop() {
786            for src in &self.states[state].epsilons {
787                if !reachable[src.to] {
788                    stack.push(src.to);
789                }
790                reachable[src.to] = true;
791            }
792            for src in &self.states[state].transitions {
793                if !reachable[src.to] {
794                    stack.push(src.to);
795                }
796                reachable[src.to] = true;
797            }
798        }
799
800        return reachable;
801    }
802    /// Finds the states that can reach the end via any path
803    fn states_reachable_end(&self) -> Vec<bool> {
804        let mut reverse = vec![Vec::new(); self.states.len()];
805        for (i, state) in self.states.iter().enumerate() {
806            for e in &state.epsilons {
807                reverse[e.to].push(i);
808            }
809            for t in &state.transitions {
810                reverse[t.to].push(i);
811            }
812        }
813
814        let mut reachable = vec![false; self.states.len()];
815        reachable[self.states.len() - 1] = true;
816        let mut stack = vec![self.states.len() - 1];
817
818        while let Some(state) = stack.pop() {
819            for src in &reverse[state] {
820                if !reachable[*src] {
821                    stack.push(*src);
822                }
823                reachable[*src] = true;
824            }
825        }
826
827        return reachable;
828    }
829
830    /// Removes all nodes that cannot be reached or cannot reach the end.
831    ///
832    /// Ignores special epsilon types (so should be called after they have been resolved)
833    fn remove_unreachable(&mut self) {
834        let reach_start = self.states_reachable_start();
835        let reach_end = self.states_reachable_end();
836
837        // Remove transitions that involve redundant states
838        for state in &mut self.states {
839            state
840                .epsilons
841                .retain(|e| reach_start[e.to] && reach_end[e.to]);
842            state
843                .transitions
844                .retain(|t| reach_start[t.to] && reach_end[t.to]);
845        }
846
847        // Then remove the states
848        self.remove_dead_states(
849            std::iter::zip(reach_start.into_iter(), reach_end.into_iter()).map(|(a, b)| !a || !b),
850        );
851    }
852
853    /// Finds the number of capture groups in this NFA
854    pub fn num_capture_groups(&self) -> usize {
855        return self
856            .states
857            .iter()
858            .flat_map(|state| &state.epsilons)
859            .map(|eps| match eps.special {
860                EpsilonType::StartCapture(n) => n,
861                _ => 0,
862            })
863            .max()
864            .unwrap_or(0)
865            + 1;
866    }
867
868    /// Returns whether each there is any matching path where the capture group is unused
869    pub fn capture_group_is_optional(&self, group_num: usize) -> bool {
870        let mut reached_states = vec![false; self.states.len()];
871        reached_states[0] = true;
872        let mut stack = vec![0];
873        while let Some(idx) = stack.pop() {
874            for t in &self.states[idx].transitions {
875                if !reached_states[t.to] {
876                    reached_states[t.to] = true;
877                    stack.push(t.to);
878                }
879            }
880            for e in &self.states[idx].epsilons {
881                if !reached_states[e.to] && e.special != EpsilonType::StartCapture(group_num) {
882                    // the end capture should not be reachable without a preceding
883                    // start capture, so we only need to check the start.
884                    debug_assert_ne!(e.special, EpsilonType::EndCapture(group_num));
885                    reached_states[e.to] = true;
886                    stack.push(e.to);
887                }
888            }
889        }
890
891        return *reached_states.last().unwrap();
892    }
893
894    /// Writes a LaTeX TikZ representation to visualize the graph.
895    ///
896    /// If `include_doc` is `true`, will include the headers.
897    /// Otherwise, you should include `\usepackage{tikz}` and `\usetikzlibrary{automata, positioning}`.
898    pub fn to_tikz(&self, include_doc: bool) -> String {
899        let map_state =
900            |(i, state): (usize, &WorkingState)| -> crate::visualization::LatexGraphState {
901                let transitions =
902                    state
903                        .transitions
904                        .iter()
905                        .map(|t| crate::visualization::LatexGraphTransition {
906                            label: crate::visualization::escape_latex(t.symbol.to_string()),
907                            to: t.to,
908                        });
909                let epsilons = state.epsilons.iter().enumerate().map(|(i, e)| {
910                    let label = match e.special {
911                        EpsilonType::None => format!(r"$\epsilon_{{{i}}}$"),
912                        EpsilonType::StartAnchor => format!(r"{{\textasciicircum}}$_{{{i}}}$"),
913                        EpsilonType::EndAnchor => format!(r"$\$_{{{i}}}$"),
914                        EpsilonType::StartCapture(group) => format!("${group}(_{{{i}}}$"),
915                        EpsilonType::EndCapture(group) => format!("$){group}_{{{i}}}$"),
916                    };
917                    return crate::visualization::LatexGraphTransition { label, to: e.to };
918                });
919                let transitions = transitions.chain(epsilons).collect();
920                return crate::visualization::LatexGraphState {
921                    label: format!("q{i}"),
922                    transitions,
923                    initial: i == 0,
924                    accept: i + 1 == self.states.len(),
925                };
926            };
927
928        let graph = crate::visualization::LatexGraph {
929            states: self.states.iter().enumerate().map(map_state).collect(),
930        };
931        return graph.to_tikz(include_doc);
932    }
933
934    /// Using the classical NFA algorithm to do a simple boolean test on a string.
935    pub fn test(&self, text: &str) -> bool {
936        let mut list = vec![false; self.states.len()];
937        let mut new_list = vec![false; self.states.len()];
938        list[0] = true;
939
940        // Adds all states reachable by epsilon transitions
941        let propogate_epsilon = |list: &mut Vec<bool>, idx: usize| {
942            let mut stack: Vec<usize> = list
943                .iter()
944                .enumerate()
945                .filter_map(|(i, set)| set.then_some(i))
946                .collect();
947
948            while let Some(from) = stack.pop() {
949                for EpsilonTransition { to, special } in &self.states[from].epsilons {
950                    if list[from]
951                        && !list[*to]
952                        && (match special {
953                            EpsilonType::StartAnchor => idx == 0,
954                            EpsilonType::EndAnchor => idx == text.len(),
955                            _ => true,
956                        })
957                    {
958                        stack.push(*to);
959                        list[*to] = true;
960                    }
961                }
962            }
963        };
964
965        for (i, c) in text.char_indices() {
966            propogate_epsilon(&mut list, i);
967            for (from, state) in self.states.iter().enumerate() {
968                if !list[from] {
969                    continue;
970                }
971
972                for WorkingTransition { to, symbol } in &state.transitions {
973                    if symbol.check(c) {
974                        new_list[*to] = true;
975                    }
976                }
977            }
978            let tmp = list;
979            list = new_list;
980            new_list = tmp;
981            new_list.fill(false);
982        }
983        propogate_epsilon(&mut list, text.len());
984        return *list.last().unwrap_or(&false);
985    }
986}
987impl std::fmt::Display for WorkingNFA {
988    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
989        for (i, state) in self.states.iter().enumerate() {
990            writeln!(f, "State {i}:")?;
991            for e in &state.epsilons {
992                writeln!(f, "  {e}")?;
993            }
994            for t in &state.transitions {
995                writeln!(f, "  {t}")?;
996            }
997        }
998        return Ok(());
999    }
1000}
1001
1002#[cfg(test)]
1003mod tests {
1004    use super::*;
1005    use crate::{config::Config, parse_tree::ERE};
1006
1007    #[test]
1008    fn abbc_raw() {
1009        let nfa = WorkingNFA {
1010            states: vec![
1011                WorkingState::new().with_transition(1, 'a'.into()),
1012                WorkingState::new().with_transition(2, 'b'.into()),
1013                WorkingState::new()
1014                    .with_transition(3, 'c'.into())
1015                    .with_epsilon(1),
1016                WorkingState::new(),
1017            ],
1018        };
1019        println!("{}", nfa.to_tikz(true));
1020
1021        assert!(nfa.test("abc"));
1022        assert!(nfa.test("abbc"));
1023        assert!(nfa.test("abbbc"));
1024        assert!(nfa.test("abbbbc"));
1025
1026        assert!(!nfa.test("ac"));
1027        assert!(!nfa.test("abcc"));
1028        assert!(!nfa.test("bac"));
1029        assert!(!nfa.test("acb"));
1030    }
1031
1032    #[test]
1033    fn phone_number() {
1034        let ere = ERE::parse_str(r"^(\+1 )?[0-9]{3}-[0-9]{3}-[0-9]{4}$").unwrap();
1035        let (tree, capture_groups) = SimplifiedTreeNode::from_ere(&ere, &Config::default());
1036        assert_eq!(capture_groups, 2);
1037        let nfa = WorkingNFA::new(&tree);
1038        println!("{}", nfa.to_tikz(true));
1039
1040        assert!(nfa.test("012-345-6789"));
1041        assert!(nfa.test("987-654-3210"));
1042        assert!(nfa.test("+1 555-555-5555"));
1043        assert!(nfa.test("123-555-9876"));
1044
1045        assert!(!nfa.test("abcd"));
1046        assert!(!nfa.test("0123456789"));
1047        assert!(!nfa.test("012--345-6789"));
1048        assert!(!nfa.test("(555) 555-5555"));
1049        assert!(!nfa.test("1 555-555-5555"));
1050    }
1051
1052    #[test]
1053    fn double_loop() {
1054        let ere = ERE::parse_str(r"^.*(.*)*$").unwrap();
1055        let (tree, capture_groups) = SimplifiedTreeNode::from_ere(&ere, &Config::default());
1056        assert_eq!(capture_groups, 2);
1057        let nfa = WorkingNFA::new(&tree);
1058        // println!("{}", nfa.to_tikz(true));
1059
1060        assert!(nfa.test(""));
1061        assert!(nfa.test("asdf"));
1062        assert!(nfa.test("1234567"));
1063        assert!(nfa.test("0"));
1064
1065        assert!(!nfa.test("\0"));
1066    }
1067
1068    #[test]
1069    fn good_anchored_start() {
1070        let ere = ERE::parse_str(r"^a|b*^c|d^|n").unwrap();
1071        let (tree, capture_groups) = SimplifiedTreeNode::from_ere(&ere, &Config::default());
1072        assert_eq!(capture_groups, 1);
1073        let nfa = WorkingNFA::new(&tree);
1074        // println!("{}", nfa.to_tikz(true));
1075
1076        assert!(nfa.test("a"));
1077        assert!(nfa.test("c"));
1078        assert!(nfa.test("cq"));
1079        assert!(nfa.test("wwwnwww"));
1080
1081        assert!(!nfa.test(""));
1082        assert!(!nfa.test("qb"));
1083        assert!(!nfa.test("qc"));
1084        assert!(!nfa.test("b"));
1085        assert!(!nfa.test("bc"));
1086        assert!(!nfa.test("bbbbbbc"));
1087        assert!(!nfa.test("d"));
1088    }
1089
1090    #[test]
1091    fn good_anchored_end() {
1092        let ere = ERE::parse_str(r"a$|b$c*|$d|n").unwrap();
1093        let (tree, capture_groups) = SimplifiedTreeNode::from_ere(&ere, &Config::default());
1094        assert_eq!(capture_groups, 1);
1095        let nfa = WorkingNFA::new(&tree);
1096        println!("{}", nfa.to_tikz(true));
1097
1098        assert!(nfa.test("a"));
1099        assert!(nfa.test("b"));
1100        assert!(nfa.test("qb"));
1101        assert!(nfa.test("wwwnwww"));
1102
1103        assert!(!nfa.test(""));
1104        assert!(!nfa.test("bq"));
1105        assert!(!nfa.test("qc"));
1106        assert!(!nfa.test("c"));
1107        assert!(!nfa.test("bc"));
1108        assert!(!nfa.test("bcccccc"));
1109        assert!(!nfa.test("d"));
1110    }
1111
1112    #[test]
1113    fn range_digit() {
1114        let ere = ERE::parse_str(r"^[[:digit:].]$").unwrap();
1115        let (tree, capture_groups) = SimplifiedTreeNode::from_ere(&ere, &Config::default());
1116        assert_eq!(capture_groups, 1);
1117        let nfa = WorkingNFA::new(&tree);
1118        // println!("{}", nfa.to_tikz(true));
1119
1120        assert!(nfa.test("0"));
1121        assert!(nfa.test("1"));
1122        assert!(nfa.test("9"));
1123        assert!(nfa.test("."));
1124
1125        assert!(!nfa.test(""));
1126        assert!(!nfa.test("a"));
1127        assert!(!nfa.test("11"));
1128        assert!(!nfa.test("1."));
1129        assert!(!nfa.test(".2"));
1130        assert!(!nfa.test("09"));
1131        assert!(!nfa.test("d"));
1132    }
1133}