1use super::schema_cache::{CachedSchema, ToolSchemaCache};
28use super::ValidationConfig;
29
30#[derive(Debug, Clone)]
32pub struct ValidationResult {
33 pub is_valid: bool,
35
36 pub errors: Vec<ValidationError>,
38}
39
40#[derive(Debug, Clone)]
42pub struct ValidationError {
43 pub path: String,
45
46 pub kind: ValidationErrorKind,
48
49 pub message: String,
51}
52
53#[derive(Debug, Clone, PartialEq)]
55pub enum ValidationErrorKind {
56 MissingRequired { field: String },
58
59 TypeMismatch { expected: String, actual: String },
61
62 UnknownField {
64 field: String,
65 suggestions: Vec<String>,
66 },
67
68 InvalidValue { reason: String },
70
71 InvalidEnum { value: String, allowed: Vec<String> },
73}
74
75pub struct McpValidator {
77 cache: ToolSchemaCache,
78 config: ValidationConfig,
79}
80
81impl McpValidator {
82 pub fn new(config: ValidationConfig) -> Self {
84 Self {
85 cache: ToolSchemaCache::new(),
86 config,
87 }
88 }
89
90 pub fn cache(&self) -> &ToolSchemaCache {
92 &self.cache
93 }
94
95 pub fn config(&self) -> &ValidationConfig {
97 &self.config
98 }
99
100 pub fn validate(
102 &self,
103 server: &str,
104 tool: &str,
105 params: &serde_json::Value,
106 ) -> ValidationResult {
107 if !self.config.pre_validate {
109 return ValidationResult {
110 is_valid: true,
111 errors: vec![],
112 };
113 }
114
115 let Some(schema_ref) = self.cache.get(server, tool) else {
117 tracing::debug!(
119 server = %server,
120 tool = %tool,
121 "No cached schema, skipping validation"
122 );
123 return ValidationResult {
124 is_valid: true,
125 errors: vec![],
126 };
127 };
128
129 let schema = schema_ref.value();
130 let mut errors = Vec::new();
131
132 let validation = schema.validator.iter_errors(params);
134
135 for error in validation {
136 let path = error.instance_path.to_string();
137 let kind = self.classify_error(&error, schema);
138 let message = self.format_error(&error, schema);
139
140 errors.push(ValidationError {
141 path,
142 kind,
143 message,
144 });
145 }
146
147 ValidationResult {
148 is_valid: errors.is_empty(),
149 errors,
150 }
151 }
152
153 fn classify_error(
155 &self,
156 error: &jsonschema::ValidationError,
157 schema: &CachedSchema,
158 ) -> ValidationErrorKind {
159 let error_kind = format!("{:?}", error.kind);
160 let error_msg = error.to_string();
161
162 if error_kind.contains("Required") {
163 let field = self.extract_missing_field(&error_msg);
165 ValidationErrorKind::MissingRequired { field }
166 } else if error_kind.contains("Type") {
167 ValidationErrorKind::TypeMismatch {
168 expected: self.extract_expected_type(&error_msg),
169 actual: self.extract_actual_type(&error_msg),
170 }
171 } else if error_kind.contains("AdditionalProperties") {
172 let field = Self::extract_additional_property_field(&error_msg);
176 let suggestions = self.find_suggestions(&field, &schema.properties);
177 ValidationErrorKind::UnknownField { field, suggestions }
178 } else if error_kind.contains("Enum") {
179 ValidationErrorKind::InvalidEnum {
180 value: format!("{}", error.instance),
181 allowed: vec![], }
183 } else {
184 ValidationErrorKind::InvalidValue { reason: error_msg }
185 }
186 }
187
188 fn extract_missing_field(&self, error_msg: &str) -> String {
190 if let Some(start) = error_msg.find('"') {
192 if let Some(end) = error_msg[start + 1..].find('"') {
193 return error_msg[start + 1..start + 1 + end].to_string();
194 }
195 }
196 if let Some(start) = error_msg.find('\'') {
198 if let Some(end) = error_msg[start + 1..].find('\'') {
199 return error_msg[start + 1..start + 1 + end].to_string();
200 }
201 }
202 "unknown".to_string()
203 }
204
205 fn extract_additional_property_field(error_msg: &str) -> String {
209 if let Some(start) = error_msg.find('\'') {
211 if let Some(end) = error_msg[start + 1..].find('\'') {
212 return error_msg[start + 1..start + 1 + end].to_string();
213 }
214 }
215 if let Some(start) = error_msg.find('"') {
217 if let Some(end) = error_msg[start + 1..].find('"') {
218 return error_msg[start + 1..start + 1 + end].to_string();
219 }
220 }
221 "unknown".to_string()
222 }
223
224 fn extract_expected_type(&self, error_msg: &str) -> String {
226 if error_msg.contains("string") {
228 "string".to_string()
229 } else if error_msg.contains("integer") {
230 "integer".to_string()
231 } else if error_msg.contains("number") {
232 "number".to_string()
233 } else if error_msg.contains("boolean") {
234 "boolean".to_string()
235 } else if error_msg.contains("array") {
236 "array".to_string()
237 } else if error_msg.contains("object") {
238 "object".to_string()
239 } else {
240 "expected".to_string()
241 }
242 }
243
244 fn extract_actual_type(&self, error_msg: &str) -> String {
249 let msg = error_msg.to_lowercase();
252 if msg.contains("null") && !msg.contains("not of type \"null\"") {
253 "null".to_string()
254 } else if msg.contains("true") || msg.contains("false") {
255 "boolean".to_string()
256 } else if msg.contains('[') {
257 "array".to_string()
258 } else if msg.contains('{') {
259 "object".to_string()
260 } else if msg.contains("\"\"") || msg.contains("''") {
261 "string".to_string()
262 } else {
263 if msg.contains("not of type \"string\"") {
265 "number".to_string()
266 } else if msg.contains("not of type \"integer\"")
267 || msg.contains("not of type \"number\"")
268 {
269 "string".to_string()
270 } else {
271 "unknown".to_string()
272 }
273 }
274 }
275
276 fn format_error(&self, error: &jsonschema::ValidationError, schema: &CachedSchema) -> String {
278 let base = error.to_string();
279
280 if !schema.required.is_empty() {
282 format!(
283 "{}. Required fields: [{}]",
284 base,
285 schema.required.join(", ")
286 )
287 } else {
288 base
289 }
290 }
291
292 pub fn find_suggestions(&self, field: &str, properties: &[String]) -> Vec<String> {
294 properties
295 .iter()
296 .filter(|p| Self::edit_distance(field, p) <= self.config.suggestion_distance)
297 .cloned()
298 .collect()
299 }
300
301 pub fn edit_distance(a: &str, b: &str) -> usize {
303 let a = a.to_lowercase();
304 let b = b.to_lowercase();
305
306 if a.is_empty() {
307 return b.len();
308 }
309 if b.is_empty() {
310 return a.len();
311 }
312
313 let a_chars: Vec<char> = a.chars().collect();
314 let b_chars: Vec<char> = b.chars().collect();
315
316 let mut matrix = vec![vec![0usize; b_chars.len() + 1]; a_chars.len() + 1];
317
318 for (i, row) in matrix.iter_mut().enumerate().take(a_chars.len() + 1) {
319 row[0] = i;
320 }
321 for (j, val) in matrix[0].iter_mut().enumerate() {
322 *val = j;
323 }
324
325 for i in 1..=a_chars.len() {
326 for j in 1..=b_chars.len() {
327 let cost = if a_chars[i - 1] == b_chars[j - 1] {
328 0
329 } else {
330 1
331 };
332 matrix[i][j] = std::cmp::min(
333 std::cmp::min(
334 matrix[i - 1][j] + 1, matrix[i][j - 1] + 1, ),
337 matrix[i - 1][j - 1] + cost, );
339 }
340 }
341
342 matrix[a_chars.len()][b_chars.len()]
343 }
344}
345
346#[cfg(test)]
351mod tests {
352 use super::*;
353 use crate::types::ToolDefinition;
354 use serde_json::json;
355
356 #[test]
360 fn test_validate_missing_required_field() {
361 let validator = McpValidator::new(ValidationConfig::default());
362 validator
363 .cache()
364 .populate(
365 "novanet",
366 &[
367 ToolDefinition::new("novanet_context").with_input_schema(json!({
368 "type": "object",
369 "properties": {
370 "entity": { "type": "string" },
371 "locale": { "type": "string" }
372 },
373 "required": ["entity"]
374 })),
375 ],
376 )
377 .unwrap();
378
379 let result = validator.validate(
381 "novanet",
382 "novanet_context",
383 &json!({
384 "locale": "fr-FR"
385 }),
386 );
387
388 assert!(!result.is_valid);
389 assert_eq!(result.errors.len(), 1);
390
391 match &result.errors[0].kind {
393 ValidationErrorKind::MissingRequired { field } => {
394 assert_eq!(field, "entity");
395 }
396 other => {
397 panic!("Expected MissingRequired, got {:?}", other);
398 }
399 }
400 }
401
402 #[test]
406 fn test_validate_valid_params_passes() {
407 let validator = McpValidator::new(ValidationConfig::default());
408 validator
409 .cache()
410 .populate(
411 "novanet",
412 &[
413 ToolDefinition::new("novanet_context").with_input_schema(json!({
414 "type": "object",
415 "properties": {
416 "entity": { "type": "string" }
417 },
418 "required": ["entity"]
419 })),
420 ],
421 )
422 .unwrap();
423
424 let result = validator.validate(
425 "novanet",
426 "novanet_context",
427 &json!({
428 "entity": "qr-code"
429 }),
430 );
431
432 assert!(result.is_valid);
433 assert!(result.errors.is_empty());
434 }
435
436 #[test]
440 fn test_validate_disabled_always_passes() {
441 let config = ValidationConfig {
442 pre_validate: false,
443 ..Default::default()
444 };
445 let validator = McpValidator::new(config);
446
447 let result = validator.validate("any", "tool", &json!({}));
449 assert!(result.is_valid);
450 }
451
452 #[test]
456 fn test_validate_no_cached_schema_passes() {
457 let validator = McpValidator::new(ValidationConfig::default());
458
459 let result = validator.validate(
461 "unknown",
462 "tool",
463 &json!({
464 "anything": "goes"
465 }),
466 );
467
468 assert!(result.is_valid);
469 }
470
471 #[test]
475 fn test_validate_type_mismatch() {
476 let validator = McpValidator::new(ValidationConfig::default());
477 validator
478 .cache()
479 .populate(
480 "s",
481 &[ToolDefinition::new("t").with_input_schema(json!({
482 "type": "object",
483 "properties": {
484 "count": { "type": "integer" }
485 }
486 }))],
487 )
488 .unwrap();
489
490 let result = validator.validate(
491 "s",
492 "t",
493 &json!({
494 "count": "not-an-integer"
495 }),
496 );
497
498 assert!(!result.is_valid);
499 assert!(matches!(
500 &result.errors[0].kind,
501 ValidationErrorKind::TypeMismatch { .. }
502 ));
503 }
504
505 #[test]
509 fn test_edit_distance_exact_match() {
510 assert_eq!(McpValidator::edit_distance("entity", "entity"), 0);
511 }
512
513 #[test]
517 fn test_edit_distance_one_char_diff() {
518 assert_eq!(McpValidator::edit_distance("entity", "entityy"), 1);
519 assert_eq!(McpValidator::edit_distance("entty", "entity"), 1);
520 }
521
522 #[test]
526 fn test_edit_distance_case_insensitive() {
527 assert_eq!(McpValidator::edit_distance("Entity", "ENTITY"), 0);
528 }
529
530 #[test]
534 fn test_find_suggestions_within_distance() {
535 let validator = McpValidator::new(ValidationConfig::default());
536 validator
537 .cache()
538 .populate(
539 "s",
540 &[ToolDefinition::new("t").with_input_schema(json!({
541 "type": "object",
542 "properties": {
543 "entity": {},
544 "locale": {},
545 "forms": {}
546 }
547 }))],
548 )
549 .unwrap();
550
551 let schema = validator.cache().get("s", "t").unwrap();
552 let suggestions = validator.find_suggestions("entiy", &schema.properties);
553
554 assert!(suggestions.contains(&"entity".to_string()));
555 }
556
557 #[test]
561 fn test_edit_distance_empty_strings() {
562 assert_eq!(McpValidator::edit_distance("", ""), 0);
563 assert_eq!(McpValidator::edit_distance("abc", ""), 3);
564 assert_eq!(McpValidator::edit_distance("", "xyz"), 3);
565 }
566
567 #[test]
571 fn test_edit_distance_completely_different() {
572 assert_eq!(McpValidator::edit_distance("abc", "xyz"), 3);
573 }
574
575 #[test]
579 fn test_multiple_validation_errors() {
580 let validator = McpValidator::new(ValidationConfig::default());
581 validator
582 .cache()
583 .populate(
584 "s",
585 &[ToolDefinition::new("t").with_input_schema(json!({
586 "type": "object",
587 "properties": {
588 "a": { "type": "string" },
589 "b": { "type": "integer" }
590 },
591 "required": ["a", "b"]
592 }))],
593 )
594 .unwrap();
595
596 let result = validator.validate("s", "t", &json!({}));
598
599 assert!(!result.is_valid);
600 assert_eq!(result.errors.len(), 2);
601 }
602
603 #[test]
607 fn test_error_message_includes_required_fields() {
608 let validator = McpValidator::new(ValidationConfig::default());
609 validator
610 .cache()
611 .populate(
612 "s",
613 &[ToolDefinition::new("t").with_input_schema(json!({
614 "type": "object",
615 "properties": {
616 "entity": { "type": "string" },
617 "locale": { "type": "string" }
618 },
619 "required": ["entity"]
620 }))],
621 )
622 .unwrap();
623
624 let result = validator.validate("s", "t", &json!({}));
625
626 assert!(!result.is_valid);
627 assert!(result.errors[0].message.contains("Required fields"));
629 assert!(result.errors[0].message.contains("entity"));
630 }
631
632 #[test]
636 fn test_suggestion_distance_config() {
637 let config = ValidationConfig {
638 suggestion_distance: 1,
639 ..Default::default()
640 };
641 let validator = McpValidator::new(config);
642
643 let suggestions = validator.find_suggestions(
645 "entiy",
646 &["entity".to_string(), "completely_different".to_string()],
647 );
648 assert!(suggestions.contains(&"entity".to_string()));
649 assert!(!suggestions.contains(&"completely_different".to_string()));
650 }
651}