ere_core/engines/
flat_lockstep_nfa.rs

1//! Implements an nfa-like regex engine for over `char`s.
2//! The engine keeps all threads in lockstep (all threads are at the same input index),
3//! and the NFA's epsilon transitions are flattened to a single epsilon transition between symbols
4//! (including handling anchors and capture tags).
5//!
6//! Currently we flatten all epsilon transitions for the VM so that epsilon transitions are at most a single step between symbols.
7//! I'll have to review to ensure we avoid this causing large binary size overhead,
8//! but it should be worst-case `O(n^2)` in the number of states, and far fewer on average.
9
10use crate::{
11    epsilon_propogation::{EpsilonPropogation, Tag},
12    nfa_static,
13    working_nfa::{WorkingNFA, WorkingTransition},
14};
15use quote::{quote, ToTokens, TokenStreamExt};
16use std::fmt::Write;
17
18#[derive(Clone)]
19pub struct Thread<const N: usize, S: Send + Sync + Copy + Eq> {
20    pub state: S,
21    pub captures: [(usize, usize); N],
22}
23impl<const N: usize, S: Send + Sync + Copy + Eq + std::fmt::Debug> std::fmt::Debug
24    for Thread<N, S>
25{
26    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
27        struct CapturesDebug<'a, const N: usize>(&'a [(usize, usize); N]);
28        impl<'a, const N: usize> std::fmt::Debug for CapturesDebug<'a, N> {
29            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
30                f.write_char('[')?;
31                for (i, endpoints) in self.0.iter().enumerate() {
32                    if i != 0 {
33                        f.write_str(", ")?;
34                    }
35                    match endpoints {
36                        (usize::MAX, usize::MAX) => f.write_str("(_, _)")?,
37                        (start, usize::MAX) => write!(f, "({start}, _)")?,
38                        (usize::MAX, end) => write!(f, "(_, {end})")?,
39                        (start, end) => write!(f, "({start}, {end})")?,
40                    }
41                }
42                return f.write_char(']');
43            }
44        }
45        return f
46            .debug_struct("Thread")
47            .field("state", &self.state)
48            .field("captures", &CapturesDebug(&self.captures))
49            .finish();
50    }
51}
52
53/// The NFA and some precomputed data to go with it.
54struct CachedNFA<'a> {
55    nfa: &'a WorkingNFA,
56    excluded_states: Vec<bool>,
57    capture_groups: usize,
58}
59impl<'a> CachedNFA<'a> {
60    fn new(nfa: &'a WorkingNFA) -> CachedNFA<'a> {
61        let excluded_states = compute_excluded_states(nfa);
62        assert_eq!(nfa.states.len(), excluded_states.len());
63        let capture_groups = nfa.num_capture_groups();
64        return CachedNFA {
65            nfa,
66            excluded_states,
67            capture_groups,
68        };
69    }
70}
71
72/// Since we are shortcutting the epsilon transitions, we can skip printing
73/// states that have only epsilon transitions and are not the start/end states
74fn compute_excluded_states(nfa: &WorkingNFA) -> Vec<bool> {
75    let mut out = vec![true; nfa.states.len()];
76    out[0] = false;
77    out[nfa.states.len() - 1] = false;
78    for (from, state) in nfa.states.iter().enumerate() {
79        for t in &state.transitions {
80            out[from] = false;
81            out[t.to] = false;
82        }
83    }
84
85    return out;
86}
87
88#[derive(Clone, Copy, PartialEq, Eq)]
89struct ImplVMStateLabel(usize);
90impl ToTokens for ImplVMStateLabel {
91    fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
92        let ImplVMStateLabel(idx) = self;
93        let label = format!("State{idx}");
94        tokens.append(proc_macro2::Ident::new(
95            &label,
96            proc_macro2::Span::call_site(),
97        ));
98    }
99}
100
101mod impl_test {
102    use quote::ToTokens;
103
104    use super::*;
105
106    /// Implements symbol transitions for a single state
107    struct ImplTransitionStateSymbol<'a> {
108        transition: &'a WorkingTransition,
109    }
110    impl<'a> ToTokens for ImplTransitionStateSymbol<'a> {
111        fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
112            let ImplTransitionStateSymbol { transition } = self;
113            let WorkingTransition { symbol, to } = transition;
114            let symbol = nfa_static::AtomStatic::serialize_as_token_stream(symbol);
115            tokens.extend(quote! {{
116                let symbol = #symbol;
117                if symbol.check(c) {
118                    new_list[#to] = true;
119                }
120            }});
121        }
122    }
123
124    /// Assumes the `VMStates` enum is already created locally in the token stream
125    ///
126    /// Creates the function `transition_symbols_test` for running symbol transitions on the engine
127    pub(super) struct TransitionSymbols<'a>(pub &'a CachedNFA<'a>);
128    impl<'a> ToTokens for TransitionSymbols<'a> {
129        fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
130            let TransitionSymbols(nfa) = self;
131            let CachedNFA {
132                nfa,
133                excluded_states,
134                ..
135            } = nfa;
136
137            let transition_symbols_defs_test = nfa
138                .states
139                .iter()
140                .enumerate()
141                .filter(|(i, _)| !excluded_states[*i])
142                .map(|(i, state)| {
143                    let state_transitions = state
144                        .transitions
145                        .iter()
146                        .map(|t| ImplTransitionStateSymbol { transition: t });
147
148                    return quote! {
149                        if list[#i] {
150                            #(#state_transitions)*
151                        }
152                    };
153                });
154
155            tokens.extend(quote! {
156                fn transition_symbols_test(
157                    list: &[bool],
158                    new_list: &mut [bool],
159                    c: char,
160                ) {
161                    #(#transition_symbols_defs_test)*
162                }
163            });
164        }
165    }
166
167    /// Implements epsilon transitions for a single state
168    ///
169    /// Becomes:
170    /// ```ignore
171    /// if list[#from_state] {
172    ///     // ...
173    /// }
174    /// ```
175    pub(super) struct ImplTransitionStateEpsilon<'a> {
176        pub(super) from_state: ImplVMStateLabel,
177        pub(super) thread_updates: &'a [ThreadUpdates],
178    }
179    impl<'a> ToTokens for ImplTransitionStateEpsilon<'a> {
180        fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
181            let &ImplTransitionStateEpsilon {
182                from_state,
183                thread_updates,
184            } = self;
185            let ImplVMStateLabel(from_state) = from_state;
186
187            // Write epsilon-propogation of threads to the token stream for test
188            let start_end_threads = thread_updates
189                .iter()
190                .filter(|t| t.0.start_only && t.0.end_only)
191                .map(ThreadUpdates::serialize_thread_update_test);
192            let start_threads = thread_updates
193                .iter()
194                .filter(|t| t.0.start_only && !t.0.end_only)
195                .map(ThreadUpdates::serialize_thread_update_test);
196            let end_threads = thread_updates
197                .iter()
198                .filter(|t| !t.0.start_only && t.0.end_only)
199                .map(ThreadUpdates::serialize_thread_update_test);
200            let normal_threads = thread_updates
201                .iter()
202                .filter(|t| !t.0.start_only && !t.0.end_only)
203                .map(ThreadUpdates::serialize_thread_update_test);
204
205            tokens.extend(quote! {
206                if list[#from_state] {
207                    if is_start && is_end {
208                        #(#start_end_threads)*
209                    }
210                    if is_start {
211                        #(#start_threads)*
212                    }
213                    if is_end {
214                        #(#end_threads)*
215                    }
216                    #(#normal_threads)*
217                }
218            });
219        }
220    }
221
222    /// Implements a function that runs all epsilon transitions for all threads.
223    ///
224    /// ```ignore
225    /// fn transition_epsilons_test(
226    ///     list: &mut [bool],
227    ///     idx: usize,
228    ///     len: usize,
229    /// ) {
230    ///     // ...
231    /// }
232    /// ```
233    pub(super) struct TransitionEpsilons<'a>(pub &'a CachedNFA<'a>);
234    impl<'a> ToTokens for TransitionEpsilons<'a> {
235        fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
236            let TransitionEpsilons(nfa) = self;
237            let CachedNFA {
238                nfa,
239                excluded_states,
240                ..
241            } = nfa;
242            assert_eq!(nfa.states.len(), excluded_states.len());
243            let num_states = nfa.states.len();
244
245            let states_epsilon_transitions = std::iter::zip(nfa.states.iter(), excluded_states)
246                .enumerate()
247                .filter(|(_, (_, &excluded))| !excluded)
248                .map(|(i, _)| {
249                    // all reachable states with next transition as epsilon
250                    let mut new_threads = calculate_epsilon_propogations(nfa, i);
251                    new_threads.retain(|t| {
252                        !nfa.states[t.0.state].transitions.is_empty() || t.0.state + 1 == num_states
253                    });
254
255                    let label: ImplVMStateLabel = ImplVMStateLabel(i);
256
257                    let state_epsilon_transitions_test = impl_test::ImplTransitionStateEpsilon {
258                        from_state: label,
259                        thread_updates: &new_threads,
260                    };
261                    state_epsilon_transitions_test.to_token_stream()
262                });
263
264            tokens.extend(quote! {
265                fn transition_epsilons_test(
266                    list: &mut [bool],
267                    idx: usize,
268                    len: usize,
269                ) {
270                    let is_start = idx == 0;
271                    let is_end = idx == len;
272                    #(#states_epsilon_transitions)*
273                }
274            });
275        }
276    }
277
278    pub(crate) struct TestFn<'a>(pub &'a CachedNFA<'a>);
279    impl<'a> ToTokens for TestFn<'a> {
280        fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
281            let TestFn(nfa) = self;
282            let CachedNFA {
283                nfa: u8_nfa,
284                excluded_states,
285                ..
286            } = nfa;
287
288            let enum_states = excluded_states
289                .iter()
290                .enumerate()
291                .filter_map(|(i, excluded)| match excluded {
292                    true => None,
293                    false => Some(ImplVMStateLabel(i)),
294                });
295            let state_count = u8_nfa.states.len();
296            let accept_state = ImplVMStateLabel(state_count - 1);
297
298            let transition_symbols_test = TransitionSymbols(nfa);
299            let transition_epsilons_test = TransitionEpsilons(nfa);
300
301            tokens.extend(quote! {
302                fn test(text: &str) -> bool {
303                    #[derive(Clone, Copy, PartialEq, Eq, Debug)]
304                    enum VMStates {
305                        #(#enum_states,)*
306                    }
307
308                    #transition_symbols_test
309                    #transition_epsilons_test
310
311                    let mut list = [false; #state_count];
312                    let mut new_list = [false; #state_count];
313                    list[0] = true;
314
315                    transition_epsilons_test(&mut list, 0, text.len());
316                    for (i, c) in text.char_indices() {
317                        transition_symbols_test(&list, &mut new_list, c);
318                        if new_list.iter().all(|b| !b) {
319                            return false;
320                        }
321                        ::std::mem::swap(&mut list, &mut new_list);
322                        transition_epsilons_test(&mut list, i + c.len_utf8(), text.len());
323                        new_list.fill(false);
324                    }
325
326                    return list[#state_count - 1];
327                }
328            });
329        }
330    }
331}
332
333mod impl_exec {
334    use quote::ToTokens;
335
336    use super::*;
337
338    struct ImplTransition<'a> {
339        transition: &'a WorkingTransition,
340    }
341    impl<'a> ToTokens for ImplTransition<'a> {
342        fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
343            let ImplTransition { transition } = self;
344            let WorkingTransition { symbol, to } = transition;
345            let symbol = nfa_static::AtomStatic::serialize_as_token_stream(symbol);
346            let to_label = ImplVMStateLabel(*to);
347            tokens.extend(quote! {{
348                let symbol = #symbol;
349                if symbol.check(c) {
350                    new_threads.push(
351                        ::ere::flat_lockstep_nfa::Thread {
352                            state: VMStates::#to_label,
353                            captures: thread.captures.clone(),
354                        },
355                    );
356                }
357            }});
358        }
359    }
360
361    /// Assumes the `VMStates` enum is already created locally in the token stream
362    ///
363    /// Creates the function `transition_symbols_exec` for running symbol transitions on the flat lockstep NFA
364    /// - expects `new_threads` to be empty
365    ///
366    /// ```ignore
367    /// fn transition_symbols_exec(
368    ///     threads: &[::ere::flat_lockstep_nfa::Thread<#capture_groups, VMStates>],
369    ///     new_threads: &mut ::std::vec::Vec<::ere::flat_lockstep_nfa::Thread<#capture_groups, VMStates>>,
370    ///     c: char,
371    /// ) {
372    ///     // ...
373    /// }
374    /// ```
375    pub(super) struct TransitionSymbols<'a>(pub &'a CachedNFA<'a>);
376    impl<'a> ToTokens for TransitionSymbols<'a> {
377        fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
378            let TransitionSymbols(nfa) = self;
379            let CachedNFA {
380                nfa,
381                capture_groups,
382                excluded_states,
383            } = nfa;
384
385            let transition_symbols_defs_exec = nfa
386                .states
387                .iter()
388                .enumerate()
389                .filter(|(i, _)| !excluded_states[*i])
390                .map(|(i, state)| {
391                    let label = ImplVMStateLabel(i);
392                    let state_transitions = state
393                        .transitions
394                        .iter()
395                        .map(|t| ImplTransition { transition: t });
396
397                    return quote! {
398                        VMStates::#label => {
399                            #(#state_transitions)*
400                        }
401                    };
402                });
403
404            tokens.extend(quote! {
405                fn transition_symbols_exec(
406                    threads: &[::ere::flat_lockstep_nfa::Thread<#capture_groups, VMStates>],
407                    new_threads: &mut ::std::vec::Vec<::ere::flat_lockstep_nfa::Thread<#capture_groups, VMStates>>,
408                    c: char,
409                ) {
410                    for thread in threads {
411                        match thread.state {
412                            #(#transition_symbols_defs_exec)*
413                        }
414                    }
415                }
416            });
417        }
418    }
419
420    /// Implements epsilon transitions for a single state
421    ///
422    /// Becomes:
423    /// ```ignore
424    /// VMStates::#from_state => {
425    ///     // ...
426    /// }
427    /// ```
428    pub(super) struct ImplTransitionStateEpsilon<'a> {
429        pub(super) from_state: ImplVMStateLabel,
430        pub(super) thread_updates: &'a [ThreadUpdates],
431    }
432    impl<'a> ToTokens for ImplTransitionStateEpsilon<'a> {
433        fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
434            let &ImplTransitionStateEpsilon {
435                from_state,
436                thread_updates,
437            } = self;
438
439            // Write epsilon-propogation of threads to the token stream for exec
440            let start_end_threads = thread_updates
441                .iter()
442                .filter(|t| t.0.start_only && t.0.end_only)
443                .map(ThreadUpdates::serialize_thread_update_exec);
444            let start_threads = thread_updates
445                .iter()
446                .filter(|t| t.0.start_only && !t.0.end_only)
447                .map(ThreadUpdates::serialize_thread_update_exec);
448            let end_threads = thread_updates
449                .iter()
450                .filter(|t| !t.0.start_only && t.0.end_only)
451                .map(ThreadUpdates::serialize_thread_update_exec);
452            let normal_threads = thread_updates
453                .iter()
454                .filter(|t| !t.0.start_only && !t.0.end_only)
455                .map(ThreadUpdates::serialize_thread_update_exec);
456
457            tokens.extend(quote! {
458                VMStates::#from_state => {
459                    if is_start && is_end {
460                        #(#start_end_threads)*
461                    }
462                    if is_start {
463                        #(#start_threads)*
464                    }
465                    if is_end {
466                        #(#end_threads)*
467                    }
468                    #(#normal_threads)*
469                }
470            });
471        }
472    }
473
474    /// Implements a function that runs all epsilon transitions for all threads.
475    ///
476    /// ```ignore
477    /// fn transition_epsilons_exec(
478    ///     threads: &[::ere::flat_lockstep_nfa::Thread<#capture_groups, VMStates>],
479    ///     new_threads: &mut ::std::vec::Vec<::ere::flat_lockstep_nfa::Thread<#capture_groups, VMStates>>,
480    ///     idx: usize,
481    ///     len: usize,
482    /// ) {
483    ///     // ...
484    /// }
485    /// ```
486    pub(super) struct TransitionEpsilons<'a>(pub &'a CachedNFA<'a>);
487    impl<'a> ToTokens for TransitionEpsilons<'a> {
488        fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
489            let TransitionEpsilons(nfa) = self;
490            let CachedNFA {
491                nfa,
492                capture_groups,
493                excluded_states,
494            } = nfa;
495            assert_eq!(nfa.states.len(), excluded_states.len());
496            let num_states = nfa.states.len();
497
498            let states_epsilon_transitions = std::iter::zip(nfa.states.iter(), excluded_states)
499                .enumerate()
500                .filter(|(_, (_, &excluded))| !excluded)
501                .map(|(i, _)| {
502                    // all reachable states with next transition as epsilon
503                    let mut new_threads = calculate_epsilon_propogations(nfa, i);
504                    new_threads.retain(|t| {
505                        !nfa.states[t.0.state].transitions.is_empty() || t.0.state + 1 == num_states
506                    });
507
508                    let label: ImplVMStateLabel = ImplVMStateLabel(i);
509
510                    let state_epsilon_transitions_exec = impl_exec::ImplTransitionStateEpsilon {
511                        from_state: label,
512                        thread_updates: &new_threads,
513                    };
514                    state_epsilon_transitions_exec.to_token_stream()
515                });
516
517            tokens.extend(quote! {
518                fn transition_epsilons_exec(
519                    threads: &[::ere::flat_lockstep_nfa::Thread<#capture_groups, VMStates>],
520                    new_threads: &mut ::std::vec::Vec<::ere::flat_lockstep_nfa::Thread<#capture_groups, VMStates>>,
521                    idx: usize,
522                    len: usize,
523                ) {
524                    let is_start = idx == 0;
525                    let is_end = idx == len;
526                    let mut occupied_states = ::std::vec![false; #num_states];
527                    for thread in threads {
528                        match thread.state {
529                            #(#states_epsilon_transitions)*
530                        }
531                    }
532                }
533            });
534        }
535    }
536
537    pub(crate) struct ExecFn<'a>(pub &'a CachedNFA<'a>);
538    impl<'a> ToTokens for ExecFn<'a> {
539        fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
540            let ExecFn(nfa) = self;
541            let CachedNFA {
542                nfa: u8_nfa,
543                excluded_states,
544                capture_groups,
545            } = nfa;
546
547            let enum_states = excluded_states
548                .iter()
549                .enumerate()
550                .filter_map(|(i, excluded)| match excluded {
551                    true => None,
552                    false => Some(ImplVMStateLabel(i)),
553                });
554            let state_count = u8_nfa.states.len();
555            let accept_state = ImplVMStateLabel(state_count - 1);
556
557            let transition_symbols_exec = impl_exec::TransitionSymbols(&nfa);
558            let transition_epsilons_exec = impl_exec::TransitionEpsilons(&nfa);
559
560            tokens.extend(quote! {
561                fn exec<'a>(text: &'a str) -> Option<[Option<&'a str>; #capture_groups]> {
562                    #[derive(Clone, Copy, PartialEq, Eq, Debug)]
563                    enum VMStates {
564                        #(#enum_states,)*
565                    }
566
567                    #transition_symbols_exec
568                    #transition_epsilons_exec
569
570                    let mut threads = ::std::vec::Vec::<::ere::flat_lockstep_nfa::Thread<#capture_groups, VMStates>>::new();
571                    let mut new_threads = ::std::vec::Vec::<::ere::flat_lockstep_nfa::Thread<#capture_groups, VMStates>>::new();
572                    threads.push(::ere::flat_lockstep_nfa::Thread {
573                        state: VMStates::State0,
574                        captures: [(usize::MAX, usize::MAX); #capture_groups],
575                    });
576
577                    transition_epsilons_exec(&threads, &mut new_threads, 0, text.len());
578                    ::std::mem::swap(&mut threads, &mut new_threads);
579
580                    for (i, c) in text.char_indices() {
581                        new_threads.clear();
582                        transition_symbols_exec(&threads, &mut new_threads, c);
583                        ::std::mem::swap(&mut threads, &mut new_threads);
584
585                        new_threads.clear();
586                        transition_epsilons_exec(&threads, &mut new_threads, i + c.len_utf8(), text.len());
587                        ::std::mem::swap(&mut threads, &mut new_threads);
588
589                        if threads.is_empty() {
590                            return None;
591                        }
592                    }
593
594                    let final_capture_bounds = threads
595                        .into_iter()
596                        .find(|t| t.state == VMStates::#accept_state)?
597                        .captures;
598                    let mut captures = [::core::option::Option::None; #capture_groups];
599                    for (i, (start, end)) in final_capture_bounds.into_iter().enumerate() {
600                        if start != usize::MAX {
601                            assert_ne!(end, usize::MAX);
602                            // assert!(start <= end);
603                            captures[i] = text.get(start..end);
604                            assert!(captures[i].is_some());
605                        } else {
606                            assert_eq!(end, usize::MAX);
607                        }
608                    }
609                    return Some(captures);
610                }
611            });
612        }
613    }
614}
615
616#[derive(Clone, PartialEq, Eq)]
617struct ThreadUpdates(EpsilonPropogation);
618impl ThreadUpdates {
619    /// Creates a block which takes `list: &mut [bool; STATE_NUM]` from its local context, updates it in-place using `self` (compile-time).
620    pub fn serialize_thread_update_test(&self) -> proc_macro2::TokenStream {
621        let new_state = self.0.state;
622        return quote! {{
623            list[#new_state] = true;
624        }};
625    }
626    /// Creates a block which takes `thread` from its local context, updates it using `self` (compile-time),
627    /// and appends it to `new_threads` from its local context.
628    pub fn serialize_thread_update_exec(&self) -> proc_macro2::TokenStream {
629        let new_state_idx = self.0.state;
630        let new_state = ImplVMStateLabel(self.0.state);
631        let capture_updates = self.0.update_tags.iter().map(|tag| match tag {
632            Tag::StartCapture(capture_group) => quote! {
633                new_thread.captures[#capture_group].0 = idx;
634            },
635            Tag::EndCapture(capture_group) => quote! {
636                new_thread.captures[#capture_group].1 = idx;
637            },
638        });
639
640        return quote! {
641            if !occupied_states[#new_state_idx] {
642                let mut new_thread = thread.clone();
643                new_thread.state = VMStates::#new_state;
644
645                #(#capture_updates)*
646
647                new_threads.push(new_thread);
648                occupied_states[#new_state_idx] = true;
649            }
650        };
651    }
652}
653
654fn calculate_epsilon_propogations(nfa: &WorkingNFA, state: usize) -> Vec<ThreadUpdates> {
655    let prop = EpsilonPropogation::calculate_epsilon_propogations_char(nfa, state);
656    return prop.into_iter().map(ThreadUpdates).collect();
657}
658
659/// Converts a [`WorkingNFA`] into a format that, when returned by a proc macro, will
660/// create the corresponding engine.
661///
662/// Will evaluate to a `const` pair `(test_fn, exec_fn)`.
663pub(crate) fn serialize_flat_lockstep_nfa_token_stream(
664    nfa: &WorkingNFA,
665) -> proc_macro2::TokenStream {
666    let nfa = CachedNFA::new(nfa);
667
668    let test_fn = impl_test::TestFn(&nfa);
669    let exec_fn = impl_exec::ExecFn(&nfa);
670
671    return quote! {{
672        #test_fn
673        #exec_fn
674
675        (test, exec)
676    }};
677}