mcpkit_server/
router.rs

1//! Request routing for MCP servers.
2//!
3//! This module provides the routing infrastructure that dispatches
4//! incoming requests to the appropriate handler methods.
5//!
6//! # MCP Method Categories
7//!
8//! - **Initialization**: `initialize`, `ping`
9//! - **Tools**: `tools/list`, `tools/call`
10//! - **Resources**: `resources/list`, `resources/read`, `resources/subscribe`
11//! - **Prompts**: `prompts/list`, `prompts/get`
12//! - **Tasks**: `tasks/list`, `tasks/get`, `tasks/cancel`
13//! - **Sampling**: `sampling/createMessage`
14//! - **Completions**: `completion/complete`
15
16use mcpkit_core::error::McpError;
17use mcpkit_core::protocol::Request;
18use serde_json::Value;
19
20/// Standard MCP method names as defined in the MCP specification.
21pub mod methods {
22    /// Initialize the connection and negotiate capabilities.
23    pub const INITIALIZE: &str = "initialize";
24    /// Ping to check if the connection is alive.
25    pub const PING: &str = "ping";
26
27    /// List available tools.
28    pub const TOOLS_LIST: &str = "tools/list";
29    /// Call a specific tool with arguments.
30    pub const TOOLS_CALL: &str = "tools/call";
31
32    /// List available resources.
33    pub const RESOURCES_LIST: &str = "resources/list";
34    /// Read the contents of a resource.
35    pub const RESOURCES_READ: &str = "resources/read";
36    /// List available resource templates.
37    pub const RESOURCES_TEMPLATES_LIST: &str = "resources/templates/list";
38    /// Subscribe to resource updates.
39    pub const RESOURCES_SUBSCRIBE: &str = "resources/subscribe";
40    /// Unsubscribe from resource updates.
41    pub const RESOURCES_UNSUBSCRIBE: &str = "resources/unsubscribe";
42
43    /// List available prompts.
44    pub const PROMPTS_LIST: &str = "prompts/list";
45    /// Get a specific prompt with arguments.
46    pub const PROMPTS_GET: &str = "prompts/get";
47
48    /// List running tasks.
49    pub const TASKS_LIST: &str = "tasks/list";
50    /// Get status of a specific task.
51    pub const TASKS_GET: &str = "tasks/get";
52    /// Cancel a running task.
53    pub const TASKS_CANCEL: &str = "tasks/cancel";
54
55    /// Request the client to sample from a language model.
56    pub const SAMPLING_CREATE_MESSAGE: &str = "sampling/createMessage";
57
58    /// Request completion suggestions.
59    pub const COMPLETION_COMPLETE: &str = "completion/complete";
60
61    /// Set the logging level.
62    pub const LOGGING_SET_LEVEL: &str = "logging/setLevel";
63
64    /// Create an elicitation request.
65    pub const ELICITATION_CREATE: &str = "elicitation/create";
66}
67
68/// Standard MCP notification names as defined in the MCP specification.
69pub mod notifications {
70    /// Sent by client after successful initialization.
71    pub const INITIALIZED: &str = "notifications/initialized";
72    /// Sent when a request is cancelled.
73    pub const CANCELLED: &str = "notifications/cancelled";
74    /// Sent to report progress on a long-running operation.
75    pub const PROGRESS: &str = "notifications/progress";
76    /// Sent to deliver a log message.
77    pub const MESSAGE: &str = "notifications/message";
78    /// Sent when a resource's content has changed.
79    pub const RESOURCES_UPDATED: &str = "notifications/resources/updated";
80    /// Sent when the list of available resources has changed.
81    pub const RESOURCES_LIST_CHANGED: &str = "notifications/resources/list_changed";
82    /// Sent when the list of available tools has changed.
83    pub const TOOLS_LIST_CHANGED: &str = "notifications/tools/list_changed";
84    /// Sent when the list of available prompts has changed.
85    pub const PROMPTS_LIST_CHANGED: &str = "notifications/prompts/list_changed";
86}
87
88/// Represents a parsed MCP request with typed parameters.
89///
90/// This enum provides a type-safe representation of all MCP request types,
91/// with parameters parsed into their appropriate structures.
92#[derive(Debug)]
93pub enum ParsedRequest {
94    /// Initialize request to establish connection.
95    Initialize(InitializeParams),
96    /// Ping request to check connection health.
97    Ping,
98
99    /// Request to list available tools.
100    ToolsList(ListParams),
101    /// Request to call a specific tool.
102    ToolsCall(ToolCallParams),
103
104    /// Request to list available resources.
105    ResourcesList(ListParams),
106    /// Request to read a resource's contents.
107    ResourcesRead(ResourceReadParams),
108    /// Request to list resource templates.
109    ResourcesTemplatesList(ListParams),
110    /// Request to subscribe to resource updates.
111    ResourcesSubscribe(ResourceSubscribeParams),
112    /// Request to unsubscribe from resource updates.
113    ResourcesUnsubscribe(ResourceUnsubscribeParams),
114
115    /// Request to list available prompts.
116    PromptsList(ListParams),
117    /// Request to get a specific prompt.
118    PromptsGet(PromptGetParams),
119
120    /// Request to list running tasks.
121    TasksList(ListParams),
122    /// Request to get a task's status.
123    TasksGet(TaskGetParams),
124    /// Request to cancel a running task.
125    TasksCancel(TaskCancelParams),
126
127    /// Request for the client to sample from a language model.
128    SamplingCreateMessage(SamplingParams),
129
130    /// Request for completion suggestions.
131    CompletionComplete(CompletionParams),
132
133    /// Request to set the logging level.
134    LoggingSetLevel(LogLevelParams),
135
136    /// An unrecognized method name.
137    Unknown(String),
138}
139
140/// Common list parameters with optional cursor for pagination.
141#[derive(Debug, Default)]
142pub struct ListParams {
143    /// Optional cursor for pagination.
144    pub cursor: Option<String>,
145}
146
147/// Initialize request parameters.
148#[derive(Debug)]
149pub struct InitializeParams {
150    /// The protocol version requested by the client.
151    pub protocol_version: String,
152    /// Information about the client.
153    pub client_info: ClientInfo,
154    /// Client capabilities.
155    pub capabilities: Value,
156}
157
158/// Client info from initialize request.
159#[derive(Debug)]
160pub struct ClientInfo {
161    /// The name of the client application.
162    pub name: String,
163    /// The version of the client application.
164    pub version: String,
165}
166
167/// Tool call parameters.
168#[derive(Debug)]
169pub struct ToolCallParams {
170    /// The name of the tool to call.
171    pub name: String,
172    /// Arguments to pass to the tool.
173    pub arguments: Value,
174}
175
176/// Resource read parameters.
177#[derive(Debug)]
178pub struct ResourceReadParams {
179    /// The URI of the resource to read.
180    pub uri: String,
181}
182
183/// Resource subscribe parameters.
184#[derive(Debug)]
185pub struct ResourceSubscribeParams {
186    /// The URI of the resource to subscribe to.
187    pub uri: String,
188}
189
190/// Resource unsubscribe parameters.
191#[derive(Debug)]
192pub struct ResourceUnsubscribeParams {
193    /// The URI of the resource to unsubscribe from.
194    pub uri: String,
195}
196
197/// Prompt get parameters.
198#[derive(Debug)]
199pub struct PromptGetParams {
200    /// The name of the prompt to get.
201    pub name: String,
202    /// Optional arguments to pass to the prompt.
203    pub arguments: Option<Value>,
204}
205
206/// Task get parameters.
207#[derive(Debug)]
208pub struct TaskGetParams {
209    /// The ID of the task to get.
210    pub task_id: String,
211}
212
213/// Task cancel parameters.
214#[derive(Debug)]
215pub struct TaskCancelParams {
216    /// The ID of the task to cancel.
217    pub task_id: String,
218}
219
220/// Sampling create message parameters.
221#[derive(Debug)]
222pub struct SamplingParams {
223    /// The messages to sample from.
224    pub messages: Vec<Value>,
225    /// Optional model preferences.
226    pub model_preferences: Option<Value>,
227    /// Optional system prompt.
228    pub system_prompt: Option<String>,
229    /// Optional maximum number of tokens.
230    pub max_tokens: Option<u32>,
231}
232
233/// Completion parameters.
234#[derive(Debug)]
235pub struct CompletionParams {
236    /// The type of reference (e.g., "ref/resource", "ref/prompt").
237    pub ref_type: String,
238    /// The value of the reference (URI or name).
239    pub ref_value: String,
240    /// Optional argument for completion context.
241    pub argument: Option<CompletionArgument>,
242}
243
244/// Completion argument providing context for completion.
245#[derive(Debug)]
246pub struct CompletionArgument {
247    /// The name of the argument.
248    pub name: String,
249    /// The current value being completed.
250    pub value: String,
251}
252
253/// Log level parameters.
254#[derive(Debug)]
255pub struct LogLevelParams {
256    /// The log level to set (e.g., "debug", "info", "warn", "error").
257    pub level: String,
258}
259
260/// Parse a request into a typed representation.
261pub fn parse_request(request: &Request) -> Result<ParsedRequest, McpError> {
262    let method = request.method.as_ref();
263    let params = request.params.as_ref();
264
265    match method {
266        methods::INITIALIZE => {
267            let params =
268                params.ok_or_else(|| McpError::invalid_params(method, "missing params"))?;
269
270            Ok(ParsedRequest::Initialize(InitializeParams {
271                protocol_version: params
272                    .get("protocolVersion")
273                    .and_then(|v| v.as_str())
274                    .unwrap_or("unknown")
275                    .to_string(),
276                client_info: ClientInfo {
277                    name: params
278                        .get("clientInfo")
279                        .and_then(|v| v.get("name"))
280                        .and_then(|v| v.as_str())
281                        .unwrap_or("unknown")
282                        .to_string(),
283                    version: params
284                        .get("clientInfo")
285                        .and_then(|v| v.get("version"))
286                        .and_then(|v| v.as_str())
287                        .unwrap_or("unknown")
288                        .to_string(),
289                },
290                capabilities: params
291                    .get("capabilities")
292                    .cloned()
293                    .unwrap_or_else(|| Value::Object(serde_json::Map::new())),
294            }))
295        }
296
297        methods::PING => Ok(ParsedRequest::Ping),
298
299        methods::TOOLS_LIST => Ok(ParsedRequest::ToolsList(parse_list_params(params))),
300
301        methods::TOOLS_CALL => {
302            let params =
303                params.ok_or_else(|| McpError::invalid_params(method, "missing params"))?;
304
305            let name = params
306                .get("name")
307                .and_then(|v| v.as_str())
308                .ok_or_else(|| McpError::invalid_params(method, "missing name"))?
309                .to_string();
310
311            let arguments = params
312                .get("arguments")
313                .cloned()
314                .unwrap_or_else(|| Value::Object(serde_json::Map::new()));
315
316            Ok(ParsedRequest::ToolsCall(ToolCallParams { name, arguments }))
317        }
318
319        methods::RESOURCES_LIST => Ok(ParsedRequest::ResourcesList(parse_list_params(params))),
320
321        methods::RESOURCES_READ => {
322            let params =
323                params.ok_or_else(|| McpError::invalid_params(method, "missing params"))?;
324
325            let uri = params
326                .get("uri")
327                .and_then(|v| v.as_str())
328                .ok_or_else(|| McpError::invalid_params(method, "missing uri"))?
329                .to_string();
330
331            Ok(ParsedRequest::ResourcesRead(ResourceReadParams { uri }))
332        }
333
334        methods::RESOURCES_TEMPLATES_LIST => Ok(ParsedRequest::ResourcesTemplatesList(
335            parse_list_params(params),
336        )),
337
338        methods::RESOURCES_SUBSCRIBE => {
339            let params =
340                params.ok_or_else(|| McpError::invalid_params(method, "missing params"))?;
341
342            let uri = params
343                .get("uri")
344                .and_then(|v| v.as_str())
345                .ok_or_else(|| McpError::invalid_params(method, "missing uri"))?
346                .to_string();
347
348            Ok(ParsedRequest::ResourcesSubscribe(ResourceSubscribeParams {
349                uri,
350            }))
351        }
352
353        methods::RESOURCES_UNSUBSCRIBE => {
354            let params =
355                params.ok_or_else(|| McpError::invalid_params(method, "missing params"))?;
356
357            let uri = params
358                .get("uri")
359                .and_then(|v| v.as_str())
360                .ok_or_else(|| McpError::invalid_params(method, "missing uri"))?
361                .to_string();
362
363            Ok(ParsedRequest::ResourcesUnsubscribe(
364                ResourceUnsubscribeParams { uri },
365            ))
366        }
367
368        methods::PROMPTS_LIST => Ok(ParsedRequest::PromptsList(parse_list_params(params))),
369
370        methods::PROMPTS_GET => {
371            let params =
372                params.ok_or_else(|| McpError::invalid_params(method, "missing params"))?;
373
374            let name = params
375                .get("name")
376                .and_then(|v| v.as_str())
377                .ok_or_else(|| McpError::invalid_params(method, "missing name"))?
378                .to_string();
379
380            let arguments = params.get("arguments").cloned();
381
382            Ok(ParsedRequest::PromptsGet(PromptGetParams {
383                name,
384                arguments,
385            }))
386        }
387
388        methods::TASKS_LIST => Ok(ParsedRequest::TasksList(parse_list_params(params))),
389
390        methods::TASKS_GET => {
391            let params =
392                params.ok_or_else(|| McpError::invalid_params(method, "missing params"))?;
393
394            let task_id = params
395                .get("taskId")
396                .and_then(|v| v.as_str())
397                .ok_or_else(|| McpError::invalid_params(method, "missing taskId"))?
398                .to_string();
399
400            Ok(ParsedRequest::TasksGet(TaskGetParams { task_id }))
401        }
402
403        methods::TASKS_CANCEL => {
404            let params =
405                params.ok_or_else(|| McpError::invalid_params(method, "missing params"))?;
406
407            let task_id = params
408                .get("taskId")
409                .and_then(|v| v.as_str())
410                .ok_or_else(|| McpError::invalid_params(method, "missing taskId"))?
411                .to_string();
412
413            Ok(ParsedRequest::TasksCancel(TaskCancelParams { task_id }))
414        }
415
416        methods::SAMPLING_CREATE_MESSAGE => {
417            let params =
418                params.ok_or_else(|| McpError::invalid_params(method, "missing params"))?;
419
420            let messages = params
421                .get("messages")
422                .and_then(|v| v.as_array())
423                .ok_or_else(|| McpError::invalid_params(method, "missing messages"))?
424                .clone();
425
426            Ok(ParsedRequest::SamplingCreateMessage(SamplingParams {
427                messages,
428                model_preferences: params.get("modelPreferences").cloned(),
429                system_prompt: params
430                    .get("systemPrompt")
431                    .and_then(|v| v.as_str())
432                    .map(String::from),
433                max_tokens: params
434                    .get("maxTokens")
435                    .and_then(serde_json::Value::as_u64)
436                    .map(|v| v as u32),
437            }))
438        }
439
440        methods::COMPLETION_COMPLETE => {
441            let params =
442                params.ok_or_else(|| McpError::invalid_params(method, "missing params"))?;
443
444            let ref_obj = params
445                .get("ref")
446                .ok_or_else(|| McpError::invalid_params(method, "missing ref"))?;
447
448            Ok(ParsedRequest::CompletionComplete(CompletionParams {
449                ref_type: ref_obj
450                    .get("type")
451                    .and_then(|v| v.as_str())
452                    .unwrap_or("")
453                    .to_string(),
454                ref_value: ref_obj
455                    .get("uri")
456                    .or_else(|| ref_obj.get("name"))
457                    .and_then(|v| v.as_str())
458                    .unwrap_or("")
459                    .to_string(),
460                argument: params.get("argument").map(|arg| CompletionArgument {
461                    name: arg
462                        .get("name")
463                        .and_then(|v| v.as_str())
464                        .unwrap_or("")
465                        .to_string(),
466                    value: arg
467                        .get("value")
468                        .and_then(|v| v.as_str())
469                        .unwrap_or("")
470                        .to_string(),
471                }),
472            }))
473        }
474
475        methods::LOGGING_SET_LEVEL => {
476            let params =
477                params.ok_or_else(|| McpError::invalid_params(method, "missing params"))?;
478
479            let level = params
480                .get("level")
481                .and_then(|v| v.as_str())
482                .ok_or_else(|| McpError::invalid_params(method, "missing level"))?
483                .to_string();
484
485            Ok(ParsedRequest::LoggingSetLevel(LogLevelParams { level }))
486        }
487
488        _ => Ok(ParsedRequest::Unknown(method.to_string())),
489    }
490}
491
492/// Parse common list parameters.
493fn parse_list_params(params: Option<&Value>) -> ListParams {
494    ListParams {
495        cursor: params
496            .and_then(|p| p.get("cursor"))
497            .and_then(|v| v.as_str())
498            .map(String::from),
499    }
500}
501
502#[cfg(test)]
503mod tests {
504    use super::*;
505    use mcpkit_core::protocol::Request;
506
507    fn make_request(method: &'static str, params: Option<Value>) -> Request {
508        if let Some(p) = params {
509            Request::with_params(method, 1u64, p)
510        } else {
511            Request::new(method, 1u64)
512        }
513    }
514
515    #[test]
516    fn test_parse_ping() {
517        let request = make_request("ping", None);
518        let parsed = parse_request(&request).unwrap();
519        assert!(matches!(parsed, ParsedRequest::Ping));
520    }
521
522    #[test]
523    fn test_parse_tools_list() {
524        let request = make_request("tools/list", None);
525        let parsed = parse_request(&request).unwrap();
526        assert!(matches!(parsed, ParsedRequest::ToolsList(_)));
527    }
528
529    #[test]
530    fn test_parse_tools_call() {
531        let request = make_request(
532            "tools/call",
533            Some(serde_json::json!({
534                "name": "search",
535                "arguments": {"query": "test"}
536            })),
537        );
538        let parsed = parse_request(&request).unwrap();
539
540        if let ParsedRequest::ToolsCall(params) = parsed {
541            assert_eq!(params.name, "search");
542        } else {
543            panic!("Expected ToolsCall");
544        }
545    }
546
547    #[test]
548    fn test_parse_unknown_method() {
549        let request = make_request("unknown/method", None);
550        let parsed = parse_request(&request).unwrap();
551
552        if let ParsedRequest::Unknown(method) = parsed {
553            assert_eq!(method, "unknown/method");
554        } else {
555            panic!("Expected Unknown");
556        }
557    }
558
559    #[test]
560    fn test_parse_initialize() {
561        let request = make_request(
562            "initialize",
563            Some(serde_json::json!({
564                "protocolVersion": "2025-11-25",
565                "clientInfo": {
566                    "name": "test-client",
567                    "version": "1.0.0"
568                },
569                "capabilities": {}
570            })),
571        );
572        let parsed = parse_request(&request).unwrap();
573
574        if let ParsedRequest::Initialize(params) = parsed {
575            assert_eq!(params.protocol_version, "2025-11-25");
576            assert_eq!(params.client_info.name, "test-client");
577        } else {
578            panic!("Expected Initialize");
579        }
580    }
581}