use std::collections::HashMap;
use std::process::Stdio;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use futures::future::BoxFuture;
use serde_json::Value;
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
use tokio::process::{Child, Command};
use tokio::sync::{Mutex, oneshot};
use super::super::types::{JsonRpcNotification, JsonRpcRequest, JsonRpcResponse};
use echo_core::error::{McpError, ReactError, Result};
use super::McpTransport;
type PendingMap = Arc<Mutex<HashMap<u64, oneshot::Sender<JsonRpcResponse>>>>;
pub struct StdioTransport {
stdin: Arc<Mutex<tokio::process::ChildStdin>>,
pending: PendingMap,
next_id: Arc<AtomicU64>,
_child: Arc<Mutex<Child>>,
}
impl StdioTransport {
pub async fn new(command: &str, args: &[String], env: &[(String, String)]) -> Result<Self> {
let mut cmd = Command::new(command);
cmd.args(args);
for (k, v) in env {
cmd.env(k, v);
}
cmd.stdin(Stdio::piped());
cmd.stdout(Stdio::piped());
cmd.stderr(Stdio::piped());
let mut child = cmd.spawn().map_err(|e| {
ReactError::Mcp(McpError::ConnectionFailed(format!(
"无法启动 MCP 服务端 '{}': {}",
command, e
)))
})?;
let stdin = child.stdin.take().ok_or_else(|| {
ReactError::Mcp(McpError::ConnectionFailed(
"无法获取子进程 stdin".to_string(),
))
})?;
let stdout = child.stdout.take().ok_or_else(|| {
ReactError::Mcp(McpError::ConnectionFailed(
"无法获取子进程 stdout".to_string(),
))
})?;
let stderr = child.stderr.take();
let pending: PendingMap = Arc::new(Mutex::new(HashMap::new()));
let pending_clone = pending.clone();
tokio::spawn(async move {
let reader = BufReader::new(stdout);
let mut lines = reader.lines();
loop {
match lines.next_line().await {
Ok(Some(line)) => {
let line = line.trim().to_string();
if line.is_empty() {
continue;
}
let json: Value = match serde_json::from_str(&line) {
Ok(v) => v,
Err(e) => {
tracing::warn!(
"MCP stdio: 解析 stdout 行失败: {} | 原始内容: {}",
e,
line
);
continue;
}
};
if let Some(id) = json.get("id").and_then(|id| id.as_u64()) {
match serde_json::from_value::<JsonRpcResponse>(json) {
Ok(response) => {
let mut map = pending_clone.lock().await;
if let Some(tx) = map.remove(&id) {
let _ = tx.send(response);
}
}
Err(e) => {
tracing::warn!("MCP stdio: 解析响应失败: {}", e);
}
}
} else {
let method = json
.get("method")
.and_then(|m| m.as_str())
.unwrap_or("unknown");
tracing::debug!("MCP stdio: 收到服务端通知: {}", method);
}
}
Ok(None) => {
tracing::debug!("MCP stdio: stdout 已关闭");
let mut map = pending_clone.lock().await;
map.clear();
break;
}
Err(e) => {
tracing::warn!("MCP stdio: 读取 stdout 出错: {}", e);
break;
}
}
}
});
if let Some(stderr) = stderr {
tokio::spawn(async move {
let reader = BufReader::new(stderr);
let mut lines = reader.lines();
while let Ok(Some(line)) = lines.next_line().await {
let line = line.trim().to_string();
if !line.is_empty() {
tracing::debug!("MCP stderr: {}", line);
}
}
});
}
Ok(Self {
stdin: Arc::new(Mutex::new(stdin)),
pending,
next_id: Arc::new(AtomicU64::new(1)),
_child: Arc::new(Mutex::new(child)),
})
}
}
impl McpTransport for StdioTransport {
fn send(&self, request: JsonRpcRequest) -> BoxFuture<'_, Result<JsonRpcResponse>> {
Box::pin(async move {
let mut request = request;
let id = self.next_id.fetch_add(1, Ordering::SeqCst);
request.id = Some(Value::Number(id.into()));
let (tx, rx) = oneshot::channel::<JsonRpcResponse>();
{
let mut pending = self.pending.lock().await;
pending.insert(id, tx);
}
let line = serde_json::to_string(&request)
.map_err(|e| ReactError::Mcp(McpError::ProtocolError(e.to_string())))?
+ "\n";
{
let mut stdin = self.stdin.lock().await;
stdin.write_all(line.as_bytes()).await.map_err(|e| {
ReactError::Mcp(McpError::ProtocolError(format!("写入 stdin 失败: {}", e)))
})?;
stdin.flush().await.map_err(|e| {
ReactError::Mcp(McpError::ProtocolError(format!("flush stdin 失败: {}", e)))
})?;
}
const RESPONSE_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(120);
match tokio::time::timeout(RESPONSE_TIMEOUT, rx).await {
Ok(Ok(response)) => Ok(response),
Ok(Err(_)) => {
self.pending.lock().await.remove(&id);
Err(ReactError::Mcp(McpError::TransportClosed))
}
Err(_) => {
self.pending.lock().await.remove(&id);
Err(ReactError::Mcp(McpError::ProtocolError(format!(
"等待响应超时 (id={}, 超时 {:?})",
id, RESPONSE_TIMEOUT
))))
}
}
})
}
fn notify(&self, notification: JsonRpcNotification) -> BoxFuture<'_, Result<()>> {
Box::pin(async move {
let line = serde_json::to_string(¬ification)
.map_err(|e| ReactError::Mcp(McpError::ProtocolError(e.to_string())))?
+ "\n";
let mut stdin = self.stdin.lock().await;
stdin.write_all(line.as_bytes()).await.map_err(|e| {
ReactError::Mcp(McpError::ProtocolError(format!("写入通知失败: {}", e)))
})?;
stdin.flush().await.map_err(|e| {
ReactError::Mcp(McpError::ProtocolError(format!("flush 通知失败: {}", e)))
})?;
Ok(())
})
}
fn close(&self) -> BoxFuture<'_, ()> {
Box::pin(async move {
let mut child = self._child.lock().await;
if let Err(e) = child.kill().await {
tracing::warn!("MCP stdio: 终止子进程失败: {}", e);
}
let _ = child.wait().await;
tracing::debug!("MCP stdio: 子进程已退出");
})
}
fn notification_rx(&self) -> Option<Arc<dyn super::super::types::JsonRpcNotificationReceiver>> {
None
}
}
impl Drop for StdioTransport {
fn drop(&mut self) {
let child = self._child.clone();
tokio::spawn(async move {
let mut child = child.lock().await;
if let Err(e) = child.kill().await {
tracing::debug!("MCP stdio drop: kill 失败: {}", e);
}
let _ = child.wait().await;
});
}
}