ere_core/
pike_vm.rs

1//! Not exactly the PikeVM, but close enough that I am naming it that.
2//! It works similarly, except that since we are building at compile-time, there are benefits from inlining splitting.
3//!
4//! Due to the optimizations done earlier for the [`WorkingNFA`], we also have a slightly different NFA structure.
5//! Currently we flatten all epsilon transitions for the VM so that epsilon transitions are at most a single step between symbols.
6//! I'll have to review to ensure we avoid this causing large binary size overhead,
7//! but it should be worst-case `O(n^2)` in the number of states, and far fewer on average.
8
9use crate::{
10    nfa_static,
11    working_nfa::{EpsilonType, WorkingNFA, WorkingTransition},
12};
13use quote::quote;
14use std::fmt::Write;
15
16#[derive(Clone)]
17pub struct PikeVMThread<const N: usize, S: Send + Sync + Copy + Eq> {
18    pub state: S,
19    pub captures: [(usize, usize); N],
20}
21impl<const N: usize, S: Send + Sync + Copy + Eq + std::fmt::Debug> std::fmt::Debug
22    for PikeVMThread<N, S>
23{
24    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
25        struct CapturesDebug<'a, const N: usize>(&'a [(usize, usize); N]);
26        impl<'a, const N: usize> std::fmt::Debug for CapturesDebug<'a, N> {
27            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
28                f.write_char('[')?;
29                for (i, endpoints) in self.0.iter().enumerate() {
30                    if i != 0 {
31                        f.write_str(", ")?;
32                    }
33                    match endpoints {
34                        (usize::MAX, usize::MAX) => f.write_str("(_, _)")?,
35                        (start, usize::MAX) => write!(f, "({start}, _)")?,
36                        (usize::MAX, end) => write!(f, "(_, {end})")?,
37                        (start, end) => write!(f, "({start}, {end})")?,
38                    }
39                }
40                return f.write_char(']');
41            }
42        }
43        return f
44            .debug_struct("PikeVMThread")
45            .field("state", &self.state)
46            .field("captures", &CapturesDebug(&self.captures))
47            .finish();
48    }
49}
50
51fn vmstate_label(idx: usize) -> proc_macro2::Ident {
52    let label = format!("State{idx}");
53    return proc_macro2::Ident::new(&label, proc_macro2::Span::call_site());
54}
55
56/// Since we are shortcutting the epsilon transitions, we can skip printing
57/// states that have only epsilon transitions and are not the start/end states
58fn compute_excluded_states(nfa: &WorkingNFA) -> Vec<bool> {
59    let mut out = vec![true; nfa.states.len()];
60    out[0] = false;
61    out[nfa.states.len() - 1] = false;
62    for (from, state) in nfa.states.iter().enumerate() {
63        for t in &state.transitions {
64            out[from] = false;
65            out[t.to] = false;
66        }
67    }
68
69    return out;
70}
71
72/// Assumes the `VMStates` enum is already created locally in the token stream
73///
74/// Creates the function for running symbol transitions on the pike VM
75fn serialize_pike_vm_symbol_propogation(
76    nfa: &WorkingNFA,
77) -> (proc_macro2::TokenStream, proc_macro2::TokenStream) {
78    let WorkingNFA { states } = nfa;
79    let capture_groups = nfa.num_capture_groups();
80    let excluded_states = compute_excluded_states(nfa);
81
82    fn make_symbol_transition_test(t: &WorkingTransition) -> proc_macro2::TokenStream {
83        let WorkingTransition { symbol, to } = t;
84        let symbol = nfa_static::AtomStatic::serialize_as_token_stream(symbol);
85        return quote! {
86            {
87                let symbol = #symbol;
88                if symbol.check(c) {
89                    new_list[#to] = true;
90                }
91            }
92        };
93    }
94
95    fn make_symbol_transition_exec(t: &WorkingTransition) -> proc_macro2::TokenStream {
96        let WorkingTransition { symbol, to } = t;
97        let symbol = nfa_static::AtomStatic::serialize_as_token_stream(symbol);
98        let to_label = vmstate_label(*to);
99        return quote! {
100            {
101                let symbol = #symbol;
102                if symbol.check(c) {
103                    out.push(
104                        ::ere::pike_vm::PikeVMThread {
105                            state: VMStates::#to_label,
106                            captures: thread.captures.clone(),
107                        },
108                    );
109                }
110            }
111        };
112    }
113
114    let transition_symbols_defs_test: proc_macro2::TokenStream = states
115        .iter()
116        .enumerate()
117        .filter(|(i, _)| !excluded_states[*i])
118        .map(|(i, state)| {
119            let state_transitions: proc_macro2::TokenStream = state
120                .transitions
121                .iter()
122                .map(make_symbol_transition_test)
123                .collect();
124
125            return quote! {
126                if list[#i] {
127                    #state_transitions
128                }
129            };
130        })
131        .collect();
132
133    let transition_symbols_defs_exec: proc_macro2::TokenStream = states
134        .iter()
135        .enumerate()
136        .filter(|(i, _)| !excluded_states[*i])
137        .map(|(i, state)| {
138            let label = vmstate_label(i);
139            let state_transitions: proc_macro2::TokenStream = state
140                .transitions
141                .iter()
142                .map(make_symbol_transition_exec)
143                .collect();
144
145            return quote! {
146                VMStates::#label => {
147                    #state_transitions
148                }
149            };
150        })
151        .collect();
152
153    let transition_symbols_test = quote! {
154        fn transition_symbols_test(
155            list: &[bool],
156            new_list: &mut [bool],
157            c: char,
158        ) {
159            #transition_symbols_defs_test
160        }
161    };
162    let transition_symbols_exec = quote! {
163        fn transition_symbols_exec(
164            threads: &[::ere::pike_vm::PikeVMThread<#capture_groups, VMStates>],
165            c: char,
166        ) -> ::std::vec::Vec<::ere::pike_vm::PikeVMThread<#capture_groups, VMStates>> {
167            let mut out = ::std::vec::Vec::<::ere::pike_vm::PikeVMThread<#capture_groups, VMStates>>::new();
168            for thread in threads {
169                match thread.state {
170                    #transition_symbols_defs_exec
171                }
172            }
173            return out;
174        }
175    };
176
177    return (transition_symbols_test, transition_symbols_exec);
178}
179
180#[derive(Clone, PartialEq, Eq)]
181struct ThreadUpdates {
182    pub state: usize,
183    pub update_captures: Vec<(bool, bool)>,
184    pub start_only: bool,
185    pub end_only: bool,
186}
187impl ThreadUpdates {
188    /// Creates a block which takes `list: &mut [bool; STATE_NUM]` from its local context, updates it in-place using `self` (compile-time).
189    pub fn serialize_thread_update_test(&self) -> proc_macro2::TokenStream {
190        let new_state = self.state;
191        return quote! {{
192            list[#new_state] = true;
193        }};
194    }
195    /// Creates a block which takes `thread` from its local context, updates it using `self` (compile-time),
196    /// and appends it to `out` from its local context.
197    pub fn serialize_thread_update_exec(&self) -> proc_macro2::TokenStream {
198        let new_state_idx = self.state;
199        let new_state = vmstate_label(self.state);
200        let mut capture_updates = proc_macro2::TokenStream::new();
201        for (i, (start, end)) in self.update_captures.iter().cloned().enumerate() {
202            if start {
203                capture_updates.extend(quote! {
204                    new_thread.captures[#i].0 = idx;
205                });
206            }
207            if end {
208                capture_updates.extend(quote! {
209                    new_thread.captures[#i].1 = idx;
210                });
211            }
212        }
213
214        return quote! {
215            if !occupied_states[#new_state_idx] {
216                let mut new_thread = thread.clone();
217                new_thread.state = VMStates::#new_state;
218
219                #capture_updates
220
221                out.push(new_thread);
222                occupied_states[#new_state_idx] = true;
223            }
224        };
225    }
226}
227
228fn calculate_epsilon_propogations(nfa: &WorkingNFA, state: usize) -> Vec<ThreadUpdates> {
229    let WorkingNFA { states } = nfa;
230    let capture_groups = nfa.num_capture_groups();
231    // reduce epsilons to occur in a single step
232    let mut new_threads = vec![];
233    fn traverse(
234        thread: ThreadUpdates,
235        states: &Vec<crate::working_nfa::WorkingState>,
236        out: &mut Vec<ThreadUpdates>,
237    ) {
238        out.push(thread.clone());
239        for e in &states[thread.state].epsilons {
240            let mut new_thread = thread.clone();
241            new_thread.state = e.to;
242            match e.special {
243                EpsilonType::None => {}
244                EpsilonType::StartAnchor => new_thread.start_only = true,
245                EpsilonType::EndAnchor => new_thread.end_only = true,
246                EpsilonType::StartCapture(capture_group) => {
247                    new_thread.update_captures[capture_group].0 = true
248                }
249                EpsilonType::EndCapture(capture_group) => {
250                    new_thread.update_captures[capture_group].1 = true
251                }
252            }
253
254            if !out.contains(&new_thread) {
255                traverse(new_thread, states, out);
256            }
257        }
258    }
259    traverse(
260        ThreadUpdates {
261            state,
262            update_captures: vec![(false, false); capture_groups],
263            start_only: false,
264            end_only: false,
265        },
266        states,
267        &mut new_threads,
268    );
269    return new_threads;
270}
271
272/// Assumes the `VMStates` enum is already created locally in the token stream
273fn serialize_pike_vm_epsilon_propogation(
274    nfa: &WorkingNFA,
275) -> (proc_macro2::TokenStream, proc_macro2::TokenStream) {
276    let WorkingNFA { states } = nfa;
277    let capture_groups = nfa.num_capture_groups();
278    let num_states = states.len();
279    let excluded_states = compute_excluded_states(nfa);
280
281    // Generate code to propogate/split a thread according to epsilon transitions
282    let mut transition_epsilons_test = proc_macro2::TokenStream::new();
283    let mut transition_epsilons_exec = proc_macro2::TokenStream::new();
284    for (i, _) in states.iter().enumerate() {
285        if excluded_states[i] {
286            // since we propogate, some states are now useless if they are intermediate
287            // with only epsilon transitions
288            continue;
289        }
290        // all reachable states with next transition as epsilon
291        let mut new_threads = calculate_epsilon_propogations(nfa, i);
292        new_threads
293            .retain(|t| !states[t.state].transitions.is_empty() || t.state + 1 == num_states);
294
295        // Write epsilon-propogation of threads to the token stream for test
296        let start_end_threads: proc_macro2::TokenStream = new_threads
297            .iter()
298            .filter(|t| t.start_only && t.end_only)
299            .map(ThreadUpdates::serialize_thread_update_test)
300            .collect();
301        let start_threads: proc_macro2::TokenStream = new_threads
302            .iter()
303            .filter(|t| t.start_only && !t.end_only)
304            .map(ThreadUpdates::serialize_thread_update_test)
305            .collect();
306        let end_threads: proc_macro2::TokenStream = new_threads
307            .iter()
308            .filter(|t| !t.start_only && t.end_only)
309            .map(ThreadUpdates::serialize_thread_update_test)
310            .collect();
311        let normal_threads: proc_macro2::TokenStream = new_threads
312            .iter()
313            .filter(|t| !t.start_only && !t.end_only)
314            .map(ThreadUpdates::serialize_thread_update_test)
315            .collect();
316
317        let label = vmstate_label(i);
318        transition_epsilons_test.extend(quote! {
319            if list[#i] {
320                if is_start && is_end {
321                    #start_end_threads
322                }
323                if is_start {
324                    #start_threads
325                }
326                if is_end {
327                    #end_threads
328                }
329                #normal_threads
330            }
331        });
332
333        // Write epsilon-propogation of threads to the token stream for exec
334        let start_end_threads: proc_macro2::TokenStream = new_threads
335            .iter()
336            .filter(|t| t.start_only && t.end_only)
337            .map(ThreadUpdates::serialize_thread_update_exec)
338            .collect();
339        let start_threads: proc_macro2::TokenStream = new_threads
340            .iter()
341            .filter(|t| t.start_only && !t.end_only)
342            .map(ThreadUpdates::serialize_thread_update_exec)
343            .collect();
344        let end_threads: proc_macro2::TokenStream = new_threads
345            .iter()
346            .filter(|t| !t.start_only && t.end_only)
347            .map(ThreadUpdates::serialize_thread_update_exec)
348            .collect();
349        let normal_threads: proc_macro2::TokenStream = new_threads
350            .iter()
351            .filter(|t| !t.start_only && !t.end_only)
352            .map(ThreadUpdates::serialize_thread_update_exec)
353            .collect();
354
355        transition_epsilons_exec.extend(quote! {
356            VMStates::#label => {
357                if is_start && is_end {
358                    #start_end_threads
359                }
360                if is_start {
361                    #start_threads
362                }
363                if is_end {
364                    #end_threads
365                }
366                #normal_threads
367            }
368        });
369    }
370
371    let transition_epsilons_test = quote! {
372        fn transition_epsilons_test(
373            list: &mut [bool],
374            idx: usize,
375            len: usize,
376        ) {
377            let is_start = idx == 0;
378            let is_end = idx == len;
379            #transition_epsilons_test
380        }
381    };
382    let transition_epsilons_exec = quote! {
383        fn transition_epsilons_exec(
384            threads: &[::ere::pike_vm::PikeVMThread<#capture_groups, VMStates>],
385            idx: usize,
386            len: usize,
387        ) -> ::std::vec::Vec<::ere::pike_vm::PikeVMThread<#capture_groups, VMStates>> {
388            let is_start = idx == 0;
389            let is_end = idx == len;
390            let mut occupied_states = ::std::vec![false; #num_states];
391            let mut out = ::std::vec::Vec::<::ere::pike_vm::PikeVMThread<#capture_groups, VMStates>>::new();
392            for thread in threads {
393                match thread.state {
394                    #transition_epsilons_exec
395                }
396            }
397            return out;
398        }
399    };
400
401    return (transition_epsilons_test, transition_epsilons_exec);
402}
403
404/// Converts a [`WorkingNFA`] into a format that, when returned by a proc macro, will
405/// create the corresponding Pike VM.
406///
407/// Will evaluate to a `const` pair `(test_fn, exec_fn)`.
408pub(crate) fn serialize_pike_vm_token_stream(nfa: &WorkingNFA) -> proc_macro2::TokenStream {
409    let WorkingNFA { states, .. } = nfa;
410    let capture_groups = nfa.num_capture_groups();
411    let excluded_states = compute_excluded_states(nfa);
412
413    let enum_states: proc_macro2::TokenStream = std::iter::IntoIterator::into_iter(0..states.len())
414        .filter(|i| !excluded_states[*i])
415        .map(|i| {
416            let label = vmstate_label(i);
417            return quote! { #label, };
418        })
419        .collect();
420    let state_count = states.len(); // TODO: not all of these are used, so we may be able to slightly reduce usage.
421    let accept_state = vmstate_label(states.len() - 1);
422
423    let (transition_symbols_test, transition_symbols_exec) =
424        serialize_pike_vm_symbol_propogation(nfa);
425    let (transition_epsilons_test, transition_epsilons_exec) =
426        serialize_pike_vm_epsilon_propogation(nfa);
427
428    return quote! {{
429        #[derive(Clone, Copy, PartialEq, Eq, Debug)]
430        enum VMStates {
431            #enum_states
432        }
433
434        #transition_symbols_test
435        #transition_symbols_exec
436        #transition_epsilons_test
437        #transition_epsilons_exec
438
439        fn test(text: &str) -> bool {
440            let mut list = [false; #state_count];
441            let mut new_list = [false; #state_count];
442            list[0] = true;
443
444            transition_epsilons_test(&mut list, 0, text.len());
445            for (i, c) in text.char_indices() {
446                transition_symbols_test(&list, &mut new_list, c);
447                if new_list.iter().all(|b| !b) {
448                    return false;
449                }
450                ::std::mem::swap(&mut list, &mut new_list);
451                transition_epsilons_test(&mut list, i + c.len_utf8(), text.len());
452                new_list.fill(false);
453            }
454
455            return list[#state_count - 1];
456        }
457        fn exec<'a>(text: &'a str) -> Option<[Option<&'a str>; #capture_groups]> {
458            let mut threads = ::std::vec::Vec::<::ere::pike_vm::PikeVMThread<#capture_groups, VMStates>>::new();
459            threads.push(::ere::pike_vm::PikeVMThread {
460                state: VMStates::State0,
461                captures: [(usize::MAX, usize::MAX); #capture_groups],
462            });
463
464            let new_threads = transition_epsilons_exec(&threads, 0, text.len());
465            threads = new_threads;
466
467            for (i, c) in text.char_indices() {
468                let new_threads = transition_symbols_exec(&threads, c);
469                threads = new_threads;
470                let new_threads = transition_epsilons_exec(&threads, i + c.len_utf8(), text.len());
471                threads = new_threads;
472                if threads.is_empty() {
473                    return None;
474                }
475            }
476
477            let final_capture_bounds = threads
478                .into_iter()
479                .find(|t| t.state == VMStates::#accept_state)?
480                .captures;
481            let mut captures = [::core::option::Option::None; #capture_groups];
482            for (i, (start, end)) in final_capture_bounds.into_iter().enumerate() {
483                if start != usize::MAX {
484                    assert_ne!(end, usize::MAX);
485                    // assert!(start <= end);
486                    captures[i] = text.get(start..end);
487                    assert!(captures[i].is_some());
488                } else {
489                    assert_eq!(end, usize::MAX);
490                }
491            }
492            return Some(captures);
493        }
494
495        (test, exec)
496    }};
497}