1use mcpkit_core::error::McpError;
17use mcpkit_core::protocol::Request;
18use serde_json::Value;
19
20pub mod methods {
22 pub const INITIALIZE: &str = "initialize";
24 pub const PING: &str = "ping";
26
27 pub const TOOLS_LIST: &str = "tools/list";
29 pub const TOOLS_CALL: &str = "tools/call";
31
32 pub const RESOURCES_LIST: &str = "resources/list";
34 pub const RESOURCES_READ: &str = "resources/read";
36 pub const RESOURCES_TEMPLATES_LIST: &str = "resources/templates/list";
38 pub const RESOURCES_SUBSCRIBE: &str = "resources/subscribe";
40 pub const RESOURCES_UNSUBSCRIBE: &str = "resources/unsubscribe";
42
43 pub const PROMPTS_LIST: &str = "prompts/list";
45 pub const PROMPTS_GET: &str = "prompts/get";
47
48 pub const TASKS_LIST: &str = "tasks/list";
50 pub const TASKS_GET: &str = "tasks/get";
52 pub const TASKS_CANCEL: &str = "tasks/cancel";
54
55 pub const SAMPLING_CREATE_MESSAGE: &str = "sampling/createMessage";
57
58 pub const COMPLETION_COMPLETE: &str = "completion/complete";
60
61 pub const LOGGING_SET_LEVEL: &str = "logging/setLevel";
63
64 pub const ELICITATION_CREATE: &str = "elicitation/create";
66}
67
68pub mod notifications {
70 pub const INITIALIZED: &str = "notifications/initialized";
72 pub const CANCELLED: &str = "notifications/cancelled";
74 pub const PROGRESS: &str = "notifications/progress";
76 pub const MESSAGE: &str = "notifications/message";
78 pub const RESOURCES_UPDATED: &str = "notifications/resources/updated";
80 pub const RESOURCES_LIST_CHANGED: &str = "notifications/resources/list_changed";
82 pub const TOOLS_LIST_CHANGED: &str = "notifications/tools/list_changed";
84 pub const PROMPTS_LIST_CHANGED: &str = "notifications/prompts/list_changed";
86}
87
88#[derive(Debug)]
93pub enum ParsedRequest {
94 Initialize(InitializeParams),
96 Ping,
98
99 ToolsList(ListParams),
101 ToolsCall(ToolCallParams),
103
104 ResourcesList(ListParams),
106 ResourcesRead(ResourceReadParams),
108 ResourcesTemplatesList(ListParams),
110 ResourcesSubscribe(ResourceSubscribeParams),
112 ResourcesUnsubscribe(ResourceUnsubscribeParams),
114
115 PromptsList(ListParams),
117 PromptsGet(PromptGetParams),
119
120 TasksList(ListParams),
122 TasksGet(TaskGetParams),
124 TasksCancel(TaskCancelParams),
126
127 SamplingCreateMessage(SamplingParams),
129
130 CompletionComplete(CompletionParams),
132
133 LoggingSetLevel(LogLevelParams),
135
136 Unknown(String),
138}
139
140#[derive(Debug, Default)]
142pub struct ListParams {
143 pub cursor: Option<String>,
145}
146
147#[derive(Debug)]
149pub struct InitializeParams {
150 pub protocol_version: String,
152 pub client_info: ClientInfo,
154 pub capabilities: Value,
156}
157
158#[derive(Debug)]
160pub struct ClientInfo {
161 pub name: String,
163 pub version: String,
165}
166
167#[derive(Debug)]
169pub struct ToolCallParams {
170 pub name: String,
172 pub arguments: Value,
174}
175
176#[derive(Debug)]
178pub struct ResourceReadParams {
179 pub uri: String,
181}
182
183#[derive(Debug)]
185pub struct ResourceSubscribeParams {
186 pub uri: String,
188}
189
190#[derive(Debug)]
192pub struct ResourceUnsubscribeParams {
193 pub uri: String,
195}
196
197#[derive(Debug)]
199pub struct PromptGetParams {
200 pub name: String,
202 pub arguments: Option<Value>,
204}
205
206#[derive(Debug)]
208pub struct TaskGetParams {
209 pub task_id: String,
211}
212
213#[derive(Debug)]
215pub struct TaskCancelParams {
216 pub task_id: String,
218}
219
220#[derive(Debug)]
222pub struct SamplingParams {
223 pub messages: Vec<Value>,
225 pub model_preferences: Option<Value>,
227 pub system_prompt: Option<String>,
229 pub max_tokens: Option<u32>,
231}
232
233#[derive(Debug)]
235pub struct CompletionParams {
236 pub ref_type: String,
238 pub ref_value: String,
240 pub argument: Option<CompletionArgument>,
242}
243
244#[derive(Debug)]
246pub struct CompletionArgument {
247 pub name: String,
249 pub value: String,
251}
252
253#[derive(Debug)]
255pub struct LogLevelParams {
256 pub level: String,
258}
259
260pub fn parse_request(request: &Request) -> Result<ParsedRequest, McpError> {
262 let method = request.method.as_ref();
263 let params = request.params.as_ref();
264
265 match method {
266 methods::INITIALIZE => {
267 let params =
268 params.ok_or_else(|| McpError::invalid_params(method, "missing params"))?;
269
270 Ok(ParsedRequest::Initialize(InitializeParams {
271 protocol_version: params
272 .get("protocolVersion")
273 .and_then(|v| v.as_str())
274 .unwrap_or("unknown")
275 .to_string(),
276 client_info: ClientInfo {
277 name: params
278 .get("clientInfo")
279 .and_then(|v| v.get("name"))
280 .and_then(|v| v.as_str())
281 .unwrap_or("unknown")
282 .to_string(),
283 version: params
284 .get("clientInfo")
285 .and_then(|v| v.get("version"))
286 .and_then(|v| v.as_str())
287 .unwrap_or("unknown")
288 .to_string(),
289 },
290 capabilities: params
291 .get("capabilities")
292 .cloned()
293 .unwrap_or_else(|| Value::Object(serde_json::Map::new())),
294 }))
295 }
296
297 methods::PING => Ok(ParsedRequest::Ping),
298
299 methods::TOOLS_LIST => Ok(ParsedRequest::ToolsList(parse_list_params(params))),
300
301 methods::TOOLS_CALL => {
302 let params =
303 params.ok_or_else(|| McpError::invalid_params(method, "missing params"))?;
304
305 let name = params
306 .get("name")
307 .and_then(|v| v.as_str())
308 .ok_or_else(|| McpError::invalid_params(method, "missing name"))?
309 .to_string();
310
311 let arguments = params
312 .get("arguments")
313 .cloned()
314 .unwrap_or_else(|| Value::Object(serde_json::Map::new()));
315
316 Ok(ParsedRequest::ToolsCall(ToolCallParams { name, arguments }))
317 }
318
319 methods::RESOURCES_LIST => Ok(ParsedRequest::ResourcesList(parse_list_params(params))),
320
321 methods::RESOURCES_READ => {
322 let params =
323 params.ok_or_else(|| McpError::invalid_params(method, "missing params"))?;
324
325 let uri = params
326 .get("uri")
327 .and_then(|v| v.as_str())
328 .ok_or_else(|| McpError::invalid_params(method, "missing uri"))?
329 .to_string();
330
331 Ok(ParsedRequest::ResourcesRead(ResourceReadParams { uri }))
332 }
333
334 methods::RESOURCES_TEMPLATES_LIST => Ok(ParsedRequest::ResourcesTemplatesList(
335 parse_list_params(params),
336 )),
337
338 methods::RESOURCES_SUBSCRIBE => {
339 let params =
340 params.ok_or_else(|| McpError::invalid_params(method, "missing params"))?;
341
342 let uri = params
343 .get("uri")
344 .and_then(|v| v.as_str())
345 .ok_or_else(|| McpError::invalid_params(method, "missing uri"))?
346 .to_string();
347
348 Ok(ParsedRequest::ResourcesSubscribe(ResourceSubscribeParams {
349 uri,
350 }))
351 }
352
353 methods::RESOURCES_UNSUBSCRIBE => {
354 let params =
355 params.ok_or_else(|| McpError::invalid_params(method, "missing params"))?;
356
357 let uri = params
358 .get("uri")
359 .and_then(|v| v.as_str())
360 .ok_or_else(|| McpError::invalid_params(method, "missing uri"))?
361 .to_string();
362
363 Ok(ParsedRequest::ResourcesUnsubscribe(
364 ResourceUnsubscribeParams { uri },
365 ))
366 }
367
368 methods::PROMPTS_LIST => Ok(ParsedRequest::PromptsList(parse_list_params(params))),
369
370 methods::PROMPTS_GET => {
371 let params =
372 params.ok_or_else(|| McpError::invalid_params(method, "missing params"))?;
373
374 let name = params
375 .get("name")
376 .and_then(|v| v.as_str())
377 .ok_or_else(|| McpError::invalid_params(method, "missing name"))?
378 .to_string();
379
380 let arguments = params.get("arguments").cloned();
381
382 Ok(ParsedRequest::PromptsGet(PromptGetParams {
383 name,
384 arguments,
385 }))
386 }
387
388 methods::TASKS_LIST => Ok(ParsedRequest::TasksList(parse_list_params(params))),
389
390 methods::TASKS_GET => {
391 let params =
392 params.ok_or_else(|| McpError::invalid_params(method, "missing params"))?;
393
394 let task_id = params
395 .get("taskId")
396 .and_then(|v| v.as_str())
397 .ok_or_else(|| McpError::invalid_params(method, "missing taskId"))?
398 .to_string();
399
400 Ok(ParsedRequest::TasksGet(TaskGetParams { task_id }))
401 }
402
403 methods::TASKS_CANCEL => {
404 let params =
405 params.ok_or_else(|| McpError::invalid_params(method, "missing params"))?;
406
407 let task_id = params
408 .get("taskId")
409 .and_then(|v| v.as_str())
410 .ok_or_else(|| McpError::invalid_params(method, "missing taskId"))?
411 .to_string();
412
413 Ok(ParsedRequest::TasksCancel(TaskCancelParams { task_id }))
414 }
415
416 methods::SAMPLING_CREATE_MESSAGE => {
417 let params =
418 params.ok_or_else(|| McpError::invalid_params(method, "missing params"))?;
419
420 let messages = params
421 .get("messages")
422 .and_then(|v| v.as_array())
423 .ok_or_else(|| McpError::invalid_params(method, "missing messages"))?
424 .clone();
425
426 Ok(ParsedRequest::SamplingCreateMessage(SamplingParams {
427 messages,
428 model_preferences: params.get("modelPreferences").cloned(),
429 system_prompt: params
430 .get("systemPrompt")
431 .and_then(|v| v.as_str())
432 .map(String::from),
433 max_tokens: params
434 .get("maxTokens")
435 .and_then(serde_json::Value::as_u64)
436 .map(|v| v as u32),
437 }))
438 }
439
440 methods::COMPLETION_COMPLETE => {
441 let params =
442 params.ok_or_else(|| McpError::invalid_params(method, "missing params"))?;
443
444 let ref_obj = params
445 .get("ref")
446 .ok_or_else(|| McpError::invalid_params(method, "missing ref"))?;
447
448 Ok(ParsedRequest::CompletionComplete(CompletionParams {
449 ref_type: ref_obj
450 .get("type")
451 .and_then(|v| v.as_str())
452 .unwrap_or("")
453 .to_string(),
454 ref_value: ref_obj
455 .get("uri")
456 .or_else(|| ref_obj.get("name"))
457 .and_then(|v| v.as_str())
458 .unwrap_or("")
459 .to_string(),
460 argument: params.get("argument").map(|arg| CompletionArgument {
461 name: arg
462 .get("name")
463 .and_then(|v| v.as_str())
464 .unwrap_or("")
465 .to_string(),
466 value: arg
467 .get("value")
468 .and_then(|v| v.as_str())
469 .unwrap_or("")
470 .to_string(),
471 }),
472 }))
473 }
474
475 methods::LOGGING_SET_LEVEL => {
476 let params =
477 params.ok_or_else(|| McpError::invalid_params(method, "missing params"))?;
478
479 let level = params
480 .get("level")
481 .and_then(|v| v.as_str())
482 .ok_or_else(|| McpError::invalid_params(method, "missing level"))?
483 .to_string();
484
485 Ok(ParsedRequest::LoggingSetLevel(LogLevelParams { level }))
486 }
487
488 _ => Ok(ParsedRequest::Unknown(method.to_string())),
489 }
490}
491
492fn parse_list_params(params: Option<&Value>) -> ListParams {
494 ListParams {
495 cursor: params
496 .and_then(|p| p.get("cursor"))
497 .and_then(|v| v.as_str())
498 .map(String::from),
499 }
500}
501
502use crate::context::Context;
510use crate::handler::{PromptHandler, ResourceHandler, ToolHandler};
511use mcpkit_core::types::CallToolResult;
512
513pub async fn route_tools<TH: ToolHandler + Send + Sync>(
526 handler: &TH,
527 method: &str,
528 params: Option<&serde_json::Value>,
529 ctx: &Context<'_>,
530) -> Option<Result<serde_json::Value, McpError>> {
531 match method {
532 methods::TOOLS_LIST => {
533 tracing::debug!("Listing available tools");
534 let result = handler.list_tools(ctx).await;
535 match &result {
536 Ok(tools) => tracing::debug!(count = tools.len(), "Listed tools"),
537 Err(e) => tracing::warn!(error = %e, "Failed to list tools"),
538 }
539 Some(result.map(|tools| serde_json::json!({ "tools": tools })))
540 }
541 methods::TOOLS_CALL => {
542 let result = async {
543 let params = params.ok_or_else(|| {
544 McpError::invalid_params(methods::TOOLS_CALL, "missing params")
545 })?;
546 let name = params.get("name").and_then(|v| v.as_str()).ok_or_else(|| {
547 McpError::invalid_params(methods::TOOLS_CALL, "missing tool name")
548 })?;
549 let args = params
550 .get("arguments")
551 .cloned()
552 .unwrap_or_else(|| serde_json::json!({}));
553
554 tracing::info!(tool = %name, "Calling tool");
555 let start = std::time::Instant::now();
556 let output = handler.call_tool(name, args, ctx).await;
557 let duration = start.elapsed();
558
559 match &output {
560 Ok(_) => tracing::info!(
561 tool = %name,
562 duration_ms = duration.as_millis(),
563 "Tool call completed"
564 ),
565 Err(e) => tracing::warn!(
566 tool = %name,
567 duration_ms = duration.as_millis(),
568 error = %e,
569 "Tool call failed"
570 ),
571 }
572
573 let output = output?;
574 let result: CallToolResult = output.into();
575 Ok(serde_json::to_value(result).unwrap_or_else(|_| serde_json::json!({})))
576 }
577 .await;
578 Some(result)
579 }
580 _ => None,
581 }
582}
583
584pub async fn route_resources<RH: ResourceHandler + Send + Sync>(
597 handler: &RH,
598 method: &str,
599 params: Option<&serde_json::Value>,
600 ctx: &Context<'_>,
601) -> Option<Result<serde_json::Value, McpError>> {
602 match method {
603 methods::RESOURCES_LIST => {
604 tracing::debug!("Listing available resources");
605 let result = handler.list_resources(ctx).await;
606 match &result {
607 Ok(resources) => tracing::debug!(count = resources.len(), "Listed resources"),
608 Err(e) => tracing::warn!(error = %e, "Failed to list resources"),
609 }
610 Some(result.map(|resources| serde_json::json!({ "resources": resources })))
611 }
612 methods::RESOURCES_TEMPLATES_LIST => {
613 tracing::debug!("Listing available resource templates");
614 let result = handler.list_resource_templates(ctx).await;
615 match &result {
616 Ok(templates) => {
617 tracing::debug!(count = templates.len(), "Listed resource templates");
618 }
619 Err(e) => tracing::warn!(error = %e, "Failed to list resource templates"),
620 }
621 Some(result.map(|templates| serde_json::json!({ "resourceTemplates": templates })))
622 }
623 methods::RESOURCES_READ => {
624 let result = async {
625 let params = params.ok_or_else(|| {
626 McpError::invalid_params(methods::RESOURCES_READ, "missing params")
627 })?;
628 let uri = params.get("uri").and_then(|v| v.as_str()).ok_or_else(|| {
629 McpError::invalid_params(methods::RESOURCES_READ, "missing uri")
630 })?;
631
632 tracing::info!(uri = %uri, "Reading resource");
633 let start = std::time::Instant::now();
634 let contents = handler.read_resource(uri, ctx).await;
635 let duration = start.elapsed();
636
637 match &contents {
638 Ok(_) => tracing::info!(
639 uri = %uri,
640 duration_ms = duration.as_millis(),
641 "Resource read completed"
642 ),
643 Err(e) => tracing::warn!(
644 uri = %uri,
645 duration_ms = duration.as_millis(),
646 error = %e,
647 "Resource read failed"
648 ),
649 }
650
651 let contents = contents?;
652 Ok(serde_json::json!({ "contents": contents }))
653 }
654 .await;
655 Some(result)
656 }
657 _ => None,
658 }
659}
660
661pub async fn route_prompts<PH: PromptHandler + Send + Sync>(
674 handler: &PH,
675 method: &str,
676 params: Option<&serde_json::Value>,
677 ctx: &Context<'_>,
678) -> Option<Result<serde_json::Value, McpError>> {
679 match method {
680 methods::PROMPTS_LIST => {
681 tracing::debug!("Listing available prompts");
682 let result = handler.list_prompts(ctx).await;
683 match &result {
684 Ok(prompts) => tracing::debug!(count = prompts.len(), "Listed prompts"),
685 Err(e) => tracing::warn!(error = %e, "Failed to list prompts"),
686 }
687 Some(result.map(|prompts| serde_json::json!({ "prompts": prompts })))
688 }
689 methods::PROMPTS_GET => {
690 let result = async {
691 let params = params.ok_or_else(|| {
692 McpError::invalid_params(methods::PROMPTS_GET, "missing params")
693 })?;
694 let name = params.get("name").and_then(|v| v.as_str()).ok_or_else(|| {
695 McpError::invalid_params(methods::PROMPTS_GET, "missing prompt name")
696 })?;
697 let args = params.get("arguments").and_then(|v| v.as_object()).cloned();
698
699 tracing::info!(prompt = %name, "Getting prompt");
700 let start = std::time::Instant::now();
701 let prompt_result = handler.get_prompt(name, args, ctx).await;
702 let duration = start.elapsed();
703
704 match &prompt_result {
705 Ok(_) => tracing::info!(
706 prompt = %name,
707 duration_ms = duration.as_millis(),
708 "Prompt retrieval completed"
709 ),
710 Err(e) => tracing::warn!(
711 prompt = %name,
712 duration_ms = duration.as_millis(),
713 error = %e,
714 "Prompt retrieval failed"
715 ),
716 }
717
718 let result = prompt_result?;
719 Ok(serde_json::to_value(result).unwrap_or_else(|_| serde_json::json!({})))
720 }
721 .await;
722 Some(result)
723 }
724 _ => None,
725 }
726}
727
728#[cfg(test)]
729mod tests {
730 use super::*;
731 use mcpkit_core::protocol::Request;
732
733 fn make_request(method: &'static str, params: Option<Value>) -> Request {
734 if let Some(p) = params {
735 Request::with_params(method, 1u64, p)
736 } else {
737 Request::new(method, 1u64)
738 }
739 }
740
741 #[test]
742 fn test_parse_ping() -> Result<(), Box<dyn std::error::Error>> {
743 let request = make_request("ping", None);
744 let parsed = parse_request(&request)?;
745 assert!(matches!(parsed, ParsedRequest::Ping));
746
747 Ok(())
748 }
749
750 #[test]
751 fn test_parse_tools_list() -> Result<(), Box<dyn std::error::Error>> {
752 let request = make_request("tools/list", None);
753 let parsed = parse_request(&request)?;
754 assert!(matches!(parsed, ParsedRequest::ToolsList(_)));
755
756 Ok(())
757 }
758
759 #[test]
760 fn test_parse_tools_call() -> Result<(), Box<dyn std::error::Error>> {
761 let request = make_request(
762 "tools/call",
763 Some(serde_json::json!({
764 "name": "search",
765 "arguments": {"query": "test"}
766 })),
767 );
768 let parsed = parse_request(&request)?;
769
770 if let ParsedRequest::ToolsCall(params) = parsed {
771 assert_eq!(params.name, "search");
772 } else {
773 panic!("Expected ToolsCall");
774 }
775
776 Ok(())
777 }
778
779 #[test]
780 fn test_parse_unknown_method() -> Result<(), Box<dyn std::error::Error>> {
781 let request = make_request("unknown/method", None);
782 let parsed = parse_request(&request)?;
783
784 if let ParsedRequest::Unknown(method) = parsed {
785 assert_eq!(method, "unknown/method");
786 } else {
787 panic!("Expected Unknown");
788 }
789
790 Ok(())
791 }
792
793 #[test]
794 fn test_parse_initialize() -> Result<(), Box<dyn std::error::Error>> {
795 let request = make_request(
796 "initialize",
797 Some(serde_json::json!({
798 "protocolVersion": "2025-11-25",
799 "clientInfo": {
800 "name": "test-client",
801 "version": "1.0.0"
802 },
803 "capabilities": {}
804 })),
805 );
806 let parsed = parse_request(&request)?;
807
808 if let ParsedRequest::Initialize(params) = parsed {
809 assert_eq!(params.protocol_version, "2025-11-25");
810 assert_eq!(params.client_info.name, "test-client");
811 } else {
812 panic!("Expected Initialize");
813 }
814
815 Ok(())
816 }
817}