use async_trait::async_trait;
use bytes::Bytes;
use super::parser::{ParseResult, ParsedData};
#[async_trait]
pub trait ParserValidator: Send + Sync {
async fn before_parse(&self, content_type: Option<&str>, body: &Bytes) -> ParseResult<()>;
async fn after_parse(&self, data: &ParsedData) -> ParseResult<()>;
}
#[derive(Debug, Clone)]
pub struct SizeLimitValidator {
max_size: usize,
}
impl SizeLimitValidator {
pub fn new(max_size: usize) -> Self {
Self { max_size }
}
}
#[async_trait]
impl ParserValidator for SizeLimitValidator {
async fn before_parse(&self, _content_type: Option<&str>, body: &Bytes) -> ParseResult<()> {
use crate::exception::Error;
if body.len() > self.max_size {
return Err(Error::Validation(format!(
"Request body size {} exceeds maximum allowed size {}",
body.len(),
self.max_size
)));
}
Ok(())
}
async fn after_parse(&self, _data: &ParsedData) -> ParseResult<()> {
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct ContentTypeValidator {
allowed_types: Vec<String>,
}
impl ContentTypeValidator {
pub fn new(allowed_types: Vec<String>) -> Self {
Self { allowed_types }
}
}
#[async_trait]
impl ParserValidator for ContentTypeValidator {
async fn before_parse(&self, content_type: Option<&str>, _body: &Bytes) -> ParseResult<()> {
use crate::exception::Error;
if let Some(ct) = content_type {
let media_type = ct.split(';').next().unwrap_or(ct).trim().to_lowercase();
for allowed in &self.allowed_types {
if media_type == allowed.to_lowercase() {
return Ok(());
}
}
return Err(Error::Validation(format!(
"Content-Type '{}' is not allowed. Allowed types: {:?}",
ct, self.allowed_types
)));
}
Err(Error::Validation(
"Content-Type header is required".to_string(),
))
}
async fn after_parse(&self, _data: &ParsedData) -> ParseResult<()> {
Ok(())
}
}
#[derive(Default)]
pub struct CompositeValidator {
validators: Vec<Box<dyn ParserValidator>>,
}
impl CompositeValidator {
pub fn new() -> Self {
Self::default()
}
#[allow(clippy::should_implement_trait)]
pub fn add<V: ParserValidator + 'static>(mut self, validator: V) -> Self {
self.validators.push(Box::new(validator));
self
}
}
#[async_trait]
impl ParserValidator for CompositeValidator {
async fn before_parse(&self, content_type: Option<&str>, body: &Bytes) -> ParseResult<()> {
for validator in &self.validators {
validator.before_parse(content_type, body).await?;
}
Ok(())
}
async fn after_parse(&self, data: &ParsedData) -> ParseResult<()> {
for validator in &self.validators {
validator.after_parse(data).await?;
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use rstest::rstest;
use serde_json::json;
#[rstest]
#[tokio::test]
async fn test_size_limit_validator_within_limit() {
let validator = SizeLimitValidator::new(100);
let body = Bytes::from("small body");
let result = validator.before_parse(None, &body).await;
assert!(result.is_ok());
}
#[rstest]
#[tokio::test]
async fn test_size_limit_validator_exceeds_limit() {
let validator = SizeLimitValidator::new(10);
let body = Bytes::from("this is a very long body that exceeds the limit");
let result = validator.before_parse(None, &body).await;
assert!(result.is_err());
}
#[rstest]
#[tokio::test]
async fn test_size_limit_validator_after_parse() {
let validator = SizeLimitValidator::new(100);
let data = ParsedData::Json(json!({"key": "value"}));
let result = validator.after_parse(&data).await;
assert!(result.is_ok());
}
#[rstest]
#[tokio::test]
async fn test_content_type_validator_allowed() {
let validator = ContentTypeValidator::new(vec!["application/json".to_string()]);
let body = Bytes::new();
let result = validator
.before_parse(Some("application/json"), &body)
.await;
assert!(result.is_ok());
}
#[rstest]
#[tokio::test]
async fn test_content_type_validator_not_allowed() {
let validator = ContentTypeValidator::new(vec!["application/json".to_string()]);
let body = Bytes::new();
let result = validator.before_parse(Some("text/plain"), &body).await;
assert!(result.is_err());
}
#[rstest]
#[tokio::test]
async fn test_content_type_validator_missing() {
let validator = ContentTypeValidator::new(vec!["application/json".to_string()]);
let body = Bytes::new();
let result = validator.before_parse(None, &body).await;
assert!(result.is_err());
}
#[rstest]
#[tokio::test]
async fn test_content_type_validator_with_charset() {
let validator = ContentTypeValidator::new(vec!["application/json".to_string()]);
let body = Bytes::new();
let result = validator
.before_parse(Some("application/json; charset=utf-8"), &body)
.await;
assert!(result.is_ok());
}
#[rstest]
#[tokio::test]
async fn test_content_type_validator_rejects_substring_match() {
let validator = ContentTypeValidator::new(vec!["application/json".to_string()]);
let body = Bytes::new();
let result = validator
.before_parse(Some("application/not-json-at-all"), &body)
.await;
assert!(result.is_err());
}
#[rstest]
#[tokio::test]
async fn test_content_type_validator_rejects_prefix_substring() {
let validator = ContentTypeValidator::new(vec!["text/plain".to_string()]);
let body = Bytes::new();
let result = validator.before_parse(Some("text/plaintext"), &body).await;
assert!(result.is_err());
}
#[rstest]
#[tokio::test]
async fn test_content_type_validator_rejects_suffix_substring() {
let validator = ContentTypeValidator::new(vec!["application/xml".to_string()]);
let body = Bytes::new();
let result = validator
.before_parse(Some("application/soap+xml"), &body)
.await;
assert!(result.is_err());
}
#[rstest]
#[tokio::test]
async fn test_content_type_validator_case_insensitive() {
let validator = ContentTypeValidator::new(vec!["application/json".to_string()]);
let body = Bytes::new();
let result = validator
.before_parse(Some("Application/JSON"), &body)
.await;
assert!(result.is_ok());
}
#[rstest]
#[tokio::test]
async fn test_content_type_validator_multiple_allowed_types() {
let validator = ContentTypeValidator::new(vec![
"application/json".to_string(),
"application/xml".to_string(),
"text/plain".to_string(),
]);
let body = Bytes::new();
assert!(
validator
.before_parse(Some("application/json"), &body)
.await
.is_ok()
);
assert!(
validator
.before_parse(Some("application/xml"), &body)
.await
.is_ok()
);
assert!(
validator
.before_parse(Some("text/plain"), &body)
.await
.is_ok()
);
assert!(
validator
.before_parse(Some("text/html"), &body)
.await
.is_err()
);
}
#[rstest]
#[tokio::test]
async fn test_content_type_validator_with_multiple_parameters() {
let validator = ContentTypeValidator::new(vec!["application/json".to_string()]);
let body = Bytes::new();
let result = validator
.before_parse(
Some("application/json; charset=utf-8; boundary=something"),
&body,
)
.await;
assert!(result.is_ok());
}
#[rstest]
#[tokio::test]
async fn test_content_type_validator_whitespace_handling() {
let validator = ContentTypeValidator::new(vec!["application/json".to_string()]);
let body = Bytes::new();
let result = validator
.before_parse(Some(" application/json ; charset=utf-8"), &body)
.await;
assert!(result.is_ok());
}
#[rstest]
#[tokio::test]
async fn test_composite_validator_all_pass() {
let validator = CompositeValidator::new()
.add(SizeLimitValidator::new(100))
.add(ContentTypeValidator::new(vec![
"application/json".to_string(),
]));
let body = Bytes::from("small");
let result = validator
.before_parse(Some("application/json"), &body)
.await;
assert!(result.is_ok());
}
#[rstest]
#[tokio::test]
async fn test_composite_validator_first_fails() {
let validator = CompositeValidator::new()
.add(SizeLimitValidator::new(3))
.add(ContentTypeValidator::new(vec![
"application/json".to_string(),
]));
let body = Bytes::from("this is too long");
let result = validator
.before_parse(Some("application/json"), &body)
.await;
assert!(result.is_err());
}
#[rstest]
#[tokio::test]
async fn test_composite_validator_second_fails() {
let validator = CompositeValidator::new()
.add(SizeLimitValidator::new(100))
.add(ContentTypeValidator::new(vec![
"application/json".to_string(),
]));
let body = Bytes::from("small");
let result = validator.before_parse(Some("text/plain"), &body).await;
assert!(result.is_err());
}
#[rstest]
#[tokio::test]
async fn test_composite_validator_after_parse() {
let validator = CompositeValidator::new()
.add(SizeLimitValidator::new(100))
.add(ContentTypeValidator::new(vec![
"application/json".to_string(),
]));
let data = ParsedData::Json(json!({"key": "value"}));
let result = validator.after_parse(&data).await;
assert!(result.is_ok());
}
#[rstest]
#[tokio::test]
async fn test_composite_validator_empty() {
let validator = CompositeValidator::new();
let body = Bytes::from("test");
let result = validator.before_parse(None, &body).await;
assert!(result.is_ok());
let data = ParsedData::Json(json!({"key": "value"}));
let result = validator.after_parse(&data).await;
assert!(result.is_ok());
}
}