ere_core/
working_u8_dfa.rs

1//! Working datastructure for a tagged DFA over `u8`s.
2//! Primarily intended for use at compile time, converted from [`crate::U8NFA`].
3//!
4//! For more information, read https://en.wikipedia.org/wiki/Tagged_Deterministic_Finite_Automaton
5//!
6//! Additional references:
7//! - [NFAs with Tagged Transitions, their Conversion to Deterministic Automata and Application to Regular Expressions](https://laurikari.net/ville/spire2000-tnfa.pdf) by Ville Laurikari, 2000
8//! - [Tagged Deterministic Finite Automata with Lookahead](https://arxiv.org/pdf/1907.08837) by Ulya Trofimovich, 2019
9
10use std::{collections::HashSet, ops::RangeInclusive};
11
12use crate::{
13    epsilon_propogation::{EpsilonPropogation, Tag},
14    working_u8_nfa::U8NFA,
15};
16
17/// Represents the index of a NFA state from the original U8NFA used to produce a DFA.
18#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
19pub struct SubNFAStateID(pub usize);
20
21#[derive(Debug)]
22pub struct U8DFATransition {
23    pub to: usize,
24    pub symbol: RangeInclusive<u8>,
25    /// For each of the NFA states the new DFA state represents, the index of one of the previous DFA state's NFA states.
26    ///
27    /// The indexed NFA state's tags will be copied to the new NFA state's tags, before any updates are applied.
28    ///
29    /// basically, if a DFA state representing `N` NFA states has `old_tags: [NFAStateTags; N]` then
30    /// the new DFA state representing `copy_tags.len()` NFA states will have
31    /// `new_tags: copy_tags.map(|i| old_tags[i])`.
32    /// This means index is by local NFA state index, not global NFA state index.
33    pub copy_tags: Vec<usize>,
34    /// After tags are copied, these tags will be updated.
35    ///
36    /// Using local NFA state indices on the new DFA state.
37    pub add_tags: Vec<(usize, Tag)>,
38}
39
40/// Final epsilon-like transition when at end, allows end anchors within it
41pub struct U8DFAAcceptTransition {
42    /// Local NFA state index that we move to accept from (and to copy tags from)
43    pub nfa_state: usize,
44    pub add_tags: Vec<Tag>,
45}
46impl U8DFAAcceptTransition {
47    pub fn from_epsilon_prop(local_from_idx: usize, epsilon_prop: &EpsilonPropogation) -> Self {
48        return Self {
49            nfa_state: local_from_idx,
50            add_tags: epsilon_prop.update_tags.clone(),
51        };
52    }
53    /// return 0 if there are no capture groups
54    pub fn max_capture_group(&self) -> usize {
55        return self
56            .add_tags
57            .iter()
58            .map(Tag::capture_group)
59            .max()
60            .unwrap_or(0);
61    }
62}
63
64pub enum U8DFAAccept {
65    /// If there are only end-anchored accept(s), this is the highest priority one.
66    Anchored(U8DFAAcceptTransition),
67    /// If there are both end-anchored and non-end-anchored accept(s),
68    /// where the highest anchored one is higher priority than the highest non-anchored one:
69    ///
70    /// Is a pair `(anchored, non_anchored)`
71    Both(U8DFAAcceptTransition, U8DFAAcceptTransition),
72    /// If there is a non-end-anchored accept(s), with no higher priority anchored accept(s).
73    Unanchored(U8DFAAcceptTransition),
74    /// If there are no accept(s).
75    None,
76}
77impl U8DFAAccept {
78    /// ## Params
79    /// - `local_from_idx` is the local NFA state index that we move to accept *from* (and to copy tags from)
80    /// - `epsilon_prop` is the epsilon propogations of the NFA state
81    /// - `accept_state_idx` is the index of the accept state in the NFA
82    pub fn from_epsilon_prop<'a>(
83        local_from_idx: usize,
84        epsilon_prop: impl IntoIterator<Item = &'a EpsilonPropogation>,
85        accept_state_idx: usize,
86    ) -> U8DFAAccept {
87        let accept_transitions: Vec<_> = epsilon_prop
88            .into_iter()
89            .filter(|ep| ep.state == accept_state_idx)
90            .collect();
91        let anchored_accept = accept_transitions
92            .iter()
93            .cloned()
94            .enumerate()
95            .find(|(_, ep)| ep.end_only);
96        let unanchored_accept = accept_transitions
97            .iter()
98            .cloned()
99            .enumerate()
100            .find(|(_, ep)| !ep.end_only);
101
102        match (anchored_accept, unanchored_accept) {
103            (Some((_, anchored)), None) => U8DFAAccept::Anchored(
104                U8DFAAcceptTransition::from_epsilon_prop(local_from_idx, anchored),
105            ),
106            (None, Some((_, unanchored))) => U8DFAAccept::Unanchored(
107                U8DFAAcceptTransition::from_epsilon_prop(local_from_idx, unanchored),
108            ),
109            (None, None) => U8DFAAccept::None,
110            (Some((anchored_idx, anchored)), Some((unanchored_idx, unanchored))) => {
111                if anchored_idx < unanchored_idx {
112                    // anchored is higher priority
113                    // So we get the tags for the anchored first (if at end), and otherwise we get the tags for the unanchored
114                    U8DFAAccept::Both(
115                        U8DFAAcceptTransition::from_epsilon_prop(local_from_idx, anchored),
116                        U8DFAAcceptTransition::from_epsilon_prop(local_from_idx, unanchored),
117                    )
118                } else {
119                    // unanchored is higher priority
120                    // Since unanchored works even at the end, we don't need an extra anchored transition
121                    U8DFAAccept::Unanchored(U8DFAAcceptTransition::from_epsilon_prop(
122                        local_from_idx,
123                        unanchored,
124                    ))
125                }
126            }
127        }
128    }
129    /// Updates the accept transitions with more transitions that take lower priority.
130    ///
131    /// This is basically equivalent to [`Self::from_epsilon_prop`] with the combined epsilon propogations of the two.
132    pub fn update_with_lower_priority(self, other: U8DFAAccept) -> U8DFAAccept {
133        match (self, other) {
134            (U8DFAAccept::Anchored(a), U8DFAAccept::Unanchored(b) | U8DFAAccept::Both(_, b)) => {
135                U8DFAAccept::Both(a, b)
136            }
137            (this @ (U8DFAAccept::Both(_, _) | U8DFAAccept::Unanchored(_)), _) => this,
138            (this @ U8DFAAccept::Anchored(_), U8DFAAccept::Anchored(_)) => this,
139            (U8DFAAccept::None, other) => other,
140            (this, U8DFAAccept::None) => this,
141        }
142    }
143}
144
145pub struct U8DFAState {
146    /// Each DFA state represents a subset of the NFA states.
147    ///
148    /// When executed, the dfa state will store the tags for each of the NFA states it represents.
149    ///
150    /// Should be sorted by priority order of the NFA state threads it represents.
151    /// State uniqueness includes priority order, so multiple DFA states may represent the same set of NFA states,
152    /// just in different orders.
153    pub nfa_states: Vec<SubNFAStateID>,
154    pub transitions: Vec<U8DFATransition>,
155    /// The highest-priority zero-length (i.e. epsilon) transition(s) to the accept state
156    /// from the NFA state(s) this DFA state represents.
157    pub accept: U8DFAAccept,
158}
159impl U8DFAState {
160    /// Creates a new start state for the DFA and expands it to create stubs for all the states
161    /// it has transitions to. Unlike normal states, the start state's transitions are generated
162    /// including transitions in the NFA with start anchors.
163    ///
164    /// ## Returns
165    /// A pair `(start_state, new_states)`
166    ///
167    /// where `new_states` are the initial set of states and all need to be expanded with [`U8DFAState::expand`].
168    pub fn new_start_state(nfa: &U8NFA) -> (U8DFAState, Vec<U8DFAState>) {
169        let epsilon_prop: Vec<EpsilonPropogation> =
170            EpsilonPropogation::calculate_epsilon_propogations_u8(nfa, 0);
171        let accept = U8DFAAccept::from_epsilon_prop(0, &epsilon_prop, nfa.states.len() - 1);
172
173        let transitions: Vec<_> = epsilon_prop
174            .iter()
175            .filter(|ep| !ep.end_only)
176            .flat_map(|ep| {
177                nfa.states[ep.state]
178                    .transitions
179                    .iter()
180                    .map(|tr| (tr.symbol.0.clone(), (ep.clone(), tr)))
181            })
182            .collect();
183        let transitions = transitions
184            .iter()
185            .map(|(range, value)| (range.clone(), value));
186        let mut byte_ranges_transitions = split_ranges_u8(transitions);
187        for (_, nfa_tr) in &mut byte_ranges_transitions {
188            // remove lower priority transitions to the same nfa state
189            nfa_tr.dedup_by_key_all(|tr| tr.1.to);
190        }
191
192        let mut new_states = Vec::new();
193        let mut start_state = U8DFAState {
194            nfa_states: vec![SubNFAStateID(0)],
195            transitions: Vec::new(),
196            accept,
197        };
198        for (range, nfa_tr) in byte_ranges_transitions {
199            let nfa_states = nfa_tr.iter().map(|(_, tr)| SubNFAStateID(tr.to)).collect();
200
201            let new_state = U8DFAState {
202                nfa_states,
203                transitions: Vec::new(),   // will do when expanded
204                accept: U8DFAAccept::None, // will do when expanded
205            };
206            let new_state_idx = new_states.len();
207            new_states.push(new_state);
208
209            let add_tags = nfa_tr
210                .iter()
211                .enumerate()
212                .flat_map(|(i, (ep, _))| ep.update_tags.iter().map(move |tag| (i, tag.clone())))
213                .collect();
214            let dfa_tr = U8DFATransition {
215                to: new_state_idx,
216                symbol: range,
217                copy_tags: vec![0; nfa_tr.len()],
218                add_tags,
219            };
220            start_state.transitions.push(dfa_tr);
221        }
222        return (start_state, new_states);
223    }
224    /// Given a state with `nfa_states` set but empty transitions and no accept,
225    /// expands the state to include all possible transitions.
226    ///
227    /// ## Params
228    /// - `nfa` is the original nfa
229    /// - `curr_dfa_states` is the current list of states in the DFA.
230    ///   The returned new states will be appended to the end of this list.
231    ///
232    /// ## Returns
233    /// A list of new states, which will be added to [`U8DFA::states`]. They will only have `nfa_states` set,
234    /// and thus will need to have `expand` called on them to get the full set of transitions.
235    fn expand(&mut self, nfa: &U8NFA, curr_dfa_states: &[U8DFAState]) -> Vec<U8DFAState> {
236        assert!(self.transitions.is_empty());
237        assert!(matches!(self.accept, U8DFAAccept::None));
238
239        struct SubNFATransition {
240            /// Local NFA state index in the DFA state we are expanding
241            from: usize,
242            /// The epsilon propogation before the symbol transition
243            ep: EpsilonPropogation,
244            /// The index of the destination NFA state in [`U8NFA::states`]
245            to: SubNFAStateID,
246        }
247
248        let mut transitions = Vec::new();
249        for (local_nfa_state_idx, nfa_state_id) in self.nfa_states.iter().enumerate() {
250            // maintaining priority: self.nfa_states is sorted by priority
251
252            let epsilon_prop: Vec<EpsilonPropogation> =
253                EpsilonPropogation::calculate_epsilon_propogations_u8(nfa, nfa_state_id.0);
254
255            let mut tmp = U8DFAAccept::None;
256            std::mem::swap(&mut tmp, &mut self.accept);
257            self.accept = tmp.update_with_lower_priority(U8DFAAccept::from_epsilon_prop(
258                local_nfa_state_idx,
259                &epsilon_prop,
260                nfa.states.len() - 1,
261            ));
262
263            for ep in epsilon_prop {
264                if ep.start_only || ep.end_only {
265                    continue;
266                }
267                let nfa_prop_state = &nfa.states[ep.state];
268                for tr in &nfa_prop_state.transitions {
269                    let symbol = tr.symbol.0.clone();
270                    let nfa_tr = SubNFATransition {
271                        from: local_nfa_state_idx,
272                        ep: ep.clone(),
273                        to: SubNFAStateID(tr.to),
274                    };
275                    transitions.push((symbol, nfa_tr));
276                }
277            }
278        }
279
280        let mut new_states: Vec<U8DFAState> = Vec::new();
281        // now we have all nfa transitions from all the nfa states
282        // we need to combine and split them into byte ranges
283        let transitions = transitions
284            .iter()
285            .map(|(range, value)| (range.clone(), value));
286        let byte_ranges_transitions = split_ranges_u8(transitions);
287        for (range, mut nfa_tr) in byte_ranges_transitions {
288            // remove lower priority transitions to the same nfa state
289            nfa_tr.dedup_by_key_all(|tr| tr.to);
290
291            let nfa_states = nfa_tr.iter().map(|nfa_tr| nfa_tr.to).collect();
292            let new_state_idx = curr_dfa_states
293                .iter()
294                .enumerate()
295                .find(|(_, existing_state)| existing_state.nfa_states == nfa_states)
296                .map(|(i, _)| i)
297                .or_else(|| {
298                    new_states
299                        .iter()
300                        .enumerate()
301                        .find(|(_, new_state)| new_state.nfa_states == nfa_states)
302                        .map(|(i, _)| i + curr_dfa_states.len())
303                })
304                .unwrap_or_else(|| curr_dfa_states.len() + new_states.len());
305            if new_state_idx >= new_states.len() + curr_dfa_states.len() {
306                // new state needs to be added
307                new_states.push(U8DFAState {
308                    nfa_states,
309                    transitions: Vec::new(),   // will do when expanded
310                    accept: U8DFAAccept::None, // will do when expanded
311                });
312            }
313
314            let add_tags = nfa_tr
315                .iter()
316                .enumerate()
317                .flat_map(|(i, nfa_tr)| {
318                    nfa_tr
319                        .ep
320                        .update_tags
321                        .iter()
322                        .map(move |tag| (i, tag.clone()))
323                })
324                .collect();
325            let copy_tags = nfa_tr.iter().map(|tr| tr.from).collect();
326            let dfa_tr = U8DFATransition {
327                to: new_state_idx,
328                symbol: range,
329                copy_tags,
330                add_tags,
331            };
332            self.transitions.push(dfa_tr);
333        }
334
335        return new_states;
336    }
337}
338
339/// A DFA over `u8`s, fully constructed. Intended for use at compile time.
340pub struct U8DFA {
341    /// The 'start' state is not stored in the `states` vec, and can never be transitioned to.
342    /// This allows us to exclude transitions with start anchors from all other states,
343    /// while implicitly including them in the start state.
344    ///
345    /// The start state always begins with one NFA state (0) with no tags.
346    pub start_state: U8DFAState,
347    /// Unique by [`U8DFAState::nfa_states`], including priority order.
348    pub states: Vec<U8DFAState>,
349}
350impl U8DFA {
351    /// Creates a TDFA from a TNFA. This should be the primary way to create a `U8DFA`.
352    ///
353    /// Since the size of a DFA is worst-case exponential in the number of NFA states,
354    /// the maximum number of states is `max_states`.
355    /// If the number of states exceeds `max_states` then `None` will be returned.
356    pub fn from_nfa(nfa: &U8NFA, max_states: usize) -> Option<Self> {
357        let mut states; // Created DFA states (except start)
358        let mut stack = Vec::new(); // DFA states to expand out from. added on creation
359
360        let start_state = {
361            let (start_state, new_states) = U8DFAState::new_start_state(nfa);
362            stack.extend(0..new_states.len());
363            states = new_states;
364
365            if states.len() > max_states {
366                return None;
367            }
368
369            start_state
370        };
371
372        // other states
373        while let Some(dfa_state_id) = stack.pop() {
374            let mut dfa_state = U8DFAState {
375                nfa_states: states[dfa_state_id].nfa_states.clone(),
376                transitions: Default::default(),
377                accept: U8DFAAccept::None,
378            };
379            std::mem::swap(&mut dfa_state, &mut states[dfa_state_id]);
380            let new_states = dfa_state.expand(nfa, &states);
381            states[dfa_state_id] = dfa_state;
382
383            for new_state in new_states {
384                if states.iter().any(|s| s.nfa_states == new_state.nfa_states) {
385                    continue;
386                }
387                stack.push(states.len());
388                states.push(new_state);
389
390                if states.len() > max_states {
391                    return None;
392                }
393            }
394        }
395
396        return Some(U8DFA {
397            start_state,
398            states,
399        });
400    }
401
402    pub fn num_capture_groups(&self) -> usize {
403        let mut max = 0;
404
405        for state in self.states.iter().chain(std::iter::once(&self.start_state)) {
406            for tr in &state.transitions {
407                for (_, tag) in &tr.add_tags {
408                    max = std::cmp::max(max, tag.capture_group());
409                }
410            }
411            match &state.accept {
412                U8DFAAccept::Anchored(t) | U8DFAAccept::Unanchored(t) => {
413                    max = std::cmp::max(max, t.max_capture_group());
414                }
415                U8DFAAccept::Both(t1, t2) => {
416                    max = std::cmp::max(max, t1.max_capture_group());
417                    max = std::cmp::max(max, t2.max_capture_group());
418                }
419                U8DFAAccept::None => {}
420            }
421        }
422        return max + 1;
423    }
424
425    /// Returns the default bound for the DFA.
426    ///
427    /// Should probably be tuned.
428    pub fn default_bound(nfa_states: usize) -> usize {
429        return std::cmp::max(100, nfa_states * 2);
430    }
431}
432
433impl U8DFA {
434    /// Writes a LaTeX TikZ representation to visualize the graph.
435    ///
436    /// If `include_doc` is `true`, will include the headers.
437    /// Otherwise, you should include `\usepackage{tikz}` and `\usetikzlibrary{automata, positioning}`.
438    pub fn to_tikz(&self, include_doc: bool) -> String {
439        let accept_state = crate::visualization::LatexGraphState {
440            label: "accept".to_string(),
441            transitions: Vec::new(),
442            initial: false,
443            accept: true,
444        };
445        fn make_label(nfa_indices: &[SubNFAStateID]) -> String {
446            let nfa_indices: Vec<_> = nfa_indices.iter().map(|s| s.0.to_string()).collect();
447            return nfa_indices.join(",");
448        }
449        let map_state =
450            |(i, state): (usize, &U8DFAState)| -> crate::visualization::LatexGraphState {
451                let transitions =
452                    state
453                        .transitions
454                        .iter()
455                        .map(|t| crate::visualization::LatexGraphTransition {
456                            label: crate::visualization::escape_latex(
457                                DisplayRange(t.symbol.clone()).to_string(),
458                            ),
459                            to: t.to + 1,
460                        });
461                let accept = match &state.accept {
462                    U8DFAAccept::Anchored(_) => Some(crate::visualization::LatexGraphTransition {
463                        to: self.states.len() + 1,
464                        label: "\\$".to_string(),
465                    }),
466                    U8DFAAccept::Both(_, _) => Some(crate::visualization::LatexGraphTransition {
467                        to: self.states.len() + 1,
468                        label: String::new(),
469                    }),
470                    U8DFAAccept::Unanchored(_) => {
471                        Some(crate::visualization::LatexGraphTransition {
472                            to: self.states.len() + 1,
473                            label: String::new(),
474                        })
475                    }
476                    U8DFAAccept::None => None,
477                };
478
479                let transitions = transitions.chain(accept).collect();
480
481                return crate::visualization::LatexGraphState {
482                    label: make_label(&state.nfa_states),
483                    transitions,
484                    initial: i == 0,
485                    accept: false,
486                };
487            };
488
489        let graph = crate::visualization::LatexGraph {
490            states: std::iter::once(&self.start_state)
491                .chain(self.states.iter())
492                .enumerate()
493                .map(map_state)
494                .chain(std::iter::once(accept_state))
495                .collect(),
496        };
497        return graph.to_tikz(include_doc);
498    }
499}
500
501/// Splits overlapping ranges so they are fully overlapping and/or non-overlapping.
502/// This essentially makes the ranges disjoint, while maintaining the associated values for each u8.
503///
504/// E.g. `[(0..=5, 'a'), (3..=10, 'b')]` becomes `[(0..=2, ['a']), (3..=5, ['a', 'b']), (6..=10, ['b'])]`
505///
506/// Within each range, the items keep their order.
507fn split_ranges_u8<'a, T>(
508    items: impl IntoIterator<Item = (RangeInclusive<u8>, &'a T)>,
509) -> Vec<(RangeInclusive<u8>, Vec<&'a T>)> {
510    fn same<'a, T>(a: &[&'a T], b: &[&'a T]) -> bool {
511        return a.len() == b.len() && std::iter::zip(a, b).all(|(&a, &b)| std::ptr::eq(a, b));
512    }
513    let mut value_ranges: Box<[_; 256]> = vec![Vec::new(); 256]
514        .into_boxed_slice()
515        .try_into()
516        .unwrap_or_else(|_| unreachable!("Just allocated with size"));
517    for (range, value) in items {
518        for i in range {
519            value_ranges[i as usize].push(value);
520        }
521    }
522
523    let mut out = Vec::new();
524    let mut prev_items = Vec::new();
525    let mut prev_start = u8::MIN;
526    for (i, items) in value_ranges.into_iter().enumerate() {
527        let i = i as u8;
528        if !same(&items, &prev_items) {
529            if !prev_items.is_empty() {
530                out.push((prev_start..=i - 1, prev_items));
531            }
532            prev_items = items;
533            prev_start = i;
534        }
535    }
536    if !prev_items.is_empty() {
537        out.push((prev_start..=u8::MAX, prev_items));
538    }
539
540    debug_assert!(out.iter().all(|(_, items)| items.len() > 0));
541    debug_assert!(out.windows(2).all(|w| w[0].0.end() < w[1].0.start()));
542    return out;
543}
544
545trait VecExt<T> {
546    /// Deduplicates the vector by key, keeping the first occurrence of each key.
547    /// Unlike [`Vec::dedup_by_key`], this method removes all duplicates, not just adjacent ones.
548    fn dedup_by_key_all<K: Eq + std::hash::Hash>(&mut self, key: impl Fn(&T) -> K);
549}
550impl<T> VecExt<T> for Vec<T> {
551    fn dedup_by_key_all<K: Eq + std::hash::Hash>(&mut self, key: impl Fn(&T) -> K) {
552        // TODO: we could use a vec for smaller sizes
553        let mut seen = HashSet::new();
554        self.retain(|x| {
555            let k = key(x);
556            if seen.contains(&k) {
557                return false;
558            }
559            seen.insert(k);
560            true
561        });
562    }
563}
564
565/// Newtype for displaying bytes as characters.
566/// - Printable ascii characters are printed as themselves (or their escaped versions)
567/// - Other characters are printed as their hex value
568struct DisplayByteChar(u8);
569impl std::fmt::Display for DisplayByteChar {
570    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
571        match self.0 {
572            b if b.is_ascii_graphic() => write!(f, "{:?}", b as char),
573            b'\t' => write!(f, "'\\t'"),
574            b'\n' => write!(f, "'\\n'"),
575            b'\r' => write!(f, "'\\r'"),
576            b' ' => write!(f, "' '"),
577            b if b.is_ascii_whitespace() => write!(f, "{:?}", b as char),
578            _ => write!(f, "0x{:02x}", self.0),
579        }
580    }
581}
582
583/// Newtype for displaying a byte range for transitions.
584/// See [`DisplayByteChar`] for details.
585struct DisplayRange(RangeInclusive<u8>);
586impl std::fmt::Display for DisplayRange {
587    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
588        let start = self.0.start();
589        let end = self.0.end();
590        if start == end {
591            return write!(f, "{}", DisplayByteChar(*start));
592        } else {
593            return write!(f, "{}..={}", DisplayByteChar(*start), DisplayByteChar(*end));
594        }
595    }
596}
597
598#[cfg(test)]
599mod tests {
600    use crate::{
601        config::Config, parse_tree::ERE, simplified_tree::SimplifiedTreeNode,
602        working_nfa::WorkingNFA,
603    };
604
605    use super::*;
606
607    #[test]
608    fn phone_number() {
609        let ere = ERE::parse_str(r"^(\+1 )?[0-9]{3}-[0-9]{3}-[0-9]{4}$").unwrap();
610        let (tree, _) = SimplifiedTreeNode::from_ere(&ere, &Config::default());
611        // assert_eq!(capture_groups, 2);
612        let nfa = WorkingNFA::new(&tree);
613        let nfa = U8NFA::new(&nfa);
614        let nfa = U8DFA::from_nfa(&nfa, 100).unwrap();
615
616        // println!("{}", nfa.to_tikz(true));
617    }
618}