use std::sync::atomic::{AtomicU64, Ordering};
use bytes::{BufMut, Bytes, BytesMut};
use http::HeaderValue;
use http_body_util::Full;
use crate::body::{
Body,
codec::{BodyContentType, BodyEncoder},
};
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct Part {
pub name: String,
pub filename: Option<String>,
pub content_type: Option<HeaderValue>,
pub body: Bytes,
}
impl Part {
pub fn text<N, V>(name: N, value: V) -> Self
where
N: Into<String>,
V: Into<String>,
{
Self {
name: name.into(),
filename: None,
content_type: Some(HeaderValue::from_static("text/plain; charset=utf-8")),
body: Bytes::from(value.into()),
}
}
pub fn file<N, F>(
name: N,
filename: F,
content_type: HeaderValue,
body: impl Into<Bytes>,
) -> Self
where
N: Into<String>,
F: Into<String>,
{
Self {
name: name.into(),
filename: Some(filename.into()),
content_type: Some(content_type),
body: body.into(),
}
}
pub fn raw<N>(name: N, content_type: HeaderValue, body: impl Into<Bytes>) -> Self
where
N: Into<String>,
{
Self {
name: name.into(),
filename: None,
content_type: Some(content_type),
body: body.into(),
}
}
}
#[derive(Clone, Debug, Default, PartialEq, Eq)]
pub struct MultipartForm {
parts: Vec<Part>,
}
impl MultipartForm {
pub fn builder() -> MultipartFormBuilder {
MultipartFormBuilder::default()
}
pub fn from_parts(parts: Vec<Part>) -> Self {
Self { parts }
}
pub fn parts(&self) -> &[Part] {
&self.parts
}
}
#[derive(Clone, Debug, Default, PartialEq, Eq)]
pub struct MultipartFormBuilder {
parts: Vec<Part>,
}
impl MultipartFormBuilder {
pub fn part(mut self, part: Part) -> Self {
self.parts.push(part);
self
}
pub fn text<N, V>(self, name: N, value: V) -> Self
where
N: Into<String>,
V: Into<String>,
{
self.part(Part::text(name, value))
}
pub fn file<N, F>(
self,
name: N,
filename: F,
content_type: HeaderValue,
body: impl Into<Bytes>,
) -> Self
where
N: Into<String>,
F: Into<String>,
{
self.part(Part::file(name, filename, content_type, body))
}
pub fn build(self) -> MultipartForm {
MultipartForm { parts: self.parts }
}
}
#[derive(Clone, Debug)]
pub struct MultipartEncoder {
boundary: String,
}
impl Default for MultipartEncoder {
fn default() -> Self {
Self::new()
}
}
impl MultipartEncoder {
pub fn new() -> Self {
Self {
boundary: next_boundary(),
}
}
pub fn with_boundary(boundary: impl Into<String>) -> Self {
Self {
boundary: boundary.into(),
}
}
pub fn boundary(&self) -> &str {
&self.boundary
}
}
static BOUNDARY_COUNTER: AtomicU64 = AtomicU64::new(0);
fn next_boundary() -> String {
let n = BOUNDARY_COUNTER.fetch_add(1, Ordering::Relaxed);
format!("----toac-boundary-{n:016x}")
}
impl BodyContentType for MultipartEncoder {
fn content_type(&self) -> HeaderValue {
let raw = format!("multipart/form-data; boundary={}", self.boundary);
HeaderValue::try_from(raw).expect("boundary is ASCII by construction")
}
}
impl BodyEncoder<&MultipartForm> for MultipartEncoder {
type Error = std::convert::Infallible;
fn encode(&self, data: &MultipartForm) -> Result<Body, Self::Error> {
let bytes = render_parts(&self.boundary, data.parts());
Ok(Body::new(Full::new(bytes)))
}
}
fn render_parts(boundary: &str, parts: &[Part]) -> Bytes {
let mut out = BytesMut::with_capacity(256 + parts.iter().map(|p| p.body.len()).sum::<usize>());
for part in parts {
out.put(b"--".as_slice());
out.put(boundary.as_bytes());
out.put(b"\r\n".as_slice());
out.put(b"Content-Disposition: form-data; name=\"".as_slice());
out.put(escape_quoted(&part.name).as_bytes());
out.put(b"\"".as_slice());
if let Some(filename) = part.filename.as_deref() {
out.put(b"; filename=\"".as_slice());
out.put(escape_quoted(filename).as_bytes());
out.put(b"\"".as_slice());
}
out.put(b"\r\n".as_slice());
if let Some(ct) = part.content_type.as_ref() {
out.put(b"Content-Type: ".as_slice());
out.put(ct.as_bytes());
out.put(b"\r\n".as_slice());
}
out.put(b"\r\n".as_slice());
out.put(part.body.as_ref());
out.put(b"\r\n".as_slice());
}
out.put(b"--".as_slice());
out.put(boundary.as_bytes());
out.put(b"--\r\n".as_slice());
out.freeze()
}
fn escape_quoted(raw: &str) -> String {
let mut out = String::with_capacity(raw.len());
for c in raw.chars() {
match c {
'"' => out.push_str("%22"),
'\r' => out.push_str("%0D"),
'\n' => out.push_str("%0A"),
c => out.push(c),
}
}
out
}