use crate::grammar::SymbolId;
use std::collections::{HashMap, HashSet};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) enum Action {
Shift(usize),
Reduce(usize),
ShiftOrReduce { shift_state: usize, reduce_rule: usize },
Error,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(transparent)]
pub(crate) struct ActionEntry(pub(crate) u32);
impl ActionEntry {
pub const ERROR: ActionEntry = ActionEntry(0);
pub fn shift(state: usize) -> Self {
debug_assert!(state > 0, "Shift(0) is reserved for Error");
debug_assert!(state < 0x80000000, "Shift state too large");
ActionEntry(state as u32)
}
pub fn reduce(rule: usize) -> Self {
debug_assert!(rule < 0x1000, "Reduce rule too large (max 4095)");
ActionEntry(!(rule as u32))
}
pub fn shift_or_reduce(shift_state: usize, reduce_rule: usize) -> Self {
debug_assert!(shift_state > 0, "Shift(0) is reserved for Error");
debug_assert!(shift_state < 0x80000, "Shift state too large (max 19 bits)");
debug_assert!(reduce_rule < 0x1000, "Reduce rule too large (max 4095)");
ActionEntry(!((reduce_rule as u32) | ((shift_state as u32) << 12)))
}
pub fn decode(&self) -> Action {
let v = self.0 as i32;
if v > 0 {
Action::Shift(v as usize)
} else if v == 0 {
Action::Error
} else {
let payload = !self.0;
let r = (payload & 0xFFF) as usize;
let s = ((payload >> 12) & 0x7FFFF) as usize;
if s == 0 {
Action::Reduce(r)
} else {
Action::ShiftOrReduce { shift_state: s, reduce_rule: r }
}
}
}
}
fn format_sym(s: &str) -> String {
if let Some(base) = s.strip_prefix("__").and_then(|s| s.strip_suffix("_star")) {
format!("{}*", base)
} else if let Some(base) = s.strip_prefix("__").and_then(|s| s.strip_suffix("_plus")) {
format!("{}+", base)
} else if let Some(base) = s.strip_prefix("__").and_then(|s| s.strip_suffix("_opt")) {
format!("{}?", base)
} else if let Some(rest) = s.strip_prefix("__") {
if let Some(idx) = rest.find("_sep_") {
let base = &rest[..idx];
let sep = &rest[idx + 5..];
return format!("{} % {}", base, sep);
}
s.to_string()
} else {
s.to_string()
}
}
#[derive(Debug, Clone, Copy)]
pub struct ParseTable<'a> {
action_data: &'a [u32],
action_base: &'a [i32],
action_check: &'a [u32],
goto_data: &'a [u32],
goto_base: &'a [i32],
goto_check: &'a [u32],
rules: &'a [(u32, u8)],
num_terminals: u32,
}
impl<'a> ParseTable<'a> {
#[allow(clippy::too_many_arguments)]
pub const fn new(
action_data: &'a [u32],
action_base: &'a [i32],
action_check: &'a [u32],
goto_data: &'a [u32],
goto_base: &'a [i32],
goto_check: &'a [u32],
rules: &'a [(u32, u8)],
num_terminals: u32,
) -> Self {
ParseTable {
action_data,
action_base,
action_check,
goto_data,
goto_base,
goto_check,
rules,
num_terminals,
}
}
pub(crate) fn action(&self, state: usize, terminal: SymbolId) -> Action {
let col = terminal.0 as i32;
let displacement = self.action_base[state];
let idx = displacement.wrapping_add(col) as usize;
if idx < self.action_check.len() && self.action_check[idx] == state as u32 {
ActionEntry(self.action_data[idx]).decode()
} else {
Action::Error
}
}
pub fn goto(&self, state: usize, non_terminal: SymbolId) -> Option<usize> {
let col = (non_terminal.0 - self.num_terminals) as i32;
let displacement = self.goto_base[state];
let idx = displacement.wrapping_add(col) as usize;
if idx < self.goto_check.len() && self.goto_check[idx] == state as u32 {
Some(self.goto_data[idx] as usize)
} else {
None
}
}
pub fn rule_info(&self, rule: usize) -> (SymbolId, usize) {
let (lhs, len) = self.rules[rule];
(SymbolId(lhs), len as usize)
}
pub fn rules(&self) -> &[(u32, u8)] {
self.rules
}
}
pub trait ErrorContext {
fn symbol_name(&self, id: SymbolId) -> &str;
fn state_symbol(&self, state: usize) -> SymbolId;
fn state_items(&self, state: usize) -> &[(u16, u8)];
fn rule_rhs(&self, rule: usize) -> &[u32];
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum Precedence {
Left(u8),
Right(u8),
}
impl Precedence {
pub fn level(&self) -> u8 {
match self {
Precedence::Left(l) | Precedence::Right(l) => *l,
}
}
pub fn assoc(&self) -> u8 {
match self {
Precedence::Left(_) => 0,
Precedence::Right(_) => 1,
}
}
}
fn compute_nullable(table: &ParseTable, ctx: &impl ErrorContext) -> Vec<bool> {
let rules = table.rules();
let num_terminals = table.num_terminals as usize;
let mut max_sym = num_terminals;
for (rule_idx, &(lhs, _)) in rules.iter().enumerate() {
let lhs = lhs as usize;
if lhs >= max_sym {
max_sym = lhs + 1;
}
for &sym in ctx.rule_rhs(rule_idx) {
let id = sym as usize;
if id >= max_sym {
max_sym = id + 1;
}
}
}
let mut nullable: Vec<bool> = vec![false; max_sym];
let mut changed = true;
while changed {
changed = false;
for (rule_idx, &(lhs, _)) in rules.iter().enumerate() {
let lhs = lhs as usize;
let rhs = ctx.rule_rhs(rule_idx);
let all_nullable = rhs.iter().all(|&sym| nullable[sym as usize]);
if all_nullable && !nullable[lhs] {
nullable[lhs] = true;
changed = true;
}
}
}
nullable
}
fn expected_from_sequence(
sequence: &[u32],
table: &ParseTable,
ctx: &impl ErrorContext,
nullable: &[bool],
num_terminals: usize,
) -> HashSet<usize> {
let mut result = HashSet::new();
for &sym in sequence {
let sym_id = sym as usize;
if sym_id < num_terminals || !nullable.get(sym_id).copied().unwrap_or(false) {
result.insert(sym_id);
break;
}
expand_nullable(sym_id, table, ctx, nullable, num_terminals, &mut result, &mut HashSet::new());
}
result
}
fn expand_nullable(
sym: usize,
table: &ParseTable,
ctx: &impl ErrorContext,
nullable: &[bool],
num_terminals: usize,
result: &mut HashSet<usize>,
visited: &mut HashSet<usize>,
) {
if !visited.insert(sym) {
return;
}
for (rule_idx, &(lhs, _)) in table.rules().iter().enumerate() {
if lhs as usize != sym {
continue;
}
for &s in ctx.rule_rhs(rule_idx) {
let s_id = s as usize;
if s_id < num_terminals || !nullable.get(s_id).copied().unwrap_or(false) {
result.insert(s_id);
break;
}
expand_nullable(s_id, table, ctx, nullable, num_terminals, result, visited);
}
}
}
fn is_sequence_nullable(sequence: &[u32], nullable: &[bool]) -> bool {
sequence.iter().all(|&sym| nullable.get(sym as usize).copied().unwrap_or(false))
}
#[derive(Debug, Clone)]
pub struct ParseError {
terminal: SymbolId,
}
impl ParseError {
pub fn terminal(&self) -> SymbolId {
self.terminal
}
}
impl std::fmt::Display for ParseError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "unexpected terminal {:?}", self.terminal)
}
}
impl std::error::Error for ParseError {}
#[derive(Debug, Clone)]
pub struct Token {
pub terminal: SymbolId,
pub prec: Option<Precedence>,
}
impl Token {
pub fn new(terminal: SymbolId) -> Self {
Self { terminal, prec: None }
}
pub fn with_prec(terminal: SymbolId, prec: Precedence) -> Self {
Self { terminal, prec: Some(prec) }
}
}
#[derive(Debug, Clone, Copy)]
struct StackEntry {
state: usize,
prec: Option<Precedence>,
token_idx: usize,
}
pub struct Parser<'a> {
table: ParseTable<'a>,
state: StackEntry,
stack: Vec<StackEntry>,
token_count: usize,
}
impl<'a> Parser<'a> {
pub fn new(table: ParseTable<'a>) -> Self {
Self {
table,
state: StackEntry { state: 0, prec: None, token_idx: 0 },
stack: Vec::new(),
token_count: 0,
}
}
pub fn maybe_reduce(&mut self, lookahead: Option<&Token>) -> Result<Option<(usize, usize, usize)>, ParseError> {
let terminal = lookahead.map(|t| t.terminal).unwrap_or(SymbolId::EOF);
let lookahead_prec = lookahead.and_then(|t| t.prec);
match self.table.action(self.state.state, terminal) {
Action::Reduce(rule) => {
if rule == 0 {
Ok(Some((0, 0, 0))) } else {
let (len, start_idx) = self.do_reduce(rule);
Ok(Some((rule, len, start_idx)))
}
}
Action::Shift(_) => Ok(None),
Action::ShiftOrReduce { reduce_rule, .. } => {
let should_reduce = match (self.state.prec, lookahead_prec) {
(Some(sp), Some(tp)) => {
if tp.level() > sp.level() {
false
} else if tp.level() < sp.level() {
true
} else {
matches!(sp, Precedence::Left(_))
}
}
_ => false,
};
if should_reduce {
let (len, start_idx) = self.do_reduce(reduce_rule);
Ok(Some((reduce_rule, len, start_idx)))
} else {
Ok(None)
}
}
Action::Error => {
Err(ParseError { terminal })
}
}
}
pub fn shift(&mut self, token: &Token) {
let next_state = match self.table.action(self.state.state, token.terminal) {
Action::Shift(s) => s,
Action::ShiftOrReduce { shift_state, .. } => shift_state,
_ => panic!("shift called when action is not shift"),
};
let prec = token.prec.or(self.state.prec);
self.stack.push(self.state);
self.state = StackEntry {
state: next_state,
prec,
token_idx: self.token_count,
};
self.token_count += 1;
}
fn do_reduce(&mut self, rule: usize) -> (usize, usize) {
let (lhs, len) = self.table.rule_info(rule);
let start_idx = match len {
0 => self.token_count, 1 => self.state.token_idx, _ => self.stack[self.stack.len() - len + 1].token_idx, };
if len == 0 {
if let Some(next_state) = self.table.goto(self.state.state, lhs) {
self.stack.push(self.state);
self.state = StackEntry { state: next_state, prec: None, token_idx: start_idx };
}
} else {
for _ in 0..(len - 1) {
self.stack.pop();
}
let anchor = self.stack.last().unwrap();
let captured_prec = if len == 1 { self.state.prec } else { anchor.prec };
if let Some(next_state) = self.table.goto(anchor.state, lhs) {
self.state = StackEntry { state: next_state, prec: captured_prec, token_idx: start_idx };
}
}
(len, start_idx)
}
pub fn state(&self) -> usize {
self.state.state
}
pub fn stack_depth(&self) -> usize {
self.stack.len()
}
pub fn token_count(&self) -> usize {
self.token_count
}
pub fn state_at(&self, depth: usize) -> usize {
let idx = depth + 1;
if idx < self.stack.len() {
self.stack[idx].state
} else {
self.state.state
}
}
pub fn format_error(&self, err: &ParseError, ctx: &impl ErrorContext) -> String {
self.format_error_with(err, ctx, &HashMap::new(), &[])
}
pub fn format_error_with(
&self,
err: &ParseError,
ctx: &impl ErrorContext,
display_names: &HashMap<&str, &str>,
tokens: &[&str],
) -> String {
let mut full_stack: Vec<StackEntry> = self.stack.clone();
full_stack.push(self.state);
let error_token_idx = self.token_count;
let display = |id: SymbolId| -> &str {
let name = ctx.symbol_name(id);
display_names.get(name).copied().unwrap_or(name)
};
let stack_spans = || -> Vec<(usize, usize, usize)> {
let mut spans = Vec::with_capacity(full_stack.len());
for i in 0..full_stack.len() {
let start = full_stack[i].token_idx;
let end = if i + 1 < full_stack.len() {
full_stack[i + 1].token_idx
} else {
error_token_idx
};
spans.push((start, end, full_stack[i].state));
}
spans
};
let nullable = compute_nullable(&self.table, ctx);
let num_terminals = self.table.num_terminals as usize;
let mut relevant_items = Vec::new();
self.collect_relevant_items(ctx, self.state.state, self.stack.len() + 1, &mut relevant_items);
let expected_syms = self.compute_expected(ctx, &relevant_items, &nullable, num_terminals);
let mut expected: Vec<_> = expected_syms.iter()
.map(|&sym| format_sym(display(SymbolId(sym as u32))))
.collect();
expected.sort();
let found_name = tokens.get(error_token_idx)
.copied()
.unwrap_or_else(|| display(err.terminal));
let mut msg = format!("unexpected '{}'", found_name);
if !expected.is_empty() {
msg.push_str(&format!(", expected: {}", expected.join(", ")));
}
if !tokens.is_empty() && error_token_idx <= tokens.len() {
let spans = stack_spans();
let relevant: Vec<_> = spans.into_iter()
.skip(1) .filter(|(start, end, _)| end > start) .collect();
if !relevant.is_empty() {
let mut token_line = String::new();
let mut label_line = String::new();
for (start, end, state) in relevant.iter().rev().take(4).rev() {
let sym = ctx.state_symbol(*state);
let name = format_sym(display(sym));
let span_text = if end - start == 1 {
tokens[*start].to_string()
} else if end - start <= 3 {
tokens[*start..*end].join(" ")
} else {
format!("{} ... {}", tokens[*start], tokens[end - 1])
};
let width = span_text.chars().count().max(name.len());
if !token_line.is_empty() {
token_line.push_str(" ");
label_line.push_str(" ");
}
token_line.push_str(&format!("{:^width$}", span_text, width = width));
label_line.push_str(&format!("{:^width$}", name, width = width));
}
msg.push_str(&format!("\n {}\n {}", token_line, label_line));
}
} else if full_stack.len() > 1 {
let path: Vec<_> = full_stack[1..]
.iter()
.map(|e| display(ctx.state_symbol(e.state)))
.collect();
msg.push_str(&format!("\n after: {}", path.join(" ")));
}
let display_items = &relevant_items;
let mut seen = HashSet::new();
for &(rule, dot) in display_items {
let rhs = ctx.rule_rhs(rule);
let lhs = self.table.rule_info(rule).0;
if ctx.symbol_name(lhs) == "__start" {
continue;
}
let lhs_name = format_sym(display(lhs));
let before: Vec<_> = rhs[..dot]
.iter()
.map(|&id| format_sym(display(SymbolId(id))))
.collect();
let after: Vec<_> = rhs[dot..]
.iter()
.map(|&id| format_sym(display(SymbolId(id))))
.collect();
let line = format!(
"\n in {}: {} \u{2022} {}",
lhs_name,
before.join(" "),
after.join(" ")
);
if seen.insert(line.clone()) {
msg.push_str(&line);
}
}
msg
}
fn collect_relevant_items(
&self,
ctx: &impl ErrorContext,
state: usize,
stack_len: usize,
result: &mut Vec<(usize, usize)>,
) {
for &(rule, dot) in ctx.state_items(state) {
let rule = rule as usize;
let dot = dot as usize;
let rhs = ctx.rule_rhs(rule);
let lhs = self.table.rule_info(rule).0;
if ctx.symbol_name(lhs) == "__start" {
result.push((rule, dot));
continue;
}
if dot == 0 { continue; }
if dot < rhs.len() {
result.push((rule, dot));
} else {
let consumed = rhs.len();
if stack_len > consumed {
let caller_state = self.state_at_idx(stack_len - consumed - 1);
if let Some(goto_state) = self.table.goto(caller_state, lhs) {
self.collect_relevant_items(ctx, goto_state, stack_len - consumed + 1, result);
}
}
}
}
}
fn compute_expected(
&self,
ctx: &impl ErrorContext,
items: &[(usize, usize)],
nullable: &[bool],
num_terminals: usize,
) -> HashSet<usize> {
let mut expected = HashSet::new();
let stack_len = self.stack.len() + 1;
for &(rule, dot) in items {
let rhs = ctx.rule_rhs(rule);
let lhs = self.table.rule_info(rule).0;
let suffix = &rhs[dot..];
expected.extend(expected_from_sequence(suffix, &self.table, ctx, nullable, num_terminals));
if is_sequence_nullable(suffix, nullable) && stack_len > dot {
expected.extend(self.compute_follow_from_context(
ctx, lhs, stack_len - dot,
nullable, num_terminals, &mut HashSet::new(),
));
}
}
expected
}
fn state_at_idx(&self, idx: usize) -> usize {
if idx < self.stack.len() {
self.stack[idx].state
} else {
self.state.state
}
}
fn compute_follow_from_context(
&self,
ctx: &impl ErrorContext,
nonterminal: SymbolId,
caller_idx: usize,
nullable: &[bool],
num_terminals: usize,
visited: &mut HashSet<(usize, u32)>,
) -> HashSet<usize> {
if nonterminal == self.table.rule_info(0).0 {
let mut result = HashSet::new();
result.insert(0); return result;
}
if caller_idx == 0 {
let mut result = HashSet::new();
result.insert(0); return result;
}
let caller_state = self.state_at_idx(caller_idx - 1);
if !visited.insert((caller_idx, nonterminal.0)) {
return HashSet::new();
}
let mut expected = HashSet::new();
for &(rule, dot) in ctx.state_items(caller_state) {
let rule = rule as usize;
let dot = dot as usize;
let rhs = ctx.rule_rhs(rule);
if dot < rhs.len() && rhs[dot] == nonterminal.0 {
let suffix = &rhs[dot + 1..];
let lhs = self.table.rule_info(rule).0;
let consumed = dot;
if suffix.is_empty() {
if caller_idx > consumed {
expected.extend(self.compute_follow_from_context(
ctx, lhs, caller_idx - consumed,
nullable, num_terminals, visited,
));
} else {
expected.insert(0);
}
} else {
expected.extend(expected_from_sequence(suffix, &self.table, ctx, nullable, num_terminals));
if is_sequence_nullable(suffix, nullable) {
if caller_idx > consumed {
expected.extend(self.compute_follow_from_context(
ctx, lhs, caller_idx - consumed,
nullable, num_terminals, visited,
));
} else {
expected.insert(0);
}
}
}
}
}
expected
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::grammar::SymbolId;
use crate::table::CompiledTable;
#[test]
fn test_action_entry_encoding() {
let shift = ActionEntry::shift(42);
assert_eq!(shift.decode(), Action::Shift(42));
let reduce = ActionEntry::reduce(7);
assert_eq!(reduce.decode(), Action::Reduce(7));
let accept = ActionEntry::reduce(0);
assert_eq!(accept.decode(), Action::Reduce(0));
let error = ActionEntry::ERROR;
assert_eq!(error.decode(), Action::Error);
let sor = ActionEntry::shift_or_reduce(10, 5);
match sor.decode() {
Action::ShiftOrReduce { shift_state, reduce_rule } => {
assert_eq!(shift_state, 10);
assert_eq!(reduce_rule, 5);
}
other => panic!("Expected ShiftOrReduce, got {:?}", other),
}
}
use crate::meta::parse_grammar;
use crate::lr::to_grammar_internal;
#[test]
fn test_parse_single_token() {
let grammar = to_grammar_internal(&parse_grammar(r#"
grammar Simple { start s; terminals { a } s = a; }
"#).unwrap()).unwrap();
let compiled = CompiledTable::build_with_algorithm(&grammar, crate::lr::LrAlgorithm::default());
let mut parser = Parser::new(compiled.table());
let a_id = compiled.symbol_id("a").unwrap();
let token = Token::new(a_id);
assert!(matches!(parser.maybe_reduce(Some(&token)), Ok(None)));
parser.shift(&token);
let result = parser.maybe_reduce(None);
assert!(matches!(result, Ok(Some((1, 1, 0)))));
let result = parser.maybe_reduce(None);
assert!(matches!(result, Ok(Some((0, 0, 0)))));
}
#[test]
fn test_parse_error() {
let grammar = to_grammar_internal(&parse_grammar(r#"
grammar Simple { start s; terminals { a } s = a; }
"#).unwrap()).unwrap();
let compiled = CompiledTable::build_with_algorithm(&grammar, crate::lr::LrAlgorithm::default());
let mut parser = Parser::new(compiled.table());
let wrong_id = SymbolId(99);
let token = Token::new(wrong_id);
let result = parser.maybe_reduce(Some(&token));
assert!(result.is_err());
}
#[test]
fn test_format_error() {
let grammar = to_grammar_internal(&parse_grammar(r#"
grammar Simple { start s; terminals { a, b } s = a; }
"#).unwrap()).unwrap();
let compiled = CompiledTable::build_with_algorithm(&grammar, crate::lr::LrAlgorithm::default());
let mut parser = Parser::new(compiled.table());
let b_id = compiled.symbol_id("b").unwrap();
let token = Token::new(b_id);
let err = parser.maybe_reduce(Some(&token)).unwrap_err();
let msg = parser.format_error(&err, &compiled);
assert!(msg.contains("unexpected"), "msg: {}", msg);
assert!(msg.contains("'b'"), "msg: {}", msg);
assert!(msg.contains("s"), "msg: {}", msg);
}
}