use proc_macro2::TokenStream;
use quote::{format_ident, quote};
use super::CodegenContext;
use super::reduction::{self, ReductionInfo, SymbolKind, typed_symbol_indices};
use super::table::CodegenTableInfo;
use crate::lr::AltAction;
pub fn generate(ctx: &CodegenContext, info: &CodegenTableInfo) -> Result<TokenStream, String> {
let vis: TokenStream = "pub".parse().unwrap();
let terminal_enum = format_ident!("Terminal");
let types_trait = format_ident!("Types");
let parser_struct = format_ident!("Parser");
let value_union = format_ident!("__Value");
let table_mod = format_ident!("__table");
let gazelle_crate_path = ctx.gazelle_crate_path_tokens();
let reductions = reduction::analyze_reductions(ctx)?;
let typed_non_terminals: Vec<_> = ctx
.grammar
.symbols
.non_terminal_ids()
.filter_map(|id| {
let ty = ctx.grammar.types.get(&id)?.as_ref()?;
let name = ctx.grammar.symbols.name(id);
if name.starts_with("__") {
return None;
}
Some((name.to_string(), ty.clone()))
})
.collect();
let all_typed_non_terminals: Vec<_> = ctx
.grammar
.symbols
.non_terminal_ids()
.filter_map(|id| {
let ty = ctx.grammar.types.get(&id)?.as_ref()?;
Some((ctx.grammar.symbols.name(id).to_string(), ty.clone()))
})
.collect();
let start_nt = &ctx.start_symbol;
let start_type_annotation = typed_non_terminals
.iter()
.find(|(name, _)| name == start_nt)
.map(|(_, ty)| ty.clone());
let start_field = format_ident!("__{}", start_nt.to_lowercase());
let enum_code =
generate_nonterminal_enums(ctx, &reductions, &typed_non_terminals, &types_trait, &vis);
let (traits_code, reducer_bounds) = generate_traits(
ctx,
&types_trait,
&typed_non_terminals,
&reductions,
&vis,
&gazelle_crate_path,
);
let value_union_code =
generate_value_union(ctx, &all_typed_non_terminals, &value_union, &types_trait);
let shift_arms = generate_terminal_shift_arms(ctx, &terminal_enum, &value_union);
let reduction_arms =
generate_reduction_arms(ctx, &reductions, &value_union, &typed_non_terminals);
let drop_arms = generate_drop_arms(ctx, info);
let finish_method = if let Some(start_type) = start_type_annotation {
let start_type_ident = format_ident!("{}", start_type);
quote! {
pub fn finish(mut self, actions: &mut A) -> Result<A::#start_type_ident, (Self, A::Error)> {
loop {
match self.parser.maybe_reduce(None) {
Ok(Some((0, _, _))) => {
let union_val = self.value_stack.pop().unwrap();
return Ok(unsafe { std::mem::ManuallyDrop::into_inner(union_val.#start_field) });
}
Ok(Some((rule, _, start_idx))) => {
if let Err(e) = self.do_reduce(rule, start_idx, actions) {
return Err((self, e));
}
}
Ok(None) => unreachable!(),
Err(e) => {
self.drain_values();
self.parser.restore_checkpoint();
return Err((self, e.into()));
}
}
}
}
}
} else {
quote! {
pub fn finish(mut self, actions: &mut A) -> Result<(), (Self, A::Error)> {
loop {
match self.parser.maybe_reduce(None) {
Ok(Some((0, _, _))) => {
self.value_stack.pop();
return Ok(());
}
Ok(Some((rule, _, start_idx))) => {
if let Err(e) = self.do_reduce(rule, start_idx, actions) {
return Err((self, e));
}
}
Ok(None) => unreachable!(),
Err(e) => {
self.drain_values();
self.parser.restore_checkpoint();
return Err((self, e.into()));
}
}
}
}
}
};
Ok(quote! {
#enum_code
#traits_code
#value_union_code
#vis struct #parser_struct<A: #types_trait> {
parser: #gazelle_crate_path::Parser<'static>,
value_stack: Vec<#value_union<A>>,
}
impl<A: #types_trait> #parser_struct<A> {
pub fn new() -> Self {
Self {
parser: #gazelle_crate_path::Parser::new(#table_mod::TABLE),
value_stack: Vec::new(),
}
}
pub fn state(&self) -> usize {
self.parser.state()
}
pub fn format_error(
&self,
err: &#gazelle_crate_path::ParseError,
display_names: Option<&std::collections::HashMap<&str, &str>>,
tokens: Option<&[&str]>,
) -> String {
self.parser.format_error(err, &#table_mod::ERROR_INFO, display_names, tokens)
}
pub fn error_info() -> &'static #gazelle_crate_path::ErrorInfo<'static> {
&#table_mod::ERROR_INFO
}
pub fn recover(&mut self, buffer: &[#gazelle_crate_path::Token]) -> Vec<#gazelle_crate_path::RecoveryInfo> {
self.parser.recover(buffer)
}
fn drain_values(&mut self) {
for i in (0..self.value_stack.len()).rev() {
let union_val = self.value_stack.pop().unwrap();
let sym_id = #table_mod::STATE_SYMBOL[self.parser.state_at(i)];
unsafe {
match sym_id {
#(#drop_arms)*
_ => {}
}
}
}
}
}
#[allow(clippy::result_large_err)]
impl<A: #types_trait #(#reducer_bounds)*> #parser_struct<A> {
pub fn push(&mut self, terminal: #terminal_enum<A>, actions: &mut A) -> Result<(), A::Error> {
let token = #gazelle_crate_path::Token {
terminal: terminal.symbol_id(),
prec: terminal.precedence(),
};
loop {
match self.parser.maybe_reduce(Some(token)) {
Ok(Some((rule, _, start_idx))) => {
self.do_reduce(rule, start_idx, actions)?;
}
Ok(None) => break,
Err(e) => {
self.drain_values();
self.parser.restore_checkpoint();
return Err(e.into());
}
}
}
self.parser.shift(token);
match terminal {
#(#shift_arms)*
}
Ok(())
}
#finish_method
fn do_reduce(&mut self, rule: usize, start_idx: usize, actions: &mut A) -> Result<(), A::Error> {
if rule == 0 { return Ok(()); }
actions.set_token_range(start_idx, self.parser.token_count());
let original_rule_idx = rule - 1;
let value = match original_rule_idx {
#(#reduction_arms)*
_ => return Ok(()),
};
self.value_stack.push(value);
Ok(())
}
}
impl<A: #types_trait> Default for #parser_struct<A> {
fn default() -> Self { Self::new() }
}
impl<A: #types_trait> Drop for #parser_struct<A> {
fn drop(&mut self) {
self.drain_values();
}
}
})
}
fn generate_nonterminal_enums(
ctx: &CodegenContext,
reductions: &[ReductionInfo],
typed_non_terminals: &[(String, String)],
types_trait: &syn::Ident,
vis: &TokenStream,
) -> TokenStream {
let mut enums = Vec::new();
let terminal_assoc_types: std::collections::BTreeMap<&str, &str> = ctx
.grammar
.symbols
.terminal_ids()
.skip(1)
.filter_map(|id| {
let type_name = ctx.grammar.types.get(&id)?.as_ref()?;
Some((ctx.grammar.symbols.name(id), type_name.as_str()))
})
.collect();
let nt_result_types: std::collections::HashMap<&str, &str> = typed_non_terminals
.iter()
.map(|(name, result_type)| (name.as_str(), result_type.as_str()))
.collect();
let mut nt_variants: std::collections::BTreeMap<&str, Vec<&ReductionInfo>> =
std::collections::BTreeMap::new();
for info in reductions {
if info.variant_name.is_some() {
nt_variants
.entry(&info.non_terminal)
.or_default()
.push(info);
}
}
for (nt_name, variants) in &nt_variants {
let enum_ident = enum_name(nt_name);
let variant_defs: Vec<_> = variants
.iter()
.map(|info| {
let variant_name = format_ident!(
"{}",
crate::lr::to_camel_case(info.variant_name.as_ref().unwrap())
);
let fields: Vec<_> = typed_symbol_indices(&info.rhs_symbols)
.iter()
.map(|&idx| {
let sym = &info.rhs_symbols[idx];
symbol_to_field_type(sym, &nt_result_types, &terminal_assoc_types, ctx)
})
.collect();
if fields.is_empty() {
quote! { #variant_name }
} else {
quote! { #variant_name(#(#fields),*) }
}
})
.collect();
let uses_a = variants.iter().any(|info| {
typed_symbol_indices(&info.rhs_symbols).iter().any(|&idx| {
let sym = &info.rhs_symbols[idx];
symbol_references_a(sym, &nt_result_types, &terminal_assoc_types, ctx)
})
});
let (phantom_variant, phantom_arm) = if !uses_a {
(
quote! {
, #[doc(hidden)] _Phantom(std::convert::Infallible, std::marker::PhantomData<A>)
},
quote! { _ => unreachable!(), },
)
} else {
(quote! {}, quote! {})
};
let debug_arms: Vec<_> = variants.iter().map(|info| {
let variant_name = format_ident!("{}", crate::lr::to_camel_case(info.variant_name.as_ref().unwrap()));
let field_indices = typed_symbol_indices(&info.rhs_symbols);
let field_count = field_indices.len();
let variant_str = variant_name.to_string();
if field_count == 0 {
quote! { Self::#variant_name => f.write_str(#variant_str) }
} else {
let bindings: Vec<_> = (0..field_count).map(|i| format_ident!("f{}", i)).collect();
let field_calls: Vec<_> = bindings.iter().map(|b| quote! { .field(#b) }).collect();
quote! {
Self::#variant_name(#(#bindings),*) => f.debug_tuple(#variant_str)#(#field_calls)*.finish()
}
}
}).collect();
enums.push(quote! {
#vis enum #enum_ident<A: #types_trait> {
#(#variant_defs),*
#phantom_variant
}
impl<A: #types_trait> std::fmt::Debug for #enum_ident<A> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self { #(#debug_arms,)* #phantom_arm }
}
}
});
}
quote! { #(#enums)* }
}
fn symbol_to_field_type(
sym: &reduction::SymbolInfo,
nt_result_types: &std::collections::HashMap<&str, &str>,
terminal_assoc_types: &std::collections::BTreeMap<&str, &str>,
ctx: &CodegenContext,
) -> TokenStream {
if sym.kind == SymbolKind::NonTerminal {
if let Some(&result_type) = nt_result_types.get(sym.name.as_str()) {
let assoc = format_ident!("{}", result_type);
quote! { A::#assoc }
} else if sym.name.starts_with("__") {
if let Some(result_type) = ctx.get_type(&sym.name) {
synthetic_type_to_tokens_with_prefix(result_type, false)
} else {
quote! { () }
}
} else {
quote! { () }
}
} else if let Some(assoc_name) = terminal_assoc_types.get(sym.name.as_str()) {
let assoc = format_ident!("{}", assoc_name);
quote! { A::#assoc }
} else {
quote! { () }
}
}
fn symbol_references_a(
sym: &reduction::SymbolInfo,
nt_result_types: &std::collections::HashMap<&str, &str>,
terminal_assoc_types: &std::collections::BTreeMap<&str, &str>,
ctx: &CodegenContext,
) -> bool {
if sym.kind == SymbolKind::NonTerminal {
if nt_result_types.contains_key(sym.name.as_str()) {
true
} else if sym.name.starts_with("__") {
if let Some(result_type) = ctx.get_type(&sym.name) {
let inner = result_type
.strip_prefix("Option<")
.and_then(|s| s.strip_suffix('>'))
.or_else(|| {
result_type
.strip_prefix("Vec<")
.and_then(|s| s.strip_suffix('>'))
});
match inner {
Some("()") => false,
Some(_) => true,
None => false,
}
} else {
false
}
} else {
false
}
} else {
terminal_assoc_types.contains_key(sym.name.as_str())
}
}
fn generate_traits(
ctx: &CodegenContext,
types_trait: &syn::Ident,
typed_non_terminals: &[(String, String)],
reductions: &[ReductionInfo],
vis: &TokenStream,
gazelle_crate_path: &TokenStream,
) -> (TokenStream, Vec<TokenStream>) {
let mut assoc_types = Vec::new();
let mut seen_types = std::collections::HashSet::new();
for id in ctx.grammar.symbols.terminal_ids().skip(1) {
if let Some(type_name) = ctx.grammar.types.get(&id).and_then(|t| t.as_ref())
&& seen_types.insert(type_name.as_str())
{
let type_ident = format_ident!("{}", type_name);
assoc_types.push(quote! { type #type_ident: std::fmt::Debug; });
}
}
for (_, result_type) in typed_non_terminals {
if seen_types.insert(result_type.as_str()) {
let type_name = format_ident!("{}", result_type);
assoc_types.push(quote! { type #type_name: std::fmt::Debug; });
}
}
let mut reducer_bounds = Vec::new();
let mut ast_node_impls = Vec::new();
let mut seen_nt = std::collections::HashSet::new();
for info in reductions {
if info.variant_name.is_some() && seen_nt.insert(&info.non_terminal) {
let enum_ident = enum_name(&info.non_terminal);
if let Some((_, result_type)) = typed_non_terminals
.iter()
.find(|(n, _)| n == &info.non_terminal)
{
let result_ident = format_ident!("{}", result_type);
ast_node_impls.push(quote! {
impl<A: #types_trait> #gazelle_crate_path::AstNode for #enum_ident<A> {
type Output = A::#result_ident;
type Error = A::Error;
}
});
} else {
ast_node_impls.push(quote! {
impl<A: #types_trait> #gazelle_crate_path::AstNode for #enum_ident<A> {
type Output = ();
type Error = A::Error;
}
});
}
reducer_bounds.push(quote! { + #gazelle_crate_path::Action<#enum_ident<A>> });
}
}
(
quote! {
#vis trait #types_trait: Sized {
type Error: From<#gazelle_crate_path::ParseError>;
#(#assoc_types)*
#[allow(unused_variables)]
fn set_token_range(&mut self, start: usize, end: usize) {}
}
#(#ast_node_impls)*
},
reducer_bounds,
)
}
fn generate_value_union(
ctx: &CodegenContext,
typed_non_terminals: &[(String, String)],
value_union: &syn::Ident,
types_trait: &syn::Ident,
) -> TokenStream {
let mut fields = Vec::new();
for id in ctx.grammar.symbols.terminal_ids().skip(1) {
if let Some(type_name) = ctx.grammar.types.get(&id).and_then(|t| t.as_ref()) {
let name = ctx.grammar.symbols.name(id);
let field_name = format_ident!("__{}", name.to_lowercase());
let assoc_type = format_ident!("{}", type_name);
fields.push(quote! { #field_name: std::mem::ManuallyDrop<A::#assoc_type>, });
}
}
for (name, result_type) in typed_non_terminals {
let field_name = format_ident!("__{}", name.to_lowercase());
if name.starts_with("__") {
let field_type = synthetic_type_to_tokens_with_prefix(result_type, false);
fields.push(quote! { #field_name: std::mem::ManuallyDrop<#field_type>, });
} else {
let assoc_type = format_ident!("{}", result_type);
fields.push(quote! { #field_name: std::mem::ManuallyDrop<A::#assoc_type>, });
}
}
quote! {
#[doc(hidden)]
union #value_union<A: #types_trait> {
#(#fields)*
__unit: (),
__phantom: std::mem::ManuallyDrop<std::marker::PhantomData<A>>,
}
}
}
fn synthetic_type_to_tokens_with_prefix(type_str: &str, use_self: bool) -> TokenStream {
if let Some(inner) = type_str
.strip_prefix("Option<")
.and_then(|s| s.strip_suffix('>'))
{
if inner == "()" {
quote! { Option<()> }
} else {
let inner_ident = format_ident!("{}", inner);
if use_self {
quote! { Option<Self::#inner_ident> }
} else {
quote! { Option<A::#inner_ident> }
}
}
} else if let Some(inner) = type_str
.strip_prefix("Vec<")
.and_then(|s| s.strip_suffix('>'))
{
if inner == "()" {
quote! { Vec<()> }
} else {
let inner_ident = format_ident!("{}", inner);
if use_self {
quote! { Vec<Self::#inner_ident> }
} else {
quote! { Vec<A::#inner_ident> }
}
}
} else {
let ident = format_ident!("{}", type_str);
if use_self {
quote! { Self::#ident }
} else {
quote! { A::#ident }
}
}
}
fn generate_terminal_shift_arms(
ctx: &CodegenContext,
terminal_enum: &syn::Ident,
value_union: &syn::Ident,
) -> Vec<TokenStream> {
let mut arms = Vec::new();
for id in ctx.grammar.symbols.terminal_ids().skip(1) {
let name = ctx.grammar.symbols.name(id);
let variant_name = format_ident!("{}", crate::lr::to_camel_case(name));
let ty = ctx.grammar.types.get(&id).and_then(|t| t.as_ref());
let is_prec = ctx.grammar.symbols.is_prec_terminal(id);
match (is_prec, ty.is_some()) {
(false, true) => {
let field_name = format_ident!("__{}", name.to_lowercase());
arms.push(quote! {
#terminal_enum::#variant_name(v) => {
self.value_stack.push(
#value_union { #field_name: std::mem::ManuallyDrop::new(v) }
);
}
});
}
(false, false) => {
arms.push(quote! {
#terminal_enum::#variant_name => {
self.value_stack.push(#value_union { __unit: () });
}
});
}
(true, true) => {
let field_name = format_ident!("__{}", name.to_lowercase());
arms.push(quote! {
#terminal_enum::#variant_name(v, _prec) => {
self.value_stack.push(
#value_union { #field_name: std::mem::ManuallyDrop::new(v) }
);
}
});
}
(true, false) => {
arms.push(quote! {
#terminal_enum::#variant_name(_prec) => {
self.value_stack.push(#value_union { __unit: () });
}
});
}
}
}
arms.push(quote! {
#terminal_enum::__Phantom(_) => unreachable!(),
});
arms
}
fn generate_reduction_arms(
ctx: &CodegenContext,
reductions: &[ReductionInfo],
value_union: &syn::Ident,
_typed_non_terminals: &[(String, String)],
) -> Vec<TokenStream> {
let gazelle_crate_path = ctx.gazelle_crate_path_tokens();
let mut arms = Vec::new();
for (idx, info) in reductions.iter().enumerate() {
let lhs_field = format_ident!("__{}", info.non_terminal.to_lowercase());
let idx_lit = idx;
let mut stmts = Vec::new();
for (i, sym) in info.rhs_symbols.iter().enumerate().rev() {
let pop_expr = quote! { self.value_stack.pop().unwrap() };
if sym.ty.is_some() {
let field_name = match sym.kind {
SymbolKind::UnitTerminal => {
stmts.push(quote! { let _ = #pop_expr; });
continue;
}
SymbolKind::PayloadTerminal | SymbolKind::PrecTerminal => {
format_ident!("__{}", sym.name.to_lowercase())
}
SymbolKind::NonTerminal => {
format_ident!("__{}", sym.name.to_lowercase())
}
};
let var_name = format_ident!("v{}", i);
let extract = quote! { std::mem::ManuallyDrop::into_inner(#pop_expr.#field_name) };
stmts.push(quote! { let #var_name = unsafe { #extract }; });
} else {
stmts.push(quote! { let _ = #pop_expr; });
}
}
let has_result_type = ctx
.grammar
.symbols
.non_terminal_ids()
.find(|&id| ctx.grammar.symbols.name(id) == info.non_terminal)
.and_then(|id| ctx.grammar.types.get(&id)?.as_ref())
.is_some();
let result = if let Some(variant_name) = &info.variant_name {
let enum_name = enum_name(&info.non_terminal);
let variant_ident = format_ident!("{}", crate::lr::to_camel_case(variant_name));
let args: Vec<_> = typed_symbol_indices(&info.rhs_symbols)
.iter()
.map(|sym_idx| format_ident!("v{}", sym_idx))
.collect();
let node_expr = if args.is_empty() {
quote! { #enum_name::#variant_ident }
} else {
quote! { #enum_name::#variant_ident(#(#args),*) }
};
if has_result_type {
quote! { #value_union { #lhs_field: std::mem::ManuallyDrop::new(
#gazelle_crate_path::Action::build(actions, #node_expr)?
) } }
} else {
quote! { {
#gazelle_crate_path::Action::build(actions, #node_expr)?;
#value_union { __unit: () }
} }
}
} else {
match &info.action {
AltAction::Named(_) => {
quote! { #value_union { __unit: () } }
}
AltAction::OptSome => {
let is_unit = info
.rhs_symbols
.first()
.map(|s| s.ty.is_none())
.unwrap_or(true);
if is_unit {
quote! { #value_union { #lhs_field: std::mem::ManuallyDrop::new(Some(())) } }
} else {
quote! { #value_union { #lhs_field: std::mem::ManuallyDrop::new(Some(v0)) } }
}
}
AltAction::OptNone => {
quote! { #value_union { #lhs_field: std::mem::ManuallyDrop::new(None) } }
}
AltAction::VecEmpty => {
quote! { #value_union { #lhs_field: std::mem::ManuallyDrop::new(Vec::new()) } }
}
AltAction::VecSingle => {
let is_unit = info
.rhs_symbols
.first()
.map(|s| s.ty.is_none())
.unwrap_or(true);
if is_unit {
quote! { #value_union { #lhs_field: std::mem::ManuallyDrop::new(vec![()]) } }
} else {
quote! { #value_union { #lhs_field: std::mem::ManuallyDrop::new(vec![v0]) } }
}
}
AltAction::VecAppend => {
let last_idx = info.rhs_symbols.len() - 1;
let is_unit = info
.rhs_symbols
.get(last_idx)
.map(|s| s.ty.is_none())
.unwrap_or(true);
if is_unit {
quote! { { let mut v0 = v0; v0.push(()); #value_union { #lhs_field: std::mem::ManuallyDrop::new(v0) } } }
} else {
let elem_var = format_ident!("v{}", last_idx);
quote! { { let mut v0 = v0; v0.push(#elem_var); #value_union { #lhs_field: std::mem::ManuallyDrop::new(v0) } } }
}
}
}
};
arms.push(quote! {
#idx_lit => {
#(#stmts)*
#result
}
});
}
arms
}
fn generate_drop_arms(ctx: &CodegenContext, info: &CodegenTableInfo) -> Vec<TokenStream> {
let mut arms = Vec::new();
for id in ctx.grammar.symbols.terminal_ids().skip(1) {
if ctx
.grammar
.types
.get(&id)
.and_then(|t| t.as_ref())
.is_some()
{
let name = ctx.grammar.symbols.name(id);
if let Some((_, table_id)) = info.terminal_ids.iter().find(|(n, _)| n == name) {
let field_name = format_ident!("__{}", name.to_lowercase());
arms.push(quote! {
#table_id => { std::mem::ManuallyDrop::into_inner(union_val.#field_name); }
});
}
}
}
for id in ctx.grammar.symbols.non_terminal_ids() {
if ctx
.grammar
.types
.get(&id)
.and_then(|t| t.as_ref())
.is_some()
{
let name = ctx.grammar.symbols.name(id);
let field_name = format_ident!("__{}", name.to_lowercase());
if let Some((_, table_id)) = info.non_terminal_ids.iter().find(|(n, _)| n == name) {
arms.push(quote! {
#table_id => { std::mem::ManuallyDrop::into_inner(union_val.#field_name); }
});
}
}
}
arms
}
fn enum_name(nt_name: &str) -> syn::Ident {
format_ident!("{}", crate::lr::to_camel_case(nt_name))
}