1use std::collections::HashMap;
2
3use crate::{
4 error::McpResult,
5 types::{
6 messages::{
7 CallToolRequest, GetPromptRequest, ListPromptsResult, ListResourcesResult,
8 ListToolsResult, ReadResourceRequest,
9 },
10 prompt::{GetPromptResult, Prompt},
11 resource::{ReadResourceResult, Resource, ResourceTemplate},
12 tool::{CallToolResult, Tool},
13 },
14};
15
16use crate::server::handler::{PromptHandlerFn, ResourceHandlerFn, ToolHandlerFn};
17
18pub struct ToolRoute {
21 pub tool: Tool,
22 pub handler: ToolHandlerFn,
23}
24
25pub struct ResourceRoute {
28 pub resource: Resource,
29 pub handler: ResourceHandlerFn,
30}
31
32pub struct ResourceTemplateRoute {
33 pub template: ResourceTemplate,
34 pub handler: ResourceHandlerFn,
35}
36
37pub struct PromptRoute {
40 pub prompt: Prompt,
41 pub handler: PromptHandlerFn,
42}
43
44#[derive(Default)]
48pub struct Router {
49 tools: HashMap<String, ToolRoute>,
50 resources: HashMap<String, ResourceRoute>,
51 resource_templates: Vec<ResourceTemplateRoute>,
52 prompts: HashMap<String, PromptRoute>,
53}
54
55impl Router {
56 pub fn new() -> Self {
57 Self::default()
58 }
59
60 pub fn add_tool(&mut self, tool: Tool, handler: ToolHandlerFn) {
63 self.tools
64 .insert(tool.name.clone(), ToolRoute { tool, handler });
65 }
66
67 pub fn list_tools(&self, _cursor: Option<&str>) -> ListToolsResult {
68 let tools: Vec<Tool> = self.tools.values().map(|r| r.tool.clone()).collect();
69 ListToolsResult {
70 tools,
71 next_cursor: None,
72 }
73 }
74
75 pub async fn call_tool(&self, req: CallToolRequest) -> McpResult<CallToolResult> {
76 let route = self
77 .tools
78 .get(&req.name)
79 .ok_or_else(|| crate::error::McpError::ToolNotFound(req.name.clone()))?;
80 (route.handler)(req).await
81 }
82
83 pub fn add_resource(&mut self, resource: Resource, handler: ResourceHandlerFn) {
86 self.resources
87 .insert(resource.uri.clone(), ResourceRoute { resource, handler });
88 }
89
90 pub fn add_resource_template(
91 &mut self,
92 template: ResourceTemplate,
93 handler: ResourceHandlerFn,
94 ) {
95 self.resource_templates
96 .push(ResourceTemplateRoute { template, handler });
97 }
98
99 pub fn list_resources(&self, _cursor: Option<&str>) -> ListResourcesResult {
100 let resources: Vec<Resource> = self
101 .resources
102 .values()
103 .map(|r| r.resource.clone())
104 .collect();
105 ListResourcesResult {
106 resources,
107 next_cursor: None,
108 }
109 }
110
111 pub async fn read_resource(&self, req: ReadResourceRequest) -> McpResult<ReadResourceResult> {
112 if let Some(route) = self.resources.get(&req.uri) {
113 return (route.handler)(req).await;
114 }
115 for tpl in &self.resource_templates {
116 if uri_matches_template(&req.uri, &tpl.template.uri_template) {
117 return (tpl.handler)(req).await;
118 }
119 }
120 Err(crate::error::McpError::ResourceNotFound(req.uri))
121 }
122
123 pub fn add_prompt(&mut self, prompt: Prompt, handler: PromptHandlerFn) {
126 self.prompts
127 .insert(prompt.name.clone(), PromptRoute { prompt, handler });
128 }
129
130 pub fn list_prompts(&self, _cursor: Option<&str>) -> ListPromptsResult {
131 let prompts: Vec<Prompt> = self.prompts.values().map(|r| r.prompt.clone()).collect();
132 ListPromptsResult {
133 prompts,
134 next_cursor: None,
135 }
136 }
137
138 pub async fn get_prompt(&self, req: GetPromptRequest) -> McpResult<GetPromptResult> {
139 let route = self
140 .prompts
141 .get(&req.name)
142 .ok_or_else(|| crate::error::McpError::PromptNotFound(req.name.clone()))?;
143 (route.handler)(req).await
144 }
145
146 pub fn has_tools(&self) -> bool {
149 !self.tools.is_empty()
150 }
151 pub fn has_resources(&self) -> bool {
152 !self.resources.is_empty() || !self.resource_templates.is_empty()
153 }
154 pub fn has_prompts(&self) -> bool {
155 !self.prompts.is_empty()
156 }
157}
158
159fn uri_matches_template(uri: &str, template: &str) -> bool {
162 let re = template_to_pattern(template);
163 pattern_match(uri, &re)
164}
165
166fn template_to_pattern(template: &str) -> String {
167 let mut re = String::from("^");
168 let mut chars = template.chars().peekable();
169 while let Some(c) = chars.next() {
170 if c == '{' {
171 for inner in chars.by_ref() {
172 if inner == '}' {
173 break;
174 }
175 }
176 re.push_str("[^/]+");
177 } else {
178 match c {
179 '.' | '*' | '+' | '?' | '(' | ')' | '[' | ']' | '^' | '$' | '|' | '\\' => {
180 re.push('\\');
181 re.push(c);
182 }
183 _ => re.push(c),
184 }
185 }
186 }
187 re.push('$');
188 re
189}
190
191fn pattern_match(text: &str, pattern: &str) -> bool {
192 let trimmed = pattern.trim_start_matches('^').trim_end_matches('$');
193 match_inner(text, trimmed)
194}
195
196fn match_inner(text: &str, pattern: &str) -> bool {
197 if pattern.is_empty() {
198 return text.is_empty();
199 }
200 if let Some(rest) = pattern.strip_prefix("[^/]+") {
201 let slash_pos = text.find('/').unwrap_or(text.len());
202 if slash_pos == 0 {
203 return false;
204 }
205 for end in 1..=slash_pos {
206 if match_inner(&text[end..], rest) {
207 return true;
208 }
209 }
210 false
211 } else {
212 let (pat_char, rest_pat) = if pattern.starts_with('\\') && pattern.len() >= 2 {
213 (pattern.chars().nth(1).unwrap(), &pattern[2..])
214 } else {
215 let c = pattern.chars().next().unwrap();
216 (c, &pattern[c.len_utf8()..])
217 };
218 text.starts_with(pat_char) && match_inner(&text[pat_char.len_utf8()..], rest_pat)
219 }
220}