turbomcp_client/client/
dispatcher.rs

1//! Message dispatcher for routing JSON-RPC messages
2//!
3//! This module implements the message routing layer that solves the bidirectional
4//! communication problem. It runs a background task that reads ALL messages from
5//! the transport and routes them appropriately:
6//!
7//! - **Responses** → Routed to waiting `request()` calls via oneshot channels
8//! - **Requests** → Routed to registered request handler (for elicitation, sampling, etc.)
9//! - **Notifications** → Routed to registered notification handler
10//!
11//! ## Architecture
12//!
13//! ```text
14//! ┌──────────────────────────────────────────────┐
15//! │          MessageDispatcher                   │
16//! │                                              │
17//! │  Background Task (tokio::spawn):             │
18//! │  loop {                                      │
19//! │    msg = transport.receive().await           │
20//! │    match parse(msg) {                        │
21//! │      Response => send to oneshot channel     │
22//! │      Request => call request_handler         │
23//! │      Notification => call notif_handler      │
24//! │    }                                         │
25//! │  }                                           │
26//! └──────────────────────────────────────────────┘
27//! ```
28//!
29//! This ensures that there's only ONE consumer of `transport.receive()`,
30//! eliminating race conditions by centralizing all message routing.
31
32use std::collections::HashMap;
33use std::sync::{Arc, Mutex}; // Use std::sync::Mutex for simpler synchronous access
34
35use tokio::sync::{Notify, oneshot};
36use turbomcp_protocol::jsonrpc::{
37    JsonRpcMessage, JsonRpcNotification, JsonRpcRequest, JsonRpcResponse,
38};
39use turbomcp_protocol::{Error, MessageId, Result};
40use turbomcp_transport::{Transport, TransportMessage};
41
42/// Type alias for request handler functions
43///
44/// The handler receives a request and processes it asynchronously.
45/// It's responsible for sending responses back via the transport.
46type RequestHandler = Arc<dyn Fn(JsonRpcRequest) -> Result<()> + Send + Sync>;
47
48/// Type alias for notification handler functions
49///
50/// The handler receives a notification and processes it asynchronously.
51type NotificationHandler = Arc<dyn Fn(JsonRpcNotification) -> Result<()> + Send + Sync>;
52
53/// Message dispatcher that routes incoming JSON-RPC messages
54///
55/// The dispatcher solves the bidirectional communication problem by being the
56/// SINGLE consumer of `transport.receive()`. It runs a background task that
57/// continuously reads messages and routes them to the appropriate handlers.
58///
59/// # Design Principles
60///
61/// 1. **Single Responsibility**: Only handles message routing, not processing
62/// 2. **Thread-Safe**: All state protected by Arc<Mutex<...>>
63/// 3. **Graceful Shutdown**: Supports clean shutdown via Notify signal
64/// 4. **Error Resilient**: Continues running even if individual messages fail
65/// 5. **Production-Ready**: Comprehensive logging and error handling
66///
67/// # Known Limitations
68///
69/// **Response Waiter Cleanup**: If a request is made but the response never arrives
70/// (e.g., server crash, network partition), the oneshot sender remains in the
71/// `response_waiters` HashMap indefinitely. This has minimal impact because:
72/// - Oneshot senders have a small memory footprint (~24 bytes)
73/// - In practice, responses arrive or clients timeout and drop the receiver
74/// - When a receiver is dropped, the send fails gracefully (error is ignored)
75///
76/// Future enhancement: Add a background cleanup task or request timeout mechanism
77/// to remove stale entries after a configurable duration.
78///
79/// # Example
80///
81/// ```rust,ignore
82/// let dispatcher = MessageDispatcher::new(Arc::new(transport));
83///
84/// // Register handlers
85/// dispatcher.set_request_handler(Arc::new(|req| {
86///     // Handle server-initiated requests (elicitation, sampling)
87///     Ok(())
88/// })).await;
89///
90/// // Wait for a response to a specific request
91/// let id = MessageId::from("req-123");
92/// let receiver = dispatcher.wait_for_response(id.clone()).await;
93///
94/// // The background task routes the response when it arrives
95/// let response = receiver.await?;
96/// ```
97pub(super) struct MessageDispatcher {
98    /// Map of request IDs to oneshot senders for response routing
99    ///
100    /// When `ProtocolClient::request()` sends a request, it registers a oneshot
101    /// channel here. When the dispatcher receives the corresponding response,
102    /// it sends it through the channel.
103    response_waiters: Arc<Mutex<HashMap<MessageId, oneshot::Sender<JsonRpcResponse>>>>,
104
105    /// Optional handler for server-initiated requests (elicitation, sampling)
106    ///
107    /// This is set by the Client to handle incoming requests from the server.
108    /// The handler is responsible for processing the request and sending a response.
109    request_handler: Arc<Mutex<Option<RequestHandler>>>,
110
111    /// Optional handler for server-initiated notifications
112    ///
113    /// This is set by the Client to handle incoming notifications from the server.
114    notification_handler: Arc<Mutex<Option<NotificationHandler>>>,
115
116    /// Shutdown signal for graceful termination
117    ///
118    /// When `shutdown()` is called, this notify wakes up the background task
119    /// which then exits cleanly.
120    shutdown: Arc<Notify>,
121}
122
123impl MessageDispatcher {
124    /// Create a new message dispatcher and start the background routing task
125    ///
126    /// The dispatcher immediately spawns a background task that continuously
127    /// reads messages from the transport and routes them appropriately.
128    ///
129    /// # Arguments
130    ///
131    /// * `transport` - The transport to read messages from
132    ///
133    /// # Returns
134    ///
135    /// Returns a new `MessageDispatcher` with the routing task running.
136    pub fn new<T: Transport + 'static>(transport: Arc<T>) -> Arc<Self> {
137        let dispatcher = Arc::new(Self {
138            response_waiters: Arc::new(Mutex::new(HashMap::new())),
139            request_handler: Arc::new(Mutex::new(None)),
140            notification_handler: Arc::new(Mutex::new(None)),
141            shutdown: Arc::new(Notify::new()),
142        });
143
144        // Start background routing task
145        Self::spawn_routing_task(dispatcher.clone(), transport);
146
147        dispatcher
148    }
149
150    /// Register a request handler for server-initiated requests
151    ///
152    /// This handler will be called when the server sends a request (like
153    /// elicitation/create or sampling/createMessage). The handler is responsible
154    /// for processing the request and sending a response back.
155    ///
156    /// # Arguments
157    ///
158    /// * `handler` - Function to handle incoming requests
159    pub fn set_request_handler(&self, handler: RequestHandler) {
160        *self.request_handler.lock().expect("handler mutex poisoned") = Some(handler);
161        tracing::debug!("Request handler registered with dispatcher");
162    }
163
164    /// Register a notification handler for server-initiated notifications
165    ///
166    /// This handler will be called when the server sends a notification.
167    ///
168    /// # Arguments
169    ///
170    /// * `handler` - Function to handle incoming notifications
171    pub fn set_notification_handler(&self, handler: NotificationHandler) {
172        *self
173            .notification_handler
174            .lock()
175            .expect("handler mutex poisoned") = Some(handler);
176        tracing::debug!("Notification handler registered with dispatcher");
177    }
178
179    /// Wait for a response to a specific request ID
180    ///
181    /// This method is called by `ProtocolClient::request()` before sending a request.
182    /// It registers a oneshot channel that will receive the response when it arrives.
183    ///
184    /// # Arguments
185    ///
186    /// * `id` - The request ID to wait for
187    ///
188    /// # Returns
189    ///
190    /// Returns a oneshot receiver that will be sent the response when it arrives.
191    ///
192    /// # Example
193    ///
194    /// ```rust,ignore
195    /// // Register waiter before sending request
196    /// let id = MessageId::from("req-123");
197    /// let receiver = dispatcher.wait_for_response(id.clone()).await;
198    ///
199    /// // Send request...
200    ///
201    /// // Wait for response
202    /// let response = receiver.await?;
203    /// ```
204    pub fn wait_for_response(&self, id: MessageId) -> oneshot::Receiver<JsonRpcResponse> {
205        let (tx, rx) = oneshot::channel();
206        self.response_waiters
207            .lock()
208            .expect("response_waiters mutex poisoned")
209            .insert(id.clone(), tx);
210        tracing::trace!("Registered response waiter for request ID: {:?}", id);
211        rx
212    }
213
214    /// Signal the dispatcher to shutdown gracefully
215    ///
216    /// This notifies the background routing task to exit cleanly.
217    /// The task will finish processing the current message and then terminate.
218    ///
219    /// This method is called automatically when the Client is dropped,
220    /// ensuring proper cleanup of background resources.
221    pub fn shutdown(&self) {
222        self.shutdown.notify_one();
223        tracing::info!("Message dispatcher shutdown initiated");
224    }
225
226    /// Spawn the background routing task
227    ///
228    /// This task continuously reads messages from the transport and routes them
229    /// to the appropriate handlers. It runs until `shutdown()` is called or
230    /// the transport is closed.
231    ///
232    /// # Arguments
233    ///
234    /// * `dispatcher` - Arc reference to the dispatcher
235    /// * `transport` - Arc reference to the transport
236    fn spawn_routing_task<T: Transport + 'static>(dispatcher: Arc<Self>, transport: Arc<T>) {
237        let response_waiters = dispatcher.response_waiters.clone();
238        let request_handler = dispatcher.request_handler.clone();
239        let notification_handler = dispatcher.notification_handler.clone();
240        let shutdown = dispatcher.shutdown.clone();
241
242        tokio::spawn(async move {
243            tracing::info!("Message dispatcher routing task started");
244
245            let mut consecutive_errors = 0u32;
246            let max_consecutive_errors = 20; // After 20 consecutive errors, back off significantly
247
248            loop {
249                tokio::select! {
250                    // Graceful shutdown
251                    _ = shutdown.notified() => {
252                        tracing::info!("Message dispatcher routing task shutting down");
253                        break;
254                    }
255
256                    // Read and route messages
257                    result = transport.receive() => {
258                        match result {
259                            Ok(Some(msg)) => {
260                                // Successfully received message - reset error counter
261                                consecutive_errors = 0;
262
263                                // Route the message
264                                if let Err(e) = Self::route_message(
265                                    msg,
266                                    &response_waiters,
267                                    &request_handler,
268                                    &notification_handler,
269                                ).await {
270                                    tracing::error!("Error routing message: {}", e);
271                                }
272                            }
273                            Ok(None) => {
274                                // No message available - transport returned None
275                                // Brief sleep to avoid busy-waiting
276                                tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
277                            }
278                            Err(e) => {
279                                consecutive_errors += 1;
280
281                                // Check transport state to determine error severity
282                                let state = transport.state().await;
283                                let is_fatal = matches!(state, turbomcp_transport::TransportState::Disconnected
284                                                             | turbomcp_transport::TransportState::Failed { .. });
285
286                                if consecutive_errors == 1 {
287                                    // First error - log at error level
288                                    tracing::error!("Transport receive error: {}", e);
289                                } else if consecutive_errors <= max_consecutive_errors {
290                                    // Subsequent errors - log at warn to reduce noise
291                                    tracing::warn!("Transport receive error (attempt {}): {}", consecutive_errors, e);
292                                } else {
293                                    // Too many errors - log once and suppress further logs
294                                    if consecutive_errors == max_consecutive_errors + 1 {
295                                        tracing::error!(
296                                            "Transport in failed state ({}), suppressing further error logs. Waiting for recovery...",
297                                            state
298                                        );
299                                    }
300                                }
301
302                                // Exponential backoff based on error count and transport state
303                                let delay_ms = if is_fatal {
304                                    // Fatal error - wait longer to avoid spam
305                                    if consecutive_errors > max_consecutive_errors {
306                                        5000 // 5 seconds when transport is dead
307                                    } else {
308                                        1000 // 1 second initially
309                                    }
310                                } else {
311                                    // Transient error - shorter backoff
312                                    100u64.saturating_mul(2u64.saturating_pow(consecutive_errors.min(5)))
313                                };
314
315                                tokio::time::sleep(tokio::time::Duration::from_millis(delay_ms)).await;
316                            }
317                        }
318                    }
319                }
320            }
321
322            tracing::info!("Message dispatcher routing task terminated");
323        });
324    }
325
326    /// Route an incoming message to the appropriate handler
327    ///
328    /// This is the core routing logic. It parses the raw transport message as
329    /// a JSON-RPC message and routes it based on type:
330    ///
331    /// - **Response**: Look up the waiting oneshot channel and send the response
332    /// - **Request**: Call the registered request handler
333    /// - **Notification**: Call the registered notification handler
334    ///
335    /// # Arguments
336    ///
337    /// * `msg` - The raw transport message to route
338    /// * `response_waiters` - Map of request IDs to oneshot senders
339    /// * `request_handler` - Optional request handler
340    /// * `notification_handler` - Optional notification handler
341    ///
342    /// # Errors
343    ///
344    /// Returns an error if the message cannot be parsed as valid JSON-RPC.
345    /// Handler errors are logged but do not propagate.
346    async fn route_message(
347        msg: TransportMessage,
348        response_waiters: &Arc<Mutex<HashMap<MessageId, oneshot::Sender<JsonRpcResponse>>>>,
349        request_handler: &Arc<Mutex<Option<RequestHandler>>>,
350        notification_handler: &Arc<Mutex<Option<NotificationHandler>>>,
351    ) -> Result<()> {
352        // Parse as JSON-RPC message
353        let json_msg: JsonRpcMessage = serde_json::from_slice(&msg.payload)
354            .map_err(|e| Error::protocol(format!("Invalid JSON-RPC message: {}", e)))?;
355
356        match json_msg {
357            JsonRpcMessage::Response(response) => {
358                // Route to waiting request() call
359                // ResponseId is Option<RequestId> where RequestId = MessageId
360                if let Some(request_id) = &response.id.0 {
361                    if let Some(tx) = response_waiters
362                        .lock()
363                        .expect("response_waiters mutex poisoned")
364                        .remove(request_id)
365                    {
366                        tracing::trace!("Routing response to request ID: {:?}", request_id);
367                        // Send response through oneshot channel
368                        // Ignore error if receiver was dropped (request timed out)
369                        let _ = tx.send(response);
370                    } else {
371                        tracing::warn!(
372                            "Received response for unknown/expired request ID: {:?}",
373                            request_id
374                        );
375                    }
376                } else {
377                    tracing::warn!("Received response with null ID (parse error)");
378                }
379            }
380
381            JsonRpcMessage::Request(request) => {
382                // Route to request handler (elicitation, sampling, etc.)
383                tracing::debug!(
384                    "Routing server-initiated request: method={}, id={:?}",
385                    request.method,
386                    request.id
387                );
388
389                if let Some(handler) = request_handler
390                    .lock()
391                    .expect("request_handler mutex poisoned")
392                    .as_ref()
393                {
394                    // Call handler (handler is responsible for sending response)
395                    if let Err(e) = handler(request) {
396                        tracing::error!("Request handler error: {}", e);
397                    }
398                } else {
399                    tracing::warn!(
400                        "Received server request but no handler registered: method={}",
401                        request.method
402                    );
403                }
404            }
405
406            JsonRpcMessage::Notification(notification) => {
407                // Route to notification handler
408                tracing::debug!(
409                    "Routing server notification: method={}",
410                    notification.method
411                );
412
413                if let Some(handler) = notification_handler
414                    .lock()
415                    .expect("notification_handler mutex poisoned")
416                    .as_ref()
417                {
418                    if let Err(e) = handler(notification) {
419                        tracing::error!("Notification handler error: {}", e);
420                    }
421                } else {
422                    tracing::debug!(
423                        "Received notification but no handler registered: method={}",
424                        notification.method
425                    );
426                }
427            }
428
429            JsonRpcMessage::RequestBatch(_)
430            | JsonRpcMessage::ResponseBatch(_)
431            | JsonRpcMessage::MessageBatch(_) => {
432                // Batch operations not yet supported
433                tracing::debug!("Received batch message (not yet supported)");
434            }
435        }
436
437        Ok(())
438    }
439}
440
441impl std::fmt::Debug for MessageDispatcher {
442    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
443        f.debug_struct("MessageDispatcher")
444            .field("response_waiters", &"<Arc<Mutex<HashMap>>>")
445            .field("request_handler", &"<Arc<Mutex<Option<Handler>>>>")
446            .field("notification_handler", &"<Arc<Mutex<Option<Handler>>>>")
447            .field("shutdown", &"<Arc<Notify>>")
448            .finish()
449    }
450}
451
452#[cfg(test)]
453mod tests {
454
455    // Note: Full integration tests with mock transport will be added
456    // in tests/bidirectional_integration.rs
457
458    #[test]
459    fn test_dispatcher_creation() {
460        // Smoke test to ensure the module compiles and basic structures work
461        // Full testing requires a mock transport
462    }
463}