1use std::collections::{HashMap, HashSet};
4
5use libpetri_event::net_event::NetEvent;
6
7use crate::debug_response::TokenInfo;
8
9pub const SNAPSHOT_INTERVAL: usize = 256;
11
12#[derive(Debug, Clone)]
14pub struct ComputedState {
15 pub marking: HashMap<String, Vec<TokenInfo>>,
16 pub enabled_transitions: Vec<String>,
17 pub in_flight_transitions: Vec<String>,
18}
19
20pub struct MarkingCache {
22 snapshots: Vec<ComputedState>,
23}
24
25impl MarkingCache {
26 pub fn new() -> Self {
28 Self {
29 snapshots: Vec::new(),
30 }
31 }
32
33 pub fn compute_at(&mut self, events: &[NetEvent], target_index: usize) -> ComputedState {
36 if target_index == 0 {
37 return compute_state(&[]);
38 }
39
40 self.ensure_cached_up_to(events, target_index);
41
42 if self.snapshots.is_empty() {
43 return compute_state(&events[..target_index.min(events.len())]);
44 }
45
46 let snapshot_slot = (target_index / SNAPSHOT_INTERVAL)
48 .min(self.snapshots.len())
49 .saturating_sub(1);
50 let snapshot_event_index = (snapshot_slot + 1) * SNAPSHOT_INTERVAL;
51
52 if snapshot_event_index == target_index {
53 return self.snapshots[snapshot_slot].clone();
54 }
55
56 let end = target_index.min(events.len());
57 if snapshot_event_index >= end {
58 return compute_state(&events[..end]);
59 }
60
61 replay_delta(
62 &self.snapshots[snapshot_slot],
63 &events[snapshot_event_index..end],
64 )
65 }
66
67 pub fn invalidate(&mut self) {
69 self.snapshots.clear();
70 }
71
72 fn ensure_cached_up_to(&mut self, events: &[NetEvent], target_index: usize) {
73 let needed_snapshots = target_index / SNAPSHOT_INTERVAL;
74
75 while self.snapshots.len() < needed_snapshots {
76 let next_snapshot_index = (self.snapshots.len() + 1) * SNAPSHOT_INTERVAL;
77 if next_snapshot_index > events.len() {
78 break;
79 }
80
81 if self.snapshots.is_empty() {
82 self.snapshots
83 .push(compute_state(&events[..next_snapshot_index]));
84 } else {
85 let prev_snapshot_index = self.snapshots.len() * SNAPSHOT_INTERVAL;
86 let delta = &events[prev_snapshot_index..next_snapshot_index];
87 let state = replay_delta(self.snapshots.last().unwrap(), delta);
88 self.snapshots.push(state);
89 }
90 }
91 }
92}
93
94impl Default for MarkingCache {
95 fn default() -> Self {
96 Self::new()
97 }
98}
99
100pub fn compute_state(events: &[NetEvent]) -> ComputedState {
102 let mut marking = HashMap::new();
103 let mut enabled = HashSet::new();
104 let mut in_flight = HashSet::new();
105 apply_events(&mut marking, &mut enabled, &mut in_flight, events);
106 to_computed_state(marking, enabled, in_flight)
107}
108
109pub fn apply_events(
111 marking: &mut HashMap<String, Vec<TokenInfo>>,
112 enabled: &mut HashSet<String>,
113 in_flight: &mut HashSet<String>,
114 events: &[NetEvent],
115) {
116 for event in events {
117 match event {
118 NetEvent::TokenAdded {
119 place_name,
120 timestamp,
121 } => {
122 marking
123 .entry(place_name.to_string())
124 .or_default()
125 .push(TokenInfo {
126 id: None,
127 token_type: "unknown".into(),
128 value: None,
129 timestamp: Some(timestamp.to_string()),
130 });
131 }
132 NetEvent::TokenRemoved { place_name, .. } => {
133 if let Some(tokens) = marking.get_mut(place_name.as_ref()) {
134 if !tokens.is_empty() {
135 tokens.remove(0);
136 }
137 }
138 }
139 NetEvent::MarkingSnapshot { marking: m, .. } => {
140 marking.clear();
141 for (name, count) in m {
142 let tokens = (0..*count)
143 .map(|_| TokenInfo {
144 id: None,
145 token_type: "unknown".into(),
146 value: None,
147 timestamp: None,
148 })
149 .collect();
150 marking.insert(name.to_string(), tokens);
151 }
152 }
153 NetEvent::TransitionEnabled {
154 transition_name, ..
155 } => {
156 enabled.insert(transition_name.to_string());
157 }
158 NetEvent::TransitionStarted {
159 transition_name, ..
160 } => {
161 enabled.remove(transition_name.as_ref());
162 in_flight.insert(transition_name.to_string());
163 }
164 NetEvent::TransitionCompleted {
165 transition_name, ..
166 }
167 | NetEvent::TransitionFailed {
168 transition_name, ..
169 }
170 | NetEvent::TransitionTimedOut {
171 transition_name, ..
172 }
173 | NetEvent::ActionTimedOut {
174 transition_name, ..
175 } => {
176 in_flight.remove(transition_name.as_ref());
177 }
178 _ => {}
179 }
180 }
181}
182
183fn to_computed_state(
184 marking: HashMap<String, Vec<TokenInfo>>,
185 enabled: HashSet<String>,
186 in_flight: HashSet<String>,
187) -> ComputedState {
188 ComputedState {
189 marking,
190 enabled_transitions: enabled.into_iter().collect(),
191 in_flight_transitions: in_flight.into_iter().collect(),
192 }
193}
194
195fn replay_delta(base: &ComputedState, delta: &[NetEvent]) -> ComputedState {
196 let mut marking: HashMap<String, Vec<TokenInfo>> = base
197 .marking
198 .iter()
199 .map(|(k, v)| (k.clone(), v.clone()))
200 .collect();
201 let mut enabled: HashSet<String> = base.enabled_transitions.iter().cloned().collect();
202 let mut in_flight: HashSet<String> = base.in_flight_transitions.iter().cloned().collect();
203 apply_events(&mut marking, &mut enabled, &mut in_flight, delta);
204 to_computed_state(marking, enabled, in_flight)
205}
206
207#[cfg(test)]
208mod tests {
209 use super::*;
210 use std::sync::Arc;
211
212 fn token_added(place: &str, ts: u64) -> NetEvent {
213 NetEvent::TokenAdded {
214 place_name: Arc::from(place),
215 timestamp: ts,
216 }
217 }
218
219 fn token_removed(place: &str, ts: u64) -> NetEvent {
220 NetEvent::TokenRemoved {
221 place_name: Arc::from(place),
222 timestamp: ts,
223 }
224 }
225
226 fn transition_enabled(name: &str, ts: u64) -> NetEvent {
227 NetEvent::TransitionEnabled {
228 transition_name: Arc::from(name),
229 timestamp: ts,
230 }
231 }
232
233 fn transition_started(name: &str, ts: u64) -> NetEvent {
234 NetEvent::TransitionStarted {
235 transition_name: Arc::from(name),
236 timestamp: ts,
237 }
238 }
239
240 fn transition_completed(name: &str, ts: u64) -> NetEvent {
241 NetEvent::TransitionCompleted {
242 transition_name: Arc::from(name),
243 timestamp: ts,
244 }
245 }
246
247 #[test]
248 fn compute_empty_state() {
249 let state = compute_state(&[]);
250 assert!(state.marking.is_empty());
251 assert!(state.enabled_transitions.is_empty());
252 assert!(state.in_flight_transitions.is_empty());
253 }
254
255 #[test]
256 fn compute_state_with_tokens() {
257 let events = [
258 token_added("p1", 0),
259 token_added("p1", 1),
260 token_added("p2", 2),
261 token_removed("p1", 3),
262 ];
263 let state = compute_state(&events);
264 assert_eq!(state.marking.get("p1").map(|t| t.len()), Some(1));
265 assert_eq!(state.marking.get("p2").map(|t| t.len()), Some(1));
266 }
267
268 #[test]
269 fn compute_state_with_transitions() {
270 let events = [
271 transition_enabled("t1", 0),
272 transition_started("t1", 1),
273 transition_completed("t1", 2),
274 ];
275 let state = compute_state(&events);
276 assert!(state.enabled_transitions.is_empty());
277 assert!(state.in_flight_transitions.is_empty());
278 }
279
280 #[test]
281 fn compute_state_in_flight() {
282 let events = [transition_enabled("t1", 0), transition_started("t1", 1)];
283 let state = compute_state(&events);
284 assert!(state.enabled_transitions.is_empty());
285 assert!(state.in_flight_transitions.contains(&"t1".to_string()));
286 }
287
288 #[test]
289 fn marking_cache_basic() {
290 let mut cache = MarkingCache::new();
291 let events: Vec<NetEvent> = (0..10).map(|i| token_added("p1", i)).collect();
292
293 let state = cache.compute_at(&events, 5);
294 assert_eq!(state.marking.get("p1").map(|t| t.len()), Some(5));
295
296 let state = cache.compute_at(&events, 10);
297 assert_eq!(state.marking.get("p1").map(|t| t.len()), Some(10));
298 }
299
300 #[test]
301 fn marking_cache_with_snapshots() {
302 let mut cache = MarkingCache::new();
303 let events: Vec<NetEvent> = (0..512).map(|i| token_added("p1", i)).collect();
305
306 let state = cache.compute_at(&events, 300);
307 assert_eq!(state.marking.get("p1").map(|t| t.len()), Some(300));
308
309 let state = cache.compute_at(&events, 260);
311 assert_eq!(state.marking.get("p1").map(|t| t.len()), Some(260));
312 }
313
314 #[test]
315 fn marking_cache_invalidate() {
316 let mut cache = MarkingCache::new();
317 let events: Vec<NetEvent> = (0..512).map(|i| token_added("p1", i)).collect();
318
319 let _ = cache.compute_at(&events, 300);
320 cache.invalidate();
321
322 let state = cache.compute_at(&events, 300);
324 assert_eq!(state.marking.get("p1").map(|t| t.len()), Some(300));
325 }
326}