turbomcp_client/client/
protocol.rs

1//! Protocol client for JSON-RPC communication
2//!
3//! This module provides the ProtocolClient which handles the low-level
4//! JSON-RPC protocol communication with MCP servers.
5//!
6//! ## Bidirectional Communication Architecture
7//!
8//! The ProtocolClient uses a MessageDispatcher to solve the bidirectional
9//! communication problem. Instead of directly calling `transport.receive()`,
10//! which created race conditions when multiple code paths tried to receive,
11//! we now use a centralized message routing layer:
12//!
13//! ```text
14//! ProtocolClient::request()
15//!     ↓
16//!   1. Register oneshot channel with dispatcher
17//!   2. Send request via transport
18//!   3. Wait on oneshot channel
19//!     ↓
20//! MessageDispatcher (background task)
21//!     ↓
22//!   Continuously reads transport.receive()
23//!   Routes responses → oneshot channels
24//!   Routes requests → Client handlers
25//! ```
26//!
27//! This ensures there's only ONE consumer of transport.receive(),
28//! eliminating the race condition.
29
30use std::sync::Arc;
31use std::sync::atomic::{AtomicU64, Ordering};
32
33use turbomcp_protocol::jsonrpc::{JsonRpcRequest, JsonRpcVersion};
34use turbomcp_protocol::{Error, Result};
35use turbomcp_transport::{Transport, TransportConfig, TransportMessage};
36
37use super::dispatcher::MessageDispatcher;
38
39/// JSON-RPC protocol handler for MCP communication
40///
41/// Handles request/response correlation, serialization, and protocol-level concerns.
42/// This is the abstraction layer between raw Transport and high-level Client APIs.
43///
44/// ## Architecture
45///
46/// The ProtocolClient now uses a MessageDispatcher to handle bidirectional
47/// communication correctly. The dispatcher runs a background task that:
48/// - Reads ALL messages from the transport
49/// - Routes responses to waiting request() calls
50/// - Routes incoming requests to registered handlers
51///
52/// This eliminates race conditions by centralizing all message routing
53/// in a single background task.
54#[derive(Debug)]
55pub(super) struct ProtocolClient<T: Transport> {
56    transport: Arc<T>,
57    dispatcher: Arc<MessageDispatcher>,
58    next_id: AtomicU64,
59    /// Transport configuration for timeout enforcement (v2.2.0+)
60    config: TransportConfig,
61}
62
63impl<T: Transport + 'static> ProtocolClient<T> {
64    /// Create a new protocol client with message dispatcher
65    ///
66    /// This automatically starts the message routing background task.
67    pub(super) fn new(transport: T) -> Self {
68        let transport = Arc::new(transport);
69        let dispatcher = MessageDispatcher::new(transport.clone());
70
71        Self {
72            transport,
73            dispatcher,
74            next_id: AtomicU64::new(1),
75            config: TransportConfig::default(), // Use default timeout config
76        }
77    }
78
79    /// Create a new protocol client with custom transport configuration
80    ///
81    /// This allows setting custom timeouts and limits.
82    #[allow(dead_code)] // May be used in future
83    pub(super) fn with_config(transport: T, config: TransportConfig) -> Self {
84        let transport = Arc::new(transport);
85        let dispatcher = MessageDispatcher::new(transport.clone());
86
87        Self {
88            transport,
89            dispatcher,
90            next_id: AtomicU64::new(1),
91            config,
92        }
93    }
94
95    /// Get the message dispatcher for handler registration
96    ///
97    /// This allows the Client to register request/notification handlers
98    /// with the dispatcher.
99    pub(super) fn dispatcher(&self) -> &Arc<MessageDispatcher> {
100        &self.dispatcher
101    }
102
103    /// Send JSON-RPC request and await typed response
104    ///
105    /// ## New Architecture (v2.0+)
106    ///
107    /// Instead of calling `transport.receive()` directly (which created the
108    /// race condition), this method now:
109    ///
110    /// 1. Registers a oneshot channel with the dispatcher BEFORE sending
111    /// 2. Sends the request via transport
112    /// 3. Waits on the oneshot channel for the response
113    ///
114    /// The dispatcher's background task receives the response and routes it
115    /// to the oneshot channel. This ensures responses always reach the right
116    /// request() call, even when the server sends requests (elicitation, etc.)
117    /// in between.
118    ///
119    /// ## Example Flow with Elicitation
120    ///
121    /// ```text
122    /// Client: call_tool("test") → request(id=1)
123    ///   1. Register oneshot channel for id=1
124    ///   2. Send tools/call request
125    ///   3. Wait on channel...
126    ///
127    /// Server: Sends elicitation/create request (id=2)
128    ///   → Dispatcher routes to request handler
129    ///   → Client processes elicitation
130    ///   → Client sends elicitation response
131    ///
132    /// Server: Sends tools/call response (id=1)
133    ///   → Dispatcher routes to oneshot channel for id=1
134    ///   → request() receives response ✓
135    /// ```
136    pub(super) async fn request<R: serde::de::DeserializeOwned>(
137        &self,
138        method: &str,
139        params: Option<serde_json::Value>,
140    ) -> Result<R> {
141        // Wrap the entire operation in total timeout (if configured)
142        let operation = self.request_inner(method, params);
143
144        if let Some(total_timeout) = self.config.timeouts.total {
145            match tokio::time::timeout(total_timeout, operation).await {
146                Ok(result) => result,
147                Err(_) => {
148                    let err = turbomcp_transport::TransportError::TotalTimeout {
149                        operation: format!("{}()", method),
150                        timeout: total_timeout,
151                    };
152                    Err(Error::transport(err.to_string()))
153                }
154            }
155        } else {
156            operation.await
157        }
158    }
159
160    /// Inner request implementation without total timeout wrapper
161    async fn request_inner<R: serde::de::DeserializeOwned>(
162        &self,
163        method: &str,
164        params: Option<serde_json::Value>,
165    ) -> Result<R> {
166        // Generate unique request ID
167        let id = self.next_id.fetch_add(1, Ordering::Relaxed);
168        let request_id = turbomcp_protocol::MessageId::from(id.to_string());
169
170        // Build JSON-RPC request
171        let request = JsonRpcRequest {
172            jsonrpc: JsonRpcVersion,
173            id: request_id.clone(),
174            method: method.to_string(),
175            params,
176        };
177
178        // Step 1: Register oneshot channel BEFORE sending request
179        // This ensures the dispatcher can route the response when it arrives
180        let response_receiver = self.dispatcher.wait_for_response(request_id.clone());
181
182        // Step 2: Serialize and send request
183        let payload = serde_json::to_vec(&request)
184            .map_err(|e| Error::protocol(format!("Failed to serialize request: {e}")))?;
185
186        let message = TransportMessage::new(
187            turbomcp_protocol::MessageId::from(format!("req-{id}")),
188            payload.into(),
189        );
190
191        self.transport
192            .send(message)
193            .await
194            .map_err(|e| Error::transport(format!("Transport send failed: {e}")))?;
195
196        // Step 3: Wait for response via oneshot channel with request timeout
197        // The dispatcher's background task will send the response when it arrives
198        let response = if let Some(request_timeout) = self.config.timeouts.request {
199            match tokio::time::timeout(request_timeout, response_receiver).await {
200                Ok(Ok(response)) => response,
201                Ok(Err(_)) => return Err(Error::transport("Response channel closed".to_string())),
202                Err(_) => {
203                    let err = turbomcp_transport::TransportError::RequestTimeout {
204                        operation: format!("{}()", method),
205                        timeout: request_timeout,
206                    };
207                    return Err(Error::transport(err.to_string()));
208                }
209            }
210        } else {
211            response_receiver
212                .await
213                .map_err(|_| Error::transport("Response channel closed".to_string()))?
214        };
215
216        // Handle JSON-RPC errors
217        if let Some(error) = response.error() {
218            tracing::info!(
219                "🔍 [protocol.rs] Received JSON-RPC error - code: {}, message: {}",
220                error.code,
221                error.message
222            );
223            let err = Error::rpc(error.code, &error.message);
224            tracing::info!(
225                "🔍 [protocol.rs] Created Error - kind: {:?}, jsonrpc_code: {}",
226                err.kind,
227                err.jsonrpc_error_code()
228            );
229            return Err(err);
230        }
231
232        // Deserialize result
233        serde_json::from_value(response.result().unwrap_or_default().clone())
234            .map_err(|e| Error::protocol(format!("Failed to deserialize response: {e}")))
235    }
236
237    /// Send JSON-RPC notification (no response expected)
238    pub(super) async fn notify(
239        &self,
240        method: &str,
241        params: Option<serde_json::Value>,
242    ) -> Result<()> {
243        let request = serde_json::json!({
244            "jsonrpc": "2.0",
245            "method": method,
246            "params": params
247        });
248
249        let payload = serde_json::to_vec(&request)
250            .map_err(|e| Error::protocol(format!("Failed to serialize notification: {e}")))?;
251
252        let message = TransportMessage::new(
253            turbomcp_protocol::MessageId::from("notification"),
254            payload.into(),
255        );
256
257        self.transport
258            .send(message)
259            .await
260            .map_err(|e| Error::transport(format!("Transport send failed: {e}")))
261    }
262
263    /// Connect the transport
264    #[allow(dead_code)] // Reserved for future use
265    pub(super) async fn connect(&self) -> Result<()> {
266        self.transport
267            .connect()
268            .await
269            .map_err(|e| Error::transport(format!("Transport connect failed: {e}")))
270    }
271
272    /// Disconnect the transport
273    #[allow(dead_code)] // Reserved for future use
274    pub(super) async fn disconnect(&self) -> Result<()> {
275        self.transport
276            .disconnect()
277            .await
278            .map_err(|e| Error::transport(format!("Transport disconnect failed: {e}")))
279    }
280
281    /// Get transport reference
282    ///
283    /// Returns an Arc reference to the transport, allowing it to be shared
284    /// with other components (like the message dispatcher).
285    pub(super) fn transport(&self) -> &Arc<T> {
286        &self.transport
287    }
288}