ere_core/
working_u8_nfa.rs

1//! Implements `u8`-based version of [`crate::working_nfa`].
2//!
3//! Primarily involves converting from a [`WorkingNFA`] to a [`U8NFA`],
4//! which is used as an additional intermediate step for engines that match `u8`s
5//! instead of the more complex `char`s.
6
7use crate::working_nfa::{EpsilonType, WorkingNFA};
8use crate::{parse_tree::Atom, working_nfa::EpsilonTransition};
9use std::ops::RangeInclusive;
10use std::{usize, vec};
11
12#[derive(Debug, PartialEq, Eq, Clone)]
13pub struct U8Atom(pub RangeInclusive<u8>);
14impl U8Atom {
15    #[inline]
16    pub fn check(&self, c: u8) -> bool {
17        return self.0.contains(&c);
18    }
19    #[inline]
20    pub const fn start(&self) -> u8 {
21        return *self.0.start();
22    }
23    #[inline]
24    pub const fn end(&self) -> u8 {
25        return *self.0.end();
26    }
27}
28impl std::fmt::Display for U8Atom {
29    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
30        if self.0.start() == self.0.end() {
31            return write!(f, "{}", self.0.start().escape_ascii());
32        } else {
33            return write!(
34                f,
35                "[{}-{}]",
36                self.0.start().escape_ascii(),
37                self.0.end().escape_ascii()
38            );
39        }
40    }
41}
42impl From<u8> for U8Atom {
43    fn from(value: u8) -> Self {
44        return U8Atom(value..=value);
45    }
46}
47impl From<RangeInclusive<u8>> for U8Atom {
48    fn from(value: RangeInclusive<u8>) -> Self {
49        return U8Atom(value);
50    }
51}
52impl TryFrom<char> for U8Atom {
53    type Error = std::char::TryFromCharError;
54
55    fn try_from(value: char) -> Result<Self, Self::Error> {
56        return Ok(u8::try_from(value)?.into());
57    }
58}
59
60#[derive(Debug, Clone, PartialEq, Eq)]
61pub(crate) struct U8Transition {
62    pub(crate) to: usize,
63    pub(crate) symbol: U8Atom,
64}
65impl U8Transition {
66    pub fn new(to: usize, symbol: U8Atom) -> U8Transition {
67        return U8Transition { to, symbol };
68    }
69    pub fn with_offset(mut self, offset: usize) -> U8Transition {
70        self.inplace_offset(offset);
71        return self;
72    }
73    pub fn inplace_offset(&mut self, offset: usize) {
74        self.to += offset;
75    }
76    pub fn add_offset(&self, offset: usize) -> U8Transition {
77        return U8Transition {
78            to: self.to + offset,
79            symbol: self.symbol.clone(),
80        };
81    }
82}
83impl std::fmt::Display for U8Transition {
84    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
85        return write!(f, "-({})> {}", self.symbol, self.to);
86    }
87}
88
89#[derive(Debug, Clone)]
90pub struct U8State {
91    pub(crate) transitions: Vec<U8Transition>,
92    pub(crate) epsilons: Vec<EpsilonTransition>,
93}
94impl U8State {
95    pub const fn new() -> U8State {
96        return U8State {
97            transitions: Vec::new(),
98            epsilons: Vec::new(),
99        };
100    }
101    pub fn with_transition(mut self, to: usize, symbol: U8Atom) -> U8State {
102        self.transitions.push(U8Transition::new(to, symbol));
103        return self;
104    }
105    pub fn with_epsilon(mut self, to: usize) -> U8State {
106        self.epsilons.push(EpsilonTransition::new(to));
107        return self;
108    }
109    pub fn with_epsilon_special(mut self, to: usize, special: EpsilonType) -> U8State {
110        self.epsilons.push(EpsilonTransition { to, special });
111        return self;
112    }
113    pub fn with_offset(mut self, offset: usize) -> U8State {
114        self.inplace_offset(offset);
115        return self;
116    }
117    pub fn inplace_offset(&mut self, offset: usize) {
118        for t in &mut self.transitions {
119            t.inplace_offset(offset);
120        }
121        for e in &mut self.epsilons {
122            e.inplace_offset(offset);
123        }
124    }
125    pub fn add_offset(&self, offset: usize) -> U8State {
126        return U8State {
127            transitions: self
128                .transitions
129                .iter()
130                .map(|t| t.add_offset(offset))
131                .collect(),
132            epsilons: self.epsilons.iter().map(|e| e.add_offset(offset)).collect(),
133        };
134    }
135}
136impl std::fmt::Display for U8State {
137    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
138        for t in &self.transitions {
139            writeln!(f, "  {t}")?;
140        }
141        for e in &self.epsilons {
142            writeln!(f, "  {e}")?;
143        }
144        return Ok(());
145    }
146}
147
148/// Each NFA has one start state (`0`) and one accept state (`states.len() - 1`)
149#[derive(Debug, Clone)]
150pub struct U8NFA {
151    pub(crate) states: Vec<U8State>,
152}
153impl U8NFA {
154    /// Makes an NFA that matches with zero length.
155    fn nfa_empty() -> U8NFA {
156        let states = vec![U8State::new()];
157        return U8NFA { states };
158    }
159    /// Makes an NFA matching a byte range.
160    fn nfa_byte(c: &U8Atom) -> U8NFA {
161        let states = vec![U8State::new().with_transition(1, c.clone()), U8State::new()];
162        return U8NFA { states };
163    }
164    /// Makes an NFA matching a specific char.
165    fn nfa_symbol_char(c: char) -> U8NFA {
166        let mut bytes = [0u8; 4];
167        c.encode_utf8(&mut bytes);
168        let states = bytes
169            .iter()
170            .take(c.len_utf8())
171            .enumerate()
172            .map(|(i, byte)| U8State::new().with_transition(i + 1, (*byte).into()))
173            .chain(std::iter::once(U8State::new()))
174            .collect();
175        return U8NFA { states };
176    }
177    /// Makes an NFA matching a some symbol.
178    fn nfa_symbol(c: &Atom) -> U8NFA {
179        let ranges = c.to_ranges();
180        let mut states = vec![U8State::new()];
181
182        for range in ranges {
183            for byte_ranges in utf8_ranges::Utf8Sequences::new(*range.start(), *range.end()) {
184                let mut state = 0usize;
185                for (i, byte_range) in byte_ranges.into_iter().enumerate() {
186                    let byte_range = byte_range.start..=byte_range.end;
187                    if let Some(next_state) = states[state]
188                        .transitions
189                        .iter()
190                        .find(|a| a.symbol.0 == byte_range)
191                    {
192                        state = next_state.to;
193                    } else if i + 1 == byte_ranges.len() {
194                        states[state]
195                            .transitions
196                            .push(U8Transition::new(usize::MAX, U8Atom(byte_range)));
197                        break; // sanity check: should be unnecessary
198                    } else {
199                        let new_state_idx = states.len();
200                        states.push(U8State::new());
201                        states[state]
202                            .transitions
203                            .push(U8Transition::new(new_state_idx, U8Atom(byte_range)));
204                        state = new_state_idx;
205                    }
206                }
207            }
208        }
209
210        // then insert accept state, replacing [`usize::MAX`] placeholders
211        let accept_state_idx = states.len();
212        states.push(U8State::new());
213        for state in &mut states {
214            for transition in &mut state.transitions {
215                if transition.to == usize::MAX {
216                    transition.to = accept_state_idx;
217                }
218            }
219        }
220        // TODO: shared suffix optimizations
221
222        return U8NFA { states };
223    }
224    /// Makes a union of NFAs.
225    fn nfa_union(nodes: &[U8NFA]) -> U8NFA {
226        let states_count = 2 + nodes.iter().map(|n| n.states.len()).sum::<usize>();
227        let mut states = vec![U8State::new()];
228        for nfa in nodes {
229            let sub_nfa_start = states.len();
230            states[0]
231                .epsilons
232                .push(EpsilonTransition::new(sub_nfa_start));
233            states.extend(
234                nfa.states
235                    .iter()
236                    .map(|state| state.add_offset(sub_nfa_start)),
237            );
238            states
239                .last_mut()
240                .unwrap()
241                .epsilons
242                .push(EpsilonTransition::new(states_count - 1));
243        }
244        states.push(U8State::new());
245        assert_eq!(states_count, states.len());
246
247        return U8NFA { states };
248    }
249    /// Wraps an NFA part in a capture group.
250    fn nfa_capture(nfa: &U8NFA, group_num: usize) -> U8NFA {
251        let states_count = 2 + nfa.states.len();
252        let mut states: Vec<U8State> = std::iter::once(
253            U8State::new().with_epsilon_special(1, EpsilonType::StartCapture(group_num)),
254        )
255        .chain(nfa.states.iter().map(|state| state.add_offset(1)))
256        .chain(std::iter::once(U8State::new()))
257        .collect();
258        assert_eq!(states_count, states.len());
259        states[states_count - 2].epsilons.push(EpsilonTransition {
260            to: states_count - 1,
261            special: EpsilonType::EndCapture(group_num),
262        });
263
264        return U8NFA { states };
265    }
266    /// Makes an NFA that matches a concatenation of NFAs.
267    fn nfa_concat<T: IntoIterator<Item = U8NFA>>(nodes: T) -> U8NFA {
268        let mut states = vec![U8State::new().with_epsilon(1)];
269
270        for nfa in nodes {
271            let states_count = states.len();
272            states.extend(
273                nfa.states
274                    .into_iter()
275                    .map(|state| state.with_offset(states_count)),
276            );
277            let states_count = states.len();
278            states
279                .last_mut()
280                .unwrap()
281                .epsilons
282                .push(EpsilonTransition::new(states_count));
283        }
284
285        states.push(U8State::new());
286        return U8NFA { states };
287    }
288    /// Makes an NFA that matches some NFA concatenated with itself multiple times.
289    fn nfa_repeat(nfa: &U8NFA, times: usize) -> U8NFA {
290        return U8NFA::nfa_concat(std::iter::repeat(nfa).cloned().take(times));
291    }
292    /// Makes an NFA that matches some NFA concatenated with itself up to some number of times.
293    fn nfa_upto(nfa: &U8NFA, times: usize, longest: bool) -> U8NFA {
294        let end_state_idx = 1 + (nfa.states.len() + 1) * times;
295
296        let mut states = vec![U8State::new()
297            .with_epsilon(1)
298            .with_epsilon(end_state_idx - 1)];
299        for i in 0..times {
300            let states_count = states.len();
301            states.extend(
302                nfa.states
303                    .iter()
304                    .map(|state| state.add_offset(states_count)),
305            );
306            let transition_state_idx = states.len();
307            states
308                .last_mut()
309                .unwrap()
310                .epsilons
311                .push(EpsilonTransition::new(transition_state_idx));
312            let mut transition_state = U8State::new();
313            if i + 1 != times {
314                if longest {
315                    transition_state
316                        .epsilons
317                        .push(EpsilonTransition::new(states.len() + 1));
318                }
319
320                transition_state
321                    .epsilons
322                    .push(EpsilonTransition::new(end_state_idx - 1));
323                if !longest {
324                    transition_state
325                        .epsilons
326                        .push(EpsilonTransition::new(states.len() + 1));
327                }
328            }
329            states.push(transition_state);
330        }
331
332        return U8NFA { states };
333    }
334    /// Makes an NFA that matches some NFA concatenated with itself any number of times.
335    fn nfa_star(nfa: U8NFA, longest: bool) -> U8NFA {
336        let end_state_idx = 1 + nfa.states.len();
337        let mut start_state = U8State::new();
338        if !longest {
339            start_state
340                .epsilons
341                .push(EpsilonTransition::new(end_state_idx));
342        }
343        start_state.epsilons.push(EpsilonTransition::new(1));
344        if longest {
345            start_state
346                .epsilons
347                .push(EpsilonTransition::new(end_state_idx));
348        }
349        let mut states: Vec<U8State> = std::iter::once(start_state)
350            .chain(nfa.states.into_iter().map(|state| state.with_offset(1)))
351            .chain(std::iter::once(U8State::new()))
352            .collect();
353        states[end_state_idx - 1]
354            .epsilons
355            .push(EpsilonTransition::new(0));
356        return U8NFA { states };
357    }
358    /// Makes an NFA that matches zero length but only at the text start
359    fn nfa_start() -> U8NFA {
360        let states = vec![
361            U8State::new().with_epsilon_special(1, EpsilonType::StartAnchor),
362            U8State::new(),
363        ];
364        return U8NFA { states };
365    }
366    /// Makes an NFA that matches zero length but only at the text end
367    fn nfa_end() -> U8NFA {
368        let states = vec![
369            U8State::new().with_epsilon_special(1, EpsilonType::EndAnchor),
370            U8State::new(),
371        ];
372        return U8NFA { states };
373    }
374    /// Makes an NFA that never matches.
375    fn nfa_never() -> U8NFA {
376        let states = vec![U8State::new(), U8State::new()];
377        return U8NFA { states };
378    }
379    /// Converts from a char-based NFA
380    ///
381    /// Does not include any optimizations
382    fn build(nfa: &WorkingNFA) -> U8NFA {
383        let mut states: Vec<U8State> = Vec::new();
384        let mut sub_states: Vec<U8State> = Vec::new();
385        for state in &nfa.states {
386            let mut new_state = U8State {
387                transitions: Vec::new(),
388                epsilons: state.epsilons.clone(),
389            };
390            // Decompose char transitions into byte transitions
391            // `a -x> b` will become an expanded nfa with initial state `a` and accept state `b`
392            for t in &state.transitions {
393                let symbol_nfa = U8NFA::nfa_symbol(&t.symbol);
394                let symbol_nfa_accept = symbol_nfa.states.len() - 1;
395                let sub_states_offset = nfa.states.len() + sub_states.len() - 1;
396
397                // Updates transition indices
398                let map_transition = |sub_t: &U8Transition| {
399                    if sub_t.to == symbol_nfa_accept {
400                        U8Transition::new(t.to, sub_t.symbol.clone())
401                    } else {
402                        sub_t.add_offset(sub_states_offset)
403                    }
404                };
405
406                new_state
407                    .transitions
408                    .extend(symbol_nfa.states[0].transitions.iter().map(map_transition));
409
410                for sub_state in symbol_nfa
411                    .states
412                    .iter()
413                    .skip(1)
414                    .take(symbol_nfa.states.len() - 2)
415                {
416                    let sub_state = U8State {
417                        transitions: sub_state.transitions.iter().map(map_transition).collect(),
418                        epsilons: Vec::new(),
419                    };
420                    sub_states.push(sub_state);
421                }
422            }
423            states.push(new_state);
424        }
425
426        states.extend_from_slice(&sub_states);
427        let new_accept_state = states.len();
428        states.push(U8State::new());
429        states[nfa.states.len() - 1]
430            .epsilons
431            .push(EpsilonTransition::new(new_accept_state));
432        return U8NFA { states };
433    }
434    /// Converts from a char-based NFA
435    pub fn new(nfa: &WorkingNFA) -> U8NFA {
436        let mut nfa = U8NFA::build(nfa);
437
438        while nfa.optimize_pass() {}
439        nfa.remove_unreachable();
440        return nfa;
441    }
442
443    /// Helper function for removing a set of states.
444    ///
445    /// These states should have no incoming transitions.
446    fn remove_dead_states<T: IntoIterator<Item = bool>>(&mut self, dead_states: T) {
447        let state_map: Vec<usize> = dead_states
448            .into_iter()
449            .scan(0, |s, dead| {
450                if dead {
451                    return Some(usize::MAX);
452                } else {
453                    let out = *s;
454                    *s += 1;
455                    return Some(out);
456                }
457            })
458            .collect();
459        self.states = self
460            .states
461            .iter()
462            .enumerate()
463            .filter(|(i, _)| state_map[*i] != usize::MAX)
464            .map(|(_, state)| state)
465            .cloned()
466            .collect();
467
468        for state in &mut self.states {
469            for t in &mut state.transitions {
470                t.to = state_map[t.to];
471            }
472            for t in &mut state.epsilons {
473                t.to = state_map[t.to];
474            }
475        }
476    }
477
478    /// De-duplicates identical transitions
479    /// (`a -e> b`, `a -e> b`) -> (`a -e> b`)
480    ///
481    /// Returns `true` if changes were made.
482    /// The highest-priority transition will be kept.
483    ///
484    /// ---
485    ///
486    /// Typically these are caused by optimizations that merge paths.
487    fn dedupe_transitions(&mut self) -> bool {
488        let mut changed = false;
489
490        for state in &mut self.states {
491            // state transitions
492            let keep: Vec<bool> = state
493                .transitions
494                .iter()
495                .enumerate()
496                .map(|(i, e)| state.transitions[..=i].contains(e))
497                .collect();
498            let prev_len = state.transitions.len();
499            let mut i = 0;
500            state.transitions.retain(|_| {
501                let idx = i;
502                i += 1;
503                return keep[idx];
504            });
505            if state.transitions.len() != prev_len {
506                changed = true;
507            }
508
509            // epsilon transitions
510            let keep: Vec<bool> = state
511                .epsilons
512                .iter()
513                .enumerate()
514                .map(|(i, e)| !state.epsilons[..i].contains(e))
515                .collect();
516            let prev_len = state.epsilons.len();
517            let mut i = 0;
518            state.epsilons.retain(|_| {
519                let idx = i;
520                i += 1;
521                return keep[idx];
522            });
523            if state.epsilons.len() != prev_len {
524                changed = true;
525            }
526        }
527
528        return changed;
529    }
530
531    /// Optimizes the NFA graph.
532    ///
533    /// Returns `true` if changes were made (meaning another pass should be tried).
534    fn optimize_pass(&mut self) -> bool {
535        let mut changed = false;
536        let state_count = self.states.len();
537
538        let mut dead_states = vec![false; self.states.len()];
539
540        // Skip redundant states
541        // Special transitions (anchors + capture groups) are treated similar to non-epsilon transitions
542        'state_loop: for state_idx in 1..state_count - 1 {
543            // merge states with same outgoing
544            for other_idx in 0..state_count - 1 {
545                if self.states[state_idx].epsilons == self.states[other_idx].epsilons
546                    && self.states[state_idx].transitions == self.states[other_idx].transitions
547                    && state_idx != other_idx
548                    && (!self.states[state_idx].epsilons.is_empty()
549                        || !self.states[state_idx].transitions.is_empty())
550                {
551                    // TODO: if the two states have self-loops, they currently are not counted
552                    // as equivalent even if they should be.
553
554                    // I think symbol transition order matters here because it may have been created by previous
555                    // optimizations, which originated from epsilon transitions where it was important.
556                    dead_states[state_idx] = true;
557                    changed = true;
558                    self.states[state_idx].epsilons = Vec::new();
559                    self.states[state_idx].transitions = Vec::new();
560                    // divert other states to other
561                    for s in &mut self.states {
562                        for ep in &mut s.epsilons {
563                            if ep.to == state_idx {
564                                ep.to = other_idx;
565                            }
566                        }
567                        for tr in &mut s.transitions {
568                            if tr.to == state_idx {
569                                tr.to = other_idx;
570                            }
571                        }
572                    }
573                    continue 'state_loop;
574                }
575            }
576
577            // dedupe transitions
578            changed |= self.dedupe_transitions();
579
580            let incoming: Vec<(usize, usize)> = self
581                .states
582                .iter()
583                .enumerate()
584                .flat_map(|(s_i, s)| s.transitions.iter().enumerate().map(move |(t, _)| (s_i, t)))
585                .filter(|(s, t)| self.states[*s].transitions[*t].to == state_idx)
586                .collect();
587            let incoming_eps: Vec<(usize, usize)> = self
588                .states
589                .iter()
590                .enumerate()
591                .flat_map(|(s_i, s)| s.epsilons.iter().enumerate().map(move |(e, _)| (s_i, e)))
592                .filter(|(s, e)| self.states[*s].epsilons[*e].to == state_idx)
593                .collect();
594
595            match (
596                incoming.as_slice(),
597                incoming_eps.as_slice(),
598                self.states[state_idx].transitions.len(),
599                self.states[state_idx].epsilons.len(),
600            ) {
601                // `as -xes> b -e> c` can become `as -xes> c` (assuming no other transitions)
602                (incoming, incoming_eps, 0, 1)
603                    if self.states[state_idx].epsilons[0].special == EpsilonType::None =>
604                {
605                    let to = self.states[state_idx].epsilons[0].to;
606                    for (s, t) in incoming {
607                        self.states[*s].transitions[*t].to = to;
608                    }
609                    for (s, e) in incoming_eps {
610                        self.states[*s].epsilons[*e].to = to;
611                    }
612                    dead_states[state_idx] = true;
613                    self.states[state_idx].epsilons = Vec::new();
614                    changed = true;
615                    continue;
616                }
617                // `a -e> b -es> cs` can become `a -es> cs` (assuming no other transitions)
618                (&[], &[(incoming_state, incoming_eps)], 0, _)
619                    if self.states[incoming_state].epsilons[incoming_eps].special
620                        == EpsilonType::None =>
621                {
622                    let outgoing_eps = std::mem::take(&mut self.states[state_idx].epsilons);
623                    let after = self.states[incoming_state]
624                        .epsilons
625                        .split_off(incoming_eps + 1);
626                    self.states[incoming_state].epsilons.pop();
627                    self.states[incoming_state]
628                        .epsilons
629                        .extend_from_slice(&outgoing_eps);
630                    self.states[incoming_state]
631                        .epsilons
632                        .extend_from_slice(&after);
633
634                    dead_states[state_idx] = true;
635                    changed = true;
636                    continue;
637                }
638                _ => {}
639            }
640
641            // TODO:
642            // `a -e> b -xes> cs` can become `a -xes> cs` (assuming no other transitions)
643            // `a -e> b -e> a` can combine `a` and `b` (including other transitions)
644            // TODO: might cause additional overhead in some cases, should we do
645            // ??? `a -x> b -es> cs` can become `a -xs> cs`
646            // ??? `as -es> b -x> c` can become `as -xs> c`
647        }
648        if !changed {
649            return changed;
650        }
651
652        self.remove_dead_states(dead_states);
653
654        return changed;
655    }
656
657    /// Finds the states that can be reached from the start via any path
658    fn states_reachable_start(&self) -> Vec<bool> {
659        let mut reachable = vec![false; self.states.len()];
660        reachable[0] = true;
661        let mut stack = vec![0];
662
663        while let Some(state) = stack.pop() {
664            for src in &self.states[state].epsilons {
665                if !reachable[src.to] {
666                    stack.push(src.to);
667                }
668                reachable[src.to] = true;
669            }
670            for src in &self.states[state].transitions {
671                if !reachable[src.to] {
672                    stack.push(src.to);
673                }
674                reachable[src.to] = true;
675            }
676        }
677
678        return reachable;
679    }
680    /// Finds the states that can reach the end via any path
681    fn states_reachable_end(&self) -> Vec<bool> {
682        let mut reverse = vec![Vec::new(); self.states.len()];
683        for (i, state) in self.states.iter().enumerate() {
684            for e in &state.epsilons {
685                reverse[e.to].push(i);
686            }
687            for t in &state.transitions {
688                reverse[t.to].push(i);
689            }
690        }
691
692        let mut reachable = vec![false; self.states.len()];
693        reachable[self.states.len() - 1] = true;
694        let mut stack = vec![self.states.len() - 1];
695
696        while let Some(state) = stack.pop() {
697            for src in &reverse[state] {
698                if !reachable[*src] {
699                    stack.push(*src);
700                }
701                reachable[*src] = true;
702            }
703        }
704
705        return reachable;
706    }
707
708    /// Removes all nodes that cannot be reached or cannot reach the end.
709    ///
710    /// Ignores special epsilon types (so should be called after they have been resolved)
711    fn remove_unreachable(&mut self) {
712        let reach_start = self.states_reachable_start();
713        let reach_end = self.states_reachable_end();
714
715        // Remove transitions that involve redundant states
716        for state in &mut self.states {
717            state
718                .epsilons
719                .retain(|e| reach_start[e.to] && reach_end[e.to]);
720            state
721                .transitions
722                .retain(|t| reach_start[t.to] && reach_end[t.to]);
723        }
724
725        // Then remove the states
726        self.remove_dead_states(
727            std::iter::zip(reach_start.into_iter(), reach_end.into_iter()).map(|(a, b)| !a || !b),
728        );
729    }
730
731    /// Finds the number of capture groups in this NFA
732    pub fn num_capture_groups(&self) -> usize {
733        return self
734            .states
735            .iter()
736            .flat_map(|state| &state.epsilons)
737            .map(|eps| match eps.special {
738                EpsilonType::StartCapture(n) => n,
739                _ => 0,
740            })
741            .max()
742            .unwrap_or(0)
743            + 1;
744    }
745
746    /// Tries to find a [topological ordering](https://en.wikipedia.org/wiki/Topological_sorting)
747    /// from the start node to the accept node.
748    ///
749    /// If successful (the graph is a DAG), it will return a sequence of indices.
750    /// Since multiple topological orderings may exist for a graph, the returned ordering may not be unique.
751    ///
752    /// If there is no topological ordering (the graph contains cycles and is not a DAG)
753    /// then it will return `None`.
754    ///
755    /// Will also return `None` if some node could not be reached. This should not happen.
756    pub fn topological_ordering(&self) -> Option<Vec<usize>> {
757        let mut done = vec![false; self.states.len()];
758        let mut active = vec![false; self.states.len()];
759        let mut order = Vec::new();
760
761        enum StackItem {
762            PreVisit(usize),
763            PostVisit(usize),
764        }
765        let mut stack = vec![StackItem::PreVisit(0)];
766
767        while let Some(item) = stack.pop() {
768            match item {
769                StackItem::PreVisit(node) => {
770                    if done[node] {
771                        continue;
772                    } else if active[node] {
773                        return None;
774                    }
775                    active[node] = true;
776                    stack.push(StackItem::PostVisit(node));
777                    for tr in &self.states[node].transitions {
778                        stack.push(StackItem::PreVisit(tr.to));
779                    }
780                    for ep in &self.states[node].epsilons {
781                        stack.push(StackItem::PreVisit(ep.to));
782                    }
783                }
784                StackItem::PostVisit(node) => {
785                    done[node] = true;
786                    order.push(node);
787                }
788            }
789        }
790
791        if order.len() != self.states.len() {
792            return None;
793        }
794        order.reverse();
795        return Some(order);
796    }
797
798    /// Writes a LaTeX TikZ representation to visualize the graph.
799    ///
800    /// If `include_doc` is `true`, will include the headers.
801    /// Otherwise, you should include `\usepackage{tikz}` and `\usetikzlibrary{automata, positioning}`.
802    pub fn to_tikz(&self, include_doc: bool) -> String {
803        let map_state = |(i, state): (usize, &U8State)| -> crate::visualization::LatexGraphState {
804            let transitions =
805                state
806                    .transitions
807                    .iter()
808                    .map(|t| crate::visualization::LatexGraphTransition {
809                        label: crate::visualization::escape_latex(t.symbol.to_string()),
810                        to: t.to,
811                    });
812            let epsilons = state.epsilons.iter().enumerate().map(|(i, e)| {
813                let label = match e.special {
814                    EpsilonType::None => format!(r"$\epsilon_{{{i}}}$"),
815                    EpsilonType::StartAnchor => format!(r"{{\textasciicircum}}$_{{{i}}}$"),
816                    EpsilonType::EndAnchor => format!(r"$\$_{{{i}}}$"),
817                    EpsilonType::StartCapture(group) => format!("${group}(_{{{i}}}$"),
818                    EpsilonType::EndCapture(group) => format!("$){group}_{{{i}}}$"),
819                };
820                return crate::visualization::LatexGraphTransition { label, to: e.to };
821            });
822            let transitions = transitions.chain(epsilons).collect();
823            return crate::visualization::LatexGraphState {
824                label: format!("q{i}"),
825                transitions,
826                initial: i == 0,
827                accept: i + 1 == self.states.len(),
828            };
829        };
830
831        let graph = crate::visualization::LatexGraph {
832            states: self.states.iter().enumerate().map(map_state).collect(),
833        };
834        return graph.to_tikz(include_doc);
835    }
836
837    /// Using the classical NFA algorithm to do a simple boolean test on a string.
838    pub fn test(&self, text: &str) -> bool {
839        let mut list = vec![false; self.states.len()];
840        let mut new_list = vec![false; self.states.len()];
841        list[0] = true;
842
843        // Adds all states reachable by epsilon transitions
844        let propogate_epsilon = |list: &mut Vec<bool>, idx: usize| {
845            let mut stack: Vec<usize> = list
846                .iter()
847                .enumerate()
848                .filter_map(|(i, set)| set.then_some(i))
849                .collect();
850
851            while let Some(from) = stack.pop() {
852                for EpsilonTransition { to, special } in &self.states[from].epsilons {
853                    if list[from]
854                        && !list[*to]
855                        && (match special {
856                            EpsilonType::StartAnchor => idx == 0,
857                            EpsilonType::EndAnchor => idx == text.len(),
858                            _ => true,
859                        })
860                    {
861                        stack.push(*to);
862                        list[*to] = true;
863                    }
864                }
865            }
866        };
867
868        for (i, c) in text.as_bytes().iter().enumerate() {
869            propogate_epsilon(&mut list, i);
870            for (from, state) in self.states.iter().enumerate() {
871                if !list[from] {
872                    continue;
873                }
874
875                for U8Transition { to, symbol } in &state.transitions {
876                    if symbol.check(*c) {
877                        new_list[*to] = true;
878                    }
879                }
880            }
881            let tmp = list;
882            list = new_list;
883            new_list = tmp;
884            new_list.fill(false);
885        }
886        propogate_epsilon(&mut list, text.len());
887        return *list.last().unwrap_or(&false);
888    }
889}
890impl std::fmt::Display for U8NFA {
891    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
892        for (i, state) in self.states.iter().enumerate() {
893            writeln!(f, "State {i}:")?;
894            for e in &state.epsilons {
895                writeln!(f, "  {e}")?;
896            }
897            for t in &state.transitions {
898                writeln!(f, "  {t}")?;
899            }
900        }
901        return Ok(());
902    }
903}
904
905#[cfg(test)]
906mod tests {
907    use super::*;
908    use crate::{config::Config, parse_tree::ERE, simplified_tree::SimplifiedTreeNode};
909
910    #[test]
911    fn abbc_raw() {
912        let nfa = U8NFA {
913            states: vec![
914                U8State::new().with_transition(1, b'a'.into()),
915                U8State::new().with_transition(2, b'b'.into()),
916                U8State::new()
917                    .with_transition(3, b'c'.into())
918                    .with_epsilon(1),
919                U8State::new(),
920            ],
921        };
922        println!("{}", nfa.to_tikz(true));
923
924        assert!(nfa.test("abc"));
925        assert!(nfa.test("abbc"));
926        assert!(nfa.test("abbbc"));
927        assert!(nfa.test("abbbbc"));
928
929        assert!(!nfa.test("ac"));
930        assert!(!nfa.test("abcc"));
931        assert!(!nfa.test("bac"));
932        assert!(!nfa.test("acb"));
933    }
934
935    #[test]
936    fn phone_number() {
937        let ere = ERE::parse_str(r"^(\+1 )?[0-9]{3}-[0-9]{3}-[0-9]{4}$").unwrap();
938        let (tree, capture_groups) = SimplifiedTreeNode::from_ere(&ere, &Config::default());
939        assert_eq!(capture_groups, 2);
940        let nfa = WorkingNFA::new(&tree);
941        let nfa = U8NFA::new(&nfa);
942        println!("{}", nfa.to_tikz(true));
943
944        assert!(nfa.test("012-345-6789"));
945        assert!(nfa.test("987-654-3210"));
946        assert!(nfa.test("+1 555-555-5555"));
947        assert!(nfa.test("123-555-9876"));
948
949        assert!(!nfa.test("abcd"));
950        assert!(!nfa.test("0123456789"));
951        assert!(!nfa.test("012--345-6789"));
952        assert!(!nfa.test("(555) 555-5555"));
953        assert!(!nfa.test("1 555-555-5555"));
954    }
955
956    #[test]
957    fn double_loop() {
958        let ere = ERE::parse_str(r"^.*(.*)*$").unwrap();
959        let (tree, capture_groups) = SimplifiedTreeNode::from_ere(&ere, &Config::default());
960        assert_eq!(capture_groups, 2);
961        let nfa = WorkingNFA::new(&tree);
962        let nfa = U8NFA::new(&nfa);
963        // println!("{}", nfa.to_tikz(true));
964
965        assert!(nfa.test(""));
966        assert!(nfa.test("asdf"));
967        assert!(nfa.test("1234567"));
968        assert!(nfa.test("0"));
969
970        assert!(!nfa.test("\0"));
971    }
972
973    #[test]
974    fn good_anchored_start() {
975        let ere = ERE::parse_str(r"^a|b*^c|d^|n").unwrap();
976        let (tree, capture_groups) = SimplifiedTreeNode::from_ere(&ere, &Config::default());
977        assert_eq!(capture_groups, 1);
978        let nfa = WorkingNFA::new(&tree);
979        let nfa = U8NFA::new(&nfa);
980        // println!("{}", nfa.to_tikz(true));
981
982        assert!(nfa.test("a"));
983        assert!(nfa.test("c"));
984        assert!(nfa.test("cq"));
985        assert!(nfa.test("wwwnwww"));
986
987        assert!(!nfa.test(""));
988        assert!(!nfa.test("qb"));
989        assert!(!nfa.test("qc"));
990        assert!(!nfa.test("b"));
991        assert!(!nfa.test("bc"));
992        assert!(!nfa.test("bbbbbbc"));
993        assert!(!nfa.test("d"));
994    }
995
996    #[test]
997    fn good_anchored_end() {
998        let ere = ERE::parse_str(r"a$|b$c*|$d|n").unwrap();
999        let (tree, capture_groups) = SimplifiedTreeNode::from_ere(&ere, &Config::default());
1000        assert_eq!(capture_groups, 1);
1001        let nfa = WorkingNFA::new(&tree);
1002        let nfa = U8NFA::new(&nfa);
1003        println!("{}", nfa.to_tikz(true));
1004
1005        assert!(nfa.test("a"));
1006        assert!(nfa.test("b"));
1007        assert!(nfa.test("qb"));
1008        assert!(nfa.test("wwwnwww"));
1009
1010        assert!(!nfa.test(""));
1011        assert!(!nfa.test("bq"));
1012        assert!(!nfa.test("qc"));
1013        assert!(!nfa.test("c"));
1014        assert!(!nfa.test("bc"));
1015        assert!(!nfa.test("bcccccc"));
1016        assert!(!nfa.test("d"));
1017    }
1018
1019    #[test]
1020    fn range_digit() {
1021        let ere = ERE::parse_str(r"^[[:digit:].]$").unwrap();
1022        let (tree, capture_groups) = SimplifiedTreeNode::from_ere(&ere, &Config::default());
1023        assert_eq!(capture_groups, 1);
1024        let nfa = WorkingNFA::new(&tree);
1025        let nfa = U8NFA::new(&nfa);
1026        println!("{}", nfa.to_tikz(true));
1027
1028        assert!(nfa.test("0"));
1029        assert!(nfa.test("1"));
1030        assert!(nfa.test("9"));
1031        assert!(nfa.test("."));
1032
1033        assert!(!nfa.test(""));
1034        assert!(!nfa.test("a"));
1035        assert!(!nfa.test("11"));
1036        assert!(!nfa.test("1."));
1037        assert!(!nfa.test(".2"));
1038        assert!(!nfa.test("09"));
1039        assert!(!nfa.test("d"));
1040    }
1041
1042    #[test]
1043    fn dot() {
1044        let nfa = U8NFA::nfa_symbol(&&Atom::CharClass(crate::parse_tree::CharClass::Dot));
1045        assert!(!nfa.test("\0"));
1046        for c in '\u{0001}'..=char::MAX {
1047            let txt = c.to_string();
1048            let mut bytes = [0; 4];
1049            c.encode_utf8(&mut bytes);
1050            assert!(
1051                nfa.test(&txt),
1052                "Expected {c} (code point: 0x{:X}, utf8: 0x{:02X}{:02X}{:02X}{:02X}) to be matched by regex dot.", c as u32, bytes[0], bytes[1], bytes[2], bytes[3]
1053            );
1054        }
1055
1056        let ere = ERE::parse_str(r"^.$").unwrap();
1057        let (tree, capture_groups) = SimplifiedTreeNode::from_ere(&ere, &Config::default());
1058        assert_eq!(capture_groups, 1);
1059        let nfa = WorkingNFA::new(&tree);
1060        let nfa = U8NFA::new(&nfa);
1061        println!("{}", nfa.to_tikz(true));
1062
1063        assert!(nfa.test("0"));
1064        assert!(nfa.test("1"));
1065        assert!(nfa.test("a"));
1066        assert!(nfa.test("\u{0001}"));
1067        assert!(nfa.test("9"));
1068        assert!(nfa.test("."));
1069        assert!(nfa.test("\u{1234}"));
1070
1071        assert!(!nfa.test(""));
1072        assert!(!nfa.test("\0"));
1073        assert!(!nfa.test("ab"));
1074        assert!(!nfa.test("11"));
1075        assert!(!nfa.test("1."));
1076        assert!(!nfa.test(".2"));
1077        assert!(!nfa.test("09"));
1078        assert!(!nfa.test("\u{1234}\u{4321}"));
1079    }
1080}