mcp_core/
protocol.rs

1use super::transport::{
2    JsonRpcError, JsonRpcMessage, JsonRpcNotification, JsonRpcRequest, JsonRpcResponse, Message,
3    Transport,
4};
5use super::types::ErrorCode;
6use anyhow::anyhow;
7use anyhow::Result;
8use async_trait::async_trait;
9use serde::de::DeserializeOwned;
10use serde::Serialize;
11use std::pin::Pin;
12use std::sync::atomic::Ordering;
13use std::time::Duration;
14use std::{
15    collections::HashMap,
16    sync::{atomic::AtomicU64, Arc},
17};
18use tokio::sync::oneshot;
19use tokio::sync::Mutex;
20use tokio::time::timeout;
21use tracing::debug;
22
23#[derive(Clone)]
24pub struct Protocol<T: Transport> {
25    transport: Arc<T>,
26
27    request_id: Arc<AtomicU64>,
28    pending_requests: Arc<Mutex<HashMap<u64, oneshot::Sender<JsonRpcResponse>>>>,
29    request_handlers: Arc<Mutex<HashMap<String, Box<dyn RequestHandler>>>>,
30    notification_handlers: Arc<Mutex<HashMap<String, Box<dyn NotificationHandler>>>>,
31}
32
33impl<T: Transport> Protocol<T> {
34    pub fn builder(transport: T) -> ProtocolBuilder<T> {
35        ProtocolBuilder::new(transport)
36    }
37
38    pub async fn notify(&self, method: &str, params: Option<serde_json::Value>) -> Result<()> {
39        let notification = JsonRpcNotification {
40            method: method.to_string(),
41            params,
42            ..Default::default()
43        };
44        let msg = JsonRpcMessage::Notification(notification);
45        self.transport.send(&msg).await?;
46        Ok(())
47    }
48
49    pub async fn request(
50        &self,
51        method: &str,
52        params: Option<serde_json::Value>,
53        options: RequestOptions,
54    ) -> Result<JsonRpcResponse> {
55        let id = self.request_id.fetch_add(1, Ordering::SeqCst);
56
57        // Create a oneshot channel for this request
58        let (tx, rx) = oneshot::channel();
59
60        // Store the sender
61        {
62            let mut pending = self.pending_requests.lock().await;
63            pending.insert(id, tx);
64        }
65
66        // Send the request
67        let msg = JsonRpcMessage::Request(JsonRpcRequest {
68            id,
69            method: method.to_string(),
70            params,
71            ..Default::default()
72        });
73
74        self.transport.send(&msg).await?;
75
76        // Wait for response with timeout
77        match timeout(options.timeout, rx)
78            .await
79            .map_err(|_| anyhow!("Request timed out"))?
80        {
81            Ok(response) => Ok(response),
82            Err(_) => {
83                // Clean up the pending request if receiver was dropped
84                let mut pending = self.pending_requests.lock().await;
85                pending.remove(&id);
86                Err(anyhow!("Request cancelled"))
87            }
88        }
89    }
90
91    pub async fn listen(&self) -> Result<()> {
92        debug!("Listening for requests");
93        loop {
94            let message: Option<Message> = self.transport.receive().await?;
95
96            // Exit loop when transport signals shutdown with None
97            if message.is_none() {
98                break;
99            }
100
101            match message.unwrap() {
102                JsonRpcMessage::Request(request) => self.handle_request(request).await?,
103                JsonRpcMessage::Response(response) => {
104                    let id = response.id;
105                    let mut pending = self.pending_requests.lock().await;
106                    if let Some(tx) = pending.remove(&id) {
107                        let _ = tx.send(response);
108                    }
109                }
110                JsonRpcMessage::Notification(notification) => {
111                    let handlers = self.notification_handlers.lock().await;
112                    if let Some(handler) = handlers.get(&notification.method) {
113                        handler.handle(notification).await?;
114                    }
115                }
116            }
117        }
118        Ok(())
119    }
120
121    async fn handle_request(&self, request: JsonRpcRequest) -> Result<()> {
122        let handlers = self.request_handlers.lock().await;
123        if let Some(handler) = handlers.get(&request.method) {
124            match handler.handle(request.clone()).await {
125                Ok(response) => {
126                    let msg = JsonRpcMessage::Response(response);
127                    self.transport.send(&msg).await?;
128                }
129                Err(e) => {
130                    let error_response = JsonRpcResponse {
131                        id: request.id,
132                        result: None,
133                        error: Some(JsonRpcError {
134                            code: ErrorCode::InternalError as i32,
135                            message: e.to_string(),
136                            data: None,
137                        }),
138                        ..Default::default()
139                    };
140                    let msg = JsonRpcMessage::Response(error_response);
141                    self.transport.send(&msg).await?;
142                }
143            }
144        } else {
145            self.transport
146                .send(&JsonRpcMessage::Response(JsonRpcResponse {
147                    id: request.id,
148                    error: Some(JsonRpcError {
149                        code: ErrorCode::MethodNotFound as i32,
150                        message: format!("Method not found: {}", request.method),
151                        data: None,
152                    }),
153                    ..Default::default()
154                }))
155                .await?;
156        }
157        Ok(())
158    }
159}
160
161/// The default request timeout, in milliseconds
162pub const DEFAULT_REQUEST_TIMEOUT_MSEC: u64 = 60000;
163pub struct RequestOptions {
164    timeout: Duration,
165}
166
167impl RequestOptions {
168    pub fn timeout(self, timeout: Duration) -> Self {
169        Self { timeout }
170    }
171}
172
173impl Default for RequestOptions {
174    fn default() -> Self {
175        Self {
176            timeout: Duration::from_millis(DEFAULT_REQUEST_TIMEOUT_MSEC),
177        }
178    }
179}
180
181pub struct ProtocolBuilder<T: Transport> {
182    transport: T,
183    request_handlers: HashMap<String, Box<dyn RequestHandler>>,
184    notification_handlers: HashMap<String, Box<dyn NotificationHandler>>,
185}
186impl<T: Transport> ProtocolBuilder<T> {
187    pub fn new(transport: T) -> Self {
188        Self {
189            transport,
190            request_handlers: HashMap::new(),
191            notification_handlers: HashMap::new(),
192        }
193    }
194    /// Register a typed request handler
195    pub fn request_handler<Req, Resp>(
196        mut self,
197        method: &str,
198        handler: impl Fn(Req) -> Pin<Box<dyn std::future::Future<Output = Result<Resp>> + Send>>
199            + Send
200            + Sync
201            + 'static,
202    ) -> Self
203    where
204        Req: DeserializeOwned + Send + Sync + 'static,
205        Resp: Serialize + Send + Sync + 'static,
206    {
207        let handler = TypedRequestHandler {
208            handler: Box::new(handler),
209            _phantom: std::marker::PhantomData,
210        };
211
212        self.request_handlers
213            .insert(method.to_string(), Box::new(handler));
214        self
215    }
216
217    pub fn has_request_handler(&self, method: &str) -> bool {
218        self.request_handlers.contains_key(method)
219    }
220
221    pub fn notification_handler<N>(
222        mut self,
223        method: &str,
224        handler: impl Fn(N) -> Pin<Box<dyn std::future::Future<Output = Result<()>> + Send>>
225            + Send
226            + Sync
227            + 'static,
228    ) -> Self
229    where
230        N: DeserializeOwned + Send + Sync + 'static,
231    {
232        self.notification_handlers.insert(
233            method.to_string(),
234            Box::new(TypedNotificationHandler {
235                handler: Box::new(handler),
236                _phantom: std::marker::PhantomData,
237            }),
238        );
239        self
240    }
241
242    pub fn build(self) -> Protocol<T> {
243        Protocol {
244            transport: Arc::new(self.transport),
245            request_handlers: Arc::new(Mutex::new(self.request_handlers)),
246            notification_handlers: Arc::new(Mutex::new(self.notification_handlers)),
247            request_id: Arc::new(AtomicU64::new(0)),
248            pending_requests: Arc::new(Mutex::new(HashMap::new())),
249        }
250    }
251}
252
253// Update the handler traits to be async
254#[async_trait]
255trait RequestHandler: Send + Sync {
256    async fn handle(&self, request: JsonRpcRequest) -> Result<JsonRpcResponse>;
257}
258
259#[async_trait]
260trait NotificationHandler: Send + Sync {
261    async fn handle(&self, notification: JsonRpcNotification) -> Result<()>;
262}
263
264// Update the TypedRequestHandler to use async handlers
265struct TypedRequestHandler<Req, Resp>
266where
267    Req: DeserializeOwned + Send + Sync + 'static,
268    Resp: Serialize + Send + Sync + 'static,
269{
270    handler: Box<
271        dyn Fn(Req) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<Resp>> + Send>>
272            + Send
273            + Sync,
274    >,
275    _phantom: std::marker::PhantomData<(Req, Resp)>,
276}
277
278#[async_trait]
279impl<Req, Resp> RequestHandler for TypedRequestHandler<Req, Resp>
280where
281    Req: DeserializeOwned + Send + Sync + 'static,
282    Resp: Serialize + Send + Sync + 'static,
283{
284    async fn handle(&self, request: JsonRpcRequest) -> Result<JsonRpcResponse> {
285        let params: Req = if request.params.is_none() || request.params.as_ref().unwrap().is_null()
286        {
287            serde_json::from_value(serde_json::Value::Null)?
288        } else {
289            serde_json::from_value(request.params.unwrap())?
290        };
291        let result = (self.handler)(params).await?;
292        Ok(JsonRpcResponse {
293            id: request.id,
294            result: Some(serde_json::to_value(result)?),
295            error: None,
296            ..Default::default()
297        })
298    }
299}
300
301struct TypedNotificationHandler<N>
302where
303    N: DeserializeOwned + Send + Sync + 'static,
304{
305    handler: Box<
306        dyn Fn(N) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<()>> + Send>>
307            + Send
308            + Sync,
309    >,
310    _phantom: std::marker::PhantomData<N>,
311}
312
313#[async_trait]
314impl<N> NotificationHandler for TypedNotificationHandler<N>
315where
316    N: DeserializeOwned + Send + Sync + 'static,
317{
318    async fn handle(&self, notification: JsonRpcNotification) -> Result<()> {
319        let params: N =
320            if notification.params.is_none() || notification.params.as_ref().unwrap().is_null() {
321                serde_json::from_value(serde_json::Value::Null)?
322            } else {
323                serde_json::from_value(notification.params.unwrap())?
324            };
325        (self.handler)(params).await
326    }
327}