ere_core/
pike_vm.rs

1//! Implements a Pike VM-like regex engine.
2
3use std::collections::VecDeque;
4
5use quote::quote;
6
7use crate::{
8    nfa_static,
9    working_nfa::{EpsilonType, WorkingNFA, WorkingTransition},
10};
11
12#[derive(Clone, Debug)]
13pub struct PikeVMThread<const N: usize, S: Send + Sync + Copy + Eq> {
14    pub state: S,
15    pub captures: [(usize, usize); N],
16}
17
18/// Not exactly the PikeVM, but close enough that I am naming it that.
19/// It works similarly, except that since we are building at compile-time, there are benefits from inlining splitting.
20///
21/// Due to the optimizations done earlier for the [`WorkingNFA`], we also have a slightly different NFA structure.
22/// Currently we flatten all epsilon transitions for the VM so that epsilon transitions are at most a single step between symbols.
23/// I'll have to review to ensure we avoid this causing large binary size overhead,
24/// but it should be worst-case `O(n^2)` in the number of states, and far fewer on average.
25pub struct PikeVM<const N: usize> {
26    test_fn: fn(&str) -> bool,
27    exec_fn: for<'a> fn(&'a str) -> Option<[Option<&'a str>; N]>,
28}
29impl<const N: usize> PikeVM<N> {
30    pub fn test(&self, text: &str) -> bool {
31        return (self.test_fn)(text);
32    }
33    pub fn exec<'a>(&self, text: &'a str) -> Option<[Option<&'a str>; N]> {
34        return (self.exec_fn)(text);
35    }
36
37    /// Only intended for internal use by macros.
38    pub const fn __load(
39        test_fn: fn(&str) -> bool,
40        exec_fn: for<'a> fn(&'a str) -> Option<[Option<&'a str>; N]>,
41    ) -> PikeVM<N> {
42        return PikeVM { test_fn, exec_fn };
43    }
44}
45
46fn vmstate_label(idx: usize) -> proc_macro2::Ident {
47    let label = format!("State{idx}");
48    return proc_macro2::Ident::new(&label, proc_macro2::Span::call_site());
49}
50
51/// Since we are shortcutting the epsilon transitions, we can skip printing
52/// states that have only epsilon transitions and are not the start/end states
53fn compute_excluded_states(nfa: &WorkingNFA) -> Vec<bool> {
54    let mut out = vec![true; nfa.states.len()];
55    out[0] = false;
56    out[nfa.states.len() - 1] = false;
57    for (from, state) in nfa.states.iter().enumerate() {
58        for t in &state.transitions {
59            out[from] = false;
60            out[t.to] = false;
61        }
62    }
63
64    return out;
65}
66
67/// Assumes the `VMStates` enum is already created locally in the token stream
68///
69/// Creates the function for running symbol transitions on the pike VM
70fn serialize_pike_vm_symbol_propogation(
71    nfa: &WorkingNFA,
72) -> (proc_macro2::TokenStream, proc_macro2::TokenStream) {
73    let WorkingNFA { states } = nfa;
74    let capture_groups = nfa.num_capture_groups();
75    let excluded_states = compute_excluded_states(nfa);
76
77    fn make_symbol_transition_test(t: &WorkingTransition) -> proc_macro2::TokenStream {
78        let WorkingTransition { symbol, to } = t;
79        let symbol = nfa_static::AtomStatic::serialize_as_token_stream(symbol);
80        return quote! {
81            {
82                let symbol = #symbol;
83                if symbol.check(c) {
84                    new_list[#to] = true;
85                }
86            }
87        };
88    }
89
90    fn make_symbol_transition_exec(t: &WorkingTransition) -> proc_macro2::TokenStream {
91        let WorkingTransition { symbol, to } = t;
92        let symbol = nfa_static::AtomStatic::serialize_as_token_stream(symbol);
93        let to_label = vmstate_label(*to);
94        return quote! {
95            {
96                let symbol = #symbol;
97                if symbol.check(c) {
98                    out.push(
99                        ::ere_core::pike_vm::PikeVMThread {
100                            state: VMStates::#to_label,
101                            captures: thread.captures.clone(),
102                        },
103                    );
104                }
105            }
106        };
107    }
108
109    let transition_symbols_defs_test: proc_macro2::TokenStream = states
110        .iter()
111        .enumerate()
112        .filter(|(i, _)| !excluded_states[*i])
113        .map(|(i, state)| {
114            let state_transitions: proc_macro2::TokenStream = state
115                .transitions
116                .iter()
117                .map(make_symbol_transition_test)
118                .collect();
119
120            return quote! {
121                if list[#i] {
122                    #state_transitions
123                }
124            };
125        })
126        .collect();
127
128    let transition_symbols_defs_exec: proc_macro2::TokenStream = states
129        .iter()
130        .enumerate()
131        .filter(|(i, _)| !excluded_states[*i])
132        .map(|(i, state)| {
133            let label = vmstate_label(i);
134            let state_transitions: proc_macro2::TokenStream = state
135                .transitions
136                .iter()
137                .map(make_symbol_transition_exec)
138                .collect();
139
140            return quote! {
141                VMStates::#label => {
142                    #state_transitions
143                }
144            };
145        })
146        .collect();
147
148    let transition_symbols_test = quote! {
149        fn transition_symbols_test(
150            list: &[bool],
151            new_list: &mut [bool],
152            c: char,
153        ) {
154            #transition_symbols_defs_test
155        }
156    };
157    let transition_symbols_exec = quote! {
158        fn transition_symbols_exec(
159            threads: &[::ere_core::pike_vm::PikeVMThread<#capture_groups, VMStates>],
160            c: char,
161        ) -> ::std::vec::Vec<::ere_core::pike_vm::PikeVMThread<#capture_groups, VMStates>> {
162            let mut out = ::std::vec::Vec::<::ere_core::pike_vm::PikeVMThread<#capture_groups, VMStates>>::new();
163            for thread in threads {
164                match thread.state {
165                    #transition_symbols_defs_exec
166                }
167            }
168            return out;
169        }
170    };
171
172    return (transition_symbols_test, transition_symbols_exec);
173}
174
175#[derive(Clone, PartialEq, Eq)]
176struct ThreadUpdates {
177    pub state: usize,
178    pub update_captures: Vec<(bool, bool)>,
179    pub start_only: bool,
180    pub end_only: bool,
181}
182impl ThreadUpdates {
183    /// Creates a block which takes `list: &mut [bool; STATE_NUM]` from its local context, updates it in-place using `self` (compile-time).
184    pub fn serialize_thread_update_test(&self) -> proc_macro2::TokenStream {
185        let new_state = self.state;
186        return quote! {{
187            list[#new_state] = true;
188        }};
189    }
190    /// Creates a block which takes `thread` from its local context, updates it using `self` (compile-time),
191    /// and appends it to `out` from its local context.
192    pub fn serialize_thread_update_exec(&self) -> proc_macro2::TokenStream {
193        let new_state = vmstate_label(self.state);
194        let mut capture_updates = proc_macro2::TokenStream::new();
195        for (i, (start, end)) in self.update_captures.iter().cloned().enumerate() {
196            if start {
197                capture_updates.extend(quote! {
198                    new_thread.captures[#i].0 = idx;
199                });
200            }
201            if end {
202                capture_updates.extend(quote! {
203                    new_thread.captures[#i].1 = idx;
204                });
205            }
206        }
207
208        return quote! {{
209            let mut new_thread = thread.clone();
210            new_thread.state = VMStates::#new_state;
211
212            #capture_updates
213
214            out.push(new_thread);
215        }};
216    }
217}
218
219fn calculate_epsilon_propogations(nfa: &WorkingNFA, state: usize) -> Vec<ThreadUpdates> {
220    let WorkingNFA { states } = nfa;
221    let capture_groups = nfa.num_capture_groups();
222    // reduce epsilons to occur in a single step
223    let mut new_threads = vec![ThreadUpdates {
224        state,
225        update_captures: vec![(false, false); capture_groups],
226        start_only: false,
227        end_only: false,
228    }];
229    let mut queue = VecDeque::new();
230    queue.push_back(new_threads[0].clone());
231    while let Some(thread) = queue.pop_front() {
232        // enumerate next step of new threads
233        for e in &states[thread.state].epsilons {
234            let mut new_thread = thread.clone();
235            new_thread.state = e.to;
236            match e.special {
237                EpsilonType::None => {}
238                EpsilonType::StartAnchor => new_thread.start_only = true,
239                EpsilonType::EndAnchor => new_thread.end_only = true,
240                EpsilonType::StartCapture(capture_group) => {
241                    new_thread.update_captures[capture_group].0 = true
242                }
243                EpsilonType::EndCapture(capture_group) => {
244                    new_thread.update_captures[capture_group].1 = true
245                }
246            }
247
248            if new_threads.contains(&new_thread) {
249                continue;
250            }
251            queue.push_back(new_thread.clone());
252            new_threads.push(new_thread);
253        }
254    }
255    return new_threads;
256}
257
258/// Assumes the `VMStates` enum is already created locally in the token stream
259fn serialize_pike_vm_epsilon_propogation(
260    nfa: &WorkingNFA,
261) -> (proc_macro2::TokenStream, proc_macro2::TokenStream) {
262    let WorkingNFA { states } = nfa;
263    let capture_groups = nfa.num_capture_groups();
264    let excluded_states = compute_excluded_states(nfa);
265
266    // Generate code to propogate/split a thread according to epsilon transitions
267    let mut transition_epsilons_test = proc_macro2::TokenStream::new();
268    let mut transition_epsilons_exec = proc_macro2::TokenStream::new();
269    for (i, _) in states.iter().enumerate() {
270        if excluded_states[i] {
271            // since we propogate, some states are now useless if they are intermediate
272            // with only epsilon transitions
273            continue;
274        }
275        let mut new_threads = calculate_epsilon_propogations(nfa, i);
276        new_threads.retain(|t| !excluded_states[t.state]);
277
278        // Write epsilon-propogation of threads to the token stream for test
279        let start_end_threads: proc_macro2::TokenStream = new_threads
280            .iter()
281            .filter(|t| t.start_only && t.end_only)
282            .map(ThreadUpdates::serialize_thread_update_test)
283            .collect();
284        let start_threads: proc_macro2::TokenStream = new_threads
285            .iter()
286            .filter(|t| t.start_only && !t.end_only)
287            .map(ThreadUpdates::serialize_thread_update_test)
288            .collect();
289        let end_threads: proc_macro2::TokenStream = new_threads
290            .iter()
291            .filter(|t| !t.start_only && t.end_only)
292            .map(ThreadUpdates::serialize_thread_update_test)
293            .collect();
294        let normal_threads: proc_macro2::TokenStream = new_threads
295            .iter()
296            .filter(|t| !t.start_only && !t.end_only)
297            .map(ThreadUpdates::serialize_thread_update_test)
298            .collect();
299
300        let label = vmstate_label(i);
301        transition_epsilons_test.extend(quote! {
302            if list[#i] {
303                if is_start && is_end {
304                    #start_end_threads
305                }
306                if is_start {
307                    #start_threads
308                }
309                if is_end {
310                    #end_threads
311                }
312                #normal_threads
313            }
314        });
315
316        // Write epsilon-propogation of threads to the token stream for exec
317        let start_end_threads: proc_macro2::TokenStream = new_threads
318            .iter()
319            .filter(|t| t.start_only && t.end_only)
320            .map(ThreadUpdates::serialize_thread_update_exec)
321            .collect();
322        let start_threads: proc_macro2::TokenStream = new_threads
323            .iter()
324            .filter(|t| t.start_only && !t.end_only)
325            .map(ThreadUpdates::serialize_thread_update_exec)
326            .collect();
327        let end_threads: proc_macro2::TokenStream = new_threads
328            .iter()
329            .filter(|t| !t.start_only && t.end_only)
330            .map(ThreadUpdates::serialize_thread_update_exec)
331            .collect();
332        let normal_threads: proc_macro2::TokenStream = new_threads
333            .iter()
334            .filter(|t| !t.start_only && !t.end_only)
335            .map(ThreadUpdates::serialize_thread_update_exec)
336            .collect();
337
338        transition_epsilons_exec.extend(quote! {
339            VMStates::#label => {
340                if is_start && is_end {
341                    #start_end_threads
342                }
343                if is_start {
344                    #start_threads
345                }
346                if is_end {
347                    #end_threads
348                }
349                #normal_threads
350            }
351        });
352    }
353
354    let transition_epsilons_test = quote! {
355        fn transition_epsilons_test(
356            list: &mut [bool],
357            idx: usize,
358            len: usize,
359        ) {
360            let is_start = idx == 0;
361            let is_end = idx == len;
362            #transition_epsilons_test
363        }
364    };
365    let transition_epsilons_exec = quote! {
366        fn transition_epsilons_exec(
367            threads: &[::ere_core::pike_vm::PikeVMThread<#capture_groups, VMStates>],
368            idx: usize,
369            len: usize,
370        ) -> ::std::vec::Vec<::ere_core::pike_vm::PikeVMThread<#capture_groups, VMStates>> {
371            let is_start = idx == 0;
372            let is_end = idx == len;
373            let mut out = ::std::vec::Vec::<::ere_core::pike_vm::PikeVMThread<#capture_groups, VMStates>>::new();
374            for thread in threads {
375                match thread.state {
376                    #transition_epsilons_exec
377                }
378            }
379            return out;
380        }
381    };
382
383    return (transition_epsilons_test, transition_epsilons_exec);
384}
385
386/// Converts a [`WorkingNFA`] into a format that, when returned by a proc macro, will
387/// create the corresponding Pike VM.
388pub(crate) fn serialize_pike_vm_token_stream(nfa: &WorkingNFA) -> proc_macro2::TokenStream {
389    let WorkingNFA { states, .. } = nfa;
390    let capture_groups = nfa.num_capture_groups();
391    let excluded_states = compute_excluded_states(nfa);
392
393    let enum_states: proc_macro2::TokenStream = std::iter::IntoIterator::into_iter(0..states.len())
394        .filter(|i| !excluded_states[*i])
395        .map(|i| {
396            let label = vmstate_label(i);
397            return quote! { #label, };
398        })
399        .collect();
400    let state_count = states.len(); // TODO: not all of these are used, so we may be able to slightly reduce usage.
401    let accept_state = vmstate_label(states.len() - 1);
402
403    let (transition_symbols_test, transition_symbols_exec) =
404        serialize_pike_vm_symbol_propogation(nfa);
405    let (transition_epsilons_test, transition_epsilons_exec) =
406        serialize_pike_vm_epsilon_propogation(nfa);
407
408    return quote! {{
409        #[derive(Clone, Copy, PartialEq, Eq, Debug)]
410        enum VMStates {
411            #enum_states
412        }
413
414        #transition_symbols_test
415        #transition_symbols_exec
416        #transition_epsilons_test
417        #transition_epsilons_exec
418
419        fn test(text: &str) -> bool {
420            let mut list = [false; #state_count];
421            let mut new_list = [false; #state_count];
422            list[0] = true;
423
424            transition_epsilons_test(&mut list, 0, text.len());
425            for (i, c) in text.char_indices() {
426                transition_symbols_test(&list, &mut new_list, c);
427                if new_list.iter().all(|b| !b) {
428                    return false;
429                }
430                ::std::mem::swap(&mut list, &mut new_list);
431                transition_epsilons_test(&mut list, i + c.len_utf8(), text.len());
432                new_list.fill(false);
433            }
434
435            return list[#state_count - 1];
436        }
437        fn exec<'a>(text: &'a str) -> Option<[Option<&'a str>; #capture_groups]> {
438            let mut threads = ::std::vec::Vec::<::ere_core::pike_vm::PikeVMThread<#capture_groups, VMStates>>::new();
439            threads.push(::ere_core::pike_vm::PikeVMThread {
440                state: VMStates::State0,
441                captures: [(usize::MAX, usize::MAX); #capture_groups],
442            });
443
444            let new_threads = transition_epsilons_exec(&threads, 0, text.len());
445            threads = new_threads;
446
447            for (i, c) in text.char_indices() {
448                let new_threads = transition_symbols_exec(&threads, c);
449                threads = new_threads;
450                let new_threads = transition_epsilons_exec(&threads, i + c.len_utf8(), text.len());
451                threads = new_threads;
452                if threads.is_empty() {
453                    return None;
454                }
455            }
456
457            let final_capture_bounds = threads
458                .into_iter()
459                .find(|t| t.state == VMStates::#accept_state)?
460                .captures;
461            let mut captures = [::core::option::Option::None; #capture_groups];
462            for (i, (start, end)) in final_capture_bounds.into_iter().enumerate() {
463                if start != usize::MAX {
464                    assert_ne!(end, usize::MAX);
465                    // assert!(start <= end);
466                    captures[i] = text.get(start..end);
467                    assert!(captures[i].is_some());
468                } else {
469                    assert_eq!(end, usize::MAX);
470                }
471            }
472            return Some(captures);
473        }
474
475        ::ere_core::pike_vm::PikeVM::<#capture_groups>::__load(test, exec)
476    }};
477}