mcp_protocol_sdk/server/
handlers.rs

1//! MCP server request handlers
2//!
3//! This module provides specialized handlers for different types of MCP requests,
4//! implementing the business logic for each protocol operation.
5
6use serde_json::Value;
7use std::collections::HashMap;
8
9use crate::core::error::{McpError, McpResult};
10use crate::protocol::{LATEST_PROTOCOL_VERSION, messages::*, methods, types::*};
11
12/// Handler for initialization requests
13pub struct InitializeHandler;
14
15impl InitializeHandler {
16    /// Process an initialize request
17    pub async fn handle(
18        server_info: &ServerInfo,
19        capabilities: &ServerCapabilities,
20        params: Option<Value>,
21    ) -> McpResult<InitializeResult> {
22        let params: InitializeParams = match params {
23            Some(p) => serde_json::from_value(p)
24                .map_err(|e| McpError::Validation(format!("Invalid initialize params: {e}")))?,
25            None => {
26                return Err(McpError::Validation(
27                    "Missing initialize parameters".to_string(),
28                ));
29            }
30        };
31
32        // Validate protocol version compatibility
33        if params.protocol_version != LATEST_PROTOCOL_VERSION {
34            let protocol_version = params.protocol_version;
35            let expected = LATEST_PROTOCOL_VERSION;
36            return Err(McpError::Protocol(format!(
37                "Unsupported protocol version: {protocol_version}. Expected: {expected}"
38            )));
39        }
40
41        // Validate client info
42        if params.client_info.name.is_empty() {
43            return Err(McpError::Validation(
44                "Client name cannot be empty".to_string(),
45            ));
46        }
47
48        if params.client_info.version.is_empty() {
49            return Err(McpError::Validation(
50                "Client version cannot be empty".to_string(),
51            ));
52        }
53
54        Ok(InitializeResult::new(
55            LATEST_PROTOCOL_VERSION.to_string(),
56            capabilities.clone(),
57            server_info.clone(),
58        ))
59    }
60}
61
62/// Handler for tool-related requests
63pub struct ToolHandler;
64
65impl ToolHandler {
66    /// Handle tools/list request
67    pub async fn handle_list(
68        tools: &HashMap<String, crate::core::tool::Tool>,
69        params: Option<Value>,
70    ) -> McpResult<ListToolsResult> {
71        let _params: ListToolsParams = match params {
72            Some(p) => serde_json::from_value(p)
73                .map_err(|e| McpError::Validation(format!("Invalid list tools params: {e}")))?,
74            None => ListToolsParams::default(),
75        };
76
77        // Pagination support will be added in future versions
78        let tools: Vec<ToolInfo> = tools
79            .values()
80            .filter(|tool| tool.enabled)
81            .map(|tool| {
82                // Convert from core::tool::ToolInfo to protocol::types::ToolInfo
83                ToolInfo {
84                    name: tool.info.name.clone(),
85                    description: tool.info.description.clone(),
86                    input_schema: tool.info.input_schema.clone(),
87                    annotations: None,
88                    title: None,
89                    meta: None,
90                }
91            })
92            .collect();
93
94        Ok(ListToolsResult {
95            tools,
96            next_cursor: None,
97            meta: None,
98        })
99    }
100
101    /// Handle tools/call request
102    pub async fn handle_call(
103        tools: &HashMap<String, crate::core::tool::Tool>,
104        params: Option<Value>,
105    ) -> McpResult<CallToolResult> {
106        let params: CallToolParams = match params {
107            Some(p) => serde_json::from_value(p)
108                .map_err(|e| McpError::Validation(format!("Invalid call tool params: {e}")))?,
109            None => {
110                return Err(McpError::Validation(
111                    "Missing tool call parameters".to_string(),
112                ));
113            }
114        };
115
116        if params.name.is_empty() {
117            return Err(McpError::Validation(
118                "Tool name cannot be empty".to_string(),
119            ));
120        }
121
122        let tool = tools
123            .get(&params.name)
124            .ok_or_else(|| McpError::ToolNotFound(params.name.clone()))?;
125
126        if !tool.enabled {
127            let name = &params.name;
128            return Err(McpError::ToolNotFound(format!("Tool '{name}' is disabled")));
129        }
130
131        let arguments = params.arguments.unwrap_or_default();
132        let result = tool.handler.call(arguments).await?;
133
134        Ok(CallToolResult {
135            content: result.content,
136            is_error: result.is_error,
137            structured_content: None,
138            meta: None,
139        })
140    }
141}
142
143/// Handler for resource-related requests
144pub struct ResourceHandler;
145
146impl ResourceHandler {
147    /// Handle resources/list request
148    pub async fn handle_list(
149        resources: &HashMap<String, crate::core::resource::Resource>,
150        params: Option<Value>,
151    ) -> McpResult<ListResourcesResult> {
152        let _params: ListResourcesParams = match params {
153            Some(p) => serde_json::from_value(p)
154                .map_err(|e| McpError::Validation(format!("Invalid list resources params: {e}")))?,
155            None => ListResourcesParams::default(),
156        };
157
158        // Pagination support will be added in future versions
159        let resources: Vec<ResourceInfo> = resources
160            .values()
161            .map(|resource| {
162                // Convert from core::resource::ResourceInfo to protocol::types::ResourceInfo
163                ResourceInfo {
164                    uri: resource.info.uri.clone(),
165                    name: resource.info.name.clone(),
166                    description: resource.info.description.clone(),
167                    mime_type: resource.info.mime_type.clone(),
168                    annotations: None,
169                    size: None,
170                    title: None,
171                    meta: None,
172                }
173            })
174            .collect();
175
176        Ok(ListResourcesResult {
177            resources,
178            next_cursor: None,
179            meta: None,
180        })
181    }
182
183    /// Handle resources/read request
184    pub async fn handle_read(
185        resources: &HashMap<String, crate::core::resource::Resource>,
186        params: Option<Value>,
187    ) -> McpResult<ReadResourceResult> {
188        let params: ReadResourceParams = match params {
189            Some(p) => serde_json::from_value(p)
190                .map_err(|e| McpError::Validation(format!("Invalid read resource params: {e}")))?,
191            None => {
192                return Err(McpError::Validation(
193                    "Missing resource read parameters".to_string(),
194                ));
195            }
196        };
197
198        if params.uri.is_empty() {
199            return Err(McpError::Validation(
200                "Resource URI cannot be empty".to_string(),
201            ));
202        }
203
204        let resource = resources
205            .get(&params.uri)
206            .ok_or_else(|| McpError::ResourceNotFound(params.uri.clone()))?;
207
208        // Query parameter extraction from URI will be implemented in future versions
209        let query_params = HashMap::new();
210        let contents = resource.handler.read(&params.uri, &query_params).await?;
211
212        Ok(ReadResourceResult {
213            contents,
214            meta: None,
215        })
216    }
217
218    /// Handle resources/subscribe request
219    pub async fn handle_subscribe(
220        resources: &HashMap<String, crate::core::resource::Resource>,
221        params: Option<Value>,
222    ) -> McpResult<SubscribeResourceResult> {
223        let params: SubscribeResourceParams = match params {
224            Some(p) => serde_json::from_value(p).map_err(|e| {
225                McpError::Validation(format!("Invalid subscribe resource params: {e}"))
226            })?,
227            None => {
228                return Err(McpError::Validation(
229                    "Missing resource subscribe parameters".to_string(),
230                ));
231            }
232        };
233
234        if params.uri.is_empty() {
235            return Err(McpError::Validation(
236                "Resource URI cannot be empty".to_string(),
237            ));
238        }
239
240        let resource = resources
241            .get(&params.uri)
242            .ok_or_else(|| McpError::ResourceNotFound(params.uri.clone()))?;
243
244        resource.handler.subscribe(&params.uri).await?;
245
246        Ok(SubscribeResourceResult { meta: None })
247    }
248
249    /// Handle resources/unsubscribe request
250    pub async fn handle_unsubscribe(
251        resources: &HashMap<String, crate::core::resource::Resource>,
252        params: Option<Value>,
253    ) -> McpResult<UnsubscribeResourceResult> {
254        let params: UnsubscribeResourceParams = match params {
255            Some(p) => serde_json::from_value(p).map_err(|e| {
256                McpError::Validation(format!("Invalid unsubscribe resource params: {e}"))
257            })?,
258            None => {
259                return Err(McpError::Validation(
260                    "Missing resource unsubscribe parameters".to_string(),
261                ));
262            }
263        };
264
265        if params.uri.is_empty() {
266            return Err(McpError::Validation(
267                "Resource URI cannot be empty".to_string(),
268            ));
269        }
270
271        let resource = resources
272            .get(&params.uri)
273            .ok_or_else(|| McpError::ResourceNotFound(params.uri.clone()))?;
274
275        resource.handler.unsubscribe(&params.uri).await?;
276
277        Ok(UnsubscribeResourceResult { meta: None })
278    }
279}
280
281/// Handler for prompt-related requests
282pub struct PromptHandler;
283
284impl PromptHandler {
285    /// Handle prompts/list request
286    pub async fn handle_list(
287        prompts: &HashMap<String, crate::core::prompt::Prompt>,
288        params: Option<Value>,
289    ) -> McpResult<ListPromptsResult> {
290        let _params: ListPromptsParams = match params {
291            Some(p) => serde_json::from_value(p)
292                .map_err(|e| McpError::Validation(format!("Invalid list prompts params: {e}")))?,
293            None => ListPromptsParams::default(),
294        };
295
296        // Pagination support will be added in future versions
297        let prompts: Vec<PromptInfo> = prompts
298            .values()
299            .map(|prompt| {
300                // Convert from core::prompt::PromptInfo to protocol::types::PromptInfo
301                PromptInfo {
302                    name: prompt.info.name.clone(),
303                    description: prompt.info.description.clone(),
304                    arguments: prompt.info.arguments.as_ref().map(|args| {
305                        args.iter()
306                            .map(|arg| PromptArgument {
307                                name: arg.name.clone(),
308                                description: arg.description.clone(),
309                                required: arg.required,
310                                title: None,
311                            })
312                            .collect()
313                    }),
314                    title: None,
315                    meta: None,
316                }
317            })
318            .collect();
319
320        Ok(ListPromptsResult {
321            prompts,
322            next_cursor: None,
323            meta: None,
324        })
325    }
326
327    /// Handle prompts/get request
328    pub async fn handle_get(
329        prompts: &HashMap<String, crate::core::prompt::Prompt>,
330        params: Option<Value>,
331    ) -> McpResult<GetPromptResult> {
332        let params: GetPromptParams = match params {
333            Some(p) => serde_json::from_value(p)
334                .map_err(|e| McpError::Validation(format!("Invalid get prompt params: {e}")))?,
335            None => {
336                return Err(McpError::Validation(
337                    "Missing prompt get parameters".to_string(),
338                ));
339            }
340        };
341
342        if params.name.is_empty() {
343            return Err(McpError::Validation(
344                "Prompt name cannot be empty".to_string(),
345            ));
346        }
347
348        let prompt = prompts
349            .get(&params.name)
350            .ok_or_else(|| McpError::PromptNotFound(params.name.clone()))?;
351
352        let arguments = params
353            .arguments
354            .unwrap_or_default()
355            .into_iter()
356            .map(|(k, v)| (k, serde_json::Value::String(v)))
357            .collect();
358        let result = prompt.handler.get(arguments).await?;
359
360        Ok(GetPromptResult {
361            description: result.description,
362            messages: result
363                .messages
364                .into_iter()
365                .map(|msg| {
366                    // Convert from core::prompt::PromptMessage to protocol::types::PromptMessage
367                    PromptMessage {
368                        role: msg.role,
369                        content: match msg.content {
370                            ContentBlock::Text { text, .. } => ContentBlock::Text {
371                                text,
372                                annotations: None,
373                                meta: None,
374                            },
375                            ContentBlock::Image {
376                                data, mime_type, ..
377                            } => ContentBlock::Image {
378                                data,
379                                mime_type,
380                                annotations: None,
381                                meta: None,
382                            },
383                            other => other,
384                        },
385                    }
386                })
387                .collect(),
388            meta: None,
389        })
390    }
391}
392
393/// Handler for sampling requests
394pub struct SamplingHandler;
395
396impl SamplingHandler {
397    /// Handle sampling/createMessage request
398    pub async fn handle_create_message(_params: Option<Value>) -> McpResult<CreateMessageResult> {
399        // Note: Sampling is typically handled by the client side (LLM),
400        // but servers can provide sampling capabilities if they have access to LLMs
401        Err(McpError::Protocol(
402            "Sampling not implemented on server side".to_string(),
403        ))
404    }
405}
406
407/// Handler for logging requests
408pub struct LoggingHandler;
409
410impl LoggingHandler {
411    /// Handle logging/setLevel request
412    pub async fn handle_set_level(params: Option<Value>) -> McpResult<SetLoggingLevelResult> {
413        let _params: SetLoggingLevelParams = match params {
414            Some(p) => serde_json::from_value(p).map_err(|e| {
415                McpError::Validation(format!("Invalid set logging level params: {e}"))
416            })?,
417            None => {
418                return Err(McpError::Validation(
419                    "Missing logging level parameters".to_string(),
420                ));
421            }
422        };
423
424        // Logging level management feature planned for future implementation
425        // This would typically integrate with a logging framework like tracing
426
427        Ok(SetLoggingLevelResult { meta: None })
428    }
429}
430
431/// Handler for ping requests
432pub struct PingHandler;
433
434impl PingHandler {
435    /// Handle ping request
436    pub async fn handle(_params: Option<Value>) -> McpResult<PingResult> {
437        Ok(PingResult { meta: None })
438    }
439}
440
441/// Helper functions for common validation patterns
442pub mod validation {
443    use super::*;
444
445    /// Validate that required parameters are present
446    pub fn require_params<T>(params: Option<Value>, error_msg: &str) -> McpResult<T>
447    where
448        T: serde::de::DeserializeOwned,
449    {
450        match params {
451            Some(p) => serde_json::from_value(p)
452                .map_err(|e| McpError::Validation(format!("{error_msg}: {e}"))),
453            None => Err(McpError::Validation(error_msg.to_string())),
454        }
455    }
456
457    /// Validate that a string parameter is not empty
458    pub fn require_non_empty_string(value: &str, field_name: &str) -> McpResult<()> {
459        if value.is_empty() {
460            Err(McpError::Validation(format!(
461                "{field_name} cannot be empty"
462            )))
463        } else {
464            Ok(())
465        }
466    }
467
468    /// Validate URI format
469    pub fn validate_uri_format(uri: &str) -> McpResult<()> {
470        if uri.is_empty() {
471            return Err(McpError::Validation("URI cannot be empty".to_string()));
472        }
473
474        // Basic URI validation - check for scheme or absolute path
475        if !uri.contains("://") && !uri.starts_with('/') && !uri.starts_with("file:") {
476            return Err(McpError::Validation(
477                "URI must have a scheme or be an absolute path".to_string(),
478            ));
479        }
480
481        Ok(())
482    }
483}
484
485/// Notification builders for common server events
486pub mod notifications {
487    use super::*;
488
489    /// Create a tools list changed notification
490    pub fn tools_list_changed() -> McpResult<JsonRpcNotification> {
491        Ok(JsonRpcNotification::new(
492            methods::TOOLS_LIST_CHANGED.to_string(),
493            Some(ToolListChangedParams { meta: None }),
494        )?)
495    }
496
497    /// Create a resources list changed notification
498    pub fn resources_list_changed() -> McpResult<JsonRpcNotification> {
499        Ok(JsonRpcNotification::new(
500            methods::RESOURCES_LIST_CHANGED.to_string(),
501            Some(ResourceListChangedParams { meta: None }),
502        )?)
503    }
504
505    /// Create a prompts list changed notification
506    pub fn prompts_list_changed() -> McpResult<JsonRpcNotification> {
507        Ok(JsonRpcNotification::new(
508            methods::PROMPTS_LIST_CHANGED.to_string(),
509            Some(PromptListChangedParams { meta: None }),
510        )?)
511    }
512
513    /// Create a resource updated notification
514    pub fn resource_updated(uri: String) -> McpResult<JsonRpcNotification> {
515        Ok(JsonRpcNotification::new(
516            methods::RESOURCES_UPDATED.to_string(),
517            Some(ResourceUpdatedParams { uri }),
518        )?)
519    }
520
521    /// Create a progress notification
522    pub fn progress(
523        progress_token: String,
524        progress: f32,
525        total: Option<f32>,
526    ) -> McpResult<JsonRpcNotification> {
527        Ok(JsonRpcNotification::new(
528            methods::PROGRESS.to_string(),
529            Some(ProgressParams {
530                progress_token: serde_json::Value::String(progress_token),
531                progress,
532                total,
533                message: None,
534            }),
535        )?)
536    }
537
538    /// Create a logging message notification
539    pub fn log_message(
540        level: LoggingLevel,
541        logger: Option<String>,
542        data: Value,
543    ) -> McpResult<JsonRpcNotification> {
544        Ok(JsonRpcNotification::new(
545            methods::LOGGING_MESSAGE.to_string(),
546            Some(LoggingMessageParams {
547                level,
548                logger,
549                data,
550            }),
551        )?)
552    }
553}
554
555#[cfg(test)]
556mod tests {
557    use super::*;
558    use serde_json::json;
559
560    #[tokio::test]
561    async fn test_initialize_handler() {
562        let server_info = ServerInfo {
563            name: "test-server".to_string(),
564            version: "1.0.0".to_string(),
565            title: Some("Test Server".to_string()),
566        };
567        let capabilities = ServerCapabilities::default();
568
569        let params = json!({
570            "clientInfo": {
571                "name": "test-client",
572                "version": "1.0.0"
573            },
574            "capabilities": {},
575            "protocolVersion": LATEST_PROTOCOL_VERSION
576        });
577
578        let result = InitializeHandler::handle(&server_info, &capabilities, Some(params)).await;
579        assert!(result.is_ok());
580
581        let init_result = result.unwrap();
582        assert_eq!(init_result.server_info.name, "test-server");
583        assert_eq!(init_result.protocol_version, LATEST_PROTOCOL_VERSION);
584    }
585
586    #[tokio::test]
587    async fn test_ping_handler() {
588        let result = PingHandler::handle(None).await;
589        assert!(result.is_ok());
590    }
591
592    #[test]
593    fn test_validation_helpers() {
594        // Test require_non_empty_string
595        assert!(validation::require_non_empty_string("test", "field").is_ok());
596        assert!(validation::require_non_empty_string("", "field").is_err());
597
598        // Test validate_uri_format
599        assert!(validation::validate_uri_format("https://example.com").is_ok());
600        assert!(validation::validate_uri_format("file:///path").is_ok());
601        assert!(validation::validate_uri_format("/absolute/path").is_ok());
602        assert!(validation::validate_uri_format("").is_err());
603        assert!(validation::validate_uri_format("invalid").is_err());
604    }
605
606    #[test]
607    fn test_notification_builders() {
608        assert!(notifications::tools_list_changed().is_ok());
609        assert!(notifications::resources_list_changed().is_ok());
610        assert!(notifications::prompts_list_changed().is_ok());
611        assert!(notifications::resource_updated("file:///test".to_string()).is_ok());
612        assert!(notifications::progress("token".to_string(), 0.5, Some(100.0)).is_ok());
613        assert!(
614            notifications::log_message(
615                LoggingLevel::Info,
616                Some("test".to_string()),
617                json!({"message": "test log"})
618            )
619            .is_ok()
620        );
621    }
622}