mcp_protocol_sdk/server/
handlers.rs

1//! MCP server request handlers
2//!
3//! This module provides specialized handlers for different types of MCP requests,
4//! implementing the business logic for each protocol operation.
5
6use serde_json::Value;
7use std::collections::HashMap;
8
9use crate::core::error::{McpError, McpResult};
10use crate::protocol::{messages::*, types::*};
11
12/// Handler for initialization requests
13pub struct InitializeHandler;
14
15impl InitializeHandler {
16    /// Process an initialize request
17    pub async fn handle(
18        server_info: &ServerInfo,
19        capabilities: &ServerCapabilities,
20        params: Option<Value>,
21    ) -> McpResult<InitializeResult> {
22        let params: InitializeParams = match params {
23            Some(p) => serde_json::from_value(p)
24                .map_err(|e| McpError::Validation(format!("Invalid initialize params: {}", e)))?,
25            None => {
26                return Err(McpError::Validation(
27                    "Missing initialize parameters".to_string(),
28                ))
29            }
30        };
31
32        // Validate protocol version compatibility
33        if params.protocol_version != MCP_PROTOCOL_VERSION {
34            return Err(McpError::Protocol(format!(
35                "Unsupported protocol version: {}. Expected: {}",
36                params.protocol_version, MCP_PROTOCOL_VERSION
37            )));
38        }
39
40        // Validate client info
41        if params.client_info.name.is_empty() {
42            return Err(McpError::Validation(
43                "Client name cannot be empty".to_string(),
44            ));
45        }
46
47        if params.client_info.version.is_empty() {
48            return Err(McpError::Validation(
49                "Client version cannot be empty".to_string(),
50            ));
51        }
52
53        Ok(InitializeResult::new(
54            server_info.clone(),
55            capabilities.clone(),
56            MCP_PROTOCOL_VERSION.to_string(),
57        ))
58    }
59}
60
61/// Handler for tool-related requests
62pub struct ToolHandler;
63
64impl ToolHandler {
65    /// Handle tools/list request
66    pub async fn handle_list(
67        tools: &HashMap<String, crate::core::tool::Tool>,
68        params: Option<Value>,
69    ) -> McpResult<ListToolsResult> {
70        let _params: ListToolsParams = match params {
71            Some(p) => serde_json::from_value(p)
72                .map_err(|e| McpError::Validation(format!("Invalid list tools params: {}", e)))?,
73            None => ListToolsParams::default(),
74        };
75
76        // Pagination support will be added in future versions
77        let tools: Vec<ToolInfo> = tools
78            .values()
79            .filter(|tool| tool.enabled)
80            .map(|tool| {
81                // Convert from core::tool::ToolInfo to protocol::types::ToolInfo
82                ToolInfo {
83                    name: tool.info.name.clone(),
84                    description: tool.info.description.clone(),
85                    input_schema: tool.info.input_schema.clone(),
86                }
87            })
88            .collect();
89
90        Ok(ListToolsResult {
91            tools,
92            next_cursor: None,
93        })
94    }
95
96    /// Handle tools/call request
97    pub async fn handle_call(
98        tools: &HashMap<String, crate::core::tool::Tool>,
99        params: Option<Value>,
100    ) -> McpResult<CallToolResult> {
101        let params: CallToolParams = match params {
102            Some(p) => serde_json::from_value(p)
103                .map_err(|e| McpError::Validation(format!("Invalid call tool params: {}", e)))?,
104            None => {
105                return Err(McpError::Validation(
106                    "Missing tool call parameters".to_string(),
107                ))
108            }
109        };
110
111        if params.name.is_empty() {
112            return Err(McpError::Validation(
113                "Tool name cannot be empty".to_string(),
114            ));
115        }
116
117        let tool = tools
118            .get(&params.name)
119            .ok_or_else(|| McpError::ToolNotFound(params.name.clone()))?;
120
121        if !tool.enabled {
122            return Err(McpError::ToolNotFound(format!(
123                "Tool '{}' is disabled",
124                params.name
125            )));
126        }
127
128        let arguments = params.arguments.unwrap_or_default();
129        let result = tool.handler.call(arguments).await?;
130
131        Ok(CallToolResult {
132            content: result.content,
133            is_error: result.is_error,
134        })
135    }
136}
137
138/// Handler for resource-related requests
139pub struct ResourceHandler;
140
141impl ResourceHandler {
142    /// Handle resources/list request
143    pub async fn handle_list(
144        resources: &HashMap<String, crate::core::resource::Resource>,
145        params: Option<Value>,
146    ) -> McpResult<ListResourcesResult> {
147        let _params: ListResourcesParams = match params {
148            Some(p) => serde_json::from_value(p).map_err(|e| {
149                McpError::Validation(format!("Invalid list resources params: {}", e))
150            })?,
151            None => ListResourcesParams::default(),
152        };
153
154        // Pagination support will be added in future versions
155        let resources: Vec<ResourceInfo> = resources
156            .values()
157            .map(|resource| {
158                // Convert from core::resource::ResourceInfo to protocol::types::ResourceInfo
159                ResourceInfo {
160                    uri: resource.info.uri.clone(),
161                    name: resource.info.name.clone(),
162                    description: resource.info.description.clone(),
163                    mime_type: resource.info.mime_type.clone(),
164                }
165            })
166            .collect();
167
168        Ok(ListResourcesResult {
169            resources,
170            next_cursor: None,
171        })
172    }
173
174    /// Handle resources/read request
175    pub async fn handle_read(
176        resources: &HashMap<String, crate::core::resource::Resource>,
177        params: Option<Value>,
178    ) -> McpResult<ReadResourceResult> {
179        let params: ReadResourceParams = match params {
180            Some(p) => serde_json::from_value(p).map_err(|e| {
181                McpError::Validation(format!("Invalid read resource params: {}", e))
182            })?,
183            None => {
184                return Err(McpError::Validation(
185                    "Missing resource read parameters".to_string(),
186                ))
187            }
188        };
189
190        if params.uri.is_empty() {
191            return Err(McpError::Validation(
192                "Resource URI cannot be empty".to_string(),
193            ));
194        }
195
196        let resource = resources
197            .get(&params.uri)
198            .ok_or_else(|| McpError::ResourceNotFound(params.uri.clone()))?;
199
200        // Query parameter extraction from URI will be implemented in future versions
201        let query_params = HashMap::new();
202        let contents = resource.handler.read(&params.uri, &query_params).await?;
203
204        Ok(ReadResourceResult { contents })
205    }
206
207    /// Handle resources/subscribe request
208    pub async fn handle_subscribe(
209        resources: &HashMap<String, crate::core::resource::Resource>,
210        params: Option<Value>,
211    ) -> McpResult<SubscribeResourceResult> {
212        let params: SubscribeResourceParams = match params {
213            Some(p) => serde_json::from_value(p).map_err(|e| {
214                McpError::Validation(format!("Invalid subscribe resource params: {}", e))
215            })?,
216            None => {
217                return Err(McpError::Validation(
218                    "Missing resource subscribe parameters".to_string(),
219                ))
220            }
221        };
222
223        if params.uri.is_empty() {
224            return Err(McpError::Validation(
225                "Resource URI cannot be empty".to_string(),
226            ));
227        }
228
229        let resource = resources
230            .get(&params.uri)
231            .ok_or_else(|| McpError::ResourceNotFound(params.uri.clone()))?;
232
233        resource.handler.subscribe(&params.uri).await?;
234
235        Ok(SubscribeResourceResult {})
236    }
237
238    /// Handle resources/unsubscribe request
239    pub async fn handle_unsubscribe(
240        resources: &HashMap<String, crate::core::resource::Resource>,
241        params: Option<Value>,
242    ) -> McpResult<UnsubscribeResourceResult> {
243        let params: UnsubscribeResourceParams = match params {
244            Some(p) => serde_json::from_value(p).map_err(|e| {
245                McpError::Validation(format!("Invalid unsubscribe resource params: {}", e))
246            })?,
247            None => {
248                return Err(McpError::Validation(
249                    "Missing resource unsubscribe parameters".to_string(),
250                ))
251            }
252        };
253
254        if params.uri.is_empty() {
255            return Err(McpError::Validation(
256                "Resource URI cannot be empty".to_string(),
257            ));
258        }
259
260        let resource = resources
261            .get(&params.uri)
262            .ok_or_else(|| McpError::ResourceNotFound(params.uri.clone()))?;
263
264        resource.handler.unsubscribe(&params.uri).await?;
265
266        Ok(UnsubscribeResourceResult {})
267    }
268}
269
270/// Handler for prompt-related requests
271pub struct PromptHandler;
272
273impl PromptHandler {
274    /// Handle prompts/list request
275    pub async fn handle_list(
276        prompts: &HashMap<String, crate::core::prompt::Prompt>,
277        params: Option<Value>,
278    ) -> McpResult<ListPromptsResult> {
279        let _params: ListPromptsParams = match params {
280            Some(p) => serde_json::from_value(p)
281                .map_err(|e| McpError::Validation(format!("Invalid list prompts params: {}", e)))?,
282            None => ListPromptsParams::default(),
283        };
284
285        // Pagination support will be added in future versions
286        let prompts: Vec<PromptInfo> = prompts
287            .values()
288            .map(|prompt| {
289                // Convert from core::prompt::PromptInfo to protocol::types::PromptInfo
290                PromptInfo {
291                    name: prompt.info.name.clone(),
292                    description: prompt.info.description.clone(),
293                    arguments: prompt.info.arguments.as_ref().map(|args| {
294                        args.iter()
295                            .map(|arg| PromptArgument {
296                                name: arg.name.clone(),
297                                description: arg.description.clone(),
298                                required: arg.required,
299                            })
300                            .collect()
301                    }),
302                }
303            })
304            .collect();
305
306        Ok(ListPromptsResult {
307            prompts,
308            next_cursor: None,
309        })
310    }
311
312    /// Handle prompts/get request
313    pub async fn handle_get(
314        prompts: &HashMap<String, crate::core::prompt::Prompt>,
315        params: Option<Value>,
316    ) -> McpResult<GetPromptResult> {
317        let params: GetPromptParams = match params {
318            Some(p) => serde_json::from_value(p)
319                .map_err(|e| McpError::Validation(format!("Invalid get prompt params: {}", e)))?,
320            None => {
321                return Err(McpError::Validation(
322                    "Missing prompt get parameters".to_string(),
323                ))
324            }
325        };
326
327        if params.name.is_empty() {
328            return Err(McpError::Validation(
329                "Prompt name cannot be empty".to_string(),
330            ));
331        }
332
333        let prompt = prompts
334            .get(&params.name)
335            .ok_or_else(|| McpError::PromptNotFound(params.name.clone()))?;
336
337        let arguments = params.arguments.unwrap_or_default();
338        let result = prompt.handler.get(arguments).await?;
339
340        Ok(GetPromptResult {
341            description: result.description,
342            messages: result
343                .messages
344                .into_iter()
345                .map(|msg| {
346                    // Convert from core::prompt::PromptMessage to protocol::types::PromptMessage
347                    PromptMessage {
348                        role: msg.role,
349                        content: match msg.content {
350                            crate::protocol::types::PromptContent::Text { content_type, text } => {
351                                PromptContent::Text { content_type, text }
352                            }
353                            crate::protocol::types::PromptContent::Image {
354                                content_type,
355                                data,
356                                mime_type,
357                            } => PromptContent::Image {
358                                content_type,
359                                data,
360                                mime_type,
361                            },
362                        },
363                    }
364                })
365                .collect(),
366        })
367    }
368}
369
370/// Handler for sampling requests
371pub struct SamplingHandler;
372
373impl SamplingHandler {
374    /// Handle sampling/createMessage request
375    pub async fn handle_create_message(_params: Option<Value>) -> McpResult<CreateMessageResult> {
376        // Note: Sampling is typically handled by the client side (LLM),
377        // but servers can provide sampling capabilities if they have access to LLMs
378        Err(McpError::Protocol(
379            "Sampling not implemented on server side".to_string(),
380        ))
381    }
382}
383
384/// Handler for logging requests
385pub struct LoggingHandler;
386
387impl LoggingHandler {
388    /// Handle logging/setLevel request
389    pub async fn handle_set_level(params: Option<Value>) -> McpResult<SetLoggingLevelResult> {
390        let _params: SetLoggingLevelParams = match params {
391            Some(p) => serde_json::from_value(p).map_err(|e| {
392                McpError::Validation(format!("Invalid set logging level params: {}", e))
393            })?,
394            None => {
395                return Err(McpError::Validation(
396                    "Missing logging level parameters".to_string(),
397                ))
398            }
399        };
400
401        // Logging level management feature planned for future implementation
402        // This would typically integrate with a logging framework like tracing
403
404        Ok(SetLoggingLevelResult {})
405    }
406}
407
408/// Handler for ping requests
409pub struct PingHandler;
410
411impl PingHandler {
412    /// Handle ping request
413    pub async fn handle(_params: Option<Value>) -> McpResult<PingResult> {
414        Ok(PingResult {})
415    }
416}
417
418/// Helper functions for common validation patterns
419pub mod validation {
420    use super::*;
421
422    /// Validate that required parameters are present
423    pub fn require_params<T>(params: Option<Value>, error_msg: &str) -> McpResult<T>
424    where
425        T: serde::de::DeserializeOwned,
426    {
427        match params {
428            Some(p) => serde_json::from_value(p)
429                .map_err(|e| McpError::Validation(format!("{}: {}", error_msg, e))),
430            None => Err(McpError::Validation(error_msg.to_string())),
431        }
432    }
433
434    /// Validate that a string parameter is not empty
435    pub fn require_non_empty_string(value: &str, field_name: &str) -> McpResult<()> {
436        if value.is_empty() {
437            Err(McpError::Validation(format!(
438                "{} cannot be empty",
439                field_name
440            )))
441        } else {
442            Ok(())
443        }
444    }
445
446    /// Validate URI format
447    pub fn validate_uri_format(uri: &str) -> McpResult<()> {
448        if uri.is_empty() {
449            return Err(McpError::Validation("URI cannot be empty".to_string()));
450        }
451
452        // Basic URI validation - check for scheme or absolute path
453        if !uri.contains("://") && !uri.starts_with('/') && !uri.starts_with("file:") {
454            return Err(McpError::Validation(
455                "URI must have a scheme or be an absolute path".to_string(),
456            ));
457        }
458
459        Ok(())
460    }
461}
462
463/// Notification builders for common server events
464pub mod notifications {
465    use super::*;
466
467    /// Create a tools list changed notification
468    pub fn tools_list_changed() -> McpResult<JsonRpcNotification> {
469        Ok(JsonRpcNotification::new(
470            methods::TOOLS_LIST_CHANGED.to_string(),
471            Some(ToolListChangedParams {}),
472        )?)
473    }
474
475    /// Create a resources list changed notification
476    pub fn resources_list_changed() -> McpResult<JsonRpcNotification> {
477        Ok(JsonRpcNotification::new(
478            methods::RESOURCES_LIST_CHANGED.to_string(),
479            Some(ResourceListChangedParams {}),
480        )?)
481    }
482
483    /// Create a prompts list changed notification
484    pub fn prompts_list_changed() -> McpResult<JsonRpcNotification> {
485        Ok(JsonRpcNotification::new(
486            methods::PROMPTS_LIST_CHANGED.to_string(),
487            Some(PromptListChangedParams {}),
488        )?)
489    }
490
491    /// Create a resource updated notification
492    pub fn resource_updated(uri: String) -> McpResult<JsonRpcNotification> {
493        Ok(JsonRpcNotification::new(
494            methods::RESOURCES_UPDATED.to_string(),
495            Some(ResourceUpdatedParams { uri }),
496        )?)
497    }
498
499    /// Create a progress notification
500    pub fn progress(
501        progress_token: String,
502        progress: f32,
503        total: Option<u32>,
504    ) -> McpResult<JsonRpcNotification> {
505        Ok(JsonRpcNotification::new(
506            methods::PROGRESS.to_string(),
507            Some(ProgressParams {
508                progress_token,
509                progress,
510                total,
511            }),
512        )?)
513    }
514
515    /// Create a logging message notification
516    pub fn log_message(
517        level: LoggingLevel,
518        logger: Option<String>,
519        data: Value,
520    ) -> McpResult<JsonRpcNotification> {
521        Ok(JsonRpcNotification::new(
522            methods::LOGGING_MESSAGE.to_string(),
523            Some(LoggingMessageParams {
524                level,
525                logger,
526                data,
527            }),
528        )?)
529    }
530}
531
532#[cfg(test)]
533mod tests {
534    use super::*;
535    use serde_json::json;
536
537    #[tokio::test]
538    async fn test_initialize_handler() {
539        let server_info = ServerInfo {
540            name: "test-server".to_string(),
541            version: "1.0.0".to_string(),
542        };
543        let capabilities = ServerCapabilities::default();
544
545        let params = json!({
546            "clientInfo": {
547                "name": "test-client",
548                "version": "1.0.0"
549            },
550            "capabilities": {},
551            "protocolVersion": MCP_PROTOCOL_VERSION
552        });
553
554        let result = InitializeHandler::handle(&server_info, &capabilities, Some(params)).await;
555        assert!(result.is_ok());
556
557        let init_result = result.unwrap();
558        assert_eq!(init_result.server_info.name, "test-server");
559        assert_eq!(init_result.protocol_version, MCP_PROTOCOL_VERSION);
560    }
561
562    #[tokio::test]
563    async fn test_ping_handler() {
564        let result = PingHandler::handle(None).await;
565        assert!(result.is_ok());
566    }
567
568    #[test]
569    fn test_validation_helpers() {
570        // Test require_non_empty_string
571        assert!(validation::require_non_empty_string("test", "field").is_ok());
572        assert!(validation::require_non_empty_string("", "field").is_err());
573
574        // Test validate_uri_format
575        assert!(validation::validate_uri_format("https://example.com").is_ok());
576        assert!(validation::validate_uri_format("file:///path").is_ok());
577        assert!(validation::validate_uri_format("/absolute/path").is_ok());
578        assert!(validation::validate_uri_format("").is_err());
579        assert!(validation::validate_uri_format("invalid").is_err());
580    }
581
582    #[test]
583    fn test_notification_builders() {
584        assert!(notifications::tools_list_changed().is_ok());
585        assert!(notifications::resources_list_changed().is_ok());
586        assert!(notifications::prompts_list_changed().is_ok());
587        assert!(notifications::resource_updated("file:///test".to_string()).is_ok());
588        assert!(notifications::progress("token".to_string(), 0.5, Some(100)).is_ok());
589        assert!(notifications::log_message(
590            LoggingLevel::Info,
591            Some("test".to_string()),
592            json!({"message": "test log"})
593        )
594        .is_ok());
595    }
596}