Skip to main content

mcp_kit/server/
router.rs

1use std::collections::HashMap;
2
3use crate::{
4    error::McpResult,
5    types::{
6        messages::{
7            CallToolRequest, GetPromptRequest, ListPromptsResult, ListResourcesResult,
8            ListToolsResult, ReadResourceRequest,
9        },
10        prompt::{GetPromptResult, Prompt},
11        resource::{ReadResourceResult, Resource, ResourceTemplate},
12        tool::{CallToolResult, Tool},
13    },
14};
15
16use crate::server::handler::{PromptHandlerFn, ResourceHandlerFn, ToolHandlerFn};
17
18// ─── Tool route ───────────────────────────────────────────────────────────────
19
20pub struct ToolRoute {
21    pub tool: Tool,
22    pub handler: ToolHandlerFn,
23}
24
25// ─── Resource route ───────────────────────────────────────────────────────────
26
27pub struct ResourceRoute {
28    pub resource: Resource,
29    pub handler: ResourceHandlerFn,
30}
31
32pub struct ResourceTemplateRoute {
33    pub template: ResourceTemplate,
34    pub handler: ResourceHandlerFn,
35}
36
37// ─── Prompt route ─────────────────────────────────────────────────────────────
38
39pub struct PromptRoute {
40    pub prompt: Prompt,
41    pub handler: PromptHandlerFn,
42}
43
44// ─── Router ───────────────────────────────────────────────────────────────────
45
46/// Central routing table — maps method names to handlers.
47#[derive(Default)]
48pub struct Router {
49    tools: HashMap<String, ToolRoute>,
50    resources: HashMap<String, ResourceRoute>,
51    resource_templates: Vec<ResourceTemplateRoute>,
52    prompts: HashMap<String, PromptRoute>,
53}
54
55impl Router {
56    pub fn new() -> Self {
57        Self::default()
58    }
59
60    // ── Tool registration ─────────────────────────────────────────────────────
61
62    pub fn add_tool(&mut self, tool: Tool, handler: ToolHandlerFn) {
63        self.tools
64            .insert(tool.name.clone(), ToolRoute { tool, handler });
65    }
66
67    pub fn list_tools(&self, _cursor: Option<&str>) -> ListToolsResult {
68        let tools: Vec<Tool> = self.tools.values().map(|r| r.tool.clone()).collect();
69        ListToolsResult {
70            tools,
71            next_cursor: None,
72        }
73    }
74
75    pub async fn call_tool(&self, req: CallToolRequest) -> McpResult<CallToolResult> {
76        let route = self
77            .tools
78            .get(&req.name)
79            .ok_or_else(|| crate::error::McpError::ToolNotFound(req.name.clone()))?;
80        (route.handler)(req).await
81    }
82
83    // ── Resource registration ─────────────────────────────────────────────────
84
85    pub fn add_resource(&mut self, resource: Resource, handler: ResourceHandlerFn) {
86        self.resources
87            .insert(resource.uri.clone(), ResourceRoute { resource, handler });
88    }
89
90    pub fn add_resource_template(
91        &mut self,
92        template: ResourceTemplate,
93        handler: ResourceHandlerFn,
94    ) {
95        self.resource_templates
96            .push(ResourceTemplateRoute { template, handler });
97    }
98
99    pub fn list_resources(&self, _cursor: Option<&str>) -> ListResourcesResult {
100        let resources: Vec<Resource> = self
101            .resources
102            .values()
103            .map(|r| r.resource.clone())
104            .collect();
105        ListResourcesResult {
106            resources,
107            next_cursor: None,
108        }
109    }
110
111    pub async fn read_resource(&self, req: ReadResourceRequest) -> McpResult<ReadResourceResult> {
112        if let Some(route) = self.resources.get(&req.uri) {
113            return (route.handler)(req).await;
114        }
115        for tpl in &self.resource_templates {
116            if uri_matches_template(&req.uri, &tpl.template.uri_template) {
117                return (tpl.handler)(req).await;
118            }
119        }
120        Err(crate::error::McpError::ResourceNotFound(req.uri))
121    }
122
123    // ── Prompt registration ───────────────────────────────────────────────────
124
125    pub fn add_prompt(&mut self, prompt: Prompt, handler: PromptHandlerFn) {
126        self.prompts
127            .insert(prompt.name.clone(), PromptRoute { prompt, handler });
128    }
129
130    pub fn list_prompts(&self, _cursor: Option<&str>) -> ListPromptsResult {
131        let prompts: Vec<Prompt> = self.prompts.values().map(|r| r.prompt.clone()).collect();
132        ListPromptsResult {
133            prompts,
134            next_cursor: None,
135        }
136    }
137
138    pub async fn get_prompt(&self, req: GetPromptRequest) -> McpResult<GetPromptResult> {
139        let route = self
140            .prompts
141            .get(&req.name)
142            .ok_or_else(|| crate::error::McpError::PromptNotFound(req.name.clone()))?;
143        (route.handler)(req).await
144    }
145
146    // ── Capability introspection ──────────────────────────────────────────────
147
148    pub fn has_tools(&self) -> bool {
149        !self.tools.is_empty()
150    }
151    pub fn has_resources(&self) -> bool {
152        !self.resources.is_empty() || !self.resource_templates.is_empty()
153    }
154    pub fn has_prompts(&self) -> bool {
155        !self.prompts.is_empty()
156    }
157}
158
159// ─── URI template matching ────────────────────────────────────────────────────
160
161fn uri_matches_template(uri: &str, template: &str) -> bool {
162    let re = template_to_pattern(template);
163    pattern_match(uri, &re)
164}
165
166fn template_to_pattern(template: &str) -> String {
167    let mut re = String::from("^");
168    let mut chars = template.chars().peekable();
169    while let Some(c) = chars.next() {
170        if c == '{' {
171            for inner in chars.by_ref() {
172                if inner == '}' {
173                    break;
174                }
175            }
176            re.push_str("[^/]+");
177        } else {
178            match c {
179                '.' | '*' | '+' | '?' | '(' | ')' | '[' | ']' | '^' | '$' | '|' | '\\' => {
180                    re.push('\\');
181                    re.push(c);
182                }
183                _ => re.push(c),
184            }
185        }
186    }
187    re.push('$');
188    re
189}
190
191fn pattern_match(text: &str, pattern: &str) -> bool {
192    let trimmed = pattern.trim_start_matches('^').trim_end_matches('$');
193    match_inner(text, trimmed)
194}
195
196fn match_inner(text: &str, pattern: &str) -> bool {
197    if pattern.is_empty() {
198        return text.is_empty();
199    }
200    if let Some(rest) = pattern.strip_prefix("[^/]+") {
201        let slash_pos = text.find('/').unwrap_or(text.len());
202        if slash_pos == 0 {
203            return false;
204        }
205        for end in 1..=slash_pos {
206            if match_inner(&text[end..], rest) {
207                return true;
208            }
209        }
210        false
211    } else {
212        let (pat_char, rest_pat) = if pattern.starts_with('\\') && pattern.len() >= 2 {
213            (pattern.chars().nth(1).unwrap(), &pattern[2..])
214        } else {
215            let c = pattern.chars().next().unwrap();
216            (c, &pattern[c.len_utf8()..])
217        };
218        text.starts_with(pat_char) && match_inner(&text[pat_char.len_utf8()..], rest_pat)
219    }
220}