use crate::api_description::ApiEndpointParameter;
use crate::api_description::ApiSchemaGenerator;
use crate::api_description::{ApiEndpointBodyContentType, ExtensionMode};
use crate::error::HttpError;
use crate::http_util::http_dump_body;
use crate::http_util::CONTENT_TYPE_JSON;
use crate::schema_util::make_subschema_for;
use crate::server::ServerContext;
use crate::ExclusiveExtractor;
use crate::ExtractorMetadata;
use crate::RequestContext;
use async_trait::async_trait;
use bytes::BufMut;
use bytes::Bytes;
use bytes::BytesMut;
use futures::Stream;
use futures::TryStreamExt;
use http_body_util::BodyExt;
use schemars::schema::InstanceType;
use schemars::schema::SchemaObject;
use schemars::JsonSchema;
use serde::de::DeserializeOwned;
use std::fmt::Debug;
#[derive(Debug)]
pub struct TypedBody<BodyType: JsonSchema + DeserializeOwned + Send + Sync> {
inner: BodyType,
}
impl<BodyType: JsonSchema + DeserializeOwned + Send + Sync>
TypedBody<BodyType>
{
pub fn into_inner(self) -> BodyType {
self.inner
}
pub fn map<T, F>(self, f: F) -> TypedBody<T>
where
T: JsonSchema + DeserializeOwned + Send + Sync,
F: FnOnce(BodyType) -> T,
{
TypedBody { inner: f(self.inner) }
}
pub fn try_map<T, E, F>(self, f: F) -> Result<TypedBody<T>, E>
where
T: JsonSchema + DeserializeOwned + Send + Sync,
F: FnOnce(BodyType) -> Result<T, E>,
{
Ok(TypedBody { inner: f(self.inner)? })
}
}
impl<BodyType: JsonSchema + DeserializeOwned + Send + Sync> From<BodyType>
for TypedBody<BodyType>
{
fn from(value: BodyType) -> Self {
TypedBody { inner: value }
}
}
#[derive(Debug)]
pub struct MultipartBody {
pub content: multer::Multipart<'static>,
}
#[async_trait]
impl ExclusiveExtractor for MultipartBody {
async fn from_request<Context: ServerContext>(
_rqctx: &RequestContext<Context>,
request: hyper::Request<crate::Body>,
) -> Result<Self, HttpError> {
let (parts, body) = request.into_parts();
let content_type = parts
.headers
.get(http::header::CONTENT_TYPE)
.ok_or_else(|| {
HttpError::for_bad_request(
None,
"missing content-type header".to_string(),
)
})?
.to_str()
.map_err(|e| {
HttpError::for_bad_request(
None,
format!("invalid content type: {}", e),
)
})?;
let boundary =
content_type.split("boundary=").nth(1).ok_or_else(|| {
HttpError::for_bad_request(
None,
"missing boundary in content-type header".to_string(),
)
})?;
Ok(MultipartBody {
content: multer::Multipart::new(
body.into_data_stream(),
boundary.to_string(),
),
})
}
fn metadata(
_content_type: ApiEndpointBodyContentType,
) -> ExtractorMetadata {
let body = ApiEndpointParameter::new_body(
ApiEndpointBodyContentType::MultipartFormData,
true,
ApiSchemaGenerator::Static {
schema: Box::new(
SchemaObject {
instance_type: Some(InstanceType::String.into()),
format: Some(String::from("binary")),
..Default::default()
}
.into(),
),
dependencies: indexmap::IndexMap::default(),
},
vec![],
);
ExtractorMetadata {
extension_mode: ExtensionMode::None,
parameters: vec![body],
}
}
}
async fn http_request_load_body<BodyType>(
request: hyper::Request<crate::Body>,
request_body_max_bytes: usize,
expected_body_content_type: &ApiEndpointBodyContentType,
) -> Result<TypedBody<BodyType>, HttpError>
where
BodyType: JsonSchema + DeserializeOwned + Send + Sync,
{
let (parts, body) = request.into_parts();
let body = StreamingBody::new(body, request_body_max_bytes)
.into_bytes_mut()
.await?;
let content_type = parts
.headers
.get(http::header::CONTENT_TYPE)
.map(|hv| {
hv.to_str().map_err(|e| {
HttpError::for_bad_request(
None,
format!("invalid content type: {}", e),
)
})
})
.unwrap_or(Ok(CONTENT_TYPE_JSON))?;
let end = content_type.find(';').unwrap_or_else(|| content_type.len());
let mime_type = content_type[..end].trim_end().to_lowercase();
let body_content_type = ApiEndpointBodyContentType::from_mime_type(
&mime_type,
)
.map_err(|e| {
HttpError::for_bad_request(
None,
format!("unsupported content-type: {}", e),
)
})?;
use ApiEndpointBodyContentType::*;
let content = match (expected_body_content_type, body_content_type) {
(Json, Json) => {
let jd = &mut serde_json::Deserializer::from_slice(&body);
serde_path_to_error::deserialize(jd).map_err(|e| {
HttpError::for_bad_request(
None,
format!("unable to parse JSON body: {}", e),
)
})?
}
(UrlEncoded, UrlEncoded) => {
let ud = serde_urlencoded::Deserializer::new(
form_urlencoded::parse(&body),
);
serde_path_to_error::deserialize(ud).map_err(|e| {
HttpError::for_bad_request(
None,
format!("unable to parse URL-encoded body: {}", e),
)
})?
}
(expected, requested) => {
return Err(HttpError::for_bad_request(
None,
format!(
"expected content type \"{}\", got \"{}\"",
expected.mime_type(),
requested.mime_type()
),
));
}
};
Ok(TypedBody { inner: content })
}
#[async_trait]
impl<BodyType> ExclusiveExtractor for TypedBody<BodyType>
where
BodyType: JsonSchema + DeserializeOwned + Send + Sync + 'static,
{
async fn from_request<Context: ServerContext>(
rqctx: &RequestContext<Context>,
request: hyper::Request<crate::Body>,
) -> Result<TypedBody<BodyType>, HttpError> {
http_request_load_body(
request,
rqctx.request_body_max_bytes(),
&rqctx.endpoint.body_content_type,
)
.await
}
fn metadata(content_type: ApiEndpointBodyContentType) -> ExtractorMetadata {
let body = ApiEndpointParameter::new_body(
content_type,
true,
ApiSchemaGenerator::Gen {
name: BodyType::schema_name,
schema: make_subschema_for::<BodyType>,
},
vec![],
);
ExtractorMetadata {
extension_mode: ExtensionMode::None,
parameters: vec![body],
}
}
}
#[derive(Debug)]
pub struct UntypedBody {
content: Bytes,
}
impl UntypedBody {
pub fn as_bytes(&self) -> &[u8] {
&self.content
}
pub fn as_str(&self) -> Result<&str, HttpError> {
std::str::from_utf8(self.as_bytes()).map_err(|e| {
HttpError::for_bad_request(
None,
format!("failed to parse body as UTF-8 string: {}", e),
)
})
}
}
#[async_trait]
impl ExclusiveExtractor for UntypedBody {
async fn from_request<Context: ServerContext>(
rqctx: &RequestContext<Context>,
request: hyper::Request<crate::Body>,
) -> Result<UntypedBody, HttpError> {
let body = request.into_body();
let body_bytes =
StreamingBody::new(body, rqctx.request_body_max_bytes())
.into_bytes_mut()
.await?;
Ok(UntypedBody { content: body_bytes.freeze() })
}
fn metadata(
_content_type: ApiEndpointBodyContentType,
) -> ExtractorMetadata {
untyped_metadata()
}
}
#[derive(Debug)]
pub struct StreamingBody {
body: crate::Body,
cap: usize,
}
impl StreamingBody {
fn new(body: crate::Body, cap: usize) -> Self {
Self { body, cap }
}
#[doc(hidden)]
pub fn __from_bytes(data: Bytes) -> Self {
let cap = data.len();
let body = crate::Body::from(data);
Self { body, cap }
}
pub fn into_stream(
mut self,
) -> impl Stream<Item = Result<Bytes, HttpError>> + Send {
async_stream::try_stream! {
let mut bytes_read: usize = 0;
while let Some(frame_res) = self.body.frame().await {
let frame = frame_res.map_err(|e| HttpError::for_bad_request(
None,
format!("error streaming request body: {}", e),
))?;
let Ok(buf) = frame.into_data() else { continue }; let len = buf.len();
if bytes_read + len > self.cap {
http_dump_body(&mut self.body).await.map_err(|e| {
HttpError::for_bad_request(
None,
format!("error streaming request body: {}", e),
)
})?;
Err(HttpError::for_bad_request(
None,
format!("request body exceeded maximum size of {} bytes", self.cap),
))?;
}
bytes_read += len;
yield buf;
}
}
}
async fn into_bytes_mut(self) -> Result<BytesMut, HttpError> {
self.into_stream()
.try_fold(BytesMut::new(), |mut out, chunk| {
out.put(chunk);
futures::future::ok(out)
})
.await
}
}
#[async_trait]
impl ExclusiveExtractor for StreamingBody {
async fn from_request<Context: ServerContext>(
rqctx: &RequestContext<Context>,
request: hyper::Request<crate::Body>,
) -> Result<Self, HttpError> {
Ok(Self {
body: request.into_body(),
cap: rqctx.request_body_max_bytes(),
})
}
fn metadata(
_content_type: ApiEndpointBodyContentType,
) -> ExtractorMetadata {
untyped_metadata()
}
}
fn untyped_metadata() -> ExtractorMetadata {
ExtractorMetadata {
parameters: vec![ApiEndpointParameter::new_body(
ApiEndpointBodyContentType::Bytes,
true,
ApiSchemaGenerator::Static {
schema: Box::new(
SchemaObject {
instance_type: Some(InstanceType::String.into()),
format: Some(String::from("binary")),
..Default::default()
}
.into(),
),
dependencies: indexmap::IndexMap::default(),
},
vec![],
)],
extension_mode: ExtensionMode::None,
}
}
#[cfg(test)]
mod tests {
use schemars::JsonSchema;
use serde::Deserialize;
use crate::extractor::body::http_request_load_body;
#[tokio::test]
async fn test_content_plus_json() {
#[derive(Deserialize, JsonSchema)]
struct TheRealScimShady {}
let body = "{}";
let request = hyper::Request::builder()
.header(http::header::CONTENT_TYPE, "application/scim+json")
.body(crate::Body::with_content(body))
.unwrap();
let r = http_request_load_body::<TheRealScimShady>(
request,
9000,
&crate::ApiEndpointBodyContentType::Json,
)
.await;
assert!(r.is_ok())
}
#[test]
fn test_typed_body_from() {
#[derive(Deserialize, JsonSchema, Clone, Debug, PartialEq, Eq)]
struct SampleBody {
field: String,
}
let sample = SampleBody { field: "value".to_string() };
let typed_body: crate::extractor::body::TypedBody<SampleBody> =
sample.clone().into();
assert_eq!(typed_body.into_inner(), sample);
}
}