mcp_attr/
server.rs

1//! Module for implementing MCP server
2
3use std::{future::Future, sync::Arc};
4
5use jsoncall::{
6    ErrorCode, Handler, Hook, NotificationContext, Params, RequestContextAs, RequestId, Response,
7    Result, Session, SessionContext, SessionOptions, SessionResult, bail_public,
8};
9use serde::{Serialize, de::DeserializeOwned};
10use serde_json::Map;
11
12use crate::{
13    common::McpCancellationHook,
14    schema::{
15        CallToolRequestParams, CallToolResult, CancelledNotificationParams, ClientCapabilities,
16        CompleteRequestParams, CompleteResult, CreateMessageRequestParams, CreateMessageResult,
17        GetPromptRequestParams, GetPromptResult, Implementation, InitializeRequestParams,
18        InitializeResult, InitializedNotificationParams, ListPromptsRequestParams,
19        ListPromptsResult, ListResourceTemplatesRequestParams, ListResourceTemplatesResult,
20        ListResourcesRequestParams, ListResourcesResult, ListRootsRequestParams, ListRootsResult,
21        ListToolsRequestParams, ListToolsResult, PingRequestParams, ProgressNotificationParams,
22        ReadResourceRequestParams, ReadResourceResult, Root, ServerCapabilities,
23        ServerCapabilitiesPrompts, ServerCapabilitiesResources, ServerCapabilitiesTools,
24    },
25    server::errors::{prompt_not_found, tool_not_found},
26    utils::{Empty, ProtocolVersion},
27};
28
29pub mod builder;
30pub mod errors;
31mod mcp_server_attr;
32
33pub use builder::{McpServerBuilder, prompt, resource, route, tool};
34pub use mcp_server_attr::mcp_server;
35
36struct SessionData {
37    initialize: InitializeRequestParams,
38    protocol_version: ProtocolVersion,
39}
40
41struct McpServerHandler {
42    server: Arc<dyn DynMcpServer>,
43    data: Option<Arc<SessionData>>,
44    is_initizlized: bool,
45}
46impl Handler for McpServerHandler {
47    fn hook(&self) -> Arc<dyn Hook> {
48        Arc::new(McpCancellationHook)
49    }
50    fn request(
51        &mut self,
52        method: &str,
53        params: Params,
54        cx: jsoncall::RequestContext,
55    ) -> Result<Response> {
56        match method {
57            "initialize" => return cx.handle(self.initialize(params.to()?)),
58            "ping" => return cx.handle(self.ping(params.to_opt()?)),
59            _ => {}
60        }
61        let (Some(data), true) = (&self.data, self.is_initizlized) else {
62            bail_public!(_, "Server not initialized");
63        };
64        let d = data.clone();
65        match method {
66            "prompts/list" => self.call_opt(params, cx, |s, p, cx| s.dyn_prompts_list(p, cx, d)),
67            "prompts/get" => self.call(params, cx, |s, p, cx| s.dyn_prompts_get(p, cx, d)),
68            "resources/list" => {
69                self.call_opt(params, cx, |s, p, cx| s.dyn_resources_list(p, cx, d))
70            }
71            "resources/templates/list" => self.call_opt(params, cx, |s, p, cx| {
72                s.dyn_resources_templates_list(p, cx, d)
73            }),
74            "resources/read" => self.call(params, cx, |s, p, cx| s.dyn_resources_read(p, cx, d)),
75            "tools/list" => self.call_opt(params, cx, |s, p, cx| s.dyn_tools_list(p, cx, d)),
76            "tools/call" => self.call(params, cx, |s, p, cx| s.dyn_tools_call(p, cx, d)),
77            "completion/complete" => {
78                self.call(params, cx, |s, p, cx| s.dyn_completion_complete(p, cx, d))
79            }
80            _ => cx.method_not_found(),
81        }
82    }
83    fn notification(
84        &mut self,
85        method: &str,
86        params: Params,
87        cx: NotificationContext,
88    ) -> Result<Response> {
89        match method {
90            "notifications/initialized" => cx.handle(self.initialized(params.to_opt()?)),
91            "notifications/cancelled" => self.notifications_cancelled(params.to()?, cx),
92            _ => cx.method_not_found(),
93        }
94    }
95}
96impl McpServerHandler {
97    pub fn new(server: impl McpServer) -> Self {
98        Self {
99            server: Arc::new(server),
100            data: None,
101            is_initizlized: false,
102        }
103    }
104}
105impl McpServerHandler {
106    fn initialize(&mut self, p: InitializeRequestParams) -> Result<InitializeResult> {
107        let mut protocol_version = ProtocolVersion::LATEST;
108        if p.protocol_version == ProtocolVersion::V_2024_11_05.as_str() {
109            protocol_version = ProtocolVersion::V_2024_11_05;
110        }
111        self.data = Some(Arc::new(SessionData {
112            initialize: p,
113            protocol_version,
114        }));
115        Ok(self.server.initialize_result(protocol_version))
116    }
117    fn initialized(&mut self, _p: Option<InitializedNotificationParams>) -> Result<()> {
118        if self.data.is_none() {
119            bail_public!(
120                _,
121                "`initialize` request must be called before `initialized` notification"
122            );
123        }
124        self.is_initizlized = true;
125        Ok(())
126    }
127    fn ping(&self, _p: Option<PingRequestParams>) -> Result<Empty> {
128        Ok(Empty::default())
129    }
130    fn notifications_cancelled(
131        &self,
132        p: CancelledNotificationParams,
133        cx: NotificationContext,
134    ) -> Result<Response> {
135        cx.session().cancel_incoming_request(&p.request_id, None);
136        cx.handle(Ok(()))
137    }
138
139    // fn logging_set_level(&self, p: SetLevelRequestParams) -> Result<()> {
140    //     todo!()
141    // }
142
143    // fn resources_subscribe(&self, p: SubscribeRequestParams) -> Result<()> {
144    //     todo!()
145    // }
146
147    // fn resources_unsubscribe(&self, p: UnsubscribeRequestParams) -> Result<()> {
148    //     todo!()
149    // }
150
151    fn call<P, R>(
152        &self,
153        p: Params,
154        cx: jsoncall::RequestContext,
155        f: impl FnOnce(Arc<dyn DynMcpServer>, P, RequestContextAs<R>) -> Result<Response>,
156    ) -> Result<Response>
157    where
158        P: DeserializeOwned,
159        R: Serialize,
160    {
161        f(self.server.clone(), p.to()?, cx.to())
162    }
163    fn call_opt<P, R>(
164        &self,
165        p: Params,
166        cx: jsoncall::RequestContext,
167        f: impl FnOnce(Arc<dyn DynMcpServer>, P, RequestContextAs<R>) -> Result<Response>,
168    ) -> Result<Response>
169    where
170        P: DeserializeOwned + Default,
171        R: Serialize,
172    {
173        f(
174            self.server.clone(),
175            p.to_opt()?.unwrap_or_default(),
176            cx.to(),
177        )
178    }
179}
180
181trait DynMcpServer: Send + Sync + 'static {
182    fn initialize_result(&self, protocol_version: ProtocolVersion) -> InitializeResult;
183
184    fn dyn_prompts_list(
185        self: Arc<Self>,
186        p: ListPromptsRequestParams,
187        cx: RequestContextAs<ListPromptsResult>,
188        data: Arc<SessionData>,
189    ) -> Result<Response>;
190
191    fn dyn_prompts_get(
192        self: Arc<Self>,
193        p: GetPromptRequestParams,
194        cx: RequestContextAs<GetPromptResult>,
195        data: Arc<SessionData>,
196    ) -> Result<Response>;
197
198    fn dyn_resources_list(
199        self: Arc<Self>,
200        p: ListResourcesRequestParams,
201        cx: RequestContextAs<ListResourcesResult>,
202        data: Arc<SessionData>,
203    ) -> Result<Response>;
204
205    fn dyn_resources_read(
206        self: Arc<Self>,
207        p: ReadResourceRequestParams,
208        cx: RequestContextAs<ReadResourceResult>,
209        data: Arc<SessionData>,
210    ) -> Result<Response>;
211
212    fn dyn_resources_templates_list(
213        self: Arc<Self>,
214        p: ListResourceTemplatesRequestParams,
215        cx: RequestContextAs<ListResourceTemplatesResult>,
216        data: Arc<SessionData>,
217    ) -> Result<Response>;
218
219    fn dyn_tools_list(
220        self: Arc<Self>,
221        p: ListToolsRequestParams,
222        cx: RequestContextAs<ListToolsResult>,
223        data: Arc<SessionData>,
224    ) -> Result<Response>;
225
226    fn dyn_tools_call(
227        self: Arc<Self>,
228        p: CallToolRequestParams,
229        cx: RequestContextAs<CallToolResult>,
230        data: Arc<SessionData>,
231    ) -> Result<Response>;
232
233    fn dyn_completion_complete(
234        self: Arc<Self>,
235        p: CompleteRequestParams,
236        cx: RequestContextAs<CompleteResult>,
237        data: Arc<SessionData>,
238    ) -> Result<Response>;
239}
240impl<T: McpServer> DynMcpServer for T {
241    fn initialize_result(&self, protocol_version: ProtocolVersion) -> InitializeResult {
242        InitializeResult {
243            capabilities: self.capabilities(),
244            instructions: self.instructions(),
245            meta: Map::new(),
246            protocol_version: protocol_version.as_str().to_string(),
247            server_info: self.server_info(),
248        }
249    }
250    fn dyn_prompts_list(
251        self: Arc<Self>,
252        p: ListPromptsRequestParams,
253        cx: RequestContextAs<ListPromptsResult>,
254        data: Arc<SessionData>,
255    ) -> Result<Response> {
256        let mut mcp_cx = RequestContext::new(&cx, data);
257        cx.handle_async(async move { self.prompts_list(p, &mut mcp_cx).await })
258    }
259
260    fn dyn_prompts_get(
261        self: Arc<Self>,
262        p: GetPromptRequestParams,
263        cx: RequestContextAs<GetPromptResult>,
264        data: Arc<SessionData>,
265    ) -> Result<Response> {
266        let mut mcp_cx = RequestContext::new(&cx, data);
267        cx.handle_async(async move { self.prompts_get(p, &mut mcp_cx).await })
268    }
269
270    fn dyn_resources_list(
271        self: Arc<Self>,
272        p: ListResourcesRequestParams,
273        cx: RequestContextAs<ListResourcesResult>,
274        data: Arc<SessionData>,
275    ) -> Result<Response> {
276        let mut mcp_cx = RequestContext::new(&cx, data);
277        cx.handle_async(async move { self.resources_list(p, &mut mcp_cx).await })
278    }
279
280    fn dyn_resources_templates_list(
281        self: Arc<Self>,
282        p: ListResourceTemplatesRequestParams,
283        cx: RequestContextAs<ListResourceTemplatesResult>,
284        data: Arc<SessionData>,
285    ) -> Result<Response> {
286        let mut mcp_cx = RequestContext::new(&cx, data);
287        cx.handle_async(async move { self.resources_templates_list(p, &mut mcp_cx).await })
288    }
289
290    fn dyn_resources_read(
291        self: Arc<Self>,
292        p: ReadResourceRequestParams,
293        cx: RequestContextAs<ReadResourceResult>,
294        data: Arc<SessionData>,
295    ) -> Result<Response> {
296        let mut mcp_cx = RequestContext::new(&cx, data);
297        cx.handle_async(async move { self.resources_read(p, &mut mcp_cx).await })
298    }
299
300    fn dyn_tools_list(
301        self: Arc<Self>,
302        p: ListToolsRequestParams,
303        cx: RequestContextAs<ListToolsResult>,
304        data: Arc<SessionData>,
305    ) -> Result<Response> {
306        let mut mcp_cx = RequestContext::new(&cx, data);
307        cx.handle_async(async move { self.tools_list(p, &mut mcp_cx).await })
308    }
309
310    fn dyn_tools_call(
311        self: Arc<Self>,
312        p: CallToolRequestParams,
313        cx: RequestContextAs<CallToolResult>,
314        data: Arc<SessionData>,
315    ) -> Result<Response> {
316        let mut mcp_cx = RequestContext::new(&cx, data);
317        cx.handle_async(async move { self.tools_call(p, &mut mcp_cx).await })
318    }
319
320    fn dyn_completion_complete(
321        self: Arc<Self>,
322        p: CompleteRequestParams,
323        cx: RequestContextAs<CompleteResult>,
324        data: Arc<SessionData>,
325    ) -> Result<Response> {
326        let mut mcp_cx = RequestContext::new(&cx, data);
327        cx.handle_async(async move { self.completion_complete(p, &mut mcp_cx).await })
328    }
329}
330
331/// Trait for implementing MCP server
332pub trait McpServer: Send + Sync + 'static {
333    /// Returns `server_info` used in the [`initialize`] request response
334    ///
335    /// [`initialize`]: https://modelcontextprotocol.io/specification/2025-03-26/basic/lifecycle/#initialization
336    fn server_info(&self) -> Implementation {
337        Implementation::from_compile_time_env()
338    }
339
340    /// Returns `instructions` used in the [`initialize`] request response
341    ///
342    /// [`initialize`]: https://modelcontextprotocol.io/specification/2025-03-26/basic/lifecycle/#initialization
343    fn instructions(&self) -> Option<String> {
344        None
345    }
346
347    /// Returns `capabilities` used in the [`initialize`] request response
348    ///
349    /// [`initialize`]: https://modelcontextprotocol.io/specification/2025-03-26/basic/lifecycle/#initialization
350    fn capabilities(&self) -> ServerCapabilities {
351        ServerCapabilities {
352            prompts: Some(ServerCapabilitiesPrompts {
353                ..Default::default()
354            }),
355            resources: Some(ServerCapabilitiesResources {
356                ..Default::default()
357            }),
358            tools: Some(ServerCapabilitiesTools {
359                ..Default::default()
360            }),
361            ..Default::default()
362        }
363    }
364
365    /// Handles [`prompts/list`]
366    ///
367    /// [`prompts/list`]: https://modelcontextprotocol.io/specification/2025-03-26/server/prompts/#listing-prompts
368    #[allow(unused_variables)]
369    fn prompts_list(
370        &self,
371        p: ListPromptsRequestParams,
372        cx: &mut RequestContext,
373    ) -> impl Future<Output = Result<ListPromptsResult>> + Send {
374        async { Ok(ListPromptsResult::default()) }
375    }
376
377    /// Handles [`prompts/get`]
378    ///
379    /// [`prompts/get`]: https://modelcontextprotocol.io/specification/2025-03-26/server/prompts/#getting-a-prompt
380    #[allow(unused_variables)]
381    fn prompts_get(
382        &self,
383        p: GetPromptRequestParams,
384        cx: &mut RequestContext,
385    ) -> impl Future<Output = Result<GetPromptResult>> + Send {
386        async move { Err(prompt_not_found(&p.name)) }
387    }
388
389    /// Handles [`resources/list`]
390    ///
391    /// [`resources/list`]: https://modelcontextprotocol.io/specification/2025-03-26/server/resources/#listing-resources
392    #[allow(unused_variables)]
393    fn resources_list(
394        &self,
395        p: ListResourcesRequestParams,
396        cx: &mut RequestContext,
397    ) -> impl Future<Output = Result<ListResourcesResult>> + Send {
398        async { Ok(ListResourcesResult::default()) }
399    }
400
401    /// Handles [`resources/templates/list`]
402    ///
403    /// [`resources/templates/list`]: https://modelcontextprotocol.io/specification/2025-03-26/server/resources/#resource-templates
404    #[allow(unused_variables)]
405    fn resources_templates_list(
406        &self,
407        p: ListResourceTemplatesRequestParams,
408        cx: &mut RequestContext,
409    ) -> impl Future<Output = Result<ListResourceTemplatesResult>> + Send {
410        async { Ok(ListResourceTemplatesResult::default()) }
411    }
412
413    /// Handles [`resources/read`]
414    ///
415    /// [`resources/read`]: https://modelcontextprotocol.io/specification/2025-03-26/server/resources/#reading-resources
416    #[allow(unused_variables)]
417    fn resources_read(
418        &self,
419        p: ReadResourceRequestParams,
420        cx: &mut RequestContext,
421    ) -> impl Future<Output = Result<ReadResourceResult>> + Send {
422        async move { bail_public!(ErrorCode::INVALID_PARAMS, "Resource `{}` not found", p.uri) }
423    }
424
425    /// Handles [`tools/list`]
426    ///
427    /// [`tools/list`]: https://modelcontextprotocol.io/specification/2025-03-26/server/tools/#listing-tools
428    #[allow(unused_variables)]
429    fn tools_list(
430        &self,
431        p: ListToolsRequestParams,
432        cx: &mut RequestContext,
433    ) -> impl Future<Output = Result<ListToolsResult>> + Send {
434        async { Ok(ListToolsResult::default()) }
435    }
436
437    /// Handles [`tools/call`]
438    ///
439    /// [`tools/call`]: https://modelcontextprotocol.io/specification/2025-03-26/server/tools/#calling-tools
440    #[allow(unused_variables)]
441    fn tools_call(
442        &self,
443        p: CallToolRequestParams,
444        cx: &mut RequestContext,
445    ) -> impl Future<Output = Result<CallToolResult>> + Send {
446        async move { Err(tool_not_found(&p.name)) }
447    }
448
449    /// Handles [`completion/complete`]
450    ///
451    /// [`completion/complete`]: https://modelcontextprotocol.io/specification/2025-03-26/server/utilities/completion/#requesting-completions
452    #[allow(unused_variables)]
453    fn completion_complete(
454        &self,
455        p: CompleteRequestParams,
456        cx: &mut RequestContext,
457    ) -> impl Future<Output = Result<CompleteResult>> + Send {
458        async { Ok(CompleteResult::default()) }
459    }
460
461    /// Gets the JSON RPC `Handler`
462    fn into_handler(self) -> impl Handler + Send + Sync + 'static
463    where
464        Self: Sized + Send + Sync + 'static,
465    {
466        McpServerHandler::new(self)
467    }
468}
469
470/// Context for retrieving request-related information and calling client features
471pub struct RequestContext {
472    session: SessionContext,
473    id: RequestId,
474    data: Arc<SessionData>,
475}
476
477impl RequestContext {
478    fn new(cx: &RequestContextAs<impl Serialize>, data: Arc<SessionData>) -> Self {
479        Self {
480            session: cx.session(),
481            id: cx.id().clone(),
482            data,
483        }
484    }
485
486    /// Gets client information
487    pub fn client_info(&self) -> &Implementation {
488        &self.data.initialize.client_info
489    }
490
491    /// Gets client capabilities
492    pub fn client_capabilities(&self) -> &ClientCapabilities {
493        &self.data.initialize.capabilities
494    }
495
496    /// Protocol version of the current session
497    pub fn protocol_version(&self) -> ProtocolVersion {
498        self.data.protocol_version
499    }
500
501    /// Notifies progress of the request associated with this context
502    ///
503    /// See [`notifications/progress`]
504    ///
505    /// [`notifications/progress`]: https://modelcontextprotocol.io/specification/2025-03-26/basic/utilities/progress/
506    pub fn progress(&self, progress: f64, total: Option<f64>, message: Option<String>) {
507        self.session
508            .notification(
509                "notifications/progress",
510                Some(&ProgressNotificationParams {
511                    progress,
512                    total,
513                    message,
514                    progress_token: self.id.clone(),
515                }),
516            )
517            .unwrap();
518    }
519
520    /// Calls [`sampling/createMessage`]
521    ///
522    /// [`sampling/createMessage`]: https://modelcontextprotocol.io/specification/2025-03-26/client/sampling/#creating-messages
523    pub async fn sampling_create_message(
524        &self,
525        p: CreateMessageRequestParams,
526    ) -> SessionResult<CreateMessageResult> {
527        self.session
528            .request("sampling/createMessage", Some(&p))
529            .await
530    }
531
532    /// Calls [`roots/list`]
533    ///
534    /// [`roots/list`]: https://modelcontextprotocol.io/specification/2025-03-26/client/roots/#listing-roots
535    pub async fn roots_list(&self) -> SessionResult<Vec<Root>> {
536        let res: ListRootsResult = self
537            .session
538            .request("roots/list", Some(&ListRootsRequestParams::default()))
539            .await?;
540        Ok(res.roots)
541    }
542}
543
544/// Runs an MCP server using stdio transport
545pub async fn serve_stdio(server: impl McpServer) -> SessionResult<()> {
546    Session::from_stdio(McpServerHandler::new(server), &SessionOptions::default())
547        .wait()
548        .await
549}
550
551/// Runs an MCP server using stdio transport with specified options
552pub async fn serve_stdio_with(
553    server: impl McpServer,
554    options: &SessionOptions,
555) -> SessionResult<()> {
556    Session::from_stdio(McpServerHandler::new(server), options)
557        .wait()
558        .await
559}