use crate::{
errors::ServerError,
transport::{ByteTransport, TransportError},
};
use kuri_mcp_protocol::jsonrpc::{JsonRpcResponse, SendableMessage};
use std::{convert::Infallible, pin::Pin};
use tokio::io::{AsyncRead, AsyncWrite};
use tower::Service;
pub struct Server<S> {
service: S,
}
fn trace_response(response: &Option<JsonRpcResponse>) {
let response_json = serde_json::to_string(&response)
.unwrap_or_else(|_| "Failed to serialize response".to_string());
tracing::debug!(
json = %response_json,
"Sending response"
);
}
impl<S> Server<S>
where
S: Service<SendableMessage, Response = Option<JsonRpcResponse>, Error = Infallible>,
{
pub fn new(service: S) -> Self {
Self { service }
}
#[tracing::instrument(level = "trace", fields(request_id, method), skip(self, transport))]
async fn process_message<R, W>(
&mut self,
transport: &mut Pin<&mut ByteTransport<R, W>>,
msg_result: Result<SendableMessage, TransportError>,
) -> Result<(), ServerError>
where
R: AsyncRead + Unpin,
W: AsyncWrite + Unpin,
{
use valuable::Valuable;
match msg_result {
Ok(SendableMessage::Request(request)) => {
let id = request.id.clone();
tracing::Span::current().record("request_id", id.as_value());
tracing::Span::current().record("method", &request.method);
let response = self
.service
.call(SendableMessage::from(request))
.await
.expect("MCPService cannot return an error.");
trace_response(&response);
if let Some(response) = response {
transport
.write_message(response)
.await
.map_err(|e| ServerError::Transport(TransportError::Io(e)))?;
}
}
Ok(SendableMessage::Notification(notification)) => {
tracing::Span::current().record("method", ¬ification.method);
self.service
.call(SendableMessage::from(notification))
.await
.expect("MCPService cannot return an error.");
}
Err(e) => {
tracing::error!(error = ?e, "Transport error");
}
}
Ok(())
}
pub async fn run<R, W>(mut self, mut transport: ByteTransport<R, W>) -> Result<(), ServerError>
where
R: AsyncRead + Unpin,
W: AsyncWrite + Unpin,
{
use futures::StreamExt;
let mut transport = Pin::new(&mut transport);
tracing::info!("Server started");
while let Some(msg_result) = transport.next().await {
self.process_message(&mut transport, msg_result).await?;
}
Ok(())
}
}