async_mcp/
protocol.rs

1use super::transport::{
2    JsonRpcError, JsonRpcMessage, JsonRpcNotification, JsonRpcRequest, JsonRpcResponse, Transport,
3};
4use super::types::ErrorCode;
5use anyhow::anyhow;
6use anyhow::Result;
7use async_trait::async_trait;
8use serde::de::DeserializeOwned;
9use serde::Serialize;
10use std::pin::Pin;
11use std::sync::atomic::Ordering;
12use std::time::Duration;
13use std::{
14    collections::HashMap,
15    sync::{atomic::AtomicU64, Arc},
16};
17use tokio::sync::oneshot;
18use tokio::sync::Mutex;
19use tokio::time::timeout;
20use tracing::debug;
21
22#[derive(Clone)]
23pub struct Protocol<T: Transport> {
24    transport: Arc<T>,
25
26    request_id: Arc<AtomicU64>,
27    pending_requests: Arc<Mutex<HashMap<u64, oneshot::Sender<JsonRpcResponse>>>>,
28    request_handlers: Arc<Mutex<HashMap<String, Box<dyn RequestHandler>>>>,
29    notification_handlers: Arc<Mutex<HashMap<String, Box<dyn NotificationHandler>>>>,
30}
31
32impl<T: Transport> Protocol<T> {
33    pub fn builder(transport: T) -> ProtocolBuilder<T> {
34        ProtocolBuilder::new(transport)
35    }
36
37    pub async fn notify(&self, method: &str, params: Option<serde_json::Value>) -> Result<()> {
38        let notification = JsonRpcNotification {
39            method: method.to_string(),
40            params,
41            ..Default::default()
42        };
43        let msg = JsonRpcMessage::Notification(notification);
44        self.transport.send(&msg).await?;
45        Ok(())
46    }
47
48    pub async fn request(
49        &self,
50        method: &str,
51        params: Option<serde_json::Value>,
52        options: RequestOptions,
53    ) -> Result<JsonRpcResponse> {
54        let id = self.request_id.fetch_add(1, Ordering::SeqCst);
55
56        // Create a oneshot channel for this request
57        let (tx, rx) = oneshot::channel();
58
59        // Store the sender
60        {
61            let mut pending = self.pending_requests.lock().await;
62            pending.insert(id, tx);
63        }
64
65        // Send the request
66        let msg = JsonRpcMessage::Request(JsonRpcRequest {
67            id,
68            method: method.to_string(),
69            params,
70            ..Default::default()
71        });
72        self.transport.send(&msg).await?;
73
74        // Wait for response with timeout
75        match timeout(options.timeout, rx)
76            .await
77            .map_err(|_| anyhow!("Request timed out"))?
78        {
79            Ok(response) => Ok(response),
80            Err(_) => {
81                // Clean up the pending request if receiver was dropped
82                let mut pending = self.pending_requests.lock().await;
83                pending.remove(&id);
84                Err(anyhow!("Request cancelled"))
85            }
86        }
87    }
88
89    pub async fn listen(&self) -> Result<()> {
90        debug!("Listening for requests");
91        loop {
92            let message = self.transport.receive().await;
93
94            let message = match message {
95                Ok(msg) => msg,
96                Err(e) => {
97                    tracing::error!("Failed to parse message: {:?}", e);
98                    continue;
99                }
100            };
101
102            // Exit loop when transport signals shutdown with None
103            if message.is_none() {
104                break;
105            }
106
107            match message.unwrap() {
108                JsonRpcMessage::Request(request) => self.handle_request(request).await?,
109                JsonRpcMessage::Response(response) => {
110                    let id = response.id;
111                    let mut pending = self.pending_requests.lock().await;
112                    if let Some(tx) = pending.remove(&id) {
113                        let _ = tx.send(response);
114                    }
115                }
116                JsonRpcMessage::Notification(notification) => {
117                    let handlers = self.notification_handlers.lock().await;
118                    if let Some(handler) = handlers.get(&notification.method) {
119                        handler.handle(notification).await?;
120                    }
121                }
122            }
123        }
124        Ok(())
125    }
126
127    async fn handle_request(&self, request: JsonRpcRequest) -> Result<()> {
128        let handlers = self.request_handlers.lock().await;
129        if let Some(handler) = handlers.get(&request.method) {
130            match handler.handle(request.clone()).await {
131                Ok(response) => {
132                    let msg = JsonRpcMessage::Response(response);
133                    self.transport.send(&msg).await?;
134                }
135                Err(e) => {
136                    let error_response = JsonRpcResponse {
137                        id: request.id,
138                        result: None,
139                        error: Some(JsonRpcError {
140                            code: ErrorCode::InternalError as i32,
141                            message: e.to_string(),
142                            data: None,
143                        }),
144                        ..Default::default()
145                    };
146                    let msg = JsonRpcMessage::Response(error_response);
147                    self.transport.send(&msg).await?;
148                }
149            }
150        } else {
151            self.transport
152                .send(&JsonRpcMessage::Response(JsonRpcResponse {
153                    id: request.id,
154                    error: Some(JsonRpcError {
155                        code: ErrorCode::MethodNotFound as i32,
156                        message: format!("Method not found: {}", request.method),
157                        data: None,
158                    }),
159                    ..Default::default()
160                }))
161                .await?;
162        }
163        Ok(())
164    }
165}
166
167/// The default request timeout, in milliseconds
168pub const DEFAULT_REQUEST_TIMEOUT_MSEC: u64 = 60000;
169pub struct RequestOptions {
170    timeout: Duration,
171}
172
173impl RequestOptions {
174    pub fn timeout(self, timeout: Duration) -> Self {
175        Self { timeout }
176    }
177}
178
179impl Default for RequestOptions {
180    fn default() -> Self {
181        Self {
182            timeout: Duration::from_millis(DEFAULT_REQUEST_TIMEOUT_MSEC),
183        }
184    }
185}
186
187pub struct ProtocolBuilder<T: Transport> {
188    transport: T,
189    request_handlers: HashMap<String, Box<dyn RequestHandler>>,
190    notification_handlers: HashMap<String, Box<dyn NotificationHandler>>,
191}
192impl<T: Transport> ProtocolBuilder<T> {
193    pub fn new(transport: T) -> Self {
194        Self {
195            transport,
196            request_handlers: HashMap::new(),
197            notification_handlers: HashMap::new(),
198        }
199    }
200    /// Register a typed request handler
201    pub fn request_handler<Req, Resp>(
202        mut self,
203        method: &str,
204        handler: impl Fn(Req) -> Pin<Box<dyn std::future::Future<Output = Result<Resp>> + Send>>
205            + Send
206            + Sync
207            + 'static,
208    ) -> Self
209    where
210        Req: DeserializeOwned + Send + Sync + 'static,
211        Resp: Serialize + Send + Sync + 'static,
212    {
213        let handler = TypedRequestHandler {
214            handler: Box::new(handler),
215            _phantom: std::marker::PhantomData,
216        };
217
218        self.request_handlers
219            .insert(method.to_string(), Box::new(handler));
220        self
221    }
222
223    pub fn has_request_handler(&self, method: &str) -> bool {
224        self.request_handlers.contains_key(method)
225    }
226
227    pub fn notification_handler<N>(
228        mut self,
229        method: &str,
230        handler: impl Fn(N) -> Pin<Box<dyn std::future::Future<Output = Result<()>> + Send>>
231            + Send
232            + Sync
233            + 'static,
234    ) -> Self
235    where
236        N: DeserializeOwned + Send + Sync + 'static,
237    {
238        self.notification_handlers.insert(
239            method.to_string(),
240            Box::new(TypedNotificationHandler {
241                handler: Box::new(handler),
242                _phantom: std::marker::PhantomData,
243            }),
244        );
245        self
246    }
247
248    pub fn build(self) -> Protocol<T> {
249        Protocol {
250            transport: Arc::new(self.transport),
251            request_handlers: Arc::new(Mutex::new(self.request_handlers)),
252            notification_handlers: Arc::new(Mutex::new(self.notification_handlers)),
253            request_id: Arc::new(AtomicU64::new(0)),
254            pending_requests: Arc::new(Mutex::new(HashMap::new())),
255        }
256    }
257}
258
259// Update the handler traits to be async
260#[async_trait]
261trait RequestHandler: Send + Sync {
262    async fn handle(&self, request: JsonRpcRequest) -> Result<JsonRpcResponse>;
263}
264
265#[async_trait]
266trait NotificationHandler: Send + Sync {
267    async fn handle(&self, notification: JsonRpcNotification) -> Result<()>;
268}
269
270// Update the TypedRequestHandler to use async handlers
271struct TypedRequestHandler<Req, Resp>
272where
273    Req: DeserializeOwned + Send + Sync + 'static,
274    Resp: Serialize + Send + Sync + 'static,
275{
276    handler: Box<
277        dyn Fn(Req) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<Resp>> + Send>>
278            + Send
279            + Sync,
280    >,
281    _phantom: std::marker::PhantomData<(Req, Resp)>,
282}
283
284#[async_trait]
285impl<Req, Resp> RequestHandler for TypedRequestHandler<Req, Resp>
286where
287    Req: DeserializeOwned + Send + Sync + 'static,
288    Resp: Serialize + Send + Sync + 'static,
289{
290    async fn handle(&self, request: JsonRpcRequest) -> Result<JsonRpcResponse> {
291        let params: Req = if request.params.is_none() || request.params.as_ref().unwrap().is_null()
292        {
293            serde_json::from_value(serde_json::Value::Null)?
294        } else {
295            serde_json::from_value(request.params.unwrap())?
296        };
297        let result = (self.handler)(params).await?;
298        Ok(JsonRpcResponse {
299            id: request.id,
300            result: Some(serde_json::to_value(result)?),
301            error: None,
302            ..Default::default()
303        })
304    }
305}
306
307struct TypedNotificationHandler<N>
308where
309    N: DeserializeOwned + Send + Sync + 'static,
310{
311    handler: Box<
312        dyn Fn(N) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<()>> + Send>>
313            + Send
314            + Sync,
315    >,
316    _phantom: std::marker::PhantomData<N>,
317}
318
319#[async_trait]
320impl<N> NotificationHandler for TypedNotificationHandler<N>
321where
322    N: DeserializeOwned + Send + Sync + 'static,
323{
324    async fn handle(&self, notification: JsonRpcNotification) -> Result<()> {
325        let params: N =
326            if notification.params.is_none() || notification.params.as_ref().unwrap().is_null() {
327                serde_json::from_value(serde_json::Value::Null)?
328            } else {
329                match &notification.params {
330                    Some(params) => {
331                        let res = serde_json::from_value(params.clone());
332                        match res {
333                            Ok(r) => r,
334                            Err(e) => {
335                                tracing::warn!(
336                                    "Failed to parse notification params: {:?}. Params: {:?}",
337                                    e,
338                                    notification.params
339                                );
340                                serde_json::from_value(serde_json::Value::Null)?
341                            }
342                        }
343                    }
344                    None => serde_json::from_value(serde_json::Value::Null)?,
345                }
346            };
347        (self.handler)(params).await
348    }
349}