use std::{
borrow::Cow,
fmt::Display,
io::Cursor,
pin::Pin,
task::{Context, Poll},
};
use rocket::{
futures::Stream,
http::{ContentType, Status},
response::{Responder, Result},
tokio::io::{AsyncRead, ReadBuf},
Request, Response,
};
pub struct MultipartSection<'r> {
pub content_type: Option<ContentType>,
pub content_encoding: Option<Cow<'r, str>>,
pub content: Pin<Box<dyn AsyncRead + Send + 'r>>,
}
pub struct MultipartStream<T> {
boundary: String,
stream: T,
sub_type: &'static str,
}
impl<T> MultipartStream<T> {
pub fn new(boundary: impl Into<String>, stream: T) -> Self {
Self {
boundary: boundary.into(),
stream,
sub_type: "mixed",
}
}
#[cfg(feature = "rand")]
pub fn new_random(stream: T) -> Self {
use rand::{distributions::Alphanumeric, Rng};
Self {
boundary: rand::thread_rng()
.sample_iter(Alphanumeric)
.map(|v| v as char)
.take(15)
.collect(),
stream,
sub_type: "mixed",
}
}
pub fn with_subtype(mut self, sub_type: &'static str) -> Self {
self.sub_type = sub_type;
self
}
}
impl<'r, 'o: 'r, T: Stream<Item = MultipartSection<'o>> + Send + 'o> Responder<'r, 'o>
for MultipartStream<T>
{
fn respond_to(self, _r: &'r Request<'_>) -> Result<'o> {
Response::build()
.status(Status::Ok)
.header(
ContentType::new("multipart", self.sub_type)
.with_params(("boundary", self.boundary.clone())),
)
.streamed_body(MultipartStreamInner(
self.boundary,
self.stream,
StreamState::Waiting,
))
.ok()
}
}
struct MultipartStreamInner<'r, T>(String, T, StreamState<'r>);
impl<'r, T> MultipartStreamInner<'r, T> {
fn inner(self: Pin<&mut Self>) -> (&str, Pin<&mut T>, &mut StreamState<'r>) {
let this = unsafe { self.get_unchecked_mut() };
(
&this.0,
unsafe { Pin::new_unchecked(&mut this.1) },
&mut this.2,
)
}
}
enum StreamState<'r> {
Waiting,
Header(Cursor<Vec<u8>>, Pin<Box<dyn AsyncRead + Send + 'r>>),
Raw(Pin<Box<dyn AsyncRead + Send + 'r>>),
Footer(Cursor<Vec<u8>>),
}
struct HV<T>(&'static str, Option<T>);
impl<T: Display> Display for HV<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match &self.1 {
Some(v) => write!(f, "{}: {}\r\n", self.0, v),
None => Ok(()),
}
}
}
impl<'r, T: Stream<Item = MultipartSection<'r>> + Send + 'r> AsyncRead
for MultipartStreamInner<'r, T>
{
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
let (boundary, mut stream, state) = self.inner();
loop {
match state {
StreamState::Waiting => match stream.as_mut().poll_next(cx) {
Poll::Ready(Some(v)) => {
*state = StreamState::Header(
Cursor::new(
format!(
"\r\n--{boundary}\r\n{}{}\r\n",
HV("Content-Type", v.content_type),
HV("Content-Encoding", v.content_encoding),
)
.into_bytes(),
),
v.content,
);
}
Poll::Ready(None) => {
*state = StreamState::Footer(Cursor::new(
format!("\r\n--{boundary}--\r\n").into_bytes(),
));
}
Poll::Pending => return Poll::Pending,
},
StreamState::Header(r, _) => {
let cur = buf.filled().len();
match Pin::new(r).poll_read(cx, buf) {
Poll::Ready(Ok(())) => (),
v => return v,
}
if cur == buf.filled().len() {
if let StreamState::Header(_, next) =
std::mem::replace(state, StreamState::Waiting)
{
*state = StreamState::Raw(next);
} else {
unreachable!()
}
} else {
return Poll::Ready(Ok(()));
}
}
StreamState::Raw(r) => {
let cur = buf.filled().len();
match r.as_mut().poll_read(cx, buf) {
Poll::Ready(Ok(())) => (),
v => return v,
}
if cur == buf.filled().len() {
*state = StreamState::Waiting;
} else {
return Poll::Ready(Ok(()));
}
}
StreamState::Footer(r) => {
let cur = buf.filled().len();
match Pin::new(r).poll_read(cx, buf) {
Poll::Ready(Ok(())) => (),
v => return v,
}
if cur == buf.filled().len() {
return Poll::Ready(Ok(()));
} else {
return Poll::Ready(Ok(()));
}
}
}
}
}
}
#[cfg(test)]
mod tests {
use async_stream::stream;
use rocket::{get, local::blocking::Client, routes, Build, Rocket};
use super::*;
#[get("/mixed")]
fn multipart_route() -> MultipartStream<impl Stream<Item = MultipartSection<'static>>> {
MultipartStream::new(
"Sep",
stream! {
yield MultipartSection {
content_type: Some(ContentType::Text),
content_encoding: None,
content: Box::pin(b"How can I help you" as &[u8])
};
yield MultipartSection {
content_type: Some(ContentType::Text),
content_encoding: None,
content: Box::pin(b"today?" as &[u8])
};
yield MultipartSection {
content_type: Some(ContentType::Binary),
content_encoding: None,
content: Box::pin(&[0xFFu8, 0xFE, 0xF0] as &[u8])
};
},
)
}
fn rocket() -> Rocket<Build> {
rocket::build().mount("/", routes![multipart_route])
}
#[test]
fn simple_it_works() {
let client = Client::untracked(rocket()).unwrap();
let res = client.get("/mixed").dispatch();
assert_eq!(res.status(), Status::Ok);
assert_eq!(res.content_type(), Some(ContentType::new("multipart", "mixed")));
let mut expected_contents = vec![];
expected_contents.extend_from_slice(b"\r\n");
expected_contents.extend_from_slice(b"--Sep\r\n");
expected_contents.extend_from_slice(b"Content-Type: text/plain; charset=utf-8\r\n");
expected_contents.extend_from_slice(b"\r\n");
expected_contents.extend_from_slice(b"How can I help you\r\n");
expected_contents.extend_from_slice(b"--Sep\r\n");
expected_contents.extend_from_slice(b"Content-Type: text/plain; charset=utf-8\r\n");
expected_contents.extend_from_slice(b"\r\n");
expected_contents.extend_from_slice(b"today?\r\n");
expected_contents.extend_from_slice(b"--Sep\r\n");
expected_contents.extend_from_slice(b"Content-Type: application/octet-stream\r\n");
expected_contents.extend_from_slice(b"\r\n");
expected_contents.extend_from_slice(&[0xFF, 0xFe, 0xF0]);
expected_contents.extend_from_slice(b"\r\n");
expected_contents.extend_from_slice(b"--Sep--\r\n");
assert_eq!(res.into_bytes(), Some(expected_contents));
}
}