mcp_core/
protocol.rs

1use super::transport::{JsonRpcError, JsonRpcNotification, JsonRpcRequest, JsonRpcResponse};
2use super::types::ErrorCode;
3use anyhow::Result;
4use async_trait::async_trait;
5use serde::de::DeserializeOwned;
6use serde::Serialize;
7use serde_json::json;
8use std::pin::Pin;
9
10use std::sync::atomic::{AtomicU64, Ordering};
11use std::time::Duration;
12use std::{collections::HashMap, sync::Arc};
13use tokio::sync::{oneshot, Mutex};
14
15#[derive(Clone)]
16pub struct Protocol {
17    request_id: Arc<AtomicU64>,
18    pending_requests: Arc<Mutex<HashMap<u64, oneshot::Sender<JsonRpcResponse>>>>,
19    request_handlers: Arc<Mutex<HashMap<String, Box<dyn RequestHandler>>>>,
20    notification_handlers: Arc<Mutex<HashMap<String, Box<dyn NotificationHandler>>>>,
21}
22
23impl Protocol {
24    pub fn builder() -> ProtocolBuilder {
25        ProtocolBuilder::new()
26    }
27
28    pub async fn handle_request(&self, request: JsonRpcRequest) -> JsonRpcResponse {
29        let handlers = self.request_handlers.lock().await;
30        if let Some(handler) = handlers.get(&request.method) {
31            match handler.handle(request.clone()).await {
32                Ok(response) => response,
33                Err(e) => JsonRpcResponse {
34                    id: request.id,
35                    result: None,
36                    error: Some(JsonRpcError {
37                        code: ErrorCode::InternalError as i32,
38                        message: e.to_string(),
39                        data: None,
40                    }),
41                    ..Default::default()
42                },
43            }
44        } else {
45            JsonRpcResponse {
46                id: request.id,
47                error: Some(JsonRpcError {
48                    code: ErrorCode::MethodNotFound as i32,
49                    message: format!("Method not found: {}", request.method),
50                    data: None,
51                }),
52                ..Default::default()
53            }
54        }
55    }
56
57    pub async fn handle_notification(&self, request: JsonRpcNotification) {
58        let handlers = self.notification_handlers.lock().await;
59        if let Some(handler) = handlers.get(&request.method) {
60            match handler.handle(request.clone()).await {
61                Ok(_) => tracing::info!("Received notification: {:?}", request.method),
62                Err(e) => tracing::error!("Error handling notification: {}", e),
63            }
64        } else {
65            tracing::debug!("No handler for notification: {}", request.method);
66        }
67    }
68
69    pub fn new_message_id(&self) -> u64 {
70        self.request_id.fetch_add(1, Ordering::SeqCst)
71    }
72
73    pub async fn create_request(&self) -> (u64, oneshot::Receiver<JsonRpcResponse>) {
74        let id = self.new_message_id();
75        let (tx, rx) = oneshot::channel();
76
77        {
78            let mut pending = self.pending_requests.lock().await;
79            pending.insert(id, tx);
80        }
81
82        (id, rx)
83    }
84
85    pub async fn handle_response(&self, response: JsonRpcResponse) {
86        if let Some(tx) = self.pending_requests.lock().await.remove(&response.id) {
87            let _ = tx.send(response);
88        }
89    }
90
91    pub async fn cancel_response(&self, id: u64) {
92        if let Some(tx) = self.pending_requests.lock().await.remove(&id) {
93            let _ = tx.send(JsonRpcResponse {
94                id,
95                result: None,
96                error: Some(JsonRpcError {
97                    code: ErrorCode::RequestTimeout as i32,
98                    message: "Request cancelled".to_string(),
99                    data: None,
100                }),
101                ..Default::default()
102            });
103        }
104    }
105}
106
107/// The default request timeout, in milliseconds
108pub const DEFAULT_REQUEST_TIMEOUT_MSEC: u64 = 60000;
109pub struct RequestOptions {
110    pub timeout: Duration,
111}
112
113impl RequestOptions {
114    pub fn timeout(self, timeout: Duration) -> Self {
115        Self { timeout }
116    }
117}
118
119impl Default for RequestOptions {
120    fn default() -> Self {
121        Self {
122            timeout: Duration::from_millis(DEFAULT_REQUEST_TIMEOUT_MSEC),
123        }
124    }
125}
126
127#[derive(Clone)]
128pub struct ProtocolBuilder {
129    request_handlers: Arc<Mutex<HashMap<String, Box<dyn RequestHandler>>>>,
130    notification_handlers: Arc<Mutex<HashMap<String, Box<dyn NotificationHandler>>>>,
131}
132
133impl ProtocolBuilder {
134    pub fn new() -> Self {
135        Self {
136            request_handlers: Arc::new(Mutex::new(HashMap::new())),
137            notification_handlers: Arc::new(Mutex::new(HashMap::new())),
138        }
139    }
140
141    /// Register a typed request handler
142    pub fn request_handler<Req, Resp>(
143        self,
144        method: &str,
145        handler: impl Fn(Req) -> Pin<Box<dyn std::future::Future<Output = Result<Resp>> + Send>>
146            + Send
147            + Sync
148            + 'static,
149    ) -> Self
150    where
151        Req: DeserializeOwned + Send + Sync + 'static,
152        Resp: Serialize + Send + Sync + 'static,
153    {
154        let handler = TypedRequestHandler {
155            handler: Box::new(handler),
156            _phantom: std::marker::PhantomData,
157        };
158
159        if let Ok(mut handlers) = self.request_handlers.try_lock() {
160            handlers.insert(method.to_string(), Box::new(handler));
161        }
162        self
163    }
164
165    pub fn has_request_handler(&self, method: &str) -> bool {
166        self.request_handlers
167            .try_lock()
168            .map(|handlers| handlers.contains_key(method))
169            .unwrap_or(false)
170    }
171
172    pub fn notification_handler<N>(
173        self,
174        method: &str,
175        handler: impl Fn(N) -> Pin<Box<dyn std::future::Future<Output = Result<()>> + Send>>
176            + Send
177            + Sync
178            + 'static,
179    ) -> Self
180    where
181        N: DeserializeOwned + Send + Sync + 'static,
182    {
183        let handler = TypedNotificationHandler {
184            handler: Box::new(handler),
185            _phantom: std::marker::PhantomData,
186        };
187
188        if let Ok(mut handlers) = self.notification_handlers.try_lock() {
189            handlers.insert(method.to_string(), Box::new(handler));
190        }
191        self
192    }
193
194    pub fn has_notification_handler(&self, method: &str) -> bool {
195        self.notification_handlers
196            .try_lock()
197            .map(|handlers| handlers.contains_key(method))
198            .unwrap_or(false)
199    }
200
201    pub fn build(self) -> Protocol {
202        Protocol {
203            request_id: Arc::new(AtomicU64::new(0)),
204            pending_requests: Arc::new(Mutex::new(HashMap::new())),
205            request_handlers: self.request_handlers,
206            notification_handlers: self.notification_handlers,
207        }
208    }
209}
210
211// Update the handler traits to be async
212#[async_trait]
213trait RequestHandler: Send + Sync {
214    async fn handle(&self, request: JsonRpcRequest) -> Result<JsonRpcResponse>;
215}
216
217#[async_trait]
218trait NotificationHandler: Send + Sync {
219    async fn handle(&self, notification: JsonRpcNotification) -> Result<()>;
220}
221
222// Update the TypedRequestHandler to use async handlers
223struct TypedRequestHandler<Req, Resp>
224where
225    Req: DeserializeOwned + Send + Sync + 'static,
226    Resp: Serialize + Send + Sync + 'static,
227{
228    handler: Box<
229        dyn Fn(Req) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<Resp>> + Send>>
230            + Send
231            + Sync,
232    >,
233    _phantom: std::marker::PhantomData<(Req, Resp)>,
234}
235
236#[async_trait]
237impl<Req, Resp> RequestHandler for TypedRequestHandler<Req, Resp>
238where
239    Req: DeserializeOwned + Send + Sync + 'static,
240    Resp: Serialize + Send + Sync + 'static,
241{
242    async fn handle(&self, request: JsonRpcRequest) -> Result<JsonRpcResponse> {
243        let params: Req = if request.params.is_none() || request.params.as_ref().unwrap().is_null()
244        {
245            serde_json::from_value(json!({}))?
246        } else {
247            serde_json::from_value(request.params.unwrap())?
248        };
249        let result = (self.handler)(params).await?;
250        Ok(JsonRpcResponse {
251            id: request.id,
252            result: Some(serde_json::to_value(result)?),
253            error: None,
254            ..Default::default()
255        })
256    }
257}
258
259struct TypedNotificationHandler<N>
260where
261    N: DeserializeOwned + Send + Sync + 'static,
262{
263    handler: Box<
264        dyn Fn(N) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<()>> + Send>>
265            + Send
266            + Sync,
267    >,
268    _phantom: std::marker::PhantomData<N>,
269}
270
271#[async_trait]
272impl<N> NotificationHandler for TypedNotificationHandler<N>
273where
274    N: DeserializeOwned + Send + Sync + 'static,
275{
276    async fn handle(&self, notification: JsonRpcNotification) -> Result<()> {
277        let params: N =
278            if notification.params.is_none() || notification.params.as_ref().unwrap().is_null() {
279                serde_json::from_value(serde_json::Value::Null)?
280            } else {
281                serde_json::from_value(notification.params.unwrap())?
282            };
283        (self.handler)(params).await
284    }
285}