ere_core/
working_nfa.rs

1//! Implements the primary compile-time intermediate [`WorkingNFA`] structure for optimization.
2
3use crate::parse_tree::Atom;
4use crate::simplified_tree::SimplifiedTreeNode;
5use quote::{quote, ToTokens};
6
7#[derive(Debug, Clone, Copy, PartialEq, Eq)]
8pub enum EpsilonType {
9    None,
10    StartAnchor,
11    EndAnchor,
12    StartCapture(usize),
13    EndCapture(usize),
14}
15impl ToTokens for EpsilonType {
16    fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
17        match self {
18            EpsilonType::None => tokens.extend(quote! { ::ere::working_nfa::EpsilonType::None }),
19            EpsilonType::StartAnchor => {
20                tokens.extend(quote! { ::ere::working_nfa::EpsilonType::StartAnchor })
21            }
22            EpsilonType::EndAnchor => {
23                tokens.extend(quote! { ::ere::working_nfa::EpsilonType::EndAnchor })
24            }
25            EpsilonType::StartCapture(group_num) => tokens.extend(quote! {
26                ::ere::working_nfa::EpsilonType::StartCapture(#group_num)
27            }),
28            EpsilonType::EndCapture(group_num) => tokens.extend(quote! {
29                ::ere::working_nfa::EpsilonType::EndCapture(#group_num)
30            }),
31        };
32    }
33}
34
35/// An epsilon transition for the [`WorkingNFA`]
36#[derive(Debug, Clone, Copy, PartialEq, Eq)]
37pub struct EpsilonTransition {
38    pub(crate) to: usize,
39    pub(crate) special: EpsilonType,
40}
41impl EpsilonTransition {
42    pub(crate) const fn new(to: usize) -> EpsilonTransition {
43        return EpsilonTransition {
44            to,
45            special: EpsilonType::None,
46        };
47    }
48    pub(crate) const fn with_offset(self, offset: usize) -> EpsilonTransition {
49        return EpsilonTransition {
50            to: self.to + offset,
51            special: self.special,
52        };
53    }
54    pub(crate) fn inplace_offset(&mut self, offset: usize) {
55        self.to += offset;
56    }
57    pub(crate) const fn add_offset(&self, offset: usize) -> EpsilonTransition {
58        return EpsilonTransition {
59            to: self.to + offset,
60            special: self.special,
61        };
62    }
63    /// Only intended for internal use by macros.
64    pub const fn __load(to: usize, special: EpsilonType) -> EpsilonTransition {
65        return EpsilonTransition { to, special };
66    }
67}
68impl std::fmt::Display for EpsilonTransition {
69    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
70        return write!(f, "-> {}", self.to);
71    }
72}
73impl ToTokens for EpsilonTransition {
74    fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
75        let EpsilonTransition { to, special } = self;
76        tokens.extend(quote! {
77            ere_core::working_nfa::EpsilonTransition::__load(
78                #to,
79                #special,
80            )
81        });
82    }
83}
84
85#[derive(Debug, Clone, PartialEq, Eq)]
86pub(crate) struct WorkingTransition {
87    pub(crate) to: usize,
88    pub(crate) symbol: Atom,
89}
90impl WorkingTransition {
91    pub fn new(to: usize, symbol: Atom) -> WorkingTransition {
92        return WorkingTransition { to, symbol };
93    }
94    pub fn with_offset(mut self, offset: usize) -> WorkingTransition {
95        self.inplace_offset(offset);
96        return self;
97    }
98    pub fn inplace_offset(&mut self, offset: usize) {
99        self.to += offset;
100    }
101    pub fn add_offset(&self, offset: usize) -> WorkingTransition {
102        return WorkingTransition {
103            to: self.to + offset,
104            symbol: self.symbol.clone(),
105        };
106    }
107}
108impl std::fmt::Display for WorkingTransition {
109    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
110        return write!(f, "-({})> {}", self.symbol, self.to);
111    }
112}
113
114/// The symbol transitions are assumed to have priority before epsilon transitions,
115/// meaning that of all the propogated next symbol transitions available,
116/// those going out from the previous transition's destination will come first.
117///
118/// So, if the symbol transitions should have a lower priority than those found
119/// via an epsilon transition, they should be in a new state with a lower priority
120/// epsilon transition to it.
121/// Due to the way we construct NFAs, this should initially be the case--
122/// No state has both incoming and outgoing symbol transitions
123/// (they are always separated by at least one epsilon transition) and priority is maintained.
124/// We then can optimize where there are no competing transitions.
125#[derive(Debug, Clone)]
126pub struct WorkingState {
127    pub(crate) transitions: Vec<WorkingTransition>,
128    pub(crate) epsilons: Vec<EpsilonTransition>,
129}
130impl WorkingState {
131    pub const fn new() -> WorkingState {
132        return WorkingState {
133            transitions: Vec::new(),
134            epsilons: Vec::new(),
135        };
136    }
137    pub fn with_transition(mut self, to: usize, symbol: Atom) -> WorkingState {
138        self.transitions.push(WorkingTransition::new(to, symbol));
139        return self;
140    }
141    pub fn with_epsilon(mut self, to: usize) -> WorkingState {
142        self.epsilons.push(EpsilonTransition::new(to));
143        return self;
144    }
145    pub fn with_epsilon_special(mut self, to: usize, special: EpsilonType) -> WorkingState {
146        self.epsilons.push(EpsilonTransition { to, special });
147        return self;
148    }
149    pub fn with_offset(mut self, offset: usize) -> WorkingState {
150        self.inplace_offset(offset);
151        return self;
152    }
153    pub fn inplace_offset(&mut self, offset: usize) {
154        for t in &mut self.transitions {
155            t.inplace_offset(offset);
156        }
157        for e in &mut self.epsilons {
158            e.inplace_offset(offset);
159        }
160    }
161    pub fn add_offset(&self, offset: usize) -> WorkingState {
162        return WorkingState {
163            transitions: self
164                .transitions
165                .iter()
166                .map(|t| t.add_offset(offset))
167                .collect(),
168            epsilons: self.epsilons.iter().map(|e| e.add_offset(offset)).collect(),
169        };
170    }
171}
172impl std::fmt::Display for WorkingState {
173    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
174        for t in &self.transitions {
175            writeln!(f, "  {t}")?;
176        }
177        for e in &self.epsilons {
178            writeln!(f, "  {e}")?;
179        }
180        return Ok(());
181    }
182}
183
184/// Each NFA has one start state (`0`) and one accept state (`states.len() - 1`)
185#[derive(Debug, Clone)]
186pub struct WorkingNFA {
187    pub(crate) states: Vec<WorkingState>,
188}
189impl WorkingNFA {
190    /// Makes an NFA that matches with zero length.
191    fn nfa_empty() -> WorkingNFA {
192        let states = vec![WorkingState::new()];
193        return WorkingNFA { states };
194    }
195    /// Makes an NFA matching a some symbol.
196    fn nfa_symbol(c: &Atom) -> WorkingNFA {
197        let states = vec![
198            WorkingState::new().with_transition(1, c.clone()),
199            WorkingState::new(),
200        ];
201        return WorkingNFA { states };
202    }
203    /// Makes a union of NFAs.
204    fn nfa_union(nodes: &[WorkingNFA]) -> WorkingNFA {
205        let states_count = 2 + nodes.iter().map(|n| n.states.len()).sum::<usize>();
206        let mut states = vec![WorkingState::new()];
207        for nfa in nodes {
208            let sub_nfa_start = states.len();
209            states[0]
210                .epsilons
211                .push(EpsilonTransition::new(sub_nfa_start));
212            states.extend(
213                nfa.states
214                    .iter()
215                    .map(|state| state.add_offset(sub_nfa_start)),
216            );
217            states
218                .last_mut()
219                .unwrap()
220                .epsilons
221                .push(EpsilonTransition::new(states_count - 1));
222        }
223        states.push(WorkingState::new());
224        assert_eq!(states_count, states.len());
225
226        return WorkingNFA { states };
227    }
228    fn build_union(nodes: &[SimplifiedTreeNode]) -> WorkingNFA {
229        let sub_nfas: Vec<WorkingNFA> = nodes.iter().map(WorkingNFA::build).collect();
230        return WorkingNFA::nfa_union(&sub_nfas);
231    }
232    /// Wraps an NFA part in a capture group.
233    fn nfa_capture(nfa: &WorkingNFA, group_num: usize) -> WorkingNFA {
234        let states_count = 2 + nfa.states.len();
235        let mut states: Vec<WorkingState> = std::iter::once(
236            WorkingState::new().with_epsilon_special(1, EpsilonType::StartCapture(group_num)),
237        )
238        .chain(nfa.states.iter().map(|state| state.add_offset(1)))
239        .chain(std::iter::once(WorkingState::new()))
240        .collect();
241        assert_eq!(states_count, states.len());
242        states[states_count - 2].epsilons.push(EpsilonTransition {
243            to: states_count - 1,
244            special: EpsilonType::EndCapture(group_num),
245        });
246
247        return WorkingNFA { states };
248    }
249    fn build_capture(tree: &SimplifiedTreeNode, group_num: usize) -> WorkingNFA {
250        let nfa = WorkingNFA::build(tree);
251        return WorkingNFA::nfa_capture(&nfa, group_num);
252    }
253    /// Makes an NFA that matches a concatenation of NFAs.
254    fn nfa_concat<T: IntoIterator<Item = WorkingNFA>>(nodes: T) -> WorkingNFA {
255        let mut states = vec![WorkingState::new().with_epsilon(1)];
256
257        for nfa in nodes {
258            let states_count = states.len();
259            states.extend(
260                nfa.states
261                    .into_iter()
262                    .map(|state| state.with_offset(states_count)),
263            );
264            let states_count = states.len();
265            states
266                .last_mut()
267                .unwrap()
268                .epsilons
269                .push(EpsilonTransition::new(states_count));
270        }
271
272        states.push(WorkingState::new());
273        return WorkingNFA { states };
274    }
275    fn build_concat<'a, T: IntoIterator<Item = &'a SimplifiedTreeNode>>(nodes: T) -> WorkingNFA {
276        return WorkingNFA::nfa_concat(nodes.into_iter().map(WorkingNFA::build));
277    }
278    /// Makes an NFA that matches some NFA concatenated with itself multiple times.
279    fn nfa_repeat(nfa: &WorkingNFA, times: usize) -> WorkingNFA {
280        return WorkingNFA::nfa_concat(std::iter::repeat(nfa).cloned().take(times));
281    }
282    fn build_repeat(tree: &SimplifiedTreeNode, times: usize) -> WorkingNFA {
283        let nfa = WorkingNFA::build(tree);
284        return WorkingNFA::nfa_repeat(&nfa, times);
285    }
286    /// Makes an NFA that matches some NFA concatenated with itself up to some number of times.
287    fn nfa_upto(nfa: &WorkingNFA, times: usize, longest: bool) -> WorkingNFA {
288        let end_state_idx = 1 + (nfa.states.len() + 1) * times;
289
290        let state0 = if longest {
291            WorkingState::new()
292                .with_epsilon(1)
293                .with_epsilon(end_state_idx - 1)
294        } else {
295            WorkingState::new()
296                .with_epsilon(end_state_idx - 1)
297                .with_epsilon(1)
298        };
299        let mut states = vec![state0];
300        for i in 0..times {
301            let states_count = states.len();
302            states.extend(
303                nfa.states
304                    .iter()
305                    .map(|state| state.add_offset(states_count)),
306            );
307            let transition_state_idx = states.len();
308            states
309                .last_mut()
310                .unwrap()
311                .epsilons
312                .push(EpsilonTransition::new(transition_state_idx));
313            let mut transition_state = WorkingState::new();
314            if i + 1 != times {
315                if longest {
316                    transition_state
317                        .epsilons
318                        .push(EpsilonTransition::new(states.len() + 1));
319                }
320
321                transition_state
322                    .epsilons
323                    .push(EpsilonTransition::new(end_state_idx - 1));
324                if !longest {
325                    transition_state
326                        .epsilons
327                        .push(EpsilonTransition::new(states.len() + 1));
328                }
329            }
330            states.push(transition_state);
331        }
332
333        return WorkingNFA { states };
334    }
335    fn build_upto(tree: &SimplifiedTreeNode, times: usize, longest: bool) -> WorkingNFA {
336        let nfa = WorkingNFA::build(tree);
337        return WorkingNFA::nfa_upto(&nfa, times, longest);
338    }
339    /// Makes an NFA that matches some NFA concatenated with itself any number of times.
340    fn nfa_star(nfa: WorkingNFA, longest: bool) -> WorkingNFA {
341        let end_state_idx = 1 + nfa.states.len();
342        let mut start_state = WorkingState::new();
343        if !longest {
344            start_state
345                .epsilons
346                .push(EpsilonTransition::new(end_state_idx));
347        }
348        start_state.epsilons.push(EpsilonTransition::new(1));
349        if longest {
350            start_state
351                .epsilons
352                .push(EpsilonTransition::new(end_state_idx));
353        }
354        let mut states: Vec<WorkingState> = std::iter::once(start_state)
355            .chain(nfa.states.into_iter().map(|state| state.with_offset(1)))
356            .chain(std::iter::once(WorkingState::new()))
357            .collect();
358        states[end_state_idx - 1]
359            .epsilons
360            .push(EpsilonTransition::new(0));
361        return WorkingNFA { states };
362    }
363    fn build_star(tree: &SimplifiedTreeNode, longest: bool) -> WorkingNFA {
364        let nfa = WorkingNFA::build(tree);
365        return WorkingNFA::nfa_star(nfa, longest);
366    }
367    /// Makes an NFA that matches zero length but only at the text start
368    fn nfa_start() -> WorkingNFA {
369        let states = vec![
370            WorkingState::new().with_epsilon_special(1, EpsilonType::StartAnchor),
371            WorkingState::new(),
372        ];
373        return WorkingNFA { states };
374    }
375    /// Makes an NFA that matches zero length but only at the text end
376    fn nfa_end() -> WorkingNFA {
377        let states = vec![
378            WorkingState::new().with_epsilon_special(1, EpsilonType::EndAnchor),
379            WorkingState::new(),
380        ];
381        return WorkingNFA { states };
382    }
383    /// Makes an NFA that never matches.
384    fn nfa_never() -> WorkingNFA {
385        let states = vec![WorkingState::new(), WorkingState::new()];
386        return WorkingNFA { states };
387    }
388    /// Recursively builds an inefficient but valid NFA based loosely on Thompson's Algorithm.
389    ///
390    /// Should be optimized using [`WorkingNFA::optimize_pass`]
391    pub fn build(tree: &SimplifiedTreeNode) -> WorkingNFA {
392        return match tree {
393            SimplifiedTreeNode::Empty => WorkingNFA::nfa_empty(),
394            SimplifiedTreeNode::Symbol(c) => WorkingNFA::nfa_symbol(c),
395            SimplifiedTreeNode::Union(nodes) => WorkingNFA::build_union(nodes),
396            SimplifiedTreeNode::Capture(tree, group_num) => {
397                WorkingNFA::build_capture(&tree, *group_num)
398            }
399            SimplifiedTreeNode::Concat(nodes) => WorkingNFA::build_concat(nodes),
400            SimplifiedTreeNode::Repeat(tree, times) => WorkingNFA::build_repeat(tree, *times),
401            SimplifiedTreeNode::UpTo(tree, times, longest) => {
402                WorkingNFA::build_upto(tree, *times, *longest)
403            }
404            SimplifiedTreeNode::Star(tree, longest) => WorkingNFA::build_star(tree, *longest),
405            SimplifiedTreeNode::Start => WorkingNFA::nfa_start(),
406            SimplifiedTreeNode::End => WorkingNFA::nfa_end(),
407            SimplifiedTreeNode::Never => WorkingNFA::nfa_never(),
408        };
409    }
410    pub fn new(tree: &SimplifiedTreeNode) -> WorkingNFA {
411        let mut nfa = WorkingNFA::build(tree);
412
413        nfa.clean_start_anchors();
414        nfa.clean_end_anchors();
415
416        // add loops at start and end in case we lack anchors
417        nfa = WorkingNFA::nfa_concat([
418            WorkingNFA::nfa_star(
419                WorkingNFA::nfa_symbol(&Atom::NonmatchingList(Vec::new())),
420                false,
421            ),
422            nfa,
423            WorkingNFA::nfa_star(
424                WorkingNFA::nfa_symbol(&Atom::NonmatchingList(Vec::new())),
425                false,
426            ),
427        ]);
428
429        // Then remove redundant transitions from nodes before/after anchors
430        // May include the loops we just added
431        let zero_symbol_states: Vec<bool> =
432            std::iter::zip(nfa.nodes_after_end(), nfa.nodes_before_start())
433                .map(|(a, b)| a || b)
434                .collect();
435        for (from, state) in nfa.states.iter_mut().enumerate() {
436            if zero_symbol_states[from] {
437                state.transitions = Vec::new();
438            }
439        }
440
441        // nfa.remove_unreachable();
442        // Finally, do normal optimization passes
443        // println!("{}", nfa.to_tikz(true));
444        while nfa.optimize_pass() {
445            // println!("{}", nfa.to_tikz(true));
446        }
447        nfa.remove_unreachable();
448        return nfa;
449    }
450
451    /// Removes start anchors that will never be satisfied
452    /// (basically turning them into a `Never` to allow further optimization)
453    fn clean_start_anchors(&mut self) {
454        let mut zero_len_reachable = vec![false; self.states.len()];
455        zero_len_reachable[0] = true;
456        let mut stack = vec![0];
457        while let Some(state) = stack.pop() {
458            for e in &self.states[state].epsilons {
459                if !zero_len_reachable[e.to] {
460                    stack.push(e.to);
461                }
462                zero_len_reachable[e.to] = true;
463            }
464        }
465
466        for (i, state) in self.states.iter_mut().enumerate() {
467            state
468                .epsilons
469                .retain(|e| e.special != EpsilonType::StartAnchor || zero_len_reachable[i]);
470        }
471    }
472
473    /// Removes end anchors that will never be satisfied
474    /// (basically turning them into a `Never` to allow further optimization)    
475    fn clean_end_anchors(&mut self) {
476        let mut zero_len_reachable = vec![false; self.states.len()];
477        zero_len_reachable[self.states.len() - 1] = true;
478
479        let mut reverse_epsilons = vec![Vec::new(); self.states.len()];
480        for (i, state) in self.states.iter().enumerate() {
481            for e in &state.epsilons {
482                reverse_epsilons[e.to].push(i);
483            }
484        }
485
486        let mut stack = vec![self.states.len() - 1];
487        while let Some(state) = stack.pop() {
488            for src in &reverse_epsilons[state] {
489                if !zero_len_reachable[*src] {
490                    stack.push(*src);
491                }
492                zero_len_reachable[*src] = true;
493            }
494        }
495
496        for state in self.states.iter_mut() {
497            state
498                .epsilons
499                .retain(|e| e.special != EpsilonType::EndAnchor || zero_len_reachable[e.to]);
500        }
501    }
502    /// Finds all nodes that are only ever visited after a `$`.
503    fn nodes_after_end(&self) -> Vec<bool> {
504        let mut nodes = vec![true; self.states.len()];
505        nodes[0] = false;
506
507        let mut stack = vec![0];
508        while let Some(from) = stack.pop() {
509            for e in self.states[from].epsilons.iter() {
510                if nodes[e.to] && e.special != EpsilonType::EndAnchor {
511                    nodes[e.to] = false;
512                    stack.push(e.to);
513                }
514            }
515            for t in self.states[from].transitions.iter() {
516                if nodes[t.to] {
517                    nodes[t.to] = false;
518                    stack.push(t.to);
519                }
520            }
521        }
522        return nodes;
523    }
524    /// Finds all nodes that are only ever visited before a `^`.
525    fn nodes_before_start(&self) -> Vec<bool> {
526        let mut reverse = vec![Vec::new(); self.states.len()];
527        for (i, state) in self.states.iter().enumerate() {
528            for e in &state.epsilons {
529                if e.special != EpsilonType::StartAnchor {
530                    reverse[e.to].push(i);
531                }
532            }
533            for t in &state.transitions {
534                reverse[t.to].push(i);
535            }
536        }
537
538        let mut nodes = vec![true; self.states.len()];
539        nodes[self.states.len() - 1] = false;
540
541        let mut stack = vec![self.states.len() - 1];
542        while let Some(to) = stack.pop() {
543            for from in &reverse[to] {
544                if nodes[*from] {
545                    nodes[*from] = false;
546                    stack.push(*from);
547                }
548            }
549        }
550        return nodes;
551    }
552
553    /// Helper function for removing a set of states.
554    ///
555    /// These states should have no incoming transitions.
556    fn remove_dead_states<T: IntoIterator<Item = bool>>(&mut self, dead_states: T) {
557        let state_map: Vec<usize> = dead_states
558            .into_iter()
559            .scan(0, |s, dead| {
560                if dead {
561                    return Some(usize::MAX);
562                } else {
563                    let out = *s;
564                    *s += 1;
565                    return Some(out);
566                }
567            })
568            .collect();
569        self.states = self
570            .states
571            .iter()
572            .enumerate()
573            .filter(|(i, _)| state_map[*i] != usize::MAX)
574            .map(|(_, state)| state)
575            .cloned()
576            .collect();
577
578        for state in &mut self.states {
579            for t in &mut state.transitions {
580                t.to = state_map[t.to];
581            }
582            for t in &mut state.epsilons {
583                t.to = state_map[t.to];
584            }
585        }
586    }
587
588    /// De-duplicates identical transitions
589    /// (`a -e> b`, `a -e> b`) -> (`a -e> b`)
590    ///
591    /// Returns `true` if changes were made.
592    /// The highest-priority transition will be kept.
593    ///
594    /// ---
595    ///
596    /// Typically these are caused by optimizations that merge paths.
597    fn dedupe_transitions(&mut self) -> bool {
598        let mut changed = false;
599
600        for state in &mut self.states {
601            // state transitions
602            let keep: Vec<bool> = state
603                .transitions
604                .iter()
605                .enumerate()
606                .map(|(i, e)| state.transitions[..=i].contains(e))
607                .collect();
608            let prev_len = state.transitions.len();
609            let mut i = 0;
610            state.transitions.retain(|_| {
611                let idx = i;
612                i += 1;
613                return keep[idx];
614            });
615            if state.transitions.len() != prev_len {
616                changed = true;
617            }
618
619            // epsilon transitions
620            let keep: Vec<bool> = state
621                .epsilons
622                .iter()
623                .enumerate()
624                .map(|(i, e)| !state.epsilons[..i].contains(e))
625                .collect();
626            let prev_len = state.epsilons.len();
627            let mut i = 0;
628            state.epsilons.retain(|_| {
629                let idx = i;
630                i += 1;
631                return keep[idx];
632            });
633            if state.epsilons.len() != prev_len {
634                changed = true;
635            }
636        }
637
638        return changed;
639    }
640
641    /// Various operations to optimize the NFA graph.
642    ///
643    /// Returns `true` if changes were made (meaning another pass should be tried).
644    fn optimize_pass(&mut self) -> bool {
645        let mut changed = false;
646        let state_count = self.states.len();
647        debug_assert!(state_count >= 2);
648
649        let mut dead_states = vec![false; self.states.len()];
650
651        // Skip redundant states
652        // Special transitions (anchors + capture groups) are treated similar to non-epsilon transitions
653        'state_loop: for state_idx in 1..state_count - 1 {
654            // merge states with same outgoing
655            for other_idx in 0..state_count - 1 {
656                if self.states[state_idx].epsilons == self.states[other_idx].epsilons
657                    && self.states[state_idx].transitions == self.states[other_idx].transitions
658                    && state_idx != other_idx
659                    && (!self.states[state_idx].epsilons.is_empty()
660                        || !self.states[state_idx].transitions.is_empty())
661                {
662                    // TODO: if the two states have self-loops, they currently are not counted
663                    // as equivalent even if they should be.
664
665                    // I think symbol transition order matters here because it may have been created by previous
666                    // optimizations, which originated from epsilon transitions where it was important.
667                    dead_states[state_idx] = true;
668                    changed = true;
669                    self.states[state_idx].epsilons = Vec::new();
670                    self.states[state_idx].transitions = Vec::new();
671                    // divert other states to other
672                    for s in &mut self.states {
673                        for ep in &mut s.epsilons {
674                            if ep.to == state_idx {
675                                ep.to = other_idx;
676                            }
677                        }
678                        for tr in &mut s.transitions {
679                            if tr.to == state_idx {
680                                tr.to = other_idx;
681                            }
682                        }
683                    }
684                    continue 'state_loop;
685                }
686            }
687
688            // dedupe transitions
689            changed |= self.dedupe_transitions();
690
691            // skip redundant
692            let incoming: Vec<(usize, usize)> = self
693                .states
694                .iter()
695                .enumerate()
696                .flat_map(|(s_i, s)| s.transitions.iter().enumerate().map(move |(t, _)| (s_i, t)))
697                .filter(|(s, t)| self.states[*s].transitions[*t].to == state_idx)
698                .collect();
699            let incoming_eps: Vec<(usize, usize)> = self
700                .states
701                .iter()
702                .enumerate()
703                .flat_map(|(s_i, s)| s.epsilons.iter().enumerate().map(move |(e, _)| (s_i, e)))
704                .filter(|(s, e)| self.states[*s].epsilons[*e].to == state_idx)
705                .collect();
706
707            match (
708                incoming.as_slice(),
709                incoming_eps.as_slice(),
710                self.states[state_idx].transitions.len(),
711                self.states[state_idx].epsilons.len(),
712            ) {
713                // `as -xes> b -e> c` can become `as -xes> c` (assuming no other transitions)
714                (incoming, incoming_eps, 0, 1)
715                    if self.states[state_idx].epsilons[0].special == EpsilonType::None =>
716                {
717                    let to = self.states[state_idx].epsilons[0].to;
718                    for (s, t) in incoming {
719                        self.states[*s].transitions[*t].to = to;
720                    }
721                    for (s, e) in incoming_eps {
722                        self.states[*s].epsilons[*e].to = to;
723                    }
724                    dead_states[state_idx] = true;
725                    self.states[state_idx].epsilons = Vec::new();
726                    changed = true;
727                    continue;
728                }
729                // `a -e> b -es> cs` can become `a -es> cs` (assuming no other transitions)
730                (&[], &[(incoming_state, incoming_eps)], 0, _)
731                    if self.states[incoming_state].epsilons[incoming_eps].special
732                        == EpsilonType::None =>
733                {
734                    let outgoing_eps = std::mem::take(&mut self.states[state_idx].epsilons);
735                    let after = self.states[incoming_state]
736                        .epsilons
737                        .split_off(incoming_eps + 1);
738                    self.states[incoming_state].epsilons.pop();
739                    self.states[incoming_state]
740                        .epsilons
741                        .extend_from_slice(&outgoing_eps);
742                    self.states[incoming_state]
743                        .epsilons
744                        .extend_from_slice(&after);
745
746                    dead_states[state_idx] = true;
747                    changed = true;
748                    continue;
749                }
750                _ => {}
751            }
752
753            // TODO:
754            // `a -e> b -xes> cs` can become `a -xes> cs` (assuming no other transitions)
755            // `a -e> b -e> a` can combine `a` and `b` (including other transitions)
756            // TODO: might cause additional overhead in some cases, should we do
757            // ??? `a -x> b -es> cs` can become `a -xs> cs`
758            // ??? `as -es> b -x> c` can become `as -xs> c`
759        }
760
761        if changed {
762            self.remove_dead_states(dead_states);
763            return true;
764        }
765        return false;
766    }
767
768    /// Finds the states that can be reached from the start via any path
769    fn states_reachable_start(&self) -> Vec<bool> {
770        let mut reachable = vec![false; self.states.len()];
771        reachable[0] = true;
772        let mut stack = vec![0];
773
774        while let Some(state) = stack.pop() {
775            for src in &self.states[state].epsilons {
776                if !reachable[src.to] {
777                    stack.push(src.to);
778                }
779                reachable[src.to] = true;
780            }
781            for src in &self.states[state].transitions {
782                if !reachable[src.to] {
783                    stack.push(src.to);
784                }
785                reachable[src.to] = true;
786            }
787        }
788
789        return reachable;
790    }
791    /// Finds the states that can reach the end via any path
792    fn states_reachable_end(&self) -> Vec<bool> {
793        let mut reverse = vec![Vec::new(); self.states.len()];
794        for (i, state) in self.states.iter().enumerate() {
795            for e in &state.epsilons {
796                reverse[e.to].push(i);
797            }
798            for t in &state.transitions {
799                reverse[t.to].push(i);
800            }
801        }
802
803        let mut reachable = vec![false; self.states.len()];
804        reachable[self.states.len() - 1] = true;
805        let mut stack = vec![self.states.len() - 1];
806
807        while let Some(state) = stack.pop() {
808            for src in &reverse[state] {
809                if !reachable[*src] {
810                    stack.push(*src);
811                }
812                reachable[*src] = true;
813            }
814        }
815
816        return reachable;
817    }
818
819    /// Removes all nodes that cannot be reached or cannot reach the end.
820    ///
821    /// Ignores special epsilon types (so should be called after they have been resolved)
822    fn remove_unreachable(&mut self) {
823        let reach_start = self.states_reachable_start();
824        let reach_end = self.states_reachable_end();
825
826        // Remove transitions that involve redundant states
827        for state in &mut self.states {
828            state
829                .epsilons
830                .retain(|e| reach_start[e.to] && reach_end[e.to]);
831            state
832                .transitions
833                .retain(|t| reach_start[t.to] && reach_end[t.to]);
834        }
835
836        // Then remove the states
837        self.remove_dead_states(
838            std::iter::zip(reach_start.into_iter(), reach_end.into_iter()).map(|(a, b)| !a || !b),
839        );
840    }
841
842    /// Finds the number of capture groups in this NFA
843    pub fn num_capture_groups(&self) -> usize {
844        return self
845            .states
846            .iter()
847            .flat_map(|state| &state.epsilons)
848            .map(|eps| match eps.special {
849                EpsilonType::StartCapture(n) => n,
850                _ => 0,
851            })
852            .max()
853            .unwrap_or(0)
854            + 1;
855    }
856
857    /// Returns whether each there is any matching path where the capture group is unused
858    pub fn capture_group_is_optional(&self, group_num: usize) -> bool {
859        let mut reached_states = vec![false; self.states.len()];
860        reached_states[0] = true;
861        let mut stack = vec![0];
862        while let Some(idx) = stack.pop() {
863            for t in &self.states[idx].transitions {
864                if !reached_states[t.to] {
865                    reached_states[t.to] = true;
866                    stack.push(t.to);
867                }
868            }
869            for e in &self.states[idx].epsilons {
870                if !reached_states[e.to] && e.special != EpsilonType::StartCapture(group_num) {
871                    // the end capture should not be reachable without a preceding
872                    // start capture, so we only need to check the start.
873                    debug_assert_ne!(e.special, EpsilonType::EndCapture(group_num));
874                    reached_states[e.to] = true;
875                    stack.push(e.to);
876                }
877            }
878        }
879
880        return *reached_states.last().unwrap();
881    }
882
883    /// Writes a LaTeX TikZ representation to visualize the graph.
884    ///
885    /// If `include_doc` is `true`, will include the headers.
886    /// Otherwise, you should include `\usepackage{tikz}` and `\usetikzlibrary{automata, positioning}`.
887    pub fn to_tikz(&self, include_doc: bool) -> String {
888        let map_state =
889            |(i, state): (usize, &WorkingState)| -> crate::visualization::LatexGraphState {
890                let transitions =
891                    state
892                        .transitions
893                        .iter()
894                        .map(|t| crate::visualization::LatexGraphTransition {
895                            label: crate::visualization::escape_latex(t.symbol.to_string()),
896                            to: t.to,
897                        });
898                let epsilons = state.epsilons.iter().enumerate().map(|(i, e)| {
899                    let label = match e.special {
900                        EpsilonType::None => format!(r"$\epsilon_{{{i}}}$"),
901                        EpsilonType::StartAnchor => format!(r"{{\textasciicircum}}$_{{{i}}}$"),
902                        EpsilonType::EndAnchor => format!(r"$\$_{{{i}}}$"),
903                        EpsilonType::StartCapture(group) => format!("${group}(_{{{i}}}$"),
904                        EpsilonType::EndCapture(group) => format!("$){group}_{{{i}}}$"),
905                    };
906                    return crate::visualization::LatexGraphTransition { label, to: e.to };
907                });
908                let transitions = transitions.chain(epsilons).collect();
909                return crate::visualization::LatexGraphState {
910                    label: format!("q{i}"),
911                    transitions,
912                    initial: i == 0,
913                    accept: i + 1 == self.states.len(),
914                };
915            };
916
917        let graph = crate::visualization::LatexGraph {
918            states: self.states.iter().enumerate().map(map_state).collect(),
919        };
920        return graph.to_tikz(include_doc);
921    }
922
923    /// Using the classical NFA algorithm to do a simple boolean test on a string.
924    pub fn test(&self, text: &str) -> bool {
925        let mut list = vec![false; self.states.len()];
926        let mut new_list = vec![false; self.states.len()];
927        list[0] = true;
928
929        // Adds all states reachable by epsilon transitions
930        let propogate_epsilon = |list: &mut Vec<bool>, idx: usize| {
931            let mut stack: Vec<usize> = list
932                .iter()
933                .enumerate()
934                .filter_map(|(i, set)| set.then_some(i))
935                .collect();
936
937            while let Some(from) = stack.pop() {
938                for EpsilonTransition { to, special } in &self.states[from].epsilons {
939                    if list[from]
940                        && !list[*to]
941                        && (match special {
942                            EpsilonType::StartAnchor => idx == 0,
943                            EpsilonType::EndAnchor => idx == text.len(),
944                            _ => true,
945                        })
946                    {
947                        stack.push(*to);
948                        list[*to] = true;
949                    }
950                }
951            }
952        };
953
954        for (i, c) in text.char_indices() {
955            propogate_epsilon(&mut list, i);
956            for (from, state) in self.states.iter().enumerate() {
957                if !list[from] {
958                    continue;
959                }
960
961                for WorkingTransition { to, symbol } in &state.transitions {
962                    if symbol.check(c) {
963                        new_list[*to] = true;
964                    }
965                }
966            }
967            let tmp = list;
968            list = new_list;
969            new_list = tmp;
970            new_list.fill(false);
971        }
972        propogate_epsilon(&mut list, text.len());
973        return *list.last().unwrap_or(&false);
974    }
975}
976impl std::fmt::Display for WorkingNFA {
977    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
978        for (i, state) in self.states.iter().enumerate() {
979            writeln!(f, "State {i}:")?;
980            for e in &state.epsilons {
981                writeln!(f, "  {e}")?;
982            }
983            for t in &state.transitions {
984                writeln!(f, "  {t}")?;
985            }
986        }
987        return Ok(());
988    }
989}
990
991#[cfg(test)]
992mod tests {
993    use super::*;
994    use crate::{config::Config, parse_tree::ERE};
995
996    #[test]
997    fn abbc_raw() {
998        let nfa = WorkingNFA {
999            states: vec![
1000                WorkingState::new().with_transition(1, 'a'.into()),
1001                WorkingState::new().with_transition(2, 'b'.into()),
1002                WorkingState::new()
1003                    .with_transition(3, 'c'.into())
1004                    .with_epsilon(1),
1005                WorkingState::new(),
1006            ],
1007        };
1008        println!("{}", nfa.to_tikz(true));
1009
1010        assert!(nfa.test("abc"));
1011        assert!(nfa.test("abbc"));
1012        assert!(nfa.test("abbbc"));
1013        assert!(nfa.test("abbbbc"));
1014
1015        assert!(!nfa.test("ac"));
1016        assert!(!nfa.test("abcc"));
1017        assert!(!nfa.test("bac"));
1018        assert!(!nfa.test("acb"));
1019    }
1020
1021    #[test]
1022    fn phone_number() {
1023        let ere = ERE::parse_str(r"^(\+1 )?[0-9]{3}-[0-9]{3}-[0-9]{4}$").unwrap();
1024        let (tree, capture_groups) = SimplifiedTreeNode::from_ere(&ere, &Config::default());
1025        assert_eq!(capture_groups, 2);
1026        let nfa = WorkingNFA::new(&tree);
1027        println!("{}", nfa.to_tikz(true));
1028
1029        assert!(nfa.test("012-345-6789"));
1030        assert!(nfa.test("987-654-3210"));
1031        assert!(nfa.test("+1 555-555-5555"));
1032        assert!(nfa.test("123-555-9876"));
1033
1034        assert!(!nfa.test("abcd"));
1035        assert!(!nfa.test("0123456789"));
1036        assert!(!nfa.test("012--345-6789"));
1037        assert!(!nfa.test("(555) 555-5555"));
1038        assert!(!nfa.test("1 555-555-5555"));
1039    }
1040
1041    #[test]
1042    fn double_loop() {
1043        let ere = ERE::parse_str(r"^.*(.*)*$").unwrap();
1044        let (tree, capture_groups) = SimplifiedTreeNode::from_ere(&ere, &Config::default());
1045        assert_eq!(capture_groups, 2);
1046        let nfa = WorkingNFA::new(&tree);
1047        // println!("{}", nfa.to_tikz(true));
1048
1049        assert!(nfa.test(""));
1050        assert!(nfa.test("asdf"));
1051        assert!(nfa.test("1234567"));
1052        assert!(nfa.test("0"));
1053
1054        assert!(!nfa.test("\0"));
1055    }
1056
1057    #[test]
1058    fn good_anchored_start() {
1059        let ere = ERE::parse_str(r"^a|b*^c|d^|n").unwrap();
1060        let (tree, capture_groups) = SimplifiedTreeNode::from_ere(&ere, &Config::default());
1061        assert_eq!(capture_groups, 1);
1062        let nfa = WorkingNFA::new(&tree);
1063        // println!("{}", nfa.to_tikz(true));
1064
1065        assert!(nfa.test("a"));
1066        assert!(nfa.test("c"));
1067        assert!(nfa.test("cq"));
1068        assert!(nfa.test("wwwnwww"));
1069
1070        assert!(!nfa.test(""));
1071        assert!(!nfa.test("qb"));
1072        assert!(!nfa.test("qc"));
1073        assert!(!nfa.test("b"));
1074        assert!(!nfa.test("bc"));
1075        assert!(!nfa.test("bbbbbbc"));
1076        assert!(!nfa.test("d"));
1077    }
1078
1079    #[test]
1080    fn good_anchored_end() {
1081        let ere = ERE::parse_str(r"a$|b$c*|$d|n").unwrap();
1082        let (tree, capture_groups) = SimplifiedTreeNode::from_ere(&ere, &Config::default());
1083        assert_eq!(capture_groups, 1);
1084        let nfa = WorkingNFA::new(&tree);
1085        println!("{}", nfa.to_tikz(true));
1086
1087        assert!(nfa.test("a"));
1088        assert!(nfa.test("b"));
1089        assert!(nfa.test("qb"));
1090        assert!(nfa.test("wwwnwww"));
1091
1092        assert!(!nfa.test(""));
1093        assert!(!nfa.test("bq"));
1094        assert!(!nfa.test("qc"));
1095        assert!(!nfa.test("c"));
1096        assert!(!nfa.test("bc"));
1097        assert!(!nfa.test("bcccccc"));
1098        assert!(!nfa.test("d"));
1099    }
1100
1101    #[test]
1102    fn range_digit() {
1103        let ere = ERE::parse_str(r"^[[:digit:].]$").unwrap();
1104        let (tree, capture_groups) = SimplifiedTreeNode::from_ere(&ere, &Config::default());
1105        assert_eq!(capture_groups, 1);
1106        let nfa = WorkingNFA::new(&tree);
1107        // println!("{}", nfa.to_tikz(true));
1108
1109        assert!(nfa.test("0"));
1110        assert!(nfa.test("1"));
1111        assert!(nfa.test("9"));
1112        assert!(nfa.test("."));
1113
1114        assert!(!nfa.test(""));
1115        assert!(!nfa.test("a"));
1116        assert!(!nfa.test("11"));
1117        assert!(!nfa.test("1."));
1118        assert!(!nfa.test(".2"));
1119        assert!(!nfa.test("09"));
1120        assert!(!nfa.test("d"));
1121    }
1122}