mcp_sdk/
protocol.rs

1use super::transport::{
2    JsonRpcError, JsonRpcMessage, JsonRpcNotification, JsonRpcRequest, JsonRpcResponse, Message,
3    Transport,
4};
5use super::types::ErrorCode;
6use anyhow::anyhow;
7use anyhow::Result;
8use async_trait::async_trait;
9use serde::de::DeserializeOwned;
10use serde::Serialize;
11use std::sync::atomic::Ordering;
12use std::time::Duration;
13use std::{
14    collections::HashMap,
15    sync::{atomic::AtomicU64, Arc},
16};
17use tokio::sync::oneshot;
18use tokio::sync::Mutex;
19use tokio::time::timeout;
20use tracing::debug;
21
22#[derive(Clone)]
23pub struct Protocol<T: Transport> {
24    transport: Arc<T>,
25
26    request_id: Arc<AtomicU64>,
27    pending_requests: Arc<Mutex<HashMap<u64, oneshot::Sender<JsonRpcResponse>>>>,
28    request_handlers: Arc<Mutex<HashMap<String, Box<dyn RequestHandler>>>>,
29    notification_handlers: Arc<Mutex<HashMap<String, Box<dyn NotificationHandler>>>>,
30}
31
32impl<T: Transport> Protocol<T> {
33    pub fn builder(transport: T) -> ProtocolBuilder<T> {
34        ProtocolBuilder::new(transport)
35    }
36
37    pub fn notify(&self, method: &str, params: Option<serde_json::Value>) -> Result<()> {
38        let notification = JsonRpcNotification {
39            method: method.to_string(),
40            params,
41            ..Default::default()
42        };
43        let msg = JsonRpcMessage::Notification(notification);
44        self.transport.send(&msg)?;
45        Ok(())
46    }
47
48    pub async fn request(
49        &self,
50        method: &str,
51        params: Option<serde_json::Value>,
52        options: RequestOptions,
53    ) -> Result<JsonRpcResponse> {
54        let id = self.request_id.fetch_add(1, Ordering::SeqCst);
55
56        // Create a oneshot channel for this request
57        let (tx, rx) = oneshot::channel();
58
59        // Store the sender
60        {
61            let mut pending = self.pending_requests.lock().await;
62            pending.insert(id, tx);
63        }
64
65        // Send the request
66        let msg = JsonRpcMessage::Request(JsonRpcRequest {
67            id,
68            method: method.to_string(),
69            params,
70            ..Default::default()
71        });
72        self.transport.send(&msg)?;
73
74        // Wait for response with timeout
75        match timeout(options.timeout, rx)
76            .await
77            .map_err(|_| anyhow!("Request timed out"))?
78        {
79            Ok(response) => Ok(response),
80            Err(_) => {
81                // Clean up the pending request if receiver was dropped
82                let mut pending = self.pending_requests.lock().await;
83                pending.remove(&id);
84                Err(anyhow!("Request cancelled"))
85            }
86        }
87    }
88
89    pub async fn listen(&self) -> Result<()> {
90        debug!("Listening for requests");
91        loop {
92            let message: Message = self.transport.receive()?;
93            match message {
94                JsonRpcMessage::Request(request) => self.handle_request(request).await?,
95                JsonRpcMessage::Response(response) => {
96                    // Remove and send response through the channel
97                    let id = response.id;
98                    let mut pending = self.pending_requests.lock().await;
99                    if let Some(tx) = pending.remove(&id) {
100                        // If send fails, the receiver was dropped, just continue
101                        let _ = tx.send(response);
102                    }
103                }
104                JsonRpcMessage::Notification(json_rpc_notification) => {
105                    let handlers = self.notification_handlers.lock().await;
106                    if let Some(handler) = handlers.get(&json_rpc_notification.method) {
107                        handler.handle(json_rpc_notification)?;
108                    }
109                }
110            }
111        }
112    }
113
114    async fn handle_request(&self, request: JsonRpcRequest) -> Result<()> {
115        let handlers = self.request_handlers.lock().await;
116        if let Some(handler) = handlers.get(&request.method) {
117            match handler.handle(request.clone()).await {
118                Ok(response) => {
119                    let msg = JsonRpcMessage::Response(response);
120                    self.transport.send(&msg)?;
121                }
122                Err(e) => {
123                    let error_response = JsonRpcResponse {
124                        id: request.id,
125                        result: None,
126                        error: Some(JsonRpcError {
127                            code: ErrorCode::InternalError as i32,
128                            message: e.to_string(),
129                            data: None,
130                        }),
131                        ..Default::default()
132                    };
133                    let msg = JsonRpcMessage::Response(error_response);
134                    self.transport.send(&msg)?;
135                }
136            }
137        } else {
138            self.transport
139                .send(&JsonRpcMessage::Response(JsonRpcResponse {
140                    id: request.id,
141                    error: Some(JsonRpcError {
142                        code: ErrorCode::MethodNotFound as i32,
143                        message: format!("Method not found: {}", request.method),
144                        data: None,
145                    }),
146                    ..Default::default()
147                }))?;
148        }
149        Ok(())
150    }
151}
152
153/// The default request timeout, in milliseconds
154pub const DEFAULT_REQUEST_TIMEOUT_MSEC: u64 = 60000;
155pub struct RequestOptions {
156    timeout: Duration,
157}
158
159impl RequestOptions {
160    pub fn timeout(self, timeout: Duration) -> Self {
161        Self { timeout }
162    }
163}
164
165impl Default for RequestOptions {
166    fn default() -> Self {
167        Self {
168            timeout: Duration::from_millis(DEFAULT_REQUEST_TIMEOUT_MSEC),
169        }
170    }
171}
172
173pub struct ProtocolBuilder<T: Transport> {
174    transport: T,
175    request_handlers: HashMap<String, Box<dyn RequestHandler>>,
176    notification_handlers: HashMap<String, Box<dyn NotificationHandler>>,
177}
178impl<T: Transport> ProtocolBuilder<T> {
179    pub fn new(transport: T) -> Self {
180        Self {
181            transport,
182            request_handlers: HashMap::new(),
183            notification_handlers: HashMap::new(),
184        }
185    }
186    /// Register a typed request handler
187    pub fn request_handler<Req, Resp>(
188        mut self,
189        method: &str,
190        handler: impl Fn(Req) -> Result<Resp> + Send + Sync + 'static,
191    ) -> Self
192    where
193        Req: DeserializeOwned + Send + Sync + 'static,
194        Resp: Serialize + Send + Sync + 'static,
195    {
196        let handler = TypedRequestHandler {
197            handler,
198            _phantom: std::marker::PhantomData,
199        };
200
201        self.request_handlers
202            .insert(method.to_string(), Box::new(handler));
203        self
204    }
205
206    pub fn has_request_handler(&self, method: &str) -> bool {
207        self.request_handlers.contains_key(method)
208    }
209
210    pub fn notification_handler<N>(
211        mut self,
212        method: &str,
213        handler: impl Fn(N) -> Result<()> + Send + Sync + 'static,
214    ) -> Self
215    where
216        N: DeserializeOwned + Send + Sync + 'static,
217    {
218        self.notification_handlers.insert(
219            method.to_string(),
220            Box::new(TypedNotificationHandler {
221                handler,
222                _phantom: std::marker::PhantomData,
223            }),
224        );
225        self
226    }
227
228    pub fn build(self) -> Protocol<T> {
229        Protocol {
230            transport: Arc::new(self.transport),
231            request_handlers: Arc::new(Mutex::new(self.request_handlers)),
232            notification_handlers: Arc::new(Mutex::new(self.notification_handlers)),
233            request_id: Arc::new(AtomicU64::new(0)),
234            pending_requests: Arc::new(Mutex::new(HashMap::new())),
235        }
236    }
237}
238
239// Wrapper for handler types using async trait
240#[async_trait]
241trait RequestHandler: Send + Sync {
242    async fn handle(&self, request: JsonRpcRequest) -> Result<JsonRpcResponse>;
243}
244
245trait NotificationHandler: Send + Sync {
246    fn handle(&self, notification: JsonRpcNotification) -> Result<()>;
247}
248
249// Typed handler implementations
250struct TypedRequestHandler<Req, Resp, F>
251where
252    Req: DeserializeOwned + Send + Sync + 'static,
253    Resp: Serialize + Send + Sync + 'static,
254    F: Fn(Req) -> Result<Resp> + Send + Sync + 'static,
255{
256    handler: F,
257    _phantom: std::marker::PhantomData<(Req, Resp)>,
258}
259
260#[async_trait]
261impl<Req, Resp, F> RequestHandler for TypedRequestHandler<Req, Resp, F>
262where
263    Req: DeserializeOwned + Send + Sync + 'static,
264    Resp: Serialize + Send + Sync + 'static,
265    F: Fn(Req) -> Result<Resp> + Send + Sync + 'static,
266{
267    async fn handle(&self, request: JsonRpcRequest) -> Result<JsonRpcResponse> {
268        // If params is None or null, deserialize as unit type using Value::Null
269        let params: Req = if request.params.is_none() || request.params.as_ref().unwrap().is_null()
270        {
271            serde_json::from_value(serde_json::Value::Null)?
272        } else {
273            serde_json::from_value(request.params.unwrap())?
274        };
275        let result = (self.handler)(params)?;
276        Ok(JsonRpcResponse {
277            id: request.id,
278            result: Some(serde_json::to_value(result)?),
279            error: None,
280            ..Default::default()
281        })
282    }
283}
284pub struct TypedNotificationHandler<N, F>
285where
286    N: DeserializeOwned + Send + Sync + 'static,
287    F: Fn(N) -> Result<()> + Send + Sync + 'static,
288{
289    handler: F,
290    _phantom: std::marker::PhantomData<N>,
291}
292
293#[async_trait]
294impl<N, F> NotificationHandler for TypedNotificationHandler<N, F>
295where
296    N: DeserializeOwned + Send + Sync + 'static,
297    F: Fn(N) -> Result<()> + Send + Sync + 'static,
298{
299    fn handle(&self, notification: JsonRpcNotification) -> Result<()> {
300        let params: N =
301            if notification.params.is_none() || notification.params.as_ref().unwrap().is_null() {
302                serde_json::from_value(serde_json::Value::Null)?
303            } else {
304                serde_json::from_value(notification.params.unwrap())?
305            };
306        (self.handler)(params)
307    }
308}