use crate::working_u8_nfa::U8NFA;
use proc_macro2::TokenStream;
use quote::quote;
pub struct U8OnePass<const N: usize> {
test_fn: fn(&str) -> bool,
exec_fn: for<'a> fn(&'a str) -> Option<[Option<&'a str>; N]>,
}
impl<const N: usize> U8OnePass<N> {
pub fn test(&self, text: &str) -> bool {
return (self.test_fn)(text);
}
pub fn exec<'a>(&self, text: &'a str) -> Option<[Option<&'a str>; N]> {
return (self.exec_fn)(text);
}
}
pub const fn __load_u8onepass<const N: usize>(
test_fn: fn(&str) -> bool,
exec_fn: for<'a> fn(&'a str) -> Option<[Option<&'a str>; N]>,
) -> U8OnePass<N> {
return U8OnePass { test_fn, exec_fn };
}
fn vmstate_label(idx: usize) -> proc_macro2::Ident {
let label = format!("State{idx}");
return proc_macro2::Ident::new(&label, proc_macro2::Span::call_site());
}
fn compute_excluded_states(nfa: &U8NFA) -> Vec<bool> {
let mut out = vec![true; nfa.states.len()];
out[0] = false;
for state in &nfa.states {
for t in &state.transitions {
out[t.to] = false;
}
}
return out;
}
#[derive(Clone, PartialEq, Eq)]
struct ThreadUpdates {
pub state: usize,
pub update_captures: Vec<(bool, bool)>,
pub start_only: bool,
pub end_only: bool,
}
impl ThreadUpdates {
pub fn new(state: usize, num_captures: usize) -> ThreadUpdates {
return ThreadUpdates {
state,
update_captures: vec![(false, false); num_captures],
start_only: false,
end_only: false,
};
}
}
pub(crate) fn serialize_one_pass_token_stream(nfa: &U8NFA) -> Option<TokenStream> {
let num_captures = nfa.num_capture_groups();
let mut symbol_transitions = vec![Vec::new(); nfa.states.len()];
let mut accept_transitions = vec![Vec::new(); nfa.states.len()];
for (state_idx, _) in nfa.states.iter().enumerate() {
let mut stack = vec![ThreadUpdates::new(state_idx, num_captures)];
let mut reached = vec![ThreadUpdates::new(state_idx, num_captures)];
while let Some(thread) = stack.pop() {
if thread.state + 1 == nfa.states.len() {
accept_transitions[state_idx].push(thread.clone());
}
for ep in &nfa.states[thread.state].epsilons {
let mut new_thread = thread.clone();
new_thread.state = ep.to;
match ep.special {
crate::working_nfa::EpsilonType::None => (),
crate::working_nfa::EpsilonType::StartAnchor => new_thread.start_only = true,
crate::working_nfa::EpsilonType::EndAnchor => new_thread.end_only = true,
crate::working_nfa::EpsilonType::StartCapture(c) => {
new_thread.update_captures[c].0 = true
}
crate::working_nfa::EpsilonType::EndCapture(c) => {
new_thread.update_captures[c].1 = true
}
}
if !reached.contains(&new_thread) {
reached.push(new_thread.clone());
stack.push(new_thread);
}
}
for tr in &nfa.states[thread.state].transitions {
let new_transition = (
tr.symbol.0.clone(),
ThreadUpdates {
state: tr.to,
update_captures: thread.update_captures.clone(),
start_only: thread.start_only,
end_only: thread.end_only,
},
);
if !symbol_transitions[state_idx].contains(&new_transition) {
symbol_transitions[state_idx].push(new_transition);
}
}
}
}
for state_transitions in &mut symbol_transitions {
state_transitions.sort_by_key(|(range, _)| *range.start());
let overlap = !state_transitions.windows(2).all(|ranges| {
if let &[(a, _), (b, _)] = &ranges {
return a.end() < b.start();
} else {
unreachable!("Vec::windows does not use const generics so we have to do this.");
}
});
if overlap {
return None;
}
}
if let Some(_) = nfa.topological_ordering() {
return Some(codegen_functional(
nfa,
num_captures,
symbol_transitions,
accept_transitions,
));
} else {
return Some(codegen_vmlike(
nfa,
num_captures,
symbol_transitions,
accept_transitions,
));
}
}
#[derive(Clone, Copy, PartialEq, Eq, Debug, PartialOrd, Ord)]
enum Tag {
StartCapture(usize),
EndCapture(usize),
}
#[derive(Clone)]
struct Run {
start_state: usize,
symbols: Vec<std::ops::RangeInclusive<u8>>,
tags: Vec<(usize, Tag)>,
end_state: usize,
}
#[derive(Clone)]
enum StateRunInclusion {
Start(Run),
Internal,
End,
None,
}
fn compute_runs(
symbol_transitions: &Vec<Vec<(std::ops::RangeInclusive<u8>, ThreadUpdates)>>,
) -> Vec<StateRunInclusion> {
let mut incoming_count = vec![0; symbol_transitions.len()];
for state_transitions in symbol_transitions {
for (_, update) in state_transitions {
incoming_count[update.state] += 1;
}
}
let run_next: Vec<_> = symbol_transitions
.iter()
.map(|state_transitions| {
let &[(range, update)] = &state_transitions.as_slice() else {
return None;
};
if update.end_only || update.start_only || incoming_count[update.state] != 1 {
return None;
}
let tags: Vec<_> = update
.update_captures
.iter()
.enumerate()
.flat_map(|(i, set_tags)| match set_tags {
(false, false) => Vec::new(),
(true, false) => vec![Tag::StartCapture(i)],
(false, true) => vec![Tag::EndCapture(i)],
(true, true) => vec![Tag::StartCapture(i), Tag::EndCapture(i)],
})
.collect();
return Some((update.state, range, tags));
})
.collect();
let mut run_start: Vec<bool> = run_next.iter().map(Option::is_some).collect();
for next in &run_next {
if let Some((next, _, _)) = next {
run_start[*next] = false;
}
}
let mut out = vec![StateRunInclusion::None; symbol_transitions.len()];
for (start_state, is_start) in run_start.iter().enumerate() {
if !*is_start {
continue;
}
let mut symbols = Vec::new();
let mut tags = Vec::new();
let mut internal = Vec::new(); let mut state = start_state;
while let Some(Some((next, range, transition_tags))) = run_next.get(state) {
if *next == start_state {
break;
}
for tag in transition_tags {
tags.push((symbols.len(), tag.clone()));
}
symbols.push((*range).clone());
internal.push(*next);
state = *next;
}
tags.sort();
if state == start_state {
continue;
}
out[start_state] = StateRunInclusion::Start(Run {
start_state,
symbols,
tags,
end_state: state,
});
for internal_state in internal {
out[internal_state] = StateRunInclusion::Internal;
}
out[state] = StateRunInclusion::End;
}
return out;
}
fn codegen_vmlike(
nfa: &U8NFA,
num_captures: usize,
symbol_transitions: Vec<Vec<(std::ops::RangeInclusive<u8>, ThreadUpdates)>>,
accept_transitions: Vec<Vec<ThreadUpdates>>,
) -> TokenStream {
let U8NFA { states, .. } = nfa;
let excluded_states = compute_excluded_states(nfa);
let enum_states: proc_macro2::TokenStream = std::iter::IntoIterator::into_iter(0..states.len())
.filter(|i| !excluded_states[*i])
.map(|i| {
let label = vmstate_label(i);
return quote! { #label, };
})
.collect();
let make_test_match_statements = |state_idx: usize| -> TokenStream {
let mut out = TokenStream::new();
let this_state = vmstate_label(state_idx);
for (range, thread) in &symbol_transitions[state_idx] {
if excluded_states[thread.state] {
continue; }
if thread.end_only {
continue; }
let range_start = *range.start();
let range_end = *range.end();
let conditions = if thread.start_only {
quote! {if i == 0}
} else {
TokenStream::new()
};
let to = vmstate_label(thread.state);
out.extend(quote! {
(VMStates::#this_state, #range_start..=#range_end) #conditions => {
state = VMStates::#to;
}
});
}
return out;
};
let make_test_match_statements_final = |state_idx: usize| -> TokenStream {
let mut out = TokenStream::new();
let this_state = vmstate_label(state_idx);
for thread in &accept_transitions[state_idx] {
let conditions = if thread.start_only {
quote! {if i == 0}
} else {
TokenStream::new()
};
out.extend(quote! {
VMStates::#this_state #conditions => true,
});
}
return out;
};
let make_exec_match_statements = |state_idx: usize| -> TokenStream {
let mut out = TokenStream::new();
let this_state = vmstate_label(state_idx);
for (range, thread) in &symbol_transitions[state_idx] {
if excluded_states[thread.state] {
continue; }
if thread.end_only {
continue; }
let range_start = *range.start();
let range_end = *range.end();
let conditions = if thread.start_only {
quote! {if i == 0}
} else {
TokenStream::new()
};
let mut capture_updates = TokenStream::new();
for (group_num, (start, end)) in thread.update_captures.iter().enumerate() {
if *start {
capture_updates.extend(quote! {
captures[#group_num].0 = i;
});
}
if *end {
capture_updates.extend(quote! {
captures[#group_num].1 = i;
});
}
}
let to = vmstate_label(thread.state);
out.extend(quote! {
(VMStates::#this_state, #range_start..=#range_end) #conditions => {
#capture_updates
state = VMStates::#to;
}
});
}
return out;
};
let make_exec_match_statements_final = |state_idx: usize| -> TokenStream {
let mut out = TokenStream::new();
let this_state = vmstate_label(state_idx);
for thread in &accept_transitions[state_idx] {
let conditions = if thread.start_only {
quote! {if i == 0}
} else {
TokenStream::new()
};
let mut capture_updates = TokenStream::new();
for (group_num, (start, end)) in thread.update_captures.iter().enumerate() {
if *start {
capture_updates.extend(quote! {
captures[#group_num].0 = text.len();
});
}
if *end {
capture_updates.extend(quote! {
captures[#group_num].1 = text.len();
});
}
}
out.extend(quote! {
VMStates::#this_state #conditions => {
#capture_updates
}
});
}
return out;
};
let test_match_statements: TokenStream = (0..states.len())
.filter(|state_idx| !excluded_states[*state_idx])
.map(make_test_match_statements)
.collect();
let test_match_statements_final: TokenStream = (0..states.len())
.filter(|state_idx| !excluded_states[*state_idx])
.map(make_test_match_statements_final)
.collect();
let exec_match_statements: TokenStream = (0..states.len())
.filter(|state_idx| !excluded_states[*state_idx])
.map(make_exec_match_statements)
.collect();
let exec_match_statements_final: TokenStream = (0..states.len())
.filter(|state_idx| !excluded_states[*state_idx])
.map(make_exec_match_statements_final)
.collect();
return quote! {{
#[derive(Clone, Copy, PartialEq, Eq, Debug)]
enum VMStates {
#enum_states
}
fn test(text: &str) -> bool {
let mut state: VMStates = VMStates::State0;
for (i, c) in text.bytes().enumerate() {
match (state, c) {
#test_match_statements
_ => return false,
}
}
return match state {
#test_match_statements_final
_ => false,
};
}
fn exec<'a>(text: &'a str) -> Option<[Option<&'a str>; #num_captures]> {
let mut state: VMStates = VMStates::State0;
let mut captures: [(usize, usize); #num_captures] = [(usize::MAX, usize::MAX); #num_captures];
for (i, c) in text.bytes().enumerate() {
match (state, c) {
#exec_match_statements
_ => return ::core::option::Option::None,
}
}
match state {
#exec_match_statements_final
_ => return ::core::option::Option::None,
}
let mut capture_strs = [::core::option::Option::None; #num_captures];
for (i, (start, end)) in captures.into_iter().enumerate() {
if start != usize::MAX {
assert_ne!(end, usize::MAX);
capture_strs[i] = text.get(start..end);
assert!(capture_strs[i].is_some());
} else {
assert_eq!(end, usize::MAX);
}
}
return Some(capture_strs);
}
(test, exec)
}}.into();
}
fn codegen_functional(
nfa: &U8NFA,
num_captures: usize,
symbol_transitions: Vec<Vec<(std::ops::RangeInclusive<u8>, ThreadUpdates)>>,
accept_transitions: Vec<Vec<ThreadUpdates>>,
) -> TokenStream {
let U8NFA { states, .. } = nfa;
let excluded_states = compute_excluded_states(nfa);
let fn_idents: Vec<_> = (0..states.len())
.map(|i| {
proc_macro2::Ident::new(&format!("func_state_{i}"), proc_macro2::Span::call_site())
})
.collect();
let runs = compute_runs(&symbol_transitions);
let make_test_func = |(i, run): (usize, Option<&Run>)| {
let fn_ident = &fn_idents[i];
if let Some(run) = run {
debug_assert_eq!(run.start_state, i);
let run_length = run.symbols.len();
let new_state_fn_ident = &fn_idents[run.end_state];
let conditions: proc_macro2::TokenStream = run
.symbols
.iter()
.enumerate()
.map(|(i, range)| {
let lower = range.start();
let upper = range.end();
return quote! {
(#lower <= run_part[#i]) & (run_part[#i] <= #upper) &
};
})
.collect();
return quote! {
fn #fn_ident<'a>(mut bytes: ::core::slice::Iter<'a, u8>, start: bool) -> bool {
let ::core::option::Option::Some((run_part, rest)) = bytes.as_slice().split_at_checked(#run_length) else {
return false;
};
let result = #conditions true;
return result && #new_state_fn_ident(rest.iter(), false);
}
};
}
let accept_transitions = &accept_transitions[i];
let end_case = if accept_transitions.is_empty() {
quote! { false }
} else if accept_transitions.iter().any(|tu| !tu.start_only) {
quote! { true }
} else if i == 0 {
quote! { start }
} else {
quote! { false }
};
let symbol_transitions = &symbol_transitions[i];
let cases: TokenStream = symbol_transitions
.into_iter()
.map(|(range, tu)| {
let start = *range.start();
let end = *range.end();
let conditions = if tu.end_only || (tu.start_only && i != 0) {
quote! { if false }
} else if tu.start_only && i == 0 {
quote! { if start }
} else {
TokenStream::new()
};
let new_state_fn_ident = &fn_idents[tu.state];
return quote! {
::core::option::Option::Some(#start..=#end) #conditions => #new_state_fn_ident(bytes, false),
};
})
.collect();
return quote! {
fn #fn_ident<'a>(mut bytes: ::core::slice::Iter<'a, u8>, start: bool) -> bool {
return match bytes.next() {
::core::option::Option::None => #end_case,
#cases
::core::option::Option::Some(_) => false,
}
}
};
};
let test_funcs: TokenStream = runs
.iter()
.enumerate()
.filter(|(i, _)| !excluded_states[*i])
.filter_map(|(i, run)| match run {
StateRunInclusion::Start(run) => Some((i, Some(run))),
StateRunInclusion::Internal => None,
StateRunInclusion::End => Some((i, None)),
StateRunInclusion::None => Some((i, None)),
})
.map(make_test_func)
.collect();
fn make_capture_statements(tu: &ThreadUpdates) -> TokenStream {
fn map_capture((capture_group, (start, end)): (usize, &(bool, bool))) -> TokenStream {
let mut out = TokenStream::new();
if *start {
out.extend(quote! {
captures[#capture_group].0 = byte_idx;
});
}
if *end {
out.extend(quote! {
captures[#capture_group].1 = byte_idx;
});
}
return out;
}
return tu
.update_captures
.iter()
.enumerate()
.map(map_capture)
.collect();
}
let make_exec_func = |(i, run): (usize, Option<&Run>)| {
let fn_ident = &fn_idents[i];
if let Some(run) = run {
debug_assert_eq!(run.start_state, i);
let run_length = run.symbols.len();
let new_state_fn_ident = &fn_idents[run.end_state];
let conditions: proc_macro2::TokenStream = run
.symbols
.iter()
.enumerate()
.map(|(i, range)| {
let lower = range.start();
let upper = range.end();
return quote! {
(#lower <= run_part[#i]) & (run_part[#i] <= #upper) &
};
})
.collect();
let apply_tags: proc_macro2::TokenStream = run
.tags
.iter()
.map(|(offset, tag)| match tag {
Tag::StartCapture(capture_idx) => {
quote! { captures[#capture_idx].0 = #offset + byte_idx; }
}
Tag::EndCapture(capture_idx) => {
quote! { captures[#capture_idx].1 = #offset + byte_idx; }
}
})
.collect();
return quote! {
fn #fn_ident<'a>(
mut bytes: ::core::slice::Iter<'a, u8>,
byte_idx: usize,
mut captures: [(usize, usize); #num_captures],
start: bool,
) -> Option<[(usize, usize); #num_captures]> {
let ::core::option::Option::Some((run_part, rest)) = bytes.as_slice().split_at_checked(#run_length) else {
return ::core::option::Option::None;
};
let result = #conditions true;
if !result {
return ::core::option::Option::None;
}
#apply_tags
return #new_state_fn_ident(rest.into_iter(), byte_idx + #run_length, captures, false);
}
};
}
let accept_transitions = &accept_transitions[i];
let end_case = match &accept_transitions.as_slice() {
&[] => quote! { ::core::option::Option::None },
&[tu] if tu.start_only && i == 0 => {
let capture_statements = make_capture_statements(tu);
quote! {
#capture_statements
if start {
::core::option::Option::Some(captures)
} else {
::core::option::Option::None
}
}
}
&[tu] if tu.start_only && i != 0 => quote! { ::core::option::Option::None },
&[tu] => {
let capture_statements = make_capture_statements(tu);
quote! {
#capture_statements
::core::option::Option::Some(captures)
}
}
_more => quote! {
compiler_error!("Should only be one thread update on accept for one-pass");
},
};
let symbol_transitions = &symbol_transitions[i];
let cases: TokenStream = symbol_transitions
.into_iter()
.map(|(range, tu)| {
let start = *range.start();
let end = *range.end();
let conditions = if tu.end_only || (tu.start_only && i != 0) {
quote! { if false }
} else if tu.start_only && i == 0 {
quote! { if start }
} else {
TokenStream::new()
};
let new_state_fn_ident = &fn_idents[tu.state];
let capture_statements = make_capture_statements(tu);
return quote! {
::core::option::Option::Some(#start..=#end) #conditions => {
#capture_statements
#new_state_fn_ident(bytes, byte_idx + 1, captures, false)
}
};
})
.collect();
return quote! {
fn #fn_ident<'a>(
mut bytes: ::core::slice::Iter<'a, u8>,
byte_idx: usize,
mut captures: [(usize, usize); #num_captures],
start: bool,
) -> Option<[(usize, usize); #num_captures]> {
return match bytes.next() {
::core::option::Option::None => {
#end_case
}
#cases
::core::option::Option::Some(_) => ::core::option::Option::None,
};
}
};
};
let exec_funcs: TokenStream = runs
.iter()
.enumerate()
.filter(|(i, _)| !excluded_states[*i])
.filter_map(|(i, run)| match run {
StateRunInclusion::Start(run) => Some((i, Some(run))),
StateRunInclusion::Internal => None,
StateRunInclusion::End => Some((i, None)),
StateRunInclusion::None => Some((i, None)),
})
.map(make_exec_func)
.collect();
return quote! {{
fn test<'a>(text: &'a str) -> bool {
#test_funcs
return func_state_0(text.as_bytes().iter(), true);
}
fn exec<'a>(text: &'a str) -> ::core::option::Option<[::core::option::Option<&'a str>; #num_captures]> {
let captures: [(usize, usize); #num_captures] = [(usize::MAX, usize::MAX); #num_captures];
#exec_funcs
let captures = func_state_0(text.as_bytes().iter(), 0, captures, true)?;
let mut capture_strs = [::core::option::Option::None; #num_captures];
for (i, (start, end)) in captures.into_iter().enumerate() {
if start != usize::MAX {
assert_ne!(end, usize::MAX);
capture_strs[i] = text.get(start..end);
assert!(capture_strs[i].is_some());
} else {
assert_eq!(end, usize::MAX);
}
}
return ::core::option::Option::Some(capture_strs);
}
(test, exec)
}}.into();
}