use super::config::McpServerConfig;
use super::server::{McpError, McpResult, McpServer, ServerInfo, ServerState};
use super::types::{
CallToolParams, CallToolResult, ClientInfo, GetPromptResult, InitializeParams,
InitializeResult, JsonRpcRequest, JsonRpcResponse, ListPromptsResult, ListResourcesResult,
ListToolsResult, ReadResourceResult, JSONRPC_VERSION, MCP_VERSION,
};
use async_trait::async_trait;
use serde_json::Value;
use std::collections::HashMap;
use std::process::Stdio;
use std::sync::atomic::{AtomicI64, Ordering};
use std::sync::Arc;
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
use tokio::process::{Child, Command};
use tokio::sync::{mpsc, oneshot, Mutex};
struct PendingRequest {
sender: oneshot::Sender<McpResult<Value>>,
}
pub struct McpStdioServer {
config: McpServerConfig,
state: ServerState,
info: Option<ServerInfo>,
child: Option<Child>,
request_id: AtomicI64,
pending: Arc<Mutex<HashMap<i64, PendingRequest>>>,
write_tx: Option<mpsc::Sender<String>>,
}
impl McpStdioServer {
pub fn new(config: McpServerConfig) -> Self {
Self {
config,
state: ServerState::Disconnected,
info: None,
child: None,
request_id: AtomicI64::new(1),
pending: Arc::new(Mutex::new(HashMap::new())),
write_tx: None,
}
}
fn next_request_id(&self) -> i64 {
self.request_id.fetch_add(1, Ordering::SeqCst)
}
async fn request(&self, method: &str, params: Option<Value>) -> McpResult<Value> {
let write_tx = self.write_tx.as_ref().ok_or(McpError::Disconnected)?;
let id = self.next_request_id();
let request = JsonRpcRequest {
jsonrpc: JSONRPC_VERSION.to_string(),
id: id.into(),
method: method.to_string(),
params,
};
let request_json = serde_json::to_string(&request)?;
let (tx, rx) = oneshot::channel();
{
let mut pending = self.pending.lock().await;
pending.insert(id, PendingRequest { sender: tx });
}
write_tx
.send(request_json)
.await
.map_err(|_| McpError::Disconnected)?;
let timeout = self.config.timeout;
match tokio::time::timeout(timeout, rx).await {
Ok(Ok(result)) => result,
Ok(Err(_)) => Err(McpError::Disconnected),
Err(_) => {
let mut pending = self.pending.lock().await;
pending.remove(&id);
Err(McpError::Timeout)
}
}
}
async fn initialize(&mut self) -> McpResult<InitializeResult> {
let params = InitializeParams {
protocol_version: MCP_VERSION.to_string(),
capabilities: Default::default(),
client_info: ClientInfo::default(),
};
let result: Value = self
.request("initialize", Some(serde_json::to_value(¶ms)?))
.await?;
let init_result: InitializeResult = serde_json::from_value(result)?;
let notification = serde_json::json!({
"jsonrpc": JSONRPC_VERSION,
"method": "notifications/initialized"
});
if let Some(write_tx) = &self.write_tx {
let _ = write_tx.send(serde_json::to_string(¬ification)?).await;
}
self.info = Some(ServerInfo {
name: self.config.name.clone(),
server_name: Some(init_result.server_info.name.clone()),
server_version: init_result.server_info.version.clone(),
capabilities: init_result.capabilities.clone(),
state: ServerState::Connected,
});
Ok(init_result)
}
}
#[async_trait]
impl McpServer for McpStdioServer {
fn name(&self) -> &str {
&self.config.name
}
fn config(&self) -> &McpServerConfig {
&self.config
}
fn state(&self) -> ServerState {
self.state
}
fn info(&self) -> Option<&ServerInfo> {
self.info.as_ref()
}
async fn connect(&mut self) -> McpResult<()> {
if self.state == ServerState::Connected {
return Ok(());
}
self.state = ServerState::Connecting;
let command = self
.config
.command
.as_ref()
.ok_or_else(|| McpError::ConnectionFailed("No command specified".to_string()))?;
let mut cmd = Command::new(command);
cmd.args(&self.config.args)
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::piped());
for (key, value) in &self.config.env {
cmd.env(key, value);
}
if let Some(cwd) = &self.config.cwd {
cmd.current_dir(cwd);
}
let mut child = cmd.spawn().map_err(|e| {
self.state = ServerState::Failed;
McpError::ConnectionFailed(format!("Failed to spawn process: {}", e))
})?;
let stdin = child.stdin.take().ok_or_else(|| {
self.state = ServerState::Failed;
McpError::ConnectionFailed("Failed to get stdin".to_string())
})?;
let stdout = child.stdout.take().ok_or_else(|| {
self.state = ServerState::Failed;
McpError::ConnectionFailed("Failed to get stdout".to_string())
})?;
self.child = Some(child);
let (write_tx, mut write_rx) = mpsc::channel::<String>(32);
self.write_tx = Some(write_tx);
let mut stdin = stdin;
tokio::spawn(async move {
while let Some(msg) = write_rx.recv().await {
if stdin.write_all(msg.as_bytes()).await.is_err() {
break;
}
if stdin.write_all(b"\n").await.is_err() {
break;
}
if stdin.flush().await.is_err() {
break;
}
}
});
let pending = self.pending.clone();
tokio::spawn(async move {
let reader = BufReader::new(stdout);
let mut lines = reader.lines();
while let Ok(Some(line)) = lines.next_line().await {
if line.is_empty() {
continue;
}
if let Ok(response) = serde_json::from_str::<JsonRpcResponse>(&line) {
let id = match &response.id {
super::types::RequestId::Number(n) => *n,
super::types::RequestId::String(s) => s.parse().unwrap_or(-1),
};
let mut pending_guard = pending.lock().await;
if let Some(request) = pending_guard.remove(&id) {
let result = if let Some(error) = response.error {
Err(McpError::ServerError {
code: error.code,
message: error.message,
})
} else {
Ok(response.result.unwrap_or(Value::Null))
};
let _ = request.sender.send(result);
}
}
}
});
self.initialize().await?;
self.state = ServerState::Connected;
Ok(())
}
async fn disconnect(&mut self) -> McpResult<()> {
self.state = ServerState::ShuttingDown;
self.write_tx = None;
if let Some(mut child) = self.child.take() {
let _ = child.kill().await;
}
self.state = ServerState::Disconnected;
self.info = None;
Ok(())
}
async fn list_tools(&self) -> McpResult<ListToolsResult> {
if self.state != ServerState::Connected {
return Err(McpError::Disconnected);
}
let result = self.request("tools/list", None).await?;
let list: ListToolsResult = serde_json::from_value(result)?;
Ok(list)
}
async fn call_tool(&self, params: CallToolParams) -> McpResult<CallToolResult> {
if self.state != ServerState::Connected {
return Err(McpError::Disconnected);
}
let result = self
.request("tools/call", Some(serde_json::to_value(¶ms)?))
.await?;
let call_result: CallToolResult = serde_json::from_value(result)?;
Ok(call_result)
}
async fn list_resources(&self) -> McpResult<ListResourcesResult> {
if self.state != ServerState::Connected {
return Err(McpError::Disconnected);
}
let result = self.request("resources/list", None).await?;
let list: ListResourcesResult = serde_json::from_value(result)?;
Ok(list)
}
async fn read_resource(&self, uri: &str) -> McpResult<ReadResourceResult> {
if self.state != ServerState::Connected {
return Err(McpError::Disconnected);
}
let params = serde_json::json!({ "uri": uri });
let result = self.request("resources/read", Some(params)).await?;
let read_result: ReadResourceResult = serde_json::from_value(result)?;
Ok(read_result)
}
async fn list_prompts(&self) -> McpResult<ListPromptsResult> {
if self.state != ServerState::Connected {
return Err(McpError::Disconnected);
}
let result = self.request("prompts/list", None).await?;
let list: ListPromptsResult = serde_json::from_value(result)?;
Ok(list)
}
async fn get_prompt(
&self,
name: &str,
arguments: Option<HashMap<String, String>>,
) -> McpResult<GetPromptResult> {
if self.state != ServerState::Connected {
return Err(McpError::Disconnected);
}
let params = serde_json::json!({
"name": name,
"arguments": arguments
});
let result = self.request("prompts/get", Some(params)).await?;
let prompt_result: GetPromptResult = serde_json::from_value(result)?;
Ok(prompt_result)
}
}
impl std::fmt::Debug for McpStdioServer {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("McpStdioServer")
.field("name", &self.config.name)
.field("state", &self.state)
.field("command", &self.config.command)
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_stdio_server_creation() {
let config = McpServerConfig::stdio("test", "echo");
let server = McpStdioServer::new(config);
assert_eq!(server.name(), "test");
assert_eq!(server.state(), ServerState::Disconnected);
assert!(server.info().is_none());
}
#[test]
fn test_stdio_server_request_id() {
let config = McpServerConfig::stdio("test", "echo");
let server = McpStdioServer::new(config);
let id1 = server.next_request_id();
let id2 = server.next_request_id();
let id3 = server.next_request_id();
assert_eq!(id1, 1);
assert_eq!(id2, 2);
assert_eq!(id3, 3);
}
#[tokio::test]
async fn test_stdio_server_disconnect_when_not_connected() {
let config = McpServerConfig::stdio("test", "echo");
let mut server = McpStdioServer::new(config);
let result = server.disconnect().await;
assert!(result.is_ok());
}
}