rmcp/handler/server/
router.rs1use std::sync::Arc;
2
3use prompt::{IntoPromptRoute, PromptRoute};
4use tool::{IntoToolRoute, ToolRoute};
5
6use super::ServerHandler;
7use crate::{
8 RoleServer, Service,
9 model::{ClientNotification, ClientRequest, ListPromptsResult, ListToolsResult, ServerResult},
10 service::NotificationContext,
11};
12
13pub mod prompt;
14pub mod tool;
15
16#[non_exhaustive]
17pub struct Router<S> {
18 pub tool_router: tool::ToolRouter<S>,
19 pub prompt_router: prompt::PromptRouter<S>,
20 pub service: Arc<S>,
21 peer_slot: Arc<std::sync::OnceLock<crate::service::Peer<RoleServer>>>,
22}
23
24impl<S> Router<S>
25where
26 S: ServerHandler,
27{
28 pub fn new(service: S) -> Self {
29 let (notifier, peer_slot) = tool::ToolRouter::<S>::deferred_peer_notifier();
30 let mut tool_router = tool::ToolRouter::new();
31 tool_router.set_notifier(notifier);
32 Self {
33 tool_router,
34 prompt_router: prompt::PromptRouter::new(),
35 service: Arc::new(service),
36 peer_slot,
37 }
38 }
39
40 pub fn with_tool<R, A>(mut self, route: R) -> Self
41 where
42 R: IntoToolRoute<S, A>,
43 {
44 self.tool_router.add_route(route.into_tool_route());
45 self
46 }
47
48 pub fn with_tools(mut self, routes: impl IntoIterator<Item = ToolRoute<S>>) -> Self {
49 for route in routes {
50 self.tool_router.add_route(route);
51 }
52 self
53 }
54
55 pub fn with_prompt<R, A: 'static>(mut self, route: R) -> Self
56 where
57 R: IntoPromptRoute<S, A>,
58 {
59 self.prompt_router.add_route(route.into_prompt_route());
60 self
61 }
62
63 pub fn with_prompts(mut self, routes: impl IntoIterator<Item = PromptRoute<S>>) -> Self {
64 for route in routes {
65 self.prompt_router.add_route(route);
66 }
67 self
68 }
69}
70
71impl<S> Service<RoleServer> for Router<S>
72where
73 S: ServerHandler,
74{
75 async fn handle_notification(
76 &self,
77 notification: <RoleServer as crate::service::ServiceRole>::PeerNot,
78 context: NotificationContext<RoleServer>,
79 ) -> Result<(), crate::ErrorData> {
80 if matches!(
81 ¬ification,
82 ClientNotification::InitializedNotification(_)
83 ) {
84 let _ = self.peer_slot.set(context.peer.clone());
85 }
86 self.service
87 .handle_notification(notification, context)
88 .await
89 }
90 async fn handle_request(
91 &self,
92 request: <RoleServer as crate::service::ServiceRole>::PeerReq,
93 context: crate::service::RequestContext<RoleServer>,
94 ) -> Result<<RoleServer as crate::service::ServiceRole>::Resp, crate::ErrorData> {
95 match request {
96 ClientRequest::CallToolRequest(request) => {
97 if self
98 .tool_router
99 .map
100 .contains_key(request.params.name.as_ref())
101 || !self.tool_router.transparent_when_not_found
102 {
103 let tool_call_context = crate::handler::server::tool::ToolCallContext::new(
104 self.service.as_ref(),
105 request.params,
106 context,
107 );
108 let result = self.tool_router.call(tool_call_context).await?;
109 Ok(ServerResult::CallToolResult(result))
110 } else {
111 self.service
112 .handle_request(ClientRequest::CallToolRequest(request), context)
113 .await
114 }
115 }
116 ClientRequest::ListToolsRequest(_) => {
117 let tools = self.tool_router.list_all();
118 Ok(ServerResult::ListToolsResult(ListToolsResult {
119 tools,
120 ..Default::default()
121 }))
122 }
123 ClientRequest::GetPromptRequest(request) => {
124 if self.prompt_router.has_route(request.params.name.as_ref()) {
125 let prompt_context = crate::handler::server::prompt::PromptContext::new(
126 self.service.as_ref(),
127 request.params.name,
128 request.params.arguments,
129 context,
130 );
131 let result = self.prompt_router.get_prompt(prompt_context).await?;
132 Ok(ServerResult::GetPromptResult(result))
133 } else {
134 self.service
135 .handle_request(ClientRequest::GetPromptRequest(request), context)
136 .await
137 }
138 }
139 ClientRequest::ListPromptsRequest(_) => {
140 let prompts = self.prompt_router.list_all();
141 Ok(ServerResult::ListPromptsResult(ListPromptsResult {
142 prompts,
143 ..Default::default()
144 }))
145 }
146 rest => self.service.handle_request(rest, context).await,
147 }
148 }
149
150 fn get_info(&self) -> <RoleServer as crate::service::ServiceRole>::Info {
151 let mut info = ServerHandler::get_info(&self.service);
152 info.capabilities
153 .tools
154 .get_or_insert_with(Default::default)
155 .list_changed = Some(true);
156 info
157 }
158}
159
160#[cfg(test)]
161mod tests {
162 use std::sync::Arc;
163
164 use super::*;
165 use crate::{
166 model::{CallToolResult, ClientNotification, ServerNotification, Tool},
167 service::{AtomicU32RequestIdProvider, Peer, PeerSinkMessage, RequestIdProvider},
168 };
169
170 struct DummyHandler;
171 impl ServerHandler for DummyHandler {}
172
173 async fn recv_notification(
174 rx: &mut tokio::sync::mpsc::Receiver<PeerSinkMessage<RoleServer>>,
175 ) -> ServerNotification {
176 let msg = tokio::time::timeout(std::time::Duration::from_secs(1), rx.recv())
177 .await
178 .expect("timed out")
179 .expect("channel closed");
180 match msg {
181 PeerSinkMessage::Notification {
182 notification,
183 responder,
184 } => {
185 let _ = responder.send(Ok(()));
186 notification
187 }
188 other => panic!("expected notification, got {other:?}"),
189 }
190 }
191
192 #[tokio::test]
193 async fn test_router_deferred_notifier_e2e() {
194 let mut router = Router::new(DummyHandler).with_tool(tool::ToolRoute::new_dyn(
195 Tool::new("my_tool", "test", Arc::new(Default::default())),
196 |_ctx| Box::pin(async { Ok(CallToolResult::default()) }),
197 ));
198
199 let id_provider: Arc<dyn RequestIdProvider> =
200 Arc::new(AtomicU32RequestIdProvider::default());
201 let (peer, mut rx) = Peer::<RoleServer>::new(id_provider, None);
202
203 let context = crate::service::NotificationContext {
204 peer: peer.clone(),
205 meta: Default::default(),
206 extensions: Default::default(),
207 };
208 router
209 .handle_notification(
210 ClientNotification::InitializedNotification(Default::default()),
211 context,
212 )
213 .await
214 .unwrap();
215
216 router.tool_router.disable_route("my_tool");
217 assert!(matches!(
218 recv_notification(&mut rx).await,
219 ServerNotification::ToolListChangedNotification(_)
220 ));
221
222 router.tool_router.enable_route("my_tool");
223 assert!(matches!(
224 recv_notification(&mut rx).await,
225 ServerNotification::ToolListChangedNotification(_)
226 ));
227 }
228}