use futures::channel::mpsc;
use futures::{SinkExt, StreamExt};
use rustc_hash::FxHashMap;
use crate::mcp_server::{McpConnectionTo, McpServerConnect};
use crate::role;
use crate::role::HasPeer;
use crate::schema::{
McpConnectRequest, McpConnectResponse, McpDisconnectNotification, McpOverAcpMessage,
};
use crate::util::MatchDispatchFrom;
use crate::{
Agent, Channel, ConnectTo, ConnectionTo, Dispatch, HandleDispatchFrom, Handled, Responder,
Role, UntypedMessage,
};
use std::sync::Arc;
pub(super) struct McpActiveSession<Counterpart: Role> {
acp_url: String,
mcp_connect: Arc<dyn McpServerConnect<Counterpart>>,
connections: FxHashMap<String, mpsc::Sender<Dispatch>>,
}
impl<Counterpart: Role> McpActiveSession<Counterpart>
where
Counterpart: HasPeer<Agent>,
{
pub fn new(acp_url: String, mcp_connect: Arc<dyn McpServerConnect<Counterpart>>) -> Self {
Self {
acp_url,
mcp_connect,
connections: FxHashMap::default(),
}
}
async fn handle_connect_request(
&mut self,
request: McpConnectRequest,
responder: Responder<McpConnectResponse>,
acp_connection: &ConnectionTo<Counterpart>,
) -> Result<Handled<(McpConnectRequest, Responder<McpConnectResponse>)>, crate::Error> {
if request.acp_url != self.acp_url {
return Ok(Handled::No {
message: (request, responder),
retry: false,
});
}
let connection_id = format!("mcp-over-acp-connection:{}", uuid::Uuid::new_v4());
let (mcp_server_tx, mut mcp_server_rx) = mpsc::channel(128);
self.connections
.insert(connection_id.clone(), mcp_server_tx);
let (client_channel, server_channel) = Channel::duplex();
let client_component = {
let connection_id = connection_id.clone();
let acp_connection = acp_connection.clone();
role::mcp::Client
.builder()
.on_receive_dispatch(
async move |message: Dispatch, _mcp_connection| {
let wrapped = message.map(
|request, responder| {
(
McpOverAcpMessage {
connection_id: connection_id.clone(),
message: request,
meta: None,
},
responder,
)
},
|notification| McpOverAcpMessage {
connection_id: connection_id.clone(),
message: notification,
meta: None,
},
);
acp_connection.send_proxied_message_to(Agent, wrapped)
},
crate::on_receive_dispatch!(),
)
.with_spawned(move |mcp_connection| async move {
while let Some(msg) = mcp_server_rx.next().await {
mcp_connection.send_proxied_message_to(role::mcp::Server, msg)?;
}
Ok(())
})
};
let spawned_server = self.mcp_connect.connect(McpConnectionTo {
acp_url: request.acp_url.clone(),
connection: acp_connection.clone(),
});
let spawn_results = acp_connection
.spawn(async move { client_component.connect_to(client_channel).await })
.and_then(|()| {
acp_connection.spawn(async move { spawned_server.connect_to(server_channel).await })
});
match spawn_results {
Ok(()) => {
responder.respond(McpConnectResponse {
connection_id,
meta: None,
})?;
Ok(Handled::Yes)
}
Err(err) => {
responder.respond_with_error(err)?;
Ok(Handled::Yes)
}
}
}
async fn handle_mcp_over_acp_request(
&mut self,
request: McpOverAcpMessage<UntypedMessage>,
responder: Responder<serde_json::Value>,
) -> Result<
Handled<(
McpOverAcpMessage<UntypedMessage>,
Responder<serde_json::Value>,
)>,
crate::Error,
> {
let Some(mcp_server_tx) = self.connections.get_mut(&request.connection_id) else {
return Ok(Handled::No {
message: (request, responder),
retry: false,
});
};
mcp_server_tx
.send(Dispatch::Request(request.message, responder))
.await
.map_err(crate::Error::into_internal_error)?;
Ok(Handled::Yes)
}
async fn handle_mcp_over_acp_notification(
&mut self,
notification: McpOverAcpMessage<UntypedMessage>,
) -> Result<Handled<McpOverAcpMessage<UntypedMessage>>, crate::Error> {
let Some(mcp_server_tx) = self.connections.get_mut(¬ification.connection_id) else {
return Ok(Handled::No {
message: notification,
retry: false,
});
};
mcp_server_tx
.send(Dispatch::Notification(notification.message))
.await
.map_err(crate::Error::into_internal_error)?;
Ok(Handled::Yes)
}
async fn handle_mcp_disconnect_notification(
&mut self,
successor_notification: McpDisconnectNotification,
) -> Result<Handled<McpDisconnectNotification>, crate::Error> {
if let Some(_) = self
.connections
.remove(&successor_notification.connection_id)
{
Ok(Handled::Yes)
} else {
Ok(Handled::No {
message: successor_notification,
retry: false,
})
}
}
}
impl<Counterpart: Role> HandleDispatchFrom<Counterpart> for McpActiveSession<Counterpart>
where
Counterpart: HasPeer<Agent>,
{
fn describe_chain(&self) -> impl std::fmt::Debug {
"McpServerSession"
}
async fn handle_dispatch_from(
&mut self,
message: Dispatch,
connection: ConnectionTo<Counterpart>,
) -> Result<Handled<Dispatch>, crate::Error> {
MatchDispatchFrom::new(message, &connection)
.if_request_from(Agent, async |request, responder| {
self.handle_connect_request(request, responder, &connection)
.await
})
.await
.if_request_from(Agent, async |request, responder| {
self.handle_mcp_over_acp_request(request, responder).await
})
.await
.if_notification_from(Agent, async |notification| {
self.handle_mcp_over_acp_notification(notification).await
})
.await
.if_notification_from(Agent, async |notification| {
self.handle_mcp_disconnect_notification(notification).await
})
.await
.done()
}
}