use crate::error::MCPError;
use serde::{de::DeserializeOwned, Serialize};
use std::io;
pub type ErrorCallback = Box<dyn Fn(&MCPError) + Send + Sync>;
pub type MessageCallback = Box<dyn Fn(&str) + Send + Sync>;
pub type CloseCallback = Box<dyn Fn() + Send + Sync>;
pub trait Transport {
fn start(&mut self) -> Result<(), MCPError>;
fn send<T: Serialize>(&mut self, message: &T) -> Result<(), MCPError>;
fn receive<T: DeserializeOwned>(&mut self) -> Result<T, MCPError>;
fn close(&mut self) -> Result<(), MCPError>;
fn set_on_close(&mut self, callback: Option<CloseCallback>);
fn set_on_error(&mut self, callback: Option<ErrorCallback>);
fn set_on_message<F>(&mut self, callback: Option<F>)
where
F: Fn(&str) + Send + Sync + 'static;
}
pub mod stdio {
use super::*;
use std::io::{BufRead, BufReader, Write};
pub struct StdioTransport {
reader: BufReader<Box<dyn io::Read + Send>>,
writer: Box<dyn io::Write + Send>,
is_connected: bool,
on_close: Option<CloseCallback>,
on_error: Option<ErrorCallback>,
on_message: Option<MessageCallback>,
}
impl Default for StdioTransport {
fn default() -> Self {
Self::new()
}
}
impl StdioTransport {
pub fn new() -> Self {
Self {
reader: BufReader::new(Box::new(io::stdin())),
writer: Box::new(io::stdout()),
is_connected: false,
on_close: None,
on_error: None,
on_message: None,
}
}
pub fn with_reader_writer(
reader: Box<dyn io::Read + Send>,
writer: Box<dyn io::Write + Send>,
) -> Self {
Self {
reader: BufReader::new(reader),
writer,
is_connected: false,
on_close: None,
on_error: None,
on_message: None,
}
}
fn handle_error(&self, error: &MCPError) {
if let Some(callback) = &self.on_error {
callback(error);
}
}
}
impl Transport for StdioTransport {
fn start(&mut self) -> Result<(), MCPError> {
if self.is_connected {
return Ok(());
}
self.is_connected = true;
Ok(())
}
fn send<T: Serialize>(&mut self, message: &T) -> Result<(), MCPError> {
if !self.is_connected {
let error = MCPError::Transport("Transport not connected".to_string());
self.handle_error(&error);
return Err(error);
}
let json = match serde_json::to_string(message) {
Ok(json) => json,
Err(e) => {
let error = MCPError::Serialization(e);
self.handle_error(&error);
return Err(error);
}
};
match writeln!(self.writer, "{}", json) {
Ok(_) => match self.writer.flush() {
Ok(_) => Ok(()),
Err(e) => {
let error = MCPError::Transport(format!("Failed to flush: {}", e));
self.handle_error(&error);
Err(error)
}
},
Err(e) => {
let error = MCPError::Transport(format!("Failed to write: {}", e));
self.handle_error(&error);
Err(error)
}
}
}
fn receive<T: DeserializeOwned>(&mut self) -> Result<T, MCPError> {
if !self.is_connected {
let error = MCPError::Transport("Transport not connected".to_string());
self.handle_error(&error);
return Err(error);
}
let mut line = String::new();
match self.reader.read_line(&mut line) {
Ok(_) => {
if let Some(callback) = &self.on_message {
callback(&line);
}
match serde_json::from_str(&line) {
Ok(parsed) => Ok(parsed),
Err(e) => {
let error = MCPError::Serialization(e);
self.handle_error(&error);
Err(error)
}
}
}
Err(e) => {
let error = MCPError::Transport(format!("Failed to read: {}", e));
self.handle_error(&error);
Err(error)
}
}
}
fn close(&mut self) -> Result<(), MCPError> {
if !self.is_connected {
return Ok(());
}
self.is_connected = false;
if let Some(callback) = &self.on_close {
callback();
}
Ok(())
}
fn set_on_close(&mut self, callback: Option<CloseCallback>) {
self.on_close = callback;
}
fn set_on_error(&mut self, callback: Option<ErrorCallback>) {
self.on_error = callback;
}
fn set_on_message<F>(&mut self, callback: Option<F>)
where
F: Fn(&str) + Send + Sync + 'static,
{
self.on_message = callback.map(|f| Box::new(f) as Box<dyn Fn(&str) + Send + Sync>);
}
}
}
pub mod sse;