1use std::collections::HashMap;
2
3use crate::{
4 error::McpResult,
5 types::{
6 messages::{
7 CallToolRequest, CompleteRequest, CompleteResult, GetPromptRequest, ListPromptsResult,
8 ListResourcesResult, ListToolsResult, ReadResourceRequest,
9 },
10 prompt::{GetPromptResult, Prompt},
11 resource::{ReadResourceResult, Resource, ResourceTemplate},
12 tool::{CallToolResult, Tool},
13 },
14};
15
16use crate::server::handler::{
17 CompletionHandlerFn, PromptHandlerFn, ResourceHandlerFn, ToolHandlerFn,
18};
19
20pub struct ToolRoute {
23 pub tool: Tool,
24 pub handler: ToolHandlerFn,
25}
26
27pub struct ResourceRoute {
30 pub resource: Resource,
31 pub handler: ResourceHandlerFn,
32}
33
34pub struct ResourceTemplateRoute {
35 pub template: ResourceTemplate,
36 pub handler: ResourceHandlerFn,
37}
38
39pub struct PromptRoute {
42 pub prompt: Prompt,
43 pub handler: PromptHandlerFn,
44 pub completion_handler: Option<CompletionHandlerFn>,
45}
46
47pub struct CompletionRoute {
50 pub handler: CompletionHandlerFn,
51}
52
53#[derive(Default)]
57pub struct Router {
58 tools: HashMap<String, ToolRoute>,
59 resources: HashMap<String, ResourceRoute>,
60 resource_templates: Vec<ResourceTemplateRoute>,
61 prompts: HashMap<String, PromptRoute>,
62 completion_handler: Option<CompletionHandlerFn>,
64 resource_completions: HashMap<String, CompletionHandlerFn>,
66}
67
68impl Router {
69 pub fn new() -> Self {
70 Self {
71 tools: HashMap::new(),
72 resources: HashMap::new(),
73 resource_templates: Vec::new(),
74 prompts: HashMap::new(),
75 completion_handler: None,
76 resource_completions: HashMap::new(),
77 }
78 }
79
80 pub fn add_tool(&mut self, tool: Tool, handler: ToolHandlerFn) {
83 self.tools
84 .insert(tool.name.clone(), ToolRoute { tool, handler });
85 }
86
87 pub fn list_tools(&self, _cursor: Option<&str>) -> ListToolsResult {
88 let tools: Vec<Tool> = self.tools.values().map(|r| r.tool.clone()).collect();
89 ListToolsResult {
90 tools,
91 next_cursor: None,
92 }
93 }
94
95 pub async fn call_tool(&self, req: CallToolRequest) -> McpResult<CallToolResult> {
96 let route = self
97 .tools
98 .get(&req.name)
99 .ok_or_else(|| crate::error::McpError::ToolNotFound(req.name.clone()))?;
100 (route.handler)(req).await
101 }
102
103 pub fn add_resource(&mut self, resource: Resource, handler: ResourceHandlerFn) {
106 self.resources
107 .insert(resource.uri.clone(), ResourceRoute { resource, handler });
108 }
109
110 pub fn add_resource_template(
111 &mut self,
112 template: ResourceTemplate,
113 handler: ResourceHandlerFn,
114 ) {
115 self.resource_templates
116 .push(ResourceTemplateRoute { template, handler });
117 }
118
119 pub fn list_resources(&self, _cursor: Option<&str>) -> ListResourcesResult {
120 let resources: Vec<Resource> = self
121 .resources
122 .values()
123 .map(|r| r.resource.clone())
124 .collect();
125 ListResourcesResult {
126 resources,
127 next_cursor: None,
128 }
129 }
130
131 pub async fn read_resource(&self, req: ReadResourceRequest) -> McpResult<ReadResourceResult> {
132 if let Some(route) = self.resources.get(&req.uri) {
133 return (route.handler)(req).await;
134 }
135 for tpl in &self.resource_templates {
136 if uri_matches_template(&req.uri, &tpl.template.uri_template) {
137 return (tpl.handler)(req).await;
138 }
139 }
140 Err(crate::error::McpError::ResourceNotFound(req.uri))
141 }
142
143 pub fn add_prompt(&mut self, prompt: Prompt, handler: PromptHandlerFn) {
146 self.prompts.insert(
147 prompt.name.clone(),
148 PromptRoute {
149 prompt,
150 handler,
151 completion_handler: None,
152 },
153 );
154 }
155
156 pub fn add_prompt_with_completion(
157 &mut self,
158 prompt: Prompt,
159 handler: PromptHandlerFn,
160 completion_handler: CompletionHandlerFn,
161 ) {
162 self.prompts.insert(
163 prompt.name.clone(),
164 PromptRoute {
165 prompt,
166 handler,
167 completion_handler: Some(completion_handler),
168 },
169 );
170 }
171
172 pub fn list_prompts(&self, _cursor: Option<&str>) -> ListPromptsResult {
173 let prompts: Vec<Prompt> = self.prompts.values().map(|r| r.prompt.clone()).collect();
174 ListPromptsResult {
175 prompts,
176 next_cursor: None,
177 }
178 }
179
180 pub async fn get_prompt(&self, req: GetPromptRequest) -> McpResult<GetPromptResult> {
181 let route = self
182 .prompts
183 .get(&req.name)
184 .ok_or_else(|| crate::error::McpError::PromptNotFound(req.name.clone()))?;
185 (route.handler)(req).await
186 }
187
188 pub fn set_completion_handler(&mut self, handler: CompletionHandlerFn) {
192 self.completion_handler = Some(handler);
193 }
194
195 pub fn add_resource_completion(&mut self, uri_pattern: String, handler: CompletionHandlerFn) {
197 self.resource_completions.insert(uri_pattern, handler);
198 }
199
200 pub async fn complete(&self, req: CompleteRequest) -> McpResult<CompleteResult> {
203 use crate::types::messages::CompletionReference;
204
205 match &req.reference {
206 CompletionReference::Prompt { name } => {
207 if let Some(route) = self.prompts.get(name) {
209 if let Some(handler) = &route.completion_handler {
210 return handler(req).await;
211 }
212 }
213 }
214 CompletionReference::Resource { uri } => {
215 if let Some(handler) = self.resource_completions.get(uri) {
217 return handler(req.clone()).await;
218 }
219 for (pattern, handler) in &self.resource_completions {
221 if uri_matches_template(uri, pattern) {
222 return handler(req.clone()).await;
223 }
224 }
225 }
226 }
227
228 if let Some(handler) = &self.completion_handler {
230 return handler(req).await;
231 }
232
233 Ok(CompleteResult::empty())
235 }
236
237 pub fn has_completions(&self) -> bool {
238 self.completion_handler.is_some()
239 || !self.resource_completions.is_empty()
240 || self
241 .prompts
242 .values()
243 .any(|r| r.completion_handler.is_some())
244 }
245
246 pub fn has_tools(&self) -> bool {
249 !self.tools.is_empty()
250 }
251 pub fn has_resources(&self) -> bool {
252 !self.resources.is_empty() || !self.resource_templates.is_empty()
253 }
254 pub fn has_prompts(&self) -> bool {
255 !self.prompts.is_empty()
256 }
257}
258
259fn uri_matches_template(uri: &str, template: &str) -> bool {
262 let re = template_to_pattern(template);
263 pattern_match(uri, &re)
264}
265
266fn template_to_pattern(template: &str) -> String {
267 let mut re = String::from("^");
268 let mut chars = template.chars().peekable();
269 while let Some(c) = chars.next() {
270 if c == '{' {
271 for inner in chars.by_ref() {
272 if inner == '}' {
273 break;
274 }
275 }
276 re.push_str("[^/]+");
277 } else {
278 match c {
279 '.' | '*' | '+' | '?' | '(' | ')' | '[' | ']' | '^' | '$' | '|' | '\\' => {
280 re.push('\\');
281 re.push(c);
282 }
283 _ => re.push(c),
284 }
285 }
286 }
287 re.push('$');
288 re
289}
290
291fn pattern_match(text: &str, pattern: &str) -> bool {
292 let trimmed = pattern.trim_start_matches('^').trim_end_matches('$');
293 match_inner(text, trimmed)
294}
295
296fn match_inner(text: &str, pattern: &str) -> bool {
297 if pattern.is_empty() {
298 return text.is_empty();
299 }
300 if let Some(rest) = pattern.strip_prefix("[^/]+") {
301 let slash_pos = text.find('/').unwrap_or(text.len());
302 if slash_pos == 0 {
303 return false;
304 }
305 for end in 1..=slash_pos {
306 if match_inner(&text[end..], rest) {
307 return true;
308 }
309 }
310 false
311 } else {
312 let (pat_char, rest_pat) = if pattern.starts_with('\\') && pattern.len() >= 2 {
313 (pattern.chars().nth(1).unwrap(), &pattern[2..])
314 } else {
315 let c = pattern.chars().next().unwrap();
316 (c, &pattern[c.len_utf8()..])
317 };
318 text.starts_with(pat_char) && match_inner(&text[pat_char.len_utf8()..], rest_pat)
319 }
320}