use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::{mpsc, oneshot};
use turbomcp_core::error::{ErrorKind, McpError, McpResult};
use turbomcp_core::handler::McpHandler;
use crate::context::{McpSession, RequestContext};
use crate::router;
use crate::transport::MAX_MESSAGE_SIZE;
use turbomcp_transport::{
Transport, TransportCapabilities, TransportError, TransportMessage, TransportMetrics,
TransportResult, TransportState, TransportType,
};
const DEFAULT_CHANNEL_BUFFER: usize = 256;
const MAX_PENDING_REQUESTS: usize = 64;
#[derive(Debug)]
pub struct ChannelTransport {
tx: mpsc::Sender<TransportMessage>,
rx: tokio::sync::Mutex<mpsc::Receiver<TransportMessage>>,
state: parking_lot::Mutex<TransportState>,
capabilities: TransportCapabilities,
}
impl ChannelTransport {
fn new(tx: mpsc::Sender<TransportMessage>, rx: mpsc::Receiver<TransportMessage>) -> Self {
Self {
tx,
rx: tokio::sync::Mutex::new(rx),
state: parking_lot::Mutex::new(TransportState::Connected),
capabilities: TransportCapabilities {
max_message_size: Some(MAX_MESSAGE_SIZE),
supports_compression: false,
supports_streaming: false,
supports_bidirectional: true,
supports_multiplexing: false,
compression_algorithms: Vec::new(),
custom: std::collections::HashMap::new(),
},
}
}
}
impl Transport for ChannelTransport {
fn transport_type(&self) -> TransportType {
TransportType::Channel
}
fn capabilities(&self) -> &TransportCapabilities {
&self.capabilities
}
fn state(
&self,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = TransportState> + Send + '_>> {
Box::pin(async move { self.state.lock().clone() })
}
fn connect(
&self,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = TransportResult<()>> + Send + '_>> {
Box::pin(async move {
*self.state.lock() = TransportState::Connected;
Ok(())
})
}
fn disconnect(
&self,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = TransportResult<()>> + Send + '_>> {
Box::pin(async move {
*self.state.lock() = TransportState::Disconnected;
Ok(())
})
}
fn send(
&self,
message: TransportMessage,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = TransportResult<()>> + Send + '_>> {
Box::pin(async move {
self.tx
.send(message)
.await
.map_err(|_| TransportError::ConnectionLost("Channel closed".to_string()))?;
Ok(())
})
}
fn receive(
&self,
) -> std::pin::Pin<
Box<
dyn std::future::Future<Output = TransportResult<Option<TransportMessage>>> + Send + '_,
>,
> {
Box::pin(async move {
let mut rx = self.rx.lock().await;
Ok(rx.recv().await)
})
}
fn metrics(
&self,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = TransportMetrics> + Send + '_>> {
Box::pin(async move { TransportMetrics::default() })
}
}
#[derive(Debug, Clone)]
struct ChannelSessionHandle {
request_tx: mpsc::Sender<SessionCommand>,
}
#[derive(Debug)]
enum SessionCommand {
Request {
method: String,
params: serde_json::Value,
response_tx: oneshot::Sender<McpResult<serde_json::Value>>,
},
Notify {
method: String,
params: serde_json::Value,
},
}
#[async_trait::async_trait]
impl McpSession for ChannelSessionHandle {
async fn call(&self, method: &str, params: serde_json::Value) -> McpResult<serde_json::Value> {
let (response_tx, response_rx) = oneshot::channel();
self.request_tx
.send(SessionCommand::Request {
method: method.to_string(),
params,
response_tx,
})
.await
.map_err(|_| McpError::internal("Session closed"))?;
response_rx
.await
.map_err(|_| McpError::internal("Response channel closed"))?
}
async fn notify(&self, method: &str, params: serde_json::Value) -> McpResult<()> {
self.request_tx
.send(SessionCommand::Notify {
method: method.to_string(),
params,
})
.await
.map_err(|_| McpError::internal("Session closed"))?;
Ok(())
}
}
pub async fn run_in_process<H: McpHandler + 'static>(
handler: &H,
) -> McpResult<(ChannelTransport, tokio::task::JoinHandle<McpResult<()>>)> {
run_in_process_with_buffer(handler, DEFAULT_CHANNEL_BUFFER).await
}
pub async fn run_in_process_with_buffer<H: McpHandler + 'static>(
handler: &H,
buffer_size: usize,
) -> McpResult<(ChannelTransport, tokio::task::JoinHandle<McpResult<()>>)> {
handler.on_initialize().await?;
let (client_tx, server_rx) = mpsc::channel::<TransportMessage>(buffer_size);
let (server_tx, client_rx) = mpsc::channel::<TransportMessage>(buffer_size);
let client_transport = ChannelTransport::new(client_tx, client_rx);
let handler = handler.clone();
let server_handle =
tokio::spawn(async move { run_server_loop(handler, server_rx, server_tx).await });
Ok((client_transport, server_handle))
}
async fn run_server_loop<H: McpHandler>(
handler: H,
mut incoming: mpsc::Receiver<TransportMessage>,
outgoing: mpsc::Sender<TransportMessage>,
) -> McpResult<()> {
let (cmd_tx, mut cmd_rx) = mpsc::channel::<SessionCommand>(32);
let session_handle = Arc::new(ChannelSessionHandle { request_tx: cmd_tx });
let (response_tx, mut response_rx) = mpsc::channel::<router::JsonRpcOutgoing>(32);
let mut pending_requests =
HashMap::<serde_json::Value, oneshot::Sender<McpResult<serde_json::Value>>>::new();
let mut next_request_id = 1u64;
loop {
tokio::select! {
msg = incoming.recv() => {
let Some(msg) = msg else { break; };
if msg.payload.len() > MAX_MESSAGE_SIZE {
send_error_msg(
&outgoing,
None,
McpError::invalid_request(format!(
"Message exceeds maximum size of {MAX_MESSAGE_SIZE} bytes"
)),
).await?;
continue;
}
let value: serde_json::Value = match serde_json::from_slice(&msg.payload) {
Ok(v) => v,
Err(e) => {
send_error_msg(&outgoing, None, McpError::parse_error(e.to_string())).await?;
continue;
}
};
if let Some(id) = value.get("id")
&& (value.get("result").is_some() || value.get("error").is_some())
{
if let Some(tx) = pending_requests.remove(id) {
if let Some(error) = value.get("error") {
let mcp_error = serde_json::from_value::<turbomcp_core::jsonrpc::JsonRpcError>(error.clone())
.map(|e| McpError::new(ErrorKind::from_i32(e.code), e.message))
.unwrap_or_else(|_| McpError::internal("Failed to parse error response"));
let _ = tx.send(Err(mcp_error));
} else {
let result = value.get("result").cloned().unwrap_or(serde_json::Value::Null);
let _ = tx.send(Ok(result));
}
}
} else {
match serde_json::from_value::<turbomcp_core::jsonrpc::JsonRpcIncoming>(value) {
Ok(request) => {
let h = handler.clone();
let session = session_handle.clone();
let resp_tx = response_tx.clone();
let ctx = RequestContext::channel().with_session(session);
let core_ctx = ctx.to_core_context();
tokio::spawn(async move {
let response = router::route_request(&h, request, &core_ctx).await;
let _ = resp_tx.send(response).await;
});
}
Err(e) => {
send_error_msg(&outgoing, None, McpError::parse_error(e.to_string())).await?;
}
}
}
}
Some(response) = response_rx.recv() => {
if response.should_send() {
send_response_msg(&outgoing, &response).await?;
}
}
Some(cmd) = cmd_rx.recv() => {
match cmd {
SessionCommand::Request { method, params, response_tx } => {
if pending_requests.len() >= MAX_PENDING_REQUESTS {
let _ = response_tx.send(Err(McpError::internal(
"Too many pending server-to-client requests"
)));
continue;
}
let id = serde_json::json!(format!("s-{next_request_id}"));
next_request_id += 1;
pending_requests.insert(id.clone(), response_tx);
let request = serde_json::json!({
"jsonrpc": "2.0",
"id": id,
"method": method,
"params": params
});
let payload = serde_json::to_vec(&request)
.map_err(|e| McpError::internal(e.to_string()))?;
outgoing.send(TransportMessage::new(
turbomcp_protocol::MessageId::from(format!("s-req-{}", next_request_id - 1)),
payload.into(),
))
.await
.map_err(|_| McpError::internal("Channel closed"))?;
}
SessionCommand::Notify { method, params } => {
let notification = serde_json::json!({
"jsonrpc": "2.0",
"method": method,
"params": params
});
let payload = serde_json::to_vec(¬ification)
.map_err(|e| McpError::internal(e.to_string()))?;
outgoing.send(TransportMessage::new(
turbomcp_protocol::MessageId::from("notification"),
payload.into(),
))
.await
.map_err(|_| McpError::internal("Channel closed"))?;
}
}
}
}
}
drop(response_tx);
while let Some(response) = response_rx.recv().await {
if response.should_send() {
send_response_msg(&outgoing, &response).await?;
}
}
handler.on_shutdown().await?;
Ok(())
}
async fn send_response_msg(
tx: &mpsc::Sender<TransportMessage>,
response: &router::JsonRpcOutgoing,
) -> McpResult<()> {
let payload = router::serialize_response(response)?;
tx.send(TransportMessage::new(
response
.id
.as_ref()
.map(|id| turbomcp_protocol::MessageId::from(id.to_string()))
.unwrap_or_else(|| turbomcp_protocol::MessageId::from("response")),
bytes::Bytes::from(payload),
))
.await
.map_err(|_| McpError::internal("Channel closed"))?;
Ok(())
}
async fn send_error_msg(
tx: &mpsc::Sender<TransportMessage>,
id: Option<serde_json::Value>,
error: McpError,
) -> McpResult<()> {
let response = router::JsonRpcOutgoing::error(id, error);
send_response_msg(tx, &response).await
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::Value;
use turbomcp_core::context::RequestContext as CoreRequestContext;
use turbomcp_core::error::McpResult;
use turbomcp_types::{
Prompt, PromptResult, Resource, ResourceResult, ServerInfo, Tool, ToolResult,
};
#[derive(Clone)]
struct TestHandler;
impl McpHandler for TestHandler {
fn server_info(&self) -> ServerInfo {
ServerInfo::new("channel-test", "1.0.0")
}
fn list_tools(&self) -> Vec<Tool> {
vec![Tool::new("ping", "Ping tool")]
}
fn list_resources(&self) -> Vec<Resource> {
vec![]
}
fn list_prompts(&self) -> Vec<Prompt> {
vec![]
}
async fn call_tool(
&self,
_name: &str,
_args: Value,
_ctx: &CoreRequestContext,
) -> McpResult<ToolResult> {
Ok(ToolResult::text("pong"))
}
async fn read_resource(
&self,
uri: &str,
_ctx: &CoreRequestContext,
) -> McpResult<ResourceResult> {
Err(McpError::resource_not_found(uri))
}
async fn get_prompt(
&self,
name: &str,
_args: Option<Value>,
_ctx: &CoreRequestContext,
) -> McpResult<PromptResult> {
Err(McpError::prompt_not_found(name))
}
}
#[tokio::test]
async fn test_channel_transport_roundtrip() {
let handler = TestHandler;
let (transport, server_handle) = run_in_process(&handler).await.unwrap();
let init_request = serde_json::json!({
"jsonrpc": "2.0",
"id": 1,
"method": "initialize",
"params": {
"protocolVersion": "2025-11-25",
"clientInfo": { "name": "test", "version": "1.0.0" },
"capabilities": {}
}
});
let payload = serde_json::to_vec(&init_request).unwrap();
transport
.send(TransportMessage::new(
turbomcp_protocol::MessageId::from("1"),
payload.into(),
))
.await
.unwrap();
let response = transport.receive().await.unwrap().unwrap();
let value: serde_json::Value = serde_json::from_slice(&response.payload).unwrap();
assert!(value.get("result").is_some());
assert_eq!(value["result"]["serverInfo"]["name"], "channel-test");
let ping_request = serde_json::json!({
"jsonrpc": "2.0",
"id": 2,
"method": "tools/call",
"params": { "name": "ping", "arguments": {} }
});
let payload = serde_json::to_vec(&ping_request).unwrap();
transport
.send(TransportMessage::new(
turbomcp_protocol::MessageId::from("2"),
payload.into(),
))
.await
.unwrap();
let response = transport.receive().await.unwrap().unwrap();
let value: serde_json::Value = serde_json::from_slice(&response.payload).unwrap();
assert!(value.get("result").is_some());
drop(transport);
let _ = server_handle.await;
}
}