#![cfg_attr(feature = "strict_docs", allow(missing_docs))]
use adze_glr_core::Action;
use adze_ir::StateId;
use std::collections::HashMap;
pub struct CompressedActionTable {
#[allow(dead_code)]
row_map: HashMap<Vec<Vec<Action>>, usize>,
pub unique_rows: Vec<Vec<Vec<Action>>>,
pub state_to_row: Vec<usize>,
}
pub struct CompressedGotoTable {
pub entries: HashMap<(usize, usize), StateId>,
#[allow(dead_code)]
state_count: usize,
#[allow(dead_code)]
symbol_count: usize,
}
pub fn compress_action_table(table: &[Vec<Vec<Action>>]) -> CompressedActionTable {
let mut row_map = HashMap::new();
let mut unique_rows = Vec::new();
let mut state_to_row = Vec::new();
for row in table {
let row_index = if let Some(&idx) = row_map.get(row) {
idx
} else {
let idx = unique_rows.len();
row_map.insert(row.clone(), idx);
unique_rows.push(row.clone());
idx
};
state_to_row.push(row_index);
}
CompressedActionTable {
row_map,
unique_rows,
state_to_row,
}
}
pub fn decompress_action(
compressed: &CompressedActionTable,
state: usize,
symbol: usize,
) -> Action {
let row_index = compressed.state_to_row[state];
let action_cell = &compressed.unique_rows[row_index][symbol];
action_cell.first().cloned().unwrap_or(Action::Error)
}
pub fn compress_goto_table(table: &[Vec<Option<StateId>>]) -> CompressedGotoTable {
let mut entries = HashMap::new();
let state_count = table.len();
let symbol_count = if state_count > 0 { table[0].len() } else { 0 };
for (state_idx, row) in table.iter().enumerate() {
for (symbol_idx, &goto) in row.iter().enumerate() {
if let Some(target) = goto {
entries.insert((state_idx, symbol_idx), target);
}
}
}
CompressedGotoTable {
entries,
state_count,
symbol_count,
}
}
pub fn decompress_goto(
compressed: &CompressedGotoTable,
state: usize,
symbol: usize,
) -> Option<StateId> {
compressed.entries.get(&(state, symbol)).copied()
}
pub struct BitPackedActionTable {
error_mask: Vec<u64>, shift_data: Vec<u32>, reduce_data: Vec<u32>, fork_data: HashMap<(usize, usize), Vec<Action>>,
#[allow(dead_code)]
state_count: usize,
symbol_count: usize,
}
impl BitPackedActionTable {
pub fn from_table(table: &[Vec<Action>]) -> Self {
let state_count = table.len();
let symbol_count = if state_count > 0 { table[0].len() } else { 0 };
let total_cells = state_count * symbol_count;
let mask_words = total_cells.div_ceil(64);
let mut error_mask = vec![0u64; mask_words];
let mut shift_data = Vec::new();
let mut reduce_data = Vec::new();
let mut fork_data = HashMap::new();
for (state_idx, row) in table.iter().enumerate() {
for (symbol_idx, action) in row.iter().enumerate() {
let cell_idx = state_idx * symbol_count + symbol_idx;
match action {
Action::Error => {
let word_idx = cell_idx / 64;
let bit_idx = cell_idx % 64;
error_mask[word_idx] |= 1 << bit_idx;
}
Action::Shift(state) => {
shift_data.push(state.0 as u32);
}
Action::Reduce(rule) => {
reduce_data.push(rule.0 as u32);
}
Action::Accept => {
reduce_data.push(u32::MAX);
}
Action::Recover => {
let word_idx = cell_idx / 64;
let bit_idx = cell_idx % 64;
error_mask[word_idx] |= 1 << bit_idx;
}
Action::Fork(actions) => {
fork_data.insert((state_idx, symbol_idx), actions.clone());
}
_ => {
let word_idx = cell_idx / 64;
let bit_idx = cell_idx % 64;
error_mask[word_idx] |= 1 << bit_idx;
}
}
}
}
BitPackedActionTable {
error_mask,
shift_data,
reduce_data,
fork_data,
state_count,
symbol_count,
}
}
pub fn decompress(&self, state: usize, symbol: usize) -> Action {
let cell_idx = state * self.symbol_count + symbol;
let word_idx = cell_idx / 64;
let bit_idx = cell_idx % 64;
if (self.error_mask[word_idx] >> bit_idx) & 1 == 1 {
return Action::Error;
}
if let Some(actions) = self.fork_data.get(&(state, symbol)) {
return Action::Fork(actions.clone());
}
let mut data_idx = 0;
for i in 0..cell_idx {
let w_idx = i / 64;
let b_idx = i % 64;
if (self.error_mask[w_idx] >> b_idx) & 1 == 0 {
let s_idx = i / self.symbol_count;
let sym_idx = i % self.symbol_count;
if !self.fork_data.contains_key(&(s_idx, sym_idx)) {
data_idx += 1;
}
}
}
if data_idx < self.shift_data.len() {
Action::Shift(StateId(self.shift_data[data_idx] as u16))
} else {
let reduce_idx = data_idx - self.shift_data.len();
if reduce_idx < self.reduce_data.len() {
let rule_id = self.reduce_data[reduce_idx];
if rule_id == u32::MAX {
Action::Accept
} else {
Action::Reduce(adze_ir::RuleId(rule_id as u16))
}
} else {
Action::Error }
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_row_deduplication() {
let table = vec![
vec![vec![Action::Error], vec![Action::Shift(StateId(1))]],
vec![vec![Action::Error], vec![Action::Shift(StateId(1))]], vec![
vec![Action::Reduce(adze_ir::RuleId(0))],
vec![Action::Error],
],
];
let compressed = compress_action_table(&table);
assert_eq!(compressed.unique_rows.len(), 2);
assert_eq!(decompress_action(&compressed, 0, 0), Action::Error);
assert_eq!(
decompress_action(&compressed, 0, 1),
Action::Shift(StateId(1))
);
assert_eq!(decompress_action(&compressed, 1, 0), Action::Error);
assert_eq!(
decompress_action(&compressed, 1, 1),
Action::Shift(StateId(1))
);
assert_eq!(
decompress_action(&compressed, 2, 0),
Action::Reduce(adze_ir::RuleId(0))
);
}
#[test]
fn test_sparse_goto_compression() {
let table = vec![
vec![None, Some(StateId(1)), None],
vec![Some(StateId(2)), None, None],
vec![None, None, Some(StateId(3))],
];
let compressed = compress_goto_table(&table);
assert_eq!(compressed.entries.len(), 3);
assert_eq!(decompress_goto(&compressed, 0, 0), None);
assert_eq!(decompress_goto(&compressed, 0, 1), Some(StateId(1)));
assert_eq!(decompress_goto(&compressed, 1, 0), Some(StateId(2)));
assert_eq!(decompress_goto(&compressed, 2, 2), Some(StateId(3)));
}
}