use garde::Validate;
use serde::de::DeserializeOwned;
use crate::error::{Error, Result};
use crate::extract::body::{
configured_body_limit, ensure_json_depth_within_limit, read_body_capped_with,
};
use crate::extract::{FromRequest, RequestContext};
#[derive(Debug, Clone)]
pub struct Valid<T>(pub T);
impl<T> Valid<T> {
pub fn into_inner(self) -> T {
self.0
}
}
impl<T> FromRequest for Valid<T>
where
T: DeserializeOwned + Validate<Context = ()> + Send,
{
fn from_request(
ctx: &RequestContext,
) -> impl std::future::Future<Output = Result<Self>> + Send {
let taken = ctx.take_body();
let limit = configured_body_limit(ctx);
async move {
let bytes = read_body_capped_with(taken?, limit).await?;
ensure_json_depth_within_limit(&bytes)?;
let value: T = serde_json::from_slice(&bytes)
.map_err(|_| Error::unprocessable("request body is not valid JSON"))?;
value.validate().map_err(Error::from_garde_report)?;
Ok(Valid(value))
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::body::box_body;
use crate::error::ErrorKind;
use crate::extract::PathParams;
use crate::state::StateMap;
use bytes::Bytes;
use http_body_util::Full;
use serde::Deserialize;
use std::sync::Arc;
#[derive(Debug, Deserialize, garde::Validate)]
struct Sample {
#[garde(range(min = 1))]
count: i64,
}
fn context_with_body(json: &str) -> RequestContext {
let head = http::Request::new(()).into_parts().0;
let body = box_body(Full::new(Bytes::copy_from_slice(json.as_bytes())));
RequestContext::new(head, PathParams::new(), Arc::new(StateMap::new()), body)
}
#[tokio::test]
async fn valid_body_is_accepted() {
let ctx = context_with_body(r#"{"count": 5}"#);
let valid = <Valid<Sample> as FromRequest>::from_request(&ctx)
.await
.expect("should validate");
assert_eq!(valid.into_inner().count, 5);
}
#[tokio::test]
async fn invalid_body_is_unprocessable_with_details() {
let ctx = context_with_body(r#"{"count": 0}"#);
let error = <Valid<Sample> as FromRequest>::from_request(&ctx)
.await
.unwrap_err();
assert_eq!(error.kind(), ErrorKind::Unprocessable);
assert!(
!error.details().is_empty(),
"should report the failing field"
);
}
#[tokio::test]
async fn deeply_nested_body_is_rejected_before_validation() {
let json = format!(
"{}0{}",
"{\"count\":".to_owned() + &"[".repeat(crate::extract::body::MAX_JSON_NESTING + 1),
"]".repeat(crate::extract::body::MAX_JSON_NESTING + 1) + "}"
);
let ctx = context_with_body(&json);
let error = <Valid<Sample> as FromRequest>::from_request(&ctx)
.await
.unwrap_err();
assert_eq!(error.kind(), ErrorKind::BadRequest);
assert_eq!(error.message(), "request body is too deeply nested");
}
}