use bytes::Bytes;
use crate::web::{Error, RequestContext};
#[derive(Debug, Clone)]
pub struct Part {
pub name: String,
pub file_name: Option<String>,
pub content_type: Option<String>,
pub bytes: Bytes,
}
impl Part {
pub fn text(&self) -> Option<&str> {
std::str::from_utf8(&self.bytes).ok()
}
pub fn is_file(&self) -> bool {
self.file_name.is_some()
}
}
#[derive(Debug, Clone, Default)]
pub struct MultipartForm {
parts: Vec<Part>,
}
impl MultipartForm {
pub async fn from_ctx(ctx: &RequestContext) -> Result<Self, Error> {
let content_type = ctx
.header("content-type")
.ok_or(Error::BadRequest("missing content-type for multipart body"))?;
Self::parse(content_type, ctx.body().clone()).await
}
pub async fn parse(content_type: &str, body: Bytes) -> Result<Self, Error> {
let boundary = multer::parse_boundary(content_type)
.map_err(|_| Error::BadRequest("not a multipart/form-data body"))?;
let stream =
futures::stream::once(async move { Ok::<Bytes, std::convert::Infallible>(body) });
let mut multipart = multer::Multipart::new(stream, boundary);
let mut parts = Vec::new();
while let Some(field) = multipart
.next_field()
.await
.map_err(|_| Error::BadRequest("malformed multipart body"))?
{
let name = field.name().unwrap_or("").to_string();
let file_name = field.file_name().map(str::to_string);
let content_type = field.content_type().map(|m| m.to_string());
let bytes = field
.bytes()
.await
.map_err(|_| Error::BadRequest("failed to read multipart field"))?;
parts.push(Part {
name,
file_name,
content_type,
bytes,
});
}
Ok(Self { parts })
}
pub fn parts(&self) -> &[Part] {
&self.parts
}
pub fn part(&self, name: &str) -> Option<&Part> {
self.parts.iter().find(|p| p.name == name)
}
pub fn text(&self, name: &str) -> Option<&str> {
self.part(name)
.filter(|p| !p.is_file())
.and_then(Part::text)
}
pub fn file(&self, name: &str) -> Option<&Part> {
self.parts.iter().find(|p| p.name == name && p.is_file())
}
pub fn files(&self) -> impl Iterator<Item = &Part> {
self.parts.iter().filter(|p| p.is_file())
}
}
#[cfg(test)]
mod tests {
use super::*;
fn sample() -> (String, Bytes) {
let b = "X-BOUNDARY";
let body = format!(
"--{b}\r\n\
Content-Disposition: form-data; name=\"title\"\r\n\r\n\
Hello World\r\n\
--{b}\r\n\
Content-Disposition: form-data; name=\"avatar\"; filename=\"a.png\"\r\n\
Content-Type: image/png\r\n\r\n\
PNGDATA\r\n\
--{b}--\r\n"
);
(
format!("multipart/form-data; boundary={b}"),
Bytes::from(body),
)
}
#[tokio::test]
async fn parses_fields_and_files() {
let (ct, body) = sample();
let form = MultipartForm::parse(&ct, body).await.unwrap();
assert_eq!(form.parts().len(), 2);
assert_eq!(form.text("title"), Some("Hello World"));
let file = form.file("avatar").unwrap();
assert_eq!(file.file_name.as_deref(), Some("a.png"));
assert_eq!(file.content_type.as_deref(), Some("image/png"));
assert_eq!(&file.bytes[..], b"PNGDATA");
assert_eq!(form.files().count(), 1);
}
#[tokio::test]
async fn rejects_non_multipart() {
let err = MultipartForm::parse("application/json", Bytes::from_static(b"{}"))
.await
.unwrap_err();
assert!(matches!(err, Error::BadRequest(_)));
}
}