mcp_protocol_sdk/core/
validation.rs

1//! Advanced tool validation system for MCP SDK
2//!
3//! This module provides comprehensive parameter validation, type checking,
4//! and coercion capabilities for tool arguments according to JSON Schema specifications.
5
6use crate::core::error::{McpError, McpResult};
7use serde_json::{Map, Value};
8use std::collections::{HashMap, HashSet};
9
10/// Helper function to get a human-readable type name for a JSON value
11fn get_value_type_name(value: &Value) -> &'static str {
12    match value {
13        Value::Null => "null",
14        Value::Bool(_) => "boolean",
15        Value::Number(_) => "number",
16        Value::String(_) => "string",
17        Value::Array(_) => "array",
18        Value::Object(_) => "object",
19    }
20}
21
22/// Parameter validation configuration
23#[derive(Debug, Clone)]
24pub struct ValidationConfig {
25    /// Whether to allow additional properties not in schema
26    pub allow_additional: bool,
27    /// Whether to coerce types when possible (e.g., string "5" -> number 5)
28    pub coerce_types: bool,
29    /// Whether to provide detailed validation errors
30    pub detailed_errors: bool,
31    /// Maximum string length for validation
32    pub max_string_length: Option<usize>,
33    /// Maximum array length for validation
34    pub max_array_length: Option<usize>,
35    /// Maximum object property count
36    pub max_object_properties: Option<usize>,
37}
38
39impl Default for ValidationConfig {
40    fn default() -> Self {
41        Self {
42            allow_additional: true,
43            coerce_types: true,
44            detailed_errors: true,
45            max_string_length: Some(10_000),
46            max_array_length: Some(1_000),
47            max_object_properties: Some(100),
48        }
49    }
50}
51
52/// Enhanced JSON Schema validator for tool parameters
53#[derive(Debug, Clone)]
54pub struct ParameterValidator {
55    /// JSON Schema for validation
56    pub schema: Value,
57    /// Validation configuration
58    pub config: ValidationConfig,
59}
60
61impl ParameterValidator {
62    /// Create a new parameter validator with schema
63    pub fn new(schema: Value) -> Self {
64        Self {
65            schema,
66            config: ValidationConfig::default(),
67        }
68    }
69
70    /// Create validator with custom configuration
71    pub fn with_config(schema: Value, config: ValidationConfig) -> Self {
72        Self { schema, config }
73    }
74
75    /// Validate and optionally coerce parameters
76    pub fn validate_and_coerce(&self, params: &mut HashMap<String, Value>) -> McpResult<()> {
77        let schema_obj = self
78            .schema
79            .as_object()
80            .ok_or_else(|| McpError::validation("Schema must be an object"))?;
81
82        // Check type
83        if let Some(schema_type) = schema_obj.get("type") {
84            if schema_type.as_str() != Some("object") {
85                return Err(McpError::validation("Tool schema type must be 'object'"));
86            }
87        }
88
89        // Validate required properties
90        if let Some(required) = schema_obj.get("required") {
91            self.validate_required_properties(params, required)?;
92        }
93
94        // Validate individual properties
95        if let Some(properties) = schema_obj.get("properties") {
96            self.validate_properties(params, properties)?;
97        }
98
99        // Check additional properties
100        if !self.config.allow_additional {
101            self.check_additional_properties(params, schema_obj)?;
102        }
103
104        // Check object size limits
105        if let Some(max_props) = self.config.max_object_properties {
106            if params.len() > max_props {
107                return Err(McpError::validation(format!(
108                    "Too many properties: {} > {}",
109                    params.len(),
110                    max_props
111                )));
112            }
113        }
114
115        Ok(())
116    }
117
118    /// Validate required properties are present
119    fn validate_required_properties(
120        &self,
121        params: &HashMap<String, Value>,
122        required: &Value,
123    ) -> McpResult<()> {
124        let required_array = required
125            .as_array()
126            .ok_or_else(|| McpError::validation("Required field must be an array"))?;
127
128        for req in required_array {
129            let prop_name = req
130                .as_str()
131                .ok_or_else(|| McpError::validation("Required property names must be strings"))?;
132
133            if !params.contains_key(prop_name) {
134                return Err(McpError::validation(format!(
135                    "Missing required parameter: '{prop_name}'"
136                )));
137            }
138        }
139
140        Ok(())
141    }
142
143    /// Validate and coerce individual properties
144    fn validate_properties(
145        &self,
146        params: &mut HashMap<String, Value>,
147        properties: &Value,
148    ) -> McpResult<()> {
149        let props_obj = properties
150            .as_object()
151            .ok_or_else(|| McpError::validation("Properties must be an object"))?;
152
153        for (prop_name, value) in params.iter_mut() {
154            if let Some(prop_schema) = props_obj.get(prop_name) {
155                self.validate_and_coerce_value(value, prop_schema, prop_name)?;
156            }
157        }
158
159        Ok(())
160    }
161
162    /// Validate and coerce a single value according to its schema
163    fn validate_and_coerce_value(
164        &self,
165        value: &mut Value,
166        schema: &Value,
167        field_name: &str,
168    ) -> McpResult<()> {
169        let schema_obj = schema.as_object().ok_or_else(|| {
170            McpError::validation(format!("Schema for '{field_name}' must be an object"))
171        })?;
172
173        // Get expected type
174        let expected_type = schema_obj
175            .get("type")
176            .and_then(|t| t.as_str())
177            .unwrap_or("any");
178
179        match expected_type {
180            "string" => self.validate_string(value, schema_obj, field_name)?,
181            "number" | "integer" => self.validate_number(value, schema_obj, field_name)?,
182            "boolean" => self.validate_boolean(value, field_name)?,
183            "array" => self.validate_array(value, schema_obj, field_name)?,
184            "object" => self.validate_object(value, schema_obj, field_name)?,
185            "null" => self.validate_null(value, field_name)?,
186            _ => {} // Allow any type
187        }
188
189        // Validate enum constraints
190        if let Some(enum_values) = schema_obj.get("enum") {
191            self.validate_enum(value, enum_values, field_name)?;
192        }
193
194        Ok(())
195    }
196
197    /// Validate and coerce string values
198    fn validate_string(
199        &self,
200        value: &mut Value,
201        schema: &Map<String, Value>,
202        field_name: &str,
203    ) -> McpResult<()> {
204        // Type coercion
205        if self.config.coerce_types && !value.is_string() {
206            if let Some(coerced) = self.coerce_to_string(value) {
207                *value = coerced;
208            } else {
209                return Err(McpError::validation(format!(
210                    "Parameter '{}' must be a string, got {}",
211                    field_name,
212                    get_value_type_name(value)
213                )));
214            }
215        }
216
217        let string_val = value.as_str().ok_or_else(|| {
218            McpError::validation(format!("Parameter '{field_name}' must be a string"))
219        })?;
220
221        // Length validation
222        if let Some(max_len) = self.config.max_string_length {
223            if string_val.len() > max_len {
224                return Err(McpError::validation(format!(
225                    "String '{}' too long: {} > {}",
226                    field_name,
227                    string_val.len(),
228                    max_len
229                )));
230            }
231        }
232
233        // Schema-specific length constraints
234        if let Some(min_len) = schema.get("minLength").and_then(|v| v.as_u64()) {
235            if string_val.len() < min_len as usize {
236                return Err(McpError::validation(format!(
237                    "String '{}' too short: {} < {}",
238                    field_name,
239                    string_val.len(),
240                    min_len
241                )));
242            }
243        }
244
245        if let Some(max_len) = schema.get("maxLength").and_then(|v| v.as_u64()) {
246            if string_val.len() > max_len as usize {
247                return Err(McpError::validation(format!(
248                    "String '{}' too long: {} > {}",
249                    field_name,
250                    string_val.len(),
251                    max_len
252                )));
253            }
254        }
255
256        // Pattern validation
257        if let Some(pattern) = schema.get("pattern").and_then(|v| v.as_str()) {
258            // Note: Full regex validation would require the regex crate
259            // For now, we'll do basic validation checks
260            if pattern.contains("^") && !string_val.starts_with(&pattern[1..pattern.len().min(2)]) {
261                return Err(McpError::validation(format!(
262                    "String '{field_name}' does not match pattern"
263                )));
264            }
265        }
266
267        Ok(())
268    }
269
270    /// Validate and coerce number values
271    fn validate_number(
272        &self,
273        value: &mut Value,
274        schema: &Map<String, Value>,
275        field_name: &str,
276    ) -> McpResult<()> {
277        // Type coercion
278        if self.config.coerce_types && !value.is_number() {
279            if let Some(coerced) = self.coerce_to_number(value) {
280                *value = coerced;
281            } else {
282                return Err(McpError::validation(format!(
283                    "Parameter '{}' must be a number, got {}",
284                    field_name,
285                    get_value_type_name(value)
286                )));
287            }
288        }
289
290        let num_val = value.as_f64().ok_or_else(|| {
291            McpError::validation(format!("Parameter '{field_name}' must be a number"))
292        })?;
293
294        // Range validation
295        if let Some(minimum) = schema.get("minimum").and_then(|v| v.as_f64()) {
296            if num_val < minimum {
297                return Err(McpError::validation(format!(
298                    "Number '{field_name}' too small: {num_val} < {minimum}"
299                )));
300            }
301        }
302
303        if let Some(maximum) = schema.get("maximum").and_then(|v| v.as_f64()) {
304            if num_val > maximum {
305                return Err(McpError::validation(format!(
306                    "Number '{field_name}' too large: {num_val} > {maximum}"
307                )));
308            }
309        }
310
311        // Integer validation
312        if schema.get("type").and_then(|v| v.as_str()) == Some("integer") {
313            if num_val.fract() != 0.0 {
314                if self.config.coerce_types {
315                    *value = Value::Number(serde_json::Number::from(num_val.round() as i64));
316                } else {
317                    return Err(McpError::validation(format!(
318                        "Parameter '{field_name}' must be an integer"
319                    )));
320                }
321            } else {
322                // Convert float to integer even if it has no fractional part
323                *value = Value::Number(serde_json::Number::from(num_val as i64));
324            }
325        }
326
327        Ok(())
328    }
329
330    /// Validate and coerce boolean values
331    fn validate_boolean(&self, value: &mut Value, field_name: &str) -> McpResult<()> {
332        // Type coercion
333        if self.config.coerce_types && !value.is_boolean() {
334            if let Some(coerced) = self.coerce_to_boolean(value) {
335                *value = coerced;
336            } else {
337                return Err(McpError::validation(format!(
338                    "Parameter '{}' must be a boolean, got {}",
339                    field_name,
340                    get_value_type_name(value)
341                )));
342            }
343        }
344
345        if !value.is_boolean() {
346            return Err(McpError::validation(format!(
347                "Parameter '{field_name}' must be a boolean"
348            )));
349        }
350
351        Ok(())
352    }
353
354    /// Validate array values
355    fn validate_array(
356        &self,
357        value: &mut Value,
358        schema: &Map<String, Value>,
359        field_name: &str,
360    ) -> McpResult<()> {
361        let array = value.as_array_mut().ok_or_else(|| {
362            McpError::validation(format!("Parameter '{field_name}' must be an array"))
363        })?;
364
365        // Length validation
366        if let Some(max_len) = self.config.max_array_length {
367            if array.len() > max_len {
368                return Err(McpError::validation(format!(
369                    "Array '{}' too long: {} > {}",
370                    field_name,
371                    array.len(),
372                    max_len
373                )));
374            }
375        }
376
377        if let Some(min_items) = schema.get("minItems").and_then(|v| v.as_u64()) {
378            if array.len() < min_items as usize {
379                return Err(McpError::validation(format!(
380                    "Array '{}' too short: {} < {}",
381                    field_name,
382                    array.len(),
383                    min_items
384                )));
385            }
386        }
387
388        if let Some(max_items) = schema.get("maxItems").and_then(|v| v.as_u64()) {
389            if array.len() > max_items as usize {
390                return Err(McpError::validation(format!(
391                    "Array '{}' too long: {} > {}",
392                    field_name,
393                    array.len(),
394                    max_items
395                )));
396            }
397        }
398
399        // Validate each item if items schema is provided
400        if let Some(items_schema) = schema.get("items") {
401            for (i, item) in array.iter_mut().enumerate() {
402                let item_field = format!("{field_name}[{i}]");
403                self.validate_and_coerce_value(item, items_schema, &item_field)?;
404            }
405        }
406
407        Ok(())
408    }
409
410    /// Validate object values
411    fn validate_object(
412        &self,
413        value: &mut Value,
414        _schema: &Map<String, Value>,
415        field_name: &str,
416    ) -> McpResult<()> {
417        let obj = value.as_object().ok_or_else(|| {
418            McpError::validation(format!("Parameter '{field_name}' must be an object"))
419        })?;
420
421        // Object size validation
422        if let Some(max_props) = self.config.max_object_properties {
423            if obj.len() > max_props {
424                return Err(McpError::validation(format!(
425                    "Object '{}' has too many properties: {} > {}",
426                    field_name,
427                    obj.len(),
428                    max_props
429                )));
430            }
431        }
432
433        Ok(())
434    }
435
436    /// Validate null values
437    fn validate_null(&self, value: &Value, field_name: &str) -> McpResult<()> {
438        if !value.is_null() {
439            return Err(McpError::validation(format!(
440                "Parameter '{field_name}' must be null"
441            )));
442        }
443        Ok(())
444    }
445
446    /// Validate enum constraints
447    fn validate_enum(&self, value: &Value, enum_values: &Value, field_name: &str) -> McpResult<()> {
448        let enum_array = enum_values
449            .as_array()
450            .ok_or_else(|| McpError::validation("Enum must be an array"))?;
451
452        if !enum_array.contains(value) {
453            return Err(McpError::validation(format!(
454                "Parameter '{field_name}' must be one of: {enum_array:?}"
455            )));
456        }
457
458        Ok(())
459    }
460
461    /// Check for disallowed additional properties
462    fn check_additional_properties(
463        &self,
464        params: &HashMap<String, Value>,
465        schema: &Map<String, Value>,
466    ) -> McpResult<()> {
467        if let Some(properties) = schema.get("properties").and_then(|p| p.as_object()) {
468            let allowed_props: HashSet<_> = properties.keys().collect();
469            let actual_props: HashSet<_> = params.keys().collect();
470            let additional: Vec<_> = actual_props.difference(&allowed_props).collect();
471
472            if !additional.is_empty() {
473                return Err(McpError::validation(format!(
474                    "Additional properties not allowed: {additional:?}"
475                )));
476            }
477        }
478
479        Ok(())
480    }
481
482    /// Type coercion helpers
483    fn coerce_to_string(&self, value: &Value) -> Option<Value> {
484        match value {
485            Value::Number(n) => Some(Value::String(n.to_string())),
486            Value::Bool(b) => Some(Value::String(b.to_string())),
487            Value::Null => Some(Value::String("null".to_string())),
488            _ => None,
489        }
490    }
491
492    fn coerce_to_number(&self, value: &Value) -> Option<Value> {
493        match value {
494            Value::String(s) => {
495                if let Ok(f) = s.parse::<f64>() {
496                    serde_json::Number::from_f64(f).map(Value::Number)
497                } else {
498                    None
499                }
500            }
501            Value::Bool(true) => Some(Value::Number(serde_json::Number::from(1))),
502            Value::Bool(false) => Some(Value::Number(serde_json::Number::from(0))),
503            _ => None,
504        }
505    }
506
507    fn coerce_to_boolean(&self, value: &Value) -> Option<Value> {
508        match value {
509            Value::String(s) => match s.to_lowercase().as_str() {
510                "true" | "1" | "yes" | "on" => Some(Value::Bool(true)),
511                "false" | "0" | "no" | "off" | "" => Some(Value::Bool(false)),
512                _ => None,
513            },
514            Value::Number(n) => {
515                if let Some(i) = n.as_i64() {
516                    Some(Value::Bool(i != 0))
517                } else {
518                    Some(Value::Bool(n.as_f64().unwrap_or(0.0) != 0.0))
519                }
520            }
521            Value::Null => Some(Value::Bool(false)),
522            _ => None,
523        }
524    }
525}
526
527/// Helper trait for creating typed parameter validators
528pub trait ParameterType {
529    /// Create a JSON schema for this parameter type
530    fn to_schema() -> Value;
531
532    /// Validate and extract value from parameters
533    fn from_params(params: &HashMap<String, Value>, name: &str) -> McpResult<Self>
534    where
535        Self: Sized;
536}
537
538/// Implementation for basic types
539impl ParameterType for String {
540    fn to_schema() -> Value {
541        serde_json::json!({
542            "type": "string"
543        })
544    }
545
546    fn from_params(params: &HashMap<String, Value>, name: &str) -> McpResult<Self> {
547        params
548            .get(name)
549            .and_then(|v| v.as_str())
550            .map(|s| s.to_string())
551            .ok_or_else(|| McpError::validation(format!("Missing string parameter: {name}")))
552    }
553}
554
555impl ParameterType for i64 {
556    fn to_schema() -> Value {
557        serde_json::json!({
558            "type": "integer"
559        })
560    }
561
562    fn from_params(params: &HashMap<String, Value>, name: &str) -> McpResult<Self> {
563        params
564            .get(name)
565            .and_then(|v| v.as_i64())
566            .ok_or_else(|| McpError::validation(format!("Missing integer parameter: {name}")))
567    }
568}
569
570impl ParameterType for f64 {
571    fn to_schema() -> Value {
572        serde_json::json!({
573            "type": "number"
574        })
575    }
576
577    fn from_params(params: &HashMap<String, Value>, name: &str) -> McpResult<Self> {
578        params
579            .get(name)
580            .and_then(|v| v.as_f64())
581            .ok_or_else(|| McpError::validation(format!("Missing number parameter: {name}")))
582    }
583}
584
585impl ParameterType for bool {
586    fn to_schema() -> Value {
587        serde_json::json!({
588            "type": "boolean"
589        })
590    }
591
592    fn from_params(params: &HashMap<String, Value>, name: &str) -> McpResult<Self> {
593        params
594            .get(name)
595            .and_then(|v| v.as_bool())
596            .ok_or_else(|| McpError::validation(format!("Missing boolean parameter: {name}")))
597    }
598}
599
600/// Macro for creating parameter validation schemas
601#[macro_export]
602macro_rules! param_schema {
603    // String parameter
604    (string $name:expr_2021) => {
605        ($name, serde_json::json!({"type": "string"}))
606    };
607
608    // String with constraints
609    (string $name:expr_2021, min: $min:expr_2021) => {
610        ($name, serde_json::json!({"type": "string", "minLength": $min}))
611    };
612
613    (string $name:expr_2021, max: $max:expr_2021) => {
614        ($name, serde_json::json!({"type": "string", "maxLength": $max}))
615    };
616
617    (string $name:expr_2021, min: $min:expr_2021, max: $max:expr_2021) => {
618        ($name, serde_json::json!({"type": "string", "minLength": $min, "maxLength": $max}))
619    };
620
621    // Number parameter
622    (number $name:expr_2021) => {
623        ($name, serde_json::json!({"type": "number"}))
624    };
625
626    (number $name:expr_2021, min: $min:expr_2021) => {
627        ($name, serde_json::json!({"type": "number", "minimum": $min}))
628    };
629
630    (number $name:expr_2021, max: $max:expr_2021) => {
631        ($name, serde_json::json!({"type": "number", "maximum": $max}))
632    };
633
634    (number $name:expr_2021, min: $min:expr_2021, max: $max:expr_2021) => {
635        ($name, serde_json::json!({"type": "number", "minimum": $min, "maximum": $max}))
636    };
637
638    // Integer parameter
639    (integer $name:expr_2021) => {
640        ($name, serde_json::json!({"type": "integer"}))
641    };
642
643    (integer $name:expr_2021, min: $min:expr_2021) => {
644        ($name, serde_json::json!({"type": "integer", "minimum": $min}))
645    };
646
647    (integer $name:expr_2021, max: $max:expr_2021) => {
648        ($name, serde_json::json!({"type": "integer", "maximum": $max}))
649    };
650
651    (integer $name:expr_2021, min: $min:expr_2021, max: $max:expr_2021) => {
652        ($name, serde_json::json!({"type": "integer", "minimum": $min, "maximum": $max}))
653    };
654
655    // Boolean parameter
656    (boolean $name:expr_2021) => {
657        ($name, serde_json::json!({"type": "boolean"}))
658    };
659
660    // Array parameter
661    (array $name:expr_2021, items: $items:expr_2021) => {
662        ($name, serde_json::json!({"type": "array", "items": $items}))
663    };
664
665    // Enum parameter
666    (enum $name:expr_2021, values: [$($val:expr_2021),*]) => {
667        ($name, serde_json::json!({"type": "string", "enum": [$($val),*]}))
668    };
669}
670
671/// Helper function to create tool schemas from parameter definitions
672pub fn create_tool_schema(params: Vec<(&str, Value)>, required: Vec<&str>) -> Value {
673    let mut properties = Map::new();
674
675    for (name, schema) in params {
676        properties.insert(name.to_string(), schema);
677    }
678
679    serde_json::json!({
680        "type": "object",
681        "properties": properties,
682        "required": required
683    })
684}
685
686#[cfg(test)]
687mod tests {
688    use super::*;
689    use serde_json::json;
690
691    #[test]
692    fn test_string_validation() {
693        let schema = json!({
694            "type": "object",
695            "properties": {
696                "name": {"type": "string", "minLength": 2, "maxLength": 10}
697            },
698            "required": ["name"]
699        });
700
701        let validator = ParameterValidator::new(schema);
702
703        // Valid string
704        let mut params = HashMap::new();
705        params.insert("name".to_string(), json!("test"));
706        assert!(validator.validate_and_coerce(&mut params).is_ok());
707
708        // String too short
709        let mut params = HashMap::new();
710        params.insert("name".to_string(), json!("a"));
711        assert!(validator.validate_and_coerce(&mut params).is_err());
712
713        // String too long
714        let mut params = HashMap::new();
715        params.insert("name".to_string(), json!("this_is_too_long"));
716        assert!(validator.validate_and_coerce(&mut params).is_err());
717    }
718
719    #[test]
720    fn test_number_validation() {
721        let schema = json!({
722            "type": "object",
723            "properties": {
724                "age": {"type": "integer", "minimum": 0, "maximum": 150}
725            },
726            "required": ["age"]
727        });
728
729        let validator = ParameterValidator::new(schema);
730
731        // Valid number
732        let mut params = HashMap::new();
733        params.insert("age".to_string(), json!(25));
734        assert!(validator.validate_and_coerce(&mut params).is_ok());
735
736        // Number too small
737        let mut params = HashMap::new();
738        params.insert("age".to_string(), json!(-5));
739        assert!(validator.validate_and_coerce(&mut params).is_err());
740
741        // Number too large
742        let mut params = HashMap::new();
743        params.insert("age".to_string(), json!(200));
744        assert!(validator.validate_and_coerce(&mut params).is_err());
745    }
746
747    #[test]
748    fn test_type_coercion() {
749        let schema = json!({
750            "type": "object",
751            "properties": {
752                "count": {"type": "integer"},
753                "flag": {"type": "boolean"},
754                "name": {"type": "string"}
755            }
756        });
757
758        let validator = ParameterValidator::new(schema);
759
760        let mut params = HashMap::new();
761        params.insert("count".to_string(), json!("42")); // String -> Number
762        params.insert("flag".to_string(), json!("true")); // String -> Boolean
763        params.insert("name".to_string(), json!(123)); // Number -> String
764
765        assert!(validator.validate_and_coerce(&mut params).is_ok());
766
767        // Check coercion results
768        assert_eq!(params.get("count").unwrap().as_i64(), Some(42));
769        assert_eq!(params.get("flag").unwrap().as_bool(), Some(true));
770        assert_eq!(params.get("name").unwrap().as_str(), Some("123"));
771    }
772
773    #[test]
774    fn test_param_schema_macro() {
775        let (name, schema) = param_schema!(string "username", min: 3, max: 20);
776        assert_eq!(name, "username");
777        assert_eq!(schema["type"], "string");
778        assert_eq!(schema["minLength"], 3);
779        assert_eq!(schema["maxLength"], 20);
780    }
781
782    #[test]
783    fn test_create_tool_schema() {
784        let schema = create_tool_schema(
785            vec![
786                param_schema!(string "name"),
787                param_schema!(integer "age", min: 0),
788                param_schema!(boolean "active"),
789            ],
790            vec!["name", "age"],
791        );
792
793        assert_eq!(schema["type"], "object");
794        assert!(schema["properties"]["name"]["type"] == "string");
795        assert!(schema["properties"]["age"]["type"] == "integer");
796        assert!(schema["properties"]["active"]["type"] == "boolean");
797        assert_eq!(schema["required"], json!(["name", "age"]));
798    }
799}