async_mcp/
protocol.rs

1use super::transport::{
2    JsonRpcError, JsonRpcMessage, JsonRpcNotification, JsonRpcRequest, JsonRpcResponse, Message,
3    Transport,
4};
5use super::types::ErrorCode;
6use anyhow::anyhow;
7use anyhow::Result;
8use async_trait::async_trait;
9use serde::de::DeserializeOwned;
10use serde::Serialize;
11use std::pin::Pin;
12use std::sync::atomic::Ordering;
13use std::time::Duration;
14use std::{
15    collections::HashMap,
16    sync::{atomic::AtomicU64, Arc},
17};
18use tokio::sync::oneshot;
19use tokio::sync::Mutex;
20use tokio::time::timeout;
21use tracing::debug;
22
23#[derive(Clone)]
24pub struct Protocol<T: Transport> {
25    transport: Arc<T>,
26
27    request_id: Arc<AtomicU64>,
28    pending_requests: Arc<Mutex<HashMap<u64, oneshot::Sender<JsonRpcResponse>>>>,
29    request_handlers: Arc<Mutex<HashMap<String, Box<dyn RequestHandler>>>>,
30    notification_handlers: Arc<Mutex<HashMap<String, Box<dyn NotificationHandler>>>>,
31}
32
33impl<T: Transport> Protocol<T> {
34    pub fn builder(transport: T) -> ProtocolBuilder<T> {
35        ProtocolBuilder::new(transport)
36    }
37
38    pub async fn notify(&self, method: &str, params: Option<serde_json::Value>) -> Result<()> {
39        let notification = JsonRpcNotification {
40            method: method.to_string(),
41            params,
42            ..Default::default()
43        };
44        let msg = JsonRpcMessage::Notification(notification);
45        self.transport.send(&msg).await?;
46        Ok(())
47    }
48
49    pub async fn request(
50        &self,
51        method: &str,
52        params: Option<serde_json::Value>,
53        options: RequestOptions,
54    ) -> Result<JsonRpcResponse> {
55        let id = self.request_id.fetch_add(1, Ordering::SeqCst);
56
57        // Create a oneshot channel for this request
58        let (tx, rx) = oneshot::channel();
59
60        // Store the sender
61        {
62            let mut pending = self.pending_requests.lock().await;
63            pending.insert(id, tx);
64        }
65
66        // Send the request
67        let msg = JsonRpcMessage::Request(JsonRpcRequest {
68            id,
69            method: method.to_string(),
70            params,
71            ..Default::default()
72        });
73        self.transport.send(&msg).await?;
74
75        // Wait for response with timeout
76        match timeout(options.timeout, rx)
77            .await
78            .map_err(|_| anyhow!("Request timed out"))?
79        {
80            Ok(response) => Ok(response),
81            Err(_) => {
82                // Clean up the pending request if receiver was dropped
83                let mut pending = self.pending_requests.lock().await;
84                pending.remove(&id);
85                Err(anyhow!("Request cancelled"))
86            }
87        }
88    }
89
90    pub async fn listen(&self) -> Result<()> {
91        debug!("Listening for requests");
92        loop {
93            let message: Option<Message> = self.transport.receive().await?;
94
95            // Exit loop when transport signals shutdown with None
96            if message.is_none() {
97                break;
98            }
99
100            match message.unwrap() {
101                JsonRpcMessage::Request(request) => self.handle_request(request).await?,
102                JsonRpcMessage::Response(response) => {
103                    let id = response.id;
104                    let mut pending = self.pending_requests.lock().await;
105                    if let Some(tx) = pending.remove(&id) {
106                        let _ = tx.send(response);
107                    }
108                }
109                JsonRpcMessage::Notification(notification) => {
110                    let handlers = self.notification_handlers.lock().await;
111                    if let Some(handler) = handlers.get(&notification.method) {
112                        handler.handle(notification).await?;
113                    }
114                }
115            }
116        }
117        Ok(())
118    }
119
120    async fn handle_request(&self, request: JsonRpcRequest) -> Result<()> {
121        let handlers = self.request_handlers.lock().await;
122        if let Some(handler) = handlers.get(&request.method) {
123            match handler.handle(request.clone()).await {
124                Ok(response) => {
125                    let msg = JsonRpcMessage::Response(response);
126                    self.transport.send(&msg).await?;
127                }
128                Err(e) => {
129                    let error_response = JsonRpcResponse {
130                        id: request.id,
131                        result: None,
132                        error: Some(JsonRpcError {
133                            code: ErrorCode::InternalError as i32,
134                            message: e.to_string(),
135                            data: None,
136                        }),
137                        ..Default::default()
138                    };
139                    let msg = JsonRpcMessage::Response(error_response);
140                    self.transport.send(&msg).await?;
141                }
142            }
143        } else {
144            self.transport
145                .send(&JsonRpcMessage::Response(JsonRpcResponse {
146                    id: request.id,
147                    error: Some(JsonRpcError {
148                        code: ErrorCode::MethodNotFound as i32,
149                        message: format!("Method not found: {}", request.method),
150                        data: None,
151                    }),
152                    ..Default::default()
153                }))
154                .await?;
155        }
156        Ok(())
157    }
158}
159
160/// The default request timeout, in milliseconds
161pub const DEFAULT_REQUEST_TIMEOUT_MSEC: u64 = 60000;
162pub struct RequestOptions {
163    timeout: Duration,
164}
165
166impl RequestOptions {
167    pub fn timeout(self, timeout: Duration) -> Self {
168        Self { timeout }
169    }
170}
171
172impl Default for RequestOptions {
173    fn default() -> Self {
174        Self {
175            timeout: Duration::from_millis(DEFAULT_REQUEST_TIMEOUT_MSEC),
176        }
177    }
178}
179
180pub struct ProtocolBuilder<T: Transport> {
181    transport: T,
182    request_handlers: HashMap<String, Box<dyn RequestHandler>>,
183    notification_handlers: HashMap<String, Box<dyn NotificationHandler>>,
184}
185impl<T: Transport> ProtocolBuilder<T> {
186    pub fn new(transport: T) -> Self {
187        Self {
188            transport,
189            request_handlers: HashMap::new(),
190            notification_handlers: HashMap::new(),
191        }
192    }
193    /// Register a typed request handler
194    pub fn request_handler<Req, Resp>(
195        mut self,
196        method: &str,
197        handler: impl Fn(Req) -> Pin<Box<dyn std::future::Future<Output = Result<Resp>> + Send>>
198            + Send
199            + Sync
200            + 'static,
201    ) -> Self
202    where
203        Req: DeserializeOwned + Send + Sync + 'static,
204        Resp: Serialize + Send + Sync + 'static,
205    {
206        let handler = TypedRequestHandler {
207            handler: Box::new(handler),
208            _phantom: std::marker::PhantomData,
209        };
210
211        self.request_handlers
212            .insert(method.to_string(), Box::new(handler));
213        self
214    }
215
216    pub fn has_request_handler(&self, method: &str) -> bool {
217        self.request_handlers.contains_key(method)
218    }
219
220    pub fn notification_handler<N>(
221        mut self,
222        method: &str,
223        handler: impl Fn(N) -> Pin<Box<dyn std::future::Future<Output = Result<()>> + Send>>
224            + Send
225            + Sync
226            + 'static,
227    ) -> Self
228    where
229        N: DeserializeOwned + Send + Sync + 'static,
230    {
231        self.notification_handlers.insert(
232            method.to_string(),
233            Box::new(TypedNotificationHandler {
234                handler: Box::new(handler),
235                _phantom: std::marker::PhantomData,
236            }),
237        );
238        self
239    }
240
241    pub fn build(self) -> Protocol<T> {
242        Protocol {
243            transport: Arc::new(self.transport),
244            request_handlers: Arc::new(Mutex::new(self.request_handlers)),
245            notification_handlers: Arc::new(Mutex::new(self.notification_handlers)),
246            request_id: Arc::new(AtomicU64::new(0)),
247            pending_requests: Arc::new(Mutex::new(HashMap::new())),
248        }
249    }
250}
251
252// Update the handler traits to be async
253#[async_trait]
254trait RequestHandler: Send + Sync {
255    async fn handle(&self, request: JsonRpcRequest) -> Result<JsonRpcResponse>;
256}
257
258#[async_trait]
259trait NotificationHandler: Send + Sync {
260    async fn handle(&self, notification: JsonRpcNotification) -> Result<()>;
261}
262
263// Update the TypedRequestHandler to use async handlers
264struct TypedRequestHandler<Req, Resp>
265where
266    Req: DeserializeOwned + Send + Sync + 'static,
267    Resp: Serialize + Send + Sync + 'static,
268{
269    handler: Box<
270        dyn Fn(Req) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<Resp>> + Send>>
271            + Send
272            + Sync,
273    >,
274    _phantom: std::marker::PhantomData<(Req, Resp)>,
275}
276
277#[async_trait]
278impl<Req, Resp> RequestHandler for TypedRequestHandler<Req, Resp>
279where
280    Req: DeserializeOwned + Send + Sync + 'static,
281    Resp: Serialize + Send + Sync + 'static,
282{
283    async fn handle(&self, request: JsonRpcRequest) -> Result<JsonRpcResponse> {
284        let params: Req = if request.params.is_none() || request.params.as_ref().unwrap().is_null()
285        {
286            serde_json::from_value(serde_json::Value::Null)?
287        } else {
288            serde_json::from_value(request.params.unwrap())?
289        };
290        let result = (self.handler)(params).await?;
291        Ok(JsonRpcResponse {
292            id: request.id,
293            result: Some(serde_json::to_value(result)?),
294            error: None,
295            ..Default::default()
296        })
297    }
298}
299
300struct TypedNotificationHandler<N>
301where
302    N: DeserializeOwned + Send + Sync + 'static,
303{
304    handler: Box<
305        dyn Fn(N) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<()>> + Send>>
306            + Send
307            + Sync,
308    >,
309    _phantom: std::marker::PhantomData<N>,
310}
311
312#[async_trait]
313impl<N> NotificationHandler for TypedNotificationHandler<N>
314where
315    N: DeserializeOwned + Send + Sync + 'static,
316{
317    async fn handle(&self, notification: JsonRpcNotification) -> Result<()> {
318        let params: N =
319            if notification.params.is_none() || notification.params.as_ref().unwrap().is_null() {
320                serde_json::from_value(serde_json::Value::Null)?
321            } else {
322                serde_json::from_value(notification.params.unwrap())?
323            };
324        (self.handler)(params).await
325    }
326}