mcp_protocol_sdk/server/
handlers.rs

1//! MCP server request handlers
2//!
3//! This module provides specialized handlers for different types of MCP requests,
4//! implementing the business logic for each protocol operation.
5
6use serde_json::Value;
7use std::collections::HashMap;
8
9use crate::core::error::{McpError, McpResult};
10use crate::protocol::{messages::*, methods, types::*, LATEST_PROTOCOL_VERSION};
11
12/// Handler for initialization requests
13pub struct InitializeHandler;
14
15impl InitializeHandler {
16    /// Process an initialize request
17    pub async fn handle(
18        server_info: &ServerInfo,
19        capabilities: &ServerCapabilities,
20        params: Option<Value>,
21    ) -> McpResult<InitializeResult> {
22        let params: InitializeParams = match params {
23            Some(p) => serde_json::from_value(p)
24                .map_err(|e| McpError::Validation(format!("Invalid initialize params: {}", e)))?,
25            None => {
26                return Err(McpError::Validation(
27                    "Missing initialize parameters".to_string(),
28                ))
29            }
30        };
31
32        // Validate protocol version compatibility
33        if params.protocol_version != LATEST_PROTOCOL_VERSION {
34            return Err(McpError::Protocol(format!(
35                "Unsupported protocol version: {}. Expected: {}",
36                params.protocol_version, LATEST_PROTOCOL_VERSION
37            )));
38        }
39
40        // Validate client info
41        if params.client_info.name.is_empty() {
42            return Err(McpError::Validation(
43                "Client name cannot be empty".to_string(),
44            ));
45        }
46
47        if params.client_info.version.is_empty() {
48            return Err(McpError::Validation(
49                "Client version cannot be empty".to_string(),
50            ));
51        }
52
53        Ok(InitializeResult::new(
54            LATEST_PROTOCOL_VERSION.to_string(),
55            capabilities.clone(),
56            server_info.clone(),
57        ))
58    }
59}
60
61/// Handler for tool-related requests
62pub struct ToolHandler;
63
64impl ToolHandler {
65    /// Handle tools/list request
66    pub async fn handle_list(
67        tools: &HashMap<String, crate::core::tool::Tool>,
68        params: Option<Value>,
69    ) -> McpResult<ListToolsResult> {
70        let _params: ListToolsParams = match params {
71            Some(p) => serde_json::from_value(p)
72                .map_err(|e| McpError::Validation(format!("Invalid list tools params: {}", e)))?,
73            None => ListToolsParams::default(),
74        };
75
76        // Pagination support will be added in future versions
77        let tools: Vec<ToolInfo> = tools
78            .values()
79            .filter(|tool| tool.enabled)
80            .map(|tool| {
81                // Convert from core::tool::ToolInfo to protocol::types::ToolInfo
82                ToolInfo {
83                    name: tool.info.name.clone(),
84                    description: tool.info.description.clone(),
85                    input_schema: tool.info.input_schema.clone(),
86                    annotations: None,
87                }
88            })
89            .collect();
90
91        Ok(ListToolsResult {
92            tools,
93            next_cursor: None,
94            meta: None,
95        })
96    }
97
98    /// Handle tools/call request
99    pub async fn handle_call(
100        tools: &HashMap<String, crate::core::tool::Tool>,
101        params: Option<Value>,
102    ) -> McpResult<CallToolResult> {
103        let params: CallToolParams = match params {
104            Some(p) => serde_json::from_value(p)
105                .map_err(|e| McpError::Validation(format!("Invalid call tool params: {}", e)))?,
106            None => {
107                return Err(McpError::Validation(
108                    "Missing tool call parameters".to_string(),
109                ))
110            }
111        };
112
113        if params.name.is_empty() {
114            return Err(McpError::Validation(
115                "Tool name cannot be empty".to_string(),
116            ));
117        }
118
119        let tool = tools
120            .get(&params.name)
121            .ok_or_else(|| McpError::ToolNotFound(params.name.clone()))?;
122
123        if !tool.enabled {
124            return Err(McpError::ToolNotFound(format!(
125                "Tool '{}' is disabled",
126                params.name
127            )));
128        }
129
130        let arguments = params.arguments.unwrap_or_default();
131        let result = tool.handler.call(arguments).await?;
132
133        Ok(CallToolResult {
134            content: result.content,
135            is_error: result.is_error,
136            meta: None,
137        })
138    }
139}
140
141/// Handler for resource-related requests
142pub struct ResourceHandler;
143
144impl ResourceHandler {
145    /// Handle resources/list request
146    pub async fn handle_list(
147        resources: &HashMap<String, crate::core::resource::Resource>,
148        params: Option<Value>,
149    ) -> McpResult<ListResourcesResult> {
150        let _params: ListResourcesParams = match params {
151            Some(p) => serde_json::from_value(p).map_err(|e| {
152                McpError::Validation(format!("Invalid list resources params: {}", e))
153            })?,
154            None => ListResourcesParams::default(),
155        };
156
157        // Pagination support will be added in future versions
158        let resources: Vec<ResourceInfo> = resources
159            .values()
160            .map(|resource| {
161                // Convert from core::resource::ResourceInfo to protocol::types::ResourceInfo
162                ResourceInfo {
163                    uri: resource.info.uri.clone(),
164                    name: resource.info.name.clone(),
165                    description: resource.info.description.clone(),
166                    mime_type: resource.info.mime_type.clone(),
167                    annotations: None,
168                    size: None,
169                }
170            })
171            .collect();
172
173        Ok(ListResourcesResult {
174            resources,
175            next_cursor: None,
176            meta: None,
177        })
178    }
179
180    /// Handle resources/read request
181    pub async fn handle_read(
182        resources: &HashMap<String, crate::core::resource::Resource>,
183        params: Option<Value>,
184    ) -> McpResult<ReadResourceResult> {
185        let params: ReadResourceParams = match params {
186            Some(p) => serde_json::from_value(p).map_err(|e| {
187                McpError::Validation(format!("Invalid read resource params: {}", e))
188            })?,
189            None => {
190                return Err(McpError::Validation(
191                    "Missing resource read parameters".to_string(),
192                ))
193            }
194        };
195
196        if params.uri.is_empty() {
197            return Err(McpError::Validation(
198                "Resource URI cannot be empty".to_string(),
199            ));
200        }
201
202        let resource = resources
203            .get(&params.uri)
204            .ok_or_else(|| McpError::ResourceNotFound(params.uri.clone()))?;
205
206        // Query parameter extraction from URI will be implemented in future versions
207        let query_params = HashMap::new();
208        let contents = resource.handler.read(&params.uri, &query_params).await?;
209
210        Ok(ReadResourceResult {
211            contents,
212            meta: None,
213        })
214    }
215
216    /// Handle resources/subscribe request
217    pub async fn handle_subscribe(
218        resources: &HashMap<String, crate::core::resource::Resource>,
219        params: Option<Value>,
220    ) -> McpResult<SubscribeResourceResult> {
221        let params: SubscribeResourceParams = match params {
222            Some(p) => serde_json::from_value(p).map_err(|e| {
223                McpError::Validation(format!("Invalid subscribe resource params: {}", e))
224            })?,
225            None => {
226                return Err(McpError::Validation(
227                    "Missing resource subscribe parameters".to_string(),
228                ))
229            }
230        };
231
232        if params.uri.is_empty() {
233            return Err(McpError::Validation(
234                "Resource URI cannot be empty".to_string(),
235            ));
236        }
237
238        let resource = resources
239            .get(&params.uri)
240            .ok_or_else(|| McpError::ResourceNotFound(params.uri.clone()))?;
241
242        resource.handler.subscribe(&params.uri).await?;
243
244        Ok(SubscribeResourceResult { meta: None })
245    }
246
247    /// Handle resources/unsubscribe request
248    pub async fn handle_unsubscribe(
249        resources: &HashMap<String, crate::core::resource::Resource>,
250        params: Option<Value>,
251    ) -> McpResult<UnsubscribeResourceResult> {
252        let params: UnsubscribeResourceParams = match params {
253            Some(p) => serde_json::from_value(p).map_err(|e| {
254                McpError::Validation(format!("Invalid unsubscribe resource params: {}", e))
255            })?,
256            None => {
257                return Err(McpError::Validation(
258                    "Missing resource unsubscribe parameters".to_string(),
259                ))
260            }
261        };
262
263        if params.uri.is_empty() {
264            return Err(McpError::Validation(
265                "Resource URI cannot be empty".to_string(),
266            ));
267        }
268
269        let resource = resources
270            .get(&params.uri)
271            .ok_or_else(|| McpError::ResourceNotFound(params.uri.clone()))?;
272
273        resource.handler.unsubscribe(&params.uri).await?;
274
275        Ok(UnsubscribeResourceResult { meta: None })
276    }
277}
278
279/// Handler for prompt-related requests
280pub struct PromptHandler;
281
282impl PromptHandler {
283    /// Handle prompts/list request
284    pub async fn handle_list(
285        prompts: &HashMap<String, crate::core::prompt::Prompt>,
286        params: Option<Value>,
287    ) -> McpResult<ListPromptsResult> {
288        let _params: ListPromptsParams = match params {
289            Some(p) => serde_json::from_value(p)
290                .map_err(|e| McpError::Validation(format!("Invalid list prompts params: {}", e)))?,
291            None => ListPromptsParams::default(),
292        };
293
294        // Pagination support will be added in future versions
295        let prompts: Vec<PromptInfo> = prompts
296            .values()
297            .map(|prompt| {
298                // Convert from core::prompt::PromptInfo to protocol::types::PromptInfo
299                PromptInfo {
300                    name: prompt.info.name.clone(),
301                    description: prompt.info.description.clone(),
302                    arguments: prompt.info.arguments.as_ref().map(|args| {
303                        args.iter()
304                            .map(|arg| PromptArgument {
305                                name: arg.name.clone(),
306                                description: arg.description.clone(),
307                                required: arg.required,
308                            })
309                            .collect()
310                    }),
311                }
312            })
313            .collect();
314
315        Ok(ListPromptsResult {
316            prompts,
317            next_cursor: None,
318            meta: None,
319        })
320    }
321
322    /// Handle prompts/get request
323    pub async fn handle_get(
324        prompts: &HashMap<String, crate::core::prompt::Prompt>,
325        params: Option<Value>,
326    ) -> McpResult<GetPromptResult> {
327        let params: GetPromptParams = match params {
328            Some(p) => serde_json::from_value(p)
329                .map_err(|e| McpError::Validation(format!("Invalid get prompt params: {}", e)))?,
330            None => {
331                return Err(McpError::Validation(
332                    "Missing prompt get parameters".to_string(),
333                ))
334            }
335        };
336
337        if params.name.is_empty() {
338            return Err(McpError::Validation(
339                "Prompt name cannot be empty".to_string(),
340            ));
341        }
342
343        let prompt = prompts
344            .get(&params.name)
345            .ok_or_else(|| McpError::PromptNotFound(params.name.clone()))?;
346
347        let arguments = params
348            .arguments
349            .unwrap_or_default()
350            .into_iter()
351            .map(|(k, v)| (k, serde_json::Value::String(v)))
352            .collect();
353        let result = prompt.handler.get(arguments).await?;
354
355        Ok(GetPromptResult {
356            description: result.description,
357            messages: result
358                .messages
359                .into_iter()
360                .map(|msg| {
361                    // Convert from core::prompt::PromptMessage to protocol::types::PromptMessage
362                    PromptMessage {
363                        role: msg.role,
364                        content: match msg.content {
365                            Content::Text { text, .. } => Content::Text {
366                                text,
367                                annotations: None,
368                            },
369                            Content::Image {
370                                data, mime_type, ..
371                            } => Content::Image {
372                                data,
373                                mime_type,
374                                annotations: None,
375                            },
376                            other => other,
377                        },
378                    }
379                })
380                .collect(),
381            meta: None,
382        })
383    }
384}
385
386/// Handler for sampling requests
387pub struct SamplingHandler;
388
389impl SamplingHandler {
390    /// Handle sampling/createMessage request
391    pub async fn handle_create_message(_params: Option<Value>) -> McpResult<CreateMessageResult> {
392        // Note: Sampling is typically handled by the client side (LLM),
393        // but servers can provide sampling capabilities if they have access to LLMs
394        Err(McpError::Protocol(
395            "Sampling not implemented on server side".to_string(),
396        ))
397    }
398}
399
400/// Handler for logging requests
401pub struct LoggingHandler;
402
403impl LoggingHandler {
404    /// Handle logging/setLevel request
405    pub async fn handle_set_level(params: Option<Value>) -> McpResult<SetLoggingLevelResult> {
406        let _params: SetLoggingLevelParams = match params {
407            Some(p) => serde_json::from_value(p).map_err(|e| {
408                McpError::Validation(format!("Invalid set logging level params: {}", e))
409            })?,
410            None => {
411                return Err(McpError::Validation(
412                    "Missing logging level parameters".to_string(),
413                ))
414            }
415        };
416
417        // Logging level management feature planned for future implementation
418        // This would typically integrate with a logging framework like tracing
419
420        Ok(SetLoggingLevelResult { meta: None })
421    }
422}
423
424/// Handler for ping requests
425pub struct PingHandler;
426
427impl PingHandler {
428    /// Handle ping request
429    pub async fn handle(_params: Option<Value>) -> McpResult<PingResult> {
430        Ok(PingResult { meta: None })
431    }
432}
433
434/// Helper functions for common validation patterns
435pub mod validation {
436    use super::*;
437
438    /// Validate that required parameters are present
439    pub fn require_params<T>(params: Option<Value>, error_msg: &str) -> McpResult<T>
440    where
441        T: serde::de::DeserializeOwned,
442    {
443        match params {
444            Some(p) => serde_json::from_value(p)
445                .map_err(|e| McpError::Validation(format!("{}: {}", error_msg, e))),
446            None => Err(McpError::Validation(error_msg.to_string())),
447        }
448    }
449
450    /// Validate that a string parameter is not empty
451    pub fn require_non_empty_string(value: &str, field_name: &str) -> McpResult<()> {
452        if value.is_empty() {
453            Err(McpError::Validation(format!(
454                "{} cannot be empty",
455                field_name
456            )))
457        } else {
458            Ok(())
459        }
460    }
461
462    /// Validate URI format
463    pub fn validate_uri_format(uri: &str) -> McpResult<()> {
464        if uri.is_empty() {
465            return Err(McpError::Validation("URI cannot be empty".to_string()));
466        }
467
468        // Basic URI validation - check for scheme or absolute path
469        if !uri.contains("://") && !uri.starts_with('/') && !uri.starts_with("file:") {
470            return Err(McpError::Validation(
471                "URI must have a scheme or be an absolute path".to_string(),
472            ));
473        }
474
475        Ok(())
476    }
477}
478
479/// Notification builders for common server events
480pub mod notifications {
481    use super::*;
482
483    /// Create a tools list changed notification
484    pub fn tools_list_changed() -> McpResult<JsonRpcNotification> {
485        Ok(JsonRpcNotification::new(
486            methods::TOOLS_LIST_CHANGED.to_string(),
487            Some(ToolListChangedParams { meta: None }),
488        )?)
489    }
490
491    /// Create a resources list changed notification
492    pub fn resources_list_changed() -> McpResult<JsonRpcNotification> {
493        Ok(JsonRpcNotification::new(
494            methods::RESOURCES_LIST_CHANGED.to_string(),
495            Some(ResourceListChangedParams { meta: None }),
496        )?)
497    }
498
499    /// Create a prompts list changed notification
500    pub fn prompts_list_changed() -> McpResult<JsonRpcNotification> {
501        Ok(JsonRpcNotification::new(
502            methods::PROMPTS_LIST_CHANGED.to_string(),
503            Some(PromptListChangedParams { meta: None }),
504        )?)
505    }
506
507    /// Create a resource updated notification
508    pub fn resource_updated(uri: String) -> McpResult<JsonRpcNotification> {
509        Ok(JsonRpcNotification::new(
510            methods::RESOURCES_UPDATED.to_string(),
511            Some(ResourceUpdatedParams { uri }),
512        )?)
513    }
514
515    /// Create a progress notification
516    pub fn progress(
517        progress_token: String,
518        progress: f32,
519        total: Option<f32>,
520    ) -> McpResult<JsonRpcNotification> {
521        Ok(JsonRpcNotification::new(
522            methods::PROGRESS.to_string(),
523            Some(ProgressParams {
524                progress_token: serde_json::Value::String(progress_token),
525                progress,
526                total,
527                message: None,
528            }),
529        )?)
530    }
531
532    /// Create a logging message notification
533    pub fn log_message(
534        level: LoggingLevel,
535        logger: Option<String>,
536        data: Value,
537    ) -> McpResult<JsonRpcNotification> {
538        Ok(JsonRpcNotification::new(
539            methods::LOGGING_MESSAGE.to_string(),
540            Some(LoggingMessageParams {
541                level,
542                logger,
543                data,
544            }),
545        )?)
546    }
547}
548
549#[cfg(test)]
550mod tests {
551    use super::*;
552    use serde_json::json;
553
554    #[tokio::test]
555    async fn test_initialize_handler() {
556        let server_info = ServerInfo {
557            name: "test-server".to_string(),
558            version: "1.0.0".to_string(),
559        };
560        let capabilities = ServerCapabilities::default();
561
562        let params = json!({
563            "clientInfo": {
564                "name": "test-client",
565                "version": "1.0.0"
566            },
567            "capabilities": {},
568            "protocolVersion": LATEST_PROTOCOL_VERSION
569        });
570
571        let result = InitializeHandler::handle(&server_info, &capabilities, Some(params)).await;
572        assert!(result.is_ok());
573
574        let init_result = result.unwrap();
575        assert_eq!(init_result.server_info.name, "test-server");
576        assert_eq!(init_result.protocol_version, LATEST_PROTOCOL_VERSION);
577    }
578
579    #[tokio::test]
580    async fn test_ping_handler() {
581        let result = PingHandler::handle(None).await;
582        assert!(result.is_ok());
583    }
584
585    #[test]
586    fn test_validation_helpers() {
587        // Test require_non_empty_string
588        assert!(validation::require_non_empty_string("test", "field").is_ok());
589        assert!(validation::require_non_empty_string("", "field").is_err());
590
591        // Test validate_uri_format
592        assert!(validation::validate_uri_format("https://example.com").is_ok());
593        assert!(validation::validate_uri_format("file:///path").is_ok());
594        assert!(validation::validate_uri_format("/absolute/path").is_ok());
595        assert!(validation::validate_uri_format("").is_err());
596        assert!(validation::validate_uri_format("invalid").is_err());
597    }
598
599    #[test]
600    fn test_notification_builders() {
601        assert!(notifications::tools_list_changed().is_ok());
602        assert!(notifications::resources_list_changed().is_ok());
603        assert!(notifications::prompts_list_changed().is_ok());
604        assert!(notifications::resource_updated("file:///test".to_string()).is_ok());
605        assert!(notifications::progress("token".to_string(), 0.5, Some(100.0)).is_ok());
606        assert!(notifications::log_message(
607            LoggingLevel::Info,
608            Some("test".to_string()),
609            json!({"message": "test log"})
610        )
611        .is_ok());
612    }
613}