mcp_transport_rs/server/
byte.rs

1use std::{
2    pin::Pin,
3    task::{Context, Poll},
4};
5
6use async_trait::async_trait;
7use futures::{Stream, stream::StreamExt};
8use mcp_core_rs::{protocol::message::JsonRpcMessage, utils::parse_json_rpc_message};
9use mcp_error_rs::{Error, Result};
10use pin_project::pin_project;
11use tokio::io::{AsyncBufReadExt, AsyncRead, AsyncWrite, AsyncWriteExt, BufReader};
12
13use crate::server::traits::ServerTransport;
14
15#[pin_project]
16/// A transport that reads and writes JSON-RPC messages over byte streams.
17pub struct ByteTransport<R, W> {
18    #[pin]
19    reader: BufReader<R>,
20    #[pin]
21    writer: W,
22    buf: Vec<u8>,
23}
24
25impl<R, W> ByteTransport<R, W>
26where
27    R: AsyncRead,
28    W: AsyncWrite,
29{
30    /// Creates a new `ByteTransport` with the given reader and writer.
31    pub fn new(reader: R, writer: W) -> Self {
32        Self {
33            reader: BufReader::with_capacity(2 * 1024 * 1024, reader),
34            writer,
35            buf: Vec::with_capacity(2 * 1024 * 1024),
36        }
37    }
38}
39
40impl<R, W> Stream for ByteTransport<R, W>
41where
42    R: AsyncRead + Unpin,
43    W: AsyncWrite + Unpin,
44{
45    type Item = Result<JsonRpcMessage>;
46
47    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
48        let mut this = self.project();
49        this.buf.clear();
50
51        let mut reader = this.reader.as_mut();
52        let mut read_future = Box::pin(reader.read_until(b'\n', this.buf));
53        match read_future.as_mut().poll(cx) {
54            Poll::Ready(Ok(0)) => {
55                tracing::info!("Client closed connection (read 0 bytes)");
56                Poll::Ready(None)
57            }
58            Poll::Ready(Ok(_)) => {
59                let line = match String::from_utf8(std::mem::take(this.buf)) {
60                    Ok(s) => s,
61                    Err(e) => {
62                        tracing::warn!(?e, "Invalid UTF-8 line");
63                        return Poll::Ready(Some(Err(Error::Utf8(e))));
64                    }
65                };
66                Poll::Ready(Some(parse_json_rpc_message(&line)))
67            }
68            Poll::Ready(Err(e)) => Poll::Ready(Some(Err(Error::Io(e)))),
69            Poll::Pending => Poll::Pending,
70        }
71    }
72}
73
74#[async_trait]
75impl<R, W> ServerTransport for ByteTransport<R, W>
76where
77    R: AsyncRead + Unpin + Send + Sync,
78    W: AsyncWrite + Unpin + Send + Sync,
79{
80    async fn read_message(&mut self) -> Option<Result<JsonRpcMessage>> {
81        self.next().await
82    }
83
84    async fn write_message(&mut self, msg: JsonRpcMessage) -> Result<()> {
85        let mut this = Pin::new(self).project();
86        let json = serde_json::to_string(&msg)?;
87        this.writer.write_all(json.as_bytes()).await?;
88        this.writer.write_all(b"\n").await?;
89        this.writer.flush().await?;
90        Ok(())
91    }
92}