1use std::{
2 pin::Pin,
3 task::{Context, Poll},
4};
5
6use futures::{Future, Stream};
7use mcp_spec::protocol::{JsonRpcError, JsonRpcMessage, JsonRpcRequest, JsonRpcResponse};
8use pin_project::pin_project;
9use tokio::io::{AsyncBufReadExt, AsyncRead, AsyncWrite, AsyncWriteExt, BufReader};
10use tower_service::Service;
11
12mod errors;
13pub use errors::{BoxError, RouterError, ServerError, TransportError};
14
15pub mod router;
16pub use router::Router;
17
18#[pin_project]
20pub struct ByteTransport<R, W> {
21 #[pin]
25 reader: BufReader<R>,
26 #[pin]
27 writer: W,
28}
29
30impl<R, W> ByteTransport<R, W>
31where
32 R: AsyncRead,
33 W: AsyncWrite,
34{
35 pub fn new(reader: R, writer: W) -> Self {
36 Self {
37 reader: BufReader::with_capacity(2 * 1024 * 1024, reader),
40 writer,
41 }
42 }
43}
44
45impl<R, W> Stream for ByteTransport<R, W>
46where
47 R: AsyncRead + Unpin,
48 W: AsyncWrite + Unpin,
49{
50 type Item = Result<JsonRpcMessage, TransportError>;
51
52 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
53 let mut this = self.project();
54 let mut buf = Vec::new();
55
56 let mut reader = this.reader.as_mut();
57 let mut read_future = Box::pin(reader.read_until(b'\n', &mut buf));
58 match read_future.as_mut().poll(cx) {
59 Poll::Ready(Ok(0)) => Poll::Ready(None), Poll::Ready(Ok(_)) => {
61 let line = match String::from_utf8(buf) {
63 Ok(s) => s,
64 Err(e) => return Poll::Ready(Some(Err(TransportError::Utf8(e)))),
65 };
66 tracing::info!(json = %line, "incoming message");
69
70 match serde_json::from_str::<serde_json::Value>(&line) {
72 Ok(value) => {
73 if !value.is_object() {
75 return Poll::Ready(Some(Err(TransportError::InvalidMessage(
76 "Message must be a JSON object".into(),
77 ))));
78 }
79 let obj = value.as_object().unwrap(); if !obj.contains_key("jsonrpc") || obj["jsonrpc"] != "2.0" {
83 return Poll::Ready(Some(Err(TransportError::InvalidMessage(
84 "Missing or invalid jsonrpc version".into(),
85 ))));
86 }
87
88 match serde_json::from_value::<JsonRpcMessage>(value) {
90 Ok(msg) => Poll::Ready(Some(Ok(msg))),
91 Err(e) => Poll::Ready(Some(Err(TransportError::Json(e)))),
92 }
93 }
94 Err(e) => Poll::Ready(Some(Err(TransportError::Json(e)))),
95 }
96 }
97 Poll::Ready(Err(e)) => Poll::Ready(Some(Err(TransportError::Io(e)))),
98 Poll::Pending => Poll::Pending,
99 }
100 }
101}
102
103impl<R, W> ByteTransport<R, W>
104where
105 R: AsyncRead + Unpin,
106 W: AsyncWrite + Unpin,
107{
108 pub async fn write_message(&mut self, msg: JsonRpcMessage) -> Result<(), std::io::Error> {
109 let json = serde_json::to_string(&msg)?;
110 Pin::new(&mut self.writer)
111 .write_all(json.as_bytes())
112 .await?;
113 Pin::new(&mut self.writer).write_all(b"\n").await?;
114 Pin::new(&mut self.writer).flush().await?;
115 Ok(())
116 }
117}
118
119pub struct Server<S> {
121 service: S,
122}
123
124impl<S> Server<S>
125where
126 S: Service<JsonRpcRequest, Response = JsonRpcResponse> + Send,
127 S::Error: Into<BoxError>,
128 S::Future: Send,
129{
130 pub fn new(service: S) -> Self {
131 Self { service }
132 }
133
134 pub async fn run<R, W>(self, mut transport: ByteTransport<R, W>) -> Result<(), ServerError>
136 where
137 R: AsyncRead + Unpin,
138 W: AsyncWrite + Unpin,
139 {
140 use futures::StreamExt;
141 let mut service = self.service;
142
143 tracing::info!("Server started");
144 while let Some(msg_result) = transport.next().await {
145 let _span = tracing::span!(tracing::Level::INFO, "message_processing").entered();
146 match msg_result {
147 Ok(msg) => {
148 match msg {
149 JsonRpcMessage::Request(request) => {
150 let id = request.id;
152 let request_json = serde_json::to_string(&request)
153 .unwrap_or_else(|_| "Failed to serialize request".to_string());
154
155 tracing::info!(
156 request_id = ?id,
157 method = ?request.method,
158 json = %request_json,
159 "Received request"
160 );
161
162 let response = match service.call(request).await {
164 Ok(resp) => resp,
165 Err(e) => {
166 let error_msg = e.into().to_string();
167 tracing::error!(error = %error_msg, "Request processing failed");
168 JsonRpcResponse {
169 jsonrpc: "2.0".to_string(),
170 id,
171 result: None,
172 error: Some(mcp_spec::protocol::ErrorData {
173 code: mcp_spec::protocol::INTERNAL_ERROR,
174 message: error_msg,
175 data: None,
176 }),
177 }
178 }
179 };
180
181 let response_json = serde_json::to_string(&response)
183 .unwrap_or_else(|_| "Failed to serialize response".to_string());
184
185 tracing::info!(
186 response_id = ?response.id,
187 json = %response_json,
188 "Sending response"
189 );
190 if let Err(e) = transport
192 .write_message(JsonRpcMessage::Response(response))
193 .await
194 {
195 return Err(ServerError::Transport(TransportError::Io(e)));
196 }
197 }
198 JsonRpcMessage::Response(_)
199 | JsonRpcMessage::Notification(_)
200 | JsonRpcMessage::Nil
201 | JsonRpcMessage::Error(_) => {
202 continue;
204 }
205 }
206 }
207 Err(e) => {
208 let error = match e {
210 TransportError::Json(_) | TransportError::InvalidMessage(_) => {
211 mcp_spec::protocol::ErrorData {
212 code: mcp_spec::protocol::PARSE_ERROR,
213 message: e.to_string(),
214 data: None,
215 }
216 }
217 TransportError::Protocol(_) => mcp_spec::protocol::ErrorData {
218 code: mcp_spec::protocol::INVALID_REQUEST,
219 message: e.to_string(),
220 data: None,
221 },
222 _ => mcp_spec::protocol::ErrorData {
223 code: mcp_spec::protocol::INTERNAL_ERROR,
224 message: e.to_string(),
225 data: None,
226 },
227 };
228
229 let error_response = JsonRpcMessage::Error(JsonRpcError {
230 jsonrpc: "2.0".to_string(),
231 id: None,
232 error,
233 });
234
235 if let Err(e) = transport.write_message(error_response).await {
236 return Err(ServerError::Transport(TransportError::Io(e)));
237 }
238 }
239 }
240 }
241
242 Ok(())
243 }
244}
245
246pub trait BoundedService:
249 Service<
250 JsonRpcRequest,
251 Response = JsonRpcResponse,
252 Error = BoxError,
253 Future = Pin<Box<dyn Future<Output = Result<JsonRpcResponse, BoxError>> + Send>>,
254 > + Send
255 + 'static
256{
257}
258
259impl<T> BoundedService for T where
261 T: Service<
262 JsonRpcRequest,
263 Response = JsonRpcResponse,
264 Error = BoxError,
265 Future = Pin<Box<dyn Future<Output = Result<JsonRpcResponse, BoxError>> + Send>>,
266 > + Send
267 + 'static
268{
269}