mcpkit_server/
router.rs

1//! Request routing for MCP servers.
2//!
3//! This module provides the routing infrastructure that dispatches
4//! incoming requests to the appropriate handler methods.
5//!
6//! # MCP Method Categories
7//!
8//! - **Initialization**: `initialize`, `ping`
9//! - **Tools**: `tools/list`, `tools/call`
10//! - **Resources**: `resources/list`, `resources/read`, `resources/subscribe`
11//! - **Prompts**: `prompts/list`, `prompts/get`
12//! - **Tasks**: `tasks/list`, `tasks/get`, `tasks/cancel`
13//! - **Sampling**: `sampling/createMessage`
14//! - **Completions**: `completion/complete`
15
16use mcpkit_core::error::McpError;
17use mcpkit_core::protocol::Request;
18use serde_json::Value;
19
20/// Standard MCP method names as defined in the MCP specification.
21pub mod methods {
22    /// Initialize the connection and negotiate capabilities.
23    pub const INITIALIZE: &str = "initialize";
24    /// Ping to check if the connection is alive.
25    pub const PING: &str = "ping";
26
27    /// List available tools.
28    pub const TOOLS_LIST: &str = "tools/list";
29    /// Call a specific tool with arguments.
30    pub const TOOLS_CALL: &str = "tools/call";
31
32    /// List available resources.
33    pub const RESOURCES_LIST: &str = "resources/list";
34    /// Read the contents of a resource.
35    pub const RESOURCES_READ: &str = "resources/read";
36    /// List available resource templates.
37    pub const RESOURCES_TEMPLATES_LIST: &str = "resources/templates/list";
38    /// Subscribe to resource updates.
39    pub const RESOURCES_SUBSCRIBE: &str = "resources/subscribe";
40    /// Unsubscribe from resource updates.
41    pub const RESOURCES_UNSUBSCRIBE: &str = "resources/unsubscribe";
42
43    /// List available prompts.
44    pub const PROMPTS_LIST: &str = "prompts/list";
45    /// Get a specific prompt with arguments.
46    pub const PROMPTS_GET: &str = "prompts/get";
47
48    /// List running tasks.
49    pub const TASKS_LIST: &str = "tasks/list";
50    /// Get status of a specific task.
51    pub const TASKS_GET: &str = "tasks/get";
52    /// Cancel a running task.
53    pub const TASKS_CANCEL: &str = "tasks/cancel";
54
55    /// Request the client to sample from a language model.
56    pub const SAMPLING_CREATE_MESSAGE: &str = "sampling/createMessage";
57
58    /// Request completion suggestions.
59    pub const COMPLETION_COMPLETE: &str = "completion/complete";
60
61    /// Set the logging level.
62    pub const LOGGING_SET_LEVEL: &str = "logging/setLevel";
63
64    /// Create an elicitation request.
65    pub const ELICITATION_CREATE: &str = "elicitation/create";
66}
67
68/// Standard MCP notification names as defined in the MCP specification.
69pub mod notifications {
70    /// Sent by client after successful initialization.
71    pub const INITIALIZED: &str = "notifications/initialized";
72    /// Sent when a request is cancelled.
73    pub const CANCELLED: &str = "notifications/cancelled";
74    /// Sent to report progress on a long-running operation.
75    pub const PROGRESS: &str = "notifications/progress";
76    /// Sent to deliver a log message.
77    pub const MESSAGE: &str = "notifications/message";
78    /// Sent when a resource's content has changed.
79    pub const RESOURCES_UPDATED: &str = "notifications/resources/updated";
80    /// Sent when the list of available resources has changed.
81    pub const RESOURCES_LIST_CHANGED: &str = "notifications/resources/list_changed";
82    /// Sent when the list of available tools has changed.
83    pub const TOOLS_LIST_CHANGED: &str = "notifications/tools/list_changed";
84    /// Sent when the list of available prompts has changed.
85    pub const PROMPTS_LIST_CHANGED: &str = "notifications/prompts/list_changed";
86}
87
88/// Represents a parsed MCP request with typed parameters.
89///
90/// This enum provides a type-safe representation of all MCP request types,
91/// with parameters parsed into their appropriate structures.
92#[derive(Debug)]
93pub enum ParsedRequest {
94    /// Initialize request to establish connection.
95    Initialize(InitializeParams),
96    /// Ping request to check connection health.
97    Ping,
98
99    /// Request to list available tools.
100    ToolsList(ListParams),
101    /// Request to call a specific tool.
102    ToolsCall(ToolCallParams),
103
104    /// Request to list available resources.
105    ResourcesList(ListParams),
106    /// Request to read a resource's contents.
107    ResourcesRead(ResourceReadParams),
108    /// Request to list resource templates.
109    ResourcesTemplatesList(ListParams),
110    /// Request to subscribe to resource updates.
111    ResourcesSubscribe(ResourceSubscribeParams),
112    /// Request to unsubscribe from resource updates.
113    ResourcesUnsubscribe(ResourceUnsubscribeParams),
114
115    /// Request to list available prompts.
116    PromptsList(ListParams),
117    /// Request to get a specific prompt.
118    PromptsGet(PromptGetParams),
119
120    /// Request to list running tasks.
121    TasksList(ListParams),
122    /// Request to get a task's status.
123    TasksGet(TaskGetParams),
124    /// Request to cancel a running task.
125    TasksCancel(TaskCancelParams),
126
127    /// Request for the client to sample from a language model.
128    SamplingCreateMessage(SamplingParams),
129
130    /// Request for completion suggestions.
131    CompletionComplete(CompletionParams),
132
133    /// Request to set the logging level.
134    LoggingSetLevel(LogLevelParams),
135
136    /// An unrecognized method name.
137    Unknown(String),
138}
139
140/// Common list parameters with optional cursor for pagination.
141#[derive(Debug, Default)]
142pub struct ListParams {
143    /// Optional cursor for pagination.
144    pub cursor: Option<String>,
145}
146
147/// Initialize request parameters.
148#[derive(Debug)]
149pub struct InitializeParams {
150    /// The protocol version requested by the client.
151    pub protocol_version: String,
152    /// Information about the client.
153    pub client_info: ClientInfo,
154    /// Client capabilities.
155    pub capabilities: Value,
156}
157
158/// Client info from initialize request.
159#[derive(Debug)]
160pub struct ClientInfo {
161    /// The name of the client application.
162    pub name: String,
163    /// The version of the client application.
164    pub version: String,
165}
166
167/// Tool call parameters.
168#[derive(Debug)]
169pub struct ToolCallParams {
170    /// The name of the tool to call.
171    pub name: String,
172    /// Arguments to pass to the tool.
173    pub arguments: Value,
174}
175
176/// Resource read parameters.
177#[derive(Debug)]
178pub struct ResourceReadParams {
179    /// The URI of the resource to read.
180    pub uri: String,
181}
182
183/// Resource subscribe parameters.
184#[derive(Debug)]
185pub struct ResourceSubscribeParams {
186    /// The URI of the resource to subscribe to.
187    pub uri: String,
188}
189
190/// Resource unsubscribe parameters.
191#[derive(Debug)]
192pub struct ResourceUnsubscribeParams {
193    /// The URI of the resource to unsubscribe from.
194    pub uri: String,
195}
196
197/// Prompt get parameters.
198#[derive(Debug)]
199pub struct PromptGetParams {
200    /// The name of the prompt to get.
201    pub name: String,
202    /// Optional arguments to pass to the prompt.
203    pub arguments: Option<Value>,
204}
205
206/// Task get parameters.
207#[derive(Debug)]
208pub struct TaskGetParams {
209    /// The ID of the task to get.
210    pub task_id: String,
211}
212
213/// Task cancel parameters.
214#[derive(Debug)]
215pub struct TaskCancelParams {
216    /// The ID of the task to cancel.
217    pub task_id: String,
218}
219
220/// Sampling create message parameters.
221#[derive(Debug)]
222pub struct SamplingParams {
223    /// The messages to sample from.
224    pub messages: Vec<Value>,
225    /// Optional model preferences.
226    pub model_preferences: Option<Value>,
227    /// Optional system prompt.
228    pub system_prompt: Option<String>,
229    /// Optional maximum number of tokens.
230    pub max_tokens: Option<u32>,
231}
232
233/// Completion parameters.
234#[derive(Debug)]
235pub struct CompletionParams {
236    /// The type of reference (e.g., "ref/resource", "ref/prompt").
237    pub ref_type: String,
238    /// The value of the reference (URI or name).
239    pub ref_value: String,
240    /// Optional argument for completion context.
241    pub argument: Option<CompletionArgument>,
242}
243
244/// Completion argument providing context for completion.
245#[derive(Debug)]
246pub struct CompletionArgument {
247    /// The name of the argument.
248    pub name: String,
249    /// The current value being completed.
250    pub value: String,
251}
252
253/// Log level parameters.
254#[derive(Debug)]
255pub struct LogLevelParams {
256    /// The log level to set (e.g., "debug", "info", "warn", "error").
257    pub level: String,
258}
259
260/// Parse a request into a typed representation.
261pub fn parse_request(request: &Request) -> Result<ParsedRequest, McpError> {
262    let method = request.method.as_ref();
263    let params = request.params.as_ref();
264
265    match method {
266        methods::INITIALIZE => {
267            let params = params.ok_or_else(|| {
268                McpError::invalid_params(method, "missing params")
269            })?;
270
271            Ok(ParsedRequest::Initialize(InitializeParams {
272                protocol_version: params
273                    .get("protocolVersion")
274                    .and_then(|v| v.as_str())
275                    .unwrap_or("unknown")
276                    .to_string(),
277                client_info: ClientInfo {
278                    name: params
279                        .get("clientInfo")
280                        .and_then(|v| v.get("name"))
281                        .and_then(|v| v.as_str())
282                        .unwrap_or("unknown")
283                        .to_string(),
284                    version: params
285                        .get("clientInfo")
286                        .and_then(|v| v.get("version"))
287                        .and_then(|v| v.as_str())
288                        .unwrap_or("unknown")
289                        .to_string(),
290                },
291                capabilities: params
292                    .get("capabilities")
293                    .cloned()
294                    .unwrap_or(Value::Object(serde_json::Map::new())),
295            }))
296        }
297
298        methods::PING => Ok(ParsedRequest::Ping),
299
300        methods::TOOLS_LIST => Ok(ParsedRequest::ToolsList(parse_list_params(params))),
301
302        methods::TOOLS_CALL => {
303            let params = params.ok_or_else(|| {
304                McpError::invalid_params(method, "missing params")
305            })?;
306
307            let name = params
308                .get("name")
309                .and_then(|v| v.as_str())
310                .ok_or_else(|| McpError::invalid_params(method, "missing name"))?
311                .to_string();
312
313            let arguments = params
314                .get("arguments")
315                .cloned()
316                .unwrap_or(Value::Object(serde_json::Map::new()));
317
318            Ok(ParsedRequest::ToolsCall(ToolCallParams { name, arguments }))
319        }
320
321        methods::RESOURCES_LIST => Ok(ParsedRequest::ResourcesList(parse_list_params(params))),
322
323        methods::RESOURCES_READ => {
324            let params = params.ok_or_else(|| {
325                McpError::invalid_params(method, "missing params")
326            })?;
327
328            let uri = params
329                .get("uri")
330                .and_then(|v| v.as_str())
331                .ok_or_else(|| McpError::invalid_params(method, "missing uri"))?
332                .to_string();
333
334            Ok(ParsedRequest::ResourcesRead(ResourceReadParams { uri }))
335        }
336
337        methods::RESOURCES_TEMPLATES_LIST => {
338            Ok(ParsedRequest::ResourcesTemplatesList(parse_list_params(params)))
339        }
340
341        methods::RESOURCES_SUBSCRIBE => {
342            let params = params.ok_or_else(|| {
343                McpError::invalid_params(method, "missing params")
344            })?;
345
346            let uri = params
347                .get("uri")
348                .and_then(|v| v.as_str())
349                .ok_or_else(|| McpError::invalid_params(method, "missing uri"))?
350                .to_string();
351
352            Ok(ParsedRequest::ResourcesSubscribe(ResourceSubscribeParams { uri }))
353        }
354
355        methods::RESOURCES_UNSUBSCRIBE => {
356            let params = params.ok_or_else(|| {
357                McpError::invalid_params(method, "missing params")
358            })?;
359
360            let uri = params
361                .get("uri")
362                .and_then(|v| v.as_str())
363                .ok_or_else(|| McpError::invalid_params(method, "missing uri"))?
364                .to_string();
365
366            Ok(ParsedRequest::ResourcesUnsubscribe(ResourceUnsubscribeParams { uri }))
367        }
368
369        methods::PROMPTS_LIST => Ok(ParsedRequest::PromptsList(parse_list_params(params))),
370
371        methods::PROMPTS_GET => {
372            let params = params.ok_or_else(|| {
373                McpError::invalid_params(method, "missing params")
374            })?;
375
376            let name = params
377                .get("name")
378                .and_then(|v| v.as_str())
379                .ok_or_else(|| McpError::invalid_params(method, "missing name"))?
380                .to_string();
381
382            let arguments = params.get("arguments").cloned();
383
384            Ok(ParsedRequest::PromptsGet(PromptGetParams { name, arguments }))
385        }
386
387        methods::TASKS_LIST => Ok(ParsedRequest::TasksList(parse_list_params(params))),
388
389        methods::TASKS_GET => {
390            let params = params.ok_or_else(|| {
391                McpError::invalid_params(method, "missing params")
392            })?;
393
394            let task_id = params
395                .get("taskId")
396                .and_then(|v| v.as_str())
397                .ok_or_else(|| {
398                    McpError::invalid_params(method, "missing taskId")
399                })?
400                .to_string();
401
402            Ok(ParsedRequest::TasksGet(TaskGetParams { task_id }))
403        }
404
405        methods::TASKS_CANCEL => {
406            let params = params.ok_or_else(|| {
407                McpError::invalid_params(method, "missing params")
408            })?;
409
410            let task_id = params
411                .get("taskId")
412                .and_then(|v| v.as_str())
413                .ok_or_else(|| {
414                    McpError::invalid_params(method, "missing taskId")
415                })?
416                .to_string();
417
418            Ok(ParsedRequest::TasksCancel(TaskCancelParams { task_id }))
419        }
420
421        methods::SAMPLING_CREATE_MESSAGE => {
422            let params = params.ok_or_else(|| {
423                McpError::invalid_params(method, "missing params")
424            })?;
425
426            let messages = params
427                .get("messages")
428                .and_then(|v| v.as_array())
429                .ok_or_else(|| {
430                    McpError::invalid_params(method, "missing messages")
431                })?
432                .clone();
433
434            Ok(ParsedRequest::SamplingCreateMessage(SamplingParams {
435                messages,
436                model_preferences: params.get("modelPreferences").cloned(),
437                system_prompt: params
438                    .get("systemPrompt")
439                    .and_then(|v| v.as_str())
440                    .map(String::from),
441                max_tokens: params.get("maxTokens").and_then(|v| v.as_u64()).map(|v| v as u32),
442            }))
443        }
444
445        methods::COMPLETION_COMPLETE => {
446            let params = params.ok_or_else(|| {
447                McpError::invalid_params(method, "missing params")
448            })?;
449
450            let ref_obj = params.get("ref").ok_or_else(|| {
451                McpError::invalid_params(method, "missing ref")
452            })?;
453
454            Ok(ParsedRequest::CompletionComplete(CompletionParams {
455                ref_type: ref_obj
456                    .get("type")
457                    .and_then(|v| v.as_str())
458                    .unwrap_or("")
459                    .to_string(),
460                ref_value: ref_obj
461                    .get("uri")
462                    .or_else(|| ref_obj.get("name"))
463                    .and_then(|v| v.as_str())
464                    .unwrap_or("")
465                    .to_string(),
466                argument: params.get("argument").map(|arg| CompletionArgument {
467                    name: arg.get("name").and_then(|v| v.as_str()).unwrap_or("").to_string(),
468                    value: arg.get("value").and_then(|v| v.as_str()).unwrap_or("").to_string(),
469                }),
470            }))
471        }
472
473        methods::LOGGING_SET_LEVEL => {
474            let params = params.ok_or_else(|| {
475                McpError::invalid_params(method, "missing params")
476            })?;
477
478            let level = params
479                .get("level")
480                .and_then(|v| v.as_str())
481                .ok_or_else(|| {
482                    McpError::invalid_params(method, "missing level")
483                })?
484                .to_string();
485
486            Ok(ParsedRequest::LoggingSetLevel(LogLevelParams { level }))
487        }
488
489        _ => Ok(ParsedRequest::Unknown(method.to_string())),
490    }
491}
492
493/// Parse common list parameters.
494fn parse_list_params(params: Option<&Value>) -> ListParams {
495    ListParams {
496        cursor: params
497            .and_then(|p| p.get("cursor"))
498            .and_then(|v| v.as_str())
499            .map(String::from),
500    }
501}
502
503#[cfg(test)]
504mod tests {
505    use super::*;
506    use mcpkit_core::protocol::Request;
507
508    fn make_request(method: &'static str, params: Option<Value>) -> Request {
509        if let Some(p) = params {
510            Request::with_params(method, 1u64, p)
511        } else {
512            Request::new(method, 1u64)
513        }
514    }
515
516    #[test]
517    fn test_parse_ping() {
518        let request = make_request("ping", None);
519        let parsed = parse_request(&request).unwrap();
520        assert!(matches!(parsed, ParsedRequest::Ping));
521    }
522
523    #[test]
524    fn test_parse_tools_list() {
525        let request = make_request("tools/list", None);
526        let parsed = parse_request(&request).unwrap();
527        assert!(matches!(parsed, ParsedRequest::ToolsList(_)));
528    }
529
530    #[test]
531    fn test_parse_tools_call() {
532        let request = make_request(
533            "tools/call",
534            Some(serde_json::json!({
535                "name": "search",
536                "arguments": {"query": "test"}
537            })),
538        );
539        let parsed = parse_request(&request).unwrap();
540
541        if let ParsedRequest::ToolsCall(params) = parsed {
542            assert_eq!(params.name, "search");
543        } else {
544            panic!("Expected ToolsCall");
545        }
546    }
547
548    #[test]
549    fn test_parse_unknown_method() {
550        let request = make_request("unknown/method", None);
551        let parsed = parse_request(&request).unwrap();
552
553        if let ParsedRequest::Unknown(method) = parsed {
554            assert_eq!(method, "unknown/method");
555        } else {
556            panic!("Expected Unknown");
557        }
558    }
559
560    #[test]
561    fn test_parse_initialize() {
562        let request = make_request(
563            "initialize",
564            Some(serde_json::json!({
565                "protocolVersion": "2025-11-25",
566                "clientInfo": {
567                    "name": "test-client",
568                    "version": "1.0.0"
569                },
570                "capabilities": {}
571            })),
572        );
573        let parsed = parse_request(&request).unwrap();
574
575        if let ParsedRequest::Initialize(params) = parsed {
576            assert_eq!(params.protocol_version, "2025-11-25");
577            assert_eq!(params.client_info.name, "test-client");
578        } else {
579            panic!("Expected Initialize");
580        }
581    }
582}