multipart_stream/
serializer.rs1use std::pin::Pin;
7use std::task::{Context, Poll};
8
9use bytes::{BufMut, Bytes, BytesMut};
10use futures::Stream;
11use http::HeaderMap;
12use pin_project::pin_project;
13
14use crate::Part;
15
16pub fn serialize<S, E>(parts: S, boundary: &str) -> impl Stream<Item = Result<Bytes, E>>
19where
20 S: Stream<Item = Result<Part, E>>,
21{
22 let mut b = BytesMut::with_capacity(boundary.len() + 4);
23 b.put(&b"--"[..]);
24 b.put(boundary.as_bytes());
25 b.put(&b"\r\n"[..]);
26
27 Serializer {
28 parts,
29 boundary: b.freeze(),
30 state: State::Waiting,
31 }
32}
33
34fn serialize_headers(headers: HeaderMap) -> Bytes {
36 let mut b = BytesMut::with_capacity(30 + 30 * headers.len());
38 for (name, value) in &headers {
39 b.put(name.as_str().as_bytes());
40 b.put(&b": "[..]);
41 b.put(value.as_bytes());
42 b.put(&b"\r\n"[..]);
43 }
44 b.put(&b"\r\n"[..]);
45 b.freeze()
46}
47
48enum State {
50 Waiting,
52
53 SendHeaders(Part),
55
56 SendBody(Bytes),
58}
59
60#[pin_project]
61struct Serializer<S, E>
62where
63 S: Stream<Item = Result<Part, E>>,
64{
65 #[pin]
66 parts: S,
67 boundary: Bytes,
68 state: State,
69}
70
71impl<S, E> Stream for Serializer<S, E>
72where
73 S: Stream<Item = Result<Part, E>>,
74{
75 type Item = Result<Bytes, E>;
76
77 fn poll_next(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
78 let mut this = self.project();
79 match std::mem::replace(this.state, State::Waiting) {
80 State::Waiting => match this.parts.as_mut().poll_next(ctx) {
81 Poll::Pending => return Poll::Pending,
82 Poll::Ready(None) => return Poll::Ready(None),
83 Poll::Ready(Some(Err(e))) => return Poll::Ready(Some(Err(e))),
84 Poll::Ready(Some(Ok(mut p))) => {
85 p.headers.insert(
86 http::header::CONTENT_LENGTH,
87 http::HeaderValue::from(p.body.len()),
88 );
89 *this.state = State::SendHeaders(p);
90 return Poll::Ready(Some(Ok(this.boundary.clone())));
91 }
92 },
93 State::SendHeaders(part) => {
94 *this.state = State::SendBody(part.body);
95 let headers = serialize_headers(part.headers);
96 return Poll::Ready(Some(Ok(headers)));
97 }
98 State::SendBody(body) => {
99 return Poll::Ready(Some(Ok(body)));
100 }
101 }
102 }
103}
104
105#[cfg(test)]
106mod tests {
107 use bytes::{BufMut, Bytes, BytesMut};
108 use futures::{Stream, StreamExt};
109 use http::HeaderMap;
110
111 use super::{serialize, Part};
112
113 async fn collect<S, E>(mut s: S) -> Result<Bytes, E>
114 where
115 S: Stream<Item = Result<Bytes, E>> + Unpin,
116 {
117 let mut accum = BytesMut::new();
118 while let Some(b) = s.next().await {
119 accum.put(b?);
120 }
121 Ok(accum.freeze())
122 }
123
124 #[tokio::test]
125 async fn success() {
126 let input = futures::stream::iter(vec![
127 Ok::<_, std::convert::Infallible>(Part {
128 headers: HeaderMap::new(),
129 body: "foo".into(),
130 }),
131 Ok::<_, std::convert::Infallible>(Part {
132 headers: HeaderMap::new(),
133 body: "bar".into(),
134 }),
135 ]);
136 let collected = collect(serialize(input, "b")).await.unwrap();
137 let collected = std::str::from_utf8(&collected[..]).unwrap();
138 assert_eq!(
139 collected,
140 "--b\r\ncontent-length: 3\r\n\r\nfoo\
141 --b\r\ncontent-length: 3\r\n\r\nbar"
142 );
143 }
144
145 #[tokio::test]
146 async fn err() {
147 let e: Box<dyn std::error::Error + Send + Sync> = "uh-oh".to_owned().into();
148 let input = futures::stream::iter(vec![Err(e)]);
149 assert_eq!(
150 collect(serialize(input, "b"))
151 .await
152 .unwrap_err()
153 .to_string(),
154 "uh-oh"
155 );
156 }
157}