use futures_concurrency::future::TryJoin as _;
use rmcp::ServiceExt;
use sacp::mcp_server::{McpConnectionTo, McpServer, McpServerConnect};
use sacp::role::{self, HasPeer};
use sacp::{Agent, ByteStreams, ConnectTo, DynConnectTo, NullRun, Role};
use tokio_util::compat::{TokioAsyncReadCompatExt, TokioAsyncWriteCompatExt};
pub trait McpServerExt<Counterpart: Role>
where
Counterpart: HasPeer<Agent>,
{
fn from_rmcp<S>(
name: impl ToString,
new_fn: impl Fn() -> S + Send + Sync + 'static,
) -> McpServer<Counterpart, NullRun>
where
S: rmcp::Service<rmcp::RoleServer>,
{
struct RmcpServer<F> {
name: String,
new_fn: F,
}
impl<Counterpart, F, S> McpServerConnect<Counterpart> for RmcpServer<F>
where
Counterpart: Role,
F: Fn() -> S + Send + Sync + 'static,
S: rmcp::Service<rmcp::RoleServer>,
{
fn name(&self) -> String {
self.name.clone()
}
fn connect(
&self,
_cx: McpConnectionTo<Counterpart>,
) -> DynConnectTo<role::mcp::Client> {
let service = (self.new_fn)();
DynConnectTo::new(RmcpServerComponent { service })
}
}
McpServer::new(
RmcpServer {
name: name.to_string(),
new_fn,
},
NullRun,
)
}
}
impl<Counterpart: Role> McpServerExt<Counterpart> for McpServer<Counterpart> where
Counterpart: HasPeer<Agent>
{
}
struct RmcpServerComponent<S> {
service: S,
}
impl<S> ConnectTo<role::mcp::Client> for RmcpServerComponent<S>
where
S: rmcp::Service<rmcp::RoleServer>,
{
async fn connect_to(
self,
client: impl ConnectTo<role::mcp::Server>,
) -> Result<(), sacp::Error> {
let (mcp_server_stream, mcp_client_stream) = tokio::io::duplex(8192);
let (mcp_server_read, mcp_server_write) = tokio::io::split(mcp_server_stream);
let (mcp_client_read, mcp_client_write) = tokio::io::split(mcp_client_stream);
let bytes_to_sacp = async {
let byte_streams =
ByteStreams::new(mcp_client_write.compat_write(), mcp_client_read.compat());
let _ = ConnectTo::<role::mcp::Client>::connect_to(byte_streams, client).await;
Ok(())
};
let bytes_to_rmcp = async {
let running_server = self
.service
.serve((mcp_server_read, mcp_server_write))
.await
.map_err(sacp::Error::into_internal_error)?;
running_server
.waiting()
.await
.map(|_quit_reason| ())
.map_err(sacp::Error::into_internal_error)
};
(bytes_to_sacp, bytes_to_rmcp).try_join().await?;
Ok(())
}
}