use anyhow::{Context, Result};
use std::process::Stdio;
use std::sync::Arc;
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
use tokio::process::{Child, ChildStdin, ChildStdout, Command};
use tokio::sync::Mutex;
use crate::types::{JsonRpcMessage, JsonRpcNotification, JsonRpcRequest, JsonRpcResponse};
#[derive(Debug)]
pub enum Transport {
Stdio(StdioTransport),
#[cfg(feature = "http")]
Http(HttpTransport),
}
impl Transport {
pub async fn send_request(&mut self, request: &JsonRpcRequest) -> Result<()> {
match self {
Transport::Stdio(transport) => transport.send_request(request).await,
#[cfg(feature = "http")]
Transport::Http(transport) => transport.send_request(request).await,
}
}
pub async fn receive_response(&mut self) -> Result<JsonRpcResponse> {
match self {
Transport::Stdio(transport) => transport.receive_response().await,
#[cfg(feature = "http")]
Transport::Http(transport) => transport.receive_response().await,
}
}
pub async fn receive_message(&mut self) -> Result<JsonRpcMessage> {
match self {
Transport::Stdio(transport) => transport.receive_message().await,
#[cfg(feature = "http")]
Transport::Http(transport) => transport.receive_message().await,
}
}
pub async fn close(&mut self) -> Result<()> {
match self {
Transport::Stdio(transport) => transport.close().await,
#[cfg(feature = "http")]
Transport::Http(_) => Ok(()), }
}
}
#[derive(Debug)]
pub struct StdioTransport {
stdin: Arc<Mutex<ChildStdin>>,
stdout: Arc<Mutex<BufReader<ChildStdout>>>,
child: Arc<Mutex<Child>>,
}
impl StdioTransport {
pub async fn new(command: &str, args: &[String]) -> Result<Self> {
let mut child = Command::new(command)
.args(args)
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::inherit())
.spawn()
.context(format!("Failed to spawn MCP server: {}", command))?;
let stdin = child.stdin.take().context("Failed to get stdin handle")?;
let stdout = child.stdout.take().context("Failed to get stdout handle")?;
Ok(Self {
stdin: Arc::new(Mutex::new(stdin)),
stdout: Arc::new(Mutex::new(BufReader::new(stdout))),
child: Arc::new(Mutex::new(child)),
})
}
pub async fn send_request(&mut self, request: &JsonRpcRequest) -> Result<()> {
let json =
serde_json::to_string(request).context("Failed to serialize JSON-RPC request")?;
let mut stdin = self.stdin.lock().await;
stdin
.write_all(json.as_bytes())
.await
.context("Failed to write to stdin")?;
stdin
.write_all(b"\n")
.await
.context("Failed to write newline")?;
stdin.flush().await.context("Failed to flush stdin")?;
Ok(())
}
pub async fn receive_response(&mut self) -> Result<JsonRpcResponse> {
let mut stdout = self.stdout.lock().await;
let mut line = String::new();
stdout
.read_line(&mut line)
.await
.context("Failed to read from stdout")?;
if line.is_empty() {
anyhow::bail!("EOF reached, server closed");
}
serde_json::from_str(&line).context("Failed to parse JSON-RPC response")
}
pub async fn receive_message(&mut self) -> Result<JsonRpcMessage> {
let mut stdout = self.stdout.lock().await;
let mut line = String::new();
match stdout.read_line(&mut line).await {
Ok(0) => {
anyhow::bail!("MCP server closed connection (EOF on stdout)");
}
Ok(_) => {
}
Err(e) => {
let error_msg = if e.kind() == std::io::ErrorKind::BrokenPipe {
"MCP server process terminated unexpectedly (broken pipe). The server may have crashed during tool execution. Check stderr output for panic messages.".to_string()
} else if e.kind() == std::io::ErrorKind::UnexpectedEof {
"MCP server process exited unexpectedly (unexpected EOF)".to_string()
} else {
format!(
"Failed to read from MCP server stdout: {} (kind: {:?})",
e,
e.kind()
)
};
anyhow::bail!("{}", error_msg);
}
}
if line.is_empty() {
anyhow::bail!("MCP server returned empty response");
}
let value: serde_json::Value =
serde_json::from_str(&line).context("Failed to parse JSON-RPC message")?;
let has_valid_id = value.get("id").map(|id| !id.is_null()).unwrap_or(false);
if has_valid_id {
let response: JsonRpcResponse =
serde_json::from_value(value).context("Failed to parse as JSON-RPC response")?;
Ok(JsonRpcMessage::Response(response))
} else {
let notification: JsonRpcNotification = serde_json::from_value(value)
.context("Failed to parse as JSON-RPC notification")?;
Ok(JsonRpcMessage::Notification(notification))
}
}
pub async fn close(&mut self) -> Result<()> {
let mut child = self.child.lock().await;
child
.kill()
.await
.context("Failed to kill MCP server process")?;
Ok(())
}
}
#[cfg(feature = "http")]
#[derive(Debug)]
pub struct HttpTransport {
client: reqwest::Client,
mcp_url: String,
pending: Option<String>,
}
#[cfg(feature = "http")]
impl HttpTransport {
pub fn new(base_url: impl Into<String>) -> Self {
let base = base_url.into().trim_end_matches('/').to_string();
Self {
client: reqwest::Client::new(),
mcp_url: format!("{}/mcp", base),
pending: None,
}
}
pub async fn send_request(&mut self, request: &JsonRpcRequest) -> Result<()> {
let body =
serde_json::to_string(request).context("Failed to serialize JSON-RPC request")?;
self.pending = Some(body);
Ok(())
}
pub async fn receive_response(&mut self) -> Result<JsonRpcResponse> {
let msg = self.post_and_receive().await?;
match msg {
JsonRpcMessage::Response(r) => Ok(r),
JsonRpcMessage::Notification(_) => {
anyhow::bail!("Expected JSON-RPC response, got notification")
}
}
}
pub async fn receive_message(&mut self) -> Result<JsonRpcMessage> {
self.post_and_receive().await
}
async fn post_and_receive(&mut self) -> Result<JsonRpcMessage> {
let body = self
.pending
.take()
.context("receive called before send_request")?;
let resp = self
.client
.post(&self.mcp_url)
.header("content-type", "application/json")
.body(body)
.send()
.await
.context("HTTP POST to MCP server failed")?;
if !resp.status().is_success() {
anyhow::bail!("MCP server returned HTTP {}", resp.status());
}
let text = resp
.text()
.await
.context("Failed to read MCP response body")?;
let value: serde_json::Value =
serde_json::from_str(&text).context("Failed to parse MCP response as JSON")?;
let has_valid_id = value.get("id").map(|id| !id.is_null()).unwrap_or(false);
if has_valid_id {
let response: JsonRpcResponse =
serde_json::from_value(value).context("Failed to parse as JSON-RPC response")?;
Ok(JsonRpcMessage::Response(response))
} else {
let notification: JsonRpcNotification = serde_json::from_value(value)
.context("Failed to parse as JSON-RPC notification")?;
Ok(JsonRpcMessage::Notification(notification))
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[tokio::test]
async fn test_stdio_transport_echo() {
let result = StdioTransport::new("echo", &["test".to_string()]).await;
assert!(result.is_ok());
}
#[test]
fn test_json_rpc_serialization() {
let request =
JsonRpcRequest::new(1, "initialize".to_string(), Some(json!({"test": "value"})))
.unwrap();
let json = serde_json::to_string(&request).unwrap();
assert!(json.contains("jsonrpc"));
assert!(json.contains("2.0"));
assert!(json.contains("initialize"));
}
}