mcp_protocol_sdk/protocol/
validation.rs

1//! MCP protocol validation utilities
2//!
3//! This module provides validation functions for MCP protocol messages and types,
4//! ensuring that requests and responses conform to the protocol specification.
5
6use crate::core::error::{McpError, McpResult};
7use crate::protocol::messages::*;
8use crate::protocol::types::*;
9use serde_json::Value;
10
11/// Validates that a JSON-RPC message conforms to the specification
12pub fn validate_jsonrpc_message(message: &Value) -> McpResult<()> {
13    let obj = message
14        .as_object()
15        .ok_or_else(|| McpError::Validation("Message must be a JSON object".to_string()))?;
16
17    // Check required jsonrpc field
18    let jsonrpc = obj
19        .get("jsonrpc")
20        .and_then(|v| v.as_str())
21        .ok_or_else(|| McpError::Validation("Missing or invalid 'jsonrpc' field".to_string()))?;
22
23    if jsonrpc != "2.0" {
24        return Err(McpError::Validation("jsonrpc must be '2.0'".to_string()));
25    }
26
27    // Check that it has either 'method' (request/notification) or 'result'/'error' (response)
28    let has_method = obj.contains_key("method");
29    let has_result = obj.contains_key("result");
30    let has_error = obj.contains_key("error");
31    let has_id = obj.contains_key("id");
32
33    if has_method {
34        // Request or notification
35        if has_result || has_error {
36            return Err(McpError::Validation(
37                "Request/notification cannot have 'result' or 'error' fields".to_string(),
38            ));
39        }
40
41        // Requests must have an id, notifications must not
42        // We allow both for flexibility in parsing
43    } else if has_result || has_error {
44        // Response
45        if !has_id {
46            return Err(McpError::Validation(
47                "Response must have an 'id' field".to_string(),
48            ));
49        }
50
51        if has_result && has_error {
52            return Err(McpError::Validation(
53                "Response cannot have both 'result' and 'error' fields".to_string(),
54            ));
55        }
56    } else {
57        return Err(McpError::Validation(
58            "Message must be a request, response, or notification".to_string(),
59        ));
60    }
61
62    Ok(())
63}
64
65/// Validates a JSON-RPC request
66pub fn validate_jsonrpc_request(request: &JsonRpcRequest) -> McpResult<()> {
67    if request.jsonrpc != "2.0" {
68        return Err(McpError::Validation("jsonrpc must be '2.0'".to_string()));
69    }
70
71    if request.method.is_empty() {
72        return Err(McpError::Validation(
73            "Method name cannot be empty".to_string(),
74        ));
75    }
76
77    // Method names starting with "rpc." are reserved for JSON-RPC internal methods
78    if request.method.starts_with("rpc.") && !request.method.starts_with("rpc.discover") {
79        return Err(McpError::Validation(
80            "Method names starting with 'rpc.' are reserved".to_string(),
81        ));
82    }
83
84    Ok(())
85}
86
87/// Validates a JSON-RPC response
88pub fn validate_jsonrpc_response(response: &JsonRpcResponse) -> McpResult<()> {
89    if response.jsonrpc != "2.0" {
90        return Err(McpError::Validation("jsonrpc must be '2.0'".to_string()));
91    }
92
93    // Must have either result or error, but not both
94    match (&response.result, &response.error) {
95        (Some(_), Some(_)) => Err(McpError::Validation(
96            "Response cannot have both result and error".to_string(),
97        )),
98        (None, None) => Err(McpError::Validation(
99            "Response must have either result or error".to_string(),
100        )),
101        _ => Ok(()),
102    }
103}
104
105/// Validates a JSON-RPC notification
106pub fn validate_jsonrpc_notification(notification: &JsonRpcNotification) -> McpResult<()> {
107    if notification.jsonrpc != "2.0" {
108        return Err(McpError::Validation("jsonrpc must be '2.0'".to_string()));
109    }
110
111    if notification.method.is_empty() {
112        return Err(McpError::Validation(
113            "Method name cannot be empty".to_string(),
114        ));
115    }
116
117    Ok(())
118}
119
120/// Validates initialization parameters
121pub fn validate_initialize_params(params: &InitializeParams) -> McpResult<()> {
122    if params.client_info.name.is_empty() {
123        return Err(McpError::Validation(
124            "Client name cannot be empty".to_string(),
125        ));
126    }
127
128    if params.client_info.version.is_empty() {
129        return Err(McpError::Validation(
130            "Client version cannot be empty".to_string(),
131        ));
132    }
133
134    if params.protocol_version.is_empty() {
135        return Err(McpError::Validation(
136            "Protocol version cannot be empty".to_string(),
137        ));
138    }
139
140    Ok(())
141}
142
143/// Validates tool information
144pub fn validate_tool_info(tool: &ToolInfo) -> McpResult<()> {
145    if tool.name.is_empty() {
146        return Err(McpError::Validation(
147            "Tool name cannot be empty".to_string(),
148        ));
149    }
150
151    // Validate that input_schema is a valid JSON Schema object
152    if !tool.input_schema.is_object() {
153        return Err(McpError::Validation(
154            "Tool input_schema must be a JSON object".to_string(),
155        ));
156    }
157
158    Ok(())
159}
160
161/// Validates tool call parameters
162pub fn validate_call_tool_params(params: &CallToolParams) -> McpResult<()> {
163    if params.name.is_empty() {
164        return Err(McpError::Validation(
165            "Tool name cannot be empty".to_string(),
166        ));
167    }
168
169    Ok(())
170}
171
172/// Validates resource information
173pub fn validate_resource_info(resource: &ResourceInfo) -> McpResult<()> {
174    if resource.uri.is_empty() {
175        return Err(McpError::Validation(
176            "Resource URI cannot be empty".to_string(),
177        ));
178    }
179
180    if resource.name.is_empty() {
181        return Err(McpError::Validation(
182            "Resource name cannot be empty".to_string(),
183        ));
184    }
185
186    // Basic URI validation - check if it looks like a valid URI
187    validate_uri(&resource.uri)?;
188
189    Ok(())
190}
191
192/// Validates resource read parameters
193pub fn validate_read_resource_params(params: &ReadResourceParams) -> McpResult<()> {
194    if params.uri.is_empty() {
195        return Err(McpError::Validation(
196            "Resource URI cannot be empty".to_string(),
197        ));
198    }
199
200    validate_uri(&params.uri)?;
201
202    Ok(())
203}
204
205/// Validates resource content
206pub fn validate_resource_content(content: &ResourceContent) -> McpResult<()> {
207    if content.uri.is_empty() {
208        return Err(McpError::Validation(
209            "Resource content URI cannot be empty".to_string(),
210        ));
211    }
212
213    // Must have either text or blob content
214    match (&content.text, &content.blob) {
215        (Some(_), Some(_)) => Err(McpError::Validation(
216            "Resource content cannot have both text and blob".to_string(),
217        )),
218        (None, None) => Err(McpError::Validation(
219            "Resource content must have either text or blob".to_string(),
220        )),
221        _ => Ok(()),
222    }
223}
224
225/// Validates prompt information
226pub fn validate_prompt_info(prompt: &PromptInfo) -> McpResult<()> {
227    if prompt.name.is_empty() {
228        return Err(McpError::Validation(
229            "Prompt name cannot be empty".to_string(),
230        ));
231    }
232
233    if let Some(args) = &prompt.arguments {
234        for arg in args {
235            if arg.name.is_empty() {
236                return Err(McpError::Validation(
237                    "Prompt argument name cannot be empty".to_string(),
238                ));
239            }
240        }
241    }
242
243    Ok(())
244}
245
246/// Validates prompt get parameters
247pub fn validate_get_prompt_params(params: &GetPromptParams) -> McpResult<()> {
248    if params.name.is_empty() {
249        return Err(McpError::Validation(
250            "Prompt name cannot be empty".to_string(),
251        ));
252    }
253
254    Ok(())
255}
256
257/// Validates prompt messages
258pub fn validate_prompt_messages(messages: &[PromptMessage]) -> McpResult<()> {
259    if messages.is_empty() {
260        return Err(McpError::Validation(
261            "Prompt must have at least one message".to_string(),
262        ));
263    }
264
265    for message in messages {
266        if message.role.is_empty() {
267            return Err(McpError::Validation(
268                "Message role cannot be empty".to_string(),
269            ));
270        }
271    }
272
273    Ok(())
274}
275
276/// Validates sampling messages
277pub fn validate_sampling_messages(messages: &[SamplingMessage]) -> McpResult<()> {
278    if messages.is_empty() {
279        return Err(McpError::Validation(
280            "Sampling request must have at least one message".to_string(),
281        ));
282    }
283
284    for message in messages {
285        if message.role.is_empty() {
286            return Err(McpError::Validation(
287                "Message role cannot be empty".to_string(),
288            ));
289        }
290    }
291
292    Ok(())
293}
294
295/// Validates create message parameters
296pub fn validate_create_message_params(params: &CreateMessageParams) -> McpResult<()> {
297    validate_sampling_messages(&params.messages)?;
298
299    // Validate temperature range
300    if let Some(temp) = params.temperature {
301        if !(0.0..=2.0).contains(&temp) {
302            return Err(McpError::Validation(
303                "Temperature must be between 0.0 and 2.0".to_string(),
304            ));
305        }
306    }
307
308    // Validate top_p range
309    if let Some(top_p) = params.top_p {
310        if !(0.0..=1.0).contains(&top_p) {
311            return Err(McpError::Validation(
312                "top_p must be between 0.0 and 1.0".to_string(),
313            ));
314        }
315    }
316
317    // Validate max_tokens
318    if let Some(max_tokens) = params.max_tokens {
319        if max_tokens == 0 {
320            return Err(McpError::Validation(
321                "max_tokens must be greater than 0".to_string(),
322            ));
323        }
324    }
325
326    Ok(())
327}
328
329/// Validates content
330pub fn validate_content(content: &Content) -> McpResult<()> {
331    match content {
332        Content::Text { text } => {
333            if text.is_empty() {
334                return Err(McpError::Validation(
335                    "Text content cannot be empty".to_string(),
336                ));
337            }
338        }
339        Content::Image { data, mime_type } => {
340            if data.is_empty() {
341                return Err(McpError::Validation(
342                    "Image data cannot be empty".to_string(),
343                ));
344            }
345            if mime_type.is_empty() {
346                return Err(McpError::Validation(
347                    "Image MIME type cannot be empty".to_string(),
348                ));
349            }
350            if !mime_type.starts_with("image/") {
351                return Err(McpError::Validation(
352                    "Image MIME type must start with 'image/'".to_string(),
353                ));
354            }
355        }
356    }
357
358    Ok(())
359}
360
361/// Basic URI validation
362pub fn validate_uri(uri: &str) -> McpResult<()> {
363    if uri.is_empty() {
364        return Err(McpError::Validation("URI cannot be empty".to_string()));
365    }
366
367    // Basic check for scheme
368    if !uri.contains("://") && !uri.starts_with('/') && !uri.starts_with("file:") {
369        return Err(McpError::Validation(
370            "URI must have a scheme or be an absolute path".to_string(),
371        ));
372    }
373
374    Ok(())
375}
376
377/// Validates method name against MCP specification
378pub fn validate_method_name(method: &str) -> McpResult<()> {
379    if method.is_empty() {
380        return Err(McpError::Validation(
381            "Method name cannot be empty".to_string(),
382        ));
383    }
384
385    // Check for valid MCP method patterns
386    match method {
387        methods::INITIALIZE
388        | methods::PING
389        | methods::TOOLS_LIST
390        | methods::TOOLS_CALL
391        | methods::TOOLS_LIST_CHANGED
392        | methods::RESOURCES_LIST
393        | methods::RESOURCES_READ
394        | methods::RESOURCES_SUBSCRIBE
395        | methods::RESOURCES_UNSUBSCRIBE
396        | methods::RESOURCES_UPDATED
397        | methods::RESOURCES_LIST_CHANGED
398        | methods::PROMPTS_LIST
399        | methods::PROMPTS_GET
400        | methods::PROMPTS_LIST_CHANGED
401        | methods::SAMPLING_CREATE_MESSAGE
402        | methods::LOGGING_SET_LEVEL
403        | methods::LOGGING_MESSAGE
404        | methods::PROGRESS => Ok(()),
405        _ => {
406            // Allow custom methods if they follow naming conventions
407            if method.contains('/') || method.contains('.') {
408                Ok(())
409            } else {
410                Err(McpError::Validation(format!(
411                    "Unknown or invalid method name: {}",
412                    method
413                )))
414            }
415        }
416    }
417}
418
419/// Validates server capabilities
420pub fn validate_server_capabilities(_capabilities: &ServerCapabilities) -> McpResult<()> {
421    // All capability structures are currently valid if they exist
422    // Future versions might add validation for specific capability values
423    Ok(())
424}
425
426/// Validates client capabilities
427pub fn validate_client_capabilities(_capabilities: &ClientCapabilities) -> McpResult<()> {
428    // All capability structures are currently valid if they exist
429    // Future versions might add validation for specific capability values
430    Ok(())
431}
432
433/// Validates progress parameters
434pub fn validate_progress_params(params: &ProgressParams) -> McpResult<()> {
435    if params.progress_token.is_empty() {
436        return Err(McpError::Validation(
437            "Progress token cannot be empty".to_string(),
438        ));
439    }
440
441    if !(0.0..=1.0).contains(&params.progress) {
442        return Err(McpError::Validation(
443            "Progress must be between 0.0 and 1.0".to_string(),
444        ));
445    }
446
447    Ok(())
448}
449
450/// Validates logging message parameters
451pub fn validate_logging_message_params(params: &LoggingMessageParams) -> McpResult<()> {
452    // Logger name can be empty (optional), but data cannot be null
453    if params.data.is_null() {
454        return Err(McpError::Validation(
455            "Log message data cannot be null".to_string(),
456        ));
457    }
458
459    Ok(())
460}
461
462/// Comprehensive validation for any MCP request
463pub fn validate_mcp_request(method: &str, params: Option<&Value>) -> McpResult<()> {
464    validate_method_name(method)?;
465
466    if let Some(params_value) = params {
467        match method {
468            methods::INITIALIZE => {
469                let params: InitializeParams = serde_json::from_value(params_value.clone())
470                    .map_err(|e| {
471                        McpError::Validation(format!("Invalid initialize params: {}", e))
472                    })?;
473                validate_initialize_params(&params)?;
474            }
475            methods::TOOLS_CALL => {
476                let params: CallToolParams =
477                    serde_json::from_value(params_value.clone()).map_err(|e| {
478                        McpError::Validation(format!("Invalid call tool params: {}", e))
479                    })?;
480                validate_call_tool_params(&params)?;
481            }
482            methods::RESOURCES_READ => {
483                let params: ReadResourceParams = serde_json::from_value(params_value.clone())
484                    .map_err(|e| {
485                        McpError::Validation(format!("Invalid read resource params: {}", e))
486                    })?;
487                validate_read_resource_params(&params)?;
488            }
489            methods::PROMPTS_GET => {
490                let params: GetPromptParams = serde_json::from_value(params_value.clone())
491                    .map_err(|e| {
492                        McpError::Validation(format!("Invalid get prompt params: {}", e))
493                    })?;
494                validate_get_prompt_params(&params)?;
495            }
496            methods::SAMPLING_CREATE_MESSAGE => {
497                let params: CreateMessageParams = serde_json::from_value(params_value.clone())
498                    .map_err(|e| {
499                        McpError::Validation(format!("Invalid create message params: {}", e))
500                    })?;
501                validate_create_message_params(&params)?;
502            }
503            methods::PROGRESS => {
504                let params: ProgressParams = serde_json::from_value(params_value.clone())
505                    .map_err(|e| McpError::Validation(format!("Invalid progress params: {}", e)))?;
506                validate_progress_params(&params)?;
507            }
508            methods::LOGGING_MESSAGE => {
509                let params: LoggingMessageParams = serde_json::from_value(params_value.clone())
510                    .map_err(|e| {
511                        McpError::Validation(format!("Invalid logging message params: {}", e))
512                    })?;
513                validate_logging_message_params(&params)?;
514            }
515            _ => {
516                // For other methods, we just validate that params is a valid JSON object if present
517                if !params_value.is_object() && !params_value.is_null() {
518                    return Err(McpError::Validation(
519                        "Parameters must be a JSON object or null".to_string(),
520                    ));
521                }
522            }
523        }
524    }
525
526    Ok(())
527}
528
529#[cfg(test)]
530mod tests {
531    use super::*;
532    use serde_json::json;
533
534    #[test]
535    fn test_validate_jsonrpc_request() {
536        let valid_request = JsonRpcRequest {
537            jsonrpc: "2.0".to_string(),
538            id: json!(1),
539            method: "test_method".to_string(),
540            params: None,
541        };
542        assert!(validate_jsonrpc_request(&valid_request).is_ok());
543
544        let invalid_request = JsonRpcRequest {
545            jsonrpc: "1.0".to_string(),
546            id: json!(1),
547            method: "test_method".to_string(),
548            params: None,
549        };
550        assert!(validate_jsonrpc_request(&invalid_request).is_err());
551    }
552
553    #[test]
554    fn test_validate_uri() {
555        assert!(validate_uri("https://example.com").is_ok());
556        assert!(validate_uri("file:///path/to/file").is_ok());
557        assert!(validate_uri("/absolute/path").is_ok());
558        assert!(validate_uri("").is_err());
559        assert!(validate_uri("invalid").is_err());
560    }
561
562    #[test]
563    fn test_validate_tool_info() {
564        let valid_tool = ToolInfo {
565            name: "test_tool".to_string(),
566            description: Some("A test tool".to_string()),
567            input_schema: json!({
568                "type": "object",
569                "properties": {
570                    "param": {"type": "string"}
571                }
572            }),
573        };
574        assert!(validate_tool_info(&valid_tool).is_ok());
575
576        let invalid_tool = ToolInfo {
577            name: "".to_string(),
578            description: None,
579            input_schema: json!("not an object"),
580        };
581        assert!(validate_tool_info(&invalid_tool).is_err());
582    }
583
584    #[test]
585    fn test_validate_create_message_params() {
586        let valid_params = CreateMessageParams {
587            messages: vec![SamplingMessage::user("Hello")],
588            model_preferences: None,
589            system_prompt: None,
590            include_context: None,
591            max_tokens: Some(100),
592            temperature: Some(0.7),
593            top_p: Some(0.9),
594            stop_sequences: None,
595            metadata: None,
596        };
597        assert!(validate_create_message_params(&valid_params).is_ok());
598
599        let invalid_params = CreateMessageParams {
600            messages: vec![],
601            model_preferences: None,
602            system_prompt: None,
603            include_context: None,
604            max_tokens: None,
605            temperature: Some(3.0), // Invalid temperature
606            top_p: None,
607            stop_sequences: None,
608            metadata: None,
609        };
610        assert!(validate_create_message_params(&invalid_params).is_err());
611    }
612
613    #[test]
614    fn test_validate_content() {
615        let valid_text = Content::text("Hello, world!");
616        assert!(validate_content(&valid_text).is_ok());
617
618        let valid_image = Content::image("base64data", "image/png");
619        assert!(validate_content(&valid_image).is_ok());
620
621        let invalid_text = Content::Text {
622            text: "".to_string(),
623        };
624        assert!(validate_content(&invalid_text).is_err());
625
626        let invalid_image = Content::Image {
627            data: "data".to_string(),
628            mime_type: "text/plain".to_string(), // Invalid MIME type for image
629        };
630        assert!(validate_content(&invalid_image).is_err());
631    }
632
633    #[test]
634    fn test_validate_method_name() {
635        assert!(validate_method_name(methods::INITIALIZE).is_ok());
636        assert!(validate_method_name(methods::TOOLS_LIST).is_ok());
637        assert!(validate_method_name("custom/method").is_ok());
638        assert!(validate_method_name("custom.method").is_ok());
639        assert!(validate_method_name("").is_err());
640    }
641
642    #[test]
643    fn test_validate_mcp_request() {
644        let init_params = json!({
645            "clientInfo": {
646                "name": "test-client",
647                "version": "1.0.0"
648            },
649            "capabilities": {},
650            "protocolVersion": "2024-11-05"
651        });
652
653        assert!(validate_mcp_request(methods::INITIALIZE, Some(&init_params)).is_ok());
654        assert!(validate_mcp_request(methods::PING, None).is_ok());
655        assert!(validate_mcp_request("", None).is_err());
656    }
657}