1use crate::core::error::{McpError, McpResult};
7use serde_json::{Map, Value};
8use std::collections::{HashMap, HashSet};
9
10fn get_value_type_name(value: &Value) -> &'static str {
12 match value {
13 Value::Null => "null",
14 Value::Bool(_) => "boolean",
15 Value::Number(_) => "number",
16 Value::String(_) => "string",
17 Value::Array(_) => "array",
18 Value::Object(_) => "object",
19 }
20}
21
22#[derive(Debug, Clone)]
24pub struct ValidationConfig {
25 pub allow_additional: bool,
27 pub coerce_types: bool,
29 pub detailed_errors: bool,
31 pub max_string_length: Option<usize>,
33 pub max_array_length: Option<usize>,
35 pub max_object_properties: Option<usize>,
37}
38
39impl Default for ValidationConfig {
40 fn default() -> Self {
41 Self {
42 allow_additional: true,
43 coerce_types: true,
44 detailed_errors: true,
45 max_string_length: Some(10_000),
46 max_array_length: Some(1_000),
47 max_object_properties: Some(100),
48 }
49 }
50}
51
52#[derive(Debug, Clone)]
54pub struct ParameterValidator {
55 pub schema: Value,
57 pub config: ValidationConfig,
59}
60
61impl ParameterValidator {
62 pub fn new(schema: Value) -> Self {
64 Self {
65 schema,
66 config: ValidationConfig::default(),
67 }
68 }
69
70 pub fn with_config(schema: Value, config: ValidationConfig) -> Self {
72 Self { schema, config }
73 }
74
75 pub fn validate_and_coerce(&self, params: &mut HashMap<String, Value>) -> McpResult<()> {
77 let schema_obj = self
78 .schema
79 .as_object()
80 .ok_or_else(|| McpError::validation("Schema must be an object"))?;
81
82 if let Some(schema_type) = schema_obj.get("type") {
84 if schema_type.as_str() != Some("object") {
85 return Err(McpError::validation("Tool schema type must be 'object'"));
86 }
87 }
88
89 if let Some(required) = schema_obj.get("required") {
91 self.validate_required_properties(params, required)?;
92 }
93
94 if let Some(properties) = schema_obj.get("properties") {
96 self.validate_properties(params, properties)?;
97 }
98
99 if !self.config.allow_additional {
101 self.check_additional_properties(params, schema_obj)?;
102 }
103
104 if let Some(max_props) = self.config.max_object_properties {
106 if params.len() > max_props {
107 return Err(McpError::validation(format!(
108 "Too many properties: {} > {}",
109 params.len(),
110 max_props
111 )));
112 }
113 }
114
115 Ok(())
116 }
117
118 fn validate_required_properties(
120 &self,
121 params: &HashMap<String, Value>,
122 required: &Value,
123 ) -> McpResult<()> {
124 let required_array = required
125 .as_array()
126 .ok_or_else(|| McpError::validation("Required field must be an array"))?;
127
128 for req in required_array {
129 let prop_name = req
130 .as_str()
131 .ok_or_else(|| McpError::validation("Required property names must be strings"))?;
132
133 if !params.contains_key(prop_name) {
134 return Err(McpError::validation(format!(
135 "Missing required parameter: '{prop_name}'"
136 )));
137 }
138 }
139
140 Ok(())
141 }
142
143 fn validate_properties(
145 &self,
146 params: &mut HashMap<String, Value>,
147 properties: &Value,
148 ) -> McpResult<()> {
149 let props_obj = properties
150 .as_object()
151 .ok_or_else(|| McpError::validation("Properties must be an object"))?;
152
153 for (prop_name, value) in params.iter_mut() {
154 if let Some(prop_schema) = props_obj.get(prop_name) {
155 self.validate_and_coerce_value(value, prop_schema, prop_name)?;
156 }
157 }
158
159 Ok(())
160 }
161
162 fn validate_and_coerce_value(
164 &self,
165 value: &mut Value,
166 schema: &Value,
167 field_name: &str,
168 ) -> McpResult<()> {
169 let schema_obj = schema.as_object().ok_or_else(|| {
170 McpError::validation(format!("Schema for '{field_name}' must be an object"))
171 })?;
172
173 let expected_type = schema_obj
175 .get("type")
176 .and_then(|t| t.as_str())
177 .unwrap_or("any");
178
179 match expected_type {
180 "string" => self.validate_string(value, schema_obj, field_name)?,
181 "number" | "integer" => self.validate_number(value, schema_obj, field_name)?,
182 "boolean" => self.validate_boolean(value, field_name)?,
183 "array" => self.validate_array(value, schema_obj, field_name)?,
184 "object" => self.validate_object(value, schema_obj, field_name)?,
185 "null" => self.validate_null(value, field_name)?,
186 _ => {} }
188
189 if let Some(enum_values) = schema_obj.get("enum") {
191 self.validate_enum(value, enum_values, field_name)?;
192 }
193
194 Ok(())
195 }
196
197 fn validate_string(
199 &self,
200 value: &mut Value,
201 schema: &Map<String, Value>,
202 field_name: &str,
203 ) -> McpResult<()> {
204 if self.config.coerce_types && !value.is_string() {
206 if let Some(coerced) = self.coerce_to_string(value) {
207 *value = coerced;
208 } else {
209 return Err(McpError::validation(format!(
210 "Parameter '{}' must be a string, got {}",
211 field_name,
212 get_value_type_name(value)
213 )));
214 }
215 }
216
217 let string_val = value.as_str().ok_or_else(|| {
218 McpError::validation(format!("Parameter '{field_name}' must be a string"))
219 })?;
220
221 if let Some(max_len) = self.config.max_string_length {
223 if string_val.len() > max_len {
224 return Err(McpError::validation(format!(
225 "String '{}' too long: {} > {}",
226 field_name,
227 string_val.len(),
228 max_len
229 )));
230 }
231 }
232
233 if let Some(min_len) = schema.get("minLength").and_then(|v| v.as_u64()) {
235 if string_val.len() < min_len as usize {
236 return Err(McpError::validation(format!(
237 "String '{}' too short: {} < {}",
238 field_name,
239 string_val.len(),
240 min_len
241 )));
242 }
243 }
244
245 if let Some(max_len) = schema.get("maxLength").and_then(|v| v.as_u64()) {
246 if string_val.len() > max_len as usize {
247 return Err(McpError::validation(format!(
248 "String '{}' too long: {} > {}",
249 field_name,
250 string_val.len(),
251 max_len
252 )));
253 }
254 }
255
256 if let Some(pattern) = schema.get("pattern").and_then(|v| v.as_str()) {
258 if pattern.contains("^") && !string_val.starts_with(&pattern[1..pattern.len().min(2)]) {
261 return Err(McpError::validation(format!(
262 "String '{field_name}' does not match pattern"
263 )));
264 }
265 }
266
267 Ok(())
268 }
269
270 fn validate_number(
272 &self,
273 value: &mut Value,
274 schema: &Map<String, Value>,
275 field_name: &str,
276 ) -> McpResult<()> {
277 if self.config.coerce_types && !value.is_number() {
279 if let Some(coerced) = self.coerce_to_number(value) {
280 *value = coerced;
281 } else {
282 return Err(McpError::validation(format!(
283 "Parameter '{}' must be a number, got {}",
284 field_name,
285 get_value_type_name(value)
286 )));
287 }
288 }
289
290 let num_val = value.as_f64().ok_or_else(|| {
291 McpError::validation(format!("Parameter '{field_name}' must be a number"))
292 })?;
293
294 if let Some(minimum) = schema.get("minimum").and_then(|v| v.as_f64()) {
296 if num_val < minimum {
297 return Err(McpError::validation(format!(
298 "Number '{field_name}' too small: {num_val} < {minimum}"
299 )));
300 }
301 }
302
303 if let Some(maximum) = schema.get("maximum").and_then(|v| v.as_f64()) {
304 if num_val > maximum {
305 return Err(McpError::validation(format!(
306 "Number '{field_name}' too large: {num_val} > {maximum}"
307 )));
308 }
309 }
310
311 if schema.get("type").and_then(|v| v.as_str()) == Some("integer") {
313 if num_val.fract() != 0.0 {
314 if self.config.coerce_types {
315 *value = Value::Number(serde_json::Number::from(num_val.round() as i64));
316 } else {
317 return Err(McpError::validation(format!(
318 "Parameter '{field_name}' must be an integer"
319 )));
320 }
321 } else {
322 *value = Value::Number(serde_json::Number::from(num_val as i64));
324 }
325 }
326
327 Ok(())
328 }
329
330 fn validate_boolean(&self, value: &mut Value, field_name: &str) -> McpResult<()> {
332 if self.config.coerce_types && !value.is_boolean() {
334 if let Some(coerced) = self.coerce_to_boolean(value) {
335 *value = coerced;
336 } else {
337 return Err(McpError::validation(format!(
338 "Parameter '{}' must be a boolean, got {}",
339 field_name,
340 get_value_type_name(value)
341 )));
342 }
343 }
344
345 if !value.is_boolean() {
346 return Err(McpError::validation(format!(
347 "Parameter '{field_name}' must be a boolean"
348 )));
349 }
350
351 Ok(())
352 }
353
354 fn validate_array(
356 &self,
357 value: &mut Value,
358 schema: &Map<String, Value>,
359 field_name: &str,
360 ) -> McpResult<()> {
361 let array = value.as_array_mut().ok_or_else(|| {
362 McpError::validation(format!("Parameter '{field_name}' must be an array"))
363 })?;
364
365 if let Some(max_len) = self.config.max_array_length {
367 if array.len() > max_len {
368 return Err(McpError::validation(format!(
369 "Array '{}' too long: {} > {}",
370 field_name,
371 array.len(),
372 max_len
373 )));
374 }
375 }
376
377 if let Some(min_items) = schema.get("minItems").and_then(|v| v.as_u64()) {
378 if array.len() < min_items as usize {
379 return Err(McpError::validation(format!(
380 "Array '{}' too short: {} < {}",
381 field_name,
382 array.len(),
383 min_items
384 )));
385 }
386 }
387
388 if let Some(max_items) = schema.get("maxItems").and_then(|v| v.as_u64()) {
389 if array.len() > max_items as usize {
390 return Err(McpError::validation(format!(
391 "Array '{}' too long: {} > {}",
392 field_name,
393 array.len(),
394 max_items
395 )));
396 }
397 }
398
399 if let Some(items_schema) = schema.get("items") {
401 for (i, item) in array.iter_mut().enumerate() {
402 let item_field = format!("{field_name}[{i}]");
403 self.validate_and_coerce_value(item, items_schema, &item_field)?;
404 }
405 }
406
407 Ok(())
408 }
409
410 fn validate_object(
412 &self,
413 value: &mut Value,
414 _schema: &Map<String, Value>,
415 field_name: &str,
416 ) -> McpResult<()> {
417 let obj = value.as_object().ok_or_else(|| {
418 McpError::validation(format!("Parameter '{field_name}' must be an object"))
419 })?;
420
421 if let Some(max_props) = self.config.max_object_properties {
423 if obj.len() > max_props {
424 return Err(McpError::validation(format!(
425 "Object '{}' has too many properties: {} > {}",
426 field_name,
427 obj.len(),
428 max_props
429 )));
430 }
431 }
432
433 Ok(())
434 }
435
436 fn validate_null(&self, value: &Value, field_name: &str) -> McpResult<()> {
438 if !value.is_null() {
439 return Err(McpError::validation(format!(
440 "Parameter '{field_name}' must be null"
441 )));
442 }
443 Ok(())
444 }
445
446 fn validate_enum(&self, value: &Value, enum_values: &Value, field_name: &str) -> McpResult<()> {
448 let enum_array = enum_values
449 .as_array()
450 .ok_or_else(|| McpError::validation("Enum must be an array"))?;
451
452 if !enum_array.contains(value) {
453 return Err(McpError::validation(format!(
454 "Parameter '{field_name}' must be one of: {enum_array:?}"
455 )));
456 }
457
458 Ok(())
459 }
460
461 fn check_additional_properties(
463 &self,
464 params: &HashMap<String, Value>,
465 schema: &Map<String, Value>,
466 ) -> McpResult<()> {
467 if let Some(properties) = schema.get("properties").and_then(|p| p.as_object()) {
468 let allowed_props: HashSet<_> = properties.keys().collect();
469 let actual_props: HashSet<_> = params.keys().collect();
470 let additional: Vec<_> = actual_props.difference(&allowed_props).collect();
471
472 if !additional.is_empty() {
473 return Err(McpError::validation(format!(
474 "Additional properties not allowed: {additional:?}"
475 )));
476 }
477 }
478
479 Ok(())
480 }
481
482 fn coerce_to_string(&self, value: &Value) -> Option<Value> {
484 match value {
485 Value::Number(n) => Some(Value::String(n.to_string())),
486 Value::Bool(b) => Some(Value::String(b.to_string())),
487 Value::Null => Some(Value::String("null".to_string())),
488 _ => None,
489 }
490 }
491
492 fn coerce_to_number(&self, value: &Value) -> Option<Value> {
493 match value {
494 Value::String(s) => {
495 if let Ok(f) = s.parse::<f64>() {
496 serde_json::Number::from_f64(f).map(Value::Number)
497 } else {
498 None
499 }
500 }
501 Value::Bool(true) => Some(Value::Number(serde_json::Number::from(1))),
502 Value::Bool(false) => Some(Value::Number(serde_json::Number::from(0))),
503 _ => None,
504 }
505 }
506
507 fn coerce_to_boolean(&self, value: &Value) -> Option<Value> {
508 match value {
509 Value::String(s) => match s.to_lowercase().as_str() {
510 "true" | "1" | "yes" | "on" => Some(Value::Bool(true)),
511 "false" | "0" | "no" | "off" | "" => Some(Value::Bool(false)),
512 _ => None,
513 },
514 Value::Number(n) => {
515 if let Some(i) = n.as_i64() {
516 Some(Value::Bool(i != 0))
517 } else {
518 Some(Value::Bool(n.as_f64().unwrap_or(0.0) != 0.0))
519 }
520 }
521 Value::Null => Some(Value::Bool(false)),
522 _ => None,
523 }
524 }
525}
526
527pub trait ParameterType {
529 fn to_schema() -> Value;
531
532 fn from_params(params: &HashMap<String, Value>, name: &str) -> McpResult<Self>
534 where
535 Self: Sized;
536}
537
538impl ParameterType for String {
540 fn to_schema() -> Value {
541 serde_json::json!({
542 "type": "string"
543 })
544 }
545
546 fn from_params(params: &HashMap<String, Value>, name: &str) -> McpResult<Self> {
547 params
548 .get(name)
549 .and_then(|v| v.as_str())
550 .map(|s| s.to_string())
551 .ok_or_else(|| McpError::validation(format!("Missing string parameter: {name}")))
552 }
553}
554
555impl ParameterType for i64 {
556 fn to_schema() -> Value {
557 serde_json::json!({
558 "type": "integer"
559 })
560 }
561
562 fn from_params(params: &HashMap<String, Value>, name: &str) -> McpResult<Self> {
563 params
564 .get(name)
565 .and_then(|v| v.as_i64())
566 .ok_or_else(|| McpError::validation(format!("Missing integer parameter: {name}")))
567 }
568}
569
570impl ParameterType for f64 {
571 fn to_schema() -> Value {
572 serde_json::json!({
573 "type": "number"
574 })
575 }
576
577 fn from_params(params: &HashMap<String, Value>, name: &str) -> McpResult<Self> {
578 params
579 .get(name)
580 .and_then(|v| v.as_f64())
581 .ok_or_else(|| McpError::validation(format!("Missing number parameter: {name}")))
582 }
583}
584
585impl ParameterType for bool {
586 fn to_schema() -> Value {
587 serde_json::json!({
588 "type": "boolean"
589 })
590 }
591
592 fn from_params(params: &HashMap<String, Value>, name: &str) -> McpResult<Self> {
593 params
594 .get(name)
595 .and_then(|v| v.as_bool())
596 .ok_or_else(|| McpError::validation(format!("Missing boolean parameter: {name}")))
597 }
598}
599
600#[macro_export]
602macro_rules! param_schema {
603 (string $name:expr_2021) => {
605 ($name, serde_json::json!({"type": "string"}))
606 };
607
608 (string $name:expr_2021, min: $min:expr_2021) => {
610 ($name, serde_json::json!({"type": "string", "minLength": $min}))
611 };
612
613 (string $name:expr_2021, max: $max:expr_2021) => {
614 ($name, serde_json::json!({"type": "string", "maxLength": $max}))
615 };
616
617 (string $name:expr_2021, min: $min:expr_2021, max: $max:expr_2021) => {
618 ($name, serde_json::json!({"type": "string", "minLength": $min, "maxLength": $max}))
619 };
620
621 (number $name:expr_2021) => {
623 ($name, serde_json::json!({"type": "number"}))
624 };
625
626 (number $name:expr_2021, min: $min:expr_2021) => {
627 ($name, serde_json::json!({"type": "number", "minimum": $min}))
628 };
629
630 (number $name:expr_2021, max: $max:expr_2021) => {
631 ($name, serde_json::json!({"type": "number", "maximum": $max}))
632 };
633
634 (number $name:expr_2021, min: $min:expr_2021, max: $max:expr_2021) => {
635 ($name, serde_json::json!({"type": "number", "minimum": $min, "maximum": $max}))
636 };
637
638 (integer $name:expr_2021) => {
640 ($name, serde_json::json!({"type": "integer"}))
641 };
642
643 (integer $name:expr_2021, min: $min:expr_2021) => {
644 ($name, serde_json::json!({"type": "integer", "minimum": $min}))
645 };
646
647 (integer $name:expr_2021, max: $max:expr_2021) => {
648 ($name, serde_json::json!({"type": "integer", "maximum": $max}))
649 };
650
651 (integer $name:expr_2021, min: $min:expr_2021, max: $max:expr_2021) => {
652 ($name, serde_json::json!({"type": "integer", "minimum": $min, "maximum": $max}))
653 };
654
655 (boolean $name:expr_2021) => {
657 ($name, serde_json::json!({"type": "boolean"}))
658 };
659
660 (array $name:expr_2021, items: $items:expr_2021) => {
662 ($name, serde_json::json!({"type": "array", "items": $items}))
663 };
664
665 (enum $name:expr_2021, values: [$($val:expr_2021),*]) => {
667 ($name, serde_json::json!({"type": "string", "enum": [$($val),*]}))
668 };
669}
670
671pub fn create_tool_schema(params: Vec<(&str, Value)>, required: Vec<&str>) -> Value {
673 let mut properties = Map::new();
674
675 for (name, schema) in params {
676 properties.insert(name.to_string(), schema);
677 }
678
679 serde_json::json!({
680 "type": "object",
681 "properties": properties,
682 "required": required
683 })
684}
685
686#[cfg(test)]
687mod tests {
688 use super::*;
689 use serde_json::json;
690
691 #[test]
692 fn test_string_validation() {
693 let schema = json!({
694 "type": "object",
695 "properties": {
696 "name": {"type": "string", "minLength": 2, "maxLength": 10}
697 },
698 "required": ["name"]
699 });
700
701 let validator = ParameterValidator::new(schema);
702
703 let mut params = HashMap::new();
705 params.insert("name".to_string(), json!("test"));
706 assert!(validator.validate_and_coerce(&mut params).is_ok());
707
708 let mut params = HashMap::new();
710 params.insert("name".to_string(), json!("a"));
711 assert!(validator.validate_and_coerce(&mut params).is_err());
712
713 let mut params = HashMap::new();
715 params.insert("name".to_string(), json!("this_is_too_long"));
716 assert!(validator.validate_and_coerce(&mut params).is_err());
717 }
718
719 #[test]
720 fn test_number_validation() {
721 let schema = json!({
722 "type": "object",
723 "properties": {
724 "age": {"type": "integer", "minimum": 0, "maximum": 150}
725 },
726 "required": ["age"]
727 });
728
729 let validator = ParameterValidator::new(schema);
730
731 let mut params = HashMap::new();
733 params.insert("age".to_string(), json!(25));
734 assert!(validator.validate_and_coerce(&mut params).is_ok());
735
736 let mut params = HashMap::new();
738 params.insert("age".to_string(), json!(-5));
739 assert!(validator.validate_and_coerce(&mut params).is_err());
740
741 let mut params = HashMap::new();
743 params.insert("age".to_string(), json!(200));
744 assert!(validator.validate_and_coerce(&mut params).is_err());
745 }
746
747 #[test]
748 fn test_type_coercion() {
749 let schema = json!({
750 "type": "object",
751 "properties": {
752 "count": {"type": "integer"},
753 "flag": {"type": "boolean"},
754 "name": {"type": "string"}
755 }
756 });
757
758 let validator = ParameterValidator::new(schema);
759
760 let mut params = HashMap::new();
761 params.insert("count".to_string(), json!("42")); params.insert("flag".to_string(), json!("true")); params.insert("name".to_string(), json!(123)); assert!(validator.validate_and_coerce(&mut params).is_ok());
766
767 assert_eq!(params.get("count").unwrap().as_i64(), Some(42));
769 assert_eq!(params.get("flag").unwrap().as_bool(), Some(true));
770 assert_eq!(params.get("name").unwrap().as_str(), Some("123"));
771 }
772
773 #[test]
774 fn test_param_schema_macro() {
775 let (name, schema) = param_schema!(string "username", min: 3, max: 20);
776 assert_eq!(name, "username");
777 assert_eq!(schema["type"], "string");
778 assert_eq!(schema["minLength"], 3);
779 assert_eq!(schema["maxLength"], 20);
780 }
781
782 #[test]
783 fn test_create_tool_schema() {
784 let schema = create_tool_schema(
785 vec![
786 param_schema!(string "name"),
787 param_schema!(integer "age", min: 0),
788 param_schema!(boolean "active"),
789 ],
790 vec!["name", "age"],
791 );
792
793 assert_eq!(schema["type"], "object");
794 assert!(schema["properties"]["name"]["type"] == "string");
795 assert!(schema["properties"]["age"]["type"] == "integer");
796 assert!(schema["properties"]["active"]["type"] == "boolean");
797 assert_eq!(schema["required"], json!(["name", "age"]));
798 }
799}