1use std::collections::VecDeque;
4
5use quote::quote;
6
7use crate::{
8 nfa_static,
9 working_nfa::{EpsilonType, WorkingNFA, WorkingTransition},
10};
11
12#[derive(Clone, Debug)]
13pub struct PikeVMThread<const N: usize, S: Send + Sync + Copy + Eq> {
14 pub state: S,
15 pub captures: [(usize, usize); N],
16}
17
18pub struct PikeVM<const N: usize> {
26 test_fn: fn(&str) -> bool,
27 exec_fn: for<'a> fn(&'a str) -> Option<[Option<&'a str>; N]>,
28}
29impl<const N: usize> PikeVM<N> {
30 pub fn test(&self, text: &str) -> bool {
31 return (self.test_fn)(text);
32 }
33 pub fn exec<'a>(&self, text: &'a str) -> Option<[Option<&'a str>; N]> {
34 return (self.exec_fn)(text);
35 }
36
37 pub const fn __load(
39 test_fn: fn(&str) -> bool,
40 exec_fn: for<'a> fn(&'a str) -> Option<[Option<&'a str>; N]>,
41 ) -> PikeVM<N> {
42 return PikeVM { test_fn, exec_fn };
43 }
44}
45
46fn vmstate_label(idx: usize) -> proc_macro2::Ident {
47 let label = format!("State{idx}");
48 return proc_macro2::Ident::new(&label, proc_macro2::Span::call_site());
49}
50
51fn compute_excluded_states(nfa: &WorkingNFA) -> Vec<bool> {
54 let mut out = vec![true; nfa.states.len()];
55 out[0] = false;
56 out[nfa.states.len() - 1] = false;
57 for (from, state) in nfa.states.iter().enumerate() {
58 for t in &state.transitions {
59 out[from] = false;
60 out[t.to] = false;
61 }
62 }
63
64 return out;
65}
66
67fn serialize_pike_vm_symbol_propogation(
71 nfa: &WorkingNFA,
72) -> (proc_macro2::TokenStream, proc_macro2::TokenStream) {
73 let WorkingNFA { states } = nfa;
74 let capture_groups = nfa.num_capture_groups();
75 let excluded_states = compute_excluded_states(nfa);
76
77 fn make_symbol_transition_test(t: &WorkingTransition) -> proc_macro2::TokenStream {
78 let WorkingTransition { symbol, to } = t;
79 let symbol = nfa_static::AtomStatic::serialize_as_token_stream(symbol);
80 return quote! {
81 {
82 let symbol = #symbol;
83 if symbol.check(c) {
84 new_list[#to] = true;
85 }
86 }
87 };
88 }
89
90 fn make_symbol_transition_exec(t: &WorkingTransition) -> proc_macro2::TokenStream {
91 let WorkingTransition { symbol, to } = t;
92 let symbol = nfa_static::AtomStatic::serialize_as_token_stream(symbol);
93 let to_label = vmstate_label(*to);
94 return quote! {
95 {
96 let symbol = #symbol;
97 if symbol.check(c) {
98 out.push(
99 ::ere_core::pike_vm::PikeVMThread {
100 state: VMStates::#to_label,
101 captures: thread.captures.clone(),
102 },
103 );
104 }
105 }
106 };
107 }
108
109 let transition_symbols_defs_test: proc_macro2::TokenStream = states
110 .iter()
111 .enumerate()
112 .filter(|(i, _)| !excluded_states[*i])
113 .map(|(i, state)| {
114 let state_transitions: proc_macro2::TokenStream = state
115 .transitions
116 .iter()
117 .map(make_symbol_transition_test)
118 .collect();
119
120 return quote! {
121 if list[#i] {
122 #state_transitions
123 }
124 };
125 })
126 .collect();
127
128 let transition_symbols_defs_exec: proc_macro2::TokenStream = states
129 .iter()
130 .enumerate()
131 .filter(|(i, _)| !excluded_states[*i])
132 .map(|(i, state)| {
133 let label = vmstate_label(i);
134 let state_transitions: proc_macro2::TokenStream = state
135 .transitions
136 .iter()
137 .map(make_symbol_transition_exec)
138 .collect();
139
140 return quote! {
141 VMStates::#label => {
142 #state_transitions
143 }
144 };
145 })
146 .collect();
147
148 let transition_symbols_test = quote! {
149 fn transition_symbols_test(
150 list: &[bool],
151 new_list: &mut [bool],
152 c: char,
153 ) {
154 #transition_symbols_defs_test
155 }
156 };
157 let transition_symbols_exec = quote! {
158 fn transition_symbols_exec(
159 threads: &[::ere_core::pike_vm::PikeVMThread<#capture_groups, VMStates>],
160 c: char,
161 ) -> ::std::vec::Vec<::ere_core::pike_vm::PikeVMThread<#capture_groups, VMStates>> {
162 let mut out = ::std::vec::Vec::<::ere_core::pike_vm::PikeVMThread<#capture_groups, VMStates>>::new();
163 for thread in threads {
164 match thread.state {
165 #transition_symbols_defs_exec
166 }
167 }
168 return out;
169 }
170 };
171
172 return (transition_symbols_test, transition_symbols_exec);
173}
174
175#[derive(Clone, PartialEq, Eq)]
176struct ThreadUpdates {
177 pub state: usize,
178 pub update_captures: Vec<(bool, bool)>,
179 pub start_only: bool,
180 pub end_only: bool,
181}
182impl ThreadUpdates {
183 pub fn serialize_thread_update_test(&self) -> proc_macro2::TokenStream {
185 let new_state = self.state;
186 return quote! {{
187 list[#new_state] = true;
188 }};
189 }
190 pub fn serialize_thread_update_exec(&self) -> proc_macro2::TokenStream {
193 let new_state = vmstate_label(self.state);
194 let mut capture_updates = proc_macro2::TokenStream::new();
195 for (i, (start, end)) in self.update_captures.iter().cloned().enumerate() {
196 if start {
197 capture_updates.extend(quote! {
198 new_thread.captures[#i].0 = idx;
199 });
200 }
201 if end {
202 capture_updates.extend(quote! {
203 new_thread.captures[#i].1 = idx;
204 });
205 }
206 }
207
208 return quote! {{
209 let mut new_thread = thread.clone();
210 new_thread.state = VMStates::#new_state;
211
212 #capture_updates
213
214 out.push(new_thread);
215 }};
216 }
217}
218
219fn calculate_epsilon_propogations(nfa: &WorkingNFA, state: usize) -> Vec<ThreadUpdates> {
220 let WorkingNFA { states } = nfa;
221 let capture_groups = nfa.num_capture_groups();
222 let mut new_threads = vec![ThreadUpdates {
224 state,
225 update_captures: vec![(false, false); capture_groups],
226 start_only: false,
227 end_only: false,
228 }];
229 let mut queue = VecDeque::new();
230 queue.push_back(new_threads[0].clone());
231 while let Some(thread) = queue.pop_front() {
232 for e in &states[thread.state].epsilons {
234 let mut new_thread = thread.clone();
235 new_thread.state = e.to;
236 match e.special {
237 EpsilonType::None => {}
238 EpsilonType::StartAnchor => new_thread.start_only = true,
239 EpsilonType::EndAnchor => new_thread.end_only = true,
240 EpsilonType::StartCapture(capture_group) => {
241 new_thread.update_captures[capture_group].0 = true
242 }
243 EpsilonType::EndCapture(capture_group) => {
244 new_thread.update_captures[capture_group].1 = true
245 }
246 }
247
248 if new_threads.contains(&new_thread) {
249 continue;
250 }
251 queue.push_back(new_thread.clone());
252 new_threads.push(new_thread);
253 }
254 }
255 return new_threads;
256}
257
258fn serialize_pike_vm_epsilon_propogation(
260 nfa: &WorkingNFA,
261) -> (proc_macro2::TokenStream, proc_macro2::TokenStream) {
262 let WorkingNFA { states } = nfa;
263 let capture_groups = nfa.num_capture_groups();
264 let excluded_states = compute_excluded_states(nfa);
265
266 let mut transition_epsilons_test = proc_macro2::TokenStream::new();
268 let mut transition_epsilons_exec = proc_macro2::TokenStream::new();
269 for (i, _) in states.iter().enumerate() {
270 if excluded_states[i] {
271 continue;
274 }
275 let mut new_threads = calculate_epsilon_propogations(nfa, i);
276 new_threads.retain(|t| !excluded_states[t.state]);
277
278 let start_end_threads: proc_macro2::TokenStream = new_threads
280 .iter()
281 .filter(|t| t.start_only && t.end_only)
282 .map(ThreadUpdates::serialize_thread_update_test)
283 .collect();
284 let start_threads: proc_macro2::TokenStream = new_threads
285 .iter()
286 .filter(|t| t.start_only && !t.end_only)
287 .map(ThreadUpdates::serialize_thread_update_test)
288 .collect();
289 let end_threads: proc_macro2::TokenStream = new_threads
290 .iter()
291 .filter(|t| !t.start_only && t.end_only)
292 .map(ThreadUpdates::serialize_thread_update_test)
293 .collect();
294 let normal_threads: proc_macro2::TokenStream = new_threads
295 .iter()
296 .filter(|t| !t.start_only && !t.end_only)
297 .map(ThreadUpdates::serialize_thread_update_test)
298 .collect();
299
300 let label = vmstate_label(i);
301 transition_epsilons_test.extend(quote! {
302 if list[#i] {
303 if is_start && is_end {
304 #start_end_threads
305 }
306 if is_start {
307 #start_threads
308 }
309 if is_end {
310 #end_threads
311 }
312 #normal_threads
313 }
314 });
315
316 let start_end_threads: proc_macro2::TokenStream = new_threads
318 .iter()
319 .filter(|t| t.start_only && t.end_only)
320 .map(ThreadUpdates::serialize_thread_update_exec)
321 .collect();
322 let start_threads: proc_macro2::TokenStream = new_threads
323 .iter()
324 .filter(|t| t.start_only && !t.end_only)
325 .map(ThreadUpdates::serialize_thread_update_exec)
326 .collect();
327 let end_threads: proc_macro2::TokenStream = new_threads
328 .iter()
329 .filter(|t| !t.start_only && t.end_only)
330 .map(ThreadUpdates::serialize_thread_update_exec)
331 .collect();
332 let normal_threads: proc_macro2::TokenStream = new_threads
333 .iter()
334 .filter(|t| !t.start_only && !t.end_only)
335 .map(ThreadUpdates::serialize_thread_update_exec)
336 .collect();
337
338 transition_epsilons_exec.extend(quote! {
339 VMStates::#label => {
340 if is_start && is_end {
341 #start_end_threads
342 }
343 if is_start {
344 #start_threads
345 }
346 if is_end {
347 #end_threads
348 }
349 #normal_threads
350 }
351 });
352 }
353
354 let transition_epsilons_test = quote! {
355 fn transition_epsilons_test(
356 list: &mut [bool],
357 idx: usize,
358 len: usize,
359 ) {
360 let is_start = idx == 0;
361 let is_end = idx == len;
362 #transition_epsilons_test
363 }
364 };
365 let transition_epsilons_exec = quote! {
366 fn transition_epsilons_exec(
367 threads: &[::ere_core::pike_vm::PikeVMThread<#capture_groups, VMStates>],
368 idx: usize,
369 len: usize,
370 ) -> ::std::vec::Vec<::ere_core::pike_vm::PikeVMThread<#capture_groups, VMStates>> {
371 let is_start = idx == 0;
372 let is_end = idx == len;
373 let mut out = ::std::vec::Vec::<::ere_core::pike_vm::PikeVMThread<#capture_groups, VMStates>>::new();
374 for thread in threads {
375 match thread.state {
376 #transition_epsilons_exec
377 }
378 }
379 return out;
380 }
381 };
382
383 return (transition_epsilons_test, transition_epsilons_exec);
384}
385
386pub(crate) fn serialize_pike_vm_token_stream(nfa: &WorkingNFA) -> proc_macro2::TokenStream {
389 let WorkingNFA { states, .. } = nfa;
390 let capture_groups = nfa.num_capture_groups();
391 let excluded_states = compute_excluded_states(nfa);
392
393 let enum_states: proc_macro2::TokenStream = std::iter::IntoIterator::into_iter(0..states.len())
394 .filter(|i| !excluded_states[*i])
395 .map(|i| {
396 let label = vmstate_label(i);
397 return quote! { #label, };
398 })
399 .collect();
400 let state_count = states.len(); let accept_state = vmstate_label(states.len() - 1);
402
403 let (transition_symbols_test, transition_symbols_exec) =
404 serialize_pike_vm_symbol_propogation(nfa);
405 let (transition_epsilons_test, transition_epsilons_exec) =
406 serialize_pike_vm_epsilon_propogation(nfa);
407
408 return quote! {{
409 #[derive(Clone, Copy, PartialEq, Eq, Debug)]
410 enum VMStates {
411 #enum_states
412 }
413
414 #transition_symbols_test
415 #transition_symbols_exec
416 #transition_epsilons_test
417 #transition_epsilons_exec
418
419 fn test(text: &str) -> bool {
420 let mut list = [false; #state_count];
421 let mut new_list = [false; #state_count];
422 list[0] = true;
423
424 transition_epsilons_test(&mut list, 0, text.len());
425 for (i, c) in text.char_indices() {
426 transition_symbols_test(&list, &mut new_list, c);
427 if new_list.iter().all(|b| !b) {
428 return false;
429 }
430 ::std::mem::swap(&mut list, &mut new_list);
431 transition_epsilons_test(&mut list, i + c.len_utf8(), text.len());
432 new_list.fill(false);
433 }
434
435 return list[#state_count - 1];
436 }
437 fn exec<'a>(text: &'a str) -> Option<[Option<&'a str>; #capture_groups]> {
438 let mut threads = ::std::vec::Vec::<::ere_core::pike_vm::PikeVMThread<#capture_groups, VMStates>>::new();
439 threads.push(::ere_core::pike_vm::PikeVMThread {
440 state: VMStates::State0,
441 captures: [(usize::MAX, usize::MAX); #capture_groups],
442 });
443
444 let new_threads = transition_epsilons_exec(&threads, 0, text.len());
445 threads = new_threads;
446
447 for (i, c) in text.char_indices() {
448 let new_threads = transition_symbols_exec(&threads, c);
449 threads = new_threads;
450 let new_threads = transition_epsilons_exec(&threads, i + c.len_utf8(), text.len());
451 threads = new_threads;
452 if threads.is_empty() {
453 return None;
454 }
455 }
456
457 let final_capture_bounds = threads
458 .into_iter()
459 .find(|t| t.state == VMStates::#accept_state)?
460 .captures;
461 let mut captures = [::core::option::Option::None; #capture_groups];
462 for (i, (start, end)) in final_capture_bounds.into_iter().enumerate() {
463 if start != usize::MAX {
464 assert_ne!(end, usize::MAX);
465 captures[i] = text.get(start..end);
467 assert!(captures[i].is_some());
468 } else {
469 assert_eq!(end, usize::MAX);
470 }
471 }
472 return Some(captures);
473 }
474
475 ::ere_core::pike_vm::PikeVM::<#capture_groups>::__load(test, exec)
476 }};
477}