use std::collections::HashMap;
use std::fs::OpenOptions;
use std::io::{BufWriter, Write};
use std::path::Path;
use std::time::Instant;
use agent_client_protocol::schema::{McpOverAcpMessage, SuccessorMessage};
use agent_client_protocol::{DynConnectTo, JsonRpcMessage, Role, UntypedMessage, jsonrpcmsg};
use rustc_hash::FxHashMap;
use serde::{Deserialize, Serialize};
use crate::ComponentIndex;
use crate::snoop::SnooperComponent;
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
#[non_exhaustive]
pub enum TraceEvent {
Request(RequestEvent),
Response(ResponseEvent),
Notification(NotificationEvent),
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
#[non_exhaustive]
pub enum Protocol {
Acp,
Mcp,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[non_exhaustive]
pub struct RequestEvent {
pub ts: f64,
pub protocol: Protocol,
pub from: String,
pub to: String,
pub id: serde_json::Value,
pub method: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub session: Option<String>,
pub params: serde_json::Value,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[non_exhaustive]
pub struct ResponseEvent {
pub ts: f64,
pub from: String,
pub to: String,
pub id: serde_json::Value,
pub is_error: bool,
pub payload: serde_json::Value,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[non_exhaustive]
pub struct NotificationEvent {
pub ts: f64,
pub protocol: Protocol,
pub from: String,
pub to: String,
pub method: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub session: Option<String>,
pub params: serde_json::Value,
}
pub trait WriteEvent: Send + 'static {
fn write_event(&mut self, event: &TraceEvent) -> std::io::Result<()>;
}
pub(crate) struct EventWriter<W> {
writer: W,
}
impl<W: Write> EventWriter<W> {
pub fn new(writer: W) -> Self {
Self { writer }
}
}
impl<W: Write + Send + 'static> WriteEvent for EventWriter<W> {
fn write_event(&mut self, event: &TraceEvent) -> std::io::Result<()> {
serde_json::to_writer(&mut self.writer, event).map_err(std::io::Error::other)?;
self.writer.write_all(b"\n")?;
self.writer.flush()
}
}
impl WriteEvent for futures::channel::mpsc::UnboundedSender<TraceEvent> {
fn write_event(&mut self, event: &TraceEvent) -> std::io::Result<()> {
self.unbounded_send(event.clone())
.map_err(|e| std::io::Error::new(std::io::ErrorKind::BrokenPipe, e))
}
}
pub struct TraceWriter {
dest: Box<dyn WriteEvent>,
start_time: Instant,
request_details: FxHashMap<serde_json::Value, RequestDetails>,
}
impl std::fmt::Debug for TraceWriter {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("TraceWriter")
.field("start_time", &self.start_time)
.finish_non_exhaustive()
}
}
struct RequestDetails {
#[expect(dead_code)]
protocol: Protocol,
#[expect(dead_code)]
method: String,
request_from: ComponentIndex,
request_to: ComponentIndex,
}
impl TraceWriter {
pub fn new<D: WriteEvent>(dest: D) -> Self {
Self {
dest: Box::new(dest),
start_time: Instant::now(),
request_details: HashMap::default(),
}
}
pub fn from_path(path: impl AsRef<Path>) -> std::io::Result<Self> {
let file = OpenOptions::new()
.create(true)
.write(true)
.truncate(true)
.open(path.as_ref())?;
Ok(Self::new(EventWriter::new(BufWriter::new(file))))
}
fn elapsed(&self) -> f64 {
self.start_time.elapsed().as_secs_f64()
}
fn write_event(&mut self, event: &TraceEvent) {
drop(self.dest.write_event(event));
}
#[expect(clippy::too_many_arguments)]
fn request(
&mut self,
protocol: Protocol,
from: ComponentIndex,
to: ComponentIndex,
id: serde_json::Value,
method: String,
session: Option<String>,
params: serde_json::Value,
) {
self.request_details.insert(
id.clone(),
RequestDetails {
protocol,
method: method.clone(),
request_from: from,
request_to: to,
},
);
self.write_event(&TraceEvent::Request(RequestEvent {
ts: self.elapsed(),
protocol,
from: format!("{from:?}"),
to: format!("{to:?}"),
id,
method,
session,
params,
}));
}
fn response(
&mut self,
from: ComponentIndex,
to: ComponentIndex,
id: serde_json::Value,
is_error: bool,
payload: serde_json::Value,
) {
self.write_event(&TraceEvent::Response(ResponseEvent {
ts: self.elapsed(),
from: format!("{from:?}"),
to: format!("{to:?}"),
id,
is_error,
payload,
}));
}
fn notification(
&mut self,
protocol: Protocol,
from: ComponentIndex,
to: ComponentIndex,
method: impl Into<String>,
session: Option<String>,
params: serde_json::Value,
) {
self.write_event(&TraceEvent::Notification(NotificationEvent {
ts: self.elapsed(),
protocol,
from: format!("{from:?}"),
to: format!("{to:?}"),
method: method.into(),
session,
params,
}));
}
fn trace_message(&mut self, traced_message: TracedMessage) {
let TracedMessage {
component_index,
successor_index,
incoming,
message,
} = traced_message;
match message {
jsonrpcmsg::Message::Request(req) => {
let MessageInfo {
successor,
id,
protocol,
method,
params,
} = MessageInfo::from_req(req);
let (from, to) = match (successor, incoming, component_index, successor_index) {
(Successor(false), Incoming(true), ComponentIndex::Proxy(proxy_index), _) => (
ComponentIndex::predecessor_of(proxy_index),
ComponentIndex::Proxy(proxy_index),
),
(Successor(true), Incoming(true), component_index, successor_index) => {
(successor_index, component_index)
}
(Successor(true), Incoming(false), component_index, ComponentIndex::Agent) => {
(component_index, ComponentIndex::Agent)
}
_ => return,
};
match id {
Some(id) => {
self.request(protocol, from, to, id_to_json(&id), method, None, params);
}
None => {
self.notification(protocol, from, to, method, None, params);
}
}
}
jsonrpcmsg::Message::Response(resp) => {
if let Some(id) = resp.id {
let id = id_to_json(&id);
if let Some(RequestDetails {
protocol: _,
method: _,
request_from,
request_to,
}) = self.request_details.remove(&id)
{
let (is_error, payload) = match (&resp.result, &resp.error) {
(Some(result), _) => (false, result.clone()),
(_, Some(error)) => {
(true, serde_json::to_value(error).unwrap_or_default())
}
(None, None) => (false, serde_json::Value::Null),
};
self.response(request_to, request_from, id, is_error, payload);
}
}
}
}
}
pub(crate) fn spawn(
mut self: TraceWriter,
) -> (
TraceHandle,
impl std::future::Future<Output = Result<(), agent_client_protocol::Error>>,
) {
use futures::StreamExt;
let (tx, mut rx) = futures::channel::mpsc::unbounded();
let future = async move {
while let Some(event) = rx.next().await {
self.trace_message(event);
}
Ok(())
};
(TraceHandle { tx }, future)
}
}
#[derive(Clone, Debug)]
pub(crate) struct TraceHandle {
tx: futures::channel::mpsc::UnboundedSender<TracedMessage>,
}
impl TraceHandle {
fn trace_message(
&self,
component_index: ComponentIndex,
successor_index: ComponentIndex,
incoming: Incoming,
message: &jsonrpcmsg::Message,
) -> Result<(), agent_client_protocol::Error> {
self.tx
.unbounded_send(TracedMessage {
component_index,
successor_index,
incoming,
message: message.clone(),
})
.map_err(agent_client_protocol::util::internal_error)
}
pub fn bridge_component<R: Role>(
&self,
proxy_index: ComponentIndex,
successor_index: ComponentIndex,
proxy: impl agent_client_protocol::ConnectTo<R>,
) -> DynConnectTo<R> {
DynConnectTo::new(SnooperComponent::new(
proxy,
{
let trace_handle = self.clone();
move |msg| {
trace_handle.trace_message(proxy_index, successor_index, Incoming(true), msg)
}
},
{
let trace_handle = self.clone();
move |msg| {
trace_handle.trace_message(proxy_index, successor_index, Incoming(false), msg)
}
},
))
}
}
fn id_to_json(id: &jsonrpcmsg::Id) -> serde_json::Value {
match id {
jsonrpcmsg::Id::String(s) => serde_json::Value::String(s.clone()),
jsonrpcmsg::Id::Number(n) => serde_json::Value::Number((*n).into()),
jsonrpcmsg::Id::Null => serde_json::Value::Null,
}
}
#[derive(Debug)]
struct TracedMessage {
component_index: ComponentIndex,
successor_index: ComponentIndex,
incoming: Incoming,
message: jsonrpcmsg::Message,
}
#[derive(Debug)]
struct MessageInfo {
successor: Successor,
id: Option<jsonrpcmsg::Id>,
protocol: Protocol,
method: String,
params: serde_json::Value,
}
#[derive(Copy, Clone, Debug)]
struct Successor(bool);
#[derive(Copy, Clone, Debug)]
struct Incoming(bool);
impl MessageInfo {
fn from_req(req: jsonrpcmsg::Request) -> Self {
let untyped = UntypedMessage::parse_message(&req.method, &req.params)
.expect("untyped message is infallible");
Self::from_untyped(Successor(false), req.id, Protocol::Acp, untyped)
}
fn from_untyped(
successor: Successor,
id: Option<jsonrpcmsg::Id>,
protocol: Protocol,
untyped: UntypedMessage,
) -> Self {
if let Ok(m) = SuccessorMessage::parse_message(&untyped.method, &untyped.params) {
return Self::from_untyped(Successor(true), id, protocol, m.message);
}
if let Ok(m) = McpOverAcpMessage::parse_message(&untyped.method, &untyped.params) {
return Self::from_untyped(successor, id, Protocol::Mcp, m.message);
}
Self {
successor,
id,
protocol,
method: untyped.method,
params: untyped.params,
}
}
}