Skip to main content

libpetri_debug/
debug_protocol_handler.rs

1//! Framework-agnostic handler for the Petri net debug protocol.
2//!
3//! Manages debug subscriptions, event filtering, breakpoints, and replay
4//! for connected clients. Decoupled from any specific WebSocket framework
5//! via the [`ResponseSink`] trait.
6
7use std::collections::HashMap;
8
9use libpetri_event::net_event::NetEvent;
10
11use crate::debug_command::{BreakpointConfig, BreakpointType, DebugCommand, EventFilter};
12use crate::debug_response::{DebugResponse, NetEventInfo, SessionSummary};
13use crate::debug_session_registry::{DebugSession, DebugSessionRegistry, build_net_structure};
14use crate::marking_cache::{MarkingCache, compute_state};
15use crate::net_event_converter::{
16    event_type_name, extract_place_name, extract_transition_name, to_event_info,
17};
18
19/// Callback for sending responses to a connected client.
20pub trait ResponseSink: Send + Sync {
21    fn send(&self, response: DebugResponse);
22}
23
24/// Blanket impl for closures.
25impl<F: Fn(DebugResponse) + Send + Sync> ResponseSink for F {
26    fn send(&self, response: DebugResponse) {
27        self(response);
28    }
29}
30
31/// Maximum events per batch when sending historical events.
32const BATCH_SIZE: usize = 500;
33
34/// Debug protocol handler managing client connections and command dispatch.
35pub struct DebugProtocolHandler {
36    session_registry: DebugSessionRegistry,
37    clients: HashMap<String, ClientState>,
38}
39
40struct ClientState {
41    sink: Box<dyn ResponseSink>,
42    subscriptions: SubscriptionState,
43}
44
45impl DebugProtocolHandler {
46    /// Creates a new protocol handler.
47    pub fn new(session_registry: DebugSessionRegistry) -> Self {
48        Self {
49            session_registry,
50            clients: HashMap::new(),
51        }
52    }
53
54    /// Returns a reference to the session registry.
55    pub fn session_registry(&self) -> &DebugSessionRegistry {
56        &self.session_registry
57    }
58
59    /// Returns a mutable reference to the session registry.
60    pub fn session_registry_mut(&mut self) -> &mut DebugSessionRegistry {
61        &mut self.session_registry
62    }
63
64    /// Registers a new client connection.
65    pub fn client_connected(&mut self, client_id: String, sink: Box<dyn ResponseSink>) {
66        self.clients.insert(
67            client_id,
68            ClientState {
69                sink,
70                subscriptions: SubscriptionState::new(),
71            },
72        );
73    }
74
75    /// Cleans up when a client disconnects.
76    pub fn client_disconnected(&mut self, client_id: &str) {
77        self.clients.remove(client_id);
78    }
79
80    /// Handles a command from a connected client.
81    pub fn handle_command(&mut self, client_id: &str, command: DebugCommand) {
82        if !self.clients.contains_key(client_id) {
83            return;
84        }
85
86        let result = match command {
87            DebugCommand::ListSessions { limit, active_only } => {
88                self.handle_list_sessions(client_id, limit, active_only)
89            }
90            DebugCommand::Subscribe {
91                session_id,
92                mode,
93                from_index,
94            } => self.handle_subscribe(client_id, session_id, mode, from_index),
95            DebugCommand::Unsubscribe { session_id } => {
96                self.handle_unsubscribe(client_id, session_id)
97            }
98            DebugCommand::Seek {
99                session_id,
100                timestamp,
101            } => self.handle_seek(client_id, session_id, timestamp),
102            DebugCommand::PlaybackSpeed { session_id, speed } => {
103                self.handle_playback_speed(client_id, session_id, speed)
104            }
105            DebugCommand::Filter { session_id, filter } => {
106                self.handle_set_filter(client_id, session_id, filter)
107            }
108            DebugCommand::Pause { session_id } => self.handle_pause(client_id, session_id),
109            DebugCommand::Resume { session_id } => self.handle_resume(client_id, session_id),
110            DebugCommand::StepForward { session_id } => {
111                self.handle_step_forward(client_id, session_id)
112            }
113            DebugCommand::StepBackward { session_id } => {
114                self.handle_step_backward(client_id, session_id)
115            }
116            DebugCommand::SetBreakpoint {
117                session_id,
118                breakpoint,
119            } => self.handle_set_breakpoint(client_id, session_id, breakpoint),
120            DebugCommand::ClearBreakpoint {
121                session_id,
122                breakpoint_id,
123            } => self.handle_clear_breakpoint(client_id, session_id, breakpoint_id),
124            DebugCommand::ListBreakpoints { session_id } => {
125                self.handle_list_breakpoints(client_id, session_id)
126            }
127            DebugCommand::ListArchives { .. }
128            | DebugCommand::ImportArchive { .. }
129            | DebugCommand::UploadArchive { .. } => {
130                // Archive commands not yet implemented
131                Ok(())
132            }
133        };
134
135        if let Err(e) = result {
136            if let Some(client) = self.clients.get(client_id) {
137                send(
138                    &*client.sink,
139                    DebugResponse::Error {
140                        code: "COMMAND_ERROR".into(),
141                        message: e,
142                        session_id: None,
143                    },
144                );
145            }
146        }
147    }
148
149    /// Delivers a live event to all subscribed clients for the given session.
150    pub fn broadcast_event(&mut self, session_id: &str, event: &NetEvent) {
151        let event_info = to_event_info(event);
152
153        // Collect client IDs to avoid borrow issues
154        let client_ids: Vec<String> = self.clients.keys().cloned().collect();
155
156        for client_id in client_ids {
157            let client = self.clients.get_mut(&client_id).unwrap();
158            let Some(sub) = client.subscriptions.sessions.get_mut(session_id) else {
159                continue;
160            };
161
162            if sub.paused {
163                continue;
164            }
165
166            if !matches_filter(&sub.filter, event) {
167                sub.event_index += 1;
168                continue;
169            }
170
171            // Check breakpoints
172            let hit_bp = check_breakpoints(&sub.breakpoints, event);
173            let idx = sub.event_index;
174            sub.event_index += 1;
175
176            if let Some(bp) = hit_bp {
177                sub.paused = true;
178                send(
179                    &*client.sink,
180                    DebugResponse::BreakpointHit {
181                        session_id: session_id.to_string(),
182                        breakpoint_id: bp.id.clone(),
183                        event: event_info.clone(),
184                        event_index: idx,
185                    },
186                );
187            }
188
189            send(
190                &*client.sink,
191                DebugResponse::Event {
192                    session_id: session_id.to_string(),
193                    index: idx,
194                    event: event_info.clone(),
195                },
196            );
197        }
198    }
199
200    // ======================== Command Handlers ========================
201
202    fn handle_list_sessions(
203        &self,
204        client_id: &str,
205        limit: Option<usize>,
206        active_only: Option<bool>,
207    ) -> Result<(), String> {
208        let limit = limit.unwrap_or(50);
209        let sessions = if active_only.unwrap_or(false) {
210            self.session_registry.list_active_sessions(limit)
211        } else {
212            self.session_registry.list_sessions(limit)
213        };
214
215        let summaries: Vec<SessionSummary> = sessions.iter().map(|s| session_summary(s)).collect();
216
217        send_to(
218            &self.clients,
219            client_id,
220            DebugResponse::SessionList {
221                sessions: summaries,
222            },
223        );
224        Ok(())
225    }
226
227    fn handle_subscribe(
228        &mut self,
229        client_id: &str,
230        session_id: String,
231        mode: crate::debug_command::SubscriptionMode,
232        from_index: Option<usize>,
233    ) -> Result<(), String> {
234        let session = self
235            .session_registry
236            .get_session(&session_id)
237            .ok_or_else(|| format!("Session not found: {session_id}"))?;
238
239        let events = session.event_store.events();
240        let computed = compute_state(&events);
241        let structure = build_net_structure(session);
242        let from_index = from_index.unwrap_or(0);
243
244        let mode_str = match mode {
245            crate::debug_command::SubscriptionMode::Live => "live",
246            crate::debug_command::SubscriptionMode::Replay => "replay",
247        };
248
249        let current_marking = computed
250            .marking
251            .iter()
252            .map(|(k, v)| (k.clone(), v.clone()))
253            .collect();
254
255        let client = self.clients.get(client_id).unwrap();
256        send(
257            &*client.sink,
258            DebugResponse::Subscribed {
259                session_id: session_id.clone(),
260                net_name: session.net_name.clone(),
261                dot_diagram: session.dot_diagram.clone(),
262                structure,
263                current_marking,
264                enabled_transitions: computed.enabled_transitions.clone(),
265                in_flight_transitions: computed.in_flight_transitions.clone(),
266                event_count: session.event_store.event_count(),
267                mode: mode_str.into(),
268            },
269        );
270
271        // Send historical events
272        let historical = session.event_store.events_from(from_index);
273        let converted: Vec<NetEventInfo> = historical.iter().map(|e| to_event_info(e)).collect();
274        send_in_batches(
275            &self.clients,
276            client_id,
277            &session_id,
278            from_index,
279            &converted,
280        );
281
282        let event_index = from_index + historical.len();
283        let paused = matches!(mode, crate::debug_command::SubscriptionMode::Replay);
284
285        let client = self.clients.get_mut(client_id).unwrap();
286        client
287            .subscriptions
288            .add_subscription(session_id, event_index, paused);
289
290        Ok(())
291    }
292
293    fn handle_unsubscribe(&mut self, client_id: &str, session_id: String) -> Result<(), String> {
294        if let Some(client) = self.clients.get_mut(client_id) {
295            client.subscriptions.cancel(&session_id);
296        }
297        send_to(
298            &self.clients,
299            client_id,
300            DebugResponse::Unsubscribed { session_id },
301        );
302        Ok(())
303    }
304
305    fn handle_seek(
306        &mut self,
307        client_id: &str,
308        session_id: String,
309        timestamp: String,
310    ) -> Result<(), String> {
311        let session = self
312            .session_registry
313            .get_session(&session_id)
314            .ok_or("Session not found")?;
315
316        let events = session.event_store.events();
317        let target_ts: u64 = timestamp.parse().unwrap_or(0);
318
319        let mut target_index = events.len();
320        for (i, e) in events.iter().enumerate() {
321            if e.timestamp() >= target_ts {
322                target_index = i;
323                break;
324            }
325        }
326
327        let client = self.clients.get_mut(client_id).unwrap();
328        client
329            .subscriptions
330            .set_event_index(&session_id, target_index);
331        let computed = client
332            .subscriptions
333            .compute_state_at(&events, &session_id, target_index);
334
335        send(
336            &*client.sink,
337            DebugResponse::MarkingSnapshot {
338                session_id,
339                marking: computed.marking,
340                enabled_transitions: computed.enabled_transitions,
341                in_flight_transitions: computed.in_flight_transitions,
342            },
343        );
344        Ok(())
345    }
346
347    fn handle_playback_speed(
348        &mut self,
349        client_id: &str,
350        session_id: String,
351        speed: f64,
352    ) -> Result<(), String> {
353        let client = self.clients.get_mut(client_id).unwrap();
354        client.subscriptions.set_speed(&session_id, speed);
355        let paused = client.subscriptions.is_paused(&session_id);
356        let current_index = client.subscriptions.get_event_index(&session_id);
357        send(
358            &*client.sink,
359            DebugResponse::PlaybackStateChanged {
360                session_id,
361                paused,
362                speed,
363                current_index,
364            },
365        );
366        Ok(())
367    }
368
369    fn handle_set_filter(
370        &mut self,
371        client_id: &str,
372        session_id: String,
373        filter: EventFilter,
374    ) -> Result<(), String> {
375        let client = self.clients.get_mut(client_id).unwrap();
376        client.subscriptions.set_filter(&session_id, filter.clone());
377        send(
378            &*client.sink,
379            DebugResponse::FilterApplied { session_id, filter },
380        );
381        Ok(())
382    }
383
384    fn handle_pause(&mut self, client_id: &str, session_id: String) -> Result<(), String> {
385        let client = self.clients.get_mut(client_id).unwrap();
386        client.subscriptions.set_paused(&session_id, true);
387        let speed = client.subscriptions.get_speed(&session_id);
388        let current_index = client.subscriptions.get_event_index(&session_id);
389        send(
390            &*client.sink,
391            DebugResponse::PlaybackStateChanged {
392                session_id,
393                paused: true,
394                speed,
395                current_index,
396            },
397        );
398        Ok(())
399    }
400
401    fn handle_resume(&mut self, client_id: &str, session_id: String) -> Result<(), String> {
402        let client = self.clients.get_mut(client_id).unwrap();
403        client.subscriptions.set_paused(&session_id, false);
404        let speed = client.subscriptions.get_speed(&session_id);
405        let current_index = client.subscriptions.get_event_index(&session_id);
406        send(
407            &*client.sink,
408            DebugResponse::PlaybackStateChanged {
409                session_id,
410                paused: false,
411                speed,
412                current_index,
413            },
414        );
415        Ok(())
416    }
417
418    fn handle_step_forward(&mut self, client_id: &str, session_id: String) -> Result<(), String> {
419        let session = self
420            .session_registry
421            .get_session(&session_id)
422            .ok_or("Session not found")?;
423
424        let events = session.event_store.events();
425        let client = self.clients.get_mut(client_id).unwrap();
426        let current_index = client.subscriptions.get_event_index(&session_id);
427
428        if current_index < events.len() {
429            let event_info = to_event_info(&events[current_index]);
430            send(
431                &*client.sink,
432                DebugResponse::Event {
433                    session_id: session_id.clone(),
434                    index: current_index,
435                    event: event_info,
436                },
437            );
438            client
439                .subscriptions
440                .set_event_index(&session_id, current_index + 1);
441        }
442        Ok(())
443    }
444
445    fn handle_step_backward(&mut self, client_id: &str, session_id: String) -> Result<(), String> {
446        let session = self
447            .session_registry
448            .get_session(&session_id)
449            .ok_or("Session not found")?;
450
451        let events = session.event_store.events();
452        let client = self.clients.get_mut(client_id).unwrap();
453        let current_index = client.subscriptions.get_event_index(&session_id);
454
455        if current_index > 0 {
456            let new_index = current_index - 1;
457            client.subscriptions.set_event_index(&session_id, new_index);
458            let computed = client
459                .subscriptions
460                .compute_state_at(&events, &session_id, new_index);
461
462            send(
463                &*client.sink,
464                DebugResponse::MarkingSnapshot {
465                    session_id,
466                    marking: computed.marking,
467                    enabled_transitions: computed.enabled_transitions,
468                    in_flight_transitions: computed.in_flight_transitions,
469                },
470            );
471        }
472        Ok(())
473    }
474
475    fn handle_set_breakpoint(
476        &mut self,
477        client_id: &str,
478        session_id: String,
479        breakpoint: BreakpointConfig,
480    ) -> Result<(), String> {
481        let client = self.clients.get_mut(client_id).unwrap();
482        client
483            .subscriptions
484            .add_breakpoint(&session_id, breakpoint.clone());
485        send(
486            &*client.sink,
487            DebugResponse::BreakpointSet {
488                session_id,
489                breakpoint,
490            },
491        );
492        Ok(())
493    }
494
495    fn handle_clear_breakpoint(
496        &mut self,
497        client_id: &str,
498        session_id: String,
499        breakpoint_id: String,
500    ) -> Result<(), String> {
501        let client = self.clients.get_mut(client_id).unwrap();
502        client
503            .subscriptions
504            .remove_breakpoint(&session_id, &breakpoint_id);
505        send(
506            &*client.sink,
507            DebugResponse::BreakpointCleared {
508                session_id,
509                breakpoint_id,
510            },
511        );
512        Ok(())
513    }
514
515    fn handle_list_breakpoints(&self, client_id: &str, session_id: String) -> Result<(), String> {
516        let client = self.clients.get(client_id).unwrap();
517        let breakpoints = client.subscriptions.get_breakpoints(&session_id);
518        send(
519            &*client.sink,
520            DebugResponse::BreakpointList {
521                session_id,
522                breakpoints,
523            },
524        );
525        Ok(())
526    }
527}
528
529// ======================== Helper Functions ========================
530
531fn send(sink: &dyn ResponseSink, response: DebugResponse) {
532    sink.send(response);
533}
534
535fn send_to(clients: &HashMap<String, ClientState>, client_id: &str, response: DebugResponse) {
536    if let Some(client) = clients.get(client_id) {
537        send(&*client.sink, response);
538    }
539}
540
541fn send_in_batches(
542    clients: &HashMap<String, ClientState>,
543    client_id: &str,
544    session_id: &str,
545    start_index: usize,
546    events: &[NetEventInfo],
547) {
548    let Some(client) = clients.get(client_id) else {
549        return;
550    };
551
552    if events.is_empty() {
553        send(
554            &*client.sink,
555            DebugResponse::EventBatch {
556                session_id: session_id.to_string(),
557                start_index,
558                events: vec![],
559                has_more: false,
560            },
561        );
562        return;
563    }
564
565    for (i, chunk) in events.chunks(BATCH_SIZE).enumerate() {
566        let chunk_start = start_index + i * BATCH_SIZE;
567        let has_more = chunk_start + chunk.len() < start_index + events.len();
568        send(
569            &*client.sink,
570            DebugResponse::EventBatch {
571                session_id: session_id.to_string(),
572                start_index: chunk_start,
573                events: chunk.to_vec(),
574                has_more,
575            },
576        );
577    }
578}
579
580fn session_summary(session: &DebugSession) -> SessionSummary {
581    SessionSummary {
582        session_id: session.session_id.clone(),
583        net_name: session.net_name.clone(),
584        start_time: session.start_time.to_string(),
585        active: session.active,
586        event_count: session.event_store.event_count(),
587    }
588}
589
590fn matches_filter(filter: &Option<EventFilter>, event: &NetEvent) -> bool {
591    let Some(filter) = filter else { return true };
592
593    // Event type: include then exclude
594    let event_type = event_type_name(event);
595    if let Some(ref types) = filter.event_types {
596        if !types.is_empty() && !types.iter().any(|t| t == event_type) {
597            return false;
598        }
599    }
600    if let Some(ref types) = filter.exclude_event_types {
601        if !types.is_empty() && types.iter().any(|t| t == event_type) {
602            return false;
603        }
604    }
605
606    // Transition name: extract once, include then exclude
607    let need_transition = filter.transition_names.as_ref().is_some_and(|n| !n.is_empty())
608        || filter.exclude_transition_names.as_ref().is_some_and(|n| !n.is_empty());
609    if need_transition {
610        let t_name = extract_transition_name(event);
611        if let Some(ref names) = filter.transition_names {
612            if !names.is_empty() {
613                match t_name {
614                    Some(n) => {
615                        if !names.iter().any(|t| t == n) {
616                            return false;
617                        }
618                    }
619                    None => return false,
620                }
621            }
622        }
623        if let Some(ref names) = filter.exclude_transition_names {
624            if !names.is_empty() {
625                if let Some(n) = t_name {
626                    if names.iter().any(|t| t == n) {
627                        return false;
628                    }
629                }
630            }
631        }
632    }
633
634    // Place name: extract once, include then exclude
635    let need_place = filter.place_names.as_ref().is_some_and(|n| !n.is_empty())
636        || filter.exclude_place_names.as_ref().is_some_and(|n| !n.is_empty());
637    if need_place {
638        let p_name = extract_place_name(event);
639        if let Some(ref names) = filter.place_names {
640            if !names.is_empty() {
641                match p_name {
642                    Some(n) => {
643                        if !names.iter().any(|t| t == n) {
644                            return false;
645                        }
646                    }
647                    None => return false,
648                }
649            }
650        }
651        if let Some(ref names) = filter.exclude_place_names {
652            if !names.is_empty() {
653                if let Some(n) = p_name {
654                    if names.iter().any(|t| t == n) {
655                        return false;
656                    }
657                }
658            }
659        }
660    }
661
662    true
663}
664
665fn matches_breakpoint(bp: &BreakpointConfig, event: &NetEvent) -> bool {
666    if !bp.enabled {
667        return false;
668    }
669    match bp.bp_type {
670        BreakpointType::TransitionEnabled => {
671            matches!(event, NetEvent::TransitionEnabled { transition_name, .. }
672                if bp.target.as_ref().is_none_or(|t| t == transition_name.as_ref()))
673        }
674        BreakpointType::TransitionStart => {
675            matches!(event, NetEvent::TransitionStarted { transition_name, .. }
676                if bp.target.as_ref().is_none_or(|t| t == transition_name.as_ref()))
677        }
678        BreakpointType::TransitionComplete => {
679            matches!(event, NetEvent::TransitionCompleted { transition_name, .. }
680                if bp.target.as_ref().is_none_or(|t| t == transition_name.as_ref()))
681        }
682        BreakpointType::TransitionFail => {
683            matches!(event, NetEvent::TransitionFailed { transition_name, .. }
684                if bp.target.as_ref().is_none_or(|t| t == transition_name.as_ref()))
685        }
686        BreakpointType::TokenAdded => {
687            matches!(event, NetEvent::TokenAdded { place_name, .. }
688                if bp.target.as_ref().is_none_or(|t| t == place_name.as_ref()))
689        }
690        BreakpointType::TokenRemoved => {
691            matches!(event, NetEvent::TokenRemoved { place_name, .. }
692                if bp.target.as_ref().is_none_or(|t| t == place_name.as_ref()))
693        }
694    }
695}
696
697fn check_breakpoints(
698    breakpoints: &HashMap<String, BreakpointConfig>,
699    event: &NetEvent,
700) -> Option<BreakpointConfig> {
701    for bp in breakpoints.values() {
702        if matches_breakpoint(bp, event) {
703            return Some(bp.clone());
704        }
705    }
706    None
707}
708
709// ======================== Subscription State ========================
710
711struct SessionSubscription {
712    event_index: usize,
713    marking_cache: MarkingCache,
714    breakpoints: HashMap<String, BreakpointConfig>,
715    paused: bool,
716    speed: f64,
717    filter: Option<EventFilter>,
718}
719
720struct SubscriptionState {
721    sessions: HashMap<String, SessionSubscription>,
722}
723
724impl SubscriptionState {
725    fn new() -> Self {
726        Self {
727            sessions: HashMap::new(),
728        }
729    }
730
731    fn add_subscription(&mut self, session_id: String, event_index: usize, paused: bool) {
732        self.sessions.insert(
733            session_id,
734            SessionSubscription {
735                event_index,
736                marking_cache: MarkingCache::new(),
737                breakpoints: HashMap::new(),
738                paused,
739                speed: 1.0,
740                filter: None,
741            },
742        );
743    }
744
745    fn cancel(&mut self, session_id: &str) {
746        self.sessions.remove(session_id);
747    }
748
749    fn is_paused(&self, session_id: &str) -> bool {
750        self.sessions.get(session_id).is_some_and(|s| s.paused)
751    }
752
753    fn set_paused(&mut self, session_id: &str, paused: bool) {
754        if let Some(sub) = self.sessions.get_mut(session_id) {
755            sub.paused = paused;
756        }
757    }
758
759    fn get_speed(&self, session_id: &str) -> f64 {
760        self.sessions.get(session_id).map_or(1.0, |s| s.speed)
761    }
762
763    fn set_speed(&mut self, session_id: &str, speed: f64) {
764        if let Some(sub) = self.sessions.get_mut(session_id) {
765            sub.speed = speed;
766        }
767    }
768
769    fn get_event_index(&self, session_id: &str) -> usize {
770        self.sessions.get(session_id).map_or(0, |s| s.event_index)
771    }
772
773    fn set_event_index(&mut self, session_id: &str, index: usize) {
774        if let Some(sub) = self.sessions.get_mut(session_id) {
775            sub.event_index = index;
776        }
777    }
778
779    fn compute_state_at(
780        &mut self,
781        events: &[NetEvent],
782        session_id: &str,
783        target_index: usize,
784    ) -> crate::marking_cache::ComputedState {
785        if let Some(sub) = self.sessions.get_mut(session_id) {
786            sub.marking_cache.compute_at(events, target_index)
787        } else {
788            compute_state(&events[..target_index.min(events.len())])
789        }
790    }
791
792    fn set_filter(&mut self, session_id: &str, filter: EventFilter) {
793        if let Some(sub) = self.sessions.get_mut(session_id) {
794            sub.filter = Some(filter);
795        }
796    }
797
798    fn add_breakpoint(&mut self, session_id: &str, breakpoint: BreakpointConfig) {
799        if let Some(sub) = self.sessions.get_mut(session_id) {
800            sub.breakpoints.insert(breakpoint.id.clone(), breakpoint);
801        }
802    }
803
804    fn remove_breakpoint(&mut self, session_id: &str, breakpoint_id: &str) {
805        if let Some(sub) = self.sessions.get_mut(session_id) {
806            sub.breakpoints.remove(breakpoint_id);
807        }
808    }
809
810    fn get_breakpoints(&self, session_id: &str) -> Vec<BreakpointConfig> {
811        self.sessions
812            .get(session_id)
813            .map_or_else(Vec::new, |s| s.breakpoints.values().cloned().collect())
814    }
815}
816
817#[cfg(test)]
818mod tests {
819    use super::*;
820    use crate::debug_event_store::DebugEventStore;
821    use std::sync::{Arc, Mutex};
822
823    fn make_handler_with_net() -> (DebugProtocolHandler, Arc<DebugEventStore>) {
824        use libpetri_core::input::one;
825        use libpetri_core::output::out_place;
826        use libpetri_core::place::Place;
827        use libpetri_core::transition::Transition;
828
829        let p1 = Place::<i32>::new("p1");
830        let p2 = Place::<i32>::new("p2");
831        let t = Transition::builder("t1")
832            .input(one(&p1))
833            .output(out_place(&p2))
834            .build();
835        let net = libpetri_core::petri_net::PetriNet::builder("test")
836            .transition(t)
837            .build();
838
839        let mut registry = DebugSessionRegistry::new();
840        let store = registry.register("s1".into(), &net);
841        let handler = DebugProtocolHandler::new(registry);
842        (handler, store)
843    }
844
845    fn collector_sink() -> (Box<dyn ResponseSink>, Arc<Mutex<Vec<DebugResponse>>>) {
846        let collected = Arc::new(Mutex::new(Vec::new()));
847        let collected_clone = Arc::clone(&collected);
848        let sink: Box<dyn ResponseSink> = Box::new(move |resp: DebugResponse| {
849            collected_clone.lock().unwrap().push(resp);
850        });
851        (sink, collected)
852    }
853
854    #[test]
855    fn list_sessions() {
856        let (mut handler, _store) = make_handler_with_net();
857        let (sink, collected) = collector_sink();
858        handler.client_connected("c1".into(), sink);
859
860        handler.handle_command(
861            "c1",
862            DebugCommand::ListSessions {
863                limit: None,
864                active_only: None,
865            },
866        );
867
868        let responses = collected.lock().unwrap();
869        assert_eq!(responses.len(), 1);
870        match &responses[0] {
871            DebugResponse::SessionList { sessions } => {
872                assert_eq!(sessions.len(), 1);
873                assert_eq!(sessions[0].net_name, "test");
874            }
875            _ => panic!("expected SessionList"),
876        }
877    }
878
879    #[test]
880    fn subscribe_and_unsubscribe() {
881        let (mut handler, _store) = make_handler_with_net();
882        let (sink, collected) = collector_sink();
883        handler.client_connected("c1".into(), sink);
884
885        handler.handle_command(
886            "c1",
887            DebugCommand::Subscribe {
888                session_id: "s1".into(),
889                mode: crate::debug_command::SubscriptionMode::Live,
890                from_index: None,
891            },
892        );
893
894        {
895            let responses = collected.lock().unwrap();
896            assert!(responses.len() >= 1);
897            match &responses[0] {
898                DebugResponse::Subscribed {
899                    session_id,
900                    net_name,
901                    ..
902                } => {
903                    assert_eq!(session_id, "s1");
904                    assert_eq!(net_name, "test");
905                }
906                _ => panic!("expected Subscribed"),
907            }
908        }
909
910        handler.handle_command(
911            "c1",
912            DebugCommand::Unsubscribe {
913                session_id: "s1".into(),
914            },
915        );
916
917        let responses = collected.lock().unwrap();
918        let last = responses.last().unwrap();
919        match last {
920            DebugResponse::Unsubscribed { session_id } => {
921                assert_eq!(session_id, "s1");
922            }
923            _ => panic!("expected Unsubscribed"),
924        }
925    }
926
927    #[test]
928    fn subscribe_to_nonexistent_session() {
929        let (mut handler, _store) = make_handler_with_net();
930        let (sink, collected) = collector_sink();
931        handler.client_connected("c1".into(), sink);
932
933        handler.handle_command(
934            "c1",
935            DebugCommand::Subscribe {
936                session_id: "nonexistent".into(),
937                mode: crate::debug_command::SubscriptionMode::Live,
938                from_index: None,
939            },
940        );
941
942        let responses = collected.lock().unwrap();
943        match &responses[0] {
944            DebugResponse::Error { code, .. } => assert_eq!(code, "COMMAND_ERROR"),
945            _ => panic!("expected Error"),
946        }
947    }
948
949    #[test]
950    fn pause_and_resume() {
951        let (mut handler, _store) = make_handler_with_net();
952        let (sink, collected) = collector_sink();
953        handler.client_connected("c1".into(), sink);
954
955        handler.handle_command(
956            "c1",
957            DebugCommand::Subscribe {
958                session_id: "s1".into(),
959                mode: crate::debug_command::SubscriptionMode::Live,
960                from_index: None,
961            },
962        );
963
964        handler.handle_command(
965            "c1",
966            DebugCommand::Pause {
967                session_id: "s1".into(),
968            },
969        );
970
971        let responses = collected.lock().unwrap();
972        let pause_resp = responses
973            .iter()
974            .find(|r| matches!(r, DebugResponse::PlaybackStateChanged { paused: true, .. }));
975        assert!(pause_resp.is_some());
976    }
977
978    #[test]
979    fn set_and_list_breakpoints() {
980        let (mut handler, _store) = make_handler_with_net();
981        let (sink, collected) = collector_sink();
982        handler.client_connected("c1".into(), sink);
983
984        handler.handle_command(
985            "c1",
986            DebugCommand::Subscribe {
987                session_id: "s1".into(),
988                mode: crate::debug_command::SubscriptionMode::Live,
989                from_index: None,
990            },
991        );
992
993        handler.handle_command(
994            "c1",
995            DebugCommand::SetBreakpoint {
996                session_id: "s1".into(),
997                breakpoint: BreakpointConfig {
998                    id: "bp1".into(),
999                    bp_type: BreakpointType::TransitionStart,
1000                    target: Some("t1".into()),
1001                    enabled: true,
1002                },
1003            },
1004        );
1005
1006        handler.handle_command(
1007            "c1",
1008            DebugCommand::ListBreakpoints {
1009                session_id: "s1".into(),
1010            },
1011        );
1012
1013        let responses = collected.lock().unwrap();
1014        let bp_list = responses
1015            .iter()
1016            .find(|r| matches!(r, DebugResponse::BreakpointList { .. }));
1017        match bp_list.unwrap() {
1018            DebugResponse::BreakpointList { breakpoints, .. } => {
1019                assert_eq!(breakpoints.len(), 1);
1020                assert_eq!(breakpoints[0].id, "bp1");
1021            }
1022            _ => unreachable!(),
1023        }
1024    }
1025
1026    #[test]
1027    fn broadcast_event_to_subscribers() {
1028        let (mut handler, store) = make_handler_with_net();
1029        let (sink, collected) = collector_sink();
1030        handler.client_connected("c1".into(), sink);
1031
1032        handler.handle_command(
1033            "c1",
1034            DebugCommand::Subscribe {
1035                session_id: "s1".into(),
1036                mode: crate::debug_command::SubscriptionMode::Live,
1037                from_index: None,
1038            },
1039        );
1040
1041        let event = NetEvent::TransitionStarted {
1042            transition_name: Arc::from("t1"),
1043            timestamp: 1000,
1044        };
1045        store.append(event.clone());
1046        handler.broadcast_event("s1", &event);
1047
1048        let responses = collected.lock().unwrap();
1049        let event_resp = responses
1050            .iter()
1051            .find(|r| matches!(r, DebugResponse::Event { .. }));
1052        assert!(event_resp.is_some());
1053    }
1054
1055    #[test]
1056    fn filter_matching() {
1057        let event = NetEvent::TransitionStarted {
1058            transition_name: Arc::from("t1"),
1059            timestamp: 0,
1060        };
1061
1062        // No filter — matches all
1063        assert!(matches_filter(&None, &event));
1064
1065        // Type filter matching
1066        let filter = EventFilter {
1067            event_types: Some(vec!["TransitionStarted".into()]),
1068            ..Default::default()
1069        };
1070        assert!(matches_filter(&Some(filter), &event));
1071
1072        // Type filter not matching
1073        let filter = EventFilter {
1074            event_types: Some(vec!["TokenAdded".into()]),
1075            ..Default::default()
1076        };
1077        assert!(!matches_filter(&Some(filter), &event));
1078
1079        // Transition name filter
1080        let filter = EventFilter {
1081            transition_names: Some(vec!["t1".into()]),
1082            ..Default::default()
1083        };
1084        assert!(matches_filter(&Some(filter), &event));
1085
1086        let filter = EventFilter {
1087            transition_names: Some(vec!["t2".into()]),
1088            ..Default::default()
1089        };
1090        assert!(!matches_filter(&Some(filter), &event));
1091    }
1092
1093    #[test]
1094    fn filter_exclusion() {
1095        let event = NetEvent::TransitionStarted {
1096            transition_name: Arc::from("t1"),
1097            timestamp: 0,
1098        };
1099
1100        // Exclude by transition name
1101        let filter = EventFilter {
1102            exclude_transition_names: Some(vec!["t1".into()]),
1103            ..Default::default()
1104        };
1105        assert!(!matches_filter(&Some(filter), &event));
1106
1107        // Exclude different transition — should pass
1108        let filter = EventFilter {
1109            exclude_transition_names: Some(vec!["t2".into()]),
1110            ..Default::default()
1111        };
1112        assert!(matches_filter(&Some(filter), &event));
1113
1114        // Exclude by event type
1115        let filter = EventFilter {
1116            exclude_event_types: Some(vec!["TransitionStarted".into()]),
1117            ..Default::default()
1118        };
1119        assert!(!matches_filter(&Some(filter), &event));
1120    }
1121
1122    #[test]
1123    fn filter_combined_include_exclude() {
1124        let event_t1 = NetEvent::TransitionStarted {
1125            transition_name: Arc::from("t1"),
1126            timestamp: 0,
1127        };
1128        let event_t2 = NetEvent::TransitionStarted {
1129            transition_name: Arc::from("t2"),
1130            timestamp: 0,
1131        };
1132
1133        // Include TransitionStarted, exclude t2
1134        let filter = EventFilter {
1135            event_types: Some(vec!["TransitionStarted".into()]),
1136            exclude_transition_names: Some(vec!["t2".into()]),
1137            ..Default::default()
1138        };
1139        assert!(matches_filter(&Some(filter.clone()), &event_t1));
1140        assert!(!matches_filter(&Some(filter), &event_t2));
1141    }
1142
1143    #[test]
1144    fn breakpoint_matching() {
1145        let event = NetEvent::TransitionStarted {
1146            transition_name: Arc::from("t1"),
1147            timestamp: 0,
1148        };
1149
1150        let bp = BreakpointConfig {
1151            id: "bp1".into(),
1152            bp_type: BreakpointType::TransitionStart,
1153            target: Some("t1".into()),
1154            enabled: true,
1155        };
1156        assert!(matches_breakpoint(&bp, &event));
1157
1158        // Disabled breakpoint
1159        let bp_disabled = BreakpointConfig {
1160            id: "bp2".into(),
1161            bp_type: BreakpointType::TransitionStart,
1162            target: Some("t1".into()),
1163            enabled: false,
1164        };
1165        assert!(!matches_breakpoint(&bp_disabled, &event));
1166
1167        // Wrong target
1168        let bp_wrong = BreakpointConfig {
1169            id: "bp3".into(),
1170            bp_type: BreakpointType::TransitionStart,
1171            target: Some("t2".into()),
1172            enabled: true,
1173        };
1174        assert!(!matches_breakpoint(&bp_wrong, &event));
1175
1176        // Wildcard (no target)
1177        let bp_wild = BreakpointConfig {
1178            id: "bp4".into(),
1179            bp_type: BreakpointType::TransitionStart,
1180            target: None,
1181            enabled: true,
1182        };
1183        assert!(matches_breakpoint(&bp_wild, &event));
1184    }
1185
1186    #[test]
1187    fn client_disconnect_cleanup() {
1188        let (mut handler, _store) = make_handler_with_net();
1189        let (sink, _collected) = collector_sink();
1190        handler.client_connected("c1".into(), sink);
1191        handler.client_disconnected("c1");
1192        assert!(handler.clients.is_empty());
1193    }
1194
1195    #[test]
1196    fn step_forward_and_backward() {
1197        let (mut handler, store) = make_handler_with_net();
1198        let (sink, collected) = collector_sink();
1199        handler.client_connected("c1".into(), sink);
1200
1201        // Add some events
1202        for i in 0..5 {
1203            store.append(NetEvent::TokenAdded {
1204                place_name: Arc::from("p1"),
1205                timestamp: i,
1206            });
1207        }
1208
1209        handler.handle_command(
1210            "c1",
1211            DebugCommand::Subscribe {
1212                session_id: "s1".into(),
1213                mode: crate::debug_command::SubscriptionMode::Replay,
1214                from_index: Some(0),
1215            },
1216        );
1217
1218        // Step forward
1219        handler.handle_command(
1220            "c1",
1221            DebugCommand::StepForward {
1222                session_id: "s1".into(),
1223            },
1224        );
1225
1226        // Step backward
1227        handler.handle_command(
1228            "c1",
1229            DebugCommand::StepBackward {
1230                session_id: "s1".into(),
1231            },
1232        );
1233
1234        let responses = collected.lock().unwrap();
1235        assert!(responses.len() >= 3); // subscribed + batch + step responses
1236    }
1237}