mcp-core 0.1.50

A Rust library implementing the Modern Context Protocol (MCP)
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
//! # MCP Protocol Implementation
//!
//! This module implements the core JSON-RPC protocol layer used by the MCP system.
//! It provides the infrastructure for sending and receiving JSON-RPC requests,
//! notifications, and responses between MCP clients and servers.
//!
//! The protocol layer is transport-agnostic and can work with any transport
//! implementation that conforms to the `Transport` trait.
//!
//! Key components include:
//! - `Protocol`: The main protocol handler
//! - `ProtocolBuilder`: A builder for configuring protocols
//! - Request and notification handlers
//! - Timeout and error handling

use super::transport::{JsonRpcError, JsonRpcNotification, JsonRpcRequest, JsonRpcResponse};
use super::types::ErrorCode;
use anyhow::Result;
use async_trait::async_trait;
use serde::de::DeserializeOwned;
use serde::Serialize;
use serde_json::json;
use std::pin::Pin;

use std::sync::atomic::{AtomicU64, Ordering};
use std::time::Duration;
use std::{collections::HashMap, sync::Arc};
use tokio::sync::{oneshot, Mutex};

/// The core protocol handler for MCP.
///
/// The `Protocol` struct manages the lifecycle of JSON-RPC requests and responses,
/// dispatches incoming requests to the appropriate handlers, and manages
/// pending requests and their responses.
#[derive(Clone)]
pub struct Protocol {
    request_id: Arc<AtomicU64>,
    pending_requests: Arc<Mutex<HashMap<u64, oneshot::Sender<JsonRpcResponse>>>>,
    request_handlers: Arc<Mutex<HashMap<String, Box<dyn RequestHandler>>>>,
    notification_handlers: Arc<Mutex<HashMap<String, Box<dyn NotificationHandler>>>>,
}

impl Protocol {
    /// Creates a new protocol builder.
    ///
    /// # Returns
    ///
    /// A `ProtocolBuilder` for configuring the protocol
    pub fn builder() -> ProtocolBuilder {
        ProtocolBuilder::new()
    }

    /// Handles an incoming JSON-RPC request.
    ///
    /// This method dispatches the request to the appropriate handler based on
    /// the request method, and returns the handler's response.
    ///
    /// # Arguments
    ///
    /// * `request` - The incoming JSON-RPC request
    ///
    /// # Returns
    ///
    /// A `JsonRpcResponse` containing the handler's response or an error
    pub async fn handle_request(&self, request: JsonRpcRequest) -> JsonRpcResponse {
        let handlers = self.request_handlers.lock().await;
        if let Some(handler) = handlers.get(&request.method) {
            match handler.handle(request.clone()).await {
                Ok(response) => response,
                Err(e) => JsonRpcResponse {
                    id: request.id,
                    result: None,
                    error: Some(JsonRpcError {
                        code: ErrorCode::InternalError as i32,
                        message: e.to_string(),
                        data: None,
                    }),
                    ..Default::default()
                },
            }
        } else {
            JsonRpcResponse {
                id: request.id,
                error: Some(JsonRpcError {
                    code: ErrorCode::MethodNotFound as i32,
                    message: format!("Method not found: {}", request.method),
                    data: None,
                }),
                ..Default::default()
            }
        }
    }

    /// Handles an incoming JSON-RPC notification.
    ///
    /// This method dispatches the notification to the appropriate handler based on
    /// the notification method.
    ///
    /// # Arguments
    ///
    /// * `request` - The incoming JSON-RPC notification
    pub async fn handle_notification(&self, request: JsonRpcNotification) {
        let handlers = self.notification_handlers.lock().await;
        if let Some(handler) = handlers.get(&request.method) {
            match handler.handle(request.clone()).await {
                Ok(_) => tracing::info!("Received notification: {:?}", request.method),
                Err(e) => tracing::error!("Error handling notification: {}", e),
            }
        } else {
            tracing::debug!("No handler for notification: {}", request.method);
        }
    }

    /// Generates a new unique message ID for requests.
    ///
    /// # Returns
    ///
    /// A unique message ID
    pub fn new_message_id(&self) -> u64 {
        self.request_id.fetch_add(1, Ordering::SeqCst)
    }

    /// Creates a new request ID and channel for receiving the response.
    ///
    /// # Returns
    ///
    /// A tuple containing the request ID and a receiver for the response
    pub async fn create_request(&self) -> (u64, oneshot::Receiver<JsonRpcResponse>) {
        let id = self.new_message_id();
        let (tx, rx) = oneshot::channel();

        {
            let mut pending = self.pending_requests.lock().await;
            pending.insert(id, tx);
        }

        (id, rx)
    }

    /// Handles an incoming JSON-RPC response.
    ///
    /// This method delivers the response to the appropriate waiting request,
    /// if any.
    ///
    /// # Arguments
    ///
    /// * `response` - The incoming JSON-RPC response
    pub async fn handle_response(&self, response: JsonRpcResponse) {
        if let Some(tx) = self.pending_requests.lock().await.remove(&response.id) {
            let _ = tx.send(response);
        }
    }

    /// Cancels a pending request and sends an error response.
    ///
    /// # Arguments
    ///
    /// * `id` - The ID of the request to cancel
    pub async fn cancel_response(&self, id: u64) {
        if let Some(tx) = self.pending_requests.lock().await.remove(&id) {
            let _ = tx.send(JsonRpcResponse {
                id,
                result: None,
                error: Some(JsonRpcError {
                    code: ErrorCode::RequestTimeout as i32,
                    message: "Request cancelled".to_string(),
                    data: None,
                }),
                ..Default::default()
            });
        }
    }
}

/// The default request timeout, in milliseconds
pub const DEFAULT_REQUEST_TIMEOUT_MSEC: u64 = 60000;

/// Options for customizing requests.
///
/// This struct allows configuring various aspects of request handling,
/// such as timeouts.
pub struct RequestOptions {
    /// The timeout duration for the request
    pub timeout: Duration,
}

impl RequestOptions {
    /// Sets the timeout for the request.
    ///
    /// # Arguments
    ///
    /// * `timeout` - The timeout duration
    ///
    /// # Returns
    ///
    /// The modified options instance
    pub fn timeout(self, timeout: Duration) -> Self {
        Self { timeout }
    }
}

impl Default for RequestOptions {
    fn default() -> Self {
        Self {
            timeout: Duration::from_millis(DEFAULT_REQUEST_TIMEOUT_MSEC),
        }
    }
}

/// Builder for creating configured protocols.
///
/// The `ProtocolBuilder` provides a fluent API for configuring and creating
/// protocols with specific request and notification handlers.
#[derive(Clone)]
pub struct ProtocolBuilder {
    request_handlers: Arc<Mutex<HashMap<String, Box<dyn RequestHandler>>>>,
    notification_handlers: Arc<Mutex<HashMap<String, Box<dyn NotificationHandler>>>>,
}

impl ProtocolBuilder {
    /// Creates a new protocol builder.
    ///
    /// # Returns
    ///
    /// A new `ProtocolBuilder` instance
    pub fn new() -> Self {
        Self {
            request_handlers: Arc::new(Mutex::new(HashMap::new())),
            notification_handlers: Arc::new(Mutex::new(HashMap::new())),
        }
    }

    /// Registers a typed request handler.
    ///
    /// # Arguments
    ///
    /// * `method` - The method name to handle
    /// * `handler` - The handler function
    ///
    /// # Returns
    ///
    /// The modified builder instance
    pub fn request_handler<Req, Resp>(
        self,
        method: &str,
        handler: impl Fn(Req) -> Pin<Box<dyn std::future::Future<Output = Result<Resp>> + Send>>
            + Send
            + Sync
            + 'static,
    ) -> Self
    where
        Req: DeserializeOwned + Send + Sync + 'static,
        Resp: Serialize + Send + Sync + 'static,
    {
        let handler = TypedRequestHandler {
            handler: Box::new(handler),
            _phantom: std::marker::PhantomData,
        };

        if let Ok(mut handlers) = self.request_handlers.try_lock() {
            handlers.insert(method.to_string(), Box::new(handler));
        }
        self
    }

    /// Checks if a request handler exists for a method.
    ///
    /// # Arguments
    ///
    /// * `method` - The method name to check
    ///
    /// # Returns
    ///
    /// `true` if a handler exists, `false` otherwise
    pub fn has_request_handler(&self, method: &str) -> bool {
        self.request_handlers
            .try_lock()
            .map(|handlers| handlers.contains_key(method))
            .unwrap_or(false)
    }

    /// Registers a typed notification handler.
    ///
    /// # Arguments
    ///
    /// * `method` - The method name to handle
    /// * `handler` - The handler function
    ///
    /// # Returns
    ///
    /// The modified builder instance
    pub fn notification_handler<N>(
        self,
        method: &str,
        handler: impl Fn(N) -> Pin<Box<dyn std::future::Future<Output = Result<()>> + Send>>
            + Send
            + Sync
            + 'static,
    ) -> Self
    where
        N: DeserializeOwned + Send + Sync + 'static,
    {
        let handler = TypedNotificationHandler {
            handler: Box::new(handler),
            _phantom: std::marker::PhantomData,
        };

        if let Ok(mut handlers) = self.notification_handlers.try_lock() {
            handlers.insert(method.to_string(), Box::new(handler));
        }
        self
    }

    /// Checks if a notification handler exists for a method.
    ///
    /// # Arguments
    ///
    /// * `method` - The method name to check
    ///
    /// # Returns
    ///
    /// `true` if a handler exists, `false` otherwise
    pub fn has_notification_handler(&self, method: &str) -> bool {
        self.notification_handlers
            .try_lock()
            .map(|handlers| handlers.contains_key(method))
            .unwrap_or(false)
    }

    /// Builds the protocol with the configured handlers.
    ///
    /// # Returns
    ///
    /// A new `Protocol` instance
    pub fn build(self) -> Protocol {
        Protocol {
            request_id: Arc::new(AtomicU64::new(0)),
            pending_requests: Arc::new(Mutex::new(HashMap::new())),
            request_handlers: self.request_handlers,
            notification_handlers: self.notification_handlers,
        }
    }
}

/// Trait for handling JSON-RPC requests.
///
/// Implementors of this trait can handle incoming JSON-RPC requests
/// and produce responses.
#[async_trait]
trait RequestHandler: Send + Sync {
    /// Handles an incoming JSON-RPC request.
    ///
    /// # Arguments
    ///
    /// * `request` - The incoming JSON-RPC request
    ///
    /// # Returns
    ///
    /// A `Result` containing the response or an error
    async fn handle(&self, request: JsonRpcRequest) -> Result<JsonRpcResponse>;
}

/// Trait for handling JSON-RPC notifications.
///
/// Implementors of this trait can handle incoming JSON-RPC notifications.
#[async_trait]
trait NotificationHandler: Send + Sync {
    /// Handles an incoming JSON-RPC notification.
    ///
    /// # Arguments
    ///
    /// * `notification` - The incoming JSON-RPC notification
    ///
    /// # Returns
    ///
    /// A `Result` indicating success or failure
    async fn handle(&self, notification: JsonRpcNotification) -> Result<()>;
}

/// A typed request handler.
///
/// This struct adapts a typed handler function to the `RequestHandler` trait,
/// handling the deserialization of the request and serialization of the response.
struct TypedRequestHandler<Req, Resp>
where
    Req: DeserializeOwned + Send + Sync + 'static,
    Resp: Serialize + Send + Sync + 'static,
{
    handler: Box<
        dyn Fn(Req) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<Resp>> + Send>>
            + Send
            + Sync,
    >,
    _phantom: std::marker::PhantomData<(Req, Resp)>,
}

#[async_trait]
impl<Req, Resp> RequestHandler for TypedRequestHandler<Req, Resp>
where
    Req: DeserializeOwned + Send + Sync + 'static,
    Resp: Serialize + Send + Sync + 'static,
{
    async fn handle(&self, request: JsonRpcRequest) -> Result<JsonRpcResponse> {
        let params: Req = if request.params.is_none() || request.params.as_ref().unwrap().is_null()
        {
            serde_json::from_value(json!({}))?
        } else {
            serde_json::from_value(request.params.unwrap())?
        };
        let result = (self.handler)(params).await?;
        Ok(JsonRpcResponse {
            id: request.id,
            result: Some(serde_json::to_value(result)?),
            error: None,
            ..Default::default()
        })
    }
}

/// A typed notification handler.
///
/// This struct adapts a typed handler function to the `NotificationHandler` trait,
/// handling the deserialization of the notification.
struct TypedNotificationHandler<N>
where
    N: DeserializeOwned + Send + Sync + 'static,
{
    handler: Box<
        dyn Fn(N) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<()>> + Send>>
            + Send
            + Sync,
    >,
    _phantom: std::marker::PhantomData<N>,
}

#[async_trait]
impl<N> NotificationHandler for TypedNotificationHandler<N>
where
    N: DeserializeOwned + Send + Sync + 'static,
{
    async fn handle(&self, notification: JsonRpcNotification) -> Result<()> {
        let params: N =
            if notification.params.is_none() || notification.params.as_ref().unwrap().is_null() {
                serde_json::from_value(serde_json::Value::Null)?
            } else {
                serde_json::from_value(notification.params.unwrap())?
            };
        (self.handler)(params).await
    }
}