use super::{rejection::*, BodyStream, FromRequest, RequestParts};
use crate::body::{Bytes, HttpBody};
use crate::BoxError;
use async_trait::async_trait;
use futures_util::stream::Stream;
use http::header::{HeaderMap, CONTENT_TYPE};
use mime::Mime;
use std::{
fmt,
pin::Pin,
task::{Context, Poll},
};
#[cfg_attr(docsrs, doc(cfg(feature = "multipart")))]
#[derive(Debug)]
pub struct Multipart {
inner: multer::Multipart<'static>,
}
#[async_trait]
impl<B> FromRequest<B> for Multipart
where
B: HttpBody<Data = Bytes> + Default + Unpin + Send + 'static,
B::Error: Into<BoxError>,
{
type Rejection = MultipartRejection;
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
let stream = BodyStream::from_request(req).await?;
let headers = req.headers().ok_or_else(HeadersAlreadyExtracted::default)?;
let boundary = parse_boundary(headers).ok_or(InvalidBoundary)?;
let multipart = multer::Multipart::new(stream, boundary);
Ok(Self { inner: multipart })
}
}
impl Multipart {
pub async fn next_field(&mut self) -> Result<Option<Field<'_>>, MultipartError> {
let field = self
.inner
.next_field()
.await
.map_err(MultipartError::from_multer)?;
if let Some(field) = field {
Ok(Some(Field {
inner: field,
_multipart: self,
}))
} else {
Ok(None)
}
}
}
#[derive(Debug)]
pub struct Field<'a> {
inner: multer::Field<'static>,
_multipart: &'a mut Multipart,
}
impl<'a> Stream for Field<'a> {
type Item = Result<Bytes, MultipartError>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
Pin::new(&mut self.inner)
.poll_next(cx)
.map_err(MultipartError::from_multer)
}
}
impl<'a> Field<'a> {
pub fn name(&self) -> Option<&str> {
self.inner.name()
}
pub fn file_name(&self) -> Option<&str> {
self.inner.file_name()
}
pub fn content_type(&self) -> Option<&Mime> {
self.inner.content_type()
}
pub fn headers(&self) -> &HeaderMap {
self.inner.headers()
}
pub async fn bytes(self) -> Result<Bytes, MultipartError> {
self.inner
.bytes()
.await
.map_err(MultipartError::from_multer)
}
pub async fn text(self) -> Result<String, MultipartError> {
self.inner.text().await.map_err(MultipartError::from_multer)
}
}
#[derive(Debug)]
pub struct MultipartError {
source: multer::Error,
}
impl MultipartError {
fn from_multer(multer: multer::Error) -> Self {
Self { source: multer }
}
}
impl fmt::Display for MultipartError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Error parsing `multipart/form-data` request")
}
}
impl std::error::Error for MultipartError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
Some(&self.source)
}
}
fn parse_boundary(headers: &HeaderMap) -> Option<String> {
let content_type = headers.get(CONTENT_TYPE)?.to_str().ok()?;
multer::parse_boundary(content_type).ok()
}
composite_rejection! {
pub enum MultipartRejection {
BodyAlreadyExtracted,
InvalidBoundary,
HeadersAlreadyExtracted,
}
}
define_rejection! {
#[status = BAD_REQUEST]
#[body = "Invalid `boundary` for `multipart/form-data` request"]
pub struct InvalidBoundary;
}