#![cfg(feature = "axum")]
use axum::extract::{FromRequest, Request};
use http_body_util::BodyExt;
use serde::de::DeserializeOwned;
use crate::{
attack_signal::{BoundaryViolation, ViolationKind},
error::BoundaryRejection,
limits::RequestLimits,
validate::{SecureValidate, ValidationContext},
};
pub struct SecureXml<T>(pub T);
impl<T> SecureXml<T> {
#[must_use]
pub fn into_inner(self) -> T {
self.0
}
}
impl<T, S> FromRequest<S> for SecureXml<T>
where
T: DeserializeOwned + SecureValidate,
S: Send + Sync,
{
type Rejection = BoundaryRejection;
async fn from_request(req: Request, _state: &S) -> Result<Self, Self::Rejection> {
let limits = RequestLimits::default();
let ctx = ValidationContext::new();
let content_type = req
.headers()
.get("content-type")
.and_then(|v| v.to_str().ok())
.unwrap_or("");
if !content_type.starts_with("application/xml") && !content_type.starts_with("text/xml") {
BoundaryViolation::new(ViolationKind::InvalidContentType, "invalid_content_type")
.emit();
return Err(BoundaryRejection::InvalidContentType);
}
let bytes = req
.into_body()
.collect()
.await
.map_err(|_| BoundaryRejection::MalformedBody)?
.to_bytes();
if bytes.len() > limits.max_body_bytes {
BoundaryViolation::new(ViolationKind::BodyTooLarge, "body_too_large").emit();
return Err(BoundaryRejection::BodyTooLarge);
}
let text = std::str::from_utf8(&bytes).map_err(|_| {
BoundaryViolation::new(ViolationKind::SyntaxViolation, "invalid_utf8").emit();
BoundaryRejection::MalformedBody
})?;
check_for_xxe(text)?;
let value: T = quick_xml::de::from_str(text).map_err(|_| {
BoundaryViolation::new(ViolationKind::SyntaxViolation, "malformed_xml").emit();
BoundaryRejection::MalformedBody
})?;
value.validate_syntax(&ctx).map_err(|code| {
BoundaryViolation::new(ViolationKind::SyntaxViolation, code).emit();
BoundaryRejection::SyntaxViolation { code }
})?;
value.validate_semantics(&ctx).map_err(|code| {
BoundaryViolation::new(ViolationKind::SemanticViolation, code).emit();
BoundaryRejection::SemanticViolation { code }
})?;
Ok(Self(value))
}
}
fn check_for_xxe(xml: &str) -> Result<(), BoundaryRejection> {
let upper = xml.to_uppercase();
if upper.contains("<!DOCTYPE") || upper.contains("<!ENTITY") {
BoundaryViolation::new(ViolationKind::SyntaxViolation, "xxe_blocked").emit();
return Err(BoundaryRejection::XxeBlocked);
}
Ok(())
}