Skip to main content

libpetri_debug/
debug_protocol_handler.rs

1//! Framework-agnostic handler for the Petri net debug protocol.
2//!
3//! Manages debug subscriptions, event filtering, breakpoints, and replay
4//! for connected clients. Decoupled from any specific WebSocket framework
5//! via the [`ResponseSink`] trait.
6
7use std::collections::HashMap;
8
9use libpetri_event::net_event::NetEvent;
10
11use crate::debug_command::{BreakpointConfig, BreakpointType, DebugCommand, EventFilter};
12use crate::debug_response::{DebugResponse, NetEventInfo, SessionSummary};
13use crate::debug_session_registry::{DebugSession, DebugSessionRegistry, build_net_structure};
14use crate::marking_cache::{MarkingCache, compute_state};
15use crate::net_event_converter::{
16    event_type_name, extract_place_name, extract_transition_name, to_event_info,
17};
18
19/// Callback for sending responses to a connected client.
20pub trait ResponseSink: Send + Sync {
21    fn send(&self, response: DebugResponse);
22}
23
24/// Blanket impl for closures.
25impl<F: Fn(DebugResponse) + Send + Sync> ResponseSink for F {
26    fn send(&self, response: DebugResponse) {
27        self(response);
28    }
29}
30
31/// Maximum events per batch when sending historical events.
32const BATCH_SIZE: usize = 500;
33
34/// Debug protocol handler managing client connections and command dispatch.
35pub struct DebugProtocolHandler {
36    session_registry: DebugSessionRegistry,
37    clients: HashMap<String, ClientState>,
38}
39
40struct ClientState {
41    sink: Box<dyn ResponseSink>,
42    subscriptions: SubscriptionState,
43}
44
45impl DebugProtocolHandler {
46    /// Creates a new protocol handler.
47    pub fn new(session_registry: DebugSessionRegistry) -> Self {
48        Self {
49            session_registry,
50            clients: HashMap::new(),
51        }
52    }
53
54    /// Returns a reference to the session registry.
55    pub fn session_registry(&self) -> &DebugSessionRegistry {
56        &self.session_registry
57    }
58
59    /// Returns a mutable reference to the session registry.
60    pub fn session_registry_mut(&mut self) -> &mut DebugSessionRegistry {
61        &mut self.session_registry
62    }
63
64    /// Registers a new client connection.
65    pub fn client_connected(&mut self, client_id: String, sink: Box<dyn ResponseSink>) {
66        self.clients.insert(
67            client_id,
68            ClientState {
69                sink,
70                subscriptions: SubscriptionState::new(),
71            },
72        );
73    }
74
75    /// Cleans up when a client disconnects.
76    pub fn client_disconnected(&mut self, client_id: &str) {
77        self.clients.remove(client_id);
78    }
79
80    /// Handles a command from a connected client.
81    pub fn handle_command(&mut self, client_id: &str, command: DebugCommand) {
82        if !self.clients.contains_key(client_id) {
83            return;
84        }
85
86        let result = match command {
87            DebugCommand::ListSessions {
88                limit,
89                active_only,
90                tag_filter,
91            } => self.handle_list_sessions(client_id, limit, active_only, tag_filter),
92            DebugCommand::Subscribe {
93                session_id,
94                mode,
95                from_index,
96            } => self.handle_subscribe(client_id, session_id, mode, from_index),
97            DebugCommand::Unsubscribe { session_id } => {
98                self.handle_unsubscribe(client_id, session_id)
99            }
100            DebugCommand::Seek {
101                session_id,
102                timestamp,
103            } => self.handle_seek(client_id, session_id, timestamp),
104            DebugCommand::PlaybackSpeed { session_id, speed } => {
105                self.handle_playback_speed(client_id, session_id, speed)
106            }
107            DebugCommand::Filter { session_id, filter } => {
108                self.handle_set_filter(client_id, session_id, filter)
109            }
110            DebugCommand::Pause { session_id } => self.handle_pause(client_id, session_id),
111            DebugCommand::Resume { session_id } => self.handle_resume(client_id, session_id),
112            DebugCommand::StepForward { session_id } => {
113                self.handle_step_forward(client_id, session_id)
114            }
115            DebugCommand::StepBackward { session_id } => {
116                self.handle_step_backward(client_id, session_id)
117            }
118            DebugCommand::SetBreakpoint {
119                session_id,
120                breakpoint,
121            } => self.handle_set_breakpoint(client_id, session_id, breakpoint),
122            DebugCommand::ClearBreakpoint {
123                session_id,
124                breakpoint_id,
125            } => self.handle_clear_breakpoint(client_id, session_id, breakpoint_id),
126            DebugCommand::ListBreakpoints { session_id } => {
127                self.handle_list_breakpoints(client_id, session_id)
128            }
129            DebugCommand::ListArchives { .. }
130            | DebugCommand::ImportArchive { .. }
131            | DebugCommand::UploadArchive { .. } => {
132                // Archive commands not yet implemented
133                Ok(())
134            }
135        };
136
137        if let Err(e) = result {
138            if let Some(client) = self.clients.get(client_id) {
139                send(
140                    &*client.sink,
141                    DebugResponse::Error {
142                        code: "COMMAND_ERROR".into(),
143                        message: e,
144                        session_id: None,
145                    },
146                );
147            }
148        }
149    }
150
151    /// Delivers a live event to all subscribed clients for the given session.
152    pub fn broadcast_event(&mut self, session_id: &str, event: &NetEvent) {
153        let event_info = to_event_info(event);
154
155        // Collect client IDs to avoid borrow issues
156        let client_ids: Vec<String> = self.clients.keys().cloned().collect();
157
158        for client_id in client_ids {
159            let client = self.clients.get_mut(&client_id).unwrap();
160            let Some(sub) = client.subscriptions.sessions.get_mut(session_id) else {
161                continue;
162            };
163
164            if sub.paused {
165                continue;
166            }
167
168            if !matches_filter(&sub.filter, event) {
169                sub.event_index += 1;
170                continue;
171            }
172
173            // Check breakpoints
174            let hit_bp = check_breakpoints(&sub.breakpoints, event);
175            let idx = sub.event_index;
176            sub.event_index += 1;
177
178            if let Some(bp) = hit_bp {
179                sub.paused = true;
180                send(
181                    &*client.sink,
182                    DebugResponse::BreakpointHit {
183                        session_id: session_id.to_string(),
184                        breakpoint_id: bp.id.clone(),
185                        event: event_info.clone(),
186                        event_index: idx,
187                    },
188                );
189            }
190
191            send(
192                &*client.sink,
193                DebugResponse::Event {
194                    session_id: session_id.to_string(),
195                    index: idx,
196                    event: event_info.clone(),
197                },
198            );
199        }
200    }
201
202    // ======================== Command Handlers ========================
203
204    fn handle_list_sessions(
205        &self,
206        client_id: &str,
207        limit: Option<usize>,
208        active_only: Option<bool>,
209        tag_filter: Option<HashMap<String, String>>,
210    ) -> Result<(), String> {
211        let limit = limit.unwrap_or(50);
212        let filter = tag_filter.unwrap_or_default();
213        let sessions = if active_only.unwrap_or(false) {
214            self.session_registry.list_active_sessions_tagged(limit, &filter)
215        } else {
216            self.session_registry.list_sessions_tagged(limit, &filter)
217        };
218
219        let summaries: Vec<SessionSummary> = sessions
220            .iter()
221            .map(|s| session_summary(s, &self.session_registry))
222            .collect();
223
224        send_to(
225            &self.clients,
226            client_id,
227            DebugResponse::SessionList {
228                sessions: summaries,
229            },
230        );
231        Ok(())
232    }
233
234    fn handle_subscribe(
235        &mut self,
236        client_id: &str,
237        session_id: String,
238        mode: crate::debug_command::SubscriptionMode,
239        from_index: Option<usize>,
240    ) -> Result<(), String> {
241        let session = self
242            .session_registry
243            .get_session(&session_id)
244            .ok_or_else(|| format!("Session not found: {session_id}"))?;
245
246        let events = session.event_store.events();
247        let computed = compute_state(&events);
248        let structure = build_net_structure(session);
249        let from_index = from_index.unwrap_or(0);
250
251        let mode_str = match mode {
252            crate::debug_command::SubscriptionMode::Live => "live",
253            crate::debug_command::SubscriptionMode::Replay => "replay",
254        };
255
256        let current_marking = computed
257            .marking
258            .iter()
259            .map(|(k, v)| (k.clone(), v.clone()))
260            .collect();
261
262        let client = self.clients.get(client_id).unwrap();
263        send(
264            &*client.sink,
265            DebugResponse::Subscribed {
266                session_id: session_id.clone(),
267                net_name: session.net_name.clone(),
268                dot_diagram: session.dot_diagram.clone(),
269                structure,
270                current_marking,
271                enabled_transitions: computed.enabled_transitions.clone(),
272                in_flight_transitions: computed.in_flight_transitions.clone(),
273                event_count: session.event_store.event_count(),
274                mode: mode_str.into(),
275            },
276        );
277
278        // Send historical events
279        let historical = session.event_store.events_from(from_index);
280        let converted: Vec<NetEventInfo> = historical.iter().map(|e| to_event_info(e)).collect();
281        send_in_batches(
282            &self.clients,
283            client_id,
284            &session_id,
285            from_index,
286            &converted,
287        );
288
289        let event_index = from_index + historical.len();
290        let paused = matches!(mode, crate::debug_command::SubscriptionMode::Replay);
291
292        let client = self.clients.get_mut(client_id).unwrap();
293        client
294            .subscriptions
295            .add_subscription(session_id, event_index, paused);
296
297        Ok(())
298    }
299
300    fn handle_unsubscribe(&mut self, client_id: &str, session_id: String) -> Result<(), String> {
301        if let Some(client) = self.clients.get_mut(client_id) {
302            client.subscriptions.cancel(&session_id);
303        }
304        send_to(
305            &self.clients,
306            client_id,
307            DebugResponse::Unsubscribed { session_id },
308        );
309        Ok(())
310    }
311
312    fn handle_seek(
313        &mut self,
314        client_id: &str,
315        session_id: String,
316        timestamp: String,
317    ) -> Result<(), String> {
318        let session = self
319            .session_registry
320            .get_session(&session_id)
321            .ok_or("Session not found")?;
322
323        let events = session.event_store.events();
324        let target_ts: u64 = timestamp.parse().unwrap_or(0);
325
326        let mut target_index = events.len();
327        for (i, e) in events.iter().enumerate() {
328            if e.timestamp() >= target_ts {
329                target_index = i;
330                break;
331            }
332        }
333
334        let client = self.clients.get_mut(client_id).unwrap();
335        client
336            .subscriptions
337            .set_event_index(&session_id, target_index);
338        let computed = client
339            .subscriptions
340            .compute_state_at(&events, &session_id, target_index);
341
342        send(
343            &*client.sink,
344            DebugResponse::MarkingSnapshot {
345                session_id,
346                marking: computed.marking,
347                enabled_transitions: computed.enabled_transitions,
348                in_flight_transitions: computed.in_flight_transitions,
349            },
350        );
351        Ok(())
352    }
353
354    fn handle_playback_speed(
355        &mut self,
356        client_id: &str,
357        session_id: String,
358        speed: f64,
359    ) -> Result<(), String> {
360        let client = self.clients.get_mut(client_id).unwrap();
361        client.subscriptions.set_speed(&session_id, speed);
362        let paused = client.subscriptions.is_paused(&session_id);
363        let current_index = client.subscriptions.get_event_index(&session_id);
364        send(
365            &*client.sink,
366            DebugResponse::PlaybackStateChanged {
367                session_id,
368                paused,
369                speed,
370                current_index,
371            },
372        );
373        Ok(())
374    }
375
376    fn handle_set_filter(
377        &mut self,
378        client_id: &str,
379        session_id: String,
380        filter: EventFilter,
381    ) -> Result<(), String> {
382        let client = self.clients.get_mut(client_id).unwrap();
383        client.subscriptions.set_filter(&session_id, filter.clone());
384        send(
385            &*client.sink,
386            DebugResponse::FilterApplied { session_id, filter },
387        );
388        Ok(())
389    }
390
391    fn handle_pause(&mut self, client_id: &str, session_id: String) -> Result<(), String> {
392        let client = self.clients.get_mut(client_id).unwrap();
393        client.subscriptions.set_paused(&session_id, true);
394        let speed = client.subscriptions.get_speed(&session_id);
395        let current_index = client.subscriptions.get_event_index(&session_id);
396        send(
397            &*client.sink,
398            DebugResponse::PlaybackStateChanged {
399                session_id,
400                paused: true,
401                speed,
402                current_index,
403            },
404        );
405        Ok(())
406    }
407
408    fn handle_resume(&mut self, client_id: &str, session_id: String) -> Result<(), String> {
409        let client = self.clients.get_mut(client_id).unwrap();
410        client.subscriptions.set_paused(&session_id, false);
411        let speed = client.subscriptions.get_speed(&session_id);
412        let current_index = client.subscriptions.get_event_index(&session_id);
413        send(
414            &*client.sink,
415            DebugResponse::PlaybackStateChanged {
416                session_id,
417                paused: false,
418                speed,
419                current_index,
420            },
421        );
422        Ok(())
423    }
424
425    fn handle_step_forward(&mut self, client_id: &str, session_id: String) -> Result<(), String> {
426        let session = self
427            .session_registry
428            .get_session(&session_id)
429            .ok_or("Session not found")?;
430
431        let events = session.event_store.events();
432        let client = self.clients.get_mut(client_id).unwrap();
433        let current_index = client.subscriptions.get_event_index(&session_id);
434
435        if current_index < events.len() {
436            let event_info = to_event_info(&events[current_index]);
437            send(
438                &*client.sink,
439                DebugResponse::Event {
440                    session_id: session_id.clone(),
441                    index: current_index,
442                    event: event_info,
443                },
444            );
445            client
446                .subscriptions
447                .set_event_index(&session_id, current_index + 1);
448        }
449        Ok(())
450    }
451
452    fn handle_step_backward(&mut self, client_id: &str, session_id: String) -> Result<(), String> {
453        let session = self
454            .session_registry
455            .get_session(&session_id)
456            .ok_or("Session not found")?;
457
458        let events = session.event_store.events();
459        let client = self.clients.get_mut(client_id).unwrap();
460        let current_index = client.subscriptions.get_event_index(&session_id);
461
462        if current_index > 0 {
463            let new_index = current_index - 1;
464            client.subscriptions.set_event_index(&session_id, new_index);
465            let computed = client
466                .subscriptions
467                .compute_state_at(&events, &session_id, new_index);
468
469            send(
470                &*client.sink,
471                DebugResponse::MarkingSnapshot {
472                    session_id,
473                    marking: computed.marking,
474                    enabled_transitions: computed.enabled_transitions,
475                    in_flight_transitions: computed.in_flight_transitions,
476                },
477            );
478        }
479        Ok(())
480    }
481
482    fn handle_set_breakpoint(
483        &mut self,
484        client_id: &str,
485        session_id: String,
486        breakpoint: BreakpointConfig,
487    ) -> Result<(), String> {
488        let client = self.clients.get_mut(client_id).unwrap();
489        client
490            .subscriptions
491            .add_breakpoint(&session_id, breakpoint.clone());
492        send(
493            &*client.sink,
494            DebugResponse::BreakpointSet {
495                session_id,
496                breakpoint,
497            },
498        );
499        Ok(())
500    }
501
502    fn handle_clear_breakpoint(
503        &mut self,
504        client_id: &str,
505        session_id: String,
506        breakpoint_id: String,
507    ) -> Result<(), String> {
508        let client = self.clients.get_mut(client_id).unwrap();
509        client
510            .subscriptions
511            .remove_breakpoint(&session_id, &breakpoint_id);
512        send(
513            &*client.sink,
514            DebugResponse::BreakpointCleared {
515                session_id,
516                breakpoint_id,
517            },
518        );
519        Ok(())
520    }
521
522    fn handle_list_breakpoints(&self, client_id: &str, session_id: String) -> Result<(), String> {
523        let client = self.clients.get(client_id).unwrap();
524        let breakpoints = client.subscriptions.get_breakpoints(&session_id);
525        send(
526            &*client.sink,
527            DebugResponse::BreakpointList {
528                session_id,
529                breakpoints,
530            },
531        );
532        Ok(())
533    }
534}
535
536// ======================== Helper Functions ========================
537
538fn send(sink: &dyn ResponseSink, response: DebugResponse) {
539    sink.send(response);
540}
541
542fn send_to(clients: &HashMap<String, ClientState>, client_id: &str, response: DebugResponse) {
543    if let Some(client) = clients.get(client_id) {
544        send(&*client.sink, response);
545    }
546}
547
548fn send_in_batches(
549    clients: &HashMap<String, ClientState>,
550    client_id: &str,
551    session_id: &str,
552    start_index: usize,
553    events: &[NetEventInfo],
554) {
555    let Some(client) = clients.get(client_id) else {
556        return;
557    };
558
559    if events.is_empty() {
560        send(
561            &*client.sink,
562            DebugResponse::EventBatch {
563                session_id: session_id.to_string(),
564                start_index,
565                events: vec![],
566                has_more: false,
567            },
568        );
569        return;
570    }
571
572    for (i, chunk) in events.chunks(BATCH_SIZE).enumerate() {
573        let chunk_start = start_index + i * BATCH_SIZE;
574        let has_more = chunk_start + chunk.len() < start_index + events.len();
575        send(
576            &*client.sink,
577            DebugResponse::EventBatch {
578                session_id: session_id.to_string(),
579                start_index: chunk_start,
580                events: chunk.to_vec(),
581                has_more,
582            },
583        );
584    }
585}
586
587fn session_summary(session: &DebugSession, registry: &DebugSessionRegistry) -> SessionSummary {
588    SessionSummary {
589        session_id: session.session_id.clone(),
590        net_name: session.net_name.clone(),
591        start_time: session.start_time.to_string(),
592        active: session.active,
593        event_count: session.event_store.event_count(),
594        tags: registry.tags_for(&session.session_id),
595        end_time: session.end_time.map(|t| t.to_string()),
596        duration_ms: session.duration_ms(),
597    }
598}
599
600fn matches_filter(filter: &Option<EventFilter>, event: &NetEvent) -> bool {
601    let Some(filter) = filter else { return true };
602
603    // Event type: include then exclude
604    let event_type = event_type_name(event);
605    if let Some(ref types) = filter.event_types {
606        if !types.is_empty() && !types.iter().any(|t| t == event_type) {
607            return false;
608        }
609    }
610    if let Some(ref types) = filter.exclude_event_types {
611        if !types.is_empty() && types.iter().any(|t| t == event_type) {
612            return false;
613        }
614    }
615
616    // Transition name: extract once, include then exclude
617    let need_transition = filter.transition_names.as_ref().is_some_and(|n| !n.is_empty())
618        || filter.exclude_transition_names.as_ref().is_some_and(|n| !n.is_empty());
619    if need_transition {
620        let t_name = extract_transition_name(event);
621        if let Some(ref names) = filter.transition_names {
622            if !names.is_empty() {
623                match t_name {
624                    Some(n) => {
625                        if !names.iter().any(|t| t == n) {
626                            return false;
627                        }
628                    }
629                    None => return false,
630                }
631            }
632        }
633        if let Some(ref names) = filter.exclude_transition_names {
634            if !names.is_empty() {
635                if let Some(n) = t_name {
636                    if names.iter().any(|t| t == n) {
637                        return false;
638                    }
639                }
640            }
641        }
642    }
643
644    // Place name: extract once, include then exclude
645    let need_place = filter.place_names.as_ref().is_some_and(|n| !n.is_empty())
646        || filter.exclude_place_names.as_ref().is_some_and(|n| !n.is_empty());
647    if need_place {
648        let p_name = extract_place_name(event);
649        if let Some(ref names) = filter.place_names {
650            if !names.is_empty() {
651                match p_name {
652                    Some(n) => {
653                        if !names.iter().any(|t| t == n) {
654                            return false;
655                        }
656                    }
657                    None => return false,
658                }
659            }
660        }
661        if let Some(ref names) = filter.exclude_place_names {
662            if !names.is_empty() {
663                if let Some(n) = p_name {
664                    if names.iter().any(|t| t == n) {
665                        return false;
666                    }
667                }
668            }
669        }
670    }
671
672    true
673}
674
675fn matches_breakpoint(bp: &BreakpointConfig, event: &NetEvent) -> bool {
676    if !bp.enabled {
677        return false;
678    }
679    match bp.bp_type {
680        BreakpointType::TransitionEnabled => {
681            matches!(event, NetEvent::TransitionEnabled { transition_name, .. }
682                if bp.target.as_ref().is_none_or(|t| t == transition_name.as_ref()))
683        }
684        BreakpointType::TransitionStart => {
685            matches!(event, NetEvent::TransitionStarted { transition_name, .. }
686                if bp.target.as_ref().is_none_or(|t| t == transition_name.as_ref()))
687        }
688        BreakpointType::TransitionComplete => {
689            matches!(event, NetEvent::TransitionCompleted { transition_name, .. }
690                if bp.target.as_ref().is_none_or(|t| t == transition_name.as_ref()))
691        }
692        BreakpointType::TransitionFail => {
693            matches!(event, NetEvent::TransitionFailed { transition_name, .. }
694                if bp.target.as_ref().is_none_or(|t| t == transition_name.as_ref()))
695        }
696        BreakpointType::TokenAdded => {
697            matches!(event, NetEvent::TokenAdded { place_name, .. }
698                if bp.target.as_ref().is_none_or(|t| t == place_name.as_ref()))
699        }
700        BreakpointType::TokenRemoved => {
701            matches!(event, NetEvent::TokenRemoved { place_name, .. }
702                if bp.target.as_ref().is_none_or(|t| t == place_name.as_ref()))
703        }
704    }
705}
706
707fn check_breakpoints(
708    breakpoints: &HashMap<String, BreakpointConfig>,
709    event: &NetEvent,
710) -> Option<BreakpointConfig> {
711    for bp in breakpoints.values() {
712        if matches_breakpoint(bp, event) {
713            return Some(bp.clone());
714        }
715    }
716    None
717}
718
719// ======================== Subscription State ========================
720
721struct SessionSubscription {
722    event_index: usize,
723    marking_cache: MarkingCache,
724    breakpoints: HashMap<String, BreakpointConfig>,
725    paused: bool,
726    speed: f64,
727    filter: Option<EventFilter>,
728}
729
730struct SubscriptionState {
731    sessions: HashMap<String, SessionSubscription>,
732}
733
734impl SubscriptionState {
735    fn new() -> Self {
736        Self {
737            sessions: HashMap::new(),
738        }
739    }
740
741    fn add_subscription(&mut self, session_id: String, event_index: usize, paused: bool) {
742        self.sessions.insert(
743            session_id,
744            SessionSubscription {
745                event_index,
746                marking_cache: MarkingCache::new(),
747                breakpoints: HashMap::new(),
748                paused,
749                speed: 1.0,
750                filter: None,
751            },
752        );
753    }
754
755    fn cancel(&mut self, session_id: &str) {
756        self.sessions.remove(session_id);
757    }
758
759    fn is_paused(&self, session_id: &str) -> bool {
760        self.sessions.get(session_id).is_some_and(|s| s.paused)
761    }
762
763    fn set_paused(&mut self, session_id: &str, paused: bool) {
764        if let Some(sub) = self.sessions.get_mut(session_id) {
765            sub.paused = paused;
766        }
767    }
768
769    fn get_speed(&self, session_id: &str) -> f64 {
770        self.sessions.get(session_id).map_or(1.0, |s| s.speed)
771    }
772
773    fn set_speed(&mut self, session_id: &str, speed: f64) {
774        if let Some(sub) = self.sessions.get_mut(session_id) {
775            sub.speed = speed;
776        }
777    }
778
779    fn get_event_index(&self, session_id: &str) -> usize {
780        self.sessions.get(session_id).map_or(0, |s| s.event_index)
781    }
782
783    fn set_event_index(&mut self, session_id: &str, index: usize) {
784        if let Some(sub) = self.sessions.get_mut(session_id) {
785            sub.event_index = index;
786        }
787    }
788
789    fn compute_state_at(
790        &mut self,
791        events: &[NetEvent],
792        session_id: &str,
793        target_index: usize,
794    ) -> crate::marking_cache::ComputedState {
795        if let Some(sub) = self.sessions.get_mut(session_id) {
796            sub.marking_cache.compute_at(events, target_index)
797        } else {
798            compute_state(&events[..target_index.min(events.len())])
799        }
800    }
801
802    fn set_filter(&mut self, session_id: &str, filter: EventFilter) {
803        if let Some(sub) = self.sessions.get_mut(session_id) {
804            sub.filter = Some(filter);
805        }
806    }
807
808    fn add_breakpoint(&mut self, session_id: &str, breakpoint: BreakpointConfig) {
809        if let Some(sub) = self.sessions.get_mut(session_id) {
810            sub.breakpoints.insert(breakpoint.id.clone(), breakpoint);
811        }
812    }
813
814    fn remove_breakpoint(&mut self, session_id: &str, breakpoint_id: &str) {
815        if let Some(sub) = self.sessions.get_mut(session_id) {
816            sub.breakpoints.remove(breakpoint_id);
817        }
818    }
819
820    fn get_breakpoints(&self, session_id: &str) -> Vec<BreakpointConfig> {
821        self.sessions
822            .get(session_id)
823            .map_or_else(Vec::new, |s| s.breakpoints.values().cloned().collect())
824    }
825}
826
827#[cfg(test)]
828mod tests {
829    use super::*;
830    use crate::debug_event_store::DebugEventStore;
831    use std::sync::{Arc, Mutex};
832
833    fn make_handler_with_net() -> (DebugProtocolHandler, Arc<DebugEventStore>) {
834        use libpetri_core::input::one;
835        use libpetri_core::output::out_place;
836        use libpetri_core::place::Place;
837        use libpetri_core::transition::Transition;
838
839        let p1 = Place::<i32>::new("p1");
840        let p2 = Place::<i32>::new("p2");
841        let t = Transition::builder("t1")
842            .input(one(&p1))
843            .output(out_place(&p2))
844            .build();
845        let net = libpetri_core::petri_net::PetriNet::builder("test")
846            .transition(t)
847            .build();
848
849        let mut registry = DebugSessionRegistry::new();
850        let store = registry.register("s1".into(), &net);
851        let handler = DebugProtocolHandler::new(registry);
852        (handler, store)
853    }
854
855    fn collector_sink() -> (Box<dyn ResponseSink>, Arc<Mutex<Vec<DebugResponse>>>) {
856        let collected = Arc::new(Mutex::new(Vec::new()));
857        let collected_clone = Arc::clone(&collected);
858        let sink: Box<dyn ResponseSink> = Box::new(move |resp: DebugResponse| {
859            collected_clone.lock().unwrap().push(resp);
860        });
861        (sink, collected)
862    }
863
864    #[test]
865    fn list_sessions() {
866        let (mut handler, _store) = make_handler_with_net();
867        let (sink, collected) = collector_sink();
868        handler.client_connected("c1".into(), sink);
869
870        handler.handle_command(
871            "c1",
872            DebugCommand::ListSessions {
873                limit: None,
874                active_only: None,
875                tag_filter: None,
876            },
877        );
878
879        let responses = collected.lock().unwrap();
880        assert_eq!(responses.len(), 1);
881        match &responses[0] {
882            DebugResponse::SessionList { sessions } => {
883                assert_eq!(sessions.len(), 1);
884                assert_eq!(sessions[0].net_name, "test");
885            }
886            _ => panic!("expected SessionList"),
887        }
888    }
889
890    fn tagged_net() -> libpetri_core::petri_net::PetriNet {
891        use libpetri_core::input::one;
892        use libpetri_core::output::out_place;
893        use libpetri_core::place::Place;
894        use libpetri_core::transition::Transition;
895
896        let p1 = Place::<i32>::new("p1");
897        let p2 = Place::<i32>::new("p2");
898        let t = Transition::builder("t1")
899            .input(one(&p1))
900            .output(out_place(&p2))
901            .build();
902        libpetri_core::petri_net::PetriNet::builder("test")
903            .transition(t)
904            .build()
905    }
906
907    #[test]
908    fn list_sessions_filters_by_tag() {
909        let net = tagged_net();
910        let mut registry = DebugSessionRegistry::new();
911        let mut voice = HashMap::new();
912        voice.insert("channel".to_string(), "voice".to_string());
913        let mut text = HashMap::new();
914        text.insert("channel".to_string(), "text".to_string());
915
916        registry.register_with_tags("voice-1".into(), &net, voice.clone());
917        registry.register_with_tags("text-1".into(), &net, text);
918        registry.register_with_tags("voice-2".into(), &net, voice.clone());
919
920        let mut handler = DebugProtocolHandler::new(registry);
921        let (sink, collected) = collector_sink();
922        handler.client_connected("c1".into(), sink);
923
924        let mut filter = HashMap::new();
925        filter.insert("channel".to_string(), "voice".to_string());
926        handler.handle_command(
927            "c1",
928            DebugCommand::ListSessions {
929                limit: None,
930                active_only: None,
931                tag_filter: Some(filter),
932            },
933        );
934
935        let responses = collected.lock().unwrap();
936        assert_eq!(responses.len(), 1);
937        match &responses[0] {
938            DebugResponse::SessionList { sessions } => {
939                assert_eq!(sessions.len(), 2);
940                assert!(sessions.iter().all(|s| s.session_id.starts_with("voice")));
941                // Verify the new 1.6.0 wire fields are populated.
942                assert_eq!(sessions[0].tags.get("channel"), Some(&"voice".to_string()));
943            }
944            _ => panic!("expected SessionList"),
945        }
946    }
947
948    #[test]
949    fn list_sessions_populates_end_time_and_duration_ms() {
950        let net = tagged_net();
951        let mut registry = DebugSessionRegistry::new();
952        registry.register("s1".into(), &net);
953        std::thread::sleep(std::time::Duration::from_millis(2));
954        registry.complete("s1");
955
956        let mut handler = DebugProtocolHandler::new(registry);
957        let (sink, collected) = collector_sink();
958        handler.client_connected("c1".into(), sink);
959
960        handler.handle_command(
961            "c1",
962            DebugCommand::ListSessions {
963                limit: None,
964                active_only: None,
965                tag_filter: None,
966            },
967        );
968
969        let responses = collected.lock().unwrap();
970        if let DebugResponse::SessionList { sessions } = &responses[0] {
971            assert_eq!(sessions.len(), 1);
972            assert!(!sessions[0].active);
973            assert!(sessions[0].end_time.is_some());
974            assert!(sessions[0].duration_ms.is_some());
975        } else {
976            panic!("expected SessionList");
977        }
978    }
979
980    #[test]
981    fn subscribe_and_unsubscribe() {
982        let (mut handler, _store) = make_handler_with_net();
983        let (sink, collected) = collector_sink();
984        handler.client_connected("c1".into(), sink);
985
986        handler.handle_command(
987            "c1",
988            DebugCommand::Subscribe {
989                session_id: "s1".into(),
990                mode: crate::debug_command::SubscriptionMode::Live,
991                from_index: None,
992            },
993        );
994
995        {
996            let responses = collected.lock().unwrap();
997            assert!(responses.len() >= 1);
998            match &responses[0] {
999                DebugResponse::Subscribed {
1000                    session_id,
1001                    net_name,
1002                    ..
1003                } => {
1004                    assert_eq!(session_id, "s1");
1005                    assert_eq!(net_name, "test");
1006                }
1007                _ => panic!("expected Subscribed"),
1008            }
1009        }
1010
1011        handler.handle_command(
1012            "c1",
1013            DebugCommand::Unsubscribe {
1014                session_id: "s1".into(),
1015            },
1016        );
1017
1018        let responses = collected.lock().unwrap();
1019        let last = responses.last().unwrap();
1020        match last {
1021            DebugResponse::Unsubscribed { session_id } => {
1022                assert_eq!(session_id, "s1");
1023            }
1024            _ => panic!("expected Unsubscribed"),
1025        }
1026    }
1027
1028    #[test]
1029    fn subscribe_to_nonexistent_session() {
1030        let (mut handler, _store) = make_handler_with_net();
1031        let (sink, collected) = collector_sink();
1032        handler.client_connected("c1".into(), sink);
1033
1034        handler.handle_command(
1035            "c1",
1036            DebugCommand::Subscribe {
1037                session_id: "nonexistent".into(),
1038                mode: crate::debug_command::SubscriptionMode::Live,
1039                from_index: None,
1040            },
1041        );
1042
1043        let responses = collected.lock().unwrap();
1044        match &responses[0] {
1045            DebugResponse::Error { code, .. } => assert_eq!(code, "COMMAND_ERROR"),
1046            _ => panic!("expected Error"),
1047        }
1048    }
1049
1050    #[test]
1051    fn pause_and_resume() {
1052        let (mut handler, _store) = make_handler_with_net();
1053        let (sink, collected) = collector_sink();
1054        handler.client_connected("c1".into(), sink);
1055
1056        handler.handle_command(
1057            "c1",
1058            DebugCommand::Subscribe {
1059                session_id: "s1".into(),
1060                mode: crate::debug_command::SubscriptionMode::Live,
1061                from_index: None,
1062            },
1063        );
1064
1065        handler.handle_command(
1066            "c1",
1067            DebugCommand::Pause {
1068                session_id: "s1".into(),
1069            },
1070        );
1071
1072        let responses = collected.lock().unwrap();
1073        let pause_resp = responses
1074            .iter()
1075            .find(|r| matches!(r, DebugResponse::PlaybackStateChanged { paused: true, .. }));
1076        assert!(pause_resp.is_some());
1077    }
1078
1079    #[test]
1080    fn set_and_list_breakpoints() {
1081        let (mut handler, _store) = make_handler_with_net();
1082        let (sink, collected) = collector_sink();
1083        handler.client_connected("c1".into(), sink);
1084
1085        handler.handle_command(
1086            "c1",
1087            DebugCommand::Subscribe {
1088                session_id: "s1".into(),
1089                mode: crate::debug_command::SubscriptionMode::Live,
1090                from_index: None,
1091            },
1092        );
1093
1094        handler.handle_command(
1095            "c1",
1096            DebugCommand::SetBreakpoint {
1097                session_id: "s1".into(),
1098                breakpoint: BreakpointConfig {
1099                    id: "bp1".into(),
1100                    bp_type: BreakpointType::TransitionStart,
1101                    target: Some("t1".into()),
1102                    enabled: true,
1103                },
1104            },
1105        );
1106
1107        handler.handle_command(
1108            "c1",
1109            DebugCommand::ListBreakpoints {
1110                session_id: "s1".into(),
1111            },
1112        );
1113
1114        let responses = collected.lock().unwrap();
1115        let bp_list = responses
1116            .iter()
1117            .find(|r| matches!(r, DebugResponse::BreakpointList { .. }));
1118        match bp_list.unwrap() {
1119            DebugResponse::BreakpointList { breakpoints, .. } => {
1120                assert_eq!(breakpoints.len(), 1);
1121                assert_eq!(breakpoints[0].id, "bp1");
1122            }
1123            _ => unreachable!(),
1124        }
1125    }
1126
1127    #[test]
1128    fn broadcast_event_to_subscribers() {
1129        let (mut handler, store) = make_handler_with_net();
1130        let (sink, collected) = collector_sink();
1131        handler.client_connected("c1".into(), sink);
1132
1133        handler.handle_command(
1134            "c1",
1135            DebugCommand::Subscribe {
1136                session_id: "s1".into(),
1137                mode: crate::debug_command::SubscriptionMode::Live,
1138                from_index: None,
1139            },
1140        );
1141
1142        let event = NetEvent::TransitionStarted {
1143            transition_name: Arc::from("t1"),
1144            timestamp: 1000,
1145        };
1146        store.append(event.clone());
1147        handler.broadcast_event("s1", &event);
1148
1149        let responses = collected.lock().unwrap();
1150        let event_resp = responses
1151            .iter()
1152            .find(|r| matches!(r, DebugResponse::Event { .. }));
1153        assert!(event_resp.is_some());
1154    }
1155
1156    #[test]
1157    fn filter_matching() {
1158        let event = NetEvent::TransitionStarted {
1159            transition_name: Arc::from("t1"),
1160            timestamp: 0,
1161        };
1162
1163        // No filter — matches all
1164        assert!(matches_filter(&None, &event));
1165
1166        // Type filter matching
1167        let filter = EventFilter {
1168            event_types: Some(vec!["TransitionStarted".into()]),
1169            ..Default::default()
1170        };
1171        assert!(matches_filter(&Some(filter), &event));
1172
1173        // Type filter not matching
1174        let filter = EventFilter {
1175            event_types: Some(vec!["TokenAdded".into()]),
1176            ..Default::default()
1177        };
1178        assert!(!matches_filter(&Some(filter), &event));
1179
1180        // Transition name filter
1181        let filter = EventFilter {
1182            transition_names: Some(vec!["t1".into()]),
1183            ..Default::default()
1184        };
1185        assert!(matches_filter(&Some(filter), &event));
1186
1187        let filter = EventFilter {
1188            transition_names: Some(vec!["t2".into()]),
1189            ..Default::default()
1190        };
1191        assert!(!matches_filter(&Some(filter), &event));
1192    }
1193
1194    #[test]
1195    fn filter_exclusion() {
1196        let event = NetEvent::TransitionStarted {
1197            transition_name: Arc::from("t1"),
1198            timestamp: 0,
1199        };
1200
1201        // Exclude by transition name
1202        let filter = EventFilter {
1203            exclude_transition_names: Some(vec!["t1".into()]),
1204            ..Default::default()
1205        };
1206        assert!(!matches_filter(&Some(filter), &event));
1207
1208        // Exclude different transition — should pass
1209        let filter = EventFilter {
1210            exclude_transition_names: Some(vec!["t2".into()]),
1211            ..Default::default()
1212        };
1213        assert!(matches_filter(&Some(filter), &event));
1214
1215        // Exclude by event type
1216        let filter = EventFilter {
1217            exclude_event_types: Some(vec!["TransitionStarted".into()]),
1218            ..Default::default()
1219        };
1220        assert!(!matches_filter(&Some(filter), &event));
1221    }
1222
1223    #[test]
1224    fn filter_combined_include_exclude() {
1225        let event_t1 = NetEvent::TransitionStarted {
1226            transition_name: Arc::from("t1"),
1227            timestamp: 0,
1228        };
1229        let event_t2 = NetEvent::TransitionStarted {
1230            transition_name: Arc::from("t2"),
1231            timestamp: 0,
1232        };
1233
1234        // Include TransitionStarted, exclude t2
1235        let filter = EventFilter {
1236            event_types: Some(vec!["TransitionStarted".into()]),
1237            exclude_transition_names: Some(vec!["t2".into()]),
1238            ..Default::default()
1239        };
1240        assert!(matches_filter(&Some(filter.clone()), &event_t1));
1241        assert!(!matches_filter(&Some(filter), &event_t2));
1242    }
1243
1244    #[test]
1245    fn breakpoint_matching() {
1246        let event = NetEvent::TransitionStarted {
1247            transition_name: Arc::from("t1"),
1248            timestamp: 0,
1249        };
1250
1251        let bp = BreakpointConfig {
1252            id: "bp1".into(),
1253            bp_type: BreakpointType::TransitionStart,
1254            target: Some("t1".into()),
1255            enabled: true,
1256        };
1257        assert!(matches_breakpoint(&bp, &event));
1258
1259        // Disabled breakpoint
1260        let bp_disabled = BreakpointConfig {
1261            id: "bp2".into(),
1262            bp_type: BreakpointType::TransitionStart,
1263            target: Some("t1".into()),
1264            enabled: false,
1265        };
1266        assert!(!matches_breakpoint(&bp_disabled, &event));
1267
1268        // Wrong target
1269        let bp_wrong = BreakpointConfig {
1270            id: "bp3".into(),
1271            bp_type: BreakpointType::TransitionStart,
1272            target: Some("t2".into()),
1273            enabled: true,
1274        };
1275        assert!(!matches_breakpoint(&bp_wrong, &event));
1276
1277        // Wildcard (no target)
1278        let bp_wild = BreakpointConfig {
1279            id: "bp4".into(),
1280            bp_type: BreakpointType::TransitionStart,
1281            target: None,
1282            enabled: true,
1283        };
1284        assert!(matches_breakpoint(&bp_wild, &event));
1285    }
1286
1287    #[test]
1288    fn client_disconnect_cleanup() {
1289        let (mut handler, _store) = make_handler_with_net();
1290        let (sink, _collected) = collector_sink();
1291        handler.client_connected("c1".into(), sink);
1292        handler.client_disconnected("c1");
1293        assert!(handler.clients.is_empty());
1294    }
1295
1296    #[test]
1297    fn step_forward_and_backward() {
1298        let (mut handler, store) = make_handler_with_net();
1299        let (sink, collected) = collector_sink();
1300        handler.client_connected("c1".into(), sink);
1301
1302        // Add some events
1303        for i in 0..5 {
1304            store.append(NetEvent::TokenAdded {
1305                place_name: Arc::from("p1"),
1306                timestamp: i,
1307            });
1308        }
1309
1310        handler.handle_command(
1311            "c1",
1312            DebugCommand::Subscribe {
1313                session_id: "s1".into(),
1314                mode: crate::debug_command::SubscriptionMode::Replay,
1315                from_index: Some(0),
1316            },
1317        );
1318
1319        // Step forward
1320        handler.handle_command(
1321            "c1",
1322            DebugCommand::StepForward {
1323                session_id: "s1".into(),
1324            },
1325        );
1326
1327        // Step backward
1328        handler.handle_command(
1329            "c1",
1330            DebugCommand::StepBackward {
1331                session_id: "s1".into(),
1332            },
1333        );
1334
1335        let responses = collected.lock().unwrap();
1336        assert!(responses.len() >= 3); // subscribed + batch + step responses
1337    }
1338}