use async_trait::async_trait;
use serde_json::Value;
use std::collections::HashMap;
use std::process::Stdio;
use std::sync::Arc;
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader, BufWriter};
use tokio::process::{Child, Command};
use tokio::sync::{Mutex, mpsc};
use tokio::time::{Duration, timeout};
use crate::core::error::{McpError, McpResult};
use crate::protocol::types::{JsonRpcNotification, JsonRpcRequest, JsonRpcResponse, error_codes};
use crate::transport::traits::{
ConnectionState, ServerRequestHandler, ServerTransport, Transport, TransportConfig,
};
pub struct StdioClientTransport {
child: Option<Child>,
stdin_writer: Option<BufWriter<tokio::process::ChildStdin>>,
#[allow(dead_code)]
stdout_reader: Option<BufReader<tokio::process::ChildStdout>>,
notification_receiver: Option<mpsc::UnboundedReceiver<JsonRpcNotification>>,
pending_requests: Arc<Mutex<HashMap<Value, tokio::sync::oneshot::Sender<JsonRpcResponse>>>>,
config: TransportConfig,
state: ConnectionState,
}
impl StdioClientTransport {
pub async fn new<S: AsRef<str>>(command: S, args: Vec<S>) -> McpResult<Self> {
Self::with_config(command, args, TransportConfig::default()).await
}
pub async fn with_config<S: AsRef<str>>(
command: S,
args: Vec<S>,
config: TransportConfig,
) -> McpResult<Self> {
let command_str = command.as_ref();
let args_str: Vec<&str> = args.iter().map(|s| s.as_ref()).collect();
tracing::debug!("Starting MCP server: {} {:?}", command_str, args_str);
let mut child = Command::new(command_str)
.args(&args_str)
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.spawn()
.map_err(|e| McpError::transport(format!("Failed to start server process: {e}")))?;
let stdin = child
.stdin
.take()
.ok_or_else(|| McpError::transport("Failed to get stdin handle"))?;
let stdout = child
.stdout
.take()
.ok_or_else(|| McpError::transport("Failed to get stdout handle"))?;
let stdin_writer = BufWriter::new(stdin);
let stdout_reader = BufReader::new(stdout);
let (notification_sender, notification_receiver) = mpsc::unbounded_channel();
let pending_requests = Arc::new(Mutex::new(HashMap::new()));
let reader_pending_requests = pending_requests.clone();
let reader = stdout_reader;
tokio::spawn(async move {
Self::message_processor(reader, notification_sender, reader_pending_requests).await;
});
Ok(Self {
child: Some(child),
stdin_writer: Some(stdin_writer),
stdout_reader: None, notification_receiver: Some(notification_receiver),
pending_requests,
config,
state: ConnectionState::Connected,
})
}
async fn message_processor(
mut reader: BufReader<tokio::process::ChildStdout>,
notification_sender: mpsc::UnboundedSender<JsonRpcNotification>,
pending_requests: Arc<Mutex<HashMap<Value, tokio::sync::oneshot::Sender<JsonRpcResponse>>>>,
) {
let mut line = String::new();
loop {
line.clear();
match reader.read_line(&mut line).await {
Ok(0) => {
tracing::debug!("STDIO reader reached EOF");
break;
}
Ok(_) => {
let line = line.trim();
if line.is_empty() {
continue;
}
tracing::trace!("Received: {}", line);
if let Ok(response) = serde_json::from_str::<JsonRpcResponse>(line) {
let mut pending = pending_requests.lock().await;
match pending.remove(&response.id) {
Some(sender) => {
let _ = sender.send(response);
}
_ => {
tracing::warn!(
"Received response for unknown request ID: {:?}",
response.id
);
}
}
}
else if let Ok(notification) =
serde_json::from_str::<JsonRpcNotification>(line)
{
if notification_sender.send(notification).is_err() {
tracing::debug!("Notification receiver dropped");
break;
}
} else {
tracing::warn!("Failed to parse message: {}", line);
}
}
Err(e) => {
tracing::error!("Error reading from stdout: {}", e);
break;
}
}
}
}
}
#[async_trait]
impl Transport for StdioClientTransport {
async fn send_request(&mut self, request: JsonRpcRequest) -> McpResult<JsonRpcResponse> {
let writer = self
.stdin_writer
.as_mut()
.ok_or_else(|| McpError::transport("Transport not connected"))?;
let (sender, receiver) = tokio::sync::oneshot::channel();
{
let mut pending = self.pending_requests.lock().await;
pending.insert(request.id.clone(), sender);
}
let request_line = serde_json::to_string(&request).map_err(McpError::serialization)?;
tracing::trace!("Sending: {}", request_line);
writer
.write_all(request_line.as_bytes())
.await
.map_err(|e| McpError::transport(format!("Failed to write request: {e}")))?;
writer
.write_all(b"\n")
.await
.map_err(|e| McpError::transport(format!("Failed to write newline: {e}")))?;
writer
.flush()
.await
.map_err(|e| McpError::transport(format!("Failed to flush: {e}")))?;
let timeout_duration = Duration::from_millis(self.config.read_timeout_ms.unwrap_or(60_000));
let response = timeout(timeout_duration, receiver)
.await
.map_err(|_| McpError::timeout("Request timeout"))?
.map_err(|_| McpError::transport("Response channel closed"))?;
Ok(response)
}
async fn send_notification(&mut self, notification: JsonRpcNotification) -> McpResult<()> {
let writer = self
.stdin_writer
.as_mut()
.ok_or_else(|| McpError::transport("Transport not connected"))?;
let notification_line =
serde_json::to_string(¬ification).map_err(McpError::serialization)?;
tracing::trace!("Sending notification: {}", notification_line);
writer
.write_all(notification_line.as_bytes())
.await
.map_err(|e| McpError::transport(format!("Failed to write notification: {e}")))?;
writer
.write_all(b"\n")
.await
.map_err(|e| McpError::transport(format!("Failed to write newline: {e}")))?;
writer
.flush()
.await
.map_err(|e| McpError::transport(format!("Failed to flush: {e}")))?;
Ok(())
}
async fn receive_notification(&mut self) -> McpResult<Option<JsonRpcNotification>> {
if let Some(ref mut receiver) = self.notification_receiver {
match receiver.try_recv() {
Ok(notification) => Ok(Some(notification)),
Err(mpsc::error::TryRecvError::Empty) => Ok(None),
Err(mpsc::error::TryRecvError::Disconnected) => {
Err(McpError::transport("Notification channel disconnected"))
}
}
} else {
Ok(None)
}
}
async fn close(&mut self) -> McpResult<()> {
tracing::debug!("Closing STDIO transport");
self.state = ConnectionState::Closing;
if let Some(mut writer) = self.stdin_writer.take() {
let _ = writer.shutdown().await;
}
if let Some(mut child) = self.child.take() {
match timeout(Duration::from_secs(5), child.wait()).await {
Ok(Ok(status)) => {
tracing::debug!("Server process exited with status: {}", status);
}
Ok(Err(e)) => {
tracing::warn!("Error waiting for server process: {}", e);
}
Err(_) => {
tracing::warn!("Timeout waiting for server process, killing it");
let _ = child.kill().await;
}
}
}
self.state = ConnectionState::Disconnected;
Ok(())
}
fn is_connected(&self) -> bool {
matches!(self.state, ConnectionState::Connected)
}
fn connection_info(&self) -> String {
let state = &self.state;
format!("STDIO transport (state: {state:?})")
}
}
pub struct StdioServerTransport {
stdin_reader: Option<BufReader<tokio::io::Stdin>>,
stdout_writer: Option<BufWriter<tokio::io::Stdout>>,
#[allow(dead_code)]
config: TransportConfig,
running: bool,
request_handler: Option<ServerRequestHandler>,
}
impl StdioServerTransport {
pub fn new() -> Self {
Self::with_config(TransportConfig::default())
}
pub fn with_config(config: TransportConfig) -> Self {
let stdin_reader = BufReader::new(tokio::io::stdin());
let stdout_writer = BufWriter::new(tokio::io::stdout());
Self {
stdin_reader: Some(stdin_reader),
stdout_writer: Some(stdout_writer),
config,
running: false,
request_handler: None,
}
}
}
#[async_trait]
impl ServerTransport for StdioServerTransport {
async fn start(&mut self) -> McpResult<()> {
tracing::debug!("Starting STDIO server transport");
let mut reader = self
.stdin_reader
.take()
.ok_or_else(|| McpError::transport("STDIN reader already taken"))?;
let mut writer = self
.stdout_writer
.take()
.ok_or_else(|| McpError::transport("STDOUT writer already taken"))?;
self.running = true;
let request_handler = self.request_handler.clone();
let mut line = String::new();
while self.running {
line.clear();
match reader.read_line(&mut line).await {
Ok(0) => {
tracing::debug!("STDIN closed, stopping server");
break;
}
Ok(_) => {
let line = line.trim();
if line.is_empty() {
continue;
}
tracing::trace!("Received: {}", line);
match serde_json::from_str::<JsonRpcRequest>(line) {
Ok(request) => {
let response_result = if let Some(ref handler) = request_handler {
handler(request.clone()).await
} else {
Err(McpError::protocol(format!(
"Method '{}' not found",
request.method
)))
};
let response_or_error = match response_result {
Ok(response) => serde_json::to_string(&response),
Err(error) => {
let json_rpc_error = crate::protocol::types::JsonRpcError {
jsonrpc: "2.0".to_string(),
id: request.id,
error: crate::protocol::types::ErrorObject {
code: match error {
McpError::Protocol(ref msg) if msg.contains("not found") => {
error_codes::METHOD_NOT_FOUND
}
_ => crate::protocol::types::error_codes::INTERNAL_ERROR,
},
message: error.to_string(),
data: None,
},
};
serde_json::to_string(&json_rpc_error)
}
};
let response_line =
response_or_error.map_err(McpError::serialization)?;
tracing::trace!("Sending: {}", response_line);
writer
.write_all(response_line.as_bytes())
.await
.map_err(|e| {
McpError::transport(format!("Failed to write response: {e}"))
})?;
writer.write_all(b"\n").await.map_err(|e| {
McpError::transport(format!("Failed to write newline: {e}"))
})?;
writer.flush().await.map_err(|e| {
McpError::transport(format!("Failed to flush: {e}"))
})?;
}
Err(e) => {
tracing::warn!("Failed to parse request: {} - Error: {}", line, e);
}
}
}
Err(e) => {
tracing::error!("Error reading from stdin: {}", e);
return Err(McpError::io(e));
}
}
}
Ok(())
}
fn set_request_handler(&mut self, handler: ServerRequestHandler) {
self.request_handler = Some(handler);
}
async fn send_notification(&mut self, notification: JsonRpcNotification) -> McpResult<()> {
let writer = self
.stdout_writer
.as_mut()
.ok_or_else(|| McpError::transport("STDOUT writer not available"))?;
let notification_line =
serde_json::to_string(¬ification).map_err(McpError::serialization)?;
tracing::trace!("Sending notification: {}", notification_line);
writer
.write_all(notification_line.as_bytes())
.await
.map_err(|e| McpError::transport(format!("Failed to write notification: {e}")))?;
writer
.write_all(b"\n")
.await
.map_err(|e| McpError::transport(format!("Failed to write newline: {e}")))?;
writer
.flush()
.await
.map_err(|e| McpError::transport(format!("Failed to flush: {e}")))?;
Ok(())
}
async fn stop(&mut self) -> McpResult<()> {
tracing::debug!("Stopping STDIO server transport");
self.running = false;
Ok(())
}
fn is_running(&self) -> bool {
self.running
}
fn server_info(&self) -> String {
format!("STDIO server transport (running: {})", self.running)
}
}
impl StdioServerTransport {
pub async fn handle_request(&mut self, request: JsonRpcRequest) -> McpResult<JsonRpcResponse> {
Err(McpError::protocol(format!(
"Method '{}' not found (test mode)",
request.method
)))
}
}
impl Default for StdioServerTransport {
fn default() -> Self {
Self::new()
}
}
impl Drop for StdioClientTransport {
fn drop(&mut self) {
if let Some(mut child) = self.child.take() {
let _ = child.start_kill();
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn test_stdio_server_creation() {
let transport = StdioServerTransport::new();
assert!(!transport.is_running());
assert!(transport.stdin_reader.is_some());
assert!(transport.stdout_writer.is_some());
}
#[test]
fn test_stdio_server_with_config() {
let config = TransportConfig {
read_timeout_ms: Some(30_000),
..Default::default()
};
let transport = StdioServerTransport::with_config(config);
assert_eq!(transport.config.read_timeout_ms, Some(30_000));
}
#[tokio::test]
async fn test_stdio_server_handle_request() {
let mut transport = StdioServerTransport::new();
let request = JsonRpcRequest {
jsonrpc: "2.0".to_string(),
id: json!(1),
method: "unknown_method".to_string(),
params: None,
};
let result = transport.handle_request(request).await;
assert!(result.is_err());
match result.unwrap_err() {
McpError::Protocol(msg) => assert!(msg.contains("unknown_method")),
_ => panic!("Expected Protocol error"),
}
}
}