Skip to main content

rmcp/handler/server/
router.rs

1use std::sync::Arc;
2
3use prompt::{IntoPromptRoute, PromptRoute};
4use tool::{IntoToolRoute, ToolRoute};
5
6use super::ServerHandler;
7use crate::{
8    RoleServer, Service,
9    model::{ClientRequest, ListPromptsResult, ListToolsResult, ServerResult},
10    service::NotificationContext,
11};
12
13pub mod prompt;
14pub mod tool;
15
16#[non_exhaustive]
17pub struct Router<S> {
18    pub tool_router: tool::ToolRouter<S>,
19    pub prompt_router: prompt::PromptRouter<S>,
20    pub service: Arc<S>,
21}
22
23impl<S> Router<S>
24where
25    S: ServerHandler,
26{
27    pub fn new(service: S) -> Self {
28        Self {
29            tool_router: tool::ToolRouter::new(),
30            prompt_router: prompt::PromptRouter::new(),
31            service: Arc::new(service),
32        }
33    }
34
35    pub fn with_tool<R, A>(mut self, route: R) -> Self
36    where
37        R: IntoToolRoute<S, A>,
38    {
39        self.tool_router.add_route(route.into_tool_route());
40        self
41    }
42
43    pub fn with_tools(mut self, routes: impl IntoIterator<Item = ToolRoute<S>>) -> Self {
44        for route in routes {
45            self.tool_router.add_route(route);
46        }
47        self
48    }
49
50    pub fn with_prompt<R, A: 'static>(mut self, route: R) -> Self
51    where
52        R: IntoPromptRoute<S, A>,
53    {
54        self.prompt_router.add_route(route.into_prompt_route());
55        self
56    }
57
58    pub fn with_prompts(mut self, routes: impl IntoIterator<Item = PromptRoute<S>>) -> Self {
59        for route in routes {
60            self.prompt_router.add_route(route);
61        }
62        self
63    }
64}
65
66impl<S> Service<RoleServer> for Router<S>
67where
68    S: ServerHandler,
69{
70    async fn handle_notification(
71        &self,
72        notification: <RoleServer as crate::service::ServiceRole>::PeerNot,
73        context: NotificationContext<RoleServer>,
74    ) -> Result<(), crate::ErrorData> {
75        self.service
76            .handle_notification(notification, context)
77            .await
78    }
79    async fn handle_request(
80        &self,
81        request: <RoleServer as crate::service::ServiceRole>::PeerReq,
82        context: crate::service::RequestContext<RoleServer>,
83    ) -> Result<<RoleServer as crate::service::ServiceRole>::Resp, crate::ErrorData> {
84        match request {
85            ClientRequest::CallToolRequest(request) => {
86                if self.tool_router.has_route(request.params.name.as_ref())
87                    || !self.tool_router.transparent_when_not_found
88                {
89                    let tool_call_context = crate::handler::server::tool::ToolCallContext::new(
90                        self.service.as_ref(),
91                        request.params,
92                        context,
93                    );
94                    let result = self.tool_router.call(tool_call_context).await?;
95                    Ok(ServerResult::CallToolResult(result))
96                } else {
97                    self.service
98                        .handle_request(ClientRequest::CallToolRequest(request), context)
99                        .await
100                }
101            }
102            ClientRequest::ListToolsRequest(_) => {
103                let tools = self.tool_router.list_all();
104                Ok(ServerResult::ListToolsResult(ListToolsResult {
105                    tools,
106                    ..Default::default()
107                }))
108            }
109            ClientRequest::GetPromptRequest(request) => {
110                if self.prompt_router.has_route(request.params.name.as_ref()) {
111                    let prompt_context = crate::handler::server::prompt::PromptContext::new(
112                        self.service.as_ref(),
113                        request.params.name,
114                        request.params.arguments,
115                        context,
116                    );
117                    let result = self.prompt_router.get_prompt(prompt_context).await?;
118                    Ok(ServerResult::GetPromptResult(result))
119                } else {
120                    self.service
121                        .handle_request(ClientRequest::GetPromptRequest(request), context)
122                        .await
123                }
124            }
125            ClientRequest::ListPromptsRequest(_) => {
126                let prompts = self.prompt_router.list_all();
127                Ok(ServerResult::ListPromptsResult(ListPromptsResult {
128                    prompts,
129                    ..Default::default()
130                }))
131            }
132            rest => self.service.handle_request(rest, context).await,
133        }
134    }
135
136    fn get_info(&self) -> <RoleServer as crate::service::ServiceRole>::Info {
137        ServerHandler::get_info(&self.service)
138    }
139}