use thiserror::Error;
use unicode_normalization::UnicodeNormalization;
#[derive(Error, Debug)]
pub enum ValidationError {
#[error("Input exceeds maximum length ({max} bytes, got {actual})")]
TooLong {
max: usize,
actual: usize,
},
#[error("Invalid UTF-8 encoding")]
InvalidUtf8,
#[error("Disallowed characters in input")]
DisallowedChars,
#[error("Input failed schema validation: {0}")]
SchemaViolation(String),
#[error("JSON error: {0}")]
JsonError(#[from] serde_json::Error),
}
pub mod limits {
pub const MAX_MESSAGE_LENGTH: usize = 64 * 1024;
pub const MAX_TOOL_PARAMS_SIZE: usize = 1024 * 1024;
pub const MAX_SKILL_FILE_SIZE: usize = 256 * 1024;
pub const MAX_CONFIG_FILE_SIZE: usize = 1024 * 1024;
pub const MAX_ATTACHMENT_SIZE: usize = 50 * 1024 * 1024;
pub const MAX_JSON_DEPTH: usize = 32;
}
pub fn validate_message_content(input: &str, max_len: usize) -> Result<String, ValidationError> {
if input.len() > max_len {
return Err(ValidationError::TooLong {
max: max_len,
actual: input.len(),
});
}
let sanitized: String = input
.chars()
.filter(|c| !c.is_control() || *c == '\n' || *c == '\t' || *c == '\r')
.collect();
let normalized: String = sanitized.nfkc().collect();
Ok(normalized)
}
pub fn validate_tool_params(
params: &serde_json::Value,
schema: &serde_json::Value,
) -> Result<(), ValidationError> {
let size = serde_json::to_string(params)?.len();
if size > limits::MAX_TOOL_PARAMS_SIZE {
return Err(ValidationError::TooLong {
max: limits::MAX_TOOL_PARAMS_SIZE,
actual: size,
});
}
check_json_depth(params, 0, limits::MAX_JSON_DEPTH)?;
validate_json_structure(params, schema)?;
Ok(())
}
fn check_json_depth(
value: &serde_json::Value,
depth: usize,
max: usize,
) -> Result<(), ValidationError> {
if depth > max {
return Err(ValidationError::SchemaViolation(format!(
"JSON nesting depth exceeds maximum ({max})"
)));
}
match value {
serde_json::Value::Array(arr) => {
for item in arr {
check_json_depth(item, depth + 1, max)?;
}
}
serde_json::Value::Object(obj) => {
for (_, item) in obj {
check_json_depth(item, depth + 1, max)?;
}
}
_ => {}
}
Ok(())
}
fn validate_json_structure(
params: &serde_json::Value,
schema: &serde_json::Value,
) -> Result<(), ValidationError> {
let schema_type = schema.get("type").and_then(|t| t.as_str());
match schema_type {
Some("object") => {
if !params.is_object() {
return Err(ValidationError::SchemaViolation(
"Expected object".to_string(),
));
}
if let Some(required) = schema.get("required").and_then(|r| r.as_array()) {
let obj = params.as_object().unwrap();
for req in required {
if let Some(field) = req.as_str() {
if !obj.contains_key(field) {
return Err(ValidationError::SchemaViolation(format!(
"Missing required field: {field}"
)));
}
}
}
}
}
Some("array") => {
if !params.is_array() {
return Err(ValidationError::SchemaViolation(
"Expected array".to_string(),
));
}
}
Some("string") => {
if !params.is_string() {
return Err(ValidationError::SchemaViolation(
"Expected string".to_string(),
));
}
}
Some("number" | "integer") => {
if !params.is_number() {
return Err(ValidationError::SchemaViolation(
"Expected number".to_string(),
));
}
}
Some("boolean") => {
if !params.is_boolean() {
return Err(ValidationError::SchemaViolation(
"Expected boolean".to_string(),
));
}
}
_ => {}
}
Ok(())
}
pub fn validate_path(path: &str) -> Result<(), ValidationError> {
if path.contains("..") || path.contains('\0') {
return Err(ValidationError::DisallowedChars);
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_validate_message_content() {
let result = validate_message_content("Hello, world!", 100);
assert!(result.is_ok());
assert_eq!(result.unwrap(), "Hello, world!");
let result = validate_message_content("Hello\x00World", 100);
assert!(result.is_ok());
assert_eq!(result.unwrap(), "HelloWorld");
let result = validate_message_content("Line1\nLine2", 100);
assert!(result.is_ok());
assert_eq!(result.unwrap(), "Line1\nLine2");
let result = validate_message_content("x".repeat(200).as_str(), 100);
assert!(matches!(result, Err(ValidationError::TooLong { .. })));
}
#[test]
fn test_unicode_normalization() {
let result = validate_message_content("fi", 100); assert!(result.is_ok());
assert_eq!(result.unwrap(), "fi");
}
#[test]
fn test_validate_path() {
assert!(validate_path("/home/user/file.txt").is_ok());
assert!(validate_path("../etc/passwd").is_err());
assert!(validate_path("/home/user/\0file").is_err());
}
#[test]
fn test_json_depth() {
let shallow = serde_json::json!({"a": {"b": "c"}});
assert!(check_json_depth(&shallow, 0, 10).is_ok());
let mut deep = serde_json::json!("leaf");
for _ in 0..50 {
deep = serde_json::json!({"nested": deep});
}
assert!(check_json_depth(&deep, 0, 32).is_err());
}
}