turbomcp_client/client/dispatcher.rs
1//! Message dispatcher for routing JSON-RPC messages
2//!
3//! This module implements the message routing layer that solves the bidirectional
4//! communication problem. It runs a background task that reads ALL messages from
5//! the transport and routes them appropriately:
6//!
7//! - **Responses** → Routed to waiting `request()` calls via oneshot channels
8//! - **Requests** → Routed to registered request handler (for elicitation, sampling, etc.)
9//! - **Notifications** → Routed to registered notification handler
10//!
11//! ## Architecture
12//!
13//! ```text
14//! ┌──────────────────────────────────────────────┐
15//! │ MessageDispatcher │
16//! │ │
17//! │ Background Task (tokio::spawn): │
18//! │ loop { │
19//! │ msg = transport.receive().await │
20//! │ match parse(msg) { │
21//! │ Response => send to oneshot channel │
22//! │ Request => call request_handler │
23//! │ Notification => call notif_handler │
24//! │ } │
25//! │ } │
26//! └──────────────────────────────────────────────┘
27//! ```
28//!
29//! This ensures that there's only ONE consumer of `transport.receive()`,
30//! eliminating race conditions by centralizing all message routing.
31
32use std::collections::HashMap;
33use std::sync::{Arc, Mutex}; // Use std::sync::Mutex for simpler synchronous access
34
35use tokio::sync::{Notify, oneshot};
36use turbomcp_protocol::jsonrpc::{
37 JsonRpcMessage, JsonRpcNotification, JsonRpcRequest, JsonRpcResponse,
38};
39use turbomcp_protocol::{Error, MessageId, Result};
40use turbomcp_transport::{Transport, TransportMessage};
41
42/// Type alias for request handler functions
43///
44/// The handler receives a request and processes it asynchronously.
45/// It's responsible for sending responses back via the transport.
46type RequestHandler = Arc<dyn Fn(JsonRpcRequest) -> Result<()> + Send + Sync>;
47
48/// Type alias for notification handler functions
49///
50/// The handler receives a notification and processes it asynchronously.
51type NotificationHandler = Arc<dyn Fn(JsonRpcNotification) -> Result<()> + Send + Sync>;
52
53/// Message dispatcher that routes incoming JSON-RPC messages
54///
55/// The dispatcher solves the bidirectional communication problem by being the
56/// SINGLE consumer of `transport.receive()`. It runs a background task that
57/// continuously reads messages and routes them to the appropriate handlers.
58///
59/// # Design Principles
60///
61/// 1. **Single Responsibility**: Only handles message routing, not processing
62/// 2. **Thread-Safe**: All state protected by Arc<Mutex<...>>
63/// 3. **Graceful Shutdown**: Supports clean shutdown via Notify signal
64/// 4. **Error Resilient**: Continues running even if individual messages fail
65/// 5. **Production-Ready**: Comprehensive logging and error handling
66///
67/// # Known Limitations
68///
69/// **Response Waiter Cleanup**: If a request is made but the response never arrives
70/// (e.g., server crash, network partition), the oneshot sender remains in the
71/// `response_waiters` HashMap indefinitely. This has minimal impact because:
72/// - Oneshot senders have a small memory footprint (~24 bytes)
73/// - In practice, responses arrive or clients timeout and drop the receiver
74/// - When a receiver is dropped, the send fails gracefully (error is ignored)
75///
76/// Future enhancement: Add a background cleanup task or request timeout mechanism
77/// to remove stale entries after a configurable duration.
78///
79/// # Example
80///
81/// ```rust,ignore
82/// let dispatcher = MessageDispatcher::new(Arc::new(transport));
83///
84/// // Register handlers
85/// dispatcher.set_request_handler(Arc::new(|req| {
86/// // Handle server-initiated requests (elicitation, sampling)
87/// Ok(())
88/// })).await;
89///
90/// // Wait for a response to a specific request
91/// let id = MessageId::from("req-123");
92/// let receiver = dispatcher.wait_for_response(id.clone()).await;
93///
94/// // The background task routes the response when it arrives
95/// let response = receiver.await?;
96/// ```
97pub(super) struct MessageDispatcher {
98 /// Map of request IDs to oneshot senders for response routing
99 ///
100 /// When `ProtocolClient::request()` sends a request, it registers a oneshot
101 /// channel here. When the dispatcher receives the corresponding response,
102 /// it sends it through the channel.
103 response_waiters: Arc<Mutex<HashMap<MessageId, oneshot::Sender<JsonRpcResponse>>>>,
104
105 /// Optional handler for server-initiated requests (elicitation, sampling)
106 ///
107 /// This is set by the Client to handle incoming requests from the server.
108 /// The handler is responsible for processing the request and sending a response.
109 request_handler: Arc<Mutex<Option<RequestHandler>>>,
110
111 /// Optional handler for server-initiated notifications
112 ///
113 /// This is set by the Client to handle incoming notifications from the server.
114 notification_handler: Arc<Mutex<Option<NotificationHandler>>>,
115
116 /// Shutdown signal for graceful termination
117 ///
118 /// When `shutdown()` is called, this notify wakes up the background task
119 /// which then exits cleanly.
120 shutdown: Arc<Notify>,
121}
122
123impl MessageDispatcher {
124 /// Create a new message dispatcher and start the background routing task
125 ///
126 /// The dispatcher immediately spawns a background task that continuously
127 /// reads messages from the transport and routes them appropriately.
128 ///
129 /// # Arguments
130 ///
131 /// * `transport` - The transport to read messages from
132 ///
133 /// # Returns
134 ///
135 /// Returns a new `MessageDispatcher` with the routing task running.
136 pub fn new<T: Transport + 'static>(transport: Arc<T>) -> Arc<Self> {
137 let dispatcher = Arc::new(Self {
138 response_waiters: Arc::new(Mutex::new(HashMap::new())),
139 request_handler: Arc::new(Mutex::new(None)),
140 notification_handler: Arc::new(Mutex::new(None)),
141 shutdown: Arc::new(Notify::new()),
142 });
143
144 // Start background routing task
145 Self::spawn_routing_task(dispatcher.clone(), transport);
146
147 dispatcher
148 }
149
150 /// Register a request handler for server-initiated requests
151 ///
152 /// This handler will be called when the server sends a request (like
153 /// elicitation/create or sampling/createMessage). The handler is responsible
154 /// for processing the request and sending a response back.
155 ///
156 /// # Arguments
157 ///
158 /// * `handler` - Function to handle incoming requests
159 pub fn set_request_handler(&self, handler: RequestHandler) {
160 *self.request_handler.lock().expect("handler mutex poisoned") = Some(handler);
161 tracing::debug!("Request handler registered with dispatcher");
162 }
163
164 /// Register a notification handler for server-initiated notifications
165 ///
166 /// This handler will be called when the server sends a notification.
167 ///
168 /// # Arguments
169 ///
170 /// * `handler` - Function to handle incoming notifications
171 pub fn set_notification_handler(&self, handler: NotificationHandler) {
172 *self
173 .notification_handler
174 .lock()
175 .expect("handler mutex poisoned") = Some(handler);
176 tracing::debug!("Notification handler registered with dispatcher");
177 }
178
179 /// Wait for a response to a specific request ID
180 ///
181 /// This method is called by `ProtocolClient::request()` before sending a request.
182 /// It registers a oneshot channel that will receive the response when it arrives.
183 ///
184 /// # Arguments
185 ///
186 /// * `id` - The request ID to wait for
187 ///
188 /// # Returns
189 ///
190 /// Returns a oneshot receiver that will be sent the response when it arrives.
191 ///
192 /// # Example
193 ///
194 /// ```rust,ignore
195 /// // Register waiter before sending request
196 /// let id = MessageId::from("req-123");
197 /// let receiver = dispatcher.wait_for_response(id.clone()).await;
198 ///
199 /// // Send request...
200 ///
201 /// // Wait for response
202 /// let response = receiver.await?;
203 /// ```
204 pub fn wait_for_response(&self, id: MessageId) -> oneshot::Receiver<JsonRpcResponse> {
205 let (tx, rx) = oneshot::channel();
206 self.response_waiters
207 .lock()
208 .expect("response_waiters mutex poisoned")
209 .insert(id.clone(), tx);
210 tracing::trace!("Registered response waiter for request ID: {:?}", id);
211 rx
212 }
213
214 /// Signal the dispatcher to shutdown gracefully
215 ///
216 /// This notifies the background routing task to exit cleanly.
217 /// The task will finish processing the current message and then terminate.
218 ///
219 /// This method is called automatically when the Client is dropped,
220 /// ensuring proper cleanup of background resources.
221 pub fn shutdown(&self) {
222 self.shutdown.notify_one();
223 tracing::info!("Message dispatcher shutdown initiated");
224 }
225
226 /// Spawn the background routing task
227 ///
228 /// This task continuously reads messages from the transport and routes them
229 /// to the appropriate handlers. It runs until `shutdown()` is called or
230 /// the transport is closed.
231 ///
232 /// # Arguments
233 ///
234 /// * `dispatcher` - Arc reference to the dispatcher
235 /// * `transport` - Arc reference to the transport
236 fn spawn_routing_task<T: Transport + 'static>(dispatcher: Arc<Self>, transport: Arc<T>) {
237 let response_waiters = dispatcher.response_waiters.clone();
238 let request_handler = dispatcher.request_handler.clone();
239 let notification_handler = dispatcher.notification_handler.clone();
240 let shutdown = dispatcher.shutdown.clone();
241
242 tokio::spawn(async move {
243 tracing::info!("Message dispatcher routing task started");
244
245 let mut consecutive_errors = 0u32;
246 let max_consecutive_errors = 20; // After 20 consecutive errors, back off significantly
247
248 loop {
249 tokio::select! {
250 // Graceful shutdown
251 _ = shutdown.notified() => {
252 tracing::info!("Message dispatcher routing task shutting down");
253 break;
254 }
255
256 // Read and route messages
257 result = transport.receive() => {
258 match result {
259 Ok(Some(msg)) => {
260 // Successfully received message - reset error counter
261 consecutive_errors = 0;
262
263 // Route the message
264 if let Err(e) = Self::route_message(
265 msg,
266 &response_waiters,
267 &request_handler,
268 ¬ification_handler,
269 ).await {
270 tracing::error!("Error routing message: {}", e);
271 }
272 }
273 Ok(None) => {
274 // No message available - transport returned None
275 // Brief sleep to avoid busy-waiting
276 tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
277 }
278 Err(e) => {
279 consecutive_errors += 1;
280
281 // Check transport state to determine error severity
282 let state = transport.state().await;
283 let is_fatal = matches!(state, turbomcp_transport::TransportState::Disconnected
284 | turbomcp_transport::TransportState::Failed { .. });
285
286 if consecutive_errors == 1 {
287 // First error - log at error level
288 tracing::error!("Transport receive error: {}", e);
289 } else if consecutive_errors <= max_consecutive_errors {
290 // Subsequent errors - log at warn to reduce noise
291 tracing::warn!("Transport receive error (attempt {}): {}", consecutive_errors, e);
292 } else {
293 // Too many errors - log once and suppress further logs
294 if consecutive_errors == max_consecutive_errors + 1 {
295 tracing::error!(
296 "Transport in failed state ({}), suppressing further error logs. Waiting for recovery...",
297 state
298 );
299 }
300 }
301
302 // Exponential backoff based on error count and transport state
303 let delay_ms = if is_fatal {
304 // Fatal error - wait longer to avoid spam
305 if consecutive_errors > max_consecutive_errors {
306 5000 // 5 seconds when transport is dead
307 } else {
308 1000 // 1 second initially
309 }
310 } else {
311 // Transient error - shorter backoff
312 100u64.saturating_mul(2u64.saturating_pow(consecutive_errors.min(5)))
313 };
314
315 tokio::time::sleep(tokio::time::Duration::from_millis(delay_ms)).await;
316 }
317 }
318 }
319 }
320 }
321
322 tracing::info!("Message dispatcher routing task terminated");
323 });
324 }
325
326 /// Route an incoming message to the appropriate handler
327 ///
328 /// This is the core routing logic. It parses the raw transport message as
329 /// a JSON-RPC message and routes it based on type:
330 ///
331 /// - **Response**: Look up the waiting oneshot channel and send the response
332 /// - **Request**: Call the registered request handler
333 /// - **Notification**: Call the registered notification handler
334 ///
335 /// # Arguments
336 ///
337 /// * `msg` - The raw transport message to route
338 /// * `response_waiters` - Map of request IDs to oneshot senders
339 /// * `request_handler` - Optional request handler
340 /// * `notification_handler` - Optional notification handler
341 ///
342 /// # Errors
343 ///
344 /// Returns an error if the message cannot be parsed as valid JSON-RPC.
345 /// Handler errors are logged but do not propagate.
346 async fn route_message(
347 msg: TransportMessage,
348 response_waiters: &Arc<Mutex<HashMap<MessageId, oneshot::Sender<JsonRpcResponse>>>>,
349 request_handler: &Arc<Mutex<Option<RequestHandler>>>,
350 notification_handler: &Arc<Mutex<Option<NotificationHandler>>>,
351 ) -> Result<()> {
352 // Parse as JSON-RPC message
353 let json_msg: JsonRpcMessage = serde_json::from_slice(&msg.payload)
354 .map_err(|e| Error::protocol(format!("Invalid JSON-RPC message: {}", e)))?;
355
356 match json_msg {
357 JsonRpcMessage::Response(response) => {
358 // Route to waiting request() call
359 // ResponseId is Option<RequestId> where RequestId = MessageId
360 if let Some(request_id) = &response.id.0 {
361 if let Some(tx) = response_waiters
362 .lock()
363 .expect("response_waiters mutex poisoned")
364 .remove(request_id)
365 {
366 tracing::trace!("Routing response to request ID: {:?}", request_id);
367 // Send response through oneshot channel
368 // Ignore error if receiver was dropped (request timed out)
369 let _ = tx.send(response);
370 } else {
371 tracing::warn!(
372 "Received response for unknown/expired request ID: {:?}",
373 request_id
374 );
375 }
376 } else {
377 tracing::warn!("Received response with null ID (parse error)");
378 }
379 }
380
381 JsonRpcMessage::Request(request) => {
382 // Route to request handler (elicitation, sampling, etc.)
383 tracing::debug!(
384 "Routing server-initiated request: method={}, id={:?}",
385 request.method,
386 request.id
387 );
388
389 if let Some(handler) = request_handler
390 .lock()
391 .expect("request_handler mutex poisoned")
392 .as_ref()
393 {
394 // Call handler (handler is responsible for sending response)
395 if let Err(e) = handler(request) {
396 tracing::error!("Request handler error: {}", e);
397 }
398 } else {
399 tracing::warn!(
400 "Received server request but no handler registered: method={}",
401 request.method
402 );
403 }
404 }
405
406 JsonRpcMessage::Notification(notification) => {
407 // Route to notification handler
408 tracing::debug!(
409 "Routing server notification: method={}",
410 notification.method
411 );
412
413 if let Some(handler) = notification_handler
414 .lock()
415 .expect("notification_handler mutex poisoned")
416 .as_ref()
417 {
418 if let Err(e) = handler(notification) {
419 tracing::error!("Notification handler error: {}", e);
420 }
421 } else {
422 tracing::debug!(
423 "Received notification but no handler registered: method={}",
424 notification.method
425 );
426 }
427 }
428
429 // Allow deprecated for defensive pattern matching on batch types
430 // These exist only for defensive deserialization per MCP 2025-06-18 spec
431 #[allow(deprecated)]
432 JsonRpcMessage::RequestBatch(_)
433 | JsonRpcMessage::ResponseBatch(_)
434 | JsonRpcMessage::MessageBatch(_) => {
435 // Batch operations not supported per MCP 2025-06-18 specification
436 tracing::debug!("Received batch message (not supported per MCP specification)");
437 }
438 }
439
440 Ok(())
441 }
442}
443
444impl std::fmt::Debug for MessageDispatcher {
445 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
446 f.debug_struct("MessageDispatcher")
447 .field("response_waiters", &"<Arc<Mutex<HashMap>>>")
448 .field("request_handler", &"<Arc<Mutex<Option<Handler>>>>")
449 .field("notification_handler", &"<Arc<Mutex<Option<Handler>>>>")
450 .field("shutdown", &"<Arc<Notify>>")
451 .finish()
452 }
453}
454
455#[cfg(test)]
456mod tests {
457
458 // Note: Full integration tests with mock transport will be added
459 // in tests/bidirectional_integration.rs
460
461 #[test]
462 fn test_dispatcher_creation() {
463 // Smoke test to ensure the module compiles and basic structures work
464 // Full testing requires a mock transport
465 }
466}