1use std::collections::HashMap;
7use std::fs::OpenOptions;
8use std::io::{BufWriter, Write};
9use std::path::Path;
10use std::time::Instant;
11
12use agent_client_protocol::schema::v1::{
13 Notification as RpcNotification, Request as RpcRequest, RequestId, Response as RpcResponse,
14};
15use agent_client_protocol::schema::{McpOverAcpMessage, SuccessorMessage};
16use agent_client_protocol::{
17 DynConnectTo, JsonRpcMessage, RawJsonRpcMessage, RawJsonRpcParams, Role, UntypedMessage,
18};
19use rustc_hash::FxHashMap;
20use serde::{Deserialize, Serialize};
21
22use crate::ComponentIndex;
23use crate::snoop::SnooperComponent;
24
25#[derive(Debug, Clone, Serialize, Deserialize)]
27#[serde(tag = "type", rename_all = "snake_case")]
28#[non_exhaustive]
29pub enum TraceEvent {
30 Request(RequestEvent),
32
33 Response(ResponseEvent),
35
36 Notification(NotificationEvent),
38}
39
40#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
42#[serde(rename_all = "snake_case")]
43#[non_exhaustive]
44pub enum Protocol {
45 Acp,
47 Mcp,
49}
50
51#[derive(Debug, Clone, Serialize, Deserialize)]
53#[non_exhaustive]
54pub struct RequestEvent {
55 pub ts: f64,
57
58 pub protocol: Protocol,
60
61 pub from: String,
63
64 pub to: String,
66
67 pub id: serde_json::Value,
69
70 pub method: String,
72
73 #[serde(skip_serializing_if = "Option::is_none")]
75 pub session: Option<String>,
76
77 pub params: serde_json::Value,
79}
80
81#[derive(Debug, Clone, Serialize, Deserialize)]
83#[non_exhaustive]
84pub struct ResponseEvent {
85 pub ts: f64,
87
88 pub from: String,
90
91 pub to: String,
93
94 pub id: serde_json::Value,
96
97 pub is_error: bool,
99
100 pub payload: serde_json::Value,
102}
103
104#[derive(Debug, Clone, Serialize, Deserialize)]
106#[non_exhaustive]
107pub struct NotificationEvent {
108 pub ts: f64,
110
111 pub protocol: Protocol,
113
114 pub from: String,
116
117 pub to: String,
119
120 pub method: String,
122
123 #[serde(skip_serializing_if = "Option::is_none")]
125 pub session: Option<String>,
126
127 pub params: serde_json::Value,
129}
130
131pub trait WriteEvent: Send + 'static {
133 fn write_event(&mut self, event: &TraceEvent) -> std::io::Result<()>;
135}
136
137pub(crate) struct EventWriter<W> {
139 writer: W,
140}
141
142impl<W: Write> EventWriter<W> {
143 pub fn new(writer: W) -> Self {
144 Self { writer }
145 }
146}
147
148impl<W: Write + Send + 'static> WriteEvent for EventWriter<W> {
149 fn write_event(&mut self, event: &TraceEvent) -> std::io::Result<()> {
150 serde_json::to_writer(&mut self.writer, event).map_err(std::io::Error::other)?;
151 self.writer.write_all(b"\n")?;
152 self.writer.flush()
153 }
154}
155
156impl WriteEvent for futures::channel::mpsc::UnboundedSender<TraceEvent> {
158 fn write_event(&mut self, event: &TraceEvent) -> std::io::Result<()> {
159 self.unbounded_send(event.clone())
160 .map_err(|e| std::io::Error::new(std::io::ErrorKind::BrokenPipe, e))
161 }
162}
163
164pub struct TraceWriter {
166 dest: Box<dyn WriteEvent>,
167 start_time: Instant,
168
169 request_details: FxHashMap<serde_json::Value, RequestDetails>,
172}
173
174impl std::fmt::Debug for TraceWriter {
175 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
176 f.debug_struct("TraceWriter")
177 .field("start_time", &self.start_time)
178 .finish_non_exhaustive()
179 }
180}
181
182struct RequestDetails {
183 #[expect(dead_code)]
184 protocol: Protocol,
185
186 #[expect(dead_code)]
187 method: String,
188
189 request_from: ComponentIndex,
190 request_to: ComponentIndex,
191}
192
193impl TraceWriter {
194 pub fn new<D: WriteEvent>(dest: D) -> Self {
196 Self {
197 dest: Box::new(dest),
198 start_time: Instant::now(),
199 request_details: HashMap::default(),
200 }
201 }
202
203 pub fn from_path(path: impl AsRef<Path>) -> std::io::Result<Self> {
205 let file = OpenOptions::new()
206 .create(true)
207 .write(true)
208 .truncate(true)
209 .open(path.as_ref())?;
210 Ok(Self::new(EventWriter::new(BufWriter::new(file))))
211 }
212
213 fn elapsed(&self) -> f64 {
215 self.start_time.elapsed().as_secs_f64()
216 }
217
218 fn write_event(&mut self, event: &TraceEvent) {
220 drop(self.dest.write_event(event));
222 }
223
224 #[expect(clippy::too_many_arguments)]
226 fn request(
227 &mut self,
228 protocol: Protocol,
229 from: ComponentIndex,
230 to: ComponentIndex,
231 id: serde_json::Value,
232 method: String,
233 session: Option<String>,
234 params: serde_json::Value,
235 ) {
236 self.request_details.insert(
237 id.clone(),
238 RequestDetails {
239 protocol,
240 method: method.clone(),
241 request_from: from,
242 request_to: to,
243 },
244 );
245 self.write_event(&TraceEvent::Request(RequestEvent {
246 ts: self.elapsed(),
247 protocol,
248 from: format!("{from:?}"),
249 to: format!("{to:?}"),
250 id,
251 method,
252 session,
253 params,
254 }));
255 }
256
257 fn response(
259 &mut self,
260 from: ComponentIndex,
261 to: ComponentIndex,
262 id: serde_json::Value,
263 is_error: bool,
264 payload: serde_json::Value,
265 ) {
266 self.write_event(&TraceEvent::Response(ResponseEvent {
267 ts: self.elapsed(),
268 from: format!("{from:?}"),
269 to: format!("{to:?}"),
270 id,
271 is_error,
272 payload,
273 }));
274 }
275
276 fn notification(
278 &mut self,
279 protocol: Protocol,
280 from: ComponentIndex,
281 to: ComponentIndex,
282 method: impl Into<String>,
283 session: Option<String>,
284 params: serde_json::Value,
285 ) {
286 self.write_event(&TraceEvent::Notification(NotificationEvent {
287 ts: self.elapsed(),
288 protocol,
289 from: format!("{from:?}"),
290 to: format!("{to:?}"),
291 method: method.into(),
292 session,
293 params,
294 }));
295 }
296
297 fn trace_message(&mut self, traced_message: TracedMessage) {
299 let TracedMessage {
300 component_index,
301 successor_index,
302 incoming,
303 message,
304 } = traced_message;
305
306 match message {
317 RawJsonRpcMessage::Request(req) => {
318 let MessageInfo {
319 successor,
320 id,
321 protocol,
322 method,
323 params,
324 } = MessageInfo::from_request(req);
325
326 self.trace_request_or_notification(
327 incoming,
328 component_index,
329 successor_index,
330 successor,
331 id,
332 protocol,
333 method,
334 params,
335 );
336 }
337 RawJsonRpcMessage::Notification(notification) => {
338 let MessageInfo {
339 successor,
340 id,
341 protocol,
342 method,
343 params,
344 } = MessageInfo::from_notification(notification);
345
346 self.trace_request_or_notification(
347 incoming,
348 component_index,
349 successor_index,
350 successor,
351 id,
352 protocol,
353 method,
354 params,
355 );
356 }
357 RawJsonRpcMessage::Response(resp) => {
358 let (id, is_error, payload) = match resp {
362 RpcResponse::Result { id, result } => (id, false, result),
363 RpcResponse::Error { id, error } => {
364 (id, true, serde_json::to_value(error).unwrap_or_default())
365 }
366 };
367 let id = id_to_json(&id);
368 if let Some(RequestDetails {
369 protocol: _,
370 method: _,
371 request_from,
372 request_to,
373 }) = self.request_details.remove(&id)
374 {
375 self.response(request_to, request_from, id, is_error, payload);
376 }
377 }
378 }
379 }
380
381 #[expect(clippy::too_many_arguments)]
382 fn trace_request_or_notification(
383 &mut self,
384 incoming: Incoming,
385 component_index: ComponentIndex,
386 successor_index: ComponentIndex,
387 successor: Successor,
388 id: Option<RequestId>,
389 protocol: Protocol,
390 method: String,
391 params: serde_json::Value,
392 ) {
393 let (from, to) = match (successor, incoming, component_index, successor_index) {
394 (Successor(false), Incoming(true), ComponentIndex::Proxy(proxy_index), _) => (
396 ComponentIndex::predecessor_of(proxy_index),
397 ComponentIndex::Proxy(proxy_index),
398 ),
399
400 (Successor(true), Incoming(true), component_index, successor_index) => {
404 (successor_index, component_index)
405 }
406
407 (Successor(true), Incoming(false), component_index, ComponentIndex::Agent) => {
413 (component_index, ComponentIndex::Agent)
414 }
415
416 _ => return,
417 };
418
419 match id {
420 Some(id) => {
421 self.request(protocol, from, to, id_to_json(&id), method, None, params);
422 }
423 None => {
424 self.notification(protocol, from, to, method, None, params);
425 }
426 }
427 }
428
429 pub(crate) fn spawn(
434 mut self: TraceWriter,
435 ) -> (
436 TraceHandle,
437 impl std::future::Future<Output = Result<(), agent_client_protocol::Error>>,
438 ) {
439 use futures::StreamExt;
440
441 let (tx, mut rx) = futures::channel::mpsc::unbounded();
442
443 let future = async move {
444 while let Some(event) = rx.next().await {
445 self.trace_message(event);
446 }
447 Ok(())
448 };
449
450 (TraceHandle { tx }, future)
451 }
452}
453
454#[derive(Clone, Debug)]
458pub(crate) struct TraceHandle {
459 tx: futures::channel::mpsc::UnboundedSender<TracedMessage>,
460}
461
462impl TraceHandle {
463 fn trace_message(
465 &self,
466 component_index: ComponentIndex,
467 successor_index: ComponentIndex,
468 incoming: Incoming,
469 message: &RawJsonRpcMessage,
470 ) -> Result<(), agent_client_protocol::Error> {
471 self.tx
472 .unbounded_send(TracedMessage {
473 component_index,
474 successor_index,
475 incoming,
476 message: message.clone(),
477 })
478 .map_err(agent_client_protocol::util::internal_error)
479 }
480
481 pub fn bridge_component<R: Role>(
496 &self,
497 proxy_index: ComponentIndex,
498 successor_index: ComponentIndex,
499 proxy: impl agent_client_protocol::ConnectTo<R>,
500 ) -> DynConnectTo<R> {
501 DynConnectTo::new(SnooperComponent::new(
502 proxy,
503 {
504 let trace_handle = self.clone();
505 move |msg| {
506 trace_handle.trace_message(proxy_index, successor_index, Incoming(true), msg)
507 }
508 },
509 {
510 let trace_handle = self.clone();
511 move |msg| {
512 trace_handle.trace_message(proxy_index, successor_index, Incoming(false), msg)
513 }
514 },
515 ))
516 }
517}
518
519fn id_to_json(id: &RequestId) -> serde_json::Value {
521 serde_json::to_value(id).expect("RequestId serializes infallibly")
522}
523
524fn params_from_transport(params: Option<RawJsonRpcParams>) -> serde_json::Value {
525 params.map_or(serde_json::Value::Null, RawJsonRpcParams::into_value)
526}
527
528#[derive(Debug)]
531struct TracedMessage {
532 component_index: ComponentIndex,
533 successor_index: ComponentIndex,
534 incoming: Incoming,
535 message: RawJsonRpcMessage,
536}
537
538#[derive(Debug)]
540struct MessageInfo {
541 successor: Successor,
542 id: Option<RequestId>,
543 protocol: Protocol,
544 method: String,
545 params: serde_json::Value,
546}
547
548#[derive(Copy, Clone, Debug)]
549struct Successor(bool);
550
551#[derive(Copy, Clone, Debug)]
552struct Incoming(bool);
553
554impl MessageInfo {
555 fn from_request(req: RpcRequest<RawJsonRpcParams>) -> Self {
564 let untyped =
565 UntypedMessage::parse_message(&req.method, ¶ms_from_transport(req.params))
566 .expect("untyped message is infallible");
567 Self::from_untyped(Successor(false), Some(req.id), Protocol::Acp, untyped)
568 }
569
570 fn from_notification(notification: RpcNotification<RawJsonRpcParams>) -> Self {
571 let untyped = UntypedMessage::parse_message(
572 ¬ification.method,
573 ¶ms_from_transport(notification.params),
574 )
575 .expect("untyped message is infallible");
576 Self::from_untyped(Successor(false), None, Protocol::Acp, untyped)
577 }
578
579 fn from_untyped(
580 successor: Successor,
581 id: Option<RequestId>,
582 protocol: Protocol,
583 untyped: UntypedMessage,
584 ) -> Self {
585 if let Ok(m) = SuccessorMessage::parse_message(&untyped.method, &untyped.params) {
586 return Self::from_untyped(Successor(true), id, protocol, m.message);
587 }
588
589 if let Ok(m) = McpOverAcpMessage::parse_message(&untyped.method, &untyped.params) {
590 return Self::from_untyped(successor, id, Protocol::Mcp, m.message);
591 }
592
593 Self {
594 successor,
595 id,
596 protocol,
597 method: untyped.method,
598 params: untyped.params,
599 }
600 }
601}