Skip to main content

rmcp/handler/server/
router.rs

1use std::sync::Arc;
2
3use prompt::{IntoPromptRoute, PromptRoute};
4use tool::{IntoToolRoute, ToolRoute};
5
6use super::ServerHandler;
7use crate::{
8    RoleServer, Service,
9    model::{ClientNotification, ClientRequest, ListPromptsResult, ListToolsResult, ServerResult},
10    service::NotificationContext,
11};
12
13pub mod prompt;
14pub mod tool;
15
16#[non_exhaustive]
17pub struct Router<S> {
18    pub tool_router: tool::ToolRouter<S>,
19    pub prompt_router: prompt::PromptRouter<S>,
20    pub service: Arc<S>,
21    peer_slot: Arc<std::sync::OnceLock<crate::service::Peer<RoleServer>>>,
22}
23
24impl<S> Router<S>
25where
26    S: ServerHandler,
27{
28    pub fn new(service: S) -> Self {
29        let (notifier, peer_slot) = tool::ToolRouter::<S>::deferred_peer_notifier();
30        let mut tool_router = tool::ToolRouter::new();
31        tool_router.set_notifier(notifier);
32        Self {
33            tool_router,
34            prompt_router: prompt::PromptRouter::new(),
35            service: Arc::new(service),
36            peer_slot,
37        }
38    }
39
40    pub fn with_tool<R, A>(mut self, route: R) -> Self
41    where
42        R: IntoToolRoute<S, A>,
43    {
44        self.tool_router.add_route(route.into_tool_route());
45        self
46    }
47
48    pub fn with_tools(mut self, routes: impl IntoIterator<Item = ToolRoute<S>>) -> Self {
49        for route in routes {
50            self.tool_router.add_route(route);
51        }
52        self
53    }
54
55    pub fn with_prompt<R, A: 'static>(mut self, route: R) -> Self
56    where
57        R: IntoPromptRoute<S, A>,
58    {
59        self.prompt_router.add_route(route.into_prompt_route());
60        self
61    }
62
63    pub fn with_prompts(mut self, routes: impl IntoIterator<Item = PromptRoute<S>>) -> Self {
64        for route in routes {
65            self.prompt_router.add_route(route);
66        }
67        self
68    }
69}
70
71impl<S> Service<RoleServer> for Router<S>
72where
73    S: ServerHandler,
74{
75    async fn handle_notification(
76        &self,
77        notification: <RoleServer as crate::service::ServiceRole>::PeerNot,
78        context: NotificationContext<RoleServer>,
79    ) -> Result<(), crate::ErrorData> {
80        if matches!(
81            &notification,
82            ClientNotification::InitializedNotification(_)
83        ) {
84            let _ = self.peer_slot.set(context.peer.clone());
85        }
86        self.service
87            .handle_notification(notification, context)
88            .await
89    }
90    async fn handle_request(
91        &self,
92        request: <RoleServer as crate::service::ServiceRole>::PeerReq,
93        context: crate::service::RequestContext<RoleServer>,
94    ) -> Result<<RoleServer as crate::service::ServiceRole>::Resp, crate::ErrorData> {
95        match request {
96            ClientRequest::CallToolRequest(request) => {
97                if self
98                    .tool_router
99                    .map
100                    .contains_key(request.params.name.as_ref())
101                    || !self.tool_router.transparent_when_not_found
102                {
103                    let tool_call_context = crate::handler::server::tool::ToolCallContext::new(
104                        self.service.as_ref(),
105                        request.params,
106                        context,
107                    );
108                    let result = self.tool_router.call(tool_call_context).await?;
109                    Ok(ServerResult::CallToolResult(result))
110                } else {
111                    self.service
112                        .handle_request(ClientRequest::CallToolRequest(request), context)
113                        .await
114                }
115            }
116            ClientRequest::ListToolsRequest(_) => {
117                let tools = self.tool_router.list_all();
118                Ok(ServerResult::ListToolsResult(ListToolsResult {
119                    tools,
120                    ..Default::default()
121                }))
122            }
123            ClientRequest::GetPromptRequest(request) => {
124                if self.prompt_router.has_route(request.params.name.as_ref()) {
125                    let prompt_context = crate::handler::server::prompt::PromptContext::new(
126                        self.service.as_ref(),
127                        request.params.name,
128                        request.params.arguments,
129                        context,
130                    );
131                    let result = self.prompt_router.get_prompt(prompt_context).await?;
132                    Ok(ServerResult::GetPromptResult(result))
133                } else {
134                    self.service
135                        .handle_request(ClientRequest::GetPromptRequest(request), context)
136                        .await
137                }
138            }
139            ClientRequest::ListPromptsRequest(_) => {
140                let prompts = self.prompt_router.list_all();
141                Ok(ServerResult::ListPromptsResult(ListPromptsResult {
142                    prompts,
143                    ..Default::default()
144                }))
145            }
146            rest => self.service.handle_request(rest, context).await,
147        }
148    }
149
150    fn get_info(&self) -> <RoleServer as crate::service::ServiceRole>::Info {
151        let mut info = ServerHandler::get_info(&self.service);
152        info.capabilities
153            .tools
154            .get_or_insert_with(Default::default)
155            .list_changed = Some(true);
156        info
157    }
158}
159
160#[cfg(test)]
161mod tests {
162    use std::sync::Arc;
163
164    use super::*;
165    use crate::{
166        model::{CallToolResult, ClientNotification, ServerNotification, Tool},
167        service::{AtomicU32RequestIdProvider, Peer, PeerSinkMessage, RequestIdProvider},
168    };
169
170    struct DummyHandler;
171    impl ServerHandler for DummyHandler {}
172
173    async fn recv_notification(
174        rx: &mut tokio::sync::mpsc::Receiver<PeerSinkMessage<RoleServer>>,
175    ) -> ServerNotification {
176        let msg = tokio::time::timeout(std::time::Duration::from_secs(1), rx.recv())
177            .await
178            .expect("timed out")
179            .expect("channel closed");
180        match msg {
181            PeerSinkMessage::Notification {
182                notification,
183                responder,
184            } => {
185                let _ = responder.send(Ok(()));
186                notification
187            }
188            other => panic!("expected notification, got {other:?}"),
189        }
190    }
191
192    #[tokio::test]
193    async fn test_router_deferred_notifier_e2e() {
194        let mut router = Router::new(DummyHandler).with_tool(tool::ToolRoute::new_dyn(
195            Tool::new("my_tool", "test", Arc::new(Default::default())),
196            |_ctx| Box::pin(async { Ok(CallToolResult::default()) }),
197        ));
198
199        let id_provider: Arc<dyn RequestIdProvider> =
200            Arc::new(AtomicU32RequestIdProvider::default());
201        let (peer, mut rx) = Peer::<RoleServer>::new(id_provider, None);
202
203        let context = crate::service::NotificationContext {
204            peer: peer.clone(),
205            meta: Default::default(),
206            extensions: Default::default(),
207        };
208        router
209            .handle_notification(
210                ClientNotification::InitializedNotification(Default::default()),
211                context,
212            )
213            .await
214            .unwrap();
215
216        router.tool_router.disable_route("my_tool");
217        assert!(matches!(
218            recv_notification(&mut rx).await,
219            ServerNotification::ToolListChangedNotification(_)
220        ));
221
222        router.tool_router.enable_route("my_tool");
223        assert!(matches!(
224            recv_notification(&mut rx).await,
225            ServerNotification::ToolListChangedNotification(_)
226        ));
227    }
228}