use crate::{Action, ParseTable, StateId};
use adze_ir::SymbolId;
use std::fmt;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ConflictSummary {
pub shift_reduce: usize,
pub reduce_reduce: usize,
pub states_with_conflicts: Vec<StateId>,
pub conflict_details: Vec<ConflictDetail>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ConflictDetail {
pub state: StateId,
pub symbol: SymbolId,
pub symbol_name: String,
pub conflict_type: ConflictType,
pub actions: Vec<Action>,
pub priorities: Vec<i32>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ConflictType {
ShiftReduce,
ReduceReduce,
Mixed,
}
pub fn count_conflicts(table: &ParseTable) -> ConflictSummary {
debug_assert_eq!(
table.state_count,
table.action_table.len(),
"ParseTable invariant violation: state_count ({}) != action_table.len() ({})",
table.state_count,
table.action_table.len()
);
debug_assert!(
!table.action_table.is_empty(),
"ParseTable invariant violation: action_table is empty but should have at least initial state"
);
for (state_idx, state_actions) in table.action_table.iter().enumerate() {
debug_assert!(
state_idx < table.state_count,
"ParseTable invariant violation: state index {} >= state_count {}",
state_idx,
table.state_count
);
for symbol_idx in 0..state_actions.len() {
debug_assert!(
symbol_idx < table.index_to_symbol.len() || table.index_to_symbol.is_empty(),
"ParseTable invariant violation: symbol index {} >= index_to_symbol.len() {}",
symbol_idx,
table.index_to_symbol.len()
);
}
}
let mut summary = ConflictSummary {
shift_reduce: 0,
reduce_reduce: 0,
states_with_conflicts: Vec::new(),
conflict_details: Vec::new(),
};
for (state_idx, state_actions) in table.action_table.iter().enumerate() {
let state_id = StateId(state_idx as u16);
let mut state_has_conflict = false;
for (symbol_idx, action_cell) in state_actions.iter().enumerate() {
if action_cell.is_empty() {
continue;
}
if action_cell.len() > 1 {
state_has_conflict = true;
let symbol_id = if symbol_idx < table.index_to_symbol.len() {
table.index_to_symbol[symbol_idx]
} else {
SymbolId(0)
};
let symbol_name = if (symbol_id.0 as usize) < table.symbol_metadata.len() {
format!("symbol_{}", symbol_id.0)
} else {
format!("symbol_{}", symbol_id.0)
};
let conflict_type = classify_conflict(action_cell);
match conflict_type {
ConflictType::ShiftReduce => summary.shift_reduce += 1,
ConflictType::ReduceReduce => summary.reduce_reduce += 1,
ConflictType::Mixed => {
summary.shift_reduce += 1;
summary.reduce_reduce += 1;
}
}
let priorities = action_cell
.iter()
.map(|_action| 0i32) .collect();
summary.conflict_details.push(ConflictDetail {
state: state_id,
symbol: symbol_id,
symbol_name,
conflict_type,
actions: action_cell.clone(),
priorities,
});
}
}
if state_has_conflict {
summary.states_with_conflicts.push(state_id);
}
}
summary
}
pub fn classify_conflict(actions: &[Action]) -> ConflictType {
let mut has_shift = false;
let mut has_reduce = false;
for action in actions {
match action {
Action::Shift(_) => has_shift = true,
Action::Reduce(_) => has_reduce = true,
Action::Fork(inner) => {
let inner_type = classify_conflict(inner);
match inner_type {
ConflictType::ShiftReduce | ConflictType::Mixed => {
has_shift = true;
has_reduce = true;
}
ConflictType::ReduceReduce => has_reduce = true,
}
}
Action::Accept | Action::Error | Action::Recover => {}
}
}
match (has_shift, has_reduce) {
(true, true) => ConflictType::ShiftReduce,
(false, true) => ConflictType::ReduceReduce,
_ => ConflictType::Mixed,
}
}
pub fn state_has_conflicts(table: &ParseTable, state: StateId) -> bool {
if (state.0 as usize) >= table.action_table.len() {
return false;
}
let state_actions = &table.action_table[state.0 as usize];
state_actions.iter().any(|cell| cell.len() > 1)
}
pub fn get_state_conflicts(table: &ParseTable, state: StateId) -> Vec<ConflictDetail> {
let summary = count_conflicts(table);
summary
.conflict_details
.into_iter()
.filter(|detail| detail.state == state)
.collect()
}
pub fn find_conflicts_for_symbol(table: &ParseTable, symbol: SymbolId) -> Vec<ConflictDetail> {
let summary = count_conflicts(table);
summary
.conflict_details
.into_iter()
.filter(|detail| detail.symbol == symbol)
.collect()
}
impl fmt::Display for ConflictSummary {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
writeln!(f, "=== Conflict Summary ===")?;
writeln!(f, "Shift/Reduce conflicts: {}", self.shift_reduce)?;
writeln!(f, "Reduce/Reduce conflicts: {}", self.reduce_reduce)?;
writeln!(
f,
"States with conflicts: {}",
self.states_with_conflicts.len()
)?;
writeln!(f)?;
writeln!(f, "=== Conflict Details ===")?;
for detail in &self.conflict_details {
writeln!(f, "{}", detail)?;
}
Ok(())
}
}
impl fmt::Display for ConflictDetail {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"State {}, Symbol '{}' ({}): {:?} - {} actions",
self.state.0,
self.symbol_name,
self.symbol.0,
self.conflict_type,
self.actions.len()
)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::Action;
use adze_ir::RuleId;
fn create_test_table(action_table: Vec<Vec<Vec<Action>>>) -> ParseTable {
let state_count = action_table.len();
ParseTable {
action_table,
goto_table: vec![],
symbol_metadata: vec![],
state_count,
symbol_count: 1,
symbol_to_index: Default::default(),
index_to_symbol: Default::default(),
external_scanner_states: vec![],
rules: vec![],
nonterminal_to_index: Default::default(),
goto_indexing: crate::GotoIndexing::NonterminalMap,
eof_symbol: SymbolId(0),
start_symbol: SymbolId(0),
grammar: adze_ir::Grammar::new("test".to_string()),
initial_state: StateId(0),
token_count: 0,
external_token_count: 0,
lex_modes: vec![],
extras: vec![],
dynamic_prec_by_rule: vec![],
rule_assoc_by_rule: vec![],
alias_sequences: vec![],
field_names: vec![],
field_map: Default::default(),
}
}
#[test]
fn test_classify_shift_reduce() {
let actions = vec![Action::Shift(StateId(5)), Action::Reduce(RuleId(3))];
assert_eq!(classify_conflict(&actions), ConflictType::ShiftReduce);
}
#[test]
fn test_classify_reduce_reduce() {
let actions = vec![Action::Reduce(RuleId(3)), Action::Reduce(RuleId(7))];
assert_eq!(classify_conflict(&actions), ConflictType::ReduceReduce);
}
#[test]
fn test_classify_mixed() {
let actions = vec![Action::Shift(StateId(1)), Action::Shift(StateId(2))];
assert_eq!(classify_conflict(&actions), ConflictType::Mixed);
}
#[test]
fn test_classify_fork_shift_reduce() {
let actions = vec![Action::Fork(vec![
Action::Shift(StateId(1)),
Action::Reduce(RuleId(1)),
])];
assert_eq!(classify_conflict(&actions), ConflictType::ShiftReduce);
}
#[test]
fn test_empty_conflict_summary() {
let table = create_test_table(vec![vec![vec![Action::Shift(StateId(1))]]]);
let summary = count_conflicts(&table);
assert_eq!(summary.shift_reduce, 0);
assert_eq!(summary.reduce_reduce, 0);
assert!(summary.states_with_conflicts.is_empty());
}
#[test]
fn test_detect_shift_reduce_conflict() {
let table = create_test_table(vec![vec![vec![
Action::Shift(StateId(1)),
Action::Reduce(RuleId(0)),
]]]);
let summary = count_conflicts(&table);
assert_eq!(summary.shift_reduce, 1);
assert_eq!(summary.reduce_reduce, 0);
assert_eq!(summary.states_with_conflicts.len(), 1);
assert_eq!(summary.conflict_details.len(), 1);
let detail = &summary.conflict_details[0];
assert_eq!(detail.conflict_type, ConflictType::ShiftReduce);
assert_eq!(detail.actions.len(), 2);
}
#[test]
fn test_state_has_conflicts() {
let table = create_test_table(vec![
vec![vec![Action::Shift(StateId(1)), Action::Reduce(RuleId(0))]],
vec![vec![Action::Shift(StateId(2))]],
]);
assert!(state_has_conflicts(&table, StateId(0)));
assert!(!state_has_conflicts(&table, StateId(1)));
}
}