use crate::{
working_nfa::EpsilonType,
working_u8_nfa::{U8Transition, U8NFA},
};
use quote::quote;
use std::fmt::Write;
#[derive(Clone)]
pub struct U8PikeVMThread<const N: usize, S: Send + Sync + Copy + Eq> {
pub state: S,
pub captures: [(usize, usize); N],
}
impl<const N: usize, S: Send + Sync + Copy + Eq + std::fmt::Debug> std::fmt::Debug
for U8PikeVMThread<N, S>
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
struct CapturesDebug<'a, const N: usize>(&'a [(usize, usize); N]);
impl<'a, const N: usize> std::fmt::Debug for CapturesDebug<'a, N> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_char('[')?;
for (i, endpoints) in self.0.iter().enumerate() {
if i != 0 {
f.write_str(", ")?;
}
match endpoints {
(usize::MAX, usize::MAX) => f.write_str("(_, _)")?,
(start, usize::MAX) => write!(f, "({start}, _)")?,
(usize::MAX, end) => write!(f, "(_, {end})")?,
(start, end) => write!(f, "({start}, {end})")?,
}
}
return f.write_char(']');
}
}
return f
.debug_struct("PikeVMThread")
.field("state", &self.state)
.field("captures", &CapturesDebug(&self.captures))
.finish();
}
}
fn vmstate_label(idx: usize) -> proc_macro2::Ident {
let label = format!("State{idx}");
return proc_macro2::Ident::new(&label, proc_macro2::Span::call_site());
}
fn compute_excluded_states(nfa: &U8NFA) -> Vec<bool> {
let mut out = vec![true; nfa.states.len()];
out[0] = false;
out[nfa.states.len() - 1] = false;
for (from, state) in nfa.states.iter().enumerate() {
for t in &state.transitions {
out[from] = false;
out[t.to] = false;
}
}
return out;
}
fn serialize_pike_vm_symbol_propogation(
nfa: &U8NFA,
) -> (proc_macro2::TokenStream, proc_macro2::TokenStream) {
let U8NFA { states } = nfa;
let capture_groups = nfa.num_capture_groups();
let excluded_states = compute_excluded_states(nfa);
fn make_symbol_transition_test(t: &U8Transition) -> proc_macro2::TokenStream {
let U8Transition { symbol, to } = t;
let start = symbol.start();
let end = symbol.end();
return quote! {
{
if #start <= c && c <= #end {
new_list[#to] = true;
}
}
};
}
fn make_symbol_transition_exec(t: &U8Transition) -> proc_macro2::TokenStream {
let U8Transition { symbol, to } = t;
let start = symbol.start();
let end = symbol.end();
let to_label = vmstate_label(*to);
return quote! {
{
if #start <= c && c <= #end {
out.push(
::ere::pike_vm_u8::U8PikeVMThread {
state: VMStates::#to_label,
captures: thread.captures.clone(),
},
);
}
}
};
}
let transition_symbols_defs_test: proc_macro2::TokenStream = states
.iter()
.enumerate()
.filter(|(i, _)| !excluded_states[*i])
.map(|(i, state)| {
let state_transitions: proc_macro2::TokenStream = state
.transitions
.iter()
.map(make_symbol_transition_test)
.collect();
return quote! {
if list[#i] {
#state_transitions
}
};
})
.collect();
let transition_symbols_defs_exec: proc_macro2::TokenStream = states
.iter()
.enumerate()
.filter(|(i, _)| !excluded_states[*i])
.map(|(i, state)| {
let label = vmstate_label(i);
let state_transitions: proc_macro2::TokenStream = state
.transitions
.iter()
.map(make_symbol_transition_exec)
.collect();
return quote! {
VMStates::#label => {
#state_transitions
}
};
})
.collect();
let transition_symbols_test = quote! {
fn transition_symbols_test(
list: &[bool],
new_list: &mut [bool],
c: u8,
) {
#transition_symbols_defs_test
}
};
let transition_symbols_exec = quote! {
fn transition_symbols_exec(
threads: &[::ere::pike_vm_u8::U8PikeVMThread<#capture_groups, VMStates>],
c: u8,
) -> ::std::vec::Vec<::ere::pike_vm_u8::U8PikeVMThread<#capture_groups, VMStates>> {
let mut out = ::std::vec::Vec::<::ere::pike_vm_u8::U8PikeVMThread<#capture_groups, VMStates>>::new();
for thread in threads {
match thread.state {
#transition_symbols_defs_exec
}
}
return out;
}
};
return (transition_symbols_test, transition_symbols_exec);
}
#[derive(Clone, PartialEq, Eq)]
struct ThreadUpdates {
pub state: usize,
pub update_captures: Vec<(bool, bool)>,
pub start_only: bool,
pub end_only: bool,
}
impl ThreadUpdates {
pub fn serialize_thread_update_test(&self) -> proc_macro2::TokenStream {
let new_state = self.state;
return quote! {{
list[#new_state] = true;
}};
}
pub fn serialize_thread_update_exec(&self) -> proc_macro2::TokenStream {
let new_state_idx = self.state;
let new_state = vmstate_label(self.state);
let mut capture_updates = proc_macro2::TokenStream::new();
for (i, (start, end)) in self.update_captures.iter().cloned().enumerate() {
if start {
capture_updates.extend(quote! {
new_thread.captures[#i].0 = idx;
});
}
if end {
capture_updates.extend(quote! {
new_thread.captures[#i].1 = idx;
});
}
}
return quote! {
if !occupied_states[#new_state_idx] {
let mut new_thread = thread.clone();
new_thread.state = VMStates::#new_state;
#capture_updates
out.push(new_thread);
occupied_states[#new_state_idx] = true;
}
};
}
}
fn calculate_epsilon_propogations(nfa: &U8NFA, state: usize) -> Vec<ThreadUpdates> {
let U8NFA { states } = nfa;
let capture_groups = nfa.num_capture_groups();
let mut new_threads = vec![ThreadUpdates {
state,
update_captures: vec![(false, false); capture_groups],
start_only: false,
end_only: false,
}];
fn traverse(
thread: ThreadUpdates,
states: &Vec<crate::working_u8_nfa::U8State>,
out: &mut Vec<ThreadUpdates>,
) {
out.push(thread.clone());
for e in &states[thread.state].epsilons {
let mut new_thread = thread.clone();
new_thread.state = e.to;
match e.special {
EpsilonType::None => {}
EpsilonType::StartAnchor => new_thread.start_only = true,
EpsilonType::EndAnchor => new_thread.end_only = true,
EpsilonType::StartCapture(capture_group) => {
new_thread.update_captures[capture_group].0 = true
}
EpsilonType::EndCapture(capture_group) => {
new_thread.update_captures[capture_group].1 = true
}
}
if !out.contains(&new_thread) {
traverse(new_thread, states, out);
}
}
}
traverse(
ThreadUpdates {
state,
update_captures: vec![(false, false); capture_groups],
start_only: false,
end_only: false,
},
states,
&mut new_threads,
);
return new_threads;
}
fn serialize_pike_vm_epsilon_propogation(
nfa: &U8NFA,
) -> (proc_macro2::TokenStream, proc_macro2::TokenStream) {
let U8NFA { states } = nfa;
let capture_groups = nfa.num_capture_groups();
let num_states = states.len();
let excluded_states = compute_excluded_states(nfa);
let mut transition_epsilons_test = proc_macro2::TokenStream::new();
let mut transition_epsilons_exec = proc_macro2::TokenStream::new();
for (i, _) in states.iter().enumerate() {
if excluded_states[i] {
continue;
}
let mut new_threads = calculate_epsilon_propogations(nfa, i);
new_threads
.retain(|t| !states[t.state].transitions.is_empty() || t.state + 1 == num_states);
let start_end_threads: proc_macro2::TokenStream = new_threads
.iter()
.filter(|t| t.start_only && t.end_only)
.map(ThreadUpdates::serialize_thread_update_test)
.collect();
let start_threads: proc_macro2::TokenStream = new_threads
.iter()
.filter(|t| t.start_only && !t.end_only)
.map(ThreadUpdates::serialize_thread_update_test)
.collect();
let end_threads: proc_macro2::TokenStream = new_threads
.iter()
.filter(|t| !t.start_only && t.end_only)
.map(ThreadUpdates::serialize_thread_update_test)
.collect();
let normal_threads: proc_macro2::TokenStream = new_threads
.iter()
.filter(|t| !t.start_only && !t.end_only)
.map(ThreadUpdates::serialize_thread_update_test)
.collect();
let label = vmstate_label(i);
transition_epsilons_test.extend(quote! {
if list[#i] {
if is_start && is_end {
#start_end_threads
}
if is_start {
#start_threads
}
if is_end {
#end_threads
}
#normal_threads
}
});
let start_end_threads: proc_macro2::TokenStream = new_threads
.iter()
.filter(|t| t.start_only && t.end_only)
.map(ThreadUpdates::serialize_thread_update_exec)
.collect();
let start_threads: proc_macro2::TokenStream = new_threads
.iter()
.filter(|t| t.start_only && !t.end_only)
.map(ThreadUpdates::serialize_thread_update_exec)
.collect();
let end_threads: proc_macro2::TokenStream = new_threads
.iter()
.filter(|t| !t.start_only && t.end_only)
.map(ThreadUpdates::serialize_thread_update_exec)
.collect();
let normal_threads: proc_macro2::TokenStream = new_threads
.iter()
.filter(|t| !t.start_only && !t.end_only)
.map(ThreadUpdates::serialize_thread_update_exec)
.collect();
transition_epsilons_exec.extend(quote! {
VMStates::#label => {
if is_start && is_end {
#start_end_threads
}
if is_start {
#start_threads
}
if is_end {
#end_threads
}
#normal_threads
}
});
}
let transition_epsilons_test = quote! {
fn transition_epsilons_test(
list: &mut [bool],
idx: usize,
len: usize,
) {
let is_start = idx == 0;
let is_end = idx == len;
#transition_epsilons_test
}
};
let transition_epsilons_exec = quote! {
fn transition_epsilons_exec(
threads: &[::ere::pike_vm_u8::U8PikeVMThread<#capture_groups, VMStates>],
idx: usize,
len: usize,
) -> ::std::vec::Vec<::ere::pike_vm_u8::U8PikeVMThread<#capture_groups, VMStates>> {
let is_start = idx == 0;
let is_end = idx == len;
let mut occupied_states = ::std::vec![false; #num_states];
let mut out = ::std::vec::Vec::<::ere::pike_vm_u8::U8PikeVMThread<#capture_groups, VMStates>>::new();
for thread in threads {
match thread.state {
#transition_epsilons_exec
}
}
return out;
}
};
return (transition_epsilons_test, transition_epsilons_exec);
}
pub(crate) fn serialize_pike_vm_token_stream(nfa: &U8NFA) -> proc_macro2::TokenStream {
let U8NFA { states, .. } = nfa;
let capture_groups = nfa.num_capture_groups();
let excluded_states = compute_excluded_states(nfa);
let enum_states: proc_macro2::TokenStream = std::iter::IntoIterator::into_iter(0..states.len())
.filter(|i| !excluded_states[*i])
.map(|i| {
let label = vmstate_label(i);
return quote! { #label, };
})
.collect();
let state_count = states.len(); let accept_state = vmstate_label(states.len() - 1);
let (transition_symbols_test, transition_symbols_exec) =
serialize_pike_vm_symbol_propogation(nfa);
let (transition_epsilons_test, transition_epsilons_exec) =
serialize_pike_vm_epsilon_propogation(nfa);
return quote! {{
#[derive(Clone, Copy, PartialEq, Eq, Debug)]
enum VMStates {
#enum_states
}
#transition_symbols_test
#transition_symbols_exec
#transition_epsilons_test
#transition_epsilons_exec
fn test(text: &str) -> bool {
let mut list = [false; #state_count];
let mut new_list = [false; #state_count];
list[0] = true;
transition_epsilons_test(&mut list, 0, text.len());
for (i, c) in text.bytes().enumerate() {
transition_symbols_test(&list, &mut new_list, c);
if new_list.iter().all(|b| !b) {
return false;
}
::std::mem::swap(&mut list, &mut new_list);
transition_epsilons_test(&mut list, i + 1, text.len());
new_list.fill(false);
}
return list[#state_count - 1];
}
fn exec<'a>(text: &'a str) -> Option<[Option<&'a str>; #capture_groups]> {
let mut threads = ::std::vec::Vec::<::ere::pike_vm_u8::U8PikeVMThread<#capture_groups, VMStates>>::new();
threads.push(::ere::pike_vm_u8::U8PikeVMThread {
state: VMStates::State0,
captures: [(usize::MAX, usize::MAX); #capture_groups],
});
let new_threads = transition_epsilons_exec(&threads, 0, text.len());
threads = new_threads;
for (i, c) in text.bytes().enumerate() {
let new_threads = transition_symbols_exec(&threads, c);
threads = new_threads;
let new_threads = transition_epsilons_exec(&threads, i + 1, text.len());
threads = new_threads;
if threads.is_empty() {
return ::core::option::Option::None;
}
}
let final_capture_bounds = threads
.into_iter()
.find(|t| t.state == VMStates::#accept_state)?
.captures;
let mut captures = [::core::option::Option::None; #capture_groups];
for (i, (start, end)) in final_capture_bounds.into_iter().enumerate() {
if start != usize::MAX {
assert_ne!(end, usize::MAX);
captures[i] = text.get(start..end);
assert!(captures[i].is_some());
} else {
assert_eq!(end, usize::MAX);
}
}
return ::core::option::Option::Some(captures);
}
(test, exec)
}};
}