1use crate::protocol::{Message, RequestId};
7use std::collections::HashMap;
8use std::sync::{Arc, RwLock};
9use std::time::{Duration, Instant};
10
11#[derive(Debug, Clone)]
13pub struct MessageRecord {
14 pub timestamp: Instant,
16 pub direction: MessageDirection,
18 pub message: Message,
20 pub size_bytes: usize,
22 pub tags: Vec<String>,
24}
25
26impl MessageRecord {
27 #[must_use]
29 pub fn new(direction: MessageDirection, message: Message) -> Self {
30 let size_bytes = serde_json::to_string(&message)
31 .map(|s| s.len())
32 .unwrap_or(0);
33
34 Self {
35 timestamp: Instant::now(),
36 direction,
37 message,
38 size_bytes,
39 tags: Vec::new(),
40 }
41 }
42
43 #[must_use]
45 pub fn with_tag(mut self, tag: impl Into<String>) -> Self {
46 self.tags.push(tag.into());
47 self
48 }
49
50 #[must_use]
52 pub fn method(&self) -> Option<&str> {
53 self.message.method()
54 }
55
56 #[must_use]
58 pub fn request_id(&self) -> Option<&RequestId> {
59 match &self.message {
60 Message::Request(req) => Some(&req.id),
61 Message::Response(res) => Some(&res.id),
62 Message::Notification(_) => None,
63 }
64 }
65
66 #[must_use]
68 pub fn is_request(&self) -> bool {
69 matches!(self.message, Message::Request(_))
70 }
71
72 #[must_use]
74 pub fn is_response(&self) -> bool {
75 matches!(self.message, Message::Response(_))
76 }
77
78 #[must_use]
80 pub fn is_notification(&self) -> bool {
81 matches!(self.message, Message::Notification(_))
82 }
83
84 #[must_use]
86 pub fn is_error(&self) -> bool {
87 matches!(&self.message, Message::Response(r) if r.error.is_some())
88 }
89}
90
91#[derive(Debug, Clone, Copy, PartialEq, Eq)]
93pub enum MessageDirection {
94 Outbound,
96 Inbound,
98}
99
100#[derive(Debug, Clone, Default)]
102pub struct MessageStats {
103 pub total_messages: usize,
105 pub requests: usize,
107 pub responses: usize,
109 pub notifications: usize,
111 pub errors: usize,
113 pub total_bytes: usize,
115 pub by_method: HashMap<String, usize>,
117 pub avg_response_time: HashMap<String, Duration>,
119}
120
121impl MessageStats {
122 #[must_use]
124 pub fn error_rate(&self) -> f64 {
125 if self.responses == 0 {
126 0.0
127 } else {
128 self.errors as f64 / self.responses as f64
129 }
130 }
131}
132
133#[derive(Debug)]
138pub struct MessageInspector {
139 records: Arc<RwLock<Vec<MessageRecord>>>,
141 max_records: usize,
143 enabled: Arc<RwLock<bool>>,
145 pending_requests: Arc<RwLock<HashMap<RequestId, (Instant, String)>>>,
147 response_times: Arc<RwLock<HashMap<String, Vec<Duration>>>>,
149}
150
151impl Default for MessageInspector {
152 fn default() -> Self {
153 Self::new()
154 }
155}
156
157impl MessageInspector {
158 #[must_use]
160 pub fn new() -> Self {
161 Self {
162 records: Arc::new(RwLock::new(Vec::new())),
163 max_records: 10000,
164 enabled: Arc::new(RwLock::new(true)),
165 pending_requests: Arc::new(RwLock::new(HashMap::new())),
166 response_times: Arc::new(RwLock::new(HashMap::new())),
167 }
168 }
169
170 #[must_use]
172 pub fn with_max_records(mut self, max: usize) -> Self {
173 self.max_records = max;
174 self
175 }
176
177 pub fn set_enabled(&self, enabled: bool) {
179 if let Ok(mut flag) = self.enabled.write() {
180 *flag = enabled;
181 }
182 }
183
184 #[must_use]
186 pub fn is_enabled(&self) -> bool {
187 self.enabled.read().map(|e| *e).unwrap_or(false)
188 }
189
190 pub fn record_outbound(&self, message: Message) {
192 self.record(MessageDirection::Outbound, message);
193 }
194
195 pub fn record_inbound(&self, message: Message) {
197 self.record(MessageDirection::Inbound, message);
198 }
199
200 fn record(&self, direction: MessageDirection, message: Message) {
202 if !self.is_enabled() {
203 return;
204 }
205
206 let record = MessageRecord::new(direction, message.clone());
207
208 if let Message::Request(ref req) = message {
210 if let Ok(mut pending) = self.pending_requests.write() {
211 pending.insert(req.id.clone(), (Instant::now(), req.method.to_string()));
212 }
213 }
214
215 if let Message::Response(ref res) = message {
217 if let Ok(mut pending) = self.pending_requests.write() {
218 if let Some((start, method)) = pending.remove(&res.id) {
219 let duration = start.elapsed();
220 if let Ok(mut times) = self.response_times.write() {
221 times.entry(method).or_default().push(duration);
222 }
223 }
224 }
225 }
226
227 if let Ok(mut records) = self.records.write() {
229 records.push(record);
230
231 if self.max_records > 0 && records.len() > self.max_records {
233 let excess = records.len() - self.max_records;
234 records.drain(0..excess);
235 }
236 }
237 }
238
239 #[must_use]
241 pub fn records(&self) -> Vec<MessageRecord> {
242 self.records.read().map(|r| r.clone()).unwrap_or_default()
243 }
244
245 #[must_use]
247 pub fn len(&self) -> usize {
248 self.records.read().map(|r| r.len()).unwrap_or(0)
249 }
250
251 #[must_use]
253 pub fn is_empty(&self) -> bool {
254 self.len() == 0
255 }
256
257 pub fn clear(&self) {
259 if let Ok(mut records) = self.records.write() {
260 records.clear();
261 }
262 if let Ok(mut pending) = self.pending_requests.write() {
263 pending.clear();
264 }
265 if let Ok(mut times) = self.response_times.write() {
266 times.clear();
267 }
268 }
269
270 #[must_use]
272 #[allow(clippy::field_reassign_with_default)]
273 pub fn stats(&self) -> MessageStats {
274 let records = self.records();
275 let mut stats = MessageStats::default();
276
277 stats.total_messages = records.len();
278
279 for record in &records {
280 stats.total_bytes += record.size_bytes;
281
282 match &record.message {
283 Message::Request(req) => {
284 stats.requests += 1;
285 *stats.by_method.entry(req.method.to_string()).or_insert(0) += 1;
286 }
287 Message::Response(res) => {
288 stats.responses += 1;
289 if res.error.is_some() {
290 stats.errors += 1;
291 }
292 }
293 Message::Notification(notif) => {
294 stats.notifications += 1;
295 *stats.by_method.entry(notif.method.to_string()).or_insert(0) += 1;
296 }
297 }
298 }
299
300 if let Ok(times) = self.response_times.read() {
302 for (method, durations) in times.iter() {
303 if !durations.is_empty() {
304 let total: Duration = durations.iter().sum();
305 let avg = total / durations.len() as u32;
306 stats.avg_response_time.insert(method.clone(), avg);
307 }
308 }
309 }
310
311 stats
312 }
313
314 #[must_use]
316 pub fn pending_requests(&self) -> Vec<(RequestId, String, Duration)> {
317 self.pending_requests
318 .read()
319 .map(|pending| {
320 pending
321 .iter()
322 .map(|(id, (start, method))| (id.clone(), method.clone(), start.elapsed()))
323 .collect()
324 })
325 .unwrap_or_default()
326 }
327
328 #[must_use]
330 pub fn filter_by_method(&self, method: &str) -> Vec<MessageRecord> {
331 self.records()
332 .into_iter()
333 .filter(|r| r.method() == Some(method))
334 .collect()
335 }
336
337 #[must_use]
339 pub fn filter_by_direction(&self, direction: MessageDirection) -> Vec<MessageRecord> {
340 self.records()
341 .into_iter()
342 .filter(|r| r.direction == direction)
343 .collect()
344 }
345
346 #[must_use]
348 pub fn errors(&self) -> Vec<MessageRecord> {
349 self.records()
350 .into_iter()
351 .filter(MessageRecord::is_error)
352 .collect()
353 }
354
355 pub fn to_json(&self) -> Result<String, serde_json::Error> {
361 let records: Vec<_> = self
362 .records()
363 .into_iter()
364 .map(|r| {
365 serde_json::json!({
366 "direction": format!("{:?}", r.direction),
367 "message": r.message,
368 "size_bytes": r.size_bytes,
369 "tags": r.tags,
370 })
371 })
372 .collect();
373
374 serde_json::to_string_pretty(&records)
375 }
376}
377
378impl Clone for MessageInspector {
379 fn clone(&self) -> Self {
380 Self {
381 records: Arc::clone(&self.records),
382 max_records: self.max_records,
383 enabled: Arc::clone(&self.enabled),
384 pending_requests: Arc::clone(&self.pending_requests),
385 response_times: Arc::clone(&self.response_times),
386 }
387 }
388}
389
390#[cfg(test)]
391mod tests {
392 use super::*;
393 use crate::protocol::{Request, Response};
394
395 #[test]
396 fn test_message_inspector() {
397 let inspector = MessageInspector::new();
398
399 let req = Message::Request(Request::new("test/method", 1));
401 inspector.record_outbound(req);
402
403 let res = Message::Response(Response::success(RequestId::from(1), serde_json::json!({})));
404 inspector.record_inbound(res);
405
406 assert_eq!(inspector.len(), 2);
407
408 let stats = inspector.stats();
409 assert_eq!(stats.requests, 1);
410 assert_eq!(stats.responses, 1);
411 assert_eq!(stats.errors, 0);
412 }
413
414 #[test]
415 fn test_filter_by_method() {
416 let inspector = MessageInspector::new();
417
418 inspector.record_outbound(Message::Request(Request::new("method/a", 1)));
419 inspector.record_outbound(Message::Request(Request::new("method/b", 2)));
420 inspector.record_outbound(Message::Request(Request::new("method/a", 3)));
421
422 let filtered = inspector.filter_by_method("method/a");
423 assert_eq!(filtered.len(), 2);
424 }
425
426 #[test]
427 fn test_max_records() {
428 let inspector = MessageInspector::new().with_max_records(5);
429
430 for i in 0..10 {
431 inspector.record_outbound(Message::Request(Request::new("test", i)));
432 }
433
434 assert_eq!(inspector.len(), 5);
435 }
436
437 #[test]
438 fn test_enable_disable() {
439 let inspector = MessageInspector::new();
440
441 inspector.record_outbound(Message::Request(Request::new("test", 1)));
442 assert_eq!(inspector.len(), 1);
443
444 inspector.set_enabled(false);
445 inspector.record_outbound(Message::Request(Request::new("test", 2)));
446 assert_eq!(inspector.len(), 1); }
448}