use std::{
collections::{BTreeMap, BTreeSet},
num::NonZeroU64,
};
use bimap::{BiMap, Overwritten};
use machine_check_common::StateId;
use crate::{AbstrPanicState, WrappedState};
use mck::{concr::FullMachine, misc::MetaWrap};
use std::fmt::Debug;
pub struct StateStore<M: FullMachine> {
map: BiMap<StateId, WrappedState<M>>,
next_state_id: StateId,
}
impl<M: FullMachine> Debug for StateStore<M> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let ordered_map = BTreeMap::from_iter(self.map.iter());
f.debug_struct("StateStore")
.field("map", &ordered_map)
.field("next_state_id", &self.next_state_id)
.finish()
}
}
impl<M: FullMachine> StateStore<M> {
pub fn new() -> Self {
Self {
map: BiMap::new(),
next_state_id: StateId(NonZeroU64::MIN),
}
}
pub fn state_id(&mut self, state: AbstrPanicState<M>) -> (StateId, bool) {
let state = MetaWrap(state);
if let Some(state_id) = self.map.get_by_right(&state) {
return (*state_id, false);
};
let inserted_state_id = self.next_state_id;
assert!(matches!(
self.map.insert(inserted_state_id, state),
Overwritten::Neither
));
match self.next_state_id.0.checked_add(1) {
Some(result) => self.next_state_id.0 = result,
None => {
panic!("Next state id counter should not overflow");
}
};
(inserted_state_id, true)
}
pub fn state_data(&self, state_id: StateId) -> &AbstrPanicState<M> {
if let Some(state) = &self.map.get_by_left(&state_id) {
&state.0
} else {
panic!("State {} should be in state map", state_id);
}
}
pub fn retain_states(&mut self, retained_states: &BTreeSet<StateId>) -> BTreeSet<StateId> {
let mut removed_states = BTreeSet::new();
for state_id in self.map.left_values() {
if !retained_states.contains(state_id) {
removed_states.insert(*state_id);
}
}
self.map
.retain(|state_id, _state_data| !removed_states.contains(state_id));
removed_states
}
}