mcp_core/
protocol.rs

1use super::transport::{JsonRpcError, JsonRpcNotification, JsonRpcRequest, JsonRpcResponse};
2use super::types::ErrorCode;
3use anyhow::Result;
4use async_trait::async_trait;
5use serde::de::DeserializeOwned;
6use serde::Serialize;
7use std::pin::Pin;
8
9use std::sync::atomic::{AtomicU64, Ordering};
10use std::time::Duration;
11use std::{collections::HashMap, sync::Arc};
12use tokio::sync::{oneshot, Mutex};
13
14#[derive(Clone)]
15pub struct Protocol {
16    request_id: Arc<AtomicU64>,
17    pending_requests: Arc<Mutex<HashMap<u64, oneshot::Sender<JsonRpcResponse>>>>,
18    request_handlers: Arc<Mutex<HashMap<String, Box<dyn RequestHandler>>>>,
19    notification_handlers: Arc<Mutex<HashMap<String, Box<dyn NotificationHandler>>>>,
20}
21
22impl Protocol {
23    pub fn builder() -> ProtocolBuilder {
24        ProtocolBuilder::new()
25    }
26
27    pub async fn handle_request(&self, request: JsonRpcRequest) -> JsonRpcResponse {
28        let handlers = self.request_handlers.lock().await;
29        if let Some(handler) = handlers.get(&request.method) {
30            match handler.handle(request.clone()).await {
31                Ok(response) => response,
32                Err(e) => JsonRpcResponse {
33                    id: request.id,
34                    result: None,
35                    error: Some(JsonRpcError {
36                        code: ErrorCode::InternalError as i32,
37                        message: e.to_string(),
38                        data: None,
39                    }),
40                    ..Default::default()
41                },
42            }
43        } else {
44            JsonRpcResponse {
45                id: request.id,
46                error: Some(JsonRpcError {
47                    code: ErrorCode::MethodNotFound as i32,
48                    message: format!("Method not found: {}", request.method),
49                    data: None,
50                }),
51                ..Default::default()
52            }
53        }
54    }
55
56    pub async fn handle_notification(&self, request: JsonRpcNotification) {
57        let handlers = self.notification_handlers.lock().await;
58        if let Some(handler) = handlers.get(&request.method) {
59            match handler.handle(request.clone()).await {
60                Ok(_) => tracing::info!("Received notification: {:?}", request.method),
61                Err(e) => tracing::error!("Error handling notification: {}", e),
62            }
63        } else {
64            tracing::debug!("No handler for notification: {}", request.method);
65        }
66    }
67
68    pub async fn create_request(&self) -> (u64, oneshot::Receiver<JsonRpcResponse>) {
69        let id = self.request_id.fetch_add(1, Ordering::SeqCst);
70        let (tx, rx) = oneshot::channel();
71
72        {
73            let mut pending = self.pending_requests.lock().await;
74            pending.insert(id, tx);
75        }
76
77        (id, rx)
78    }
79
80    pub async fn handle_response(&self, response: JsonRpcResponse) {
81        if let Some(tx) = self.pending_requests.lock().await.remove(&response.id) {
82            let _ = tx.send(response);
83        }
84    }
85
86    pub async fn cancel_response(&self, id: u64) {
87        if let Some(tx) = self.pending_requests.lock().await.remove(&id) {
88            let _ = tx.send(JsonRpcResponse {
89                id,
90                result: None,
91                error: Some(JsonRpcError {
92                    code: ErrorCode::RequestTimeout as i32,
93                    message: "Request cancelled".to_string(),
94                    data: None,
95                }),
96                ..Default::default()
97            });
98        }
99    }
100}
101
102/// The default request timeout, in milliseconds
103pub const DEFAULT_REQUEST_TIMEOUT_MSEC: u64 = 60000;
104pub struct RequestOptions {
105    pub timeout: Duration,
106}
107
108impl RequestOptions {
109    pub fn timeout(self, timeout: Duration) -> Self {
110        Self { timeout }
111    }
112}
113
114impl Default for RequestOptions {
115    fn default() -> Self {
116        Self {
117            timeout: Duration::from_millis(DEFAULT_REQUEST_TIMEOUT_MSEC),
118        }
119    }
120}
121
122#[derive(Clone)]
123pub struct ProtocolBuilder {
124    request_handlers: Arc<Mutex<HashMap<String, Box<dyn RequestHandler>>>>,
125    notification_handlers: Arc<Mutex<HashMap<String, Box<dyn NotificationHandler>>>>,
126}
127
128impl ProtocolBuilder {
129    pub fn new() -> Self {
130        Self {
131            request_handlers: Arc::new(Mutex::new(HashMap::new())),
132            notification_handlers: Arc::new(Mutex::new(HashMap::new())),
133        }
134    }
135
136    /// Register a typed request handler
137    pub fn request_handler<Req, Resp>(
138        self,
139        method: &str,
140        handler: impl Fn(Req) -> Pin<Box<dyn std::future::Future<Output = Result<Resp>> + Send>>
141            + Send
142            + Sync
143            + 'static,
144    ) -> Self
145    where
146        Req: DeserializeOwned + Send + Sync + 'static,
147        Resp: Serialize + Send + Sync + 'static,
148    {
149        let handler = TypedRequestHandler {
150            handler: Box::new(handler),
151            _phantom: std::marker::PhantomData,
152        };
153
154        if let Ok(mut handlers) = self.request_handlers.try_lock() {
155            handlers.insert(method.to_string(), Box::new(handler));
156        }
157        self
158    }
159
160    pub fn has_request_handler(&self, method: &str) -> bool {
161        self.request_handlers
162            .try_lock()
163            .map(|handlers| handlers.contains_key(method))
164            .unwrap_or(false)
165    }
166
167    pub fn notification_handler<N>(
168        self,
169        method: &str,
170        handler: impl Fn(N) -> Pin<Box<dyn std::future::Future<Output = Result<()>> + Send>>
171            + Send
172            + Sync
173            + 'static,
174    ) -> Self
175    where
176        N: DeserializeOwned + Send + Sync + 'static,
177    {
178        let handler = TypedNotificationHandler {
179            handler: Box::new(handler),
180            _phantom: std::marker::PhantomData,
181        };
182
183        if let Ok(mut handlers) = self.notification_handlers.try_lock() {
184            handlers.insert(method.to_string(), Box::new(handler));
185        }
186        self
187    }
188
189    pub fn has_notification_handler(&self, method: &str) -> bool {
190        self.notification_handlers
191            .try_lock()
192            .map(|handlers| handlers.contains_key(method))
193            .unwrap_or(false)
194    }
195
196    pub fn build(self) -> Protocol {
197        Protocol {
198            request_id: Arc::new(AtomicU64::new(0)),
199            pending_requests: Arc::new(Mutex::new(HashMap::new())),
200            request_handlers: self.request_handlers,
201            notification_handlers: self.notification_handlers,
202        }
203    }
204}
205
206// Update the handler traits to be async
207#[async_trait]
208trait RequestHandler: Send + Sync {
209    async fn handle(&self, request: JsonRpcRequest) -> Result<JsonRpcResponse>;
210}
211
212#[async_trait]
213trait NotificationHandler: Send + Sync {
214    async fn handle(&self, notification: JsonRpcNotification) -> Result<()>;
215}
216
217// Update the TypedRequestHandler to use async handlers
218struct TypedRequestHandler<Req, Resp>
219where
220    Req: DeserializeOwned + Send + Sync + 'static,
221    Resp: Serialize + Send + Sync + 'static,
222{
223    handler: Box<
224        dyn Fn(Req) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<Resp>> + Send>>
225            + Send
226            + Sync,
227    >,
228    _phantom: std::marker::PhantomData<(Req, Resp)>,
229}
230
231#[async_trait]
232impl<Req, Resp> RequestHandler for TypedRequestHandler<Req, Resp>
233where
234    Req: DeserializeOwned + Send + Sync + 'static,
235    Resp: Serialize + Send + Sync + 'static,
236{
237    async fn handle(&self, request: JsonRpcRequest) -> Result<JsonRpcResponse> {
238        let params: Req = if request.params.is_none() || request.params.as_ref().unwrap().is_null()
239        {
240            serde_json::from_value(serde_json::Value::Null)?
241        } else {
242            serde_json::from_value(request.params.unwrap())?
243        };
244        let result = (self.handler)(params).await?;
245        Ok(JsonRpcResponse {
246            id: request.id,
247            result: Some(serde_json::to_value(result)?),
248            error: None,
249            ..Default::default()
250        })
251    }
252}
253
254struct TypedNotificationHandler<N>
255where
256    N: DeserializeOwned + Send + Sync + 'static,
257{
258    handler: Box<
259        dyn Fn(N) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<()>> + Send>>
260            + Send
261            + Sync,
262    >,
263    _phantom: std::marker::PhantomData<N>,
264}
265
266#[async_trait]
267impl<N> NotificationHandler for TypedNotificationHandler<N>
268where
269    N: DeserializeOwned + Send + Sync + 'static,
270{
271    async fn handle(&self, notification: JsonRpcNotification) -> Result<()> {
272        let params: N =
273            if notification.params.is_none() || notification.params.as_ref().unwrap().is_null() {
274                serde_json::from_value(serde_json::Value::Null)?
275            } else {
276                serde_json::from_value(notification.params.unwrap())?
277            };
278        (self.handler)(params).await
279    }
280}