mcp_attr/
server.rs

1//! Module for implementing MCP server
2
3use std::{future::Future, sync::Arc};
4
5use jsoncall::{
6    ErrorCode, Handler, Hook, NotificationContext, Params, RequestContextAs, RequestId, Response,
7    Result, Session, SessionContext, SessionOptions, SessionResult, bail_public,
8};
9use serde::{Serialize, de::DeserializeOwned};
10use serde_json::Map;
11
12use crate::{
13    common::McpCancellationHook,
14    schema::{
15        CallToolRequestParams, CallToolResult, CancelledNotificationParams, ClientCapabilities,
16        CompleteRequestParams, CompleteResult, CreateMessageRequestParams, CreateMessageResult,
17        GetPromptRequestParams, GetPromptResult, Implementation, InitializeRequestParams,
18        InitializeResult, InitializedNotificationParams, ListPromptsRequestParams,
19        ListPromptsResult, ListResourceTemplatesRequestParams, ListResourceTemplatesResult,
20        ListResourcesRequestParams, ListResourcesResult, ListRootsRequestParams, ListRootsResult,
21        ListToolsRequestParams, ListToolsResult, PingRequestParams, ProgressNotificationParams,
22        ReadResourceRequestParams, ReadResourceResult, Root, ServerCapabilities,
23        ServerCapabilitiesPrompts, ServerCapabilitiesResources, ServerCapabilitiesTools,
24    },
25    server::errors::{prompt_not_found, tool_not_found},
26    utils::{Empty, ProtocolVersion},
27};
28
29pub mod errors;
30mod mcp_server_attr;
31
32pub use mcp_server_attr::mcp_server;
33
34struct SessionData {
35    initialize: InitializeRequestParams,
36    protocol_version: ProtocolVersion,
37}
38
39struct McpServerHandler {
40    server: Arc<dyn DynMcpServer>,
41    data: Option<Arc<SessionData>>,
42    is_initizlized: bool,
43}
44impl Handler for McpServerHandler {
45    fn hook(&self) -> Arc<dyn Hook> {
46        Arc::new(McpCancellationHook)
47    }
48    fn request(
49        &mut self,
50        method: &str,
51        params: Params,
52        cx: jsoncall::RequestContext,
53    ) -> Result<Response> {
54        match method {
55            "initialize" => return cx.handle(self.initialize(params.to()?)),
56            "ping" => return cx.handle(self.ping(params.to_opt()?)),
57            _ => {}
58        }
59        let (Some(data), true) = (&self.data, self.is_initizlized) else {
60            bail_public!(_, "Server not initialized");
61        };
62        let d = data.clone();
63        match method {
64            "prompts/list" => self.call_opt(params, cx, |s, p, cx| s.dyn_prompts_list(p, cx, d)),
65            "prompts/get" => self.call(params, cx, |s, p, cx| s.dyn_prompts_get(p, cx, d)),
66            "resources/list" => {
67                self.call_opt(params, cx, |s, p, cx| s.dyn_resources_list(p, cx, d))
68            }
69            "resources/templates/list" => self.call_opt(params, cx, |s, p, cx| {
70                s.dyn_resources_templates_list(p, cx, d)
71            }),
72            "resources/read" => self.call(params, cx, |s, p, cx| s.dyn_resources_read(p, cx, d)),
73            "tools/list" => self.call_opt(params, cx, |s, p, cx| s.dyn_tools_list(p, cx, d)),
74            "tools/call" => self.call(params, cx, |s, p, cx| s.dyn_tools_call(p, cx, d)),
75            "completion/complete" => {
76                self.call(params, cx, |s, p, cx| s.dyn_completion_complete(p, cx, d))
77            }
78            _ => cx.method_not_found(),
79        }
80    }
81    fn notification(
82        &mut self,
83        method: &str,
84        params: Params,
85        cx: NotificationContext,
86    ) -> Result<Response> {
87        match method {
88            "notifications/initialized" => cx.handle(self.initialized(params.to_opt()?)),
89            "notifications/cancelled" => self.notifications_cancelled(params.to()?, cx),
90            _ => cx.method_not_found(),
91        }
92    }
93}
94impl McpServerHandler {
95    pub fn new(server: impl McpServer) -> Self {
96        Self {
97            server: Arc::new(server),
98            data: None,
99            is_initizlized: false,
100        }
101    }
102}
103impl McpServerHandler {
104    fn initialize(&mut self, p: InitializeRequestParams) -> Result<InitializeResult> {
105        self.data = Some(Arc::new(SessionData {
106            initialize: p,
107            protocol_version: ProtocolVersion::LATEST,
108        }));
109        Ok(self.server.initialize_result())
110    }
111    fn initialized(&mut self, _p: Option<InitializedNotificationParams>) -> Result<()> {
112        if self.data.is_none() {
113            bail_public!(
114                _,
115                "`initialize` request must be called before `initialized` notification"
116            );
117        }
118        self.is_initizlized = true;
119        Ok(())
120    }
121    fn ping(&self, _p: Option<PingRequestParams>) -> Result<Empty> {
122        Ok(Empty::default())
123    }
124    fn notifications_cancelled(
125        &self,
126        p: CancelledNotificationParams,
127        cx: NotificationContext,
128    ) -> Result<Response> {
129        cx.session().cancel_incoming_request(&p.request_id, None);
130        cx.handle(Ok(()))
131    }
132
133    // fn logging_set_level(&self, p: SetLevelRequestParams) -> Result<()> {
134    //     todo!()
135    // }
136
137    // fn resources_subscribe(&self, p: SubscribeRequestParams) -> Result<()> {
138    //     todo!()
139    // }
140
141    // fn resources_unsubscribe(&self, p: UnsubscribeRequestParams) -> Result<()> {
142    //     todo!()
143    // }
144
145    fn call<P, R>(
146        &self,
147        p: Params,
148        cx: jsoncall::RequestContext,
149        f: impl FnOnce(Arc<dyn DynMcpServer>, P, RequestContextAs<R>) -> Result<Response>,
150    ) -> Result<Response>
151    where
152        P: DeserializeOwned,
153        R: Serialize,
154    {
155        f(self.server.clone(), p.to()?, cx.to())
156    }
157    fn call_opt<P, R>(
158        &self,
159        p: Params,
160        cx: jsoncall::RequestContext,
161        f: impl FnOnce(Arc<dyn DynMcpServer>, P, RequestContextAs<R>) -> Result<Response>,
162    ) -> Result<Response>
163    where
164        P: DeserializeOwned + Default,
165        R: Serialize,
166    {
167        f(
168            self.server.clone(),
169            p.to_opt()?.unwrap_or_default(),
170            cx.to(),
171        )
172    }
173}
174
175trait DynMcpServer: Send + Sync + 'static {
176    fn initialize_result(&self) -> InitializeResult;
177
178    fn dyn_prompts_list(
179        self: Arc<Self>,
180        p: ListPromptsRequestParams,
181        cx: RequestContextAs<ListPromptsResult>,
182        data: Arc<SessionData>,
183    ) -> Result<Response>;
184
185    fn dyn_prompts_get(
186        self: Arc<Self>,
187        p: GetPromptRequestParams,
188        cx: RequestContextAs<GetPromptResult>,
189        data: Arc<SessionData>,
190    ) -> Result<Response>;
191
192    fn dyn_resources_list(
193        self: Arc<Self>,
194        p: ListResourcesRequestParams,
195        cx: RequestContextAs<ListResourcesResult>,
196        data: Arc<SessionData>,
197    ) -> Result<Response>;
198
199    fn dyn_resources_read(
200        self: Arc<Self>,
201        p: ReadResourceRequestParams,
202        cx: RequestContextAs<ReadResourceResult>,
203        data: Arc<SessionData>,
204    ) -> Result<Response>;
205
206    fn dyn_resources_templates_list(
207        self: Arc<Self>,
208        p: ListResourceTemplatesRequestParams,
209        cx: RequestContextAs<ListResourceTemplatesResult>,
210        data: Arc<SessionData>,
211    ) -> Result<Response>;
212
213    fn dyn_tools_list(
214        self: Arc<Self>,
215        p: ListToolsRequestParams,
216        cx: RequestContextAs<ListToolsResult>,
217        data: Arc<SessionData>,
218    ) -> Result<Response>;
219
220    fn dyn_tools_call(
221        self: Arc<Self>,
222        p: CallToolRequestParams,
223        cx: RequestContextAs<CallToolResult>,
224        data: Arc<SessionData>,
225    ) -> Result<Response>;
226
227    fn dyn_completion_complete(
228        self: Arc<Self>,
229        p: CompleteRequestParams,
230        cx: RequestContextAs<CompleteResult>,
231        data: Arc<SessionData>,
232    ) -> Result<Response>;
233}
234impl<T: McpServer> DynMcpServer for T {
235    fn initialize_result(&self) -> InitializeResult {
236        InitializeResult {
237            capabilities: self.capabilities(),
238            instructions: self.instructions(),
239            meta: Map::new(),
240            protocol_version: ProtocolVersion::LATEST.to_string(),
241            server_info: self.server_info(),
242        }
243    }
244    fn dyn_prompts_list(
245        self: Arc<Self>,
246        p: ListPromptsRequestParams,
247        cx: RequestContextAs<ListPromptsResult>,
248        data: Arc<SessionData>,
249    ) -> Result<Response> {
250        let mut mcp_cx = RequestContext::new(&cx, data);
251        cx.handle_async(async move { self.prompts_list(p, &mut mcp_cx).await })
252    }
253
254    fn dyn_prompts_get(
255        self: Arc<Self>,
256        p: GetPromptRequestParams,
257        cx: RequestContextAs<GetPromptResult>,
258        data: Arc<SessionData>,
259    ) -> Result<Response> {
260        let mut mcp_cx = RequestContext::new(&cx, data);
261        cx.handle_async(async move { self.prompts_get(p, &mut mcp_cx).await })
262    }
263
264    fn dyn_resources_list(
265        self: Arc<Self>,
266        p: ListResourcesRequestParams,
267        cx: RequestContextAs<ListResourcesResult>,
268        data: Arc<SessionData>,
269    ) -> Result<Response> {
270        let mut mcp_cx = RequestContext::new(&cx, data);
271        cx.handle_async(async move { self.resources_list(p, &mut mcp_cx).await })
272    }
273
274    fn dyn_resources_templates_list(
275        self: Arc<Self>,
276        p: ListResourceTemplatesRequestParams,
277        cx: RequestContextAs<ListResourceTemplatesResult>,
278        data: Arc<SessionData>,
279    ) -> Result<Response> {
280        let mut mcp_cx = RequestContext::new(&cx, data);
281        cx.handle_async(async move { self.resources_templates_list(p, &mut mcp_cx).await })
282    }
283
284    fn dyn_resources_read(
285        self: Arc<Self>,
286        p: ReadResourceRequestParams,
287        cx: RequestContextAs<ReadResourceResult>,
288        data: Arc<SessionData>,
289    ) -> Result<Response> {
290        let mut mcp_cx = RequestContext::new(&cx, data);
291        cx.handle_async(async move { self.resources_read(p, &mut mcp_cx).await })
292    }
293
294    fn dyn_tools_list(
295        self: Arc<Self>,
296        p: ListToolsRequestParams,
297        cx: RequestContextAs<ListToolsResult>,
298        data: Arc<SessionData>,
299    ) -> Result<Response> {
300        let mut mcp_cx = RequestContext::new(&cx, data);
301        cx.handle_async(async move { self.tools_list(p, &mut mcp_cx).await })
302    }
303
304    fn dyn_tools_call(
305        self: Arc<Self>,
306        p: CallToolRequestParams,
307        cx: RequestContextAs<CallToolResult>,
308        data: Arc<SessionData>,
309    ) -> Result<Response> {
310        let mut mcp_cx = RequestContext::new(&cx, data);
311        cx.handle_async(async move { self.tools_call(p, &mut mcp_cx).await })
312    }
313
314    fn dyn_completion_complete(
315        self: Arc<Self>,
316        p: CompleteRequestParams,
317        cx: RequestContextAs<CompleteResult>,
318        data: Arc<SessionData>,
319    ) -> Result<Response> {
320        let mut mcp_cx = RequestContext::new(&cx, data);
321        cx.handle_async(async move { self.completion_complete(p, &mut mcp_cx).await })
322    }
323}
324
325/// Trait for implementing MCP server
326pub trait McpServer: Send + Sync + 'static {
327    /// Returns `server_info` used in the [`initialize`] request response
328    ///
329    /// [`initialize`]: https://spec.modelcontextprotocol.io/specification/2024-11-05/basic/lifecycle/#initialization
330    fn server_info(&self) -> Implementation {
331        Implementation::from_compile_time_env()
332    }
333
334    /// Returns `instructions` used in the [`initialize`] request response
335    ///
336    /// [`initialize`]: https://spec.modelcontextprotocol.io/specification/2024-11-05/basic/lifecycle/#initialization
337    fn instructions(&self) -> Option<String> {
338        None
339    }
340
341    /// Returns `capabilities` used in the [`initialize`] request response
342    ///
343    /// [`initialize`]: https://spec.modelcontextprotocol.io/specification/2024-11-05/basic/lifecycle/#initialization
344    fn capabilities(&self) -> ServerCapabilities {
345        ServerCapabilities {
346            prompts: Some(ServerCapabilitiesPrompts {
347                ..Default::default()
348            }),
349            resources: Some(ServerCapabilitiesResources {
350                ..Default::default()
351            }),
352            tools: Some(ServerCapabilitiesTools {
353                ..Default::default()
354            }),
355            ..Default::default()
356        }
357    }
358
359    /// Handles [`prompts/list`]
360    ///
361    /// [`prompts/list`]: https://spec.modelcontextprotocol.io/specification/2024-11-05/server/prompts/#listing-prompts
362    #[allow(unused_variables)]
363    fn prompts_list(
364        &self,
365        p: ListPromptsRequestParams,
366        cx: &mut RequestContext,
367    ) -> impl Future<Output = Result<ListPromptsResult>> + Send {
368        async { Ok(ListPromptsResult::default()) }
369    }
370
371    /// Handles [`prompts/get`]
372    ///
373    /// [`prompts/get`]: https://spec.modelcontextprotocol.io/specification/2024-11-05/server/prompts/#getting-a-prompt
374    #[allow(unused_variables)]
375    fn prompts_get(
376        &self,
377        p: GetPromptRequestParams,
378        cx: &mut RequestContext,
379    ) -> impl Future<Output = Result<GetPromptResult>> + Send {
380        async move { Err(prompt_not_found(&p.name)) }
381    }
382
383    /// Handles [`resources/list`]
384    ///
385    /// [`resources/list`]: https://spec.modelcontextprotocol.io/specification/2024-11-05/server/resources/#listing-resources
386    #[allow(unused_variables)]
387    fn resources_list(
388        &self,
389        p: ListResourcesRequestParams,
390        cx: &mut RequestContext,
391    ) -> impl Future<Output = Result<ListResourcesResult>> + Send {
392        async { Ok(ListResourcesResult::default()) }
393    }
394
395    /// Handles [`resources/templates/list`]
396    ///
397    /// [`resources/templates/list`]: https://spec.modelcontextprotocol.io/specification/2024-11-05/server/resources/#resource-templates
398    #[allow(unused_variables)]
399    fn resources_templates_list(
400        &self,
401        p: ListResourceTemplatesRequestParams,
402        cx: &mut RequestContext,
403    ) -> impl Future<Output = Result<ListResourceTemplatesResult>> + Send {
404        async { Ok(ListResourceTemplatesResult::default()) }
405    }
406
407    /// Handles [`resources/read`]
408    ///
409    /// [`resources/read`]: https://spec.modelcontextprotocol.io/specification/2024-11-05/server/resources/#reading-resources
410    #[allow(unused_variables)]
411    fn resources_read(
412        &self,
413        p: ReadResourceRequestParams,
414        cx: &mut RequestContext,
415    ) -> impl Future<Output = Result<ReadResourceResult>> + Send {
416        async move { bail_public!(ErrorCode::INVALID_PARAMS, "Resource `{}` not found", p.uri) }
417    }
418
419    /// Handles [`tools/list`]
420    ///
421    /// [`tools/list`]: https://spec.modelcontextprotocol.io/specification/2024-11-05/server/tools/#listing-tools
422    #[allow(unused_variables)]
423    fn tools_list(
424        &self,
425        p: ListToolsRequestParams,
426        cx: &mut RequestContext,
427    ) -> impl Future<Output = Result<ListToolsResult>> + Send {
428        async { Ok(ListToolsResult::default()) }
429    }
430
431    /// Handles [`tools/call`]
432    ///
433    /// [`tools/call`]: https://spec.modelcontextprotocol.io/specification/2024-11-05/server/tools/#calling-a-tool
434    #[allow(unused_variables)]
435    fn tools_call(
436        &self,
437        p: CallToolRequestParams,
438        cx: &mut RequestContext,
439    ) -> impl Future<Output = Result<CallToolResult>> + Send {
440        async move { Err(tool_not_found(&p.name)) }
441    }
442
443    /// Handles [`completion/complete`]
444    ///
445    /// [`completion/complete`]: https://spec.modelcontextprotocol.io/specification/2024-11-05/server/utilities/completion/#completing-a-prompt
446    #[allow(unused_variables)]
447    fn completion_complete(
448        &self,
449        p: CompleteRequestParams,
450        cx: &mut RequestContext,
451    ) -> impl Future<Output = Result<CompleteResult>> + Send {
452        async { Ok(CompleteResult::default()) }
453    }
454
455    /// Gets the JSON RPC `Handler`
456    fn into_handler(self) -> impl Handler + Send + Sync + 'static
457    where
458        Self: Sized + Send + Sync + 'static,
459    {
460        McpServerHandler::new(self)
461    }
462}
463
464/// Context for retrieving request-related information and calling client features
465pub struct RequestContext {
466    session: SessionContext,
467    id: RequestId,
468    data: Arc<SessionData>,
469}
470
471impl RequestContext {
472    fn new(cx: &RequestContextAs<impl Serialize>, data: Arc<SessionData>) -> Self {
473        Self {
474            session: cx.session(),
475            id: cx.id().clone(),
476            data,
477        }
478    }
479
480    /// Gets client information
481    pub fn client_info(&self) -> &Implementation {
482        &self.data.initialize.client_info
483    }
484
485    /// Gets client capabilities
486    pub fn client_capabilities(&self) -> &ClientCapabilities {
487        &self.data.initialize.capabilities
488    }
489
490    /// Protocol version of the current session
491    pub fn protocol_version(&self) -> ProtocolVersion {
492        self.data.protocol_version
493    }
494
495    /// Notifies progress of the request associated with this context
496    ///
497    /// See [`notifications/progress`]
498    ///
499    /// [`notifications/progress`]: https://spec.modelcontextprotocol.io/specification/2024-11-05/server/notifications/#progress-notification
500    pub fn progress(&self, progress: f64, total: Option<f64>) {
501        self.session
502            .notification(
503                "notifications/progress",
504                Some(&ProgressNotificationParams {
505                    progress,
506                    total,
507                    progress_token: self.id.clone(),
508                }),
509            )
510            .unwrap();
511    }
512
513    /// Calls [`sampling/createMessage`]
514    pub async fn sampling_create_message(
515        &self,
516        p: CreateMessageRequestParams,
517    ) -> SessionResult<CreateMessageResult> {
518        self.session
519            .request("sampling/createMessage", Some(&p))
520            .await
521    }
522
523    /// Calls [`roots/list`]
524    pub async fn roots_list(&self) -> SessionResult<Vec<Root>> {
525        let res: ListRootsResult = self
526            .session
527            .request("roots/list", Some(&ListRootsRequestParams::default()))
528            .await?;
529        Ok(res.roots)
530    }
531}
532
533/// Runs an MCP server using stdio transport
534pub async fn serve_stdio(server: impl McpServer) -> SessionResult<()> {
535    Session::from_stdio(McpServerHandler::new(server), &SessionOptions::default())
536        .wait()
537        .await
538}
539
540/// Runs an MCP server using stdio transport with specified options
541pub async fn serve_stdio_with(
542    server: impl McpServer,
543    options: &SessionOptions,
544) -> SessionResult<()> {
545    Session::from_stdio(McpServerHandler::new(server), options)
546        .wait()
547        .await
548}