use crate::{StateId, SymbolId};
use typed_arena::Arena;
#[derive(Debug)]
pub struct ArenaStackNode<'a> {
pub state: StateId,
pub symbol: Option<SymbolId>,
pub parent: Option<&'a ArenaStackNode<'a>>,
pub depth: usize,
}
impl<'a> ArenaStackNode<'a> {
pub fn get_states(&self) -> Vec<StateId> {
let mut states = Vec::with_capacity(self.depth + 1);
let mut current = Some(self);
while let Some(node) = current {
states.push(node.state);
current = node.parent;
}
states.reverse();
states
}
pub fn shares_prefix_with(&self, other: &ArenaStackNode<'a>) -> bool {
match (self.parent, other.parent) {
(Some(p1), Some(p2)) => std::ptr::eq(p1, p2),
(None, None) => true,
_ => false,
}
}
}
pub struct ArenaGSS<'a> {
arena: &'a Arena<ArenaStackNode<'a>>,
pub active_heads: Vec<&'a ArenaStackNode<'a>>,
pub completed_heads: Vec<&'a ArenaStackNode<'a>>,
pub stats: ArenaGSSStats,
}
#[derive(Debug, Default)]
pub struct ArenaGSSStats {
pub total_nodes_created: usize,
pub max_active_heads: usize,
pub total_forks: usize,
pub total_merges: usize,
pub arena_bytes_allocated: usize,
}
impl<'a> ArenaGSS<'a> {
pub fn new(arena: &'a Arena<ArenaStackNode<'a>>, initial_state: StateId) -> Self {
let initial_node = arena.alloc(ArenaStackNode {
state: initial_state,
symbol: None,
parent: None,
depth: 0,
});
Self {
arena,
active_heads: vec![initial_node],
completed_heads: Vec::new(),
stats: ArenaGSSStats {
total_nodes_created: 1,
max_active_heads: 1,
..Default::default()
},
}
}
pub fn fork_head(&mut self, head_idx: usize) -> usize {
let head = self.active_heads[head_idx];
self.active_heads.push(head);
self.stats.total_forks += 1;
self.stats.max_active_heads = self.stats.max_active_heads.max(self.active_heads.len());
self.active_heads.len() - 1
}
pub fn push(&mut self, head_idx: usize, state: StateId, symbol: Option<SymbolId>) {
let parent = Some(self.active_heads[head_idx]);
let depth = parent.map_or(0, |p| p.depth + 1);
let new_node = self.arena.alloc(ArenaStackNode {
state,
symbol,
parent,
depth,
});
self.active_heads[head_idx] = new_node;
self.stats.total_nodes_created += 1;
}
pub fn pop(&mut self, head_idx: usize, count: usize) -> Option<Vec<StateId>> {
let mut current = Some(self.active_heads[head_idx]);
let mut popped_states = Vec::with_capacity(count);
for _ in 0..count {
match current {
Some(node) => {
popped_states.push(node.state);
current = node.parent;
}
None => return None,
}
}
if let Some(node) = current {
self.active_heads[head_idx] = node;
}
popped_states.reverse();
Some(popped_states)
}
pub fn top_state(&self, head_idx: usize) -> StateId {
self.active_heads[head_idx].state
}
pub fn can_merge(&self, idx1: usize, idx2: usize) -> bool {
if idx1 == idx2 {
return false;
}
let head1 = self.active_heads[idx1];
let head2 = self.active_heads[idx2];
head1.state == head2.state && head1.shares_prefix_with(head2)
}
pub fn merge_heads(&mut self, keep_idx: usize, remove_idx: usize) {
if self.can_merge(keep_idx, remove_idx) {
self.active_heads.remove(remove_idx);
self.stats.total_merges += 1;
}
}
pub fn deduplicate(&mut self) {
let mut i = 0;
while i < self.active_heads.len() {
let mut j = i + 1;
while j < self.active_heads.len() {
if self.can_merge(i, j) {
self.merge_heads(i, j);
} else {
j += 1;
}
}
i += 1;
}
}
pub fn get_stats(&self) -> &ArenaGSSStats {
&self.stats
}
}
pub struct ArenaGSSManager {
arena: Arena<ArenaStackNode<'static>>,
}
impl Default for ArenaGSSManager {
fn default() -> Self {
Self::new()
}
}
impl ArenaGSSManager {
pub fn new() -> Self {
Self {
arena: Arena::new(),
}
}
pub fn new_session<'a>(&'a self, initial_state: StateId) -> ArenaGSS<'a> {
unsafe {
let arena_ref = &*(&self.arena as *const Arena<ArenaStackNode<'static>>);
let arena_transmuted = std::mem::transmute::<
&Arena<ArenaStackNode<'static>>,
&'a Arena<ArenaStackNode<'a>>,
>(arena_ref);
ArenaGSS::new(arena_transmuted, initial_state)
}
}
pub fn clear(&mut self) {
self.arena = Arena::new();
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_arena_gss_basic() {
let arena = Arena::new();
let mut gss = ArenaGSS::new(&arena, StateId(0));
gss.push(0, StateId(1), Some(SymbolId(10)));
gss.push(0, StateId(2), Some(SymbolId(20)));
assert_eq!(gss.top_state(0), StateId(2));
let fork_idx = gss.fork_head(0);
assert_eq!(gss.active_heads.len(), 2);
gss.push(0, StateId(3), None);
gss.push(fork_idx, StateId(4), None);
assert_ne!(gss.top_state(0), gss.top_state(fork_idx));
}
#[test]
fn test_arena_gss_shared_memory() {
let arena = Arena::new();
let mut gss = ArenaGSS::new(&arena, StateId(0));
gss.push(0, StateId(1), None);
gss.push(0, StateId(2), None);
let fork1 = gss.fork_head(0);
let fork2 = gss.fork_head(0);
assert!(gss.active_heads[0].shares_prefix_with(gss.active_heads[fork1]));
assert!(gss.active_heads[0].shares_prefix_with(gss.active_heads[fork2]));
assert!(std::ptr::eq(
gss.active_heads[0].parent.unwrap(),
gss.active_heads[fork1].parent.unwrap()
));
}
#[test]
fn test_arena_manager() {
let manager = ArenaGSSManager::new();
{
let mut gss = manager.new_session(StateId(0));
gss.push(0, StateId(1), None);
gss.push(0, StateId(2), None);
assert_eq!(gss.top_state(0), StateId(2));
assert_eq!(gss.stats.total_nodes_created, 3);
}
}
}