mcp_server_rs/router/
traits.rs

1use std::{future::Future, pin::Pin};
2
3use mcp_core_rs::{
4    Resource, ResourceContents, Tool,
5    content::Content,
6    prompt::{Prompt, PromptMessage, PromptMessageRole},
7    protocol::{
8        capabilities::ServerCapabilities,
9        message::{JsonRpcRequest, JsonRpcResponse},
10        result::{
11            CallToolResult, GetPromptResult, Implementation, InitializeResult, ListPromptsResult,
12            ListResourcesResult, ListToolsResult, ReadResourceResult,
13        },
14    },
15};
16use mcp_error_rs::{Error, Result};
17use serde_json::Value;
18
19type PromptFuture = Pin<Box<dyn Future<Output = Result<String>> + Send + 'static>>;
20
21pub trait Router: Send + Sync + 'static {
22    fn name(&self) -> String;
23    // in the protocol, instructions are optional but we make it required
24    fn instructions(&self) -> String;
25    fn capabilities(&self) -> ServerCapabilities;
26    fn list_tools(&self) -> Vec<Tool>;
27    fn call_tool(
28        &self,
29        tool_name: &str,
30        arguments: Value,
31    ) -> Pin<Box<dyn Future<Output = Result<Vec<Content>>> + Send + 'static>>;
32    fn list_resources(&self) -> Vec<Resource>;
33    fn read_resource(
34        &self,
35        uri: &str,
36    ) -> Pin<Box<dyn Future<Output = Result<String>> + Send + 'static>>;
37    fn list_prompts(&self) -> Vec<Prompt>;
38    fn get_prompt(&self, prompt_name: &str) -> PromptFuture;
39
40    // Helper method to create base response
41    fn create_response(&self, id: Option<u64>) -> JsonRpcResponse {
42        JsonRpcResponse::empty(id)
43    }
44
45    fn handle_initialize(
46        &self,
47        req: JsonRpcRequest,
48    ) -> impl Future<Output = Result<JsonRpcResponse>> + Send {
49        async move {
50            let result = InitializeResult {
51                protocol_version: "2024-11-05".to_string(),
52                capabilities: self.capabilities(),
53                server_info: Implementation {
54                    name: self.name(),
55                    version: env!("CARGO_PKG_VERSION").to_string(),
56                },
57                instructions: Some(self.instructions()),
58            };
59
60            let mut response = self.create_response(req.id);
61            response.result = Some(
62                serde_json::to_value(result)
63                    .map_err(|e| Error::System(format!("JSON serialization error: {}", e)))?,
64            );
65
66            Ok(response)
67        }
68    }
69
70    fn handle_tools_list(
71        &self,
72        req: JsonRpcRequest,
73    ) -> impl Future<Output = Result<JsonRpcResponse>> + Send {
74        async move {
75            let tools = self.list_tools();
76
77            let result = ListToolsResult {
78                tools,
79                next_cursor: None,
80            };
81            let mut response = self.create_response(req.id);
82            response.result = Some(
83                serde_json::to_value(result)
84                    .map_err(|e| Error::System(format!("JSON serialization error: {}", e)))?,
85            );
86
87            Ok(response)
88        }
89    }
90
91    fn handle_tools_call(
92        &self,
93        req: JsonRpcRequest,
94    ) -> impl Future<Output = Result<JsonRpcResponse>> + Send {
95        async move {
96            let params = req
97                .params
98                .ok_or_else(|| Error::InvalidParameters("Missing parameters".into()))?;
99
100            let name = params
101                .get("name")
102                .and_then(Value::as_str)
103                .ok_or_else(|| Error::InvalidParameters("Missing tool name".into()))?;
104
105            let arguments = params.get("arguments").cloned().unwrap_or(Value::Null);
106
107            let result = match self.call_tool(name, arguments).await {
108                Ok(result) => CallToolResult {
109                    content: result,
110                    is_error: None,
111                },
112                Err(err) => CallToolResult {
113                    content: vec![Content::text(err.to_string())],
114                    is_error: Some(true),
115                },
116            };
117
118            let mut response = self.create_response(req.id);
119            response.result = Some(
120                serde_json::to_value(result)
121                    .map_err(|e| Error::System(format!("JSON serialization error: {}", e)))?,
122            );
123
124            Ok(response)
125        }
126    }
127
128    fn handle_resources_list(
129        &self,
130        req: JsonRpcRequest,
131    ) -> impl Future<Output = Result<JsonRpcResponse>> + Send {
132        async move {
133            let resources = self.list_resources();
134
135            let result = ListResourcesResult {
136                resources,
137                next_cursor: None,
138            };
139            let mut response = self.create_response(req.id);
140            response.result = Some(
141                serde_json::to_value(result)
142                    .map_err(|e| Error::System(format!("JSON serialization error: {}", e)))?,
143            );
144
145            Ok(response)
146        }
147    }
148
149    fn handle_resources_read(
150        &self,
151        req: JsonRpcRequest,
152    ) -> impl Future<Output = Result<JsonRpcResponse>> + Send {
153        async move {
154            let params = req
155                .params
156                .ok_or_else(|| Error::InvalidParameters("Missing parameters".into()))?;
157
158            let uri = params
159                .get("uri")
160                .and_then(Value::as_str)
161                .ok_or_else(|| Error::InvalidParameters("Missing resource URI".into()))?;
162
163            let contents = self.read_resource(uri).await.map_err(Error::from)?;
164
165            let result = ReadResourceResult {
166                contents: vec![ResourceContents::TextResourceContents {
167                    uri: uri.to_string(),
168                    mime_type: Some("text/plain".to_string()),
169                    text: contents,
170                }],
171            };
172
173            let mut response = self.create_response(req.id);
174            response.result = Some(
175                serde_json::to_value(result)
176                    .map_err(|e| Error::System(format!("JSON serialization error: {}", e)))?,
177            );
178
179            Ok(response)
180        }
181    }
182
183    fn handle_prompts_list(
184        &self,
185        req: JsonRpcRequest,
186    ) -> impl Future<Output = Result<JsonRpcResponse>> + Send {
187        async move {
188            let prompts = self.list_prompts();
189
190            let result = ListPromptsResult { prompts };
191
192            let mut response = self.create_response(req.id);
193            response.result = Some(
194                serde_json::to_value(result)
195                    .map_err(|e| Error::System(format!("JSON serialization error: {}", e)))?,
196            );
197
198            Ok(response)
199        }
200    }
201
202    fn handle_prompts_get(
203        &self,
204        req: JsonRpcRequest,
205    ) -> impl Future<Output = Result<JsonRpcResponse>> + Send {
206        async move {
207            // Validate and extract parameters
208            let params = req
209                .params
210                .ok_or_else(|| Error::InvalidParameters("Missing parameters".into()))?;
211
212            // Extract "name" field
213            let prompt_name = params
214                .get("name")
215                .and_then(Value::as_str)
216                .ok_or_else(|| Error::InvalidParameters("Missing prompt name".into()))?;
217
218            // Extract "arguments" field
219            let arguments = params
220                .get("arguments")
221                .and_then(Value::as_object)
222                .ok_or_else(|| Error::InvalidParameters("Missing arguments object".into()))?;
223
224            // Fetch the prompt definition first
225            let prompt = self
226                .list_prompts()
227                .into_iter()
228                .find(|p| p.name == prompt_name)
229                .ok_or_else(|| Error::System(format!("Prompt '{}' not found", prompt_name)))?;
230
231            // Validate required arguments
232            if let Some(args) = &prompt.arguments {
233                for arg in args {
234                    if arg.required.is_some()
235                        && arg.required.unwrap()
236                        && (!arguments.contains_key(&arg.name)
237                            || arguments
238                                .get(&arg.name)
239                                .and_then(Value::as_str)
240                                .is_none_or(str::is_empty))
241                    {
242                        return Err(Error::InvalidParameters(format!(
243                            "Missing required argument: '{}'",
244                            arg.name
245                        )));
246                    }
247                }
248            }
249
250            // Now get the prompt content
251            let description = self
252                .get_prompt(prompt_name)
253                .await
254                .map_err(|e| Error::System(e.to_string()))?;
255
256            // Validate prompt arguments for potential security issues from user text input
257            // Checks:
258            // - Prompt must be less than 10000 total characters
259            // - Argument keys must be less than 1000 characters
260            // - Argument values must be less than 1000 characters
261            // - Dangerous patterns, eg "../", "//", "\\\\", "<script>", "{{", "}}"
262            for (key, value) in arguments.iter() {
263                // Check for empty or overly long keys/values
264                if key.is_empty() || key.len() > 1000 {
265                    return Err(Error::InvalidParameters(
266                        "Argument keys must be between 1-1000 characters".into(),
267                    ));
268                }
269
270                let value_str = value.as_str().unwrap_or_default();
271                if value_str.len() > 1000 {
272                    return Err(Error::InvalidParameters(
273                        "Argument values must not exceed 1000 characters".into(),
274                    ));
275                }
276
277                // Check for potentially dangerous patterns
278                let dangerous_patterns = ["../", "//", "\\\\", "<script>", "{{", "}}"];
279                for pattern in dangerous_patterns {
280                    if key.contains(pattern) || value_str.contains(pattern) {
281                        return Err(Error::InvalidParameters(format!(
282                            "Arguments contain potentially unsafe pattern: {}",
283                            pattern
284                        )));
285                    }
286                }
287            }
288
289            // Validate the prompt description length
290            if description.len() > 10000 {
291                return Err(Error::System(
292                    "Prompt description exceeds maximum allowed length".into(),
293                ));
294            }
295
296            // Create a mutable copy of the description to fill in arguments
297            let mut description_filled = description.clone();
298
299            // Replace each argument placeholder with its value from the arguments object
300            for (key, value) in arguments {
301                let placeholder = format!("{{{}}}", key);
302                description_filled =
303                    description_filled.replace(&placeholder, value.as_str().unwrap_or_default());
304            }
305
306            let messages = vec![PromptMessage::new_text(
307                PromptMessageRole::User,
308                description_filled.to_string(),
309            )];
310
311            // Build the final response
312            let mut response = self.create_response(req.id);
313            response.result = Some(
314                serde_json::to_value(GetPromptResult {
315                    description: Some(description_filled),
316                    messages,
317                })
318                .map_err(|e| Error::System(format!("JSON serialization error: {}", e)))?,
319            );
320            Ok(response)
321        }
322    }
323}