use super::schema_cache::{CachedSchema, ToolSchemaCache};
use super::ValidationConfig;
#[derive(Debug, Clone)]
pub struct ValidationResult {
pub is_valid: bool,
pub errors: Vec<ValidationError>,
}
#[derive(Debug, Clone)]
pub struct ValidationError {
pub path: String,
pub kind: ValidationErrorKind,
pub message: String,
}
#[derive(Debug, Clone, PartialEq)]
pub enum ValidationErrorKind {
MissingRequired { field: String },
TypeMismatch { expected: String, actual: String },
UnknownField {
field: String,
suggestions: Vec<String>,
},
InvalidValue { reason: String },
InvalidEnum { value: String, allowed: Vec<String> },
}
pub struct McpValidator {
cache: ToolSchemaCache,
config: ValidationConfig,
}
impl McpValidator {
pub fn new(config: ValidationConfig) -> Self {
Self {
cache: ToolSchemaCache::new(),
config,
}
}
pub fn cache(&self) -> &ToolSchemaCache {
&self.cache
}
pub fn config(&self) -> &ValidationConfig {
&self.config
}
pub fn validate(
&self,
server: &str,
tool: &str,
params: &serde_json::Value,
) -> ValidationResult {
if !self.config.pre_validate {
return ValidationResult {
is_valid: true,
errors: vec![],
};
}
let Some(schema_ref) = self.cache.get(server, tool) else {
tracing::debug!(
server = %server,
tool = %tool,
"No cached schema, skipping validation"
);
return ValidationResult {
is_valid: true,
errors: vec![],
};
};
let schema = schema_ref.value();
let mut errors = Vec::new();
let validation = schema.validator.iter_errors(params);
for error in validation {
let path = error.instance_path.to_string();
let kind = self.classify_error(&error, schema);
let message = self.format_error(&error, schema);
errors.push(ValidationError {
path,
kind,
message,
});
}
ValidationResult {
is_valid: errors.is_empty(),
errors,
}
}
fn classify_error(
&self,
error: &jsonschema::ValidationError,
schema: &CachedSchema,
) -> ValidationErrorKind {
let error_kind = format!("{:?}", error.kind);
let error_msg = error.to_string();
if error_kind.contains("Required") {
let field = self.extract_missing_field(&error_msg);
ValidationErrorKind::MissingRequired { field }
} else if error_kind.contains("Type") {
ValidationErrorKind::TypeMismatch {
expected: self.extract_expected_type(&error_msg),
actual: self.extract_actual_type(&error_msg),
}
} else if error_kind.contains("AdditionalProperties") {
let path = error.instance_path.to_string();
let field = path
.rsplit('/')
.next()
.filter(|s| !s.is_empty())
.unwrap_or("unknown")
.to_string();
let suggestions = self.find_suggestions(&field, &schema.properties);
ValidationErrorKind::UnknownField { field, suggestions }
} else if error_kind.contains("Enum") {
ValidationErrorKind::InvalidEnum {
value: format!("{}", error.instance),
allowed: vec![], }
} else {
ValidationErrorKind::InvalidValue { reason: error_msg }
}
}
fn extract_missing_field(&self, error_msg: &str) -> String {
if let Some(start) = error_msg.find('"') {
if let Some(end) = error_msg[start + 1..].find('"') {
return error_msg[start + 1..start + 1 + end].to_string();
}
}
if let Some(start) = error_msg.find('\'') {
if let Some(end) = error_msg[start + 1..].find('\'') {
return error_msg[start + 1..start + 1 + end].to_string();
}
}
"unknown".to_string()
}
fn extract_expected_type(&self, error_msg: &str) -> String {
if error_msg.contains("string") {
"string".to_string()
} else if error_msg.contains("integer") {
"integer".to_string()
} else if error_msg.contains("number") {
"number".to_string()
} else if error_msg.contains("boolean") {
"boolean".to_string()
} else if error_msg.contains("array") {
"array".to_string()
} else if error_msg.contains("object") {
"object".to_string()
} else {
"expected".to_string()
}
}
fn extract_actual_type(&self, _error_msg: &str) -> String {
"actual".to_string()
}
fn format_error(&self, error: &jsonschema::ValidationError, schema: &CachedSchema) -> String {
let base = error.to_string();
if !schema.required.is_empty() {
format!(
"{}. Required fields: [{}]",
base,
schema.required.join(", ")
)
} else {
base
}
}
pub fn find_suggestions(&self, field: &str, properties: &[String]) -> Vec<String> {
properties
.iter()
.filter(|p| Self::edit_distance(field, p) <= self.config.suggestion_distance)
.cloned()
.collect()
}
pub fn edit_distance(a: &str, b: &str) -> usize {
let a = a.to_lowercase();
let b = b.to_lowercase();
if a.is_empty() {
return b.len();
}
if b.is_empty() {
return a.len();
}
let a_chars: Vec<char> = a.chars().collect();
let b_chars: Vec<char> = b.chars().collect();
let mut matrix = vec![vec![0usize; b_chars.len() + 1]; a_chars.len() + 1];
for (i, row) in matrix.iter_mut().enumerate().take(a_chars.len() + 1) {
row[0] = i;
}
for (j, val) in matrix[0].iter_mut().enumerate() {
*val = j;
}
for i in 1..=a_chars.len() {
for j in 1..=b_chars.len() {
let cost = if a_chars[i - 1] == b_chars[j - 1] {
0
} else {
1
};
matrix[i][j] = std::cmp::min(
std::cmp::min(
matrix[i - 1][j] + 1, matrix[i][j - 1] + 1, ),
matrix[i - 1][j - 1] + cost, );
}
}
matrix[a_chars.len()][b_chars.len()]
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::mcp::types::ToolDefinition;
use serde_json::json;
#[test]
fn test_validate_missing_required_field() {
let validator = McpValidator::new(ValidationConfig::default());
validator
.cache()
.populate(
"novanet",
&[
ToolDefinition::new("novanet_context").with_input_schema(json!({
"type": "object",
"properties": {
"entity": { "type": "string" },
"locale": { "type": "string" }
},
"required": ["entity"]
})),
],
)
.unwrap();
let result = validator.validate(
"novanet",
"novanet_context",
&json!({
"locale": "fr-FR"
}),
);
assert!(!result.is_valid);
assert_eq!(result.errors.len(), 1);
match &result.errors[0].kind {
ValidationErrorKind::MissingRequired { field } => {
assert_eq!(field, "entity");
}
other => {
panic!("Expected MissingRequired, got {:?}", other);
}
}
}
#[test]
fn test_validate_valid_params_passes() {
let validator = McpValidator::new(ValidationConfig::default());
validator
.cache()
.populate(
"novanet",
&[
ToolDefinition::new("novanet_context").with_input_schema(json!({
"type": "object",
"properties": {
"entity": { "type": "string" }
},
"required": ["entity"]
})),
],
)
.unwrap();
let result = validator.validate(
"novanet",
"novanet_context",
&json!({
"entity": "qr-code"
}),
);
assert!(result.is_valid);
assert!(result.errors.is_empty());
}
#[test]
fn test_validate_disabled_always_passes() {
let config = ValidationConfig {
pre_validate: false,
..Default::default()
};
let validator = McpValidator::new(config);
let result = validator.validate("any", "tool", &json!({}));
assert!(result.is_valid);
}
#[test]
fn test_validate_no_cached_schema_passes() {
let validator = McpValidator::new(ValidationConfig::default());
let result = validator.validate(
"unknown",
"tool",
&json!({
"anything": "goes"
}),
);
assert!(result.is_valid);
}
#[test]
fn test_validate_type_mismatch() {
let validator = McpValidator::new(ValidationConfig::default());
validator
.cache()
.populate(
"s",
&[ToolDefinition::new("t").with_input_schema(json!({
"type": "object",
"properties": {
"count": { "type": "integer" }
}
}))],
)
.unwrap();
let result = validator.validate(
"s",
"t",
&json!({
"count": "not-an-integer"
}),
);
assert!(!result.is_valid);
assert!(matches!(
&result.errors[0].kind,
ValidationErrorKind::TypeMismatch { .. }
));
}
#[test]
fn test_edit_distance_exact_match() {
assert_eq!(McpValidator::edit_distance("entity", "entity"), 0);
}
#[test]
fn test_edit_distance_one_char_diff() {
assert_eq!(McpValidator::edit_distance("entity", "entityy"), 1);
assert_eq!(McpValidator::edit_distance("entty", "entity"), 1);
}
#[test]
fn test_edit_distance_case_insensitive() {
assert_eq!(McpValidator::edit_distance("Entity", "ENTITY"), 0);
}
#[test]
fn test_find_suggestions_within_distance() {
let validator = McpValidator::new(ValidationConfig::default());
validator
.cache()
.populate(
"s",
&[ToolDefinition::new("t").with_input_schema(json!({
"type": "object",
"properties": {
"entity": {},
"locale": {},
"forms": {}
}
}))],
)
.unwrap();
let schema = validator.cache().get("s", "t").unwrap();
let suggestions = validator.find_suggestions("entiy", &schema.properties);
assert!(suggestions.contains(&"entity".to_string()));
}
#[test]
fn test_edit_distance_empty_strings() {
assert_eq!(McpValidator::edit_distance("", ""), 0);
assert_eq!(McpValidator::edit_distance("abc", ""), 3);
assert_eq!(McpValidator::edit_distance("", "xyz"), 3);
}
#[test]
fn test_edit_distance_completely_different() {
assert_eq!(McpValidator::edit_distance("abc", "xyz"), 3);
}
#[test]
fn test_multiple_validation_errors() {
let validator = McpValidator::new(ValidationConfig::default());
validator
.cache()
.populate(
"s",
&[ToolDefinition::new("t").with_input_schema(json!({
"type": "object",
"properties": {
"a": { "type": "string" },
"b": { "type": "integer" }
},
"required": ["a", "b"]
}))],
)
.unwrap();
let result = validator.validate("s", "t", &json!({}));
assert!(!result.is_valid);
assert_eq!(result.errors.len(), 2);
}
#[test]
fn test_error_message_includes_required_fields() {
let validator = McpValidator::new(ValidationConfig::default());
validator
.cache()
.populate(
"s",
&[ToolDefinition::new("t").with_input_schema(json!({
"type": "object",
"properties": {
"entity": { "type": "string" },
"locale": { "type": "string" }
},
"required": ["entity"]
}))],
)
.unwrap();
let result = validator.validate("s", "t", &json!({}));
assert!(!result.is_valid);
assert!(result.errors[0].message.contains("Required fields"));
assert!(result.errors[0].message.contains("entity"));
}
#[test]
fn test_suggestion_distance_config() {
let config = ValidationConfig {
suggestion_distance: 1,
..Default::default()
};
let validator = McpValidator::new(config);
let suggestions = validator.find_suggestions(
"entiy",
&["entity".to_string(), "completely_different".to_string()],
);
assert!(suggestions.contains(&"entity".to_string()));
assert!(!suggestions.contains(&"completely_different".to_string()));
}
}