ere_core/engines/
flat_lockstep_nfa_u8.rs

1//! Implements an nfa-like regex engine for over `u8`s.
2//! The engine keeps all threads in lockstep (all threads are at the same input index),
3//! and the NFA's epsilon transitions are flattened to a single epsilon transition between symbols
4//! (including handling anchors and capture tags).
5//!
6//! Currently we flatten all epsilon transitions for the VM so that epsilon transitions are at most a single step between symbols.
7//! I'll have to review to ensure we avoid this causing large binary size overhead,
8//! but it should be worst-case `O(n^2)` in the number of states, and far fewer on average.
9
10use crate::{
11    epsilon_propogation::{EpsilonPropogation, Tag},
12    working_u8_nfa::{U8Transition, U8NFA},
13};
14use quote::{quote, ToTokens, TokenStreamExt as _};
15use std::fmt::Write;
16
17#[derive(Clone)]
18pub struct Thread<const N: usize, S: Send + Sync + Copy + Eq> {
19    pub state: S,
20    pub captures: [(usize, usize); N],
21}
22impl<const N: usize, S: Send + Sync + Copy + Eq + std::fmt::Debug> std::fmt::Debug
23    for Thread<N, S>
24{
25    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
26        struct CapturesDebug<'a, const N: usize>(&'a [(usize, usize); N]);
27        impl<'a, const N: usize> std::fmt::Debug for CapturesDebug<'a, N> {
28            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
29                f.write_char('[')?;
30                for (i, endpoints) in self.0.iter().enumerate() {
31                    if i != 0 {
32                        f.write_str(", ")?;
33                    }
34                    match endpoints {
35                        (usize::MAX, usize::MAX) => f.write_str("(_, _)")?,
36                        (start, usize::MAX) => write!(f, "({start}, _)")?,
37                        (usize::MAX, end) => write!(f, "(_, {end})")?,
38                        (start, end) => write!(f, "({start}, {end})")?,
39                    }
40                }
41                return f.write_char(']');
42            }
43        }
44        return f
45            .debug_struct("Thread")
46            .field("state", &self.state)
47            .field("captures", &CapturesDebug(&self.captures))
48            .finish();
49    }
50}
51
52/// The NFA and some precomputed data to go with it.
53///
54/// Helps avoid recomputing the same data all over
55/// and also packages data together for convenience.
56struct CachedNFA<'a> {
57    nfa: &'a U8NFA,
58    excluded_states: Vec<bool>,
59    capture_groups: usize,
60}
61impl<'a> CachedNFA<'a> {
62    fn new(nfa: &'a U8NFA) -> CachedNFA<'a> {
63        let excluded_states = compute_excluded_states(nfa);
64        assert_eq!(nfa.states.len(), excluded_states.len());
65        let capture_groups = nfa.num_capture_groups();
66        return CachedNFA {
67            nfa,
68            excluded_states,
69            capture_groups,
70        };
71    }
72}
73
74/// Since we are shortcutting the epsilon transitions, we can skip printing
75/// states that have only epsilon transitions and are not the start/end states
76fn compute_excluded_states(nfa: &U8NFA) -> Vec<bool> {
77    let mut out = vec![true; nfa.states.len()];
78    out[0] = false;
79    out[nfa.states.len() - 1] = false;
80    for (from, state) in nfa.states.iter().enumerate() {
81        for t in &state.transitions {
82            out[from] = false;
83            out[t.to] = false;
84        }
85    }
86
87    return out;
88}
89
90#[derive(Clone, Copy, PartialEq, Eq)]
91struct ImplVMStateLabel(usize);
92impl ToTokens for ImplVMStateLabel {
93    fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
94        let ImplVMStateLabel(idx) = self;
95        let label = format!("State{idx}");
96        tokens.append(proc_macro2::Ident::new(
97            &label,
98            proc_macro2::Span::call_site(),
99        ));
100    }
101}
102
103mod impl_test {
104    use quote::ToTokens;
105
106    use super::*;
107
108    /// Implements symbol transitions for a single state
109    struct ImplTransitionStateSymbol<'a> {
110        transition: &'a U8Transition,
111    }
112    impl<'a> ToTokens for ImplTransitionStateSymbol<'a> {
113        fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
114            let &ImplTransitionStateSymbol { transition } = self;
115            let U8Transition { symbol, to } = transition;
116            let start = symbol.start();
117            let end = symbol.end();
118            tokens.extend(quote! {{
119                if #start <= c && c <= #end {
120                    new_list[#to] = true;
121                }
122            }});
123        }
124    }
125
126    /// Assumes the `VMStates` enum is already created locally in the token stream
127    ///
128    /// Creates the function `transition_symbols_test` for running symbol transitions on the engine
129    ///
130    /// ```ignore
131    /// fn transition_symbols_test(
132    ///     list: &[bool],
133    ///     new_list: &mut [bool],
134    ///     c: u8,
135    /// ) {
136    ///     // ...
137    /// }
138    /// ```
139    pub(super) struct TransitionSymbols<'a>(pub &'a CachedNFA<'a>);
140    impl<'a> ToTokens for TransitionSymbols<'a> {
141        fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
142            let TransitionSymbols(nfa) = self;
143            let CachedNFA {
144                nfa,
145                excluded_states,
146                ..
147            } = nfa;
148
149            let transition_symbols_defs_test = nfa
150                .states
151                .iter()
152                .enumerate()
153                .filter(|(i, _)| !excluded_states[*i])
154                .map(|(i, state)| {
155                    let state_transitions = state
156                        .transitions
157                        .iter()
158                        .map(|t| ImplTransitionStateSymbol { transition: t });
159
160                    return quote! {
161                        if list[#i] {
162                            #(#state_transitions)*
163                        }
164                    };
165                });
166
167            tokens.extend(quote! {
168                fn transition_symbols_test(
169                    list: &[bool],
170                    new_list: &mut [bool],
171                    c: u8,
172                ) {
173                    #(#transition_symbols_defs_test)*
174                }
175            });
176        }
177    }
178
179    /// Implements epsilon transitions for a single state
180    ///
181    /// Becomes:
182    /// ```ignore
183    /// if list[#from_state] {
184    ///     // ...
185    /// }
186    /// ```
187    pub(super) struct ImplTransitionStateEpsilon<'a> {
188        pub(super) from_state: ImplVMStateLabel,
189        pub(super) thread_updates: &'a [ThreadUpdates],
190    }
191    impl<'a> ToTokens for ImplTransitionStateEpsilon<'a> {
192        fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
193            let &ImplTransitionStateEpsilon {
194                from_state,
195                thread_updates,
196            } = self;
197            let ImplVMStateLabel(from_state) = from_state;
198
199            // Write epsilon-propogation of threads to the token stream for test
200            let start_end_threads = thread_updates
201                .iter()
202                .filter(|t| t.0.start_only && t.0.end_only)
203                .map(ThreadUpdates::serialize_thread_update_test);
204            let start_threads = thread_updates
205                .iter()
206                .filter(|t| t.0.start_only && !t.0.end_only)
207                .map(ThreadUpdates::serialize_thread_update_test);
208            let end_threads = thread_updates
209                .iter()
210                .filter(|t| !t.0.start_only && t.0.end_only)
211                .map(ThreadUpdates::serialize_thread_update_test);
212            let normal_threads = thread_updates
213                .iter()
214                .filter(|t| !t.0.start_only && !t.0.end_only)
215                .map(ThreadUpdates::serialize_thread_update_test);
216
217            tokens.extend(quote! {
218                if list[#from_state] {
219                    if is_start && is_end {
220                        #(#start_end_threads)*
221                    }
222                    if is_start {
223                        #(#start_threads)*
224                    }
225                    if is_end {
226                        #(#end_threads)*
227                    }
228                    #(#normal_threads)*
229                }
230            });
231        }
232    }
233
234    /// Implements a function that runs all epsilon transitions for all threads.
235    ///
236    /// ```ignore
237    /// fn transition_epsilons_test(
238    ///     list: &mut [bool],
239    ///     idx: usize,
240    ///     len: usize,
241    /// ) {
242    ///     // ...
243    /// }
244    /// ```
245    pub(super) struct TransitionEpsilons<'a>(pub &'a CachedNFA<'a>);
246    impl<'a> ToTokens for TransitionEpsilons<'a> {
247        fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
248            let TransitionEpsilons(nfa) = self;
249            let CachedNFA {
250                nfa,
251                excluded_states,
252                ..
253            } = nfa;
254            assert_eq!(nfa.states.len(), excluded_states.len());
255            let num_states = nfa.states.len();
256
257            let states_epsilon_transitions = std::iter::zip(nfa.states.iter(), excluded_states)
258                .enumerate()
259                .filter(|(_, (_, &excluded))| !excluded)
260                .map(|(i, _)| {
261                    // all reachable states with next transition as epsilon
262                    let mut new_threads = calculate_epsilon_propogations(nfa, i);
263                    new_threads.retain(|t| {
264                        !nfa.states[t.0.state].transitions.is_empty() || t.0.state + 1 == num_states
265                    });
266
267                    let label: ImplVMStateLabel = ImplVMStateLabel(i);
268
269                    let state_epsilon_transitions_test = impl_test::ImplTransitionStateEpsilon {
270                        from_state: label,
271                        thread_updates: &new_threads,
272                    };
273                    state_epsilon_transitions_test.to_token_stream()
274                });
275
276            tokens.extend(quote! {
277                fn transition_epsilons_test(
278                    list: &mut [bool],
279                    idx: usize,
280                    len: usize,
281                ) {
282                    let is_start = idx == 0;
283                    let is_end = idx == len;
284                    #(#states_epsilon_transitions)*
285                }
286            });
287        }
288    }
289
290    pub(crate) struct TestFn<'a>(pub &'a CachedNFA<'a>);
291    impl<'a> ToTokens for TestFn<'a> {
292        fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
293            let TestFn(nfa) = self;
294            let CachedNFA {
295                nfa: u8_nfa,
296                excluded_states,
297                ..
298            } = nfa;
299
300            let enum_states = excluded_states
301                .iter()
302                .enumerate()
303                .filter_map(|(i, excluded)| match excluded {
304                    true => None,
305                    false => Some(ImplVMStateLabel(i)),
306                });
307            let state_count = u8_nfa.states.len();
308            let accept_state = ImplVMStateLabel(state_count - 1);
309
310            let transition_symbols_test = TransitionSymbols(nfa);
311            let transition_epsilons_test = TransitionEpsilons(nfa);
312
313            tokens.extend(quote! {
314                fn test(text: &str) -> bool {
315                    #[derive(Clone, Copy, PartialEq, Eq, Debug)]
316                    enum VMStates {
317                        #(#enum_states,)*
318                    }
319
320                    #transition_symbols_test
321                    #transition_epsilons_test
322
323                    let mut list = [false; #state_count];
324                    let mut new_list = [false; #state_count];
325                    list[0] = true;
326
327                    transition_epsilons_test(&mut list, 0, text.len());
328                    for (i, c) in text.bytes().enumerate() {
329                        transition_symbols_test(&list, &mut new_list, c);
330                        if new_list.iter().all(|b| !b) {
331                            return false;
332                        }
333                        ::std::mem::swap(&mut list, &mut new_list);
334                        transition_epsilons_test(&mut list, i + 1, text.len());
335                        new_list.fill(false);
336                    }
337
338                    return list[#state_count - 1];
339                }
340            });
341        }
342    }
343}
344
345mod impl_exec {
346    use quote::ToTokens;
347
348    use super::*;
349
350    struct ImplTransition<'a> {
351        transition: &'a U8Transition,
352    }
353    impl<'a> ToTokens for ImplTransition<'a> {
354        fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
355            let ImplTransition { transition } = self;
356            let U8Transition { symbol, to } = transition;
357            let start = symbol.start();
358            let end = symbol.end();
359            let to_label = ImplVMStateLabel(*to);
360            tokens.extend(quote! {{
361                if #start <= c && c <= #end {
362                    new_threads.push(
363                        ::ere::flat_lockstep_nfa_u8::Thread {
364                            state: VMStates::#to_label,
365                            captures: thread.captures.clone(),
366                        },
367                    );
368                }
369            }});
370        }
371    }
372
373    /// Assumes the `VMStates` enum is already created locally in the token stream
374    ///
375    /// Creates the function `transition_symbols_exec` for running symbol transitions on the engine
376    /// - expects `new_threads` to be empty
377    ///
378    /// ```ignore
379    /// fn transition_symbols_exec(
380    ///     threads: &[::ere::flat_lockstep_nfa_u8::Thread<#capture_groups, VMStates>],
381    ///     new_threads: &mut ::std::vec::Vec<::ere::flat_lockstep_nfa_u8::Thread<#capture_groups, VMStates>>,
382    ///     c: u8,
383    /// ) {
384    ///     // ...
385    /// }
386    /// ```
387    pub(super) struct TransitionSymbols<'a>(pub &'a CachedNFA<'a>);
388    impl<'a> ToTokens for TransitionSymbols<'a> {
389        fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
390            let TransitionSymbols(nfa) = self;
391            let CachedNFA {
392                nfa,
393                capture_groups,
394                excluded_states,
395            } = nfa;
396
397            let transition_symbols_defs_exec = nfa
398                .states
399                .iter()
400                .enumerate()
401                .filter(|(i, _)| !excluded_states[*i])
402                .map(|(i, state)| {
403                    let label = ImplVMStateLabel(i);
404                    let state_transitions = state
405                        .transitions
406                        .iter()
407                        .map(|t| ImplTransition { transition: t });
408
409                    return quote! {
410                        VMStates::#label => {
411                            #(#state_transitions)*
412                        }
413                    };
414                });
415
416            tokens.extend(quote! {
417                fn transition_symbols_exec(
418                    threads: &[::ere::flat_lockstep_nfa_u8::Thread<#capture_groups, VMStates>],
419                    new_threads: &mut ::std::vec::Vec<::ere::flat_lockstep_nfa_u8::Thread<#capture_groups, VMStates>>,
420                    c: u8,
421                ) {
422                    for thread in threads {
423                        match thread.state {
424                            #(#transition_symbols_defs_exec)*
425                        }
426                    }
427                }
428            });
429        }
430    }
431
432    /// Implements epsilon transitions for a single state
433    ///
434    /// Becomes:
435    /// ```ignore
436    /// VMStates::#from_state => {
437    ///     // ...
438    /// }
439    /// ```
440    pub(super) struct ImplTransitionStateEpsilon<'a> {
441        pub(super) from_state: ImplVMStateLabel,
442        pub(super) thread_updates: &'a [ThreadUpdates],
443    }
444    impl<'a> ToTokens for ImplTransitionStateEpsilon<'a> {
445        fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
446            let &ImplTransitionStateEpsilon {
447                from_state,
448                thread_updates,
449            } = self;
450
451            // Write epsilon-propogation of threads to the token stream for exec
452            let start_end_threads = thread_updates
453                .iter()
454                .filter(|t| t.0.start_only && t.0.end_only)
455                .map(ThreadUpdates::serialize_thread_update_exec);
456            let start_threads = thread_updates
457                .iter()
458                .filter(|t| t.0.start_only && !t.0.end_only)
459                .map(ThreadUpdates::serialize_thread_update_exec);
460            let end_threads = thread_updates
461                .iter()
462                .filter(|t| !t.0.start_only && t.0.end_only)
463                .map(ThreadUpdates::serialize_thread_update_exec);
464            let normal_threads = thread_updates
465                .iter()
466                .filter(|t| !t.0.start_only && !t.0.end_only)
467                .map(ThreadUpdates::serialize_thread_update_exec);
468
469            tokens.extend(quote! {
470                VMStates::#from_state => {
471                    if is_start && is_end {
472                        #(#start_end_threads)*
473                    }
474                    if is_start {
475                        #(#start_threads)*
476                    }
477                    if is_end {
478                        #(#end_threads)*
479                    }
480                    #(#normal_threads)*
481                }
482            });
483        }
484    }
485
486    /// Implements a function that runs all epsilon transitions for all threads.
487    /// - expects `new_threads` to be empty
488    ///
489    /// ```ignore
490    /// fn transition_epsilons_exec(
491    ///     threads: &[::ere::flat_lockstep_nfa_u8::Thread<#capture_groups, VMStates>],
492    ///     new_threads: &mut ::std::vec::Vec<::ere::flat_lockstep_nfa_u8::Thread<#capture_groups, VMStates>>,
493    ///     idx: usize,
494    ///     len: usize,
495    /// ) {
496    ///     // ...
497    /// }
498    /// ```
499    pub(super) struct TransitionEpsilons<'a>(pub &'a CachedNFA<'a>);
500    impl<'a> ToTokens for TransitionEpsilons<'a> {
501        fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
502            let TransitionEpsilons(nfa) = self;
503            let CachedNFA {
504                nfa,
505                capture_groups,
506                excluded_states,
507            } = nfa;
508            assert_eq!(nfa.states.len(), excluded_states.len());
509            let num_states = nfa.states.len();
510
511            let states_epsilon_transitions = std::iter::zip(nfa.states.iter(), excluded_states)
512                .enumerate()
513                .filter(|(_, (_, &excluded))| !excluded)
514                .map(|(i, _)| {
515                    // all reachable states with next transition as epsilon
516                    let mut new_threads = calculate_epsilon_propogations(nfa, i);
517                    new_threads.retain(|t| {
518                        !nfa.states[t.0.state].transitions.is_empty() || t.0.state + 1 == num_states
519                    });
520
521                    let label: ImplVMStateLabel = ImplVMStateLabel(i);
522
523                    let state_epsilon_transitions_exec = impl_exec::ImplTransitionStateEpsilon {
524                        from_state: label,
525                        thread_updates: &new_threads,
526                    };
527                    state_epsilon_transitions_exec.to_token_stream()
528                });
529
530            tokens.extend(quote! {
531                fn transition_epsilons_exec(
532                    threads: &[::ere::flat_lockstep_nfa_u8::Thread<#capture_groups, VMStates>],
533                    new_threads: &mut ::std::vec::Vec<::ere::flat_lockstep_nfa_u8::Thread<#capture_groups, VMStates>>,
534                    idx: usize,
535                    len: usize,
536                ) {
537                    let is_start = idx == 0;
538                    let is_end = idx == len;
539                    let mut occupied_states = ::std::vec![false; #num_states];
540
541                    for thread in threads {
542                        match thread.state {
543                            #(#states_epsilon_transitions)*
544                        }
545                    }
546                }
547            });
548        }
549    }
550
551    pub(crate) struct ExecFn<'a>(pub &'a CachedNFA<'a>);
552    impl<'a> ToTokens for ExecFn<'a> {
553        fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
554            let ExecFn(nfa) = self;
555            let CachedNFA {
556                nfa: u8_nfa,
557                excluded_states,
558                capture_groups,
559            } = nfa;
560
561            let enum_states = excluded_states
562                .iter()
563                .enumerate()
564                .filter_map(|(i, excluded)| match excluded {
565                    true => None,
566                    false => Some(ImplVMStateLabel(i)),
567                });
568            let state_count = u8_nfa.states.len();
569            let accept_state = ImplVMStateLabel(state_count - 1);
570
571            let transition_symbols_exec = impl_exec::TransitionSymbols(&nfa);
572            let transition_epsilons_exec = impl_exec::TransitionEpsilons(&nfa);
573
574            tokens.extend(quote! {
575                fn exec<'a>(text: &'a str) -> Option<[Option<&'a str>; #capture_groups]> {
576                    #[derive(Clone, Copy, PartialEq, Eq, Debug)]
577                    enum VMStates {
578                        #(#enum_states,)*
579                    }
580
581                    #transition_symbols_exec
582                    #transition_epsilons_exec
583
584                    let mut threads = ::std::vec::Vec::<::ere::flat_lockstep_nfa_u8::Thread<#capture_groups, VMStates>>::new();
585                    let mut new_threads = ::std::vec::Vec::<::ere::flat_lockstep_nfa_u8::Thread<#capture_groups, VMStates>>::new();
586                    threads.push(::ere::flat_lockstep_nfa_u8::Thread {
587                        state: VMStates::State0,
588                        captures: [(usize::MAX, usize::MAX); #capture_groups],
589                    });
590
591                    transition_epsilons_exec(&threads, &mut new_threads,0, text.len());
592                    ::std::mem::swap(&mut threads, &mut new_threads);
593
594                    for (i, c) in text.bytes().enumerate() {
595                        new_threads.clear();
596                        transition_symbols_exec(&threads, &mut new_threads, c);
597                        ::std::mem::swap(&mut threads, &mut new_threads);
598
599                        new_threads.clear();
600                        transition_epsilons_exec(&threads, &mut new_threads, i + 1, text.len());
601                        ::std::mem::swap(&mut threads, &mut new_threads);
602                        if threads.is_empty() {
603                            return ::core::option::Option::None;
604                        }
605                    }
606
607                    let final_capture_bounds = threads
608                        .into_iter()
609                        .find(|t| t.state == VMStates::#accept_state)?
610                        .captures;
611                    let mut captures = [::core::option::Option::None; #capture_groups];
612                    for (i, (start, end)) in final_capture_bounds.into_iter().enumerate() {
613                        if start != usize::MAX {
614                            assert_ne!(end, usize::MAX);
615                            // assert!(start <= end);
616                            captures[i] = text.get(start..end);
617                            assert!(captures[i].is_some());
618                        } else {
619                            assert_eq!(end, usize::MAX);
620                        }
621                    }
622                    return ::core::option::Option::Some(captures);
623                }
624            });
625        }
626    }
627}
628
629#[derive(Clone, PartialEq, Eq)]
630struct ThreadUpdates(EpsilonPropogation);
631impl ThreadUpdates {
632    /// Creates a block which takes `list: &mut [bool; STATE_NUM]` from its local context, updates it in-place using `self` (compile-time).
633    pub fn serialize_thread_update_test(&self) -> proc_macro2::TokenStream {
634        let new_state = self.0.state;
635        return quote! {{
636            list[#new_state] = true;
637        }};
638    }
639    /// Creates a block which takes `thread` from its local context, updates it using `self` (compile-time),
640    /// and appends it to `new_threads` from its local context.
641    pub fn serialize_thread_update_exec(&self) -> proc_macro2::TokenStream {
642        let new_state_idx = self.0.state;
643        let new_state = ImplVMStateLabel(self.0.state);
644        let capture_updates = self.0.update_tags.iter().map(|tag| match tag {
645            Tag::StartCapture(capture_group) => quote! {
646                new_thread.captures[#capture_group].0 = idx;
647            },
648            Tag::EndCapture(capture_group) => quote! {
649                new_thread.captures[#capture_group].1 = idx;
650            },
651        });
652
653        return quote! {
654            if !occupied_states[#new_state_idx] {
655                let mut new_thread = thread.clone();
656                new_thread.state = VMStates::#new_state;
657
658                #(#capture_updates)*
659
660                new_threads.push(new_thread);
661                occupied_states[#new_state_idx] = true;
662            }
663        };
664    }
665}
666
667fn calculate_epsilon_propogations(nfa: &U8NFA, state: usize) -> Vec<ThreadUpdates> {
668    let prop = EpsilonPropogation::calculate_epsilon_propogations_u8(nfa, state);
669    return prop.into_iter().map(ThreadUpdates).collect();
670}
671
672/// Converts a [`U8NFA`] into a format that, when returned by a proc macro, will
673/// create the corresponding engine.
674pub(crate) fn serialize_flat_lockstep_nfa_u8_token_stream(nfa: &U8NFA) -> proc_macro2::TokenStream {
675    let nfa = CachedNFA::new(nfa);
676
677    let test_fn = impl_test::TestFn(&nfa);
678    let exec_fn = impl_exec::ExecFn(&nfa);
679
680    return quote! {{
681        #test_fn
682        #exec_fn
683
684        (test, exec)
685    }};
686}