use crate::error::Result;
use crate::rate_limit::TokenBucket;
use serde::{Deserialize, Serialize};
use std::time::Duration;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ValidationConfig {
pub max_payload_size_bytes: usize,
pub max_string_length: usize,
pub max_object_depth: usize,
pub forbidden_keys: Vec<String>,
pub required_keys: Vec<String>,
pub max_enqueue_per_second: Option<u32>,
pub max_enqueue_burst: Option<u32>,
}
impl Default for ValidationConfig {
fn default() -> Self {
Self {
max_payload_size_bytes: 1024 * 1024, max_string_length: 1024, max_object_depth: 5, forbidden_keys: vec!["__proto__".to_string(), "constructor".to_string()],
required_keys: vec![],
max_enqueue_per_second: Some(1000), max_enqueue_burst: Some(50), }
}
}
impl ValidationConfig {
pub fn validate_payload(&self, payload: &serde_json::Value) -> Result<()> {
let validator = PayloadValidator {
config: self.clone(),
rate_limiter: None,
};
validator.validate_single_pass(payload, 0, true)?;
validator.validate_size(payload)?;
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct PayloadValidator {
config: ValidationConfig,
rate_limiter: Option<TokenBucket>,
}
impl PayloadValidator {
pub fn new(config: ValidationConfig) -> Self {
let rate_limiter = config
.max_enqueue_per_second
.map(|rate| TokenBucket::new(rate, config.max_enqueue_burst.unwrap_or(50)));
Self {
config,
rate_limiter,
}
}
pub fn config(&self) -> &ValidationConfig {
&self.config
}
pub fn validate(&self, payload: &serde_json::Value) -> Result<()> {
self.validate_single_pass(payload, 0, true)?;
if let Some(ref limiter) = self.rate_limiter {
if !limiter.try_acquire() {
return Err(crate::error::Error::RateLimited {
retry_after: Duration::from_secs(1),
});
}
}
self.validate_size(payload)?;
Ok(())
}
pub fn validate_batch(&self, payloads: &[serde_json::Value]) -> Result<()> {
for (index, payload) in payloads.iter().enumerate() {
let res = self.validate_single_pass(payload, 0, true);
if payloads.len() > 1 {
res.map_err(|e| match e {
crate::error::Error::ValidationFailed { reason } => {
crate::error::Error::ValidationFailed {
reason: format!("Payload at index {}: {}", index, reason),
}
}
other => other,
})?;
} else {
res?;
}
}
if let Some(ref limiter) = self.rate_limiter {
if !limiter.try_acquire_multiple(payloads.len() as u32) {
return Err(crate::error::Error::RateLimited {
retry_after: Duration::from_secs(1),
});
}
}
for (index, payload) in payloads.iter().enumerate() {
let res = self.validate_size(payload);
if payloads.len() > 1 {
res.map_err(|e| match e {
crate::error::Error::PayloadTooLarge {
actual_bytes,
max_bytes,
} => crate::error::Error::ValidationFailed {
reason: format!(
"Payload at index {} too large: {} bytes exceeds limit {}",
index, actual_bytes, max_bytes
),
},
other => other,
})?;
} else {
res?;
}
}
Ok(())
}
fn validate_size(&self, payload: &serde_json::Value) -> Result<()> {
let serialized = serde_json::to_string(payload)?;
let size = serialized.len();
if size > self.config.max_payload_size_bytes {
return Err(crate::error::Error::PayloadTooLarge {
actual_bytes: size,
max_bytes: self.config.max_payload_size_bytes,
});
}
Ok(())
}
fn validate_single_pass(
&self,
payload: &serde_json::Value,
depth: usize,
is_top_level: bool,
) -> Result<()> {
if depth > self.config.max_object_depth {
return Err(crate::error::Error::ValidationFailed {
reason: format!(
"Object depth {} exceeds limit {}",
depth, self.config.max_object_depth
),
});
}
match payload {
serde_json::Value::Object(obj) => {
if !self.config.forbidden_keys.is_empty() {
for forbidden in &self.config.forbidden_keys {
if obj.contains_key(forbidden) {
return Err(crate::error::Error::ValidationFailed {
reason: format!("Forbidden key '{}' found in payload", forbidden),
});
}
}
}
if is_top_level && !self.config.required_keys.is_empty() {
for required in &self.config.required_keys {
if !obj.contains_key(required) {
return Err(crate::error::Error::ValidationFailed {
reason: format!("Required key '{}' missing from payload", required),
});
}
}
}
for (key, value) in obj {
if key.len() > self.config.max_string_length {
return Err(crate::error::Error::ValidationFailed {
reason: format!(
"Key '{}' length {} exceeds limit {}",
key,
key.len(),
self.config.max_string_length
),
});
}
self.validate_single_pass(value, depth + 1, false)?;
}
}
serde_json::Value::Array(arr) => {
for item in arr {
self.validate_single_pass(item, depth + 1, false)?;
}
}
serde_json::Value::String(s) => {
if s.len() > self.config.max_string_length {
return Err(crate::error::Error::ValidationFailed {
reason: format!(
"String length {} exceeds limit {}",
s.len(),
self.config.max_string_length
),
});
}
}
_ => {}
}
Ok(())
}
pub fn rate_limit_status(&self) -> Option<crate::rate_limit::RateLimitStatus> {
self.rate_limiter.as_ref().map(|limiter| limiter.status())
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn test_default_config() {
let config = ValidationConfig::default();
assert_eq!(config.max_payload_size_bytes, 1024 * 1024); assert_eq!(config.max_string_length, 1024);
assert_eq!(config.max_object_depth, 5);
assert_eq!(config.max_enqueue_per_second, Some(1000));
assert_eq!(config.max_enqueue_burst, Some(50));
}
#[test]
fn test_validate_payload_size() {
let config = ValidationConfig {
max_payload_size_bytes: 50,
..Default::default()
};
let small_payload = json!({"key": "value"});
assert!(config.validate_payload(&small_payload).is_ok());
let large_payload = json!({
"very_long_key_that_exceeds_limit": "very_long_value_that_definitely_exceeds_the_50_byte_limit_we_set_for_testing_purposes"
});
assert!(config.validate_payload(&large_payload).is_err());
}
#[test]
fn test_validate_string_length() {
let config = ValidationConfig {
max_string_length: 10,
..Default::default()
};
let valid_payload = json!({"key": "short"});
assert!(config.validate_payload(&valid_payload).is_ok());
let invalid_payload = json!({"key": "this_is_a_very_long_string"});
assert!(config.validate_payload(&invalid_payload).is_err());
}
#[test]
fn test_validate_object_depth() {
let config = ValidationConfig {
max_object_depth: 2,
..Default::default()
};
let valid_payload = json!({"level1": {"level2": "value"}});
assert!(config.validate_payload(&valid_payload).is_ok());
let invalid_payload = json!({"level1": {"level2": {"level3": {"level4": "value"}}}});
assert!(config.validate_payload(&invalid_payload).is_err());
}
#[test]
fn test_forbidden_keys() {
let config = ValidationConfig {
forbidden_keys: vec!["__proto__".to_string()],
..Default::default()
};
let valid_payload = json!({"data": "value"});
assert!(config.validate_payload(&valid_payload).is_ok());
let invalid_payload = json!({"__proto__": "malicious"});
assert!(config.validate_payload(&invalid_payload).is_err());
let nested_invalid_payload = json!({"data": {"__proto__": "malicious"}});
assert!(config.validate_payload(&nested_invalid_payload).is_err());
let deep_nested_invalid = json!({"level1": {"level2": {"__proto__": "malicious"}}});
assert!(config.validate_payload(&deep_nested_invalid).is_err());
}
#[test]
fn test_required_keys() {
let config = ValidationConfig {
required_keys: vec!["user_id".to_string()],
..Default::default()
};
let valid_payload = json!({"user_id": "123", "data": "value"});
assert!(config.validate_payload(&valid_payload).is_ok());
let invalid_payload = json!({"data": "value"});
assert!(config.validate_payload(&invalid_payload).is_err());
}
}