use axum_core::response::{IntoResponse, Response};
use fastrand;
use http::{header, HeaderMap, StatusCode};
use mime::Mime;
#[must_use]
#[derive(Debug)]
pub struct MultipartForm {
parts: Vec<Part>,
}
impl MultipartForm {
pub fn with_parts(parts: Vec<Part>) -> Self {
MultipartForm { parts }
}
}
impl IntoResponse for MultipartForm {
fn into_response(self) -> Response {
let boundary = generate_boundary();
let mut headers = HeaderMap::new();
let mime_type: Mime = match format!("multipart/form-data; boundary={boundary}").parse() {
Ok(m) => m,
Err(_) => {
return (
StatusCode::INTERNAL_SERVER_ERROR,
"Invalid multipart boundary generated",
)
.into_response()
}
};
headers.insert(header::CONTENT_TYPE, mime_type.to_string().parse().unwrap());
let mut serialized_form: Vec<u8> = Vec::new();
for part in self.parts {
serialized_form.extend_from_slice(format!("--{boundary}\r\n").as_bytes());
serialized_form.extend_from_slice(&part.serialize());
}
serialized_form.extend_from_slice(format!("--{boundary}--").as_bytes());
(headers, serialized_form).into_response()
}
}
#[derive(Debug)]
pub struct Part {
name: String,
filename: Option<String>,
mime_type: Mime,
contents: Vec<u8>,
}
impl Part {
#[must_use]
pub fn text(name: String, contents: &str) -> Self {
Self {
name,
filename: None,
mime_type: mime::TEXT_PLAIN_UTF_8,
contents: contents.as_bytes().to_vec(),
}
}
#[must_use]
pub fn file(field_name: &str, file_name: &str, contents: Vec<u8>) -> Self {
Self {
name: field_name.to_owned(),
filename: Some(file_name.to_owned()),
mime_type: mime::APPLICATION_OCTET_STREAM,
contents,
}
}
pub fn raw_part(
name: &str,
mime_type: &str,
contents: Vec<u8>,
filename: Option<&str>,
) -> Result<Self, &'static str> {
let mime_type = mime_type.parse().map_err(|_| "Invalid MIME type")?;
Ok(Self {
name: name.to_owned(),
filename: filename.map(|f| f.to_owned()),
mime_type,
contents,
})
}
pub(super) fn serialize(&self) -> Vec<u8> {
let mut serialized_part = format!("Content-Disposition: form-data; name=\"{}\"", self.name);
if let Some(filename) = &self.filename {
serialized_part += &format!("; filename=\"{filename}\"");
}
serialized_part += "\r\n";
serialized_part += &format!("Content-Type: {}\r\n", self.mime_type);
serialized_part += "\r\n";
let mut part_bytes = serialized_part.as_bytes().to_vec();
part_bytes.extend_from_slice(&self.contents);
part_bytes.extend_from_slice(b"\r\n");
part_bytes
}
}
impl FromIterator<Part> for MultipartForm {
fn from_iter<T: IntoIterator<Item = Part>>(iter: T) -> Self {
Self {
parts: iter.into_iter().collect(),
}
}
}
fn generate_boundary() -> String {
let a = fastrand::u64(0..u64::MAX);
let b = fastrand::u64(0..u64::MAX);
let c = fastrand::u64(0..u64::MAX);
let d = fastrand::u64(0..u64::MAX);
format!("{a:016x}-{b:016x}-{c:016x}-{d:016x}")
}
#[cfg(test)]
mod tests {
use super::{generate_boundary, MultipartForm, Part};
use axum::{body::Body, http};
use axum::{routing::get, Router};
use http::{Request, Response};
use http_body_util::BodyExt;
use mime::Mime;
use tower::ServiceExt;
#[tokio::test]
async fn process_form() -> Result<(), Box<dyn std::error::Error>> {
async fn handle() -> MultipartForm {
let parts: Vec<Part> = vec![
Part::text("part1".to_owned(), "basictext"),
Part::file(
"part2",
"file.txt",
vec![0x68, 0x69, 0x20, 0x6d, 0x6f, 0x6d],
),
Part::raw_part("part3", "text/plain", b"rawpart".to_vec(), None).unwrap(),
];
MultipartForm::from_iter(parts)
}
let app = Router::new().route("/", get(handle));
let response: Response<_> = app
.oneshot(Request::builder().uri("/").body(Body::empty())?)
.await?;
let ct_header = response.headers().get("content-type").unwrap().to_str()?;
let boundary = ct_header.split("boundary=").nth(1).unwrap().to_owned();
let body: &[u8] = &response.into_body().collect().await?.to_bytes();
assert_eq!(
std::str::from_utf8(body)?,
format!(
"--{boundary}\r\n\
Content-Disposition: form-data; name=\"part1\"\r\n\
Content-Type: text/plain; charset=utf-8\r\n\
\r\n\
basictext\r\n\
--{boundary}\r\n\
Content-Disposition: form-data; name=\"part2\"; filename=\"file.txt\"\r\n\
Content-Type: application/octet-stream\r\n\
\r\n\
hi mom\r\n\
--{boundary}\r\n\
Content-Disposition: form-data; name=\"part3\"\r\n\
Content-Type: text/plain\r\n\
\r\n\
rawpart\r\n\
--{boundary}--",
)
);
Ok(())
}
#[test]
fn valid_boundary_generation() {
for _ in 0..256 {
let boundary = generate_boundary();
let mime_type: Result<Mime, _> =
format!("multipart/form-data; boundary={boundary}").parse();
assert!(
mime_type.is_ok(),
"The generated boundary was unable to be parsed into a valid mime type."
);
}
}
}