use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::{fmt, ops};
use actix_web::dev;
use actix_web::http::header;
use actix_web::web::BytesMut;
use actix_web::Error as ActixError;
use actix_web::{FromRequest, HttpRequest};
use futures::future::{err, Either, LocalBoxFuture, Ready};
use futures::{FutureExt, StreamExt};
use serde::de::DeserializeOwned;
pub use crate::config::XmlConfig;
pub use crate::error::XMLPayloadError;
mod config;
mod error;
#[cfg(test)]
mod tests;
pub struct Xml<T>(pub T);
impl<T> Xml<T> {
pub fn into_inner(self) -> T {
self.0
}
}
impl<T> ops::Deref for Xml<T> {
type Target = T;
fn deref(&self) -> &T {
&self.0
}
}
impl<T> ops::DerefMut for Xml<T> {
fn deref_mut(&mut self) -> &mut T {
&mut self.0
}
}
impl<T> fmt::Debug for Xml<T>
where
T: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "XML: {:?}", self.0)
}
}
impl<T> fmt::Display for Xml<T>
where
T: fmt::Display,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt::Display::fmt(&self.0, f)
}
}
impl<T> FromRequest for Xml<T>
where
T: DeserializeOwned + 'static,
{
type Error = ActixError;
#[allow(clippy::type_complexity)]
type Future =
Either<LocalBoxFuture<'static, Result<Self, ActixError>>, Ready<Result<Self, ActixError>>>;
fn from_request(req: &HttpRequest, payload: &mut dev::Payload) -> Self::Future {
let path = req.path().to_string();
let config = XmlConfig::from_req(req);
if let Err(e) = config.check_content_type(req) {
return Either::Right(err(e.into()));
}
Either::Left(
XmlBody::new(req, payload)
.limit(config.limit)
.map(move |res| match res {
Err(e) => {
log::debug!(
"Failed to deserialize XML from payload. \
Request path: {}",
path
);
Err(e.into())
}
Ok(data) => Ok(Xml(data)),
})
.boxed_local(),
)
}
}
pub struct XmlBody<U> {
limit: usize,
length: Option<usize>,
#[cfg(feature = "__compress")]
stream: Option<dev::Decompress<dev::Payload>>,
#[cfg(not(feature = "__compress"))]
stream: Option<dev::Payload>,
err: Option<XMLPayloadError>,
fut: Option<LocalBoxFuture<'static, Result<U, XMLPayloadError>>>,
}
impl<U> XmlBody<U>
where
U: DeserializeOwned + 'static,
{
#[allow(clippy::borrow_interior_mutable_const)]
pub fn new(req: &HttpRequest, payload: &mut dev::Payload) -> Self {
let len = req
.headers()
.get(&header::CONTENT_LENGTH)
.and_then(|l| l.to_str().ok())
.and_then(|s| s.parse::<usize>().ok());
#[cfg(feature = "__compress")]
let payload = dev::Decompress::from_headers(payload.take(), req.headers());
#[cfg(not(feature = "__compress"))]
let payload = payload.take();
XmlBody {
limit: 262_144,
length: len,
stream: Some(payload),
fut: None,
err: None,
}
}
pub fn limit(mut self, limit: usize) -> Self {
self.limit = limit;
self
}
}
impl<U> Future for XmlBody<U>
where
U: DeserializeOwned + 'static,
{
type Output = Result<U, XMLPayloadError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
if let Some(ref mut fut) = self.fut {
return Pin::new(fut).poll(cx);
}
if let Some(err) = self.err.take() {
return Poll::Ready(Err(err));
}
let limit = self.limit;
if let Some(len) = self.length.take() {
if len > limit {
return Poll::Ready(Err(XMLPayloadError::Overflow));
}
}
let mut stream = self.stream.take().unwrap();
self.fut = Some(
async move {
let mut body = BytesMut::with_capacity(8192);
while let Some(item) = stream.next().await {
let chunk = item?;
if (body.len() + chunk.len()) > limit {
return Err(XMLPayloadError::Overflow);
} else {
body.extend_from_slice(&chunk);
}
}
Ok(quick_xml::de::from_reader(&*body)?)
}
.boxed_local(),
);
self.poll(cx)
}
}