use async_trait::async_trait;
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader, DuplexStream};
#[async_trait]
pub trait McpIo: Send + Sync {
async fn read_line(&mut self) -> std::io::Result<Option<String>>;
async fn write_line(&mut self, line: &str) -> std::io::Result<()>;
async fn flush(&mut self) -> std::io::Result<()>;
}
pub struct StdIo {
reader: BufReader<tokio::io::Stdin>,
writer: tokio::io::Stdout,
buffer: String,
}
impl StdIo {
pub fn new() -> Self {
Self {
reader: BufReader::new(tokio::io::stdin()),
writer: tokio::io::stdout(),
buffer: String::new(),
}
}
}
impl Default for StdIo {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl McpIo for StdIo {
async fn read_line(&mut self) -> std::io::Result<Option<String>> {
self.buffer.clear();
let bytes_read = self.reader.read_line(&mut self.buffer).await?;
if bytes_read == 0 {
Ok(None) } else {
Ok(Some(self.buffer.trim().to_string()))
}
}
async fn write_line(&mut self, line: &str) -> std::io::Result<()> {
self.writer.write_all(line.as_bytes()).await?;
self.writer.write_all(b"\n").await?;
Ok(())
}
async fn flush(&mut self) -> std::io::Result<()> {
self.writer.flush().await
}
}
pub struct MockIo {
reader: BufReader<tokio::io::ReadHalf<DuplexStream>>,
writer: tokio::io::WriteHalf<DuplexStream>,
buffer: String,
}
impl MockIo {
pub fn new(stream: DuplexStream) -> Self {
let (read_half, write_half) = tokio::io::split(stream);
Self {
reader: BufReader::new(read_half),
writer: write_half,
buffer: String::new(),
}
}
pub fn create_pair(buffer_size: usize) -> (Self, Self) {
let (client_stream, server_stream) = tokio::io::duplex(buffer_size);
let server_io = Self::new(server_stream);
let client_io = Self::new(client_stream);
(server_io, client_io)
}
}
#[async_trait]
impl McpIo for MockIo {
async fn read_line(&mut self) -> std::io::Result<Option<String>> {
self.buffer.clear();
let bytes_read = self.reader.read_line(&mut self.buffer).await?;
if bytes_read == 0 {
Ok(None) } else {
Ok(Some(self.buffer.trim().to_string()))
}
}
async fn write_line(&mut self, line: &str) -> std::io::Result<()> {
self.writer.write_all(line.as_bytes()).await?;
self.writer.write_all(b"\n").await?;
Ok(())
}
async fn flush(&mut self) -> std::io::Result<()> {
self.writer.flush().await
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_mock_io_bidirectional() {
let (mut server_io, mut client_io) = MockIo::create_pair(1024);
client_io.write_line("Hello from client").await.unwrap();
client_io.flush().await.unwrap();
let line = server_io.read_line().await.unwrap();
assert_eq!(line, Some("Hello from client".to_string()));
server_io.write_line("Hello from server").await.unwrap();
server_io.flush().await.unwrap();
let line = client_io.read_line().await.unwrap();
assert_eq!(line, Some("Hello from server".to_string()));
}
#[tokio::test]
async fn test_mock_io_multiple_lines() {
let (mut server_io, mut client_io) = MockIo::create_pair(1024);
client_io.write_line("line1").await.unwrap();
client_io.write_line("line2").await.unwrap();
client_io.write_line("line3").await.unwrap();
client_io.flush().await.unwrap();
assert_eq!(
server_io.read_line().await.unwrap(),
Some("line1".to_string())
);
assert_eq!(
server_io.read_line().await.unwrap(),
Some("line2".to_string())
);
assert_eq!(
server_io.read_line().await.unwrap(),
Some("line3".to_string())
);
}
#[tokio::test]
async fn test_mock_io_empty_lines() {
let (mut server_io, mut client_io) = MockIo::create_pair(1024);
client_io.write_line("").await.unwrap();
client_io.flush().await.unwrap();
let line = server_io.read_line().await.unwrap();
assert_eq!(line, Some("".to_string()));
}
#[tokio::test]
async fn test_mock_io_eof() {
let (mut server_io, client_io) = MockIo::create_pair(1024);
drop(client_io);
let line = server_io.read_line().await.unwrap();
assert_eq!(line, None);
}
#[test]
fn test_stdio_new() {
let _stdio = StdIo::new();
}
#[test]
fn test_stdio_default() {
let _stdio = StdIo::default();
}
#[test]
fn test_stdio_clone_safety() {
let stdio = StdIo::new();
assert_eq!(stdio.buffer.len(), 0);
}
#[tokio::test]
async fn test_mock_io_whitespace_handling() {
let (mut server_io, mut client_io) = MockIo::create_pair(1024);
client_io.write_line(" leading spaces").await.unwrap();
client_io.write_line("trailing spaces ").await.unwrap();
client_io.write_line("\ttabs\t").await.unwrap();
client_io.flush().await.unwrap();
assert_eq!(
server_io.read_line().await.unwrap(),
Some("leading spaces".to_string())
);
assert_eq!(
server_io.read_line().await.unwrap(),
Some("trailing spaces".to_string())
);
assert_eq!(
server_io.read_line().await.unwrap(),
Some("tabs".to_string())
);
}
#[tokio::test]
async fn test_mock_io_large_messages() {
let (mut server_io, mut client_io) = MockIo::create_pair(8192);
let large_msg = "x".repeat(4096);
client_io.write_line(&large_msg).await.unwrap();
client_io.flush().await.unwrap();
let received = server_io.read_line().await.unwrap();
assert_eq!(received, Some(large_msg));
}
#[tokio::test]
async fn test_mock_io_buffer_reuse() {
let (mut server_io, mut client_io) = MockIo::create_pair(1024);
for i in 0..5 {
let msg = format!("message{}", i);
client_io.write_line(&msg).await.unwrap();
client_io.flush().await.unwrap();
let received = server_io.read_line().await.unwrap();
assert_eq!(received, Some(msg));
}
}
#[tokio::test]
async fn test_mock_io_concurrent_operations() {
let (mut server_io, mut client_io) = MockIo::create_pair(4096);
let client_handle = tokio::spawn(async move {
for i in 0..10 {
client_io.write_line(&format!("msg{}", i)).await.unwrap();
client_io.flush().await.unwrap();
}
});
for i in 0..10 {
let received = server_io.read_line().await.unwrap();
assert_eq!(received, Some(format!("msg{}", i)));
}
client_handle.await.unwrap();
}
}