mcp_core/
protocol.rs

1use super::transport::{JsonRpcError, JsonRpcNotification, JsonRpcRequest, JsonRpcResponse};
2use super::types::ErrorCode;
3use anyhow::Result;
4use async_trait::async_trait;
5use serde::de::DeserializeOwned;
6use serde::Serialize;
7use serde_json::json;
8use std::pin::Pin;
9
10use std::sync::atomic::{AtomicU64, Ordering};
11use std::time::Duration;
12use std::{collections::HashMap, sync::Arc};
13use tokio::sync::{oneshot, Mutex};
14
15#[derive(Clone)]
16pub struct Protocol {
17    request_id: Arc<AtomicU64>,
18    pending_requests: Arc<Mutex<HashMap<u64, oneshot::Sender<JsonRpcResponse>>>>,
19    request_handlers: Arc<Mutex<HashMap<String, Box<dyn RequestHandler>>>>,
20    notification_handlers: Arc<Mutex<HashMap<String, Box<dyn NotificationHandler>>>>,
21}
22
23impl Protocol {
24    pub fn builder() -> ProtocolBuilder {
25        ProtocolBuilder::new()
26    }
27
28    pub async fn handle_request(&self, request: JsonRpcRequest) -> JsonRpcResponse {
29        let handlers = self.request_handlers.lock().await;
30        if let Some(handler) = handlers.get(&request.method) {
31            match handler.handle(request.clone()).await {
32                Ok(response) => response,
33                Err(e) => JsonRpcResponse {
34                    id: request.id,
35                    result: None,
36                    error: Some(JsonRpcError {
37                        code: ErrorCode::InternalError as i32,
38                        message: e.to_string(),
39                        data: None,
40                    }),
41                    ..Default::default()
42                },
43            }
44        } else {
45            JsonRpcResponse {
46                id: request.id,
47                error: Some(JsonRpcError {
48                    code: ErrorCode::MethodNotFound as i32,
49                    message: format!("Method not found: {}", request.method),
50                    data: None,
51                }),
52                ..Default::default()
53            }
54        }
55    }
56
57    pub async fn handle_notification(&self, request: JsonRpcNotification) {
58        let handlers = self.notification_handlers.lock().await;
59        if let Some(handler) = handlers.get(&request.method) {
60            match handler.handle(request.clone()).await {
61                Ok(_) => tracing::info!("Received notification: {:?}", request.method),
62                Err(e) => tracing::error!("Error handling notification: {}", e),
63            }
64        } else {
65            tracing::debug!("No handler for notification: {}", request.method);
66        }
67    }
68
69    pub async fn create_request(&self) -> (u64, oneshot::Receiver<JsonRpcResponse>) {
70        let id = self.request_id.fetch_add(1, Ordering::SeqCst);
71        let (tx, rx) = oneshot::channel();
72
73        {
74            let mut pending = self.pending_requests.lock().await;
75            pending.insert(id, tx);
76        }
77
78        (id, rx)
79    }
80
81    pub async fn handle_response(&self, response: JsonRpcResponse) {
82        if let Some(tx) = self.pending_requests.lock().await.remove(&response.id) {
83            let _ = tx.send(response);
84        }
85    }
86
87    pub async fn cancel_response(&self, id: u64) {
88        if let Some(tx) = self.pending_requests.lock().await.remove(&id) {
89            let _ = tx.send(JsonRpcResponse {
90                id,
91                result: None,
92                error: Some(JsonRpcError {
93                    code: ErrorCode::RequestTimeout as i32,
94                    message: "Request cancelled".to_string(),
95                    data: None,
96                }),
97                ..Default::default()
98            });
99        }
100    }
101}
102
103/// The default request timeout, in milliseconds
104pub const DEFAULT_REQUEST_TIMEOUT_MSEC: u64 = 60000;
105pub struct RequestOptions {
106    pub timeout: Duration,
107}
108
109impl RequestOptions {
110    pub fn timeout(self, timeout: Duration) -> Self {
111        Self { timeout }
112    }
113}
114
115impl Default for RequestOptions {
116    fn default() -> Self {
117        Self {
118            timeout: Duration::from_millis(DEFAULT_REQUEST_TIMEOUT_MSEC),
119        }
120    }
121}
122
123#[derive(Clone)]
124pub struct ProtocolBuilder {
125    request_handlers: Arc<Mutex<HashMap<String, Box<dyn RequestHandler>>>>,
126    notification_handlers: Arc<Mutex<HashMap<String, Box<dyn NotificationHandler>>>>,
127}
128
129impl ProtocolBuilder {
130    pub fn new() -> Self {
131        Self {
132            request_handlers: Arc::new(Mutex::new(HashMap::new())),
133            notification_handlers: Arc::new(Mutex::new(HashMap::new())),
134        }
135    }
136
137    /// Register a typed request handler
138    pub fn request_handler<Req, Resp>(
139        self,
140        method: &str,
141        handler: impl Fn(Req) -> Pin<Box<dyn std::future::Future<Output = Result<Resp>> + Send>>
142            + Send
143            + Sync
144            + 'static,
145    ) -> Self
146    where
147        Req: DeserializeOwned + Send + Sync + 'static,
148        Resp: Serialize + Send + Sync + 'static,
149    {
150        let handler = TypedRequestHandler {
151            handler: Box::new(handler),
152            _phantom: std::marker::PhantomData,
153        };
154
155        if let Ok(mut handlers) = self.request_handlers.try_lock() {
156            handlers.insert(method.to_string(), Box::new(handler));
157        }
158        self
159    }
160
161    pub fn has_request_handler(&self, method: &str) -> bool {
162        self.request_handlers
163            .try_lock()
164            .map(|handlers| handlers.contains_key(method))
165            .unwrap_or(false)
166    }
167
168    pub fn notification_handler<N>(
169        self,
170        method: &str,
171        handler: impl Fn(N) -> Pin<Box<dyn std::future::Future<Output = Result<()>> + Send>>
172            + Send
173            + Sync
174            + 'static,
175    ) -> Self
176    where
177        N: DeserializeOwned + Send + Sync + 'static,
178    {
179        let handler = TypedNotificationHandler {
180            handler: Box::new(handler),
181            _phantom: std::marker::PhantomData,
182        };
183
184        if let Ok(mut handlers) = self.notification_handlers.try_lock() {
185            handlers.insert(method.to_string(), Box::new(handler));
186        }
187        self
188    }
189
190    pub fn has_notification_handler(&self, method: &str) -> bool {
191        self.notification_handlers
192            .try_lock()
193            .map(|handlers| handlers.contains_key(method))
194            .unwrap_or(false)
195    }
196
197    pub fn build(self) -> Protocol {
198        Protocol {
199            request_id: Arc::new(AtomicU64::new(0)),
200            pending_requests: Arc::new(Mutex::new(HashMap::new())),
201            request_handlers: self.request_handlers,
202            notification_handlers: self.notification_handlers,
203        }
204    }
205}
206
207// Update the handler traits to be async
208#[async_trait]
209trait RequestHandler: Send + Sync {
210    async fn handle(&self, request: JsonRpcRequest) -> Result<JsonRpcResponse>;
211}
212
213#[async_trait]
214trait NotificationHandler: Send + Sync {
215    async fn handle(&self, notification: JsonRpcNotification) -> Result<()>;
216}
217
218// Update the TypedRequestHandler to use async handlers
219struct TypedRequestHandler<Req, Resp>
220where
221    Req: DeserializeOwned + Send + Sync + 'static,
222    Resp: Serialize + Send + Sync + 'static,
223{
224    handler: Box<
225        dyn Fn(Req) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<Resp>> + Send>>
226            + Send
227            + Sync,
228    >,
229    _phantom: std::marker::PhantomData<(Req, Resp)>,
230}
231
232#[async_trait]
233impl<Req, Resp> RequestHandler for TypedRequestHandler<Req, Resp>
234where
235    Req: DeserializeOwned + Send + Sync + 'static,
236    Resp: Serialize + Send + Sync + 'static,
237{
238    async fn handle(&self, request: JsonRpcRequest) -> Result<JsonRpcResponse> {
239        let params: Req = if request.params.is_none() || request.params.as_ref().unwrap().is_null()
240        {
241            serde_json::from_value(json!({}))?
242        } else {
243            serde_json::from_value(request.params.unwrap())?
244        };
245        let result = (self.handler)(params).await?;
246        Ok(JsonRpcResponse {
247            id: request.id,
248            result: Some(serde_json::to_value(result)?),
249            error: None,
250            ..Default::default()
251        })
252    }
253}
254
255struct TypedNotificationHandler<N>
256where
257    N: DeserializeOwned + Send + Sync + 'static,
258{
259    handler: Box<
260        dyn Fn(N) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<()>> + Send>>
261            + Send
262            + Sync,
263    >,
264    _phantom: std::marker::PhantomData<N>,
265}
266
267#[async_trait]
268impl<N> NotificationHandler for TypedNotificationHandler<N>
269where
270    N: DeserializeOwned + Send + Sync + 'static,
271{
272    async fn handle(&self, notification: JsonRpcNotification) -> Result<()> {
273        let params: N =
274            if notification.params.is_none() || notification.params.as_ref().unwrap().is_null() {
275                serde_json::from_value(serde_json::Value::Null)?
276            } else {
277                serde_json::from_value(notification.params.unwrap())?
278            };
279        (self.handler)(params).await
280    }
281}