1use super::json_rpc::{JsonRpcMessage, JsonRpcRequest, JsonRpcResponse, error_codes};
2use crate::error::{Error, Result};
3use crate::transport::Transport;
4use async_process::{ChildStdin, ChildStdout};
5use async_trait::async_trait;
6use futures_lite::io::{AsyncReadExt, AsyncWriteExt};
7use serde_json::Value;
8use std::collections::HashMap;
9use std::sync::{Arc, Mutex};
10use tokio::sync::oneshot;
11use tokio::task::JoinHandle;
12use tracing::{self, Instrument, span};
13use uuid::Uuid;
14
15pub struct StdioTransport {
57 name: String,
59 stdin: Arc<Mutex<ChildStdin>>,
61 response_handlers: Arc<Mutex<HashMap<String, oneshot::Sender<JsonRpcResponse>>>>,
63 reader_task: Option<JoinHandle<()>>,
65}
66
67impl StdioTransport {
68 #[tracing::instrument(skip(stdin, stdout), fields(name = %name))]
84 pub fn new(name: String, stdin: ChildStdin, mut stdout: ChildStdout) -> Self {
85 tracing::debug!("Creating new StdioTransport");
86 let stdin = Arc::new(Mutex::new(stdin));
87 let response_handlers = Arc::new(Mutex::new(HashMap::<
88 String,
89 oneshot::Sender<JsonRpcResponse>,
90 >::new()));
91
92 let response_handlers_clone = Arc::clone(&response_handlers);
94
95 let reader_span = span!(tracing::Level::INFO, "stdout_reader", name = %name);
97
98 let reader_task = tokio::spawn(async move {
101 tracing::debug!("Stdout reader task started");
102 let mut buffer = Vec::new();
104 let mut buf = [0u8; 1];
105
106 loop {
107 match stdout.read(&mut buf).await {
109 Ok(0) => {
110 tracing::debug!("Stdout reached EOF");
111 break;
112 } Ok(_) => {
114 if buf[0] == b'\n' {
115 if let Ok(line) = String::from_utf8(buffer.clone()) {
117 let trimmed_line = line.trim();
118 if trimmed_line.is_empty() {
119 buffer.clear();
121 continue;
122 }
123
124 if !trimmed_line.starts_with('{') {
126 tracing::trace!(output = "stdout", line = %trimmed_line, "Ignoring non-JSON line");
127 buffer.clear();
128 continue;
129 }
130
131 tracing::trace!(output = "stdout", line = %trimmed_line, "Attempting to parse line as JSON-RPC");
133 match serde_json::from_str::<JsonRpcMessage>(trimmed_line) {
134 Ok(JsonRpcMessage::Response(response)) => {
135 let id_str = match &response.id {
137 Value::String(s) => s.clone(),
138 Value::Number(n) => n.to_string(),
139 _ => {
140 tracing::warn!(response_id = ?response.id, "Received response with unexpected ID type");
141 continue;
142 }
143 };
144 tracing::debug!(response_id = %id_str, "Received JSON-RPC response");
145
146 if let Ok(mut handlers) = response_handlers_clone.lock() {
148 if let Some(sender) = handlers.remove(&id_str) {
149 if sender.send(response).is_err() {
150 tracing::warn!(response_id = %id_str, "Response handler dropped before response could be sent");
151 }
152 } else {
153 tracing::warn!(response_id = %id_str, "Received response for unknown or timed out request");
154 }
155 } else {
156 tracing::error!("Response handler lock poisoned!");
157 }
158 }
159 Ok(JsonRpcMessage::Request(req)) => {
160 tracing::warn!(method = %req.method, "Received unexpected JSON-RPC request from server");
161 }
162 Ok(JsonRpcMessage::Notification(notif)) => {
163 tracing::debug!(method = %notif.method, "Received JSON-RPC notification from server");
164 }
165 Err(e) => {
166 tracing::warn!(line = %trimmed_line, error = %e, "Failed to parse potential JSON-RPC message");
168 }
169 }
170 } else {
171 tracing::warn!(bytes = ?buffer, "Received non-UTF8 data on stdout");
173 }
174 buffer.clear();
175 } else {
176 buffer.push(buf[0]);
177 }
178 }
179 Err(e) => {
180 tracing::error!(error = %e, "Error reading from stdout");
181 break;
182 } }
184 }
185 tracing::debug!("Stdout reader task finished");
186 }.instrument(reader_span)); Self {
189 name,
190 stdin,
191 response_handlers,
192 reader_task: Some(reader_task),
193 }
194 }
195
196 pub fn name(&self) -> &str {
202 &self.name
203 }
204
205 #[tracing::instrument(skip(self, data), fields(name = %self.name))]
219 async fn write_to_stdin(&self, data: Vec<u8>) -> Result<()> {
220 tracing::trace!(bytes_len = data.len(), "Writing to stdin");
221 let stdin_clone = self.stdin.clone();
222
223 tokio::task::spawn_blocking(move || -> Result<()> {
224 let stdin_lock = stdin_clone
225 .lock()
226 .map_err(|_| Error::Communication("Failed to acquire stdin lock".to_string()))?;
227
228 let mut stdin = stdin_lock;
229
230 futures_lite::future::block_on(async {
231 stdin.write_all(&data).await.map_err(|e| {
232 Error::Communication(format!("Failed to write to stdin: {}", e))
233 })?;
234 stdin
235 .flush()
236 .await
237 .map_err(|e| Error::Communication(format!("Failed to flush stdin: {}", e)))?;
238 Ok::<(), Error>(())
239 })?;
240
241 Ok(())
242 })
243 .await
244 .map_err(|e| {
245 tracing::error!(error = %e, "Stdin write task panicked");
246 Error::Communication(format!("Task join error: {}", e))
247 })??;
248
249 tracing::trace!("Finished writing to stdin");
250 Ok(())
251 }
252
253 #[tracing::instrument(skip(self, request), fields(name = %self.name, method = %request.method, request_id = ?request.id))]
267 pub async fn send_request(&self, request: JsonRpcRequest) -> Result<JsonRpcResponse> {
268 tracing::debug!("Sending JSON-RPC request");
269 let id_str = match &request.id {
270 Value::String(s) => s.clone(),
271 Value::Number(n) => n.to_string(),
272 _ => return Err(Error::Communication("Invalid request ID type".to_string())),
273 };
274
275 let (sender, receiver) = oneshot::channel();
276
277 {
278 let mut handlers = self.response_handlers.lock().map_err(|_| {
279 Error::Communication("Failed to lock response handlers".to_string())
280 })?;
281 handlers.insert(id_str, sender);
282 }
283
284 let request_json = serde_json::to_string(&request)
285 .map_err(|e| Error::Serialization(format!("Failed to serialize request: {}", e)))?;
286 tracing::trace!(request_json = %request_json, "Sending request JSON");
287 let request_bytes = request_json.into_bytes();
288 let mut request_bytes_with_newline = request_bytes;
289 request_bytes_with_newline.push(b'\n');
290
291 self.write_to_stdin(request_bytes_with_newline).await?;
292
293 tracing::debug!("Waiting for response");
294 let response = receiver.await.map_err(|_| {
295 tracing::warn!("Sender dropped before response received (likely timeout or closed)");
296 Error::Communication("Failed to receive response".to_string())
297 })?;
298
299 if let Some(error) = &response.error {
300 tracing::error!(error_code = error.code, error_message = %error.message, "Received JSON-RPC error response");
301 return Err(Error::JsonRpc(error.message.clone()));
302 }
303
304 tracing::debug!("Received successful response");
305 Ok(response)
306 }
307
308 #[tracing::instrument(skip(self, notification), fields(name = %self.name, method = notification.get("method").and_then(|v| v.as_str())))]
322 pub async fn send_notification(&self, notification: serde_json::Value) -> Result<()> {
323 tracing::debug!("Sending JSON-RPC notification");
324 let notification_json = serde_json::to_string(¬ification).map_err(|e| {
325 Error::Serialization(format!("Failed to serialize notification: {}", e))
326 })?;
327 tracing::trace!(notification_json = %notification_json, "Sending notification JSON");
328 let notification_bytes = notification_json.into_bytes();
329 let mut notification_bytes_with_newline = notification_bytes;
330 notification_bytes_with_newline.push(b'\n');
331
332 self.write_to_stdin(notification_bytes_with_newline).await
333 }
334
335 #[tracing::instrument(skip(self), fields(name = %self.name))]
345 pub async fn initialize(&self) -> Result<()> {
346 tracing::info!("Initializing MCP connection");
347 let notification = serde_json::json!({
348 "jsonrpc": "2.0",
349 "method": "notifications/initialized"
350 });
351
352 self.send_notification(notification).await
353 }
354
355 #[tracing::instrument(skip(self), fields(name = %self.name))]
363 pub async fn list_tools(&self) -> Result<Vec<Value>> {
364 tracing::debug!("Listing tools");
365 let request_id = Uuid::new_v4().to_string();
366 let request = JsonRpcRequest::list_tools(request_id);
367
368 let response = self.send_request(request).await?;
369
370 if let Some(Value::Object(result)) = response.result {
371 if let Some(Value::Array(tools)) = result.get("tools") {
372 return Ok(tools.clone());
373 }
374 }
375
376 Ok(Vec::new())
377 }
378
379 #[tracing::instrument(skip(self, args), fields(name = %self.name, tool_name = %name.as_ref()))]
392 pub async fn call_tool(
393 &self,
394 name: impl AsRef<str> + std::fmt::Debug,
395 args: Value,
396 ) -> Result<Value> {
397 tracing::debug!(args = ?args, "Calling tool");
398 let request_id = Uuid::new_v4().to_string();
399 let request = JsonRpcRequest::call_tool(request_id, name.as_ref().to_string(), args);
400
401 let response = self.send_request(request).await?;
402
403 response
404 .result
405 .ok_or_else(|| Error::Communication("No result in response".to_string()))
406 }
407
408 #[tracing::instrument(skip(self), fields(name = %self.name))]
416 pub async fn list_resources(&self) -> Result<Vec<Value>> {
417 tracing::debug!("Listing resources");
418 let request_id = Uuid::new_v4().to_string();
419 let request = JsonRpcRequest::list_resources(request_id);
420
421 let response = self.send_request(request).await?;
422
423 if let Some(Value::Object(result)) = response.result {
424 if let Some(Value::Array(resources)) = result.get("resources") {
425 return Ok(resources.clone());
426 }
427 }
428
429 Ok(Vec::new())
430 }
431
432 #[tracing::instrument(skip(self), fields(name = %self.name, uri = %uri.as_ref()))]
444 pub async fn get_resource(&self, uri: impl AsRef<str> + std::fmt::Debug) -> Result<Value> {
445 tracing::debug!("Getting resource");
446 let request_id = Uuid::new_v4().to_string();
447 let request = JsonRpcRequest::get_resource(request_id, uri.as_ref().to_string());
448
449 let response = self.send_request(request).await?;
450
451 response
452 .result
453 .ok_or_else(|| Error::Communication("No result in response".to_string()))
454 }
455
456 #[tracing::instrument(skip(self), fields(name = %self.name))]
466 pub async fn close(&mut self) -> Result<()> {
467 tracing::info!("Closing transport");
468 if let Some(task) = self.reader_task.take() {
469 task.abort();
470 let _ = task.await;
471 }
472
473 if let Ok(mut handlers) = self.response_handlers.lock() {
474 for (_, sender) in handlers.drain() {
475 let _ = sender.send(JsonRpcResponse {
476 jsonrpc: "2.0".to_string(),
477 id: Value::Null,
478 result: None,
479 error: Some(super::json_rpc::JsonRpcError {
480 code: error_codes::SERVER_ERROR,
481 message: "Connection closed".to_string(),
482 data: None,
483 }),
484 });
485 }
486 }
487
488 Ok(())
489 }
490}
491
492#[async_trait]
493impl Transport for StdioTransport {
494 #[tracing::instrument(skip(self), fields(name = %self.name()))]
496 async fn initialize(&self) -> Result<()> {
497 self.initialize().await
498 }
499
500 #[tracing::instrument(skip(self), fields(name = %self.name()))]
502 async fn list_tools(&self) -> Result<Vec<Value>> {
503 self.list_tools().await
504 }
505
506 #[tracing::instrument(skip(self, args), fields(name = %self.name(), tool_name = %name))]
508 async fn call_tool(&self, name: &str, args: Value) -> Result<Value> {
509 self.call_tool(name.to_string(), args).await
510 }
511
512 #[tracing::instrument(skip(self), fields(name = %self.name()))]
514 async fn list_resources(&self) -> Result<Vec<Value>> {
515 self.list_resources().await
516 }
517
518 #[tracing::instrument(skip(self), fields(name = %self.name(), uri = %uri))]
520 async fn get_resource(&self, uri: &str) -> Result<Value> {
521 self.get_resource(uri.to_string()).await
522 }
523}