use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
pub use super::{IoError, MutToolHandler, ToolHandler};
use crate::{McpServer, OutgoingMessage, Output, ToolRegistry};
pub async fn run_stdio<R, H>(server: McpServer<R>, handler: H) -> Result<(), IoError>
where
R: ToolRegistry,
H: MutToolHandler<R>,
{
let stdin = BufReader::new(tokio::io::stdin());
let stdout = tokio::io::stdout();
run_on(stdin, stdout, server, handler).await
}
pub async fn run_on<R, H, I, O>(
mut input: I,
mut output: O,
mut server: McpServer<R>,
mut handler: H,
) -> Result<(), IoError>
where
R: ToolRegistry,
H: MutToolHandler<R>,
I: AsyncBufReadExt + Unpin,
O: AsyncWriteExt + Unpin,
{
let mut line = String::new();
loop {
line.clear();
let bytes = input.read_line(&mut line).await.map_err(IoError::Io)?;
if bytes == 0 {
break;
}
let msg = crate::parse_line(line.trim_end()).map_err(IoError::Parse)?;
match server.handle(msg) {
Output::Send(response) => {
write_message(&mut output, response).await?;
}
Output::ToolCall { tool, responder } => {
let response = responder.respond(handler.handle(None, tool).await);
write_message(&mut output, response).await?;
}
Output::ProtocolError(e) => {
return Err(IoError::Protocol(e));
}
Output::None => {}
}
}
Ok(())
}
async fn write_message(
w: &mut (impl AsyncWriteExt + Unpin),
msg: OutgoingMessage,
) -> Result<(), IoError> {
let mut json = serde_json::to_vec(msg.as_inner()).map_err(IoError::Serialize)?;
json.push(b'\n');
w.write_all(&json).await.map_err(IoError::Io)?;
w.flush().await.map_err(IoError::Io)
}
#[cfg(test)]
mod tests {
use std::io::Cursor;
use super::{IoError, run_on};
use crate::{McpServer, NoTools};
fn test_server() -> McpServer<NoTools> {
McpServer::builder().name("test").version("1.0").build()
}
#[tokio::test]
async fn full_session() {
let input = r#"{"jsonrpc":"2.0","id":1,"method":"initialize","params":{"protocolVersion":"2025-06-18","capabilities":{},"clientInfo":{"name":"test","version":"1.0"}}}
{"jsonrpc":"2.0","method":"notifications/initialized"}
{"jsonrpc":"2.0","id":2,"method":"ping"}
"#;
let mut output = Vec::new();
let result = run_on(
Cursor::new(input),
&mut output,
test_server(),
|_, _: NoTools| async { Ok::<_, std::convert::Infallible>(String::new()) },
)
.await;
assert!(result.is_ok());
let output_str = String::from_utf8(output).expect("valid utf8");
let lines: Vec<&str> = output_str.lines().collect();
assert_eq!(lines.len(), 2);
let init_response: serde_json::Value = serde_json::from_str(lines[0]).expect("valid json");
assert_eq!(init_response["id"], 1);
assert!(init_response["result"]["protocolVersion"].is_string());
let ping_response: serde_json::Value = serde_json::from_str(lines[1]).expect("valid json");
assert_eq!(ping_response["id"], 2);
}
#[tokio::test]
async fn parse_error() {
let input = "not valid json\n";
let mut output = Vec::new();
let result = run_on(
Cursor::new(input),
&mut output,
test_server(),
|_, _: NoTools| async { Ok::<_, std::convert::Infallible>(String::new()) },
)
.await;
assert!(matches!(result, Err(IoError::Parse(_))));
}
#[tokio::test]
async fn protocol_error_on_unexpected_message() {
let input = r#"{"jsonrpc":"2.0","method":"notifications/initialized"}
"#;
let mut output = Vec::new();
let result = run_on(
Cursor::new(input),
&mut output,
test_server(),
|_, _: NoTools| async { Ok::<_, std::convert::Infallible>(String::new()) },
)
.await;
assert!(matches!(result, Err(IoError::Protocol(_))));
}
}