mcp_core/
protocol.rs

1//! # MCP Protocol Implementation
2//!
3//! This module implements the core JSON-RPC protocol layer used by the MCP system.
4//! It provides the infrastructure for sending and receiving JSON-RPC requests,
5//! notifications, and responses between MCP clients and servers.
6//!
7//! The protocol layer is transport-agnostic and can work with any transport
8//! implementation that conforms to the `Transport` trait.
9//!
10//! Key components include:
11//! - `Protocol`: The main protocol handler
12//! - `ProtocolBuilder`: A builder for configuring protocols
13//! - Request and notification handlers
14//! - Timeout and error handling
15
16use super::transport::{JsonRpcError, JsonRpcNotification, JsonRpcRequest, JsonRpcResponse};
17use super::types::ErrorCode;
18use anyhow::Result;
19use async_trait::async_trait;
20use serde::de::DeserializeOwned;
21use serde::Serialize;
22use serde_json::json;
23use std::pin::Pin;
24
25use std::sync::atomic::{AtomicU64, Ordering};
26use std::time::Duration;
27use std::{collections::HashMap, sync::Arc};
28use tokio::sync::{oneshot, Mutex};
29
30/// The core protocol handler for MCP.
31///
32/// The `Protocol` struct manages the lifecycle of JSON-RPC requests and responses,
33/// dispatches incoming requests to the appropriate handlers, and manages
34/// pending requests and their responses.
35#[derive(Clone)]
36pub struct Protocol {
37    request_id: Arc<AtomicU64>,
38    pending_requests: Arc<Mutex<HashMap<u64, oneshot::Sender<JsonRpcResponse>>>>,
39    request_handlers: Arc<Mutex<HashMap<String, Box<dyn RequestHandler>>>>,
40    notification_handlers: Arc<Mutex<HashMap<String, Box<dyn NotificationHandler>>>>,
41}
42
43impl Protocol {
44    /// Creates a new protocol builder.
45    ///
46    /// # Returns
47    ///
48    /// A `ProtocolBuilder` for configuring the protocol
49    pub fn builder() -> ProtocolBuilder {
50        ProtocolBuilder::new()
51    }
52
53    /// Handles an incoming JSON-RPC request.
54    ///
55    /// This method dispatches the request to the appropriate handler based on
56    /// the request method, and returns the handler's response.
57    ///
58    /// # Arguments
59    ///
60    /// * `request` - The incoming JSON-RPC request
61    ///
62    /// # Returns
63    ///
64    /// A `JsonRpcResponse` containing the handler's response or an error
65    pub async fn handle_request(&self, request: JsonRpcRequest) -> JsonRpcResponse {
66        let handlers = self.request_handlers.lock().await;
67        if let Some(handler) = handlers.get(&request.method) {
68            match handler.handle(request.clone()).await {
69                Ok(response) => response,
70                Err(e) => JsonRpcResponse {
71                    id: request.id,
72                    result: None,
73                    error: Some(JsonRpcError {
74                        code: ErrorCode::InternalError as i32,
75                        message: e.to_string(),
76                        data: None,
77                    }),
78                    ..Default::default()
79                },
80            }
81        } else {
82            JsonRpcResponse {
83                id: request.id,
84                error: Some(JsonRpcError {
85                    code: ErrorCode::MethodNotFound as i32,
86                    message: format!("Method not found: {}", request.method),
87                    data: None,
88                }),
89                ..Default::default()
90            }
91        }
92    }
93
94    /// Handles an incoming JSON-RPC notification.
95    ///
96    /// This method dispatches the notification to the appropriate handler based on
97    /// the notification method.
98    ///
99    /// # Arguments
100    ///
101    /// * `request` - The incoming JSON-RPC notification
102    pub async fn handle_notification(&self, request: JsonRpcNotification) {
103        let handlers = self.notification_handlers.lock().await;
104        if let Some(handler) = handlers.get(&request.method) {
105            match handler.handle(request.clone()).await {
106                Ok(_) => tracing::info!("Received notification: {:?}", request.method),
107                Err(e) => tracing::error!("Error handling notification: {}", e),
108            }
109        } else {
110            tracing::debug!("No handler for notification: {}", request.method);
111        }
112    }
113
114    /// Generates a new unique message ID for requests.
115    ///
116    /// # Returns
117    ///
118    /// A unique message ID
119    pub fn new_message_id(&self) -> u64 {
120        self.request_id.fetch_add(1, Ordering::SeqCst)
121    }
122
123    /// Creates a new request ID and channel for receiving the response.
124    ///
125    /// # Returns
126    ///
127    /// A tuple containing the request ID and a receiver for the response
128    pub async fn create_request(&self) -> (u64, oneshot::Receiver<JsonRpcResponse>) {
129        let id = self.new_message_id();
130        let (tx, rx) = oneshot::channel();
131
132        {
133            let mut pending = self.pending_requests.lock().await;
134            pending.insert(id, tx);
135        }
136
137        (id, rx)
138    }
139
140    /// Handles an incoming JSON-RPC response.
141    ///
142    /// This method delivers the response to the appropriate waiting request,
143    /// if any.
144    ///
145    /// # Arguments
146    ///
147    /// * `response` - The incoming JSON-RPC response
148    pub async fn handle_response(&self, response: JsonRpcResponse) {
149        if let Some(tx) = self.pending_requests.lock().await.remove(&response.id) {
150            let _ = tx.send(response);
151        }
152    }
153
154    /// Cancels a pending request and sends an error response.
155    ///
156    /// # Arguments
157    ///
158    /// * `id` - The ID of the request to cancel
159    pub async fn cancel_response(&self, id: u64) {
160        if let Some(tx) = self.pending_requests.lock().await.remove(&id) {
161            let _ = tx.send(JsonRpcResponse {
162                id,
163                result: None,
164                error: Some(JsonRpcError {
165                    code: ErrorCode::RequestTimeout as i32,
166                    message: "Request cancelled".to_string(),
167                    data: None,
168                }),
169                ..Default::default()
170            });
171        }
172    }
173}
174
175/// The default request timeout, in milliseconds
176pub const DEFAULT_REQUEST_TIMEOUT_MSEC: u64 = 60000;
177
178/// Options for customizing requests.
179///
180/// This struct allows configuring various aspects of request handling,
181/// such as timeouts.
182pub struct RequestOptions {
183    /// The timeout duration for the request
184    pub timeout: Duration,
185}
186
187impl RequestOptions {
188    /// Sets the timeout for the request.
189    ///
190    /// # Arguments
191    ///
192    /// * `timeout` - The timeout duration
193    ///
194    /// # Returns
195    ///
196    /// The modified options instance
197    pub fn timeout(self, timeout: Duration) -> Self {
198        Self { timeout }
199    }
200}
201
202impl Default for RequestOptions {
203    fn default() -> Self {
204        Self {
205            timeout: Duration::from_millis(DEFAULT_REQUEST_TIMEOUT_MSEC),
206        }
207    }
208}
209
210/// Builder for creating configured protocols.
211///
212/// The `ProtocolBuilder` provides a fluent API for configuring and creating
213/// protocols with specific request and notification handlers.
214#[derive(Clone)]
215pub struct ProtocolBuilder {
216    request_handlers: Arc<Mutex<HashMap<String, Box<dyn RequestHandler>>>>,
217    notification_handlers: Arc<Mutex<HashMap<String, Box<dyn NotificationHandler>>>>,
218}
219
220impl ProtocolBuilder {
221    /// Creates a new protocol builder.
222    ///
223    /// # Returns
224    ///
225    /// A new `ProtocolBuilder` instance
226    pub fn new() -> Self {
227        Self {
228            request_handlers: Arc::new(Mutex::new(HashMap::new())),
229            notification_handlers: Arc::new(Mutex::new(HashMap::new())),
230        }
231    }
232
233    /// Registers a typed request handler.
234    ///
235    /// # Arguments
236    ///
237    /// * `method` - The method name to handle
238    /// * `handler` - The handler function
239    ///
240    /// # Returns
241    ///
242    /// The modified builder instance
243    pub fn request_handler<Req, Resp>(
244        self,
245        method: &str,
246        handler: impl Fn(Req) -> Pin<Box<dyn std::future::Future<Output = Result<Resp>> + Send>>
247            + Send
248            + Sync
249            + 'static,
250    ) -> Self
251    where
252        Req: DeserializeOwned + Send + Sync + 'static,
253        Resp: Serialize + Send + Sync + 'static,
254    {
255        let handler = TypedRequestHandler {
256            handler: Box::new(handler),
257            _phantom: std::marker::PhantomData,
258        };
259
260        if let Ok(mut handlers) = self.request_handlers.try_lock() {
261            handlers.insert(method.to_string(), Box::new(handler));
262        }
263        self
264    }
265
266    /// Checks if a request handler exists for a method.
267    ///
268    /// # Arguments
269    ///
270    /// * `method` - The method name to check
271    ///
272    /// # Returns
273    ///
274    /// `true` if a handler exists, `false` otherwise
275    pub fn has_request_handler(&self, method: &str) -> bool {
276        self.request_handlers
277            .try_lock()
278            .map(|handlers| handlers.contains_key(method))
279            .unwrap_or(false)
280    }
281
282    /// Registers a typed notification handler.
283    ///
284    /// # Arguments
285    ///
286    /// * `method` - The method name to handle
287    /// * `handler` - The handler function
288    ///
289    /// # Returns
290    ///
291    /// The modified builder instance
292    pub fn notification_handler<N>(
293        self,
294        method: &str,
295        handler: impl Fn(N) -> Pin<Box<dyn std::future::Future<Output = Result<()>> + Send>>
296            + Send
297            + Sync
298            + 'static,
299    ) -> Self
300    where
301        N: DeserializeOwned + Send + Sync + 'static,
302    {
303        let handler = TypedNotificationHandler {
304            handler: Box::new(handler),
305            _phantom: std::marker::PhantomData,
306        };
307
308        if let Ok(mut handlers) = self.notification_handlers.try_lock() {
309            handlers.insert(method.to_string(), Box::new(handler));
310        }
311        self
312    }
313
314    /// Checks if a notification handler exists for a method.
315    ///
316    /// # Arguments
317    ///
318    /// * `method` - The method name to check
319    ///
320    /// # Returns
321    ///
322    /// `true` if a handler exists, `false` otherwise
323    pub fn has_notification_handler(&self, method: &str) -> bool {
324        self.notification_handlers
325            .try_lock()
326            .map(|handlers| handlers.contains_key(method))
327            .unwrap_or(false)
328    }
329
330    /// Builds the protocol with the configured handlers.
331    ///
332    /// # Returns
333    ///
334    /// A new `Protocol` instance
335    pub fn build(self) -> Protocol {
336        Protocol {
337            request_id: Arc::new(AtomicU64::new(0)),
338            pending_requests: Arc::new(Mutex::new(HashMap::new())),
339            request_handlers: self.request_handlers,
340            notification_handlers: self.notification_handlers,
341        }
342    }
343}
344
345/// Trait for handling JSON-RPC requests.
346///
347/// Implementors of this trait can handle incoming JSON-RPC requests
348/// and produce responses.
349#[async_trait]
350trait RequestHandler: Send + Sync {
351    /// Handles an incoming JSON-RPC request.
352    ///
353    /// # Arguments
354    ///
355    /// * `request` - The incoming JSON-RPC request
356    ///
357    /// # Returns
358    ///
359    /// A `Result` containing the response or an error
360    async fn handle(&self, request: JsonRpcRequest) -> Result<JsonRpcResponse>;
361}
362
363/// Trait for handling JSON-RPC notifications.
364///
365/// Implementors of this trait can handle incoming JSON-RPC notifications.
366#[async_trait]
367trait NotificationHandler: Send + Sync {
368    /// Handles an incoming JSON-RPC notification.
369    ///
370    /// # Arguments
371    ///
372    /// * `notification` - The incoming JSON-RPC notification
373    ///
374    /// # Returns
375    ///
376    /// A `Result` indicating success or failure
377    async fn handle(&self, notification: JsonRpcNotification) -> Result<()>;
378}
379
380/// A typed request handler.
381///
382/// This struct adapts a typed handler function to the `RequestHandler` trait,
383/// handling the deserialization of the request and serialization of the response.
384struct TypedRequestHandler<Req, Resp>
385where
386    Req: DeserializeOwned + Send + Sync + 'static,
387    Resp: Serialize + Send + Sync + 'static,
388{
389    handler: Box<
390        dyn Fn(Req) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<Resp>> + Send>>
391            + Send
392            + Sync,
393    >,
394    _phantom: std::marker::PhantomData<(Req, Resp)>,
395}
396
397#[async_trait]
398impl<Req, Resp> RequestHandler for TypedRequestHandler<Req, Resp>
399where
400    Req: DeserializeOwned + Send + Sync + 'static,
401    Resp: Serialize + Send + Sync + 'static,
402{
403    async fn handle(&self, request: JsonRpcRequest) -> Result<JsonRpcResponse> {
404        let params: Req = if request.params.is_none() || request.params.as_ref().unwrap().is_null()
405        {
406            serde_json::from_value(json!({}))?
407        } else {
408            serde_json::from_value(request.params.unwrap())?
409        };
410        let result = (self.handler)(params).await?;
411        Ok(JsonRpcResponse {
412            id: request.id,
413            result: Some(serde_json::to_value(result)?),
414            error: None,
415            ..Default::default()
416        })
417    }
418}
419
420/// A typed notification handler.
421///
422/// This struct adapts a typed handler function to the `NotificationHandler` trait,
423/// handling the deserialization of the notification.
424struct TypedNotificationHandler<N>
425where
426    N: DeserializeOwned + Send + Sync + 'static,
427{
428    handler: Box<
429        dyn Fn(N) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<()>> + Send>>
430            + Send
431            + Sync,
432    >,
433    _phantom: std::marker::PhantomData<N>,
434}
435
436#[async_trait]
437impl<N> NotificationHandler for TypedNotificationHandler<N>
438where
439    N: DeserializeOwned + Send + Sync + 'static,
440{
441    async fn handle(&self, notification: JsonRpcNotification) -> Result<()> {
442        let params: N =
443            if notification.params.is_none() || notification.params.as_ref().unwrap().is_null() {
444                serde_json::from_value(serde_json::Value::Null)?
445            } else {
446                serde_json::from_value(notification.params.unwrap())?
447            };
448        (self.handler)(params).await
449    }
450}