1use crate::{
11 epsilon_propogation::{EpsilonPropogation, Tag},
12 working_u8_nfa::{U8Transition, U8NFA},
13};
14use quote::{quote, ToTokens, TokenStreamExt as _};
15use std::fmt::Write;
16
17#[derive(Clone)]
18pub struct Thread<const N: usize, S: Send + Sync + Copy + Eq> {
19 pub state: S,
20 pub captures: [(usize, usize); N],
21}
22impl<const N: usize, S: Send + Sync + Copy + Eq + std::fmt::Debug> std::fmt::Debug
23 for Thread<N, S>
24{
25 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
26 struct CapturesDebug<'a, const N: usize>(&'a [(usize, usize); N]);
27 impl<'a, const N: usize> std::fmt::Debug for CapturesDebug<'a, N> {
28 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
29 f.write_char('[')?;
30 for (i, endpoints) in self.0.iter().enumerate() {
31 if i != 0 {
32 f.write_str(", ")?;
33 }
34 match endpoints {
35 (usize::MAX, usize::MAX) => f.write_str("(_, _)")?,
36 (start, usize::MAX) => write!(f, "({start}, _)")?,
37 (usize::MAX, end) => write!(f, "(_, {end})")?,
38 (start, end) => write!(f, "({start}, {end})")?,
39 }
40 }
41 return f.write_char(']');
42 }
43 }
44 return f
45 .debug_struct("Thread")
46 .field("state", &self.state)
47 .field("captures", &CapturesDebug(&self.captures))
48 .finish();
49 }
50}
51
52struct CachedNFA<'a> {
57 nfa: &'a U8NFA,
58 excluded_states: Vec<bool>,
59 capture_groups: usize,
60}
61impl<'a> CachedNFA<'a> {
62 fn new(nfa: &'a U8NFA) -> CachedNFA<'a> {
63 let excluded_states = compute_excluded_states(nfa);
64 assert_eq!(nfa.states.len(), excluded_states.len());
65 let capture_groups = nfa.num_capture_groups();
66 return CachedNFA {
67 nfa,
68 excluded_states,
69 capture_groups,
70 };
71 }
72}
73
74fn compute_excluded_states(nfa: &U8NFA) -> Vec<bool> {
77 let mut out = vec![true; nfa.states.len()];
78 out[0] = false;
79 out[nfa.states.len() - 1] = false;
80 for (from, state) in nfa.states.iter().enumerate() {
81 for t in &state.transitions {
82 out[from] = false;
83 out[t.to] = false;
84 }
85 }
86
87 return out;
88}
89
90#[derive(Clone, Copy, PartialEq, Eq)]
91struct ImplVMStateLabel(usize);
92impl ToTokens for ImplVMStateLabel {
93 fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
94 let ImplVMStateLabel(idx) = self;
95 let label = format!("State{idx}");
96 tokens.append(proc_macro2::Ident::new(
97 &label,
98 proc_macro2::Span::call_site(),
99 ));
100 }
101}
102
103mod impl_test {
104 use quote::ToTokens;
105
106 use super::*;
107
108 struct ImplTransitionStateSymbol<'a> {
110 transition: &'a U8Transition,
111 }
112 impl<'a> ToTokens for ImplTransitionStateSymbol<'a> {
113 fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
114 let &ImplTransitionStateSymbol { transition } = self;
115 let U8Transition { symbol, to } = transition;
116 let start = symbol.start();
117 let end = symbol.end();
118 tokens.extend(quote! {{
119 if #start <= c && c <= #end {
120 new_list[#to] = true;
121 }
122 }});
123 }
124 }
125
126 pub(super) struct TransitionSymbols<'a>(pub &'a CachedNFA<'a>);
140 impl<'a> ToTokens for TransitionSymbols<'a> {
141 fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
142 let TransitionSymbols(nfa) = self;
143 let CachedNFA {
144 nfa,
145 excluded_states,
146 ..
147 } = nfa;
148
149 let transition_symbols_defs_test = nfa
150 .states
151 .iter()
152 .enumerate()
153 .filter(|(i, _)| !excluded_states[*i])
154 .map(|(i, state)| {
155 let state_transitions = state
156 .transitions
157 .iter()
158 .map(|t| ImplTransitionStateSymbol { transition: t });
159
160 return quote! {
161 if list[#i] {
162 #(#state_transitions)*
163 }
164 };
165 });
166
167 tokens.extend(quote! {
168 fn transition_symbols_test(
169 list: &[bool],
170 new_list: &mut [bool],
171 c: u8,
172 ) {
173 #(#transition_symbols_defs_test)*
174 }
175 });
176 }
177 }
178
179 pub(super) struct ImplTransitionStateEpsilon<'a> {
188 pub(super) from_state: ImplVMStateLabel,
189 pub(super) thread_updates: &'a [ThreadUpdates],
190 }
191 impl<'a> ToTokens for ImplTransitionStateEpsilon<'a> {
192 fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
193 let &ImplTransitionStateEpsilon {
194 from_state,
195 thread_updates,
196 } = self;
197 let ImplVMStateLabel(from_state) = from_state;
198
199 let start_end_threads = thread_updates
201 .iter()
202 .filter(|t| t.0.start_only && t.0.end_only)
203 .map(ThreadUpdates::serialize_thread_update_test);
204 let start_threads = thread_updates
205 .iter()
206 .filter(|t| t.0.start_only && !t.0.end_only)
207 .map(ThreadUpdates::serialize_thread_update_test);
208 let end_threads = thread_updates
209 .iter()
210 .filter(|t| !t.0.start_only && t.0.end_only)
211 .map(ThreadUpdates::serialize_thread_update_test);
212 let normal_threads = thread_updates
213 .iter()
214 .filter(|t| !t.0.start_only && !t.0.end_only)
215 .map(ThreadUpdates::serialize_thread_update_test);
216
217 tokens.extend(quote! {
218 if list[#from_state] {
219 if is_start && is_end {
220 #(#start_end_threads)*
221 }
222 if is_start {
223 #(#start_threads)*
224 }
225 if is_end {
226 #(#end_threads)*
227 }
228 #(#normal_threads)*
229 }
230 });
231 }
232 }
233
234 pub(super) struct TransitionEpsilons<'a>(pub &'a CachedNFA<'a>);
246 impl<'a> ToTokens for TransitionEpsilons<'a> {
247 fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
248 let TransitionEpsilons(nfa) = self;
249 let CachedNFA {
250 nfa,
251 excluded_states,
252 ..
253 } = nfa;
254 assert_eq!(nfa.states.len(), excluded_states.len());
255 let num_states = nfa.states.len();
256
257 let states_epsilon_transitions = std::iter::zip(nfa.states.iter(), excluded_states)
258 .enumerate()
259 .filter(|(_, (_, &excluded))| !excluded)
260 .map(|(i, _)| {
261 let mut new_threads = calculate_epsilon_propogations(nfa, i);
263 new_threads.retain(|t| {
264 !nfa.states[t.0.state].transitions.is_empty() || t.0.state + 1 == num_states
265 });
266
267 let label: ImplVMStateLabel = ImplVMStateLabel(i);
268
269 let state_epsilon_transitions_test = impl_test::ImplTransitionStateEpsilon {
270 from_state: label,
271 thread_updates: &new_threads,
272 };
273 state_epsilon_transitions_test.to_token_stream()
274 });
275
276 tokens.extend(quote! {
277 fn transition_epsilons_test(
278 list: &mut [bool],
279 idx: usize,
280 len: usize,
281 ) {
282 let is_start = idx == 0;
283 let is_end = idx == len;
284 #(#states_epsilon_transitions)*
285 }
286 });
287 }
288 }
289
290 pub(crate) struct TestFn<'a>(pub &'a CachedNFA<'a>);
291 impl<'a> ToTokens for TestFn<'a> {
292 fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
293 let TestFn(nfa) = self;
294 let CachedNFA {
295 nfa: u8_nfa,
296 excluded_states,
297 ..
298 } = nfa;
299
300 let enum_states = excluded_states
301 .iter()
302 .enumerate()
303 .filter_map(|(i, excluded)| match excluded {
304 true => None,
305 false => Some(ImplVMStateLabel(i)),
306 });
307 let state_count = u8_nfa.states.len();
308 let accept_state = ImplVMStateLabel(state_count - 1);
309
310 let transition_symbols_test = TransitionSymbols(nfa);
311 let transition_epsilons_test = TransitionEpsilons(nfa);
312
313 tokens.extend(quote! {
314 fn test(text: &str) -> bool {
315 #[derive(Clone, Copy, PartialEq, Eq, Debug)]
316 enum VMStates {
317 #(#enum_states,)*
318 }
319
320 #transition_symbols_test
321 #transition_epsilons_test
322
323 let mut list = [false; #state_count];
324 let mut new_list = [false; #state_count];
325 list[0] = true;
326
327 transition_epsilons_test(&mut list, 0, text.len());
328 for (i, c) in text.bytes().enumerate() {
329 transition_symbols_test(&list, &mut new_list, c);
330 if new_list.iter().all(|b| !b) {
331 return false;
332 }
333 ::std::mem::swap(&mut list, &mut new_list);
334 transition_epsilons_test(&mut list, i + 1, text.len());
335 new_list.fill(false);
336 }
337
338 return list[#state_count - 1];
339 }
340 });
341 }
342 }
343}
344
345mod impl_exec {
346 use quote::ToTokens;
347
348 use super::*;
349
350 struct ImplTransition<'a> {
351 transition: &'a U8Transition,
352 }
353 impl<'a> ToTokens for ImplTransition<'a> {
354 fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
355 let ImplTransition { transition } = self;
356 let U8Transition { symbol, to } = transition;
357 let start = symbol.start();
358 let end = symbol.end();
359 let to_label = ImplVMStateLabel(*to);
360 tokens.extend(quote! {{
361 if #start <= c && c <= #end {
362 new_threads.push(
363 ::ere::flat_lockstep_nfa_u8::Thread {
364 state: VMStates::#to_label,
365 captures: thread.captures.clone(),
366 },
367 );
368 }
369 }});
370 }
371 }
372
373 pub(super) struct TransitionSymbols<'a>(pub &'a CachedNFA<'a>);
388 impl<'a> ToTokens for TransitionSymbols<'a> {
389 fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
390 let TransitionSymbols(nfa) = self;
391 let CachedNFA {
392 nfa,
393 capture_groups,
394 excluded_states,
395 } = nfa;
396
397 let transition_symbols_defs_exec = nfa
398 .states
399 .iter()
400 .enumerate()
401 .filter(|(i, _)| !excluded_states[*i])
402 .map(|(i, state)| {
403 let label = ImplVMStateLabel(i);
404 let state_transitions = state
405 .transitions
406 .iter()
407 .map(|t| ImplTransition { transition: t });
408
409 return quote! {
410 VMStates::#label => {
411 #(#state_transitions)*
412 }
413 };
414 });
415
416 tokens.extend(quote! {
417 fn transition_symbols_exec(
418 threads: &[::ere::flat_lockstep_nfa_u8::Thread<#capture_groups, VMStates>],
419 new_threads: &mut ::std::vec::Vec<::ere::flat_lockstep_nfa_u8::Thread<#capture_groups, VMStates>>,
420 c: u8,
421 ) {
422 for thread in threads {
423 match thread.state {
424 #(#transition_symbols_defs_exec)*
425 }
426 }
427 }
428 });
429 }
430 }
431
432 pub(super) struct ImplTransitionStateEpsilon<'a> {
441 pub(super) from_state: ImplVMStateLabel,
442 pub(super) thread_updates: &'a [ThreadUpdates],
443 }
444 impl<'a> ToTokens for ImplTransitionStateEpsilon<'a> {
445 fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
446 let &ImplTransitionStateEpsilon {
447 from_state,
448 thread_updates,
449 } = self;
450
451 let start_end_threads = thread_updates
453 .iter()
454 .filter(|t| t.0.start_only && t.0.end_only)
455 .map(ThreadUpdates::serialize_thread_update_exec);
456 let start_threads = thread_updates
457 .iter()
458 .filter(|t| t.0.start_only && !t.0.end_only)
459 .map(ThreadUpdates::serialize_thread_update_exec);
460 let end_threads = thread_updates
461 .iter()
462 .filter(|t| !t.0.start_only && t.0.end_only)
463 .map(ThreadUpdates::serialize_thread_update_exec);
464 let normal_threads = thread_updates
465 .iter()
466 .filter(|t| !t.0.start_only && !t.0.end_only)
467 .map(ThreadUpdates::serialize_thread_update_exec);
468
469 tokens.extend(quote! {
470 VMStates::#from_state => {
471 if is_start && is_end {
472 #(#start_end_threads)*
473 }
474 if is_start {
475 #(#start_threads)*
476 }
477 if is_end {
478 #(#end_threads)*
479 }
480 #(#normal_threads)*
481 }
482 });
483 }
484 }
485
486 pub(super) struct TransitionEpsilons<'a>(pub &'a CachedNFA<'a>);
500 impl<'a> ToTokens for TransitionEpsilons<'a> {
501 fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
502 let TransitionEpsilons(nfa) = self;
503 let CachedNFA {
504 nfa,
505 capture_groups,
506 excluded_states,
507 } = nfa;
508 assert_eq!(nfa.states.len(), excluded_states.len());
509 let num_states = nfa.states.len();
510
511 let states_epsilon_transitions = std::iter::zip(nfa.states.iter(), excluded_states)
512 .enumerate()
513 .filter(|(_, (_, &excluded))| !excluded)
514 .map(|(i, _)| {
515 let mut new_threads = calculate_epsilon_propogations(nfa, i);
517 new_threads.retain(|t| {
518 !nfa.states[t.0.state].transitions.is_empty() || t.0.state + 1 == num_states
519 });
520
521 let label: ImplVMStateLabel = ImplVMStateLabel(i);
522
523 let state_epsilon_transitions_exec = impl_exec::ImplTransitionStateEpsilon {
524 from_state: label,
525 thread_updates: &new_threads,
526 };
527 state_epsilon_transitions_exec.to_token_stream()
528 });
529
530 tokens.extend(quote! {
531 fn transition_epsilons_exec(
532 threads: &[::ere::flat_lockstep_nfa_u8::Thread<#capture_groups, VMStates>],
533 new_threads: &mut ::std::vec::Vec<::ere::flat_lockstep_nfa_u8::Thread<#capture_groups, VMStates>>,
534 idx: usize,
535 len: usize,
536 ) {
537 let is_start = idx == 0;
538 let is_end = idx == len;
539 let mut occupied_states = ::std::vec![false; #num_states];
540
541 for thread in threads {
542 match thread.state {
543 #(#states_epsilon_transitions)*
544 }
545 }
546 }
547 });
548 }
549 }
550
551 pub(crate) struct ExecFn<'a>(pub &'a CachedNFA<'a>);
552 impl<'a> ToTokens for ExecFn<'a> {
553 fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
554 let ExecFn(nfa) = self;
555 let CachedNFA {
556 nfa: u8_nfa,
557 excluded_states,
558 capture_groups,
559 } = nfa;
560
561 let enum_states = excluded_states
562 .iter()
563 .enumerate()
564 .filter_map(|(i, excluded)| match excluded {
565 true => None,
566 false => Some(ImplVMStateLabel(i)),
567 });
568 let state_count = u8_nfa.states.len();
569 let accept_state = ImplVMStateLabel(state_count - 1);
570
571 let transition_symbols_exec = impl_exec::TransitionSymbols(&nfa);
572 let transition_epsilons_exec = impl_exec::TransitionEpsilons(&nfa);
573
574 tokens.extend(quote! {
575 fn exec<'a>(text: &'a str) -> Option<[Option<&'a str>; #capture_groups]> {
576 #[derive(Clone, Copy, PartialEq, Eq, Debug)]
577 enum VMStates {
578 #(#enum_states,)*
579 }
580
581 #transition_symbols_exec
582 #transition_epsilons_exec
583
584 let mut threads = ::std::vec::Vec::<::ere::flat_lockstep_nfa_u8::Thread<#capture_groups, VMStates>>::new();
585 let mut new_threads = ::std::vec::Vec::<::ere::flat_lockstep_nfa_u8::Thread<#capture_groups, VMStates>>::new();
586 threads.push(::ere::flat_lockstep_nfa_u8::Thread {
587 state: VMStates::State0,
588 captures: [(usize::MAX, usize::MAX); #capture_groups],
589 });
590
591 transition_epsilons_exec(&threads, &mut new_threads,0, text.len());
592 ::std::mem::swap(&mut threads, &mut new_threads);
593
594 for (i, c) in text.bytes().enumerate() {
595 new_threads.clear();
596 transition_symbols_exec(&threads, &mut new_threads, c);
597 ::std::mem::swap(&mut threads, &mut new_threads);
598
599 new_threads.clear();
600 transition_epsilons_exec(&threads, &mut new_threads, i + 1, text.len());
601 ::std::mem::swap(&mut threads, &mut new_threads);
602 if threads.is_empty() {
603 return ::core::option::Option::None;
604 }
605 }
606
607 let final_capture_bounds = threads
608 .into_iter()
609 .find(|t| t.state == VMStates::#accept_state)?
610 .captures;
611 let mut captures = [::core::option::Option::None; #capture_groups];
612 for (i, (start, end)) in final_capture_bounds.into_iter().enumerate() {
613 if start != usize::MAX {
614 assert_ne!(end, usize::MAX);
615 captures[i] = text.get(start..end);
617 assert!(captures[i].is_some());
618 } else {
619 assert_eq!(end, usize::MAX);
620 }
621 }
622 return ::core::option::Option::Some(captures);
623 }
624 });
625 }
626 }
627}
628
629#[derive(Clone, PartialEq, Eq)]
630struct ThreadUpdates(EpsilonPropogation);
631impl ThreadUpdates {
632 pub fn serialize_thread_update_test(&self) -> proc_macro2::TokenStream {
634 let new_state = self.0.state;
635 return quote! {{
636 list[#new_state] = true;
637 }};
638 }
639 pub fn serialize_thread_update_exec(&self) -> proc_macro2::TokenStream {
642 let new_state_idx = self.0.state;
643 let new_state = ImplVMStateLabel(self.0.state);
644 let capture_updates = self.0.update_tags.iter().map(|tag| match tag {
645 Tag::StartCapture(capture_group) => quote! {
646 new_thread.captures[#capture_group].0 = idx;
647 },
648 Tag::EndCapture(capture_group) => quote! {
649 new_thread.captures[#capture_group].1 = idx;
650 },
651 });
652
653 return quote! {
654 if !occupied_states[#new_state_idx] {
655 let mut new_thread = thread.clone();
656 new_thread.state = VMStates::#new_state;
657
658 #(#capture_updates)*
659
660 new_threads.push(new_thread);
661 occupied_states[#new_state_idx] = true;
662 }
663 };
664 }
665}
666
667fn calculate_epsilon_propogations(nfa: &U8NFA, state: usize) -> Vec<ThreadUpdates> {
668 let prop = EpsilonPropogation::calculate_epsilon_propogations_u8(nfa, state);
669 return prop.into_iter().map(ThreadUpdates).collect();
670}
671
672pub(crate) fn serialize_flat_lockstep_nfa_u8_token_stream(nfa: &U8NFA) -> proc_macro2::TokenStream {
675 let nfa = CachedNFA::new(nfa);
676
677 let test_fn = impl_test::TestFn(&nfa);
678 let exec_fn = impl_exec::ExecFn(&nfa);
679
680 return quote! {{
681 #test_fn
682 #exec_fn
683
684 (test, exec)
685 }};
686}