mcpkit_core/debug/
inspector.rs

1//! Message inspection utilities.
2//!
3//! The message inspector captures protocol messages for analysis
4//! and debugging.
5
6use crate::protocol::{Message, RequestId};
7use std::collections::HashMap;
8use std::sync::{Arc, RwLock};
9use std::time::{Duration, Instant};
10
11/// A record of a captured message.
12#[derive(Debug, Clone)]
13pub struct MessageRecord {
14    /// Timestamp when the message was captured.
15    pub timestamp: Instant,
16    /// The direction of the message.
17    pub direction: MessageDirection,
18    /// The captured message.
19    pub message: Message,
20    /// Size in bytes (approximate).
21    pub size_bytes: usize,
22    /// Optional context/tags.
23    pub tags: Vec<String>,
24}
25
26impl MessageRecord {
27    /// Create a new message record.
28    #[must_use]
29    pub fn new(direction: MessageDirection, message: Message) -> Self {
30        let size_bytes = serde_json::to_string(&message)
31            .map(|s| s.len())
32            .unwrap_or(0);
33
34        Self {
35            timestamp: Instant::now(),
36            direction,
37            message,
38            size_bytes,
39            tags: Vec::new(),
40        }
41    }
42
43    /// Add a tag to the record.
44    #[must_use]
45    pub fn with_tag(mut self, tag: impl Into<String>) -> Self {
46        self.tags.push(tag.into());
47        self
48    }
49
50    /// Get the method name if applicable.
51    #[must_use]
52    pub fn method(&self) -> Option<&str> {
53        self.message.method()
54    }
55
56    /// Get the request ID if applicable.
57    #[must_use]
58    pub fn request_id(&self) -> Option<&RequestId> {
59        match &self.message {
60            Message::Request(req) => Some(&req.id),
61            Message::Response(res) => Some(&res.id),
62            Message::Notification(_) => None,
63        }
64    }
65
66    /// Check if this is a request message.
67    #[must_use]
68    pub fn is_request(&self) -> bool {
69        matches!(self.message, Message::Request(_))
70    }
71
72    /// Check if this is a response message.
73    #[must_use]
74    pub fn is_response(&self) -> bool {
75        matches!(self.message, Message::Response(_))
76    }
77
78    /// Check if this is a notification message.
79    #[must_use]
80    pub fn is_notification(&self) -> bool {
81        matches!(self.message, Message::Notification(_))
82    }
83
84    /// Check if the response indicates an error.
85    #[must_use]
86    pub fn is_error(&self) -> bool {
87        matches!(&self.message, Message::Response(r) if r.error.is_some())
88    }
89}
90
91/// Direction of a message.
92#[derive(Debug, Clone, Copy, PartialEq, Eq)]
93pub enum MessageDirection {
94    /// Message sent by client.
95    Outbound,
96    /// Message received by client.
97    Inbound,
98}
99
100/// Statistics about captured messages.
101#[derive(Debug, Clone, Default)]
102pub struct MessageStats {
103    /// Total messages captured.
104    pub total_messages: usize,
105    /// Total requests.
106    pub requests: usize,
107    /// Total responses.
108    pub responses: usize,
109    /// Total notifications.
110    pub notifications: usize,
111    /// Total errors (error responses).
112    pub errors: usize,
113    /// Total bytes transferred.
114    pub total_bytes: usize,
115    /// Messages by method.
116    pub by_method: HashMap<String, usize>,
117    /// Average response time per method.
118    pub avg_response_time: HashMap<String, Duration>,
119}
120
121impl MessageStats {
122    /// Calculate error rate.
123    #[must_use]
124    pub fn error_rate(&self) -> f64 {
125        if self.responses == 0 {
126            0.0
127        } else {
128            self.errors as f64 / self.responses as f64
129        }
130    }
131}
132
133/// Message inspector for capturing and analyzing protocol traffic.
134///
135/// The inspector can be used as a debugging aid during development
136/// to understand the message flow between client and server.
137#[derive(Debug)]
138pub struct MessageInspector {
139    /// Captured messages.
140    records: Arc<RwLock<Vec<MessageRecord>>>,
141    /// Maximum number of records to keep (0 = unlimited).
142    max_records: usize,
143    /// Whether capturing is enabled.
144    enabled: Arc<RwLock<bool>>,
145    /// Pending requests for response time tracking.
146    pending_requests: Arc<RwLock<HashMap<RequestId, (Instant, String)>>>,
147    /// Response times by method.
148    response_times: Arc<RwLock<HashMap<String, Vec<Duration>>>>,
149}
150
151impl Default for MessageInspector {
152    fn default() -> Self {
153        Self::new()
154    }
155}
156
157impl MessageInspector {
158    /// Create a new message inspector.
159    #[must_use]
160    pub fn new() -> Self {
161        Self {
162            records: Arc::new(RwLock::new(Vec::new())),
163            max_records: 10000,
164            enabled: Arc::new(RwLock::new(true)),
165            pending_requests: Arc::new(RwLock::new(HashMap::new())),
166            response_times: Arc::new(RwLock::new(HashMap::new())),
167        }
168    }
169
170    /// Create an inspector with a maximum record limit.
171    #[must_use]
172    pub fn with_max_records(mut self, max: usize) -> Self {
173        self.max_records = max;
174        self
175    }
176
177    /// Enable or disable message capture.
178    pub fn set_enabled(&self, enabled: bool) {
179        if let Ok(mut flag) = self.enabled.write() {
180            *flag = enabled;
181        }
182    }
183
184    /// Check if capturing is enabled.
185    #[must_use]
186    pub fn is_enabled(&self) -> bool {
187        self.enabled.read().map(|e| *e).unwrap_or(false)
188    }
189
190    /// Record an outbound message.
191    pub fn record_outbound(&self, message: Message) {
192        self.record(MessageDirection::Outbound, message);
193    }
194
195    /// Record an inbound message.
196    pub fn record_inbound(&self, message: Message) {
197        self.record(MessageDirection::Inbound, message);
198    }
199
200    /// Record a message.
201    fn record(&self, direction: MessageDirection, message: Message) {
202        if !self.is_enabled() {
203            return;
204        }
205
206        let record = MessageRecord::new(direction, message.clone());
207
208        // Track pending requests for response time calculation
209        if let Message::Request(ref req) = message {
210            if let Ok(mut pending) = self.pending_requests.write() {
211                pending.insert(req.id.clone(), (Instant::now(), req.method.to_string()));
212            }
213        }
214
215        // Calculate response time for completed requests
216        if let Message::Response(ref res) = message {
217            if let Ok(mut pending) = self.pending_requests.write() {
218                if let Some((start, method)) = pending.remove(&res.id) {
219                    let duration = start.elapsed();
220                    if let Ok(mut times) = self.response_times.write() {
221                        times.entry(method).or_default().push(duration);
222                    }
223                }
224            }
225        }
226
227        // Store the record
228        if let Ok(mut records) = self.records.write() {
229            records.push(record);
230
231            // Trim if over limit
232            if self.max_records > 0 && records.len() > self.max_records {
233                let excess = records.len() - self.max_records;
234                records.drain(0..excess);
235            }
236        }
237    }
238
239    /// Get all captured records.
240    #[must_use]
241    pub fn records(&self) -> Vec<MessageRecord> {
242        self.records.read().map(|r| r.clone()).unwrap_or_default()
243    }
244
245    /// Get the number of captured records.
246    #[must_use]
247    pub fn len(&self) -> usize {
248        self.records.read().map(|r| r.len()).unwrap_or(0)
249    }
250
251    /// Check if there are no captured records.
252    #[must_use]
253    pub fn is_empty(&self) -> bool {
254        self.len() == 0
255    }
256
257    /// Clear all captured records.
258    pub fn clear(&self) {
259        if let Ok(mut records) = self.records.write() {
260            records.clear();
261        }
262        if let Ok(mut pending) = self.pending_requests.write() {
263            pending.clear();
264        }
265        if let Ok(mut times) = self.response_times.write() {
266            times.clear();
267        }
268    }
269
270    /// Get message statistics.
271    #[must_use]
272    #[allow(clippy::field_reassign_with_default)]
273    pub fn stats(&self) -> MessageStats {
274        let records = self.records();
275        let mut stats = MessageStats::default();
276
277        stats.total_messages = records.len();
278
279        for record in &records {
280            stats.total_bytes += record.size_bytes;
281
282            match &record.message {
283                Message::Request(req) => {
284                    stats.requests += 1;
285                    *stats.by_method.entry(req.method.to_string()).or_insert(0) += 1;
286                }
287                Message::Response(res) => {
288                    stats.responses += 1;
289                    if res.error.is_some() {
290                        stats.errors += 1;
291                    }
292                }
293                Message::Notification(notif) => {
294                    stats.notifications += 1;
295                    *stats.by_method.entry(notif.method.to_string()).or_insert(0) += 1;
296                }
297            }
298        }
299
300        // Calculate average response times
301        if let Ok(times) = self.response_times.read() {
302            for (method, durations) in times.iter() {
303                if !durations.is_empty() {
304                    let total: Duration = durations.iter().sum();
305                    let avg = total / durations.len() as u32;
306                    stats.avg_response_time.insert(method.clone(), avg);
307                }
308            }
309        }
310
311        stats
312    }
313
314    /// Find requests without responses.
315    #[must_use]
316    pub fn pending_requests(&self) -> Vec<(RequestId, String, Duration)> {
317        self.pending_requests
318            .read()
319            .map(|pending| {
320                pending
321                    .iter()
322                    .map(|(id, (start, method))| (id.clone(), method.clone(), start.elapsed()))
323                    .collect()
324            })
325            .unwrap_or_default()
326    }
327
328    /// Filter records by method.
329    #[must_use]
330    pub fn filter_by_method(&self, method: &str) -> Vec<MessageRecord> {
331        self.records()
332            .into_iter()
333            .filter(|r| r.method() == Some(method))
334            .collect()
335    }
336
337    /// Filter records by direction.
338    #[must_use]
339    pub fn filter_by_direction(&self, direction: MessageDirection) -> Vec<MessageRecord> {
340        self.records()
341            .into_iter()
342            .filter(|r| r.direction == direction)
343            .collect()
344    }
345
346    /// Get only error responses.
347    #[must_use]
348    pub fn errors(&self) -> Vec<MessageRecord> {
349        self.records()
350            .into_iter()
351            .filter(MessageRecord::is_error)
352            .collect()
353    }
354
355    /// Export records to JSON.
356    ///
357    /// # Errors
358    ///
359    /// Returns an error if serialization fails.
360    pub fn to_json(&self) -> Result<String, serde_json::Error> {
361        let records: Vec<_> = self
362            .records()
363            .into_iter()
364            .map(|r| {
365                serde_json::json!({
366                    "direction": format!("{:?}", r.direction),
367                    "message": r.message,
368                    "size_bytes": r.size_bytes,
369                    "tags": r.tags,
370                })
371            })
372            .collect();
373
374        serde_json::to_string_pretty(&records)
375    }
376}
377
378impl Clone for MessageInspector {
379    fn clone(&self) -> Self {
380        Self {
381            records: Arc::clone(&self.records),
382            max_records: self.max_records,
383            enabled: Arc::clone(&self.enabled),
384            pending_requests: Arc::clone(&self.pending_requests),
385            response_times: Arc::clone(&self.response_times),
386        }
387    }
388}
389
390#[cfg(test)]
391mod tests {
392    use super::*;
393    use crate::protocol::{Request, Response};
394
395    #[test]
396    fn test_message_inspector() {
397        let inspector = MessageInspector::new();
398
399        // Record some messages
400        let req = Message::Request(Request::new("test/method", 1));
401        inspector.record_outbound(req);
402
403        let res = Message::Response(Response::success(RequestId::from(1), serde_json::json!({})));
404        inspector.record_inbound(res);
405
406        assert_eq!(inspector.len(), 2);
407
408        let stats = inspector.stats();
409        assert_eq!(stats.requests, 1);
410        assert_eq!(stats.responses, 1);
411        assert_eq!(stats.errors, 0);
412    }
413
414    #[test]
415    fn test_filter_by_method() {
416        let inspector = MessageInspector::new();
417
418        inspector.record_outbound(Message::Request(Request::new("method/a", 1)));
419        inspector.record_outbound(Message::Request(Request::new("method/b", 2)));
420        inspector.record_outbound(Message::Request(Request::new("method/a", 3)));
421
422        let filtered = inspector.filter_by_method("method/a");
423        assert_eq!(filtered.len(), 2);
424    }
425
426    #[test]
427    fn test_max_records() {
428        let inspector = MessageInspector::new().with_max_records(5);
429
430        for i in 0..10 {
431            inspector.record_outbound(Message::Request(Request::new("test", i)));
432        }
433
434        assert_eq!(inspector.len(), 5);
435    }
436
437    #[test]
438    fn test_enable_disable() {
439        let inspector = MessageInspector::new();
440
441        inspector.record_outbound(Message::Request(Request::new("test", 1)));
442        assert_eq!(inspector.len(), 1);
443
444        inspector.set_enabled(false);
445        inspector.record_outbound(Message::Request(Request::new("test", 2)));
446        assert_eq!(inspector.len(), 1); // Still 1, not recorded
447    }
448}