use std::collections::{HashMap, HashSet};
use libpetri_event::net_event::NetEvent;
use crate::debug_response::TokenInfo;
pub const SNAPSHOT_INTERVAL: usize = 256;
#[derive(Debug, Clone)]
pub struct ComputedState {
pub marking: HashMap<String, Vec<TokenInfo>>,
pub enabled_transitions: Vec<String>,
pub in_flight_transitions: Vec<String>,
}
pub struct MarkingCache {
snapshots: Vec<ComputedState>,
}
impl MarkingCache {
pub fn new() -> Self {
Self {
snapshots: Vec::new(),
}
}
pub fn compute_at(&mut self, events: &[NetEvent], target_index: usize) -> ComputedState {
if target_index == 0 {
return compute_state(&[]);
}
self.ensure_cached_up_to(events, target_index);
if self.snapshots.is_empty() {
return compute_state(&events[..target_index.min(events.len())]);
}
let snapshot_slot = (target_index / SNAPSHOT_INTERVAL)
.min(self.snapshots.len())
.saturating_sub(1);
let snapshot_event_index = (snapshot_slot + 1) * SNAPSHOT_INTERVAL;
if snapshot_event_index == target_index {
return self.snapshots[snapshot_slot].clone();
}
let end = target_index.min(events.len());
if snapshot_event_index >= end {
return compute_state(&events[..end]);
}
replay_delta(
&self.snapshots[snapshot_slot],
&events[snapshot_event_index..end],
)
}
pub fn invalidate(&mut self) {
self.snapshots.clear();
}
fn ensure_cached_up_to(&mut self, events: &[NetEvent], target_index: usize) {
let needed_snapshots = target_index / SNAPSHOT_INTERVAL;
while self.snapshots.len() < needed_snapshots {
let next_snapshot_index = (self.snapshots.len() + 1) * SNAPSHOT_INTERVAL;
if next_snapshot_index > events.len() {
break;
}
if self.snapshots.is_empty() {
self.snapshots
.push(compute_state(&events[..next_snapshot_index]));
} else {
let prev_snapshot_index = self.snapshots.len() * SNAPSHOT_INTERVAL;
let delta = &events[prev_snapshot_index..next_snapshot_index];
let state = replay_delta(self.snapshots.last().unwrap(), delta);
self.snapshots.push(state);
}
}
}
}
impl Default for MarkingCache {
fn default() -> Self {
Self::new()
}
}
pub fn compute_state(events: &[NetEvent]) -> ComputedState {
let mut marking = HashMap::new();
let mut enabled = HashSet::new();
let mut in_flight = HashSet::new();
apply_events(&mut marking, &mut enabled, &mut in_flight, events);
to_computed_state(marking, enabled, in_flight)
}
pub fn apply_events(
marking: &mut HashMap<String, Vec<TokenInfo>>,
enabled: &mut HashSet<String>,
in_flight: &mut HashSet<String>,
events: &[NetEvent],
) {
for event in events {
match event {
NetEvent::TokenAdded {
place_name,
timestamp,
} => {
marking
.entry(place_name.to_string())
.or_default()
.push(TokenInfo {
id: None,
token_type: "unknown".into(),
value: None,
timestamp: Some(timestamp.to_string()),
});
}
NetEvent::TokenRemoved { place_name, .. } => {
if let Some(tokens) = marking.get_mut(place_name.as_ref()) {
if !tokens.is_empty() {
tokens.remove(0);
}
}
}
NetEvent::MarkingSnapshot { marking: m, .. } => {
marking.clear();
for (name, count) in m {
let tokens = (0..*count)
.map(|_| TokenInfo {
id: None,
token_type: "unknown".into(),
value: None,
timestamp: None,
})
.collect();
marking.insert(name.to_string(), tokens);
}
}
NetEvent::TransitionEnabled {
transition_name, ..
} => {
enabled.insert(transition_name.to_string());
}
NetEvent::TransitionStarted {
transition_name, ..
} => {
enabled.remove(transition_name.as_ref());
in_flight.insert(transition_name.to_string());
}
NetEvent::TransitionCompleted {
transition_name, ..
}
| NetEvent::TransitionFailed {
transition_name, ..
}
| NetEvent::TransitionTimedOut {
transition_name, ..
}
| NetEvent::ActionTimedOut {
transition_name, ..
} => {
in_flight.remove(transition_name.as_ref());
}
_ => {}
}
}
}
fn to_computed_state(
marking: HashMap<String, Vec<TokenInfo>>,
enabled: HashSet<String>,
in_flight: HashSet<String>,
) -> ComputedState {
ComputedState {
marking,
enabled_transitions: enabled.into_iter().collect(),
in_flight_transitions: in_flight.into_iter().collect(),
}
}
fn replay_delta(base: &ComputedState, delta: &[NetEvent]) -> ComputedState {
let mut marking: HashMap<String, Vec<TokenInfo>> = base
.marking
.iter()
.map(|(k, v)| (k.clone(), v.clone()))
.collect();
let mut enabled: HashSet<String> = base.enabled_transitions.iter().cloned().collect();
let mut in_flight: HashSet<String> = base.in_flight_transitions.iter().cloned().collect();
apply_events(&mut marking, &mut enabled, &mut in_flight, delta);
to_computed_state(marking, enabled, in_flight)
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
fn token_added(place: &str, ts: u64) -> NetEvent {
NetEvent::TokenAdded {
place_name: Arc::from(place),
timestamp: ts,
}
}
fn token_removed(place: &str, ts: u64) -> NetEvent {
NetEvent::TokenRemoved {
place_name: Arc::from(place),
timestamp: ts,
}
}
fn transition_enabled(name: &str, ts: u64) -> NetEvent {
NetEvent::TransitionEnabled {
transition_name: Arc::from(name),
timestamp: ts,
}
}
fn transition_started(name: &str, ts: u64) -> NetEvent {
NetEvent::TransitionStarted {
transition_name: Arc::from(name),
timestamp: ts,
}
}
fn transition_completed(name: &str, ts: u64) -> NetEvent {
NetEvent::TransitionCompleted {
transition_name: Arc::from(name),
timestamp: ts,
}
}
#[test]
fn compute_empty_state() {
let state = compute_state(&[]);
assert!(state.marking.is_empty());
assert!(state.enabled_transitions.is_empty());
assert!(state.in_flight_transitions.is_empty());
}
#[test]
fn compute_state_with_tokens() {
let events = [
token_added("p1", 0),
token_added("p1", 1),
token_added("p2", 2),
token_removed("p1", 3),
];
let state = compute_state(&events);
assert_eq!(state.marking.get("p1").map(|t| t.len()), Some(1));
assert_eq!(state.marking.get("p2").map(|t| t.len()), Some(1));
}
#[test]
fn compute_state_with_transitions() {
let events = [
transition_enabled("t1", 0),
transition_started("t1", 1),
transition_completed("t1", 2),
];
let state = compute_state(&events);
assert!(state.enabled_transitions.is_empty());
assert!(state.in_flight_transitions.is_empty());
}
#[test]
fn compute_state_in_flight() {
let events = [transition_enabled("t1", 0), transition_started("t1", 1)];
let state = compute_state(&events);
assert!(state.enabled_transitions.is_empty());
assert!(state.in_flight_transitions.contains(&"t1".to_string()));
}
#[test]
fn marking_cache_basic() {
let mut cache = MarkingCache::new();
let events: Vec<NetEvent> = (0..10).map(|i| token_added("p1", i)).collect();
let state = cache.compute_at(&events, 5);
assert_eq!(state.marking.get("p1").map(|t| t.len()), Some(5));
let state = cache.compute_at(&events, 10);
assert_eq!(state.marking.get("p1").map(|t| t.len()), Some(10));
}
#[test]
fn marking_cache_with_snapshots() {
let mut cache = MarkingCache::new();
let events: Vec<NetEvent> = (0..512).map(|i| token_added("p1", i)).collect();
let state = cache.compute_at(&events, 300);
assert_eq!(state.marking.get("p1").map(|t| t.len()), Some(300));
let state = cache.compute_at(&events, 260);
assert_eq!(state.marking.get("p1").map(|t| t.len()), Some(260));
}
#[test]
fn marking_cache_invalidate() {
let mut cache = MarkingCache::new();
let events: Vec<NetEvent> = (0..512).map(|i| token_added("p1", i)).collect();
let _ = cache.compute_at(&events, 300);
cache.invalidate();
let state = cache.compute_at(&events, 300);
assert_eq!(state.marking.get("p1").map(|t| t.len()), Some(300));
}
}