1use crate::{
10 nfa_static,
11 working_nfa::{EpsilonType, WorkingNFA, WorkingTransition},
12};
13use quote::quote;
14use std::fmt::Write;
15
16#[derive(Clone)]
17pub struct PikeVMThread<const N: usize, S: Send + Sync + Copy + Eq> {
18 pub state: S,
19 pub captures: [(usize, usize); N],
20}
21impl<const N: usize, S: Send + Sync + Copy + Eq + std::fmt::Debug> std::fmt::Debug
22 for PikeVMThread<N, S>
23{
24 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
25 struct CapturesDebug<'a, const N: usize>(&'a [(usize, usize); N]);
26 impl<'a, const N: usize> std::fmt::Debug for CapturesDebug<'a, N> {
27 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
28 f.write_char('[')?;
29 for (i, endpoints) in self.0.iter().enumerate() {
30 if i != 0 {
31 f.write_str(", ")?;
32 }
33 match endpoints {
34 (usize::MAX, usize::MAX) => f.write_str("(_, _)")?,
35 (start, usize::MAX) => write!(f, "({start}, _)")?,
36 (usize::MAX, end) => write!(f, "(_, {end})")?,
37 (start, end) => write!(f, "({start}, {end})")?,
38 }
39 }
40 return f.write_char(']');
41 }
42 }
43 return f
44 .debug_struct("PikeVMThread")
45 .field("state", &self.state)
46 .field("captures", &CapturesDebug(&self.captures))
47 .finish();
48 }
49}
50
51fn vmstate_label(idx: usize) -> proc_macro2::Ident {
52 let label = format!("State{idx}");
53 return proc_macro2::Ident::new(&label, proc_macro2::Span::call_site());
54}
55
56fn compute_excluded_states(nfa: &WorkingNFA) -> Vec<bool> {
59 let mut out = vec![true; nfa.states.len()];
60 out[0] = false;
61 out[nfa.states.len() - 1] = false;
62 for (from, state) in nfa.states.iter().enumerate() {
63 for t in &state.transitions {
64 out[from] = false;
65 out[t.to] = false;
66 }
67 }
68
69 return out;
70}
71
72fn serialize_pike_vm_symbol_propogation(
76 nfa: &WorkingNFA,
77) -> (proc_macro2::TokenStream, proc_macro2::TokenStream) {
78 let WorkingNFA { states } = nfa;
79 let capture_groups = nfa.num_capture_groups();
80 let excluded_states = compute_excluded_states(nfa);
81
82 fn make_symbol_transition_test(t: &WorkingTransition) -> proc_macro2::TokenStream {
83 let WorkingTransition { symbol, to } = t;
84 let symbol = nfa_static::AtomStatic::serialize_as_token_stream(symbol);
85 return quote! {
86 {
87 let symbol = #symbol;
88 if symbol.check(c) {
89 new_list[#to] = true;
90 }
91 }
92 };
93 }
94
95 fn make_symbol_transition_exec(t: &WorkingTransition) -> proc_macro2::TokenStream {
96 let WorkingTransition { symbol, to } = t;
97 let symbol = nfa_static::AtomStatic::serialize_as_token_stream(symbol);
98 let to_label = vmstate_label(*to);
99 return quote! {
100 {
101 let symbol = #symbol;
102 if symbol.check(c) {
103 out.push(
104 ::ere::pike_vm::PikeVMThread {
105 state: VMStates::#to_label,
106 captures: thread.captures.clone(),
107 },
108 );
109 }
110 }
111 };
112 }
113
114 let transition_symbols_defs_test: proc_macro2::TokenStream = states
115 .iter()
116 .enumerate()
117 .filter(|(i, _)| !excluded_states[*i])
118 .map(|(i, state)| {
119 let state_transitions: proc_macro2::TokenStream = state
120 .transitions
121 .iter()
122 .map(make_symbol_transition_test)
123 .collect();
124
125 return quote! {
126 if list[#i] {
127 #state_transitions
128 }
129 };
130 })
131 .collect();
132
133 let transition_symbols_defs_exec: proc_macro2::TokenStream = states
134 .iter()
135 .enumerate()
136 .filter(|(i, _)| !excluded_states[*i])
137 .map(|(i, state)| {
138 let label = vmstate_label(i);
139 let state_transitions: proc_macro2::TokenStream = state
140 .transitions
141 .iter()
142 .map(make_symbol_transition_exec)
143 .collect();
144
145 return quote! {
146 VMStates::#label => {
147 #state_transitions
148 }
149 };
150 })
151 .collect();
152
153 let transition_symbols_test = quote! {
154 fn transition_symbols_test(
155 list: &[bool],
156 new_list: &mut [bool],
157 c: char,
158 ) {
159 #transition_symbols_defs_test
160 }
161 };
162 let transition_symbols_exec = quote! {
163 fn transition_symbols_exec(
164 threads: &[::ere::pike_vm::PikeVMThread<#capture_groups, VMStates>],
165 c: char,
166 ) -> ::std::vec::Vec<::ere::pike_vm::PikeVMThread<#capture_groups, VMStates>> {
167 let mut out = ::std::vec::Vec::<::ere::pike_vm::PikeVMThread<#capture_groups, VMStates>>::new();
168 for thread in threads {
169 match thread.state {
170 #transition_symbols_defs_exec
171 }
172 }
173 return out;
174 }
175 };
176
177 return (transition_symbols_test, transition_symbols_exec);
178}
179
180#[derive(Clone, PartialEq, Eq)]
181struct ThreadUpdates {
182 pub state: usize,
183 pub update_captures: Vec<(bool, bool)>,
184 pub start_only: bool,
185 pub end_only: bool,
186}
187impl ThreadUpdates {
188 pub fn serialize_thread_update_test(&self) -> proc_macro2::TokenStream {
190 let new_state = self.state;
191 return quote! {{
192 list[#new_state] = true;
193 }};
194 }
195 pub fn serialize_thread_update_exec(&self) -> proc_macro2::TokenStream {
198 let new_state_idx = self.state;
199 let new_state = vmstate_label(self.state);
200 let mut capture_updates = proc_macro2::TokenStream::new();
201 for (i, (start, end)) in self.update_captures.iter().cloned().enumerate() {
202 if start {
203 capture_updates.extend(quote! {
204 new_thread.captures[#i].0 = idx;
205 });
206 }
207 if end {
208 capture_updates.extend(quote! {
209 new_thread.captures[#i].1 = idx;
210 });
211 }
212 }
213
214 return quote! {
215 if !occupied_states[#new_state_idx] {
216 let mut new_thread = thread.clone();
217 new_thread.state = VMStates::#new_state;
218
219 #capture_updates
220
221 out.push(new_thread);
222 occupied_states[#new_state_idx] = true;
223 }
224 };
225 }
226}
227
228fn calculate_epsilon_propogations(nfa: &WorkingNFA, state: usize) -> Vec<ThreadUpdates> {
229 let WorkingNFA { states } = nfa;
230 let capture_groups = nfa.num_capture_groups();
231 let mut new_threads = vec![];
233 fn traverse(
234 thread: ThreadUpdates,
235 states: &Vec<crate::working_nfa::WorkingState>,
236 out: &mut Vec<ThreadUpdates>,
237 ) {
238 out.push(thread.clone());
239 for e in &states[thread.state].epsilons {
240 let mut new_thread = thread.clone();
241 new_thread.state = e.to;
242 match e.special {
243 EpsilonType::None => {}
244 EpsilonType::StartAnchor => new_thread.start_only = true,
245 EpsilonType::EndAnchor => new_thread.end_only = true,
246 EpsilonType::StartCapture(capture_group) => {
247 new_thread.update_captures[capture_group].0 = true
248 }
249 EpsilonType::EndCapture(capture_group) => {
250 new_thread.update_captures[capture_group].1 = true
251 }
252 }
253
254 if !out.contains(&new_thread) {
255 traverse(new_thread, states, out);
256 }
257 }
258 }
259 traverse(
260 ThreadUpdates {
261 state,
262 update_captures: vec![(false, false); capture_groups],
263 start_only: false,
264 end_only: false,
265 },
266 states,
267 &mut new_threads,
268 );
269 return new_threads;
270}
271
272fn serialize_pike_vm_epsilon_propogation(
274 nfa: &WorkingNFA,
275) -> (proc_macro2::TokenStream, proc_macro2::TokenStream) {
276 let WorkingNFA { states } = nfa;
277 let capture_groups = nfa.num_capture_groups();
278 let num_states = states.len();
279 let excluded_states = compute_excluded_states(nfa);
280
281 let mut transition_epsilons_test = proc_macro2::TokenStream::new();
283 let mut transition_epsilons_exec = proc_macro2::TokenStream::new();
284 for (i, _) in states.iter().enumerate() {
285 if excluded_states[i] {
286 continue;
289 }
290 let mut new_threads = calculate_epsilon_propogations(nfa, i);
292 new_threads
293 .retain(|t| !states[t.state].transitions.is_empty() || t.state + 1 == num_states);
294
295 let start_end_threads: proc_macro2::TokenStream = new_threads
297 .iter()
298 .filter(|t| t.start_only && t.end_only)
299 .map(ThreadUpdates::serialize_thread_update_test)
300 .collect();
301 let start_threads: proc_macro2::TokenStream = new_threads
302 .iter()
303 .filter(|t| t.start_only && !t.end_only)
304 .map(ThreadUpdates::serialize_thread_update_test)
305 .collect();
306 let end_threads: proc_macro2::TokenStream = new_threads
307 .iter()
308 .filter(|t| !t.start_only && t.end_only)
309 .map(ThreadUpdates::serialize_thread_update_test)
310 .collect();
311 let normal_threads: proc_macro2::TokenStream = new_threads
312 .iter()
313 .filter(|t| !t.start_only && !t.end_only)
314 .map(ThreadUpdates::serialize_thread_update_test)
315 .collect();
316
317 let label = vmstate_label(i);
318 transition_epsilons_test.extend(quote! {
319 if list[#i] {
320 if is_start && is_end {
321 #start_end_threads
322 }
323 if is_start {
324 #start_threads
325 }
326 if is_end {
327 #end_threads
328 }
329 #normal_threads
330 }
331 });
332
333 let start_end_threads: proc_macro2::TokenStream = new_threads
335 .iter()
336 .filter(|t| t.start_only && t.end_only)
337 .map(ThreadUpdates::serialize_thread_update_exec)
338 .collect();
339 let start_threads: proc_macro2::TokenStream = new_threads
340 .iter()
341 .filter(|t| t.start_only && !t.end_only)
342 .map(ThreadUpdates::serialize_thread_update_exec)
343 .collect();
344 let end_threads: proc_macro2::TokenStream = new_threads
345 .iter()
346 .filter(|t| !t.start_only && t.end_only)
347 .map(ThreadUpdates::serialize_thread_update_exec)
348 .collect();
349 let normal_threads: proc_macro2::TokenStream = new_threads
350 .iter()
351 .filter(|t| !t.start_only && !t.end_only)
352 .map(ThreadUpdates::serialize_thread_update_exec)
353 .collect();
354
355 transition_epsilons_exec.extend(quote! {
356 VMStates::#label => {
357 if is_start && is_end {
358 #start_end_threads
359 }
360 if is_start {
361 #start_threads
362 }
363 if is_end {
364 #end_threads
365 }
366 #normal_threads
367 }
368 });
369 }
370
371 let transition_epsilons_test = quote! {
372 fn transition_epsilons_test(
373 list: &mut [bool],
374 idx: usize,
375 len: usize,
376 ) {
377 let is_start = idx == 0;
378 let is_end = idx == len;
379 #transition_epsilons_test
380 }
381 };
382 let transition_epsilons_exec = quote! {
383 fn transition_epsilons_exec(
384 threads: &[::ere::pike_vm::PikeVMThread<#capture_groups, VMStates>],
385 idx: usize,
386 len: usize,
387 ) -> ::std::vec::Vec<::ere::pike_vm::PikeVMThread<#capture_groups, VMStates>> {
388 let is_start = idx == 0;
389 let is_end = idx == len;
390 let mut occupied_states = ::std::vec![false; #num_states];
391 let mut out = ::std::vec::Vec::<::ere::pike_vm::PikeVMThread<#capture_groups, VMStates>>::new();
392 for thread in threads {
393 match thread.state {
394 #transition_epsilons_exec
395 }
396 }
397 return out;
398 }
399 };
400
401 return (transition_epsilons_test, transition_epsilons_exec);
402}
403
404pub(crate) fn serialize_pike_vm_token_stream(nfa: &WorkingNFA) -> proc_macro2::TokenStream {
409 let WorkingNFA { states, .. } = nfa;
410 let capture_groups = nfa.num_capture_groups();
411 let excluded_states = compute_excluded_states(nfa);
412
413 let enum_states: proc_macro2::TokenStream = std::iter::IntoIterator::into_iter(0..states.len())
414 .filter(|i| !excluded_states[*i])
415 .map(|i| {
416 let label = vmstate_label(i);
417 return quote! { #label, };
418 })
419 .collect();
420 let state_count = states.len(); let accept_state = vmstate_label(states.len() - 1);
422
423 let (transition_symbols_test, transition_symbols_exec) =
424 serialize_pike_vm_symbol_propogation(nfa);
425 let (transition_epsilons_test, transition_epsilons_exec) =
426 serialize_pike_vm_epsilon_propogation(nfa);
427
428 return quote! {{
429 #[derive(Clone, Copy, PartialEq, Eq, Debug)]
430 enum VMStates {
431 #enum_states
432 }
433
434 #transition_symbols_test
435 #transition_symbols_exec
436 #transition_epsilons_test
437 #transition_epsilons_exec
438
439 fn test(text: &str) -> bool {
440 let mut list = [false; #state_count];
441 let mut new_list = [false; #state_count];
442 list[0] = true;
443
444 transition_epsilons_test(&mut list, 0, text.len());
445 for (i, c) in text.char_indices() {
446 transition_symbols_test(&list, &mut new_list, c);
447 if new_list.iter().all(|b| !b) {
448 return false;
449 }
450 ::std::mem::swap(&mut list, &mut new_list);
451 transition_epsilons_test(&mut list, i + c.len_utf8(), text.len());
452 new_list.fill(false);
453 }
454
455 return list[#state_count - 1];
456 }
457 fn exec<'a>(text: &'a str) -> Option<[Option<&'a str>; #capture_groups]> {
458 let mut threads = ::std::vec::Vec::<::ere::pike_vm::PikeVMThread<#capture_groups, VMStates>>::new();
459 threads.push(::ere::pike_vm::PikeVMThread {
460 state: VMStates::State0,
461 captures: [(usize::MAX, usize::MAX); #capture_groups],
462 });
463
464 let new_threads = transition_epsilons_exec(&threads, 0, text.len());
465 threads = new_threads;
466
467 for (i, c) in text.char_indices() {
468 let new_threads = transition_symbols_exec(&threads, c);
469 threads = new_threads;
470 let new_threads = transition_epsilons_exec(&threads, i + c.len_utf8(), text.len());
471 threads = new_threads;
472 if threads.is_empty() {
473 return None;
474 }
475 }
476
477 let final_capture_bounds = threads
478 .into_iter()
479 .find(|t| t.state == VMStates::#accept_state)?
480 .captures;
481 let mut captures = [::core::option::Option::None; #capture_groups];
482 for (i, (start, end)) in final_capture_bounds.into_iter().enumerate() {
483 if start != usize::MAX {
484 assert_ne!(end, usize::MAX);
485 captures[i] = text.get(start..end);
487 assert!(captures[i].is_some());
488 } else {
489 assert_eq!(end, usize::MAX);
490 }
491 }
492 return Some(captures);
493 }
494
495 (test, exec)
496 }};
497}