use std::{marker::PhantomData, sync::Arc};
use agent_client_protocol_schema::NewSessionRequest;
use futures::{StreamExt, channel::mpsc};
use uuid::Uuid;
use crate::{
Agent, Client, ConnectTo, ConnectionTo, Dispatch, DynConnectTo, HandleDispatchFrom, Handled,
Role,
jsonrpc::{
DynamicHandlerRegistration,
run::{NullRun, RunWithConnectionTo},
},
mcp_server::{
McpConnectionTo, McpServerConnect, active_session::McpActiveSession,
builder::McpServerBuilder,
},
role::{self, HasPeer},
util::MatchDispatchFrom,
};
pub struct McpServer<Counterpart: Role, Run = NullRun> {
phantom: PhantomData<Counterpart>,
acp_url: String,
connect: Arc<dyn McpServerConnect<Counterpart>>,
responder: Run,
}
impl<Counterpart: Role + std::fmt::Debug, Run: std::fmt::Debug> std::fmt::Debug
for McpServer<Counterpart, Run>
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("McpServer")
.field("phantom", &self.phantom)
.field("acp_url", &self.acp_url)
.field("responder", &self.responder)
.finish_non_exhaustive()
}
}
impl<Host: Role> McpServer<Host, NullRun> {
pub fn builder(name: impl ToString) -> McpServerBuilder<Host, NullRun> {
McpServerBuilder::new(name.to_string())
}
}
impl<Counterpart: Role, Run> McpServer<Counterpart, Run>
where
Run: RunWithConnectionTo<Counterpart>,
{
pub fn new(c: impl McpServerConnect<Counterpart>, responder: Run) -> Self {
McpServer {
phantom: PhantomData,
acp_url: format!("acp:{}", Uuid::new_v4()),
connect: Arc::new(c),
responder,
}
}
pub(crate) fn into_handler_and_responder(self) -> (McpNewSessionHandler<Counterpart>, Run)
where
Counterpart: HasPeer<Agent>,
{
let Self {
phantom: _,
acp_url,
connect,
responder,
} = self;
(McpNewSessionHandler::new(acp_url, connect), responder)
}
}
pub(crate) struct McpNewSessionHandler<Counterpart: Role>
where
Counterpart: HasPeer<Agent>,
{
acp_url: String,
connect: Arc<dyn McpServerConnect<Counterpart>>,
active_session: McpActiveSession<Counterpart>,
}
impl<Counterpart: Role> McpNewSessionHandler<Counterpart>
where
Counterpart: HasPeer<Agent>,
{
pub fn new(acp_url: String, connect: Arc<dyn McpServerConnect<Counterpart>>) -> Self {
Self {
active_session: McpActiveSession::new(acp_url.clone(), connect.clone()),
acp_url,
connect,
}
}
fn modify_new_session_request(&self, request: &mut NewSessionRequest) {
request.mcp_servers.push(crate::schema::McpServer::Http(
crate::schema::McpServerHttp::new(self.connect.name(), self.acp_url.clone()),
));
}
}
impl<Counterpart: Role> McpNewSessionHandler<Counterpart>
where
Counterpart: HasPeer<Agent>,
{
pub fn into_dynamic_handler(
self,
request: &mut NewSessionRequest,
cx: &ConnectionTo<Counterpart>,
) -> Result<DynamicHandlerRegistration<Counterpart>, crate::Error>
where
Counterpart: HasPeer<Agent>,
{
self.modify_new_session_request(request);
cx.add_dynamic_handler(self.active_session)
}
}
impl<Counterpart: Role> HandleDispatchFrom<Counterpart> for McpNewSessionHandler<Counterpart>
where
Counterpart: HasPeer<Client> + HasPeer<Agent>,
{
async fn handle_dispatch_from(
&mut self,
message: Dispatch,
cx: ConnectionTo<Counterpart>,
) -> Result<Handled<Dispatch>, crate::Error> {
MatchDispatchFrom::new(message, &cx)
.if_request_from(Client, async |mut request: NewSessionRequest, responder| {
self.modify_new_session_request(&mut request);
Ok(Handled::No {
message: (request, responder),
retry: false,
})
})
.await
.otherwise_delegate(&mut self.active_session)
.await
}
fn describe_chain(&self) -> impl std::fmt::Debug {
format!("McpServer({})", self.connect.name())
}
}
impl<Run> ConnectTo<role::mcp::Client> for McpServer<role::mcp::Client, Run>
where
Run: RunWithConnectionTo<role::mcp::Client> + 'static,
{
async fn connect_to(
self,
client: impl ConnectTo<role::mcp::Server>,
) -> Result<(), crate::Error> {
let Self {
acp_url,
connect,
responder,
phantom: _,
} = self;
let (tx, mut rx) = mpsc::unbounded();
role::mcp::Server
.builder()
.with_responder(responder)
.on_receive_dispatch(
async |message_from_client: Dispatch, _cx| {
tx.unbounded_send(message_from_client)
.map_err(|_| crate::util::internal_error("nobody listening to mcp server"))
},
crate::on_receive_dispatch!(),
)
.with_spawned(async move |connection_to_client| {
let spawned_server: DynConnectTo<role::mcp::Client> =
connect.connect(McpConnectionTo {
acp_url,
connection: connection_to_client.clone(),
});
role::mcp::Client
.builder()
.on_receive_dispatch(
async |message_from_server: Dispatch, _| {
connection_to_client.send_proxied_message(message_from_server)
},
crate::on_receive_dispatch!(),
)
.connect_with(spawned_server, async |connection_to_server| {
while let Some(message_from_client) = rx.next().await {
connection_to_server.send_proxied_message(message_from_client)?;
}
Ok(())
})
.await
})
.connect_to(client)
.await
}
}