use crate::frameworks::actix::error::ActixError;
use crate::frameworks::payload_control::PayloadControl;
use crate::{Hateoas, HateoasResource};
use actix_web::dev::Payload;
use actix_web::http::header::CONTENT_LENGTH;
use actix_web::HttpRequest;
use bytes::BytesMut;
use futures_core::Stream as _;
use serde::de::DeserializeOwned;
use simple_serde::{ContentType, Decoded, SimpleDecoder};
use std::future::Future;
use std::marker::PhantomData;
use std::ops::Deref;
use std::pin::Pin;
use std::task::{Context, Poll};
const DEFAULT_LIMIT: usize = 41_943_040;
pub enum PayloadBody<T, O> {
Error(Option<ActixError>),
Body {
limit: usize,
length: Option<usize>,
content_type: ContentType,
payload: Payload,
buf: BytesMut,
_res: PhantomData<T>,
_payload_res: PhantomData<O>,
},
}
impl<T, O> Unpin for PayloadBody<T, O> {}
impl<T: DeserializeOwned, O: PayloadControl> PayloadBody<T, O> {
#[allow(clippy::borrow_interior_mutable_const)]
pub fn new(r: HttpRequest, payload: &mut Payload) -> Self {
let length = r
.headers()
.get(&CONTENT_LENGTH)
.ok_or(ActixError::NoPayloadSizeDefinitionInHeader)
.and_then(|l| l.to_str().map_err(ActixError::from))
.and_then(|s| s.parse::<usize>().map_err(ActixError::from));
let content_type = Ok(r
.headers()
.get_all("Content-Type")
.filter_map(|h| simple_serde::ContentType::try_from(h).ok())
.collect::<Vec<ContentType>>())
.and_then(|mut t: Vec<ContentType>| {
t.reverse();
t.pop().ok_or(ActixError::FailedToGetContentTypeFromHeader)
});
let payload = payload.take();
match (content_type, length) {
(Ok(c), Ok(l)) => PayloadBody::Body {
limit: O::MAX_PAYLOAD_SIZE.unwrap_or(DEFAULT_LIMIT),
content_type: c,
length: Some(l),
payload,
buf: BytesMut::with_capacity(O::BUFFER_CAPACITY.unwrap_or(8192)),
_res: PhantomData,
_payload_res: PhantomData,
},
(Ok(c), _) => PayloadBody::Body {
limit: O::MAX_PAYLOAD_SIZE.unwrap_or(DEFAULT_LIMIT),
content_type: c,
length: None,
payload,
buf: BytesMut::with_capacity(O::BUFFER_CAPACITY.unwrap_or(8192)),
_res: PhantomData,
_payload_res: PhantomData,
},
(_, _) => PayloadBody::Error(Some(ActixError::ContentType)),
}
}
pub fn limit(self, limit: usize) -> Self {
match self {
PayloadBody::Body {
length,
content_type,
payload,
buf,
..
} => {
if let Some(len) = length {
if len > limit {
return PayloadBody::Error(Some(ActixError::OverflowKnownLength {
length: len,
limit,
}));
}
}
PayloadBody::Body {
limit,
content_type,
length,
payload,
buf,
_res: PhantomData,
_payload_res: PhantomData,
}
}
PayloadBody::Error(e) => PayloadBody::Error(e),
}
}
}
impl<T, O> Future for PayloadBody<T, O>
where
T: DeserializeOwned + HateoasResource + Clone,
O: PayloadControl,
{
type Output = Result<Hateoas<T>, ActixError>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.get_mut();
match this {
PayloadBody::Body {
limit,
buf,
payload,
content_type,
..
} => loop {
let res = match Pin::new(&mut *payload).poll_next(cx) {
std::task::Poll::Ready(t) => t,
std::task::Poll::Pending => {
return std::task::Poll::Pending;
}
};
match res {
Some(chunk) => {
let chunk = chunk?;
let buf_len = buf.len() + chunk.len();
if buf_len > *limit {
return Poll::Ready(Err(ActixError::Overflow { limit: *limit }));
} else {
buf.extend_from_slice(&chunk);
}
}
None => {
let json = buf
.to_vec()
.as_slice()
.decode(content_type.deref())
.map(|d: Decoded<T>| Hateoas::from(d.into()))
.or_else(|_e| {
buf.to_vec()
.as_slice()
.decode(content_type.deref())
.map(|d: Decoded<Hateoas<T>>| d.into())
})
.map_err(ActixError::SerializationDeserializationError)?;
return Poll::Ready(Ok(json));
}
}
},
PayloadBody::Error(e) => Poll::Ready(Err(e.take().unwrap())),
}
}
}