#[cfg(any(feature = "axum", feature = "actix-web"))]
use serde::de::DeserializeOwned;
#[cfg(any(feature = "axum", feature = "actix-web"))]
use crate::{
attack_signal::{BoundaryViolation, ViolationKind},
error::BoundaryRejection,
limits::RequestLimits,
validate::{SecureValidate, ValidationContext},
};
pub struct SecureJson<T>(T);
impl<T> SecureJson<T> {
#[cfg(any(feature = "axum", feature = "actix-web"))]
#[must_use]
pub(crate) fn from_validated(value: T) -> Self {
Self(value)
}
#[must_use]
pub fn into_inner(self) -> T {
self.0
}
}
pub struct SecureQuery<T>(T);
impl<T> SecureQuery<T> {
#[must_use]
pub fn into_inner(self) -> T {
self.0
}
}
pub struct SecurePath<T>(T);
impl<T> SecurePath<T> {
#[must_use]
pub fn into_inner(self) -> T {
self.0
}
}
#[cfg(any(feature = "axum", feature = "actix-web"))]
pub(crate) fn validate_json_bytes<T>(
bytes: &[u8],
limits: &RequestLimits,
ctx: &ValidationContext,
) -> Result<T, BoundaryRejection>
where
T: DeserializeOwned + SecureValidate,
{
if bytes.len() > limits.max_body_bytes {
BoundaryViolation::new(ViolationKind::BodyTooLarge, "body_too_large").emit();
return Err(BoundaryRejection::BodyTooLarge);
}
check_json_limits(bytes, limits.max_nesting_depth, limits.max_field_count)?;
let value: T = serde_json::from_slice(bytes).map_err(|_| {
BoundaryViolation::new(ViolationKind::SyntaxViolation, "malformed_json").emit();
BoundaryRejection::MalformedBody
})?;
value.validate_syntax(ctx).map_err(|code| {
BoundaryViolation::new(ViolationKind::SyntaxViolation, code).emit();
BoundaryRejection::SyntaxViolation { code }
})?;
value.validate_semantics(ctx).map_err(|code| {
BoundaryViolation::new(ViolationKind::SemanticViolation, code).emit();
BoundaryRejection::SemanticViolation { code }
})?;
Ok(value)
}
#[cfg(feature = "axum")]
pub(crate) fn validate_parsed<T>(value: T, ctx: &ValidationContext) -> Result<T, BoundaryRejection>
where
T: SecureValidate,
{
value.validate_syntax(ctx).map_err(|code| {
BoundaryViolation::new(ViolationKind::SyntaxViolation, code).emit();
BoundaryRejection::SyntaxViolation { code }
})?;
value.validate_semantics(ctx).map_err(|code| {
BoundaryViolation::new(ViolationKind::SemanticViolation, code).emit();
BoundaryRejection::SemanticViolation { code }
})?;
Ok(value)
}
#[cfg(feature = "axum")]
mod axum_impl {
use super::{
validate_json_bytes, validate_parsed, BoundaryRejection, BoundaryViolation,
DeserializeOwned, RequestLimits, SecurePath, SecureQuery, SecureValidate,
ValidationContext, ViolationKind,
};
use crate::attack_signal;
use axum::{
extract::{FromRequest, FromRequestParts, Path, Query, Request},
http::request::Parts,
};
use http_body_util::BodyExt;
impl<T, S> FromRequest<S> for super::SecureJson<T>
where
T: DeserializeOwned + SecureValidate,
S: Send + Sync,
{
type Rejection = BoundaryRejection;
async fn from_request(req: Request, _state: &S) -> Result<Self, Self::Rejection> {
let limits = req
.extensions()
.get::<RequestLimits>()
.cloned()
.unwrap_or_default();
let ctx = ValidationContext::new();
let content_type = req
.headers()
.get("content-type")
.and_then(|v| v.to_str().ok())
.unwrap_or("");
if !content_type.starts_with("application/json") {
attack_signal::BoundaryViolation::new(
attack_signal::ViolationKind::InvalidContentType,
"invalid_content_type",
)
.emit();
return Err(BoundaryRejection::InvalidContentType);
}
let bytes = req
.into_body()
.collect()
.await
.map_err(|_| BoundaryRejection::MalformedBody)?
.to_bytes();
let value = validate_json_bytes::<T>(&bytes, &limits, &ctx)?;
Ok(super::SecureJson::from_validated(value))
}
}
impl<T, S> FromRequestParts<S> for SecureQuery<T>
where
T: DeserializeOwned + SecureValidate,
S: Send + Sync,
{
type Rejection = BoundaryRejection;
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
let ctx = ValidationContext::new();
let query = Query::<T>::from_request_parts(parts, state)
.await
.map_err(|_| {
BoundaryViolation::new(ViolationKind::InvalidQueryParam, "invalid_query")
.emit();
BoundaryRejection::InvalidParameter
})?;
let value = validate_parsed(query.0, &ctx)?;
Ok(SecureQuery(value))
}
}
impl<T, S> FromRequestParts<S> for SecurePath<T>
where
T: DeserializeOwned + SecureValidate + Send,
S: Send + Sync,
{
type Rejection = BoundaryRejection;
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
let ctx = ValidationContext::new();
let path = Path::<T>::from_request_parts(parts, state)
.await
.map_err(|_| {
BoundaryViolation::new(ViolationKind::InvalidPathParam, "invalid_path").emit();
BoundaryRejection::InvalidParameter
})?;
let value = validate_parsed(path.0, &ctx)?;
Ok(SecurePath(value))
}
}
}
#[cfg(any(feature = "axum", feature = "actix-web"))]
fn check_json_limits(
bytes: &[u8],
max_depth: usize,
max_fields: usize,
) -> Result<(), BoundaryRejection> {
let mut depth: usize = 0;
let mut field_count: usize = 0;
let mut in_string = false;
let mut escape = false;
for &b in bytes {
if escape {
escape = false;
continue;
}
if in_string {
if b == b'\\' {
escape = true;
} else if b == b'"' {
in_string = false;
}
continue;
}
match b {
b'"' => {
in_string = true;
}
b'{' | b'[' => {
depth += 1;
if depth > max_depth {
BoundaryViolation::new(ViolationKind::NestingTooDeep, "nesting_too_deep")
.emit();
return Err(BoundaryRejection::NestingTooDeep);
}
}
b'}' | b']' => {
depth = depth.saturating_sub(1);
}
b':' => {
field_count += 1;
if field_count > max_fields {
BoundaryViolation::new(ViolationKind::TooManyFields, "too_many_fields").emit();
return Err(BoundaryRejection::TooManyFields);
}
}
_ => {}
}
}
Ok(())
}