use parking_lot::Mutex;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::{Notify, oneshot};
use turbomcp_protocol::jsonrpc::{
JsonRpcMessage, JsonRpcNotification, JsonRpcRequest, JsonRpcResponse,
};
use turbomcp_protocol::{Error, MessageId, Result};
use turbomcp_transport::{Transport, TransportMessage};
type RequestHandler = Arc<dyn Fn(JsonRpcRequest) -> Result<()> + Send + Sync>;
type NotificationHandler = Arc<dyn Fn(JsonRpcNotification) -> Result<()> + Send + Sync>;
pub(super) struct MessageDispatcher {
response_waiters: Arc<Mutex<HashMap<MessageId, oneshot::Sender<JsonRpcResponse>>>>,
request_handler: Arc<Mutex<Option<RequestHandler>>>,
notification_handler: Arc<Mutex<Option<NotificationHandler>>>,
shutdown: Arc<Notify>,
}
impl MessageDispatcher {
pub fn new<T: Transport + 'static>(transport: Arc<T>) -> Arc<Self> {
let dispatcher = Arc::new(Self {
response_waiters: Arc::new(Mutex::new(HashMap::new())),
request_handler: Arc::new(Mutex::new(None)),
notification_handler: Arc::new(Mutex::new(None)),
shutdown: Arc::new(Notify::new()),
});
Self::spawn_routing_task(dispatcher.clone(), transport);
dispatcher
}
pub fn set_request_handler(&self, handler: RequestHandler) {
*self.request_handler.lock() = Some(handler);
tracing::debug!("Request handler registered with dispatcher");
}
pub fn set_notification_handler(&self, handler: NotificationHandler) {
*self.notification_handler.lock() = Some(handler);
tracing::debug!("Notification handler registered with dispatcher");
}
pub fn wait_for_response(&self, id: MessageId) -> oneshot::Receiver<JsonRpcResponse> {
let (tx, rx) = oneshot::channel();
self.response_waiters.lock().insert(id.clone(), tx);
tracing::trace!("Registered response waiter for request ID: {:?}", id);
rx
}
pub fn remove_response_waiter(&self, id: &MessageId) {
self.response_waiters.lock().remove(id);
tracing::trace!("Removed response waiter for request ID: {:?}", id);
}
#[cfg(test)]
pub fn response_waiter_count(&self) -> usize {
self.response_waiters.lock().len()
}
pub fn shutdown(&self) {
self.response_waiters.lock().clear();
self.shutdown.notify_one();
tracing::info!("Message dispatcher shutdown initiated");
}
fn spawn_routing_task<T: Transport + 'static>(dispatcher: Arc<Self>, transport: Arc<T>) {
let response_waiters = dispatcher.response_waiters.clone();
let request_handler = dispatcher.request_handler.clone();
let notification_handler = dispatcher.notification_handler.clone();
let shutdown = dispatcher.shutdown.clone();
tokio::spawn(async move {
tracing::info!("Message dispatcher routing task started");
let mut consecutive_errors = 0u32;
let max_consecutive_errors = 20; let mut has_ever_connected = false;
loop {
tokio::select! {
biased;
_ = shutdown.notified() => {
tracing::info!("Message dispatcher routing task shutting down");
break;
}
result = transport.receive() => {
match result {
Ok(Some(msg)) => {
consecutive_errors = 0;
has_ever_connected = true;
if let Err(e) = Self::route_message(
msg,
&response_waiters,
&request_handler,
¬ification_handler,
).await {
tracing::error!("Error routing message: {}", e);
}
}
Ok(None) => {
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
}
Err(e) => {
consecutive_errors += 1;
let state = transport.state().await;
let is_fatal = matches!(state, turbomcp_transport::TransportState::Disconnected
| turbomcp_transport::TransportState::Failed { .. });
if !has_ever_connected && consecutive_errors <= 3 {
tracing::debug!("Transport connecting (attempt {}): {}", consecutive_errors, e);
} else if consecutive_errors == 1 || (consecutive_errors == 4 && !has_ever_connected) {
tracing::error!("Transport receive error: {}", e);
} else if consecutive_errors <= max_consecutive_errors {
tracing::warn!("Transport receive error (attempt {}): {}", consecutive_errors, e);
} else {
if consecutive_errors == max_consecutive_errors + 1 {
tracing::error!(
"Transport in failed state ({}), suppressing further error logs. Waiting for recovery...",
state
);
}
}
let delay_ms = if is_fatal {
if consecutive_errors > max_consecutive_errors {
5000 } else {
1000 }
} else if !has_ever_connected {
50u64.saturating_mul(2u64.saturating_pow(consecutive_errors.min(4)))
} else {
100u64.saturating_mul(2u64.saturating_pow(consecutive_errors.min(5)))
};
tokio::time::sleep(tokio::time::Duration::from_millis(delay_ms)).await;
}
}
}
}
}
tracing::info!("Message dispatcher routing task terminated");
});
}
async fn route_message(
msg: TransportMessage,
response_waiters: &Arc<Mutex<HashMap<MessageId, oneshot::Sender<JsonRpcResponse>>>>,
request_handler: &Arc<Mutex<Option<RequestHandler>>>,
notification_handler: &Arc<Mutex<Option<NotificationHandler>>>,
) -> Result<()> {
let json_msg: JsonRpcMessage = serde_json::from_slice(&msg.payload)
.map_err(|e| Error::internal(format!("Invalid JSON-RPC message: {}", e)))?;
match json_msg {
JsonRpcMessage::Response(response) => {
if let Some(request_id) = &response.id.0 {
if let Some(tx) = response_waiters.lock().remove(request_id) {
tracing::trace!("Routing response to request ID: {:?}", request_id);
let _ = tx.send(response);
} else {
tracing::warn!(
"Received response for unknown/expired request ID: {:?}",
request_id
);
}
} else {
tracing::debug!(
"Received response with null ID (server parse error): {:?}",
response.error()
);
}
}
JsonRpcMessage::Request(request) => {
tracing::debug!(
"Routing server-initiated request: method={}, id={:?}",
request.method,
request.id
);
if let Some(handler) = request_handler.lock().as_ref() {
if let Err(e) = handler(request) {
tracing::error!("Request handler error: {}", e);
}
} else {
tracing::warn!(
"Received server request but no handler registered: method={}",
request.method
);
}
}
JsonRpcMessage::Notification(notification) => {
tracing::debug!(
"Routing server notification: method={}",
notification.method
);
if let Some(handler) = notification_handler.lock().as_ref() {
if let Err(e) = handler(notification) {
tracing::error!("Notification handler error: {}", e);
}
} else {
tracing::debug!(
"Received notification but no handler registered: method={}",
notification.method
);
}
}
}
Ok(())
}
}
impl std::fmt::Debug for MessageDispatcher {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MessageDispatcher")
.field("response_waiters", &"<Arc<Mutex<HashMap>>>")
.field("request_handler", &"<Arc<Mutex<Option<Handler>>>>")
.field("notification_handler", &"<Arc<Mutex<Option<Handler>>>>")
.field("shutdown", &"<Arc<Notify>>")
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use turbomcp_transport::{
TransportCapabilities, TransportConfig, TransportMessage, TransportMetrics,
TransportResult, TransportState, TransportType,
};
#[derive(Debug, Default)]
struct NoopTransport {
capabilities: TransportCapabilities,
}
impl Transport for NoopTransport {
fn transport_type(&self) -> TransportType {
TransportType::Stdio
}
fn capabilities(&self) -> &TransportCapabilities {
&self.capabilities
}
fn state(&self) -> Pin<Box<dyn Future<Output = TransportState> + Send + '_>> {
Box::pin(async { TransportState::Disconnected })
}
fn connect(&self) -> Pin<Box<dyn Future<Output = TransportResult<()>> + Send + '_>> {
Box::pin(async { Ok(()) })
}
fn disconnect(&self) -> Pin<Box<dyn Future<Output = TransportResult<()>> + Send + '_>> {
Box::pin(async { Ok(()) })
}
fn send(
&self,
_message: TransportMessage,
) -> Pin<Box<dyn Future<Output = TransportResult<()>> + Send + '_>> {
Box::pin(async { Ok(()) })
}
fn receive(
&self,
) -> Pin<Box<dyn Future<Output = TransportResult<Option<TransportMessage>>> + Send + '_>>
{
Box::pin(async { Ok(None) })
}
fn metrics(&self) -> Pin<Box<dyn Future<Output = TransportMetrics> + Send + '_>> {
Box::pin(async { TransportMetrics::default() })
}
fn configure(
&self,
_config: TransportConfig,
) -> Pin<Box<dyn Future<Output = TransportResult<()>> + Send + '_>> {
Box::pin(async { Ok(()) })
}
}
#[tokio::test]
async fn test_dispatcher_creation() {
let dispatcher = MessageDispatcher::new(Arc::new(NoopTransport::default()));
dispatcher.shutdown();
}
#[tokio::test]
async fn test_remove_response_waiter() {
let dispatcher = MessageDispatcher::new(Arc::new(NoopTransport::default()));
let id = MessageId::from("req-123");
let _rx = dispatcher.wait_for_response(id.clone());
assert!(dispatcher.response_waiters.lock().contains_key(&id));
dispatcher.remove_response_waiter(&id);
assert!(!dispatcher.response_waiters.lock().contains_key(&id));
dispatcher.shutdown();
}
}