1use crate::{
11 epsilon_propogation::{EpsilonPropogation, Tag},
12 nfa_static,
13 working_nfa::{WorkingNFA, WorkingTransition},
14};
15use quote::{quote, ToTokens, TokenStreamExt};
16use std::fmt::Write;
17
18#[derive(Clone)]
19pub struct Thread<const N: usize, S: Send + Sync + Copy + Eq> {
20 pub state: S,
21 pub captures: [(usize, usize); N],
22}
23impl<const N: usize, S: Send + Sync + Copy + Eq + std::fmt::Debug> std::fmt::Debug
24 for Thread<N, S>
25{
26 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
27 struct CapturesDebug<'a, const N: usize>(&'a [(usize, usize); N]);
28 impl<'a, const N: usize> std::fmt::Debug for CapturesDebug<'a, N> {
29 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
30 f.write_char('[')?;
31 for (i, endpoints) in self.0.iter().enumerate() {
32 if i != 0 {
33 f.write_str(", ")?;
34 }
35 match endpoints {
36 (usize::MAX, usize::MAX) => f.write_str("(_, _)")?,
37 (start, usize::MAX) => write!(f, "({start}, _)")?,
38 (usize::MAX, end) => write!(f, "(_, {end})")?,
39 (start, end) => write!(f, "({start}, {end})")?,
40 }
41 }
42 return f.write_char(']');
43 }
44 }
45 return f
46 .debug_struct("Thread")
47 .field("state", &self.state)
48 .field("captures", &CapturesDebug(&self.captures))
49 .finish();
50 }
51}
52
53struct CachedNFA<'a> {
55 nfa: &'a WorkingNFA,
56 excluded_states: Vec<bool>,
57 capture_groups: usize,
58}
59impl<'a> CachedNFA<'a> {
60 fn new(nfa: &'a WorkingNFA) -> CachedNFA<'a> {
61 let excluded_states = compute_excluded_states(nfa);
62 assert_eq!(nfa.states.len(), excluded_states.len());
63 let capture_groups = nfa.num_capture_groups();
64 return CachedNFA {
65 nfa,
66 excluded_states,
67 capture_groups,
68 };
69 }
70}
71
72fn compute_excluded_states(nfa: &WorkingNFA) -> Vec<bool> {
75 let mut out = vec![true; nfa.states.len()];
76 out[0] = false;
77 out[nfa.states.len() - 1] = false;
78 for (from, state) in nfa.states.iter().enumerate() {
79 for t in &state.transitions {
80 out[from] = false;
81 out[t.to] = false;
82 }
83 }
84
85 return out;
86}
87
88#[derive(Clone, Copy, PartialEq, Eq)]
89struct ImplVMStateLabel(usize);
90impl ToTokens for ImplVMStateLabel {
91 fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
92 let ImplVMStateLabel(idx) = self;
93 let label = format!("State{idx}");
94 tokens.append(proc_macro2::Ident::new(
95 &label,
96 proc_macro2::Span::call_site(),
97 ));
98 }
99}
100
101mod impl_test {
102 use quote::ToTokens;
103
104 use super::*;
105
106 struct ImplTransitionStateSymbol<'a> {
108 transition: &'a WorkingTransition,
109 }
110 impl<'a> ToTokens for ImplTransitionStateSymbol<'a> {
111 fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
112 let ImplTransitionStateSymbol { transition } = self;
113 let WorkingTransition { symbol, to } = transition;
114 let symbol = nfa_static::AtomStatic::serialize_as_token_stream(symbol);
115 tokens.extend(quote! {{
116 let symbol = #symbol;
117 if symbol.check(c) {
118 new_list[#to] = true;
119 }
120 }});
121 }
122 }
123
124 pub(super) struct TransitionSymbols<'a>(pub &'a CachedNFA<'a>);
128 impl<'a> ToTokens for TransitionSymbols<'a> {
129 fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
130 let TransitionSymbols(nfa) = self;
131 let CachedNFA {
132 nfa,
133 excluded_states,
134 ..
135 } = nfa;
136
137 let transition_symbols_defs_test = nfa
138 .states
139 .iter()
140 .enumerate()
141 .filter(|(i, _)| !excluded_states[*i])
142 .map(|(i, state)| {
143 let state_transitions = state
144 .transitions
145 .iter()
146 .map(|t| ImplTransitionStateSymbol { transition: t });
147
148 return quote! {
149 if list[#i] {
150 #(#state_transitions)*
151 }
152 };
153 });
154
155 tokens.extend(quote! {
156 fn transition_symbols_test(
157 list: &[bool],
158 new_list: &mut [bool],
159 c: char,
160 ) {
161 #(#transition_symbols_defs_test)*
162 }
163 });
164 }
165 }
166
167 pub(super) struct ImplTransitionStateEpsilon<'a> {
176 pub(super) from_state: ImplVMStateLabel,
177 pub(super) thread_updates: &'a [ThreadUpdates],
178 }
179 impl<'a> ToTokens for ImplTransitionStateEpsilon<'a> {
180 fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
181 let &ImplTransitionStateEpsilon {
182 from_state,
183 thread_updates,
184 } = self;
185 let ImplVMStateLabel(from_state) = from_state;
186
187 let start_end_threads = thread_updates
189 .iter()
190 .filter(|t| t.0.start_only && t.0.end_only)
191 .map(ThreadUpdates::serialize_thread_update_test);
192 let start_threads = thread_updates
193 .iter()
194 .filter(|t| t.0.start_only && !t.0.end_only)
195 .map(ThreadUpdates::serialize_thread_update_test);
196 let end_threads = thread_updates
197 .iter()
198 .filter(|t| !t.0.start_only && t.0.end_only)
199 .map(ThreadUpdates::serialize_thread_update_test);
200 let normal_threads = thread_updates
201 .iter()
202 .filter(|t| !t.0.start_only && !t.0.end_only)
203 .map(ThreadUpdates::serialize_thread_update_test);
204
205 tokens.extend(quote! {
206 if list[#from_state] {
207 if is_start && is_end {
208 #(#start_end_threads)*
209 }
210 if is_start {
211 #(#start_threads)*
212 }
213 if is_end {
214 #(#end_threads)*
215 }
216 #(#normal_threads)*
217 }
218 });
219 }
220 }
221
222 pub(super) struct TransitionEpsilons<'a>(pub &'a CachedNFA<'a>);
234 impl<'a> ToTokens for TransitionEpsilons<'a> {
235 fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
236 let TransitionEpsilons(nfa) = self;
237 let CachedNFA {
238 nfa,
239 excluded_states,
240 ..
241 } = nfa;
242 assert_eq!(nfa.states.len(), excluded_states.len());
243 let num_states = nfa.states.len();
244
245 let states_epsilon_transitions = std::iter::zip(nfa.states.iter(), excluded_states)
246 .enumerate()
247 .filter(|(_, (_, &excluded))| !excluded)
248 .map(|(i, _)| {
249 let mut new_threads = calculate_epsilon_propogations(nfa, i);
251 new_threads.retain(|t| {
252 !nfa.states[t.0.state].transitions.is_empty() || t.0.state + 1 == num_states
253 });
254
255 let label: ImplVMStateLabel = ImplVMStateLabel(i);
256
257 let state_epsilon_transitions_test = impl_test::ImplTransitionStateEpsilon {
258 from_state: label,
259 thread_updates: &new_threads,
260 };
261 state_epsilon_transitions_test.to_token_stream()
262 });
263
264 tokens.extend(quote! {
265 fn transition_epsilons_test(
266 list: &mut [bool],
267 idx: usize,
268 len: usize,
269 ) {
270 let is_start = idx == 0;
271 let is_end = idx == len;
272 #(#states_epsilon_transitions)*
273 }
274 });
275 }
276 }
277
278 pub(crate) struct TestFn<'a>(pub &'a CachedNFA<'a>);
279 impl<'a> ToTokens for TestFn<'a> {
280 fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
281 let TestFn(nfa) = self;
282 let CachedNFA {
283 nfa: u8_nfa,
284 excluded_states,
285 ..
286 } = nfa;
287
288 let enum_states = excluded_states
289 .iter()
290 .enumerate()
291 .filter_map(|(i, excluded)| match excluded {
292 true => None,
293 false => Some(ImplVMStateLabel(i)),
294 });
295 let state_count = u8_nfa.states.len();
296 let accept_state = ImplVMStateLabel(state_count - 1);
297
298 let transition_symbols_test = TransitionSymbols(nfa);
299 let transition_epsilons_test = TransitionEpsilons(nfa);
300
301 tokens.extend(quote! {
302 fn test(text: &str) -> bool {
303 #[derive(Clone, Copy, PartialEq, Eq, Debug)]
304 enum VMStates {
305 #(#enum_states,)*
306 }
307
308 #transition_symbols_test
309 #transition_epsilons_test
310
311 let mut list = [false; #state_count];
312 let mut new_list = [false; #state_count];
313 list[0] = true;
314
315 transition_epsilons_test(&mut list, 0, text.len());
316 for (i, c) in text.char_indices() {
317 transition_symbols_test(&list, &mut new_list, c);
318 if new_list.iter().all(|b| !b) {
319 return false;
320 }
321 ::std::mem::swap(&mut list, &mut new_list);
322 transition_epsilons_test(&mut list, i + c.len_utf8(), text.len());
323 new_list.fill(false);
324 }
325
326 return list[#state_count - 1];
327 }
328 });
329 }
330 }
331}
332
333mod impl_exec {
334 use quote::ToTokens;
335
336 use super::*;
337
338 struct ImplTransition<'a> {
339 transition: &'a WorkingTransition,
340 }
341 impl<'a> ToTokens for ImplTransition<'a> {
342 fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
343 let ImplTransition { transition } = self;
344 let WorkingTransition { symbol, to } = transition;
345 let symbol = nfa_static::AtomStatic::serialize_as_token_stream(symbol);
346 let to_label = ImplVMStateLabel(*to);
347 tokens.extend(quote! {{
348 let symbol = #symbol;
349 if symbol.check(c) {
350 new_threads.push(
351 ::ere::flat_lockstep_nfa::Thread {
352 state: VMStates::#to_label,
353 captures: thread.captures.clone(),
354 },
355 );
356 }
357 }});
358 }
359 }
360
361 pub(super) struct TransitionSymbols<'a>(pub &'a CachedNFA<'a>);
376 impl<'a> ToTokens for TransitionSymbols<'a> {
377 fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
378 let TransitionSymbols(nfa) = self;
379 let CachedNFA {
380 nfa,
381 capture_groups,
382 excluded_states,
383 } = nfa;
384
385 let transition_symbols_defs_exec = nfa
386 .states
387 .iter()
388 .enumerate()
389 .filter(|(i, _)| !excluded_states[*i])
390 .map(|(i, state)| {
391 let label = ImplVMStateLabel(i);
392 let state_transitions = state
393 .transitions
394 .iter()
395 .map(|t| ImplTransition { transition: t });
396
397 return quote! {
398 VMStates::#label => {
399 #(#state_transitions)*
400 }
401 };
402 });
403
404 tokens.extend(quote! {
405 fn transition_symbols_exec(
406 threads: &[::ere::flat_lockstep_nfa::Thread<#capture_groups, VMStates>],
407 new_threads: &mut ::std::vec::Vec<::ere::flat_lockstep_nfa::Thread<#capture_groups, VMStates>>,
408 c: char,
409 ) {
410 for thread in threads {
411 match thread.state {
412 #(#transition_symbols_defs_exec)*
413 }
414 }
415 }
416 });
417 }
418 }
419
420 pub(super) struct ImplTransitionStateEpsilon<'a> {
429 pub(super) from_state: ImplVMStateLabel,
430 pub(super) thread_updates: &'a [ThreadUpdates],
431 }
432 impl<'a> ToTokens for ImplTransitionStateEpsilon<'a> {
433 fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
434 let &ImplTransitionStateEpsilon {
435 from_state,
436 thread_updates,
437 } = self;
438
439 let start_end_threads = thread_updates
441 .iter()
442 .filter(|t| t.0.start_only && t.0.end_only)
443 .map(ThreadUpdates::serialize_thread_update_exec);
444 let start_threads = thread_updates
445 .iter()
446 .filter(|t| t.0.start_only && !t.0.end_only)
447 .map(ThreadUpdates::serialize_thread_update_exec);
448 let end_threads = thread_updates
449 .iter()
450 .filter(|t| !t.0.start_only && t.0.end_only)
451 .map(ThreadUpdates::serialize_thread_update_exec);
452 let normal_threads = thread_updates
453 .iter()
454 .filter(|t| !t.0.start_only && !t.0.end_only)
455 .map(ThreadUpdates::serialize_thread_update_exec);
456
457 tokens.extend(quote! {
458 VMStates::#from_state => {
459 if is_start && is_end {
460 #(#start_end_threads)*
461 }
462 if is_start {
463 #(#start_threads)*
464 }
465 if is_end {
466 #(#end_threads)*
467 }
468 #(#normal_threads)*
469 }
470 });
471 }
472 }
473
474 pub(super) struct TransitionEpsilons<'a>(pub &'a CachedNFA<'a>);
487 impl<'a> ToTokens for TransitionEpsilons<'a> {
488 fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
489 let TransitionEpsilons(nfa) = self;
490 let CachedNFA {
491 nfa,
492 capture_groups,
493 excluded_states,
494 } = nfa;
495 assert_eq!(nfa.states.len(), excluded_states.len());
496 let num_states = nfa.states.len();
497
498 let states_epsilon_transitions = std::iter::zip(nfa.states.iter(), excluded_states)
499 .enumerate()
500 .filter(|(_, (_, &excluded))| !excluded)
501 .map(|(i, _)| {
502 let mut new_threads = calculate_epsilon_propogations(nfa, i);
504 new_threads.retain(|t| {
505 !nfa.states[t.0.state].transitions.is_empty() || t.0.state + 1 == num_states
506 });
507
508 let label: ImplVMStateLabel = ImplVMStateLabel(i);
509
510 let state_epsilon_transitions_exec = impl_exec::ImplTransitionStateEpsilon {
511 from_state: label,
512 thread_updates: &new_threads,
513 };
514 state_epsilon_transitions_exec.to_token_stream()
515 });
516
517 tokens.extend(quote! {
518 fn transition_epsilons_exec(
519 threads: &[::ere::flat_lockstep_nfa::Thread<#capture_groups, VMStates>],
520 new_threads: &mut ::std::vec::Vec<::ere::flat_lockstep_nfa::Thread<#capture_groups, VMStates>>,
521 idx: usize,
522 len: usize,
523 ) {
524 let is_start = idx == 0;
525 let is_end = idx == len;
526 let mut occupied_states = ::std::vec![false; #num_states];
527 for thread in threads {
528 match thread.state {
529 #(#states_epsilon_transitions)*
530 }
531 }
532 }
533 });
534 }
535 }
536
537 pub(crate) struct ExecFn<'a>(pub &'a CachedNFA<'a>);
538 impl<'a> ToTokens for ExecFn<'a> {
539 fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
540 let ExecFn(nfa) = self;
541 let CachedNFA {
542 nfa: u8_nfa,
543 excluded_states,
544 capture_groups,
545 } = nfa;
546
547 let enum_states = excluded_states
548 .iter()
549 .enumerate()
550 .filter_map(|(i, excluded)| match excluded {
551 true => None,
552 false => Some(ImplVMStateLabel(i)),
553 });
554 let state_count = u8_nfa.states.len();
555 let accept_state = ImplVMStateLabel(state_count - 1);
556
557 let transition_symbols_exec = impl_exec::TransitionSymbols(&nfa);
558 let transition_epsilons_exec = impl_exec::TransitionEpsilons(&nfa);
559
560 tokens.extend(quote! {
561 fn exec<'a>(text: &'a str) -> Option<[Option<&'a str>; #capture_groups]> {
562 #[derive(Clone, Copy, PartialEq, Eq, Debug)]
563 enum VMStates {
564 #(#enum_states,)*
565 }
566
567 #transition_symbols_exec
568 #transition_epsilons_exec
569
570 let mut threads = ::std::vec::Vec::<::ere::flat_lockstep_nfa::Thread<#capture_groups, VMStates>>::new();
571 let mut new_threads = ::std::vec::Vec::<::ere::flat_lockstep_nfa::Thread<#capture_groups, VMStates>>::new();
572 threads.push(::ere::flat_lockstep_nfa::Thread {
573 state: VMStates::State0,
574 captures: [(usize::MAX, usize::MAX); #capture_groups],
575 });
576
577 transition_epsilons_exec(&threads, &mut new_threads, 0, text.len());
578 ::std::mem::swap(&mut threads, &mut new_threads);
579
580 for (i, c) in text.char_indices() {
581 new_threads.clear();
582 transition_symbols_exec(&threads, &mut new_threads, c);
583 ::std::mem::swap(&mut threads, &mut new_threads);
584
585 new_threads.clear();
586 transition_epsilons_exec(&threads, &mut new_threads, i + c.len_utf8(), text.len());
587 ::std::mem::swap(&mut threads, &mut new_threads);
588
589 if threads.is_empty() {
590 return None;
591 }
592 }
593
594 let final_capture_bounds = threads
595 .into_iter()
596 .find(|t| t.state == VMStates::#accept_state)?
597 .captures;
598 let mut captures = [::core::option::Option::None; #capture_groups];
599 for (i, (start, end)) in final_capture_bounds.into_iter().enumerate() {
600 if start != usize::MAX {
601 assert_ne!(end, usize::MAX);
602 captures[i] = text.get(start..end);
604 assert!(captures[i].is_some());
605 } else {
606 assert_eq!(end, usize::MAX);
607 }
608 }
609 return Some(captures);
610 }
611 });
612 }
613 }
614}
615
616#[derive(Clone, PartialEq, Eq)]
617struct ThreadUpdates(EpsilonPropogation);
618impl ThreadUpdates {
619 pub fn serialize_thread_update_test(&self) -> proc_macro2::TokenStream {
621 let new_state = self.0.state;
622 return quote! {{
623 list[#new_state] = true;
624 }};
625 }
626 pub fn serialize_thread_update_exec(&self) -> proc_macro2::TokenStream {
629 let new_state_idx = self.0.state;
630 let new_state = ImplVMStateLabel(self.0.state);
631 let capture_updates = self.0.update_tags.iter().map(|tag| match tag {
632 Tag::StartCapture(capture_group) => quote! {
633 new_thread.captures[#capture_group].0 = idx;
634 },
635 Tag::EndCapture(capture_group) => quote! {
636 new_thread.captures[#capture_group].1 = idx;
637 },
638 });
639
640 return quote! {
641 if !occupied_states[#new_state_idx] {
642 let mut new_thread = thread.clone();
643 new_thread.state = VMStates::#new_state;
644
645 #(#capture_updates)*
646
647 new_threads.push(new_thread);
648 occupied_states[#new_state_idx] = true;
649 }
650 };
651 }
652}
653
654fn calculate_epsilon_propogations(nfa: &WorkingNFA, state: usize) -> Vec<ThreadUpdates> {
655 let prop = EpsilonPropogation::calculate_epsilon_propogations_char(nfa, state);
656 return prop.into_iter().map(ThreadUpdates).collect();
657}
658
659pub(crate) fn serialize_flat_lockstep_nfa_token_stream(
664 nfa: &WorkingNFA,
665) -> proc_macro2::TokenStream {
666 let nfa = CachedNFA::new(nfa);
667
668 let test_fn = impl_test::TestFn(&nfa);
669 let exec_fn = impl_exec::ExecFn(&nfa);
670
671 return quote! {{
672 #test_fn
673 #exec_fn
674
675 (test, exec)
676 }};
677}