Skip to main content

mcp_kit/server/
router.rs

1use std::collections::HashMap;
2
3use crate::{
4    error::McpResult,
5    types::{
6        messages::{
7            CallToolRequest, CompleteRequest, CompleteResult, GetPromptRequest, ListPromptsResult,
8            ListResourcesResult, ListToolsResult, ReadResourceRequest,
9        },
10        prompt::{GetPromptResult, Prompt},
11        resource::{ReadResourceResult, Resource, ResourceTemplate},
12        tool::{CallToolResult, Tool},
13    },
14};
15
16use crate::server::handler::{
17    CompletionHandlerFn, PromptHandlerFn, ResourceHandlerFn, ToolHandlerFn,
18};
19
20// ─── Tool route ───────────────────────────────────────────────────────────────
21
22pub struct ToolRoute {
23    pub tool: Tool,
24    pub handler: ToolHandlerFn,
25}
26
27// ─── Resource route ───────────────────────────────────────────────────────────
28
29pub struct ResourceRoute {
30    pub resource: Resource,
31    pub handler: ResourceHandlerFn,
32}
33
34pub struct ResourceTemplateRoute {
35    pub template: ResourceTemplate,
36    pub handler: ResourceHandlerFn,
37}
38
39// ─── Prompt route ─────────────────────────────────────────────────────────────
40
41pub struct PromptRoute {
42    pub prompt: Prompt,
43    pub handler: PromptHandlerFn,
44    pub completion_handler: Option<CompletionHandlerFn>,
45}
46
47// ─── Completion route ─────────────────────────────────────────────────────────
48
49pub struct CompletionRoute {
50    pub handler: CompletionHandlerFn,
51}
52
53// ─── Router ───────────────────────────────────────────────────────────────────
54
55/// Central routing table — maps method names to handlers.
56#[derive(Default)]
57pub struct Router {
58    tools: HashMap<String, ToolRoute>,
59    resources: HashMap<String, ResourceRoute>,
60    resource_templates: Vec<ResourceTemplateRoute>,
61    prompts: HashMap<String, PromptRoute>,
62    /// Global completion handler (called for any completion request)
63    completion_handler: Option<CompletionHandlerFn>,
64    /// Resource-specific completion handlers
65    resource_completions: HashMap<String, CompletionHandlerFn>,
66}
67
68impl Router {
69    pub fn new() -> Self {
70        Self {
71            tools: HashMap::new(),
72            resources: HashMap::new(),
73            resource_templates: Vec::new(),
74            prompts: HashMap::new(),
75            completion_handler: None,
76            resource_completions: HashMap::new(),
77        }
78    }
79
80    // ── Tool registration ─────────────────────────────────────────────────────
81
82    pub fn add_tool(&mut self, tool: Tool, handler: ToolHandlerFn) {
83        self.tools
84            .insert(tool.name.clone(), ToolRoute { tool, handler });
85    }
86
87    pub fn list_tools(&self, _cursor: Option<&str>) -> ListToolsResult {
88        let tools: Vec<Tool> = self.tools.values().map(|r| r.tool.clone()).collect();
89        ListToolsResult {
90            tools,
91            next_cursor: None,
92        }
93    }
94
95    pub async fn call_tool(&self, req: CallToolRequest) -> McpResult<CallToolResult> {
96        let route = self
97            .tools
98            .get(&req.name)
99            .ok_or_else(|| crate::error::McpError::ToolNotFound(req.name.clone()))?;
100        (route.handler)(req).await
101    }
102
103    // ── Resource registration ─────────────────────────────────────────────────
104
105    pub fn add_resource(&mut self, resource: Resource, handler: ResourceHandlerFn) {
106        self.resources
107            .insert(resource.uri.clone(), ResourceRoute { resource, handler });
108    }
109
110    pub fn add_resource_template(
111        &mut self,
112        template: ResourceTemplate,
113        handler: ResourceHandlerFn,
114    ) {
115        self.resource_templates
116            .push(ResourceTemplateRoute { template, handler });
117    }
118
119    pub fn list_resources(&self, _cursor: Option<&str>) -> ListResourcesResult {
120        let resources: Vec<Resource> = self
121            .resources
122            .values()
123            .map(|r| r.resource.clone())
124            .collect();
125        ListResourcesResult {
126            resources,
127            next_cursor: None,
128        }
129    }
130
131    pub async fn read_resource(&self, req: ReadResourceRequest) -> McpResult<ReadResourceResult> {
132        if let Some(route) = self.resources.get(&req.uri) {
133            return (route.handler)(req).await;
134        }
135        for tpl in &self.resource_templates {
136            if uri_matches_template(&req.uri, &tpl.template.uri_template) {
137                return (tpl.handler)(req).await;
138            }
139        }
140        Err(crate::error::McpError::ResourceNotFound(req.uri))
141    }
142
143    // ── Prompt registration ───────────────────────────────────────────────────
144
145    pub fn add_prompt(&mut self, prompt: Prompt, handler: PromptHandlerFn) {
146        self.prompts.insert(
147            prompt.name.clone(),
148            PromptRoute {
149                prompt,
150                handler,
151                completion_handler: None,
152            },
153        );
154    }
155
156    pub fn add_prompt_with_completion(
157        &mut self,
158        prompt: Prompt,
159        handler: PromptHandlerFn,
160        completion_handler: CompletionHandlerFn,
161    ) {
162        self.prompts.insert(
163            prompt.name.clone(),
164            PromptRoute {
165                prompt,
166                handler,
167                completion_handler: Some(completion_handler),
168            },
169        );
170    }
171
172    pub fn list_prompts(&self, _cursor: Option<&str>) -> ListPromptsResult {
173        let prompts: Vec<Prompt> = self.prompts.values().map(|r| r.prompt.clone()).collect();
174        ListPromptsResult {
175            prompts,
176            next_cursor: None,
177        }
178    }
179
180    pub async fn get_prompt(&self, req: GetPromptRequest) -> McpResult<GetPromptResult> {
181        let route = self
182            .prompts
183            .get(&req.name)
184            .ok_or_else(|| crate::error::McpError::PromptNotFound(req.name.clone()))?;
185        (route.handler)(req).await
186    }
187
188    // ── Completion registration ───────────────────────────────────────────────
189
190    /// Set a global completion handler that handles all completion requests.
191    pub fn set_completion_handler(&mut self, handler: CompletionHandlerFn) {
192        self.completion_handler = Some(handler);
193    }
194
195    /// Set a completion handler for a specific resource URI pattern.
196    pub fn add_resource_completion(&mut self, uri_pattern: String, handler: CompletionHandlerFn) {
197        self.resource_completions.insert(uri_pattern, handler);
198    }
199
200    /// Handle a completion request.
201    /// Priority: prompt-specific → resource-specific → global → empty result
202    pub async fn complete(&self, req: CompleteRequest) -> McpResult<CompleteResult> {
203        use crate::types::messages::CompletionReference;
204
205        match &req.reference {
206            CompletionReference::Prompt { name } => {
207                // Check prompt-specific completion handler
208                if let Some(route) = self.prompts.get(name) {
209                    if let Some(handler) = &route.completion_handler {
210                        return handler(req).await;
211                    }
212                }
213            }
214            CompletionReference::Resource { uri } => {
215                // Check resource-specific completion handler
216                if let Some(handler) = self.resource_completions.get(uri) {
217                    return handler(req.clone()).await;
218                }
219                // Check template matches
220                for (pattern, handler) in &self.resource_completions {
221                    if uri_matches_template(uri, pattern) {
222                        return handler(req.clone()).await;
223                    }
224                }
225            }
226        }
227
228        // Fall back to global handler
229        if let Some(handler) = &self.completion_handler {
230            return handler(req).await;
231        }
232
233        // Default: empty completion
234        Ok(CompleteResult::empty())
235    }
236
237    pub fn has_completions(&self) -> bool {
238        self.completion_handler.is_some()
239            || !self.resource_completions.is_empty()
240            || self
241                .prompts
242                .values()
243                .any(|r| r.completion_handler.is_some())
244    }
245
246    // ── Capability introspection ──────────────────────────────────────────────
247
248    pub fn has_tools(&self) -> bool {
249        !self.tools.is_empty()
250    }
251    pub fn has_resources(&self) -> bool {
252        !self.resources.is_empty() || !self.resource_templates.is_empty()
253    }
254    pub fn has_prompts(&self) -> bool {
255        !self.prompts.is_empty()
256    }
257}
258
259// ─── URI template matching ────────────────────────────────────────────────────
260
261fn uri_matches_template(uri: &str, template: &str) -> bool {
262    let re = template_to_pattern(template);
263    pattern_match(uri, &re)
264}
265
266fn template_to_pattern(template: &str) -> String {
267    let mut re = String::from("^");
268    let mut chars = template.chars().peekable();
269    while let Some(c) = chars.next() {
270        if c == '{' {
271            for inner in chars.by_ref() {
272                if inner == '}' {
273                    break;
274                }
275            }
276            re.push_str("[^/]+");
277        } else {
278            match c {
279                '.' | '*' | '+' | '?' | '(' | ')' | '[' | ']' | '^' | '$' | '|' | '\\' => {
280                    re.push('\\');
281                    re.push(c);
282                }
283                _ => re.push(c),
284            }
285        }
286    }
287    re.push('$');
288    re
289}
290
291fn pattern_match(text: &str, pattern: &str) -> bool {
292    let trimmed = pattern.trim_start_matches('^').trim_end_matches('$');
293    match_inner(text, trimmed)
294}
295
296fn match_inner(text: &str, pattern: &str) -> bool {
297    if pattern.is_empty() {
298        return text.is_empty();
299    }
300    if let Some(rest) = pattern.strip_prefix("[^/]+") {
301        let slash_pos = text.find('/').unwrap_or(text.len());
302        if slash_pos == 0 {
303            return false;
304        }
305        for end in 1..=slash_pos {
306            if match_inner(&text[end..], rest) {
307                return true;
308            }
309        }
310        false
311    } else {
312        let (pat_char, rest_pat) = if pattern.starts_with('\\') && pattern.len() >= 2 {
313            (pattern.chars().nth(1).unwrap(), &pattern[2..])
314        } else {
315            let c = pattern.chars().next().unwrap();
316            (c, &pattern[c.len_utf8()..])
317        };
318        text.starts_with(pat_char) && match_inner(&text[pat_char.len_utf8()..], rest_pat)
319    }
320}