use crate::generator::{
analyze_char_classes, analyze_first_byte_dispatch, classify_rules, compute_first_sets,
is_left_recursive, is_tail_loop, DispatchCascade,
};
use crate::mplg::MplgOutput;
use mpl::symbols::{Metasymbol, TerminalSymbol, E};
use proc_macro2::TokenStream;
use quote::{format_ident, quote};
use std::collections::BTreeSet;
use syn::{Expr, ExprCall, ExprLit, ExprPath, Ident, Lit};
fn snake(s: &str) -> String {
let mut out = String::with_capacity(s.len() + 4);
for (i, c) in s.chars().enumerate() {
if c.is_ascii_uppercase() && i > 0 {
out.push('_');
}
out.push(c.to_ascii_lowercase());
}
out
}
fn fn_ident(name: &str) -> Ident {
format_ident!("parse_{}", snake(name))
}
fn const_ident(name: &str) -> Ident {
format_ident!("{}", snake(name).to_uppercase())
}
fn eval_e(e: &E<&str, &str>, pos_tok: TokenStream) -> TokenStream {
match e {
E::V(v) => {
let f = fn_ident(v);
quote! { #f::<S, EMIT>(state, #pos_tok) }
}
E::T(t) => match t {
TerminalSymbol::Metasymbol(m) => match m {
Metasymbol::Empty => quote! {
{
let p = #pos_tok;
if EMIT {
state.push_leaf(::mpl::fast::kind::EMPTY, p, p);
}
Ok::<u32, ()>(p)
}
},
Metasymbol::Failure => quote! { Err::<u32, ()>(()) },
Metasymbol::Any(n) => {
let n_lit = *n as u32;
quote! {
{
let p = #pos_tok;
let want = p as usize + #n_lit as usize;
if want <= state.input().len() {
if EMIT {
state.push_leaf(::mpl::fast::kind::ANY, p, want as u32);
}
Ok::<u32, ()>(want as u32)
} else {
state.note_failure_at(p);
Err::<u32, ()>(())
}
}
}
}
Metasymbol::All => quote! {
{
let p = #pos_tok;
let end = state.input().len() as u32;
if EMIT {
state.push_leaf(::mpl::fast::kind::ALL, p, end);
}
Ok::<u32, ()>(end)
}
},
Metasymbol::Omit => unimplemented!("Omit metasymbol not supported in FastParse"),
},
TerminalSymbol::Original(raw) => emit_terminal(raw, pos_tok),
},
}
}
fn emit_terminal(raw: &str, pos_tok: TokenStream) -> TokenStream {
let parsed: Expr = syn::parse_str(raw).unwrap_or_else(|_| {
panic!("FastParse: cannot parse terminal expression `{raw}`");
});
if let Expr::Call(ExprCall { func, args, .. }) = &parsed {
if let Expr::Path(ExprPath { path, .. }) = &**func {
if path.is_ident("Char") {
if let Some(Expr::Lit(ExprLit { lit: Lit::Char(c), .. })) = args.first() {
let byte = c.value() as u32 as u8;
return quote! {
{
let p = #pos_tok;
let pi = p as usize;
if pi < state.input().len() && state.input()[pi] == #byte {
if EMIT {
state.push_leaf(::mpl::fast::kind::TERMINAL, p, p + 1);
}
Ok::<u32, ()>(p + 1)
} else {
state.note_failure_at(p);
Err::<u32, ()>(())
}
}
};
}
} else if path.is_ident("Str") {
if let Some(Expr::Lit(ExprLit { lit: Lit::Str(s), .. })) = args.first() {
let bytes_lit = syn::LitByteStr::new(s.value().as_bytes(), s.span());
return quote! {
{
let p = #pos_tok;
let pi = p as usize;
let pat: &[u8] = #bytes_lit;
let end = pi + pat.len();
if end <= state.input().len() && &state.input()[pi..end] == pat {
if EMIT {
state.push_leaf(::mpl::fast::kind::TERMINAL, p, end as u32);
}
Ok::<u32, ()>(end as u32)
} else {
state.note_failure_at(p);
Err::<u32, ()>(())
}
}
};
}
}
}
}
panic!("FastParse: unsupported terminal expression `{raw}`")
}
fn is_failure(e: &E<&str, &str>) -> bool {
matches!(e, E::T(TerminalSymbol::Metasymbol(Metasymbol::Failure)))
}
fn is_variable(e: &E<&str, &str>) -> bool {
matches!(e, E::V(_))
}
fn generate_byte_table(bytes: &BTreeSet<u8>) -> TokenStream {
let entries = bytes.iter().map(|b| {
let b = *b;
quote! { __t[#b as usize] = true; }
});
quote! {
{
let mut __t = [false; 256];
#(#entries)*
__t
}
}
}
fn generate_cascade_dispatcher(
rule: &mpl::rules::Rule<&str, &str>,
bytes: &BTreeSet<u8>,
cascade_body: TokenStream,
) -> TokenStream {
let name = rule.value;
let f = fn_ident(name);
let table_ident = format_ident!("__{}_BYTES", const_ident(name));
let table_init = generate_byte_table(bytes);
quote! {
const #table_ident: [bool; 256] = #table_init;
#[inline(always)]
fn #f<S: ::mpl::fast::ParseState, const EMIT: bool>(
state: &mut S,
pos: u32,
) -> ::core::result::Result<u32, ()> {
if !EMIT {
let p = pos as usize;
if p < state.input().len() && #table_ident[state.input()[p] as usize] {
return Ok(pos + 1);
}
state.note_failure_at(pos);
return Err(());
}
#cascade_body
}
}
}
fn generate_rule_inner(rule: &mpl::rules::Rule<&str, &str>) -> TokenStream {
let name = rule.value;
let f = fn_ident(name);
let kind_const = const_ident(name);
let lhs = &rule.equal.first.lhs;
let rhs = &rule.equal.first.rhs;
let alt = &rule.equal.second.0;
let eval_lhs = eval_e(lhs, quote! { pos });
let eval_rhs = eval_e(rhs, quote! { end_b });
let first_choice = quote! {
{
let cp = if EMIT { state.checkpoint() } else { 0 };
let s = if EMIT { state.push_start(var::#kind_const, pos) } else { 0 };
if let Ok(end_b) = #eval_lhs {
if let Ok(end_c) = #eval_rhs {
if EMIT {
state.push_end(var::#kind_const, end_c, s);
}
return Ok(end_c);
}
}
if EMIT {
state.truncate(cp);
}
}
};
let second_choice = if is_failure(alt) {
quote! { Err(()) }
} else if is_variable(alt) {
let eval_alt = eval_e(alt, quote! { pos });
quote! {
{
let cp2 = if EMIT { state.checkpoint() } else { 0 };
let s = if EMIT { state.push_start(var::#kind_const, pos) } else { 0 };
if let Ok(end) = #eval_alt {
if EMIT {
state.push_end(var::#kind_const, end, s);
}
return Ok(end);
}
if EMIT {
state.truncate(cp2);
}
Err(())
}
}
} else {
let eval_alt = eval_e(alt, quote! { pos });
quote! {
{ #eval_alt }
}
};
quote! {
#[inline]
fn #f<S: ::mpl::fast::ParseState, const EMIT: bool>(
state: &mut S,
pos: u32,
) -> ::core::result::Result<u32, ()> {
#first_choice
#second_choice
}
}
}
fn generate_rule_body(rule: &mpl::rules::Rule<&str, &str>) -> TokenStream {
let kind_const = const_ident(rule.value);
let lhs = &rule.equal.first.lhs;
let rhs = &rule.equal.first.rhs;
let alt = &rule.equal.second.0;
let eval_lhs = eval_e(lhs, quote! { pos });
let eval_rhs = eval_e(rhs, quote! { end_b });
let first_choice = quote! {
{
let cp = if EMIT { state.checkpoint() } else { 0 };
let s = if EMIT { state.push_start(var::#kind_const, pos) } else { 0 };
if let Ok(end_b) = #eval_lhs {
if let Ok(end_c) = #eval_rhs {
if EMIT {
state.push_end(var::#kind_const, end_c, s);
}
return Ok(end_c);
}
}
if EMIT {
state.truncate(cp);
}
}
};
let second_choice = if is_failure(alt) {
quote! { Err(()) }
} else if is_variable(alt) {
let eval_alt = eval_e(alt, quote! { pos });
quote! {
{
let cp2 = if EMIT { state.checkpoint() } else { 0 };
let s = if EMIT { state.push_start(var::#kind_const, pos) } else { 0 };
if let Ok(end) = #eval_alt {
if EMIT {
state.push_end(var::#kind_const, end, s);
}
return Ok(end);
}
if EMIT {
state.truncate(cp2);
}
Err(())
}
}
} else {
let eval_alt = eval_e(alt, quote! { pos });
quote! {
{ #eval_alt }
}
};
quote! {
#first_choice
#second_choice
}
}
fn generate_first_byte_dispatcher(
rule: &mpl::rules::Rule<&str, &str>,
cascade: &DispatchCascade<'_>,
cascade_body: TokenStream,
) -> TokenStream {
let name = rule.value;
let f = fn_ident(name);
let mut byte_to_alts: [Vec<usize>; 256] = std::array::from_fn(|_| Vec::new());
for (i, alt) in cascade.alternatives.iter().enumerate() {
for b in alt.first.iter_bytes() {
byte_to_alts[b as usize].push(i);
}
}
let mut groups: std::collections::BTreeMap<Vec<usize>, BTreeSet<u8>> =
std::collections::BTreeMap::new();
for (b, alts) in byte_to_alts.iter().enumerate() {
if alts.is_empty() {
continue;
}
groups.entry(alts.clone()).or_default().insert(b as u8);
}
let arms = groups.iter().map(|(alt_indices, bytes)| {
let byte_pat: Vec<TokenStream> = bytes.iter().map(|b| {
let b = *b;
quote! { #b }
}).collect();
let pat = if byte_pat.len() == 1 {
let b = &byte_pat[0];
quote! { #b }
} else {
quote! { #(#byte_pat)|* }
};
let mut call_chain = TokenStream::new();
for (idx, alt_idx) in alt_indices.iter().enumerate() {
let alt_name = cascade.alternatives[*alt_idx].rule_name;
let alt_fn = fn_ident(alt_name);
if idx == 0 {
call_chain = quote! { #alt_fn::<S, EMIT>(state, pos) };
} else {
call_chain = quote! {
#call_chain.or_else(|_| #alt_fn::<S, EMIT>(state, pos))
};
}
}
quote! { #pat => #call_chain, }
});
quote! {
#[inline(always)]
fn #f<S: ::mpl::fast::ParseState, const EMIT: bool>(
state: &mut S,
pos: u32,
) -> ::core::result::Result<u32, ()> {
if EMIT {
#cascade_body
} else {
let p = pos as usize;
if p >= state.input().len() {
state.note_failure_at(pos);
return Err(());
}
match state.input()[p] {
#(#arms)*
_ => {
state.note_failure_at(pos);
Err(())
}
}
}
}
}
}
fn generate_tail_loop(
rule: &mpl::rules::Rule<&str, &str>,
cascade_body: TokenStream,
) -> TokenStream {
let name = rule.value;
let f = fn_ident(name);
let lhs = &rule.equal.first.lhs;
let eval_lhs_loop = eval_e(lhs, quote! { p });
quote! {
#[inline(always)]
fn #f<S: ::mpl::fast::ParseState, const EMIT: bool>(
state: &mut S,
pos: u32,
) -> ::core::result::Result<u32, ()> {
if EMIT {
#cascade_body
} else {
let mut p = pos;
loop {
match #eval_lhs_loop {
Ok(np) => p = np,
Err(_) => break,
}
}
Ok(p)
}
}
}
}
fn generate_left_recursive(
rule: &mpl::rules::Rule<&str, &str>,
cascade_body: TokenStream,
) -> TokenStream {
let name = rule.value;
let f = fn_ident(name);
let kind_const = const_ident(name);
quote! {
#[inline(always)]
fn #f<S: ::mpl::fast::ParseState, const EMIT: bool>(
state: &mut S,
pos: u32,
) -> ::core::result::Result<u32, ()> {
if EMIT {
#cascade_body
} else {
let key_kind = var::#kind_const;
if let Some(entry) = state.lr_memo_get(key_kind, pos) {
return match entry {
::mpl::fast::LrMemoEntry::Done { result } |
::mpl::fast::LrMemoEntry::InProgress { seed: result } => {
result.map(Ok).unwrap_or(Err(()))
}
};
}
state.lr_memo_set(
key_kind,
pos,
::mpl::fast::LrMemoEntry::InProgress { seed: None },
);
let mut best: Option<u32> = None;
loop {
let result: ::core::result::Result<u32, ()> = (|| -> ::core::result::Result<u32, ()> {
#cascade_body
})();
let new_end = result.ok();
let grew = match (new_end, best) {
(Some(_), None) => true,
(Some(n), Some(b)) => n > b,
_ => false,
};
if !grew { break; }
best = new_end;
state.lr_memo_set(
key_kind,
pos,
::mpl::fast::LrMemoEntry::InProgress { seed: best },
);
}
state.lr_memo_set(
key_kind,
pos,
::mpl::fast::LrMemoEntry::Done { result: best },
);
best.map(Ok).unwrap_or(Err(()))
}
}
}
}
fn generate_rule(
rule: &mpl::rules::Rule<&str, &str>,
char_classes: &std::collections::HashMap<&str, BTreeSet<u8>>,
dispatches: &std::collections::HashMap<&str, DispatchCascade<'_>>,
) -> TokenStream {
if is_left_recursive(rule) {
let body = generate_rule_body(rule);
return generate_left_recursive(rule, body);
}
if let Some(bytes) = char_classes.get(rule.value) {
if bytes.len() >= 4 {
let body = generate_rule_body(rule);
return generate_cascade_dispatcher(rule, bytes, body);
}
}
if let Some(cascade) = dispatches.get(rule.value) {
let body = generate_rule_body(rule);
return generate_first_byte_dispatcher(rule, cascade, body);
}
if is_tail_loop(rule) {
let body = generate_rule_body(rule);
return generate_tail_loop(rule, body);
}
generate_rule_inner(rule)
}
pub fn generate_fast(parser_ident: &Ident, lines: &[MplgOutput]) -> TokenStream {
let rules: Vec<&mpl::rules::Rule<&str, &str>> = lines
.iter()
.filter_map(|line| match line {
MplgOutput::Rule(rule) => Some(rule),
_ => None,
})
.collect();
if rules.is_empty() {
return quote! {};
}
let var_consts = rules.iter().enumerate().map(|(i, rule)| {
let cid = const_ident(rule.value);
let i = i as u32;
quote! { pub const #cid: u32 = #i; }
});
let firsts = compute_first_sets(&rules);
let char_classes = analyze_char_classes(&rules);
let dispatches = analyze_first_byte_dispatch(&rules, &firsts, &char_classes);
let rule_fns = rules
.iter()
.map(|rule| generate_rule(rule, &char_classes, &dispatches));
let start_fn = fn_ident(rules[0].value);
let classification = classify_rules(&rules, &firsts);
let needs_memo_entries = classification.iter().map(|(_, needs)| {
let v = *needs;
quote! { #v }
});
let needs_memo_count = classification.len();
let backtracking_count = classification.iter().filter(|(_, b)| *b).count();
let analysis_summary = format!(
"Mizushima cut analysis: {}/{} rules require memoisation",
backtracking_count, needs_memo_count
);
let memo_rule_names: Vec<String> = classification
.iter()
.filter(|(_, b)| *b)
.map(|(name, _)| name.clone())
.collect();
let memo_rule_list = if memo_rule_names.is_empty() {
"(none)".to_string()
} else {
memo_rule_names.join(", ")
};
let needs_memo_doc = format!(
"{}\n\nRules needing memoisation: {}",
analysis_summary, memo_rule_list
);
let mod_ident = format_ident!("__{}_fast", snake(&parser_ident.to_string()));
quote! {
#[allow(non_snake_case, unused, clippy::all)]
mod #mod_ident {
use ::mpl::fast::ParserState;
pub mod var {
#(#var_consts)*
}
#[doc = #needs_memo_doc]
pub const NEEDS_MEMO: [bool; #needs_memo_count] = [#(#needs_memo_entries),*];
#(#rule_fns)*
pub fn recognize(input: &[u8]) -> bool {
let mut state = ParserState::new(input);
matches!(
#start_fn::<ParserState<'_>, true>(&mut state, 0),
Ok(end) if end as usize == input.len()
)
}
pub fn recognize_only(input: &[u8]) -> bool {
let mut state = ::mpl::fast::CheckState::new(input);
matches!(
#start_fn::<::mpl::fast::CheckState<'_>, false>(&mut state, 0),
Ok(end) if end as usize == input.len()
)
}
pub fn parse(
input: &[u8],
) -> ::core::result::Result<Vec<::mpl::fast::Token>, ::mpl::fast::ParseError> {
let mut state = ParserState::new(input);
use ::mpl::fast::ParseState as _;
match #start_fn::<ParserState<'_>, true>(&mut state, 0) {
Ok(end) if end as usize == input.len() => Ok(state.tokens),
Ok(end) => Err(::mpl::fast::ParseError::at(end)),
Err(_) => Err(::mpl::fast::ParseError::at(state.furthest())),
}
}
pub fn recognize_in(input: &[u8], arena: &::mpl::fast::bumpalo::Bump) -> bool {
let mut state = ParserState::new_in(input, arena);
matches!(
#start_fn::<ParserState<'_>, true>(&mut state, 0),
Ok(end) if end as usize == input.len()
)
}
pub fn check(
input: &[u8],
) -> ::core::result::Result<(), ::mpl::fast::ParseError> {
let mut state = ::mpl::fast::CheckState::new(input);
use ::mpl::fast::ParseState as _;
match #start_fn::<::mpl::fast::CheckState<'_>, false>(&mut state, 0) {
Ok(end) if end as usize == input.len() => Ok(()),
Ok(end) => Err(::mpl::fast::ParseError::at(end)),
Err(_) => Err(::mpl::fast::ParseError::at(state.furthest())),
}
}
}
impl #parser_ident {
pub fn fast_recognize(&self, input: &[u8]) -> bool {
#mod_ident::recognize(input)
}
pub fn fast_recognize_only(&self, input: &[u8]) -> bool {
#mod_ident::recognize_only(input)
}
pub fn fast_parse(&self, input: &[u8])
-> ::core::result::Result<Vec<::mpl::fast::Token>, ::mpl::fast::ParseError>
{
#mod_ident::parse(input)
}
pub fn fast_check(&self, input: &[u8])
-> ::core::result::Result<(), ::mpl::fast::ParseError>
{
#mod_ident::check(input)
}
pub const fn fast_needs_memo() -> &'static [bool] {
&#mod_ident::NEEDS_MEMO
}
pub fn fast_recognize_in(
&self,
input: &[u8],
arena: &::mpl::fast::bumpalo::Bump,
) -> bool {
#mod_ident::recognize_in(input, arena)
}
}
}
}