mcp_server/
router.rs

1use std::{
2    future::Future,
3    pin::Pin,
4    task::{Context, Poll},
5};
6
7type PromptFuture = Pin<Box<dyn Future<Output = Result<String, PromptError>> + Send + 'static>>;
8
9use mcp_spec::{
10    content::Content,
11    handler::{PromptError, ResourceError, ToolError},
12    prompt::{Prompt, PromptMessage, PromptMessageRole},
13    protocol::{
14        CallToolResult, GetPromptResult, Implementation, InitializeResult, JsonRpcRequest,
15        JsonRpcResponse, ListPromptsResult, ListResourcesResult, ListToolsResult,
16        PromptsCapability, ReadResourceResult, ResourcesCapability, ServerCapabilities,
17        ToolsCapability,
18    },
19    ResourceContents,
20};
21use serde_json::Value;
22use tower_service::Service;
23
24use crate::{BoxError, RouterError};
25
26/// Builder for configuring and constructing capabilities
27pub struct CapabilitiesBuilder {
28    tools: Option<ToolsCapability>,
29    prompts: Option<PromptsCapability>,
30    resources: Option<ResourcesCapability>,
31}
32
33impl Default for CapabilitiesBuilder {
34    fn default() -> Self {
35        Self::new()
36    }
37}
38
39impl CapabilitiesBuilder {
40    pub fn new() -> Self {
41        Self {
42            tools: None,
43            prompts: None,
44            resources: None,
45        }
46    }
47
48    /// Add multiple tools to the router
49    pub fn with_tools(mut self, list_changed: bool) -> Self {
50        self.tools = Some(ToolsCapability {
51            list_changed: Some(list_changed),
52        });
53        self
54    }
55
56    /// Enable prompts capability
57    pub fn with_prompts(mut self, list_changed: bool) -> Self {
58        self.prompts = Some(PromptsCapability {
59            list_changed: Some(list_changed),
60        });
61        self
62    }
63
64    /// Enable resources capability
65    pub fn with_resources(mut self, subscribe: bool, list_changed: bool) -> Self {
66        self.resources = Some(ResourcesCapability {
67            subscribe: Some(subscribe),
68            list_changed: Some(list_changed),
69        });
70        self
71    }
72
73    /// Build the router with automatic capability inference
74    pub fn build(self) -> ServerCapabilities {
75        // Create capabilities based on what's configured
76        ServerCapabilities {
77            tools: self.tools,
78            prompts: self.prompts,
79            resources: self.resources,
80        }
81    }
82}
83
84pub trait Router: Send + Sync + 'static {
85    fn name(&self) -> String;
86    // in the protocol, instructions are optional but we make it required
87    fn instructions(&self) -> String;
88    fn capabilities(&self) -> ServerCapabilities;
89    fn list_tools(&self) -> Vec<mcp_spec::tool::Tool>;
90    fn call_tool(
91        &self,
92        tool_name: &str,
93        arguments: Value,
94    ) -> Pin<Box<dyn Future<Output = Result<Vec<Content>, ToolError>> + Send + 'static>>;
95    fn list_resources(&self) -> Vec<mcp_spec::resource::Resource>;
96    fn read_resource(
97        &self,
98        uri: &str,
99    ) -> Pin<Box<dyn Future<Output = Result<String, ResourceError>> + Send + 'static>>;
100    fn list_prompts(&self) -> Vec<Prompt>;
101    fn get_prompt(&self, prompt_name: &str) -> PromptFuture;
102
103    // Helper method to create base response
104    fn create_response(&self, id: Option<u64>) -> JsonRpcResponse {
105        JsonRpcResponse {
106            jsonrpc: "2.0".to_string(),
107            id,
108            result: None,
109            error: None,
110        }
111    }
112
113    fn handle_initialize(
114        &self,
115        req: JsonRpcRequest,
116    ) -> impl Future<Output = Result<JsonRpcResponse, RouterError>> + Send {
117        async move {
118            let result = InitializeResult {
119                protocol_version: "2024-11-05".to_string(),
120                capabilities: self.capabilities().clone(),
121                server_info: Implementation {
122                    name: self.name(),
123                    version: env!("CARGO_PKG_VERSION").to_string(),
124                },
125                instructions: Some(self.instructions()),
126            };
127
128            let mut response = self.create_response(req.id);
129            response.result =
130                Some(serde_json::to_value(result).map_err(|e| {
131                    RouterError::Internal(format!("JSON serialization error: {}", e))
132                })?);
133
134            Ok(response)
135        }
136    }
137
138    fn handle_tools_list(
139        &self,
140        req: JsonRpcRequest,
141    ) -> impl Future<Output = Result<JsonRpcResponse, RouterError>> + Send {
142        async move {
143            let tools = self.list_tools();
144
145            let result = ListToolsResult {
146                tools,
147                next_cursor: None,
148            };
149            let mut response = self.create_response(req.id);
150            response.result =
151                Some(serde_json::to_value(result).map_err(|e| {
152                    RouterError::Internal(format!("JSON serialization error: {}", e))
153                })?);
154
155            Ok(response)
156        }
157    }
158
159    fn handle_tools_call(
160        &self,
161        req: JsonRpcRequest,
162    ) -> impl Future<Output = Result<JsonRpcResponse, RouterError>> + Send {
163        async move {
164            let params = req
165                .params
166                .ok_or_else(|| RouterError::InvalidParams("Missing parameters".into()))?;
167
168            let name = params
169                .get("name")
170                .and_then(Value::as_str)
171                .ok_or_else(|| RouterError::InvalidParams("Missing tool name".into()))?;
172
173            let arguments = params.get("arguments").cloned().unwrap_or(Value::Null);
174
175            let result = match self.call_tool(name, arguments).await {
176                Ok(result) => CallToolResult {
177                    content: result,
178                    is_error: None,
179                },
180                Err(err) => CallToolResult {
181                    content: vec![Content::text(err.to_string())],
182                    is_error: Some(true),
183                },
184            };
185
186            let mut response = self.create_response(req.id);
187            response.result =
188                Some(serde_json::to_value(result).map_err(|e| {
189                    RouterError::Internal(format!("JSON serialization error: {}", e))
190                })?);
191
192            Ok(response)
193        }
194    }
195
196    fn handle_resources_list(
197        &self,
198        req: JsonRpcRequest,
199    ) -> impl Future<Output = Result<JsonRpcResponse, RouterError>> + Send {
200        async move {
201            let resources = self.list_resources();
202
203            let result = ListResourcesResult {
204                resources,
205                next_cursor: None,
206            };
207            let mut response = self.create_response(req.id);
208            response.result =
209                Some(serde_json::to_value(result).map_err(|e| {
210                    RouterError::Internal(format!("JSON serialization error: {}", e))
211                })?);
212
213            Ok(response)
214        }
215    }
216
217    fn handle_resources_read(
218        &self,
219        req: JsonRpcRequest,
220    ) -> impl Future<Output = Result<JsonRpcResponse, RouterError>> + Send {
221        async move {
222            let params = req
223                .params
224                .ok_or_else(|| RouterError::InvalidParams("Missing parameters".into()))?;
225
226            let uri = params
227                .get("uri")
228                .and_then(Value::as_str)
229                .ok_or_else(|| RouterError::InvalidParams("Missing resource URI".into()))?;
230
231            let contents = self.read_resource(uri).await.map_err(RouterError::from)?;
232
233            let result = ReadResourceResult {
234                contents: vec![ResourceContents::TextResourceContents {
235                    uri: uri.to_string(),
236                    mime_type: Some("text/plain".to_string()),
237                    text: contents,
238                }],
239            };
240
241            let mut response = self.create_response(req.id);
242            response.result =
243                Some(serde_json::to_value(result).map_err(|e| {
244                    RouterError::Internal(format!("JSON serialization error: {}", e))
245                })?);
246
247            Ok(response)
248        }
249    }
250
251    fn handle_prompts_list(
252        &self,
253        req: JsonRpcRequest,
254    ) -> impl Future<Output = Result<JsonRpcResponse, RouterError>> + Send {
255        async move {
256            let prompts = self.list_prompts();
257
258            let result = ListPromptsResult { prompts };
259
260            let mut response = self.create_response(req.id);
261            response.result =
262                Some(serde_json::to_value(result).map_err(|e| {
263                    RouterError::Internal(format!("JSON serialization error: {}", e))
264                })?);
265
266            Ok(response)
267        }
268    }
269
270    fn handle_prompts_get(
271        &self,
272        req: JsonRpcRequest,
273    ) -> impl Future<Output = Result<JsonRpcResponse, RouterError>> + Send {
274        async move {
275            // Validate and extract parameters
276            let params = req
277                .params
278                .ok_or_else(|| RouterError::InvalidParams("Missing parameters".into()))?;
279
280            // Extract "name" field
281            let prompt_name = params
282                .get("name")
283                .and_then(Value::as_str)
284                .ok_or_else(|| RouterError::InvalidParams("Missing prompt name".into()))?;
285
286            // Extract "arguments" field
287            let arguments = params
288                .get("arguments")
289                .and_then(Value::as_object)
290                .ok_or_else(|| RouterError::InvalidParams("Missing arguments object".into()))?;
291
292            // Fetch the prompt definition first
293            let prompt = self
294                .list_prompts()
295                .into_iter()
296                .find(|p| p.name == prompt_name)
297                .ok_or_else(|| {
298                    RouterError::PromptNotFound(format!("Prompt '{}' not found", prompt_name))
299                })?;
300
301            // Validate required arguments
302            if let Some(args) = &prompt.arguments {
303                for arg in args {
304                    if arg.required.is_some()
305                        && arg.required.unwrap()
306                        && (!arguments.contains_key(&arg.name)
307                            || arguments
308                                .get(&arg.name)
309                                .and_then(Value::as_str)
310                                .is_none_or(str::is_empty))
311                    {
312                        return Err(RouterError::InvalidParams(format!(
313                            "Missing required argument: '{}'",
314                            arg.name
315                        )));
316                    }
317                }
318            }
319
320            // Now get the prompt content
321            let description = self
322                .get_prompt(prompt_name)
323                .await
324                .map_err(|e| RouterError::Internal(e.to_string()))?;
325
326            // Validate prompt arguments for potential security issues from user text input
327            // Checks:
328            // - Prompt must be less than 10000 total characters
329            // - Argument keys must be less than 1000 characters
330            // - Argument values must be less than 1000 characters
331            // - Dangerous patterns, eg "../", "//", "\\\\", "<script>", "{{", "}}"
332            for (key, value) in arguments.iter() {
333                // Check for empty or overly long keys/values
334                if key.is_empty() || key.len() > 1000 {
335                    return Err(RouterError::InvalidParams(
336                        "Argument keys must be between 1-1000 characters".into(),
337                    ));
338                }
339
340                let value_str = value.as_str().unwrap_or_default();
341                if value_str.len() > 1000 {
342                    return Err(RouterError::InvalidParams(
343                        "Argument values must not exceed 1000 characters".into(),
344                    ));
345                }
346
347                // Check for potentially dangerous patterns
348                let dangerous_patterns = ["../", "//", "\\\\", "<script>", "{{", "}}"];
349                for pattern in dangerous_patterns {
350                    if key.contains(pattern) || value_str.contains(pattern) {
351                        return Err(RouterError::InvalidParams(format!(
352                            "Arguments contain potentially unsafe pattern: {}",
353                            pattern
354                        )));
355                    }
356                }
357            }
358
359            // Validate the prompt description length
360            if description.len() > 10000 {
361                return Err(RouterError::Internal(
362                    "Prompt description exceeds maximum allowed length".into(),
363                ));
364            }
365
366            // Create a mutable copy of the description to fill in arguments
367            let mut description_filled = description.clone();
368
369            // Replace each argument placeholder with its value from the arguments object
370            for (key, value) in arguments {
371                let placeholder = format!("{{{}}}", key);
372                description_filled =
373                    description_filled.replace(&placeholder, value.as_str().unwrap_or_default());
374            }
375
376            let messages = vec![PromptMessage::new_text(
377                PromptMessageRole::User,
378                description_filled.to_string(),
379            )];
380
381            // Build the final response
382            let mut response = self.create_response(req.id);
383            response.result = Some(
384                serde_json::to_value(GetPromptResult {
385                    description: Some(description_filled),
386                    messages,
387                })
388                .map_err(|e| RouterError::Internal(format!("JSON serialization error: {}", e)))?,
389            );
390            Ok(response)
391        }
392    }
393}
394
395pub struct RouterService<T>(pub T);
396
397impl<T> Service<JsonRpcRequest> for RouterService<T>
398where
399    T: Router + Clone + Send + Sync + 'static,
400{
401    type Response = JsonRpcResponse;
402    type Error = BoxError;
403    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
404
405    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
406        Poll::Ready(Ok(()))
407    }
408
409    fn call(&mut self, req: JsonRpcRequest) -> Self::Future {
410        let this = self.0.clone();
411
412        Box::pin(async move {
413            let result = match req.method.as_str() {
414                "initialize" => this.handle_initialize(req).await,
415                "tools/list" => this.handle_tools_list(req).await,
416                "tools/call" => this.handle_tools_call(req).await,
417                "resources/list" => this.handle_resources_list(req).await,
418                "resources/read" => this.handle_resources_read(req).await,
419                "prompts/list" => this.handle_prompts_list(req).await,
420                "prompts/get" => this.handle_prompts_get(req).await,
421                _ => {
422                    let mut response = this.create_response(req.id);
423                    response.error = Some(RouterError::MethodNotFound(req.method).into());
424                    Ok(response)
425                }
426            };
427
428            result.map_err(BoxError::from)
429        })
430    }
431}