roam_session/
diagnostic.rs

1//! Diagnostic state tracking for SIGUSR1 dumps.
2//!
3//! Tracks in-flight RPC requests and open channels to help debug hung connections.
4
5use std::collections::HashMap;
6use std::fmt::Write as _;
7use std::sync::{Arc, LazyLock, RwLock, Weak};
8use std::time::Instant;
9
10/// A callback that appends diagnostic info to a string.
11pub type DiagnosticCallback = Box<dyn Fn(&mut String) + Send + Sync>;
12
13/// Global registry of all diagnostic states.
14/// Each connection/driver registers its state here.
15static DIAGNOSTIC_REGISTRY: LazyLock<RwLock<Vec<Weak<DiagnosticState>>>> =
16    LazyLock::new(|| RwLock::new(Vec::new()));
17
18/// Method name registry - maps method_id to human-readable names.
19static METHOD_NAMES: LazyLock<RwLock<HashMap<u64, &'static str>>> =
20    LazyLock::new(|| RwLock::new(HashMap::new()));
21
22/// Whether to record extra debug info (checked once at startup).
23/// Set ROAM_DEBUG=1 to enable.
24static DEBUG_ENABLED: LazyLock<bool> = LazyLock::new(|| std::env::var("ROAM_DEBUG").is_ok());
25
26/// Check if debug recording is enabled.
27pub fn debug_enabled() -> bool {
28    *DEBUG_ENABLED
29}
30
31/// Register a method name for diagnostic display.
32pub fn register_method_name(method_id: u64, name: &'static str) {
33    if let Ok(mut names) = METHOD_NAMES.write() {
34        names.insert(method_id, name);
35    }
36}
37
38/// Look up a method name by ID.
39pub fn get_method_name(method_id: u64) -> Option<&'static str> {
40    METHOD_NAMES.read().ok()?.get(&method_id).copied()
41}
42
43/// Register a diagnostic state for SIGUSR1 dumps.
44pub fn register_diagnostic_state(state: &Arc<DiagnosticState>) {
45    if let Ok(mut registry) = DIAGNOSTIC_REGISTRY.write() {
46        // Clean up dead entries while we're here
47        registry.retain(|weak| weak.strong_count() > 0);
48        registry.push(Arc::downgrade(state));
49    }
50}
51
52/// Dump all diagnostic states to a string.
53pub fn dump_all_diagnostics() -> String {
54    let mut output = String::new();
55
56    let states: Vec<Arc<DiagnosticState>> = {
57        // Use try_read to avoid deadlocking if called from signal handler
58        let Ok(registry) = DIAGNOSTIC_REGISTRY.try_read() else {
59            return "ERROR: Could not acquire diagnostic registry lock (held by another thread)\n"
60                .to_string();
61        };
62        registry.iter().filter_map(|weak| weak.upgrade()).collect()
63    };
64
65    if states.is_empty() {
66        return String::new();
67    }
68
69    for state in states {
70        // Only include states that have something to report
71        if let Some(content) = state.dump_if_nonempty() {
72            let _ = writeln!(output, "[{}] {}", state.name, content);
73        }
74    }
75
76    output
77}
78
79/// Direction of an RPC request.
80#[derive(Debug, Clone, Copy, PartialEq, Eq)]
81pub enum RequestDirection {
82    /// We sent the request, waiting for response
83    Outgoing,
84    /// We received the request, processing it
85    Incoming,
86}
87
88/// An in-flight RPC request.
89#[derive(Debug, Clone)]
90pub struct InFlightRequest {
91    pub request_id: u64,
92    pub method_id: u64,
93    pub started: Instant,
94    pub direction: RequestDirection,
95    /// Optional structured arguments (only recorded when ROAM_DEBUG_ARGS is set).
96    pub args: Option<HashMap<String, String>>,
97}
98
99/// Direction of a channel.
100#[derive(Debug, Clone, Copy, PartialEq, Eq)]
101pub enum ChannelDirection {
102    /// We're sending on this channel
103    Tx,
104    /// We're receiving on this channel
105    Rx,
106}
107
108/// An open streaming channel.
109#[derive(Debug, Clone)]
110pub struct OpenChannel {
111    pub channel_id: u64,
112    pub started: Instant,
113    pub direction: ChannelDirection,
114    /// The request that opened this channel (if known).
115    pub request_id: Option<u64>,
116}
117
118/// Diagnostic state for a single connection.
119pub struct DiagnosticState {
120    /// Human-readable name for this connection (e.g., "cell-http", "host→markdown")
121    pub name: String,
122
123    /// In-flight requests
124    requests: RwLock<HashMap<u64, InFlightRequest>>,
125
126    /// Open channels
127    channels: RwLock<HashMap<u64, OpenChannel>>,
128
129    /// Custom diagnostic callbacks
130    custom_diagnostics: RwLock<Vec<DiagnosticCallback>>,
131}
132
133impl DiagnosticState {
134    /// Create a new diagnostic state.
135    pub fn new(name: impl Into<String>) -> Self {
136        Self {
137            name: name.into(),
138            requests: RwLock::new(HashMap::new()),
139            channels: RwLock::new(HashMap::new()),
140            custom_diagnostics: RwLock::new(Vec::new()),
141        }
142    }
143
144    /// Record an outgoing request (we're calling remote).
145    pub fn record_outgoing_request(
146        &self,
147        request_id: u64,
148        method_id: u64,
149        args: Option<HashMap<String, String>>,
150    ) {
151        if let Ok(mut requests) = self.requests.write() {
152            requests.insert(
153                request_id,
154                InFlightRequest {
155                    request_id,
156                    method_id,
157                    started: Instant::now(),
158                    direction: RequestDirection::Outgoing,
159                    args,
160                },
161            );
162        }
163    }
164
165    /// Record an incoming request (remote is calling us).
166    pub fn record_incoming_request(
167        &self,
168        request_id: u64,
169        method_id: u64,
170        args: Option<HashMap<String, String>>,
171    ) {
172        if let Ok(mut requests) = self.requests.write() {
173            requests.insert(
174                request_id,
175                InFlightRequest {
176                    request_id,
177                    method_id,
178                    started: Instant::now(),
179                    direction: RequestDirection::Incoming,
180                    args,
181                },
182            );
183        }
184    }
185
186    /// Mark a request as completed.
187    pub fn complete_request(&self, request_id: u64) {
188        if let Ok(mut requests) = self.requests.write() {
189            requests.remove(&request_id);
190        }
191    }
192
193    /// Record a channel being opened.
194    pub fn record_channel_open(
195        &self,
196        channel_id: u64,
197        direction: ChannelDirection,
198        request_id: Option<u64>,
199    ) {
200        if let Ok(mut channels) = self.channels.write() {
201            channels.insert(
202                channel_id,
203                OpenChannel {
204                    channel_id,
205                    started: Instant::now(),
206                    direction,
207                    request_id,
208                },
209            );
210        }
211    }
212
213    /// Record a channel being closed.
214    pub fn record_channel_close(&self, channel_id: u64) {
215        if let Ok(mut channels) = self.channels.write() {
216            channels.remove(&channel_id);
217        }
218    }
219
220    /// Associate channels with a request (called after channels are opened but before request is sent).
221    pub fn associate_channels_with_request(&self, channel_ids: &[u64], request_id: u64) {
222        if let Ok(mut channels) = self.channels.write() {
223            for &channel_id in channel_ids {
224                if let Some(channel) = channels.get_mut(&channel_id) {
225                    channel.request_id = Some(request_id);
226                }
227            }
228        }
229    }
230
231    /// Add a custom diagnostic callback.
232    pub fn add_custom_diagnostic<F>(&self, callback: F)
233    where
234        F: Fn(&mut String) + Send + Sync + 'static,
235    {
236        if let Ok(mut diagnostics) = self.custom_diagnostics.write() {
237            diagnostics.push(Box::new(callback));
238        }
239    }
240
241    /// Dump this state if non-empty, returning None if there's nothing to report.
242    /// Uses try_read() to avoid deadlocking when called from signal handlers.
243    /// Output is compact: single line for summary, one line per request.
244    pub fn dump_if_nonempty(&self) -> Option<String> {
245        let now = Instant::now();
246        let mut parts = Vec::new();
247        let mut details = Vec::new();
248
249        // Check requests
250        if let Ok(requests) = self.requests.try_read() {
251            let mut outgoing: Vec<_> = requests
252                .values()
253                .filter(|r| r.direction == RequestDirection::Outgoing)
254                .collect();
255            let mut incoming: Vec<_> = requests
256                .values()
257                .filter(|r| r.direction == RequestDirection::Incoming)
258                .collect();
259
260            outgoing.sort_by_key(|r| std::cmp::Reverse(r.started));
261            incoming.sort_by_key(|r| std::cmp::Reverse(r.started));
262
263            if !outgoing.is_empty() {
264                parts.push(format!("{}⬆", outgoing.len()));
265                for req in outgoing {
266                    let elapsed = now.duration_since(req.started);
267                    let method_name = get_method_name(req.method_id).unwrap_or("?");
268                    let mut line = format!(
269                        "  ⬆#{} {} {:.1}s",
270                        req.request_id,
271                        method_name,
272                        elapsed.as_secs_f64()
273                    );
274                    if let Some(args) = &req.args {
275                        let args_str: Vec<_> =
276                            args.iter().map(|(k, v)| format!("{}={}", k, v)).collect();
277                        if !args_str.is_empty() {
278                            let _ = write!(line, " ({})", args_str.join(", "));
279                        }
280                    }
281                    details.push(line);
282                }
283            }
284
285            if !incoming.is_empty() {
286                parts.push(format!("{}⬇", incoming.len()));
287                for req in incoming {
288                    let elapsed = now.duration_since(req.started);
289                    let method_name = get_method_name(req.method_id).unwrap_or("?");
290                    let mut line = format!(
291                        "  ⬇#{} {} {:.1}s",
292                        req.request_id,
293                        method_name,
294                        elapsed.as_secs_f64()
295                    );
296                    if let Some(args) = &req.args {
297                        let args_str: Vec<_> =
298                            args.iter().map(|(k, v)| format!("{}={}", k, v)).collect();
299                        if !args_str.is_empty() {
300                            let _ = write!(line, " ({})", args_str.join(", "));
301                        }
302                    }
303                    details.push(line);
304                }
305            }
306        }
307
308        // Check channels
309        if let Ok(channels) = self.channels.try_read()
310            && !channels.is_empty()
311        {
312            let tx_count = channels
313                .values()
314                .filter(|c| c.direction == ChannelDirection::Tx)
315                .count();
316            let rx_count = channels
317                .values()
318                .filter(|c| c.direction == ChannelDirection::Rx)
319                .count();
320            if tx_count > 0 {
321                parts.push(format!("{}tx", tx_count));
322            }
323            if rx_count > 0 {
324                parts.push(format!("{}rx", rx_count));
325            }
326        }
327
328        if parts.is_empty() {
329            return None;
330        }
331
332        let mut output = parts.join(" ");
333        for detail in details {
334            let _ = write!(output, "\n{}", detail);
335        }
336        Some(output)
337    }
338}
339
340impl std::fmt::Debug for DiagnosticState {
341    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
342        f.debug_struct("DiagnosticState")
343            .field("name", &self.name)
344            .finish_non_exhaustive()
345    }
346}