1use serde::{Deserialize, Serialize};
2use std::collections::HashSet;
3use std::sync::{Arc, Mutex};
4use tokio::sync::Notify;
5
6#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
8pub enum DebugState {
9 Running,
11 Paused,
13}
14
15#[derive(Clone)]
20pub struct DebugControl {
21 inner: Arc<DebugControlInner>,
22}
23
24struct DebugControlInner {
25 breakpoints: Mutex<HashSet<String>>,
26 pause_next: Mutex<bool>,
27 notify: Notify,
28 state: Mutex<DebugState>,
29}
30
31impl DebugControl {
32 pub fn new() -> Self {
34 Self {
35 inner: Arc::new(DebugControlInner {
36 breakpoints: Mutex::new(HashSet::new()),
37 pause_next: Mutex::new(false),
38 notify: Notify::new(),
39 state: Mutex::new(DebugState::Running),
40 }),
41 }
42 }
43
44 pub fn set_breakpoint(&self, node_id: String) {
46 self.inner.breakpoints.lock().expect("debug mutex poisoned").insert(node_id);
47 }
48
49 pub fn remove_breakpoint(&self, node_id: &str) {
51 self.inner.breakpoints.lock().expect("debug mutex poisoned").remove(node_id);
52 }
53
54 pub fn pause(&self) {
56 *self.inner.pause_next.lock().expect("debug mutex poisoned") = true;
57 }
58
59 pub fn resume(&self) {
61 *self.inner.pause_next.lock().expect("debug mutex poisoned") = false;
62 *self.inner.state.lock().expect("debug mutex poisoned") = DebugState::Running;
63 self.inner.notify.notify_waiters();
64 }
65
66 pub fn step(&self) {
68 *self.inner.pause_next.lock().expect("debug mutex poisoned") = true;
69 *self.inner.state.lock().expect("debug mutex poisoned") = DebugState::Running;
70 self.inner.notify.notify_waiters();
71 }
72
73 pub fn should_pause(&self, node_id: &str) -> bool {
77 let breakpoints = self.inner.breakpoints.lock().expect("debug mutex poisoned");
78 let mut pause_next = self.inner.pause_next.lock().expect("debug mutex poisoned");
79 let hit_breakpoint = breakpoints.contains(node_id);
80 let pause_requested = *pause_next;
81
82 if hit_breakpoint || pause_requested {
83 *pause_next = false; true
85 } else {
86 false
87 }
88 }
89
90 pub async fn wait(&self) {
92 *self.inner.state.lock().expect("debug mutex poisoned") = DebugState::Paused;
93 self.inner.notify.notified().await;
95 }
96
97 pub async fn wait_if_needed(&self, node_id: &str) {
101 if self.should_pause(node_id) {
102 self.wait().await;
103 }
104 }
105
106 pub fn state(&self) -> DebugState {
108 *self.inner.state.lock().expect("debug mutex poisoned")
109 }
110
111 pub fn list_breakpoints(&self) -> Vec<String> {
113 self.inner
114 .breakpoints
115 .lock()
116 .unwrap()
117 .iter()
118 .cloned()
119 .collect()
120 }
121}
122
123impl Default for DebugControl {
124 fn default() -> Self {
125 Self::new()
126 }
127}