mcp_core/transport/server/
stdio.rs

1use crate::protocol::{Protocol, RequestOptions};
2use crate::transport::{
3    JsonRpcError, JsonRpcNotification, JsonRpcRequest, JsonRpcResponse, Message, RequestId,
4    Transport,
5};
6use crate::types::ErrorCode;
7use anyhow::Result;
8use async_trait::async_trait;
9use std::future::Future;
10use std::io::{self, BufRead, Write};
11use std::pin::Pin;
12use tokio::time::timeout;
13use tracing::debug;
14
15#[derive(Clone)]
16pub struct ServerStdioTransport {
17    protocol: Protocol,
18}
19
20impl ServerStdioTransport {
21    pub fn new(protocol: Protocol) -> Self {
22        Self { protocol }
23    }
24}
25
26#[async_trait()]
27impl Transport for ServerStdioTransport {
28    async fn open(&self) -> Result<()> {
29        loop {
30            match self.poll_message().await {
31                Ok(Some(message)) => match message {
32                    Message::Request(request) => {
33                        let response = self.protocol.handle_request(request).await;
34                        self.send_response(response.id, response.result, response.error)
35                            .await?;
36                    }
37                    Message::Notification(notification) => {
38                        self.protocol.handle_notification(notification).await;
39                    }
40                    Message::Response(response) => {
41                        self.protocol.handle_response(response).await;
42                    }
43                },
44                Ok(None) => {
45                    break;
46                }
47                Err(e) => {
48                    tracing::error!("Error receiving message: {:?}", e);
49                }
50            }
51        }
52        Ok(())
53    }
54
55    async fn close(&self) -> Result<()> {
56        Ok(())
57    }
58
59    async fn poll_message(&self) -> Result<Option<Message>> {
60        let stdin = io::stdin();
61        let mut reader = stdin.lock();
62        let mut line = String::new();
63        reader.read_line(&mut line)?;
64        if line.is_empty() {
65            return Ok(None);
66        }
67
68        debug!("Received: {line}");
69        let message: Message = serde_json::from_str(&line)?;
70        Ok(Some(message))
71    }
72
73    fn request(
74        &self,
75        method: &str,
76        params: Option<serde_json::Value>,
77        options: RequestOptions,
78    ) -> Pin<Box<dyn Future<Output = Result<JsonRpcResponse>> + Send + Sync>> {
79        let protocol = self.protocol.clone();
80        let method = method.to_owned();
81        Box::pin(async move {
82            let (id, rx) = protocol.create_request().await;
83            let request = JsonRpcRequest {
84                id,
85                method,
86                jsonrpc: Default::default(),
87                params,
88            };
89            let serialized = serde_json::to_string(&request).unwrap_or_default();
90            debug!("Sending: {serialized}");
91
92            // Use Tokio's async stdout to perform thread-safe, nonblocking writes.
93            let mut stdout = io::stdout();
94            stdout.write_all(serialized.as_bytes())?;
95            stdout.write_all(b"\n")?;
96            stdout.flush()?;
97
98            let result = timeout(options.timeout, rx).await;
99            match result {
100                // The request future completed before the timeout.
101                Ok(inner_result) => match inner_result {
102                    Ok(response) => Ok(response),
103                    Err(_) => {
104                        protocol.cancel_response(id).await;
105                        Ok(JsonRpcResponse {
106                            id,
107                            result: None,
108                            error: Some(JsonRpcError {
109                                code: ErrorCode::RequestTimeout as i32,
110                                message: "Request cancelled".to_string(),
111                                data: None,
112                            }),
113                            ..Default::default()
114                        })
115                    }
116                },
117                // The timeout expired.
118                Err(_) => {
119                    protocol.cancel_response(id).await;
120                    Ok(JsonRpcResponse {
121                        id,
122                        result: None,
123                        error: Some(JsonRpcError {
124                            code: ErrorCode::RequestTimeout as i32,
125                            message: "Request cancelled".to_string(),
126                            data: None,
127                        }),
128                        ..Default::default()
129                    })
130                }
131            }
132        })
133    }
134
135    async fn send_notification(
136        &self,
137        method: &str,
138        params: Option<serde_json::Value>,
139    ) -> Result<()> {
140        let notification = JsonRpcNotification {
141            jsonrpc: Default::default(),
142            method: method.to_owned(),
143            params,
144        };
145        let serialized = serde_json::to_string(&notification).unwrap_or_default();
146        let stdout = io::stdout();
147        let mut writer = stdout.lock();
148        debug!("Sending: {serialized}");
149        writer.write_all(serialized.as_bytes())?;
150        writer.write_all(b"\n")?;
151        writer.flush()?;
152        Ok(())
153    }
154
155    async fn send_response(
156        &self,
157        id: RequestId,
158        result: Option<serde_json::Value>,
159        error: Option<JsonRpcError>,
160    ) -> Result<()> {
161        let response = JsonRpcResponse {
162            id,
163            result,
164            error,
165            jsonrpc: Default::default(),
166        };
167        let serialized = serde_json::to_string(&response).unwrap_or_default();
168        let stdout = io::stdout();
169        let mut writer = stdout.lock();
170        debug!("Sending: {serialized}");
171        writer.write_all(serialized.as_bytes())?;
172        writer.write_all(b"\n")?;
173        writer.flush()?;
174        Ok(())
175    }
176}