Skip to main content

libpetri_debug/
marking_cache.rs

1//! Caches computed state snapshots at periodic intervals for efficient seek/step.
2
3use std::collections::{HashMap, HashSet};
4
5use libpetri_event::net_event::NetEvent;
6
7use crate::debug_response::TokenInfo;
8
9/// Number of events between cached snapshots.
10pub const SNAPSHOT_INTERVAL: usize = 256;
11
12/// Computed state from replaying events.
13#[derive(Debug, Clone)]
14pub struct ComputedState {
15    pub marking: HashMap<String, Vec<TokenInfo>>,
16    pub enabled_transitions: Vec<String>,
17    pub in_flight_transitions: Vec<String>,
18}
19
20/// Cache of computed state snapshots for efficient seek/step operations.
21pub struct MarkingCache {
22    snapshots: Vec<ComputedState>,
23}
24
25impl MarkingCache {
26    /// Creates a new empty cache.
27    pub fn new() -> Self {
28        Self {
29            snapshots: Vec::new(),
30        }
31    }
32
33    /// Computes the state at the given event index, using cached snapshots
34    /// to minimize the number of events that need to be replayed.
35    pub fn compute_at(&mut self, events: &[NetEvent], target_index: usize) -> ComputedState {
36        if target_index == 0 {
37            return compute_state(&[]);
38        }
39
40        self.ensure_cached_up_to(events, target_index);
41
42        if self.snapshots.is_empty() {
43            return compute_state(&events[..target_index.min(events.len())]);
44        }
45
46        // Find highest snapshot <= target_index
47        let snapshot_slot = (target_index / SNAPSHOT_INTERVAL)
48            .min(self.snapshots.len())
49            .saturating_sub(1);
50        let snapshot_event_index = (snapshot_slot + 1) * SNAPSHOT_INTERVAL;
51
52        if snapshot_event_index == target_index {
53            return self.snapshots[snapshot_slot].clone();
54        }
55
56        let end = target_index.min(events.len());
57        if snapshot_event_index >= end {
58            return compute_state(&events[..end]);
59        }
60
61        replay_delta(
62            &self.snapshots[snapshot_slot],
63            &events[snapshot_event_index..end],
64        )
65    }
66
67    /// Invalidates the cache.
68    pub fn invalidate(&mut self) {
69        self.snapshots.clear();
70    }
71
72    fn ensure_cached_up_to(&mut self, events: &[NetEvent], target_index: usize) {
73        let needed_snapshots = target_index / SNAPSHOT_INTERVAL;
74
75        while self.snapshots.len() < needed_snapshots {
76            let next_snapshot_index = (self.snapshots.len() + 1) * SNAPSHOT_INTERVAL;
77            if next_snapshot_index > events.len() {
78                break;
79            }
80
81            if self.snapshots.is_empty() {
82                self.snapshots
83                    .push(compute_state(&events[..next_snapshot_index]));
84            } else {
85                let prev_snapshot_index = self.snapshots.len() * SNAPSHOT_INTERVAL;
86                let delta = &events[prev_snapshot_index..next_snapshot_index];
87                let state = replay_delta(self.snapshots.last().unwrap(), delta);
88                self.snapshots.push(state);
89            }
90        }
91    }
92}
93
94impl Default for MarkingCache {
95    fn default() -> Self {
96        Self::new()
97    }
98}
99
100/// Computes marking, enabled transitions, and in-flight transitions from events.
101pub fn compute_state(events: &[NetEvent]) -> ComputedState {
102    let mut marking = HashMap::new();
103    let mut enabled = HashSet::new();
104    let mut in_flight = HashSet::new();
105    apply_events(&mut marking, &mut enabled, &mut in_flight, events);
106    to_computed_state(marking, enabled, in_flight)
107}
108
109/// Applies events to mutable accumulator collections.
110pub fn apply_events(
111    marking: &mut HashMap<String, Vec<TokenInfo>>,
112    enabled: &mut HashSet<String>,
113    in_flight: &mut HashSet<String>,
114    events: &[NetEvent],
115) {
116    for event in events {
117        match event {
118            NetEvent::TokenAdded {
119                place_name,
120                timestamp,
121            } => {
122                marking
123                    .entry(place_name.to_string())
124                    .or_default()
125                    .push(TokenInfo {
126                        id: None,
127                        token_type: "unknown".into(),
128                        value: None,
129                        timestamp: Some(timestamp.to_string()),
130                    });
131            }
132            NetEvent::TokenRemoved { place_name, .. } => {
133                if let Some(tokens) = marking.get_mut(place_name.as_ref()) {
134                    if !tokens.is_empty() {
135                        tokens.remove(0);
136                    }
137                }
138            }
139            NetEvent::MarkingSnapshot { marking: m, .. } => {
140                marking.clear();
141                for (name, count) in m {
142                    let tokens = (0..*count)
143                        .map(|_| TokenInfo {
144                            id: None,
145                            token_type: "unknown".into(),
146                            value: None,
147                            timestamp: None,
148                        })
149                        .collect();
150                    marking.insert(name.to_string(), tokens);
151                }
152            }
153            NetEvent::TransitionEnabled {
154                transition_name, ..
155            } => {
156                enabled.insert(transition_name.to_string());
157            }
158            NetEvent::TransitionStarted {
159                transition_name, ..
160            } => {
161                enabled.remove(transition_name.as_ref());
162                in_flight.insert(transition_name.to_string());
163            }
164            NetEvent::TransitionCompleted {
165                transition_name, ..
166            }
167            | NetEvent::TransitionFailed {
168                transition_name, ..
169            }
170            | NetEvent::TransitionTimedOut {
171                transition_name, ..
172            }
173            | NetEvent::ActionTimedOut {
174                transition_name, ..
175            } => {
176                in_flight.remove(transition_name.as_ref());
177            }
178            _ => {}
179        }
180    }
181}
182
183fn to_computed_state(
184    marking: HashMap<String, Vec<TokenInfo>>,
185    enabled: HashSet<String>,
186    in_flight: HashSet<String>,
187) -> ComputedState {
188    ComputedState {
189        marking,
190        enabled_transitions: enabled.into_iter().collect(),
191        in_flight_transitions: in_flight.into_iter().collect(),
192    }
193}
194
195fn replay_delta(base: &ComputedState, delta: &[NetEvent]) -> ComputedState {
196    let mut marking: HashMap<String, Vec<TokenInfo>> = base
197        .marking
198        .iter()
199        .map(|(k, v)| (k.clone(), v.clone()))
200        .collect();
201    let mut enabled: HashSet<String> = base.enabled_transitions.iter().cloned().collect();
202    let mut in_flight: HashSet<String> = base.in_flight_transitions.iter().cloned().collect();
203    apply_events(&mut marking, &mut enabled, &mut in_flight, delta);
204    to_computed_state(marking, enabled, in_flight)
205}
206
207#[cfg(test)]
208mod tests {
209    use super::*;
210    use std::sync::Arc;
211
212    fn token_added(place: &str, ts: u64) -> NetEvent {
213        NetEvent::TokenAdded {
214            place_name: Arc::from(place),
215            timestamp: ts,
216        }
217    }
218
219    fn token_removed(place: &str, ts: u64) -> NetEvent {
220        NetEvent::TokenRemoved {
221            place_name: Arc::from(place),
222            timestamp: ts,
223        }
224    }
225
226    fn transition_enabled(name: &str, ts: u64) -> NetEvent {
227        NetEvent::TransitionEnabled {
228            transition_name: Arc::from(name),
229            timestamp: ts,
230        }
231    }
232
233    fn transition_started(name: &str, ts: u64) -> NetEvent {
234        NetEvent::TransitionStarted {
235            transition_name: Arc::from(name),
236            timestamp: ts,
237        }
238    }
239
240    fn transition_completed(name: &str, ts: u64) -> NetEvent {
241        NetEvent::TransitionCompleted {
242            transition_name: Arc::from(name),
243            timestamp: ts,
244        }
245    }
246
247    #[test]
248    fn compute_empty_state() {
249        let state = compute_state(&[]);
250        assert!(state.marking.is_empty());
251        assert!(state.enabled_transitions.is_empty());
252        assert!(state.in_flight_transitions.is_empty());
253    }
254
255    #[test]
256    fn compute_state_with_tokens() {
257        let events = [
258            token_added("p1", 0),
259            token_added("p1", 1),
260            token_added("p2", 2),
261            token_removed("p1", 3),
262        ];
263        let state = compute_state(&events);
264        assert_eq!(state.marking.get("p1").map(|t| t.len()), Some(1));
265        assert_eq!(state.marking.get("p2").map(|t| t.len()), Some(1));
266    }
267
268    #[test]
269    fn compute_state_with_transitions() {
270        let events = [
271            transition_enabled("t1", 0),
272            transition_started("t1", 1),
273            transition_completed("t1", 2),
274        ];
275        let state = compute_state(&events);
276        assert!(state.enabled_transitions.is_empty());
277        assert!(state.in_flight_transitions.is_empty());
278    }
279
280    #[test]
281    fn compute_state_in_flight() {
282        let events = [transition_enabled("t1", 0), transition_started("t1", 1)];
283        let state = compute_state(&events);
284        assert!(state.enabled_transitions.is_empty());
285        assert!(state.in_flight_transitions.contains(&"t1".to_string()));
286    }
287
288    #[test]
289    fn marking_cache_basic() {
290        let mut cache = MarkingCache::new();
291        let events: Vec<NetEvent> = (0..10).map(|i| token_added("p1", i)).collect();
292
293        let state = cache.compute_at(&events, 5);
294        assert_eq!(state.marking.get("p1").map(|t| t.len()), Some(5));
295
296        let state = cache.compute_at(&events, 10);
297        assert_eq!(state.marking.get("p1").map(|t| t.len()), Some(10));
298    }
299
300    #[test]
301    fn marking_cache_with_snapshots() {
302        let mut cache = MarkingCache::new();
303        // Create enough events to trigger snapshot creation
304        let events: Vec<NetEvent> = (0..512).map(|i| token_added("p1", i)).collect();
305
306        let state = cache.compute_at(&events, 300);
307        assert_eq!(state.marking.get("p1").map(|t| t.len()), Some(300));
308
309        // Second query should use cached snapshot
310        let state = cache.compute_at(&events, 260);
311        assert_eq!(state.marking.get("p1").map(|t| t.len()), Some(260));
312    }
313
314    #[test]
315    fn marking_cache_invalidate() {
316        let mut cache = MarkingCache::new();
317        let events: Vec<NetEvent> = (0..512).map(|i| token_added("p1", i)).collect();
318
319        let _ = cache.compute_at(&events, 300);
320        cache.invalidate();
321
322        // After invalidation, still produces correct results
323        let state = cache.compute_at(&events, 300);
324        assert_eq!(state.marking.get("p1").map(|t| t.len()), Some(300));
325    }
326}