mcp_server/
lib.rs

1use std::{
2    pin::Pin,
3    task::{Context, Poll},
4};
5
6use futures::{Future, Stream};
7use mcp_spec::protocol::{JsonRpcError, JsonRpcMessage, JsonRpcRequest, JsonRpcResponse};
8use pin_project::pin_project;
9use tokio::io::{AsyncBufReadExt, AsyncRead, AsyncWrite, AsyncWriteExt, BufReader};
10use tower_service::Service;
11
12mod errors;
13pub use errors::{BoxError, RouterError, ServerError, TransportError};
14
15pub mod router;
16pub use router::Router;
17
18/// A transport layer that handles JSON-RPC messages over byte
19#[pin_project]
20pub struct ByteTransport<R, W> {
21    // Reader is a BufReader on the underlying stream (stdin or similar) buffering
22    // the underlying data across poll calls, we clear one line (\n) during each
23    // iteration of poll_next from this buffer
24    #[pin]
25    reader: BufReader<R>,
26    #[pin]
27    writer: W,
28}
29
30impl<R, W> ByteTransport<R, W>
31where
32    R: AsyncRead,
33    W: AsyncWrite,
34{
35    pub fn new(reader: R, writer: W) -> Self {
36        Self {
37            // Default BufReader capacity is 8 * 1024, increase this to 2MB to the file size limit
38            // allows the buffer to have the capacity to read very large calls
39            reader: BufReader::with_capacity(2 * 1024 * 1024, reader),
40            writer,
41        }
42    }
43}
44
45impl<R, W> Stream for ByteTransport<R, W>
46where
47    R: AsyncRead + Unpin,
48    W: AsyncWrite + Unpin,
49{
50    type Item = Result<JsonRpcMessage, TransportError>;
51
52    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
53        let mut this = self.project();
54        let mut buf = Vec::new();
55
56        let mut reader = this.reader.as_mut();
57        let mut read_future = Box::pin(reader.read_until(b'\n', &mut buf));
58        match read_future.as_mut().poll(cx) {
59            Poll::Ready(Ok(0)) => Poll::Ready(None), // EOF
60            Poll::Ready(Ok(_)) => {
61                // Convert to UTF-8 string
62                let line = match String::from_utf8(buf) {
63                    Ok(s) => s,
64                    Err(e) => return Poll::Ready(Some(Err(TransportError::Utf8(e)))),
65                };
66                // Log incoming message here before serde conversion to
67                // track incomplete chunks which are not valid JSON
68                tracing::info!(json = %line, "incoming message");
69
70                // Parse JSON and validate message format
71                match serde_json::from_str::<serde_json::Value>(&line) {
72                    Ok(value) => {
73                        // Validate basic JSON-RPC structure
74                        if !value.is_object() {
75                            return Poll::Ready(Some(Err(TransportError::InvalidMessage(
76                                "Message must be a JSON object".into(),
77                            ))));
78                        }
79                        let obj = value.as_object().unwrap(); // Safe due to check above
80
81                        // Check jsonrpc version field
82                        if !obj.contains_key("jsonrpc") || obj["jsonrpc"] != "2.0" {
83                            return Poll::Ready(Some(Err(TransportError::InvalidMessage(
84                                "Missing or invalid jsonrpc version".into(),
85                            ))));
86                        }
87
88                        // Now try to parse as proper message
89                        match serde_json::from_value::<JsonRpcMessage>(value) {
90                            Ok(msg) => Poll::Ready(Some(Ok(msg))),
91                            Err(e) => Poll::Ready(Some(Err(TransportError::Json(e)))),
92                        }
93                    }
94                    Err(e) => Poll::Ready(Some(Err(TransportError::Json(e)))),
95                }
96            }
97            Poll::Ready(Err(e)) => Poll::Ready(Some(Err(TransportError::Io(e)))),
98            Poll::Pending => Poll::Pending,
99        }
100    }
101}
102
103impl<R, W> ByteTransport<R, W>
104where
105    R: AsyncRead + Unpin,
106    W: AsyncWrite + Unpin,
107{
108    pub async fn write_message(&mut self, msg: JsonRpcMessage) -> Result<(), std::io::Error> {
109        let json = serde_json::to_string(&msg)?;
110        Pin::new(&mut self.writer)
111            .write_all(json.as_bytes())
112            .await?;
113        Pin::new(&mut self.writer).write_all(b"\n").await?;
114        Pin::new(&mut self.writer).flush().await?;
115        Ok(())
116    }
117}
118
119/// The main server type that processes incoming requests
120pub struct Server<S> {
121    service: S,
122}
123
124impl<S> Server<S>
125where
126    S: Service<JsonRpcRequest, Response = JsonRpcResponse> + Send,
127    S::Error: Into<BoxError>,
128    S::Future: Send,
129{
130    pub fn new(service: S) -> Self {
131        Self { service }
132    }
133
134    // TODO transport trait instead of byte transport if we implement others
135    pub async fn run<R, W>(self, mut transport: ByteTransport<R, W>) -> Result<(), ServerError>
136    where
137        R: AsyncRead + Unpin,
138        W: AsyncWrite + Unpin,
139    {
140        use futures::StreamExt;
141        let mut service = self.service;
142
143        tracing::info!("Server started");
144        while let Some(msg_result) = transport.next().await {
145            let _span = tracing::span!(tracing::Level::INFO, "message_processing").entered();
146            match msg_result {
147                Ok(msg) => {
148                    match msg {
149                        JsonRpcMessage::Request(request) => {
150                            // Serialize request for logging
151                            let id = request.id;
152                            let request_json = serde_json::to_string(&request)
153                                .unwrap_or_else(|_| "Failed to serialize request".to_string());
154
155                            tracing::info!(
156                                request_id = ?id,
157                                method = ?request.method,
158                                json = %request_json,
159                                "Received request"
160                            );
161
162                            // Process the request using our service
163                            let response = match service.call(request).await {
164                                Ok(resp) => resp,
165                                Err(e) => {
166                                    let error_msg = e.into().to_string();
167                                    tracing::error!(error = %error_msg, "Request processing failed");
168                                    JsonRpcResponse {
169                                        jsonrpc: "2.0".to_string(),
170                                        id,
171                                        result: None,
172                                        error: Some(mcp_spec::protocol::ErrorData {
173                                            code: mcp_spec::protocol::INTERNAL_ERROR,
174                                            message: error_msg,
175                                            data: None,
176                                        }),
177                                    }
178                                }
179                            };
180
181                            // Serialize response for logging
182                            let response_json = serde_json::to_string(&response)
183                                .unwrap_or_else(|_| "Failed to serialize response".to_string());
184
185                            tracing::info!(
186                                response_id = ?response.id,
187                                json = %response_json,
188                                "Sending response"
189                            );
190                            // Send the response back
191                            if let Err(e) = transport
192                                .write_message(JsonRpcMessage::Response(response))
193                                .await
194                            {
195                                return Err(ServerError::Transport(TransportError::Io(e)));
196                            }
197                        }
198                        JsonRpcMessage::Response(_)
199                        | JsonRpcMessage::Notification(_)
200                        | JsonRpcMessage::Nil
201                        | JsonRpcMessage::Error(_) => {
202                            // Ignore responses, notifications and nil messages for now
203                            continue;
204                        }
205                    }
206                }
207                Err(e) => {
208                    // Convert transport error to JSON-RPC error response
209                    let error = match e {
210                        TransportError::Json(_) | TransportError::InvalidMessage(_) => {
211                            mcp_spec::protocol::ErrorData {
212                                code: mcp_spec::protocol::PARSE_ERROR,
213                                message: e.to_string(),
214                                data: None,
215                            }
216                        }
217                        TransportError::Protocol(_) => mcp_spec::protocol::ErrorData {
218                            code: mcp_spec::protocol::INVALID_REQUEST,
219                            message: e.to_string(),
220                            data: None,
221                        },
222                        _ => mcp_spec::protocol::ErrorData {
223                            code: mcp_spec::protocol::INTERNAL_ERROR,
224                            message: e.to_string(),
225                            data: None,
226                        },
227                    };
228
229                    let error_response = JsonRpcMessage::Error(JsonRpcError {
230                        jsonrpc: "2.0".to_string(),
231                        id: None,
232                        error,
233                    });
234
235                    if let Err(e) = transport.write_message(error_response).await {
236                        return Err(ServerError::Transport(TransportError::Io(e)));
237                    }
238                }
239            }
240        }
241
242        Ok(())
243    }
244}
245
246// Define a specific service implementation that we need for any
247// Any router implements this
248pub trait BoundedService:
249    Service<
250        JsonRpcRequest,
251        Response = JsonRpcResponse,
252        Error = BoxError,
253        Future = Pin<Box<dyn Future<Output = Result<JsonRpcResponse, BoxError>> + Send>>,
254    > + Send
255    + 'static
256{
257}
258
259// Implement it for any type that meets the bounds
260impl<T> BoundedService for T where
261    T: Service<
262            JsonRpcRequest,
263            Response = JsonRpcResponse,
264            Error = BoxError,
265            Future = Pin<Box<dyn Future<Output = Result<JsonRpcResponse, BoxError>> + Send>>,
266        > + Send
267        + 'static
268{
269}