mcp_server_fishcode2025/
router.rs

1use std::{
2    future::Future,
3    pin::Pin,
4    task::{Context, Poll},
5};
6
7type PromptFuture = Pin<Box<dyn Future<Output = Result<String, PromptError>> + Send + 'static>>;
8
9use mcp_core_fishcode2025::{
10    content::Content,
11    handler::{PromptError, ResourceError, ToolError},
12    prompt::{Prompt, PromptMessage, PromptMessageRole},
13    protocol::{
14        CallToolResult, GetPromptResult, Implementation, InitializeResult, JsonRpcRequest,
15        JsonRpcResponse, ListPromptsResult, ListResourcesResult, ListToolsResult,
16        PromptsCapability, ReadResourceResult, ResourcesCapability, ServerCapabilities,
17        ToolsCapability,
18    },
19    resource::Resource,
20    tool::Tool,
21    ResourceContents,
22};
23use serde_json::Value;
24use tower_service::Service;
25
26use crate::{BoxError, RouterError};
27
28/// Builder for configuring and constructing capabilities
29pub struct CapabilitiesBuilder {
30    tools: Option<ToolsCapability>,
31    prompts: Option<PromptsCapability>,
32    resources: Option<ResourcesCapability>,
33}
34
35impl Default for CapabilitiesBuilder {
36    fn default() -> Self {
37        Self::new()
38    }
39}
40
41impl CapabilitiesBuilder {
42    pub fn new() -> Self {
43        Self {
44            tools: None,
45            prompts: None,
46            resources: None,
47        }
48    }
49
50    /// Add multiple tools to the router
51    pub fn with_tools(mut self, list_changed: bool) -> Self {
52        self.tools = Some(ToolsCapability {
53            list_changed: Some(list_changed),
54        });
55        self
56    }
57
58    /// Enable prompts capability
59    pub fn with_prompts(mut self, list_changed: bool) -> Self {
60        self.prompts = Some(PromptsCapability {
61            list_changed: Some(list_changed),
62        });
63        self
64    }
65
66    /// Enable resources capability
67    pub fn with_resources(mut self, subscribe: bool, list_changed: bool) -> Self {
68        self.resources = Some(ResourcesCapability {
69            subscribe: Some(subscribe),
70            list_changed: Some(list_changed),
71        });
72        self
73    }
74
75    /// Build the router with automatic capability inference
76    pub fn build(self) -> ServerCapabilities {
77        // Create capabilities based on what's configured
78        ServerCapabilities {
79            tools: self.tools,
80            prompts: self.prompts,
81            resources: self.resources,
82        }
83    }
84}
85
86pub trait Router: Send + Sync + 'static {
87    fn name(&self) -> String;
88    // in the protocol, instructions are optional but we make it required
89    fn instructions(&self) -> String;
90    fn capabilities(&self) -> ServerCapabilities;
91    fn list_tools(&self) -> Vec<Tool>;
92    fn call_tool(
93        &self,
94        tool_name: &str,
95        arguments: Value,
96    ) -> Pin<Box<dyn Future<Output = Result<Vec<Content>, ToolError>> + Send + 'static>>;
97    fn list_resources(&self) -> Vec<Resource>;
98    fn read_resource(
99        &self,
100        uri: &str,
101    ) -> Pin<Box<dyn Future<Output = Result<String, ResourceError>> + Send + 'static>>;
102    fn list_prompts(&self) -> Vec<Prompt>;
103    fn get_prompt(&self, prompt_name: &str) -> PromptFuture;
104
105    // Helper method to create base response
106    fn create_response(&self, id: Option<u64>) -> JsonRpcResponse {
107        JsonRpcResponse {
108            jsonrpc: "2.0".to_string(),
109            id,
110            result: None,
111            error: None,
112        }
113    }
114
115    fn handle_initialize(
116        &self,
117        req: JsonRpcRequest,
118    ) -> impl Future<Output = Result<JsonRpcResponse, RouterError>> + Send {
119        async move {
120            let result = InitializeResult {
121                protocol_version: "2024-11-05".to_string(),
122                capabilities: self.capabilities().clone(),
123                server_info: Implementation {
124                    name: self.name(),
125                    version: env!("CARGO_PKG_VERSION").to_string(),
126                },
127                instructions: Some(self.instructions()),
128            };
129
130            let mut response = self.create_response(req.id);
131            response.result =
132                Some(serde_json::to_value(result).map_err(|e| {
133                    RouterError::Internal(format!("JSON serialization error: {}", e))
134                })?);
135
136            Ok(response)
137        }
138    }
139
140    fn handle_tools_list(
141        &self,
142        req: JsonRpcRequest,
143    ) -> impl Future<Output = Result<JsonRpcResponse, RouterError>> + Send {
144        async move {
145            let tools = self.list_tools();
146
147            let result = ListToolsResult {
148                tools,
149                next_cursor: None,
150            };
151            let mut response = self.create_response(req.id);
152            response.result =
153                Some(serde_json::to_value(result).map_err(|e| {
154                    RouterError::Internal(format!("JSON serialization error: {}", e))
155                })?);
156
157            Ok(response)
158        }
159    }
160
161    fn handle_tools_call(
162        &self,
163        req: JsonRpcRequest,
164    ) -> impl Future<Output = Result<JsonRpcResponse, RouterError>> + Send {
165        async move {
166            let params = req
167                .params
168                .ok_or_else(|| RouterError::InvalidParams("Missing parameters".into()))?;
169
170            let name = params
171                .get("name")
172                .and_then(Value::as_str)
173                .ok_or_else(|| RouterError::InvalidParams("Missing tool name".into()))?;
174
175            let arguments = params.get("arguments").cloned().unwrap_or(Value::Null);
176
177            let result = match self.call_tool(name, arguments).await {
178                Ok(result) => CallToolResult {
179                    content: result,
180                    is_error: None,
181                },
182                Err(err) => CallToolResult {
183                    content: vec![Content::text(err.to_string())],
184                    is_error: Some(true),
185                },
186            };
187
188            let mut response = self.create_response(req.id);
189            response.result =
190                Some(serde_json::to_value(result).map_err(|e| {
191                    RouterError::Internal(format!("JSON serialization error: {}", e))
192                })?);
193
194            Ok(response)
195        }
196    }
197
198    fn handle_resources_list(
199        &self,
200        req: JsonRpcRequest,
201    ) -> impl Future<Output = Result<JsonRpcResponse, RouterError>> + Send {
202        async move {
203            let resources = self.list_resources();
204
205            let result = ListResourcesResult {
206                resources,
207                next_cursor: None,
208            };
209            let mut response = self.create_response(req.id);
210            response.result =
211                Some(serde_json::to_value(result).map_err(|e| {
212                    RouterError::Internal(format!("JSON serialization error: {}", e))
213                })?);
214
215            Ok(response)
216        }
217    }
218
219    fn handle_resources_read(
220        &self,
221        req: JsonRpcRequest,
222    ) -> impl Future<Output = Result<JsonRpcResponse, RouterError>> + Send {
223        async move {
224            let params = req
225                .params
226                .ok_or_else(|| RouterError::InvalidParams("Missing parameters".into()))?;
227
228            let uri = params
229                .get("uri")
230                .and_then(Value::as_str)
231                .ok_or_else(|| RouterError::InvalidParams("Missing resource URI".into()))?;
232
233            let contents = self.read_resource(uri).await.map_err(RouterError::from)?;
234
235            let result = ReadResourceResult {
236                contents: vec![ResourceContents::TextResourceContents {
237                    uri: uri.to_string(),
238                    mime_type: Some("text/plain".to_string()),
239                    text: contents,
240                }],
241            };
242
243            let mut response = self.create_response(req.id);
244            response.result =
245                Some(serde_json::to_value(result).map_err(|e| {
246                    RouterError::Internal(format!("JSON serialization error: {}", e))
247                })?);
248
249            Ok(response)
250        }
251    }
252
253    fn handle_prompts_list(
254        &self,
255        req: JsonRpcRequest,
256    ) -> impl Future<Output = Result<JsonRpcResponse, RouterError>> + Send {
257        async move {
258            let prompts = self.list_prompts();
259
260            let result = ListPromptsResult { prompts };
261
262            let mut response = self.create_response(req.id);
263            response.result =
264                Some(serde_json::to_value(result).map_err(|e| {
265                    RouterError::Internal(format!("JSON serialization error: {}", e))
266                })?);
267
268            Ok(response)
269        }
270    }
271
272    fn handle_prompts_get(
273        &self,
274        req: JsonRpcRequest,
275    ) -> impl Future<Output = Result<JsonRpcResponse, RouterError>> + Send {
276        async move {
277            // Validate and extract parameters
278            let params = req
279                .params
280                .ok_or_else(|| RouterError::InvalidParams("Missing parameters".into()))?;
281
282            // Extract "name" field
283            let prompt_name = params
284                .get("name")
285                .and_then(Value::as_str)
286                .ok_or_else(|| RouterError::InvalidParams("Missing prompt name".into()))?;
287
288            // Extract "arguments" field
289            let arguments = params
290                .get("arguments")
291                .and_then(Value::as_object)
292                .ok_or_else(|| RouterError::InvalidParams("Missing arguments object".into()))?;
293
294            // Fetch the prompt definition first
295            let prompt = self
296                .list_prompts()
297                .into_iter()
298                .find(|p| p.name == prompt_name)
299                .ok_or_else(|| {
300                    RouterError::PromptNotFound(format!("Prompt '{}' not found", prompt_name))
301                })?;
302
303            // Validate required arguments
304            if let Some(args) = &prompt.arguments {
305                for arg in args {
306                    if arg.required.is_some()
307                        && arg.required.unwrap()
308                        && (!arguments.contains_key(&arg.name)
309                            || arguments
310                                .get(&arg.name)
311                                .and_then(Value::as_str)
312                                .is_none_or(str::is_empty))
313                    {
314                        return Err(RouterError::InvalidParams(format!(
315                            "Missing required argument: '{}'",
316                            arg.name
317                        )));
318                    }
319                }
320            }
321
322            // Now get the prompt content
323            let description = self
324                .get_prompt(prompt_name)
325                .await
326                .map_err(|e| RouterError::Internal(e.to_string()))?;
327
328            // Validate prompt arguments for potential security issues from user text input
329            // Checks:
330            // - Prompt must be less than 10000 total characters
331            // - Argument keys must be less than 1000 characters
332            // - Argument values must be less than 1000 characters
333            // - Dangerous patterns, eg "../", "//", "\\\\", "<script>", "{{", "}}"
334            for (key, value) in arguments.iter() {
335                // Check for empty or overly long keys/values
336                if key.is_empty() || key.len() > 1000 {
337                    return Err(RouterError::InvalidParams(
338                        "Argument keys must be between 1-1000 characters".into(),
339                    ));
340                }
341
342                let value_str = value.as_str().unwrap_or_default();
343                if value_str.len() > 1000 {
344                    return Err(RouterError::InvalidParams(
345                        "Argument values must not exceed 1000 characters".into(),
346                    ));
347                }
348
349                // Check for potentially dangerous patterns
350                let dangerous_patterns = ["../", "//", "\\\\", "<script>", "{{", "}}"];
351                for pattern in dangerous_patterns {
352                    if key.contains(pattern) || value_str.contains(pattern) {
353                        return Err(RouterError::InvalidParams(format!(
354                            "Arguments contain potentially unsafe pattern: {}",
355                            pattern
356                        )));
357                    }
358                }
359            }
360
361            // Validate the prompt description length
362            if description.len() > 10000 {
363                return Err(RouterError::Internal(
364                    "Prompt description exceeds maximum allowed length".into(),
365                ));
366            }
367
368            // Create a mutable copy of the description to fill in arguments
369            let mut description_filled = description.clone();
370
371            // Replace each argument placeholder with its value from the arguments object
372            for (key, value) in arguments {
373                let placeholder = format!("{{{}}}", key);
374                description_filled =
375                    description_filled.replace(&placeholder, value.as_str().unwrap_or_default());
376            }
377
378            let messages = vec![PromptMessage::new_text(
379                PromptMessageRole::User,
380                description_filled.to_string(),
381            )];
382
383            // Build the final response
384            let mut response = self.create_response(req.id);
385            response.result = Some(
386                serde_json::to_value(GetPromptResult {
387                    description: Some(description_filled),
388                    messages,
389                })
390                .map_err(|e| RouterError::Internal(format!("JSON serialization error: {}", e)))?,
391            );
392            Ok(response)
393        }
394    }
395}
396
397pub struct RouterService<T>(pub T);
398
399impl<T> Service<JsonRpcRequest> for RouterService<T>
400where
401    T: Router + Clone + Send + Sync + 'static,
402{
403    type Response = JsonRpcResponse;
404    type Error = BoxError;
405    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
406
407    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
408        Poll::Ready(Ok(()))
409    }
410
411    fn call(&mut self, req: JsonRpcRequest) -> Self::Future {
412        let this = self.0.clone();
413
414        Box::pin(async move {
415            let result = match req.method.as_str() {
416                "initialize" => this.handle_initialize(req).await,
417                "tools/list" => this.handle_tools_list(req).await,
418                "tools/call" => this.handle_tools_call(req).await,
419                "resources/list" => this.handle_resources_list(req).await,
420                "resources/read" => this.handle_resources_read(req).await,
421                "prompts/list" => this.handle_prompts_list(req).await,
422                "prompts/get" => this.handle_prompts_get(req).await,
423                _ => {
424                    let mut response = this.create_response(req.id);
425                    response.error = Some(RouterError::MethodNotFound(req.method).into());
426                    Ok(response)
427                }
428            };
429
430            result.map_err(BoxError::from)
431        })
432    }
433}