mcp_runner/transport/
stdio.rs

1use super::json_rpc::{JsonRpcMessage, JsonRpcRequest, JsonRpcResponse, error_codes};
2use crate::error::{Error, Result};
3use crate::transport::Transport;
4use async_process::{ChildStdin, ChildStdout};
5use async_trait::async_trait;
6use futures_lite::io::{AsyncReadExt, AsyncWriteExt};
7use serde_json::Value;
8use std::collections::HashMap;
9use std::sync::{Arc, Mutex};
10use tokio::sync::oneshot;
11use tokio::task::JoinHandle;
12use tracing::{self, Instrument, span};
13use uuid::Uuid;
14
15/// StdioTransport provides communication with an MCP server via standard I/O.
16///
17/// This implementation uses JSON-RPC over standard input/output to communicate with
18/// an MCP server. It handles concurrent requests using a background task for reading
19/// responses and dispatches them to the appropriate handler.
20/// Most public methods are instrumented with `tracing` spans.
21///
22/// # Example
23///
24/// ```no_run
25/// use mcp_runner::transport::StdioTransport;
26/// use mcp_runner::error::Result;
27/// use async_process::{Child, Command};
28/// use serde_json::json;
29///
30/// #[tokio::main]
31/// async fn main() -> Result<()> {
32///     // Start a child process
33///     let mut child = Command::new("mcp-server")
34///         .stdin(std::process::Stdio::piped())
35///         .stdout(std::process::Stdio::piped())
36///         .spawn()
37///         .expect("Failed to start MCP server");
38///     
39///     // Take ownership of stdin/stdout
40///     let stdin = child.stdin.take().expect("Failed to get stdin");
41///     let stdout = child.stdout.take().expect("Failed to get stdout");
42///     
43///     // Create transport
44///     let transport = StdioTransport::new("example-server".to_string(), stdin, stdout);
45///     
46///     // Initialize the transport
47///     transport.initialize().await?;
48///     
49///     // List available tools
50///     let tools = transport.list_tools().await?;
51///     println!("Available tools: {:?}", tools);
52///     
53///     Ok(())
54/// }
55/// ```
56pub struct StdioTransport {
57    /// Server name
58    name: String,
59    /// Child process stdin
60    stdin: Arc<Mutex<ChildStdin>>,
61    /// Response handlers
62    response_handlers: Arc<Mutex<HashMap<String, oneshot::Sender<JsonRpcResponse>>>>,
63    /// Task handle for reading from stdout
64    reader_task: Option<JoinHandle<()>>,
65}
66
67impl StdioTransport {
68    /// Creates a new StdioTransport instance.
69    ///
70    /// This constructor takes ownership of the child process's stdin and stdout,
71    /// and sets up a background task to process incoming JSON-RPC messages.
72    /// This method is instrumented with `tracing`.
73    ///
74    /// # Arguments
75    ///
76    /// * `name` - A name for this transport, typically the server name
77    /// * `stdin` - The child process's stdin handle
78    /// * `stdout` - The child process's stdout handle
79    ///
80    /// # Returns
81    ///
82    /// A new `StdioTransport` instance
83    #[tracing::instrument(skip(stdin, stdout), fields(name = %name))]
84    pub fn new(name: String, stdin: ChildStdin, mut stdout: ChildStdout) -> Self {
85        tracing::debug!("Creating new StdioTransport");
86        let stdin = Arc::new(Mutex::new(stdin));
87        let response_handlers = Arc::new(Mutex::new(HashMap::<
88            String,
89            oneshot::Sender<JsonRpcResponse>,
90        >::new()));
91
92        // Clone for the reader task
93        let response_handlers_clone = Arc::clone(&response_handlers);
94
95        // Create the span for the reader task explicitly
96        let reader_span = span!(tracing::Level::INFO, "stdout_reader", name = %name);
97
98        // Spawn a task to read from stdout
99        // Use .instrument() on the future
100        let reader_task = tokio::spawn(async move {
101            tracing::debug!("Stdout reader task started");
102            // Process stdout line by line
103            let mut buffer = Vec::new();
104            let mut buf = [0u8; 1];
105
106            loop {
107                // Try to read a single byte
108                match stdout.read(&mut buf).await {
109                    Ok(0) => {
110                        tracing::debug!("Stdout reached EOF");
111                        break;
112                    } // EOF
113                    Ok(_) => {
114                        if buf[0] == b'\n' {
115                            // Process the line
116                            if let Ok(line) = String::from_utf8(buffer.clone()) {
117                                let trimmed_line = line.trim();
118                                if trimmed_line.is_empty() {
119                                    // Ignore empty lines
120                                    buffer.clear();
121                                    continue;
122                                }
123
124                                // Check if the line looks like a JSON object before attempting to parse
125                                if !trimmed_line.starts_with('{') {
126                                    tracing::trace!(output = "stdout", line = %trimmed_line, "Ignoring non-JSON line");
127                                    buffer.clear();
128                                    continue;
129                                }
130
131                                // Attempt to parse as JSON-RPC
132                                tracing::trace!(output = "stdout", line = %trimmed_line, "Attempting to parse line as JSON-RPC");
133                                match serde_json::from_str::<JsonRpcMessage>(trimmed_line) {
134                                    Ok(JsonRpcMessage::Response(response)) => {
135                                        // Get ID as string
136                                        let id_str = match &response.id {
137                                            Value::String(s) => s.clone(),
138                                            Value::Number(n) => n.to_string(),
139                                            _ => {
140                                                tracing::warn!(response_id = ?response.id, "Received response with unexpected ID type");
141                                                continue;
142                                            }
143                                        };
144                                        tracing::debug!(response_id = %id_str, "Received JSON-RPC response");
145
146                                        // Send response to handler - handle lock errors gracefully
147                                        if let Ok(mut handlers) = response_handlers_clone.lock() {
148                                            if let Some(sender) = handlers.remove(&id_str) {
149                                                if sender.send(response).is_err() {
150                                                    tracing::warn!(response_id = %id_str, "Response handler dropped before response could be sent");
151                                                }
152                                            } else {
153                                                tracing::warn!(response_id = %id_str, "Received response for unknown or timed out request");
154                                            }
155                                        } else {
156                                            tracing::error!("Response handler lock poisoned!");
157                                        }
158                                    }
159                                    Ok(JsonRpcMessage::Request(req)) => {
160                                        tracing::warn!(method = %req.method, "Received unexpected JSON-RPC request from server");
161                                    }
162                                    Ok(JsonRpcMessage::Notification(notif)) => {
163                                        tracing::debug!(method = %notif.method, "Received JSON-RPC notification from server");
164                                    }
165                                    Err(e) => {
166                                        // Keep WARN for lines that start like JSON but fail to parse
167                                        tracing::warn!(line = %trimmed_line, error = %e, "Failed to parse potential JSON-RPC message");
168                                    }
169                                }
170                            } else {
171                                // Log if line is not valid UTF-8
172                                tracing::warn!(bytes = ?buffer, "Received non-UTF8 data on stdout");
173                            }
174                            buffer.clear();
175                        } else {
176                            buffer.push(buf[0]);
177                        }
178                    }
179                    Err(e) => {
180                        tracing::error!(error = %e, "Error reading from stdout");
181                        break;
182                    } // Error
183                }
184            }
185            tracing::debug!("Stdout reader task finished");
186        }.instrument(reader_span)); // Apply the span to the future
187
188        Self {
189            name,
190            stdin,
191            response_handlers,
192            reader_task: Some(reader_task),
193        }
194    }
195
196    /// Gets the name of the server associated with this transport.
197    ///
198    /// # Returns
199    ///
200    /// A string slice containing the server name.
201    pub fn name(&self) -> &str {
202        &self.name
203    }
204
205    /// Writes data to the child process's stdin.
206    ///
207    /// This is a helper function that handles the complexity of writing to
208    /// stdin in a thread-safe and non-blocking way.
209    /// This method is instrumented with `tracing`.
210    ///
211    /// # Arguments
212    ///
213    /// * `data` - The bytes to write to stdin
214    ///
215    /// # Returns
216    ///
217    /// A `Result<()>` indicating success or failure
218    #[tracing::instrument(skip(self, data), fields(name = %self.name))]
219    async fn write_to_stdin(&self, data: Vec<u8>) -> Result<()> {
220        tracing::trace!(bytes_len = data.len(), "Writing to stdin");
221        let stdin_clone = self.stdin.clone();
222
223        tokio::task::spawn_blocking(move || -> Result<()> {
224            let stdin_lock = stdin_clone
225                .lock()
226                .map_err(|_| Error::Communication("Failed to acquire stdin lock".to_string()))?;
227
228            let mut stdin = stdin_lock;
229
230            futures_lite::future::block_on(async {
231                stdin.write_all(&data).await.map_err(|e| {
232                    Error::Communication(format!("Failed to write to stdin: {}", e))
233                })?;
234                stdin
235                    .flush()
236                    .await
237                    .map_err(|e| Error::Communication(format!("Failed to flush stdin: {}", e)))?;
238                Ok::<(), Error>(())
239            })?;
240
241            Ok(())
242        })
243        .await
244        .map_err(|e| {
245            tracing::error!(error = %e, "Stdin write task panicked");
246            Error::Communication(format!("Task join error: {}", e))
247        })??;
248
249        tracing::trace!("Finished writing to stdin");
250        Ok(())
251    }
252
253    /// Sends a JSON-RPC request and waits for a response.
254    ///
255    /// This method handles the details of sending a request, registering a response
256    /// handler, and waiting for the response to arrive.
257    /// This method is instrumented with `tracing`.
258    ///
259    /// # Arguments
260    ///
261    /// * `request` - The JSON-RPC request to send
262    ///
263    /// # Returns
264    ///
265    /// A `Result<JsonRpcResponse>` containing the response if successful
266    #[tracing::instrument(skip(self, request), fields(name = %self.name, method = %request.method, request_id = ?request.id))]
267    pub async fn send_request(&self, request: JsonRpcRequest) -> Result<JsonRpcResponse> {
268        tracing::debug!("Sending JSON-RPC request");
269        let id_str = match &request.id {
270            Value::String(s) => s.clone(),
271            Value::Number(n) => n.to_string(),
272            _ => return Err(Error::Communication("Invalid request ID type".to_string())),
273        };
274
275        let (sender, receiver) = oneshot::channel();
276
277        {
278            let mut handlers = self.response_handlers.lock().map_err(|_| {
279                Error::Communication("Failed to lock response handlers".to_string())
280            })?;
281            handlers.insert(id_str, sender);
282        }
283
284        let request_json = serde_json::to_string(&request)
285            .map_err(|e| Error::Serialization(format!("Failed to serialize request: {}", e)))?;
286        tracing::trace!(request_json = %request_json, "Sending request JSON");
287        let request_bytes = request_json.into_bytes();
288        let mut request_bytes_with_newline = request_bytes;
289        request_bytes_with_newline.push(b'\n');
290
291        self.write_to_stdin(request_bytes_with_newline).await?;
292
293        tracing::debug!("Waiting for response");
294        let response = receiver.await.map_err(|_| {
295            tracing::warn!("Sender dropped before response received (likely timeout or closed)");
296            Error::Communication("Failed to receive response".to_string())
297        })?;
298
299        if let Some(error) = &response.error {
300            tracing::error!(error_code = error.code, error_message = %error.message, "Received JSON-RPC error response");
301            return Err(Error::JsonRpc(error.message.clone()));
302        }
303
304        tracing::debug!("Received successful response");
305        Ok(response)
306    }
307
308    /// Sends a JSON-RPC notification (no response expected).
309    ///
310    /// Unlike requests, notifications don't expect a response, so this method
311    /// just sends the message without setting up a response handler.
312    /// This method is instrumented with `tracing`.
313    ///
314    /// # Arguments
315    ///
316    /// * `notification` - The JSON-RPC notification to send
317    ///
318    /// # Returns
319    ///
320    /// A `Result<()>` indicating success or failure
321    #[tracing::instrument(skip(self, notification), fields(name = %self.name, method = notification.get("method").and_then(|v| v.as_str())))]
322    pub async fn send_notification(&self, notification: serde_json::Value) -> Result<()> {
323        tracing::debug!("Sending JSON-RPC notification");
324        let notification_json = serde_json::to_string(&notification).map_err(|e| {
325            Error::Serialization(format!("Failed to serialize notification: {}", e))
326        })?;
327        tracing::trace!(notification_json = %notification_json, "Sending notification JSON");
328        let notification_bytes = notification_json.into_bytes();
329        let mut notification_bytes_with_newline = notification_bytes;
330        notification_bytes_with_newline.push(b'\n');
331
332        self.write_to_stdin(notification_bytes_with_newline).await
333    }
334
335    /// Initializes the MCP server.
336    ///
337    /// Sends the `notifications/initialized` notification to the server,
338    /// indicating that the client is ready to communicate.
339    /// This method is instrumented with `tracing`.
340    ///
341    /// # Returns
342    ///
343    /// A `Result<()>` indicating success or failure
344    #[tracing::instrument(skip(self), fields(name = %self.name))]
345    pub async fn initialize(&self) -> Result<()> {
346        tracing::info!("Initializing MCP connection");
347        let notification = serde_json::json!({
348            "jsonrpc": "2.0",
349            "method": "notifications/initialized"
350        });
351
352        self.send_notification(notification).await
353    }
354
355    /// Lists available tools provided by the MCP server.
356    ///
357    /// This method is instrumented with `tracing`.
358    ///
359    /// # Returns
360    ///
361    /// A `Result<Vec<Value>>` containing the list of tools if successful
362    #[tracing::instrument(skip(self), fields(name = %self.name))]
363    pub async fn list_tools(&self) -> Result<Vec<Value>> {
364        tracing::debug!("Listing tools");
365        let request_id = Uuid::new_v4().to_string();
366        let request = JsonRpcRequest::list_tools(request_id);
367
368        let response = self.send_request(request).await?;
369
370        if let Some(Value::Object(result)) = response.result {
371            if let Some(Value::Array(tools)) = result.get("tools") {
372                return Ok(tools.clone());
373            }
374        }
375
376        Ok(Vec::new())
377    }
378
379    /// Calls a tool provided by the MCP server.
380    ///
381    /// This method is instrumented with `tracing`.
382    ///
383    /// # Arguments
384    ///
385    /// * `name` - The name of the tool to call
386    /// * `args` - The arguments to pass to the tool
387    ///
388    /// # Returns
389    ///
390    /// A `Result<Value>` containing the tool's response if successful
391    #[tracing::instrument(skip(self, args), fields(name = %self.name, tool_name = %name.as_ref()))]
392    pub async fn call_tool(
393        &self,
394        name: impl AsRef<str> + std::fmt::Debug,
395        args: Value,
396    ) -> Result<Value> {
397        tracing::debug!(args = ?args, "Calling tool");
398        let request_id = Uuid::new_v4().to_string();
399        let request = JsonRpcRequest::call_tool(request_id, name.as_ref().to_string(), args);
400
401        let response = self.send_request(request).await?;
402
403        response
404            .result
405            .ok_or_else(|| Error::Communication("No result in response".to_string()))
406    }
407
408    /// Lists available resources provided by the MCP server.
409    ///
410    /// This method is instrumented with `tracing`.
411    ///
412    /// # Returns
413    ///
414    /// A `Result<Vec<Value>>` containing the list of resources if successful
415    #[tracing::instrument(skip(self), fields(name = %self.name))]
416    pub async fn list_resources(&self) -> Result<Vec<Value>> {
417        tracing::debug!("Listing resources");
418        let request_id = Uuid::new_v4().to_string();
419        let request = JsonRpcRequest::list_resources(request_id);
420
421        let response = self.send_request(request).await?;
422
423        if let Some(Value::Object(result)) = response.result {
424            if let Some(Value::Array(resources)) = result.get("resources") {
425                return Ok(resources.clone());
426            }
427        }
428
429        Ok(Vec::new())
430    }
431
432    /// Retrieves a specific resource from the MCP server.
433    ///
434    /// This method is instrumented with `tracing`.
435    ///
436    /// # Arguments
437    ///
438    /// * `uri` - The URI of the resource to retrieve
439    ///
440    /// # Returns
441    ///
442    /// A `Result<Value>` containing the resource data if successful
443    #[tracing::instrument(skip(self), fields(name = %self.name, uri = %uri.as_ref()))]
444    pub async fn get_resource(&self, uri: impl AsRef<str> + std::fmt::Debug) -> Result<Value> {
445        tracing::debug!("Getting resource");
446        let request_id = Uuid::new_v4().to_string();
447        let request = JsonRpcRequest::get_resource(request_id, uri.as_ref().to_string());
448
449        let response = self.send_request(request).await?;
450
451        response
452            .result
453            .ok_or_else(|| Error::Communication("No result in response".to_string()))
454    }
455
456    /// Closes the transport and cleans up resources.
457    ///
458    /// This method should be called when the transport is no longer needed
459    /// to ensure proper cleanup of background tasks and resources.
460    /// This method is instrumented with `tracing`.
461    ///
462    /// # Returns
463    ///
464    /// A `Result<()>` indicating success or failure
465    #[tracing::instrument(skip(self), fields(name = %self.name))]
466    pub async fn close(&mut self) -> Result<()> {
467        tracing::info!("Closing transport");
468        if let Some(task) = self.reader_task.take() {
469            task.abort();
470            let _ = task.await;
471        }
472
473        if let Ok(mut handlers) = self.response_handlers.lock() {
474            for (_, sender) in handlers.drain() {
475                let _ = sender.send(JsonRpcResponse {
476                    jsonrpc: "2.0".to_string(),
477                    id: Value::Null,
478                    result: None,
479                    error: Some(super::json_rpc::JsonRpcError {
480                        code: error_codes::SERVER_ERROR,
481                        message: "Connection closed".to_string(),
482                        data: None,
483                    }),
484                });
485            }
486        }
487
488        Ok(())
489    }
490}
491
492#[async_trait]
493impl Transport for StdioTransport {
494    /// This method is instrumented with `tracing`.
495    #[tracing::instrument(skip(self), fields(name = %self.name()))]
496    async fn initialize(&self) -> Result<()> {
497        self.initialize().await
498    }
499
500    /// This method is instrumented with `tracing`.
501    #[tracing::instrument(skip(self), fields(name = %self.name()))]
502    async fn list_tools(&self) -> Result<Vec<Value>> {
503        self.list_tools().await
504    }
505
506    /// This method is instrumented with `tracing`.
507    #[tracing::instrument(skip(self, args), fields(name = %self.name(), tool_name = %name))]
508    async fn call_tool(&self, name: &str, args: Value) -> Result<Value> {
509        self.call_tool(name.to_string(), args).await
510    }
511
512    /// This method is instrumented with `tracing`.
513    #[tracing::instrument(skip(self), fields(name = %self.name()))]
514    async fn list_resources(&self) -> Result<Vec<Value>> {
515        self.list_resources().await
516    }
517
518    /// This method is instrumented with `tracing`.
519    #[tracing::instrument(skip(self), fields(name = %self.name(), uri = %uri))]
520    async fn get_resource(&self, uri: &str) -> Result<Value> {
521        self.get_resource(uri.to_string()).await
522    }
523}