use std::collections::HashMap;
use elara_core::{NodeId, StateAtom, StateId, StateTime, StateType};
#[derive(Debug, Default)]
pub struct StateField {
pub atoms: HashMap<StateId, StateAtom>,
quarantine: Vec<QuarantinedEvent>,
}
#[derive(Debug)]
pub struct QuarantinedEvent {
pub event_data: Vec<u8>,
pub missing_deps: Vec<StateId>,
pub quarantined_at: StateTime,
}
impl StateField {
pub fn new() -> Self {
StateField::default()
}
pub fn get(&self, id: StateId) -> Option<&StateAtom> {
self.atoms.get(&id)
}
pub fn get_mut(&mut self, id: StateId) -> Option<&mut StateAtom> {
self.atoms.get_mut(&id)
}
pub fn insert(&mut self, atom: StateAtom) {
self.atoms.insert(atom.id, atom);
}
pub fn remove(&mut self, id: StateId) -> Option<StateAtom> {
self.atoms.remove(&id)
}
pub fn contains(&self, id: StateId) -> bool {
self.atoms.contains_key(&id)
}
pub fn len(&self) -> usize {
self.atoms.len()
}
pub fn is_empty(&self) -> bool {
self.atoms.is_empty()
}
pub fn iter(&self) -> impl Iterator<Item = (&StateId, &StateAtom)> {
self.atoms.iter()
}
pub fn iter_by_type(&self, state_type: StateType) -> impl Iterator<Item = &StateAtom> {
self.atoms
.values()
.filter(move |a| a.state_type == state_type)
}
pub fn atoms_needing_prediction(&self, threshold_ms: u64) -> Vec<StateId> {
self.atoms
.iter()
.filter(|(_, atom)| atom.needs_prediction(threshold_ms))
.map(|(id, _)| *id)
.collect()
}
pub fn quarantine(&mut self, event_data: Vec<u8>, missing_deps: Vec<StateId>, now: StateTime) {
self.quarantine.push(QuarantinedEvent {
event_data,
missing_deps,
quarantined_at: now,
});
}
pub fn release_quarantine(&mut self) -> Vec<Vec<u8>> {
let atoms = &self.atoms;
let (ready, still_waiting): (Vec<_>, Vec<_>) = self
.quarantine
.drain(..)
.partition(|e| e.missing_deps.iter().all(|dep| atoms.contains_key(dep)));
self.quarantine = still_waiting;
ready.into_iter().map(|e| e.event_data).collect()
}
pub fn quarantine_size(&self) -> usize {
self.quarantine.len()
}
pub fn memory_size(&self) -> usize {
self.atoms.values().map(|a| a.memory_size()).sum()
}
pub fn create_atom(
&mut self,
id: StateId,
state_type: StateType,
owner: NodeId,
) -> &mut StateAtom {
let atom = StateAtom::new(id, state_type, owner);
self.atoms.insert(id, atom);
self.atoms.get_mut(&id).unwrap()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_state_field_basic() {
let mut field = StateField::new();
let owner = NodeId::new(1);
let id = StateId::new(100);
field.create_atom(id, StateType::Core, owner);
assert!(field.contains(id));
assert_eq!(field.len(), 1);
let atom = field.get(id).unwrap();
assert_eq!(atom.state_type, StateType::Core);
}
#[test]
fn test_state_field_iter_by_type() {
let mut field = StateField::new();
let owner = NodeId::new(1);
field.create_atom(StateId::new(1), StateType::Core, owner);
field.create_atom(StateId::new(2), StateType::Perceptual, owner);
field.create_atom(StateId::new(3), StateType::Core, owner);
let core_atoms: Vec<_> = field.iter_by_type(StateType::Core).collect();
assert_eq!(core_atoms.len(), 2);
}
}