multipart_stream/
serializer.rs

1// Copyright (C) 2021 Scott Lamb <slamb@slamb.org>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4//! Serializes [crate::Part]s into a byte stream.
5
6use std::pin::Pin;
7use std::task::{Context, Poll};
8
9use bytes::{BufMut, Bytes, BytesMut};
10use futures::Stream;
11use http::HeaderMap;
12use pin_project::pin_project;
13
14use crate::Part;
15
16/// Serializes [Part]s into [Bytes].
17/// Sets the `Content-Length` header on each part rather than expecting the caller to do so.
18pub fn serialize<S, E>(parts: S, boundary: &str) -> impl Stream<Item = Result<Bytes, E>>
19where
20    S: Stream<Item = Result<Part, E>>,
21{
22    let mut b = BytesMut::with_capacity(boundary.len() + 4);
23    b.put(&b"--"[..]);
24    b.put(boundary.as_bytes());
25    b.put(&b"\r\n"[..]);
26
27    Serializer {
28        parts,
29        boundary: b.freeze(),
30        state: State::Waiting,
31    }
32}
33
34/// Serializes HTTP headers into the usual form, including a final empty line.
35fn serialize_headers(headers: HeaderMap) -> Bytes {
36    // This is the same reservation hyper uses. It calls it "totally scientific".
37    let mut b = BytesMut::with_capacity(30 + 30 * headers.len());
38    for (name, value) in &headers {
39        b.put(name.as_str().as_bytes());
40        b.put(&b": "[..]);
41        b.put(value.as_bytes());
42        b.put(&b"\r\n"[..]);
43    }
44    b.put(&b"\r\n"[..]);
45    b.freeze()
46}
47
48/// State of the [Serializer].
49enum State {
50    /// Waiting for a fresh [Part] from the inner stream.
51    Waiting,
52
53    /// Waiting for a chance to send the headers of a previous [Part].
54    SendHeaders(Part),
55
56    /// Waiting for a chance to send the body of a previous [Part].
57    SendBody(Bytes),
58}
59
60#[pin_project]
61struct Serializer<S, E>
62where
63    S: Stream<Item = Result<Part, E>>,
64{
65    #[pin]
66    parts: S,
67    boundary: Bytes,
68    state: State,
69}
70
71impl<S, E> Stream for Serializer<S, E>
72where
73    S: Stream<Item = Result<Part, E>>,
74{
75    type Item = Result<Bytes, E>;
76
77    fn poll_next(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
78        let mut this = self.project();
79        match std::mem::replace(this.state, State::Waiting) {
80            State::Waiting => match this.parts.as_mut().poll_next(ctx) {
81                Poll::Pending => return Poll::Pending,
82                Poll::Ready(None) => return Poll::Ready(None),
83                Poll::Ready(Some(Err(e))) => return Poll::Ready(Some(Err(e))),
84                Poll::Ready(Some(Ok(mut p))) => {
85                    p.headers.insert(
86                        http::header::CONTENT_LENGTH,
87                        http::HeaderValue::from(p.body.len()),
88                    );
89                    *this.state = State::SendHeaders(p);
90                    return Poll::Ready(Some(Ok(this.boundary.clone())));
91                }
92            },
93            State::SendHeaders(part) => {
94                *this.state = State::SendBody(part.body);
95                let headers = serialize_headers(part.headers);
96                return Poll::Ready(Some(Ok(headers)));
97            }
98            State::SendBody(body) => {
99                return Poll::Ready(Some(Ok(body)));
100            }
101        }
102    }
103}
104
105#[cfg(test)]
106mod tests {
107    use bytes::{BufMut, Bytes, BytesMut};
108    use futures::{Stream, StreamExt};
109    use http::HeaderMap;
110
111    use super::{serialize, Part};
112
113    async fn collect<S, E>(mut s: S) -> Result<Bytes, E>
114    where
115        S: Stream<Item = Result<Bytes, E>> + Unpin,
116    {
117        let mut accum = BytesMut::new();
118        while let Some(b) = s.next().await {
119            accum.put(b?);
120        }
121        Ok(accum.freeze())
122    }
123
124    #[tokio::test]
125    async fn success() {
126        let input = futures::stream::iter(vec![
127            Ok::<_, std::convert::Infallible>(Part {
128                headers: HeaderMap::new(),
129                body: "foo".into(),
130            }),
131            Ok::<_, std::convert::Infallible>(Part {
132                headers: HeaderMap::new(),
133                body: "bar".into(),
134            }),
135        ]);
136        let collected = collect(serialize(input, "b")).await.unwrap();
137        let collected = std::str::from_utf8(&collected[..]).unwrap();
138        assert_eq!(
139            collected,
140            "--b\r\ncontent-length: 3\r\n\r\nfoo\
141             --b\r\ncontent-length: 3\r\n\r\nbar"
142        );
143    }
144
145    #[tokio::test]
146    async fn err() {
147        let e: Box<dyn std::error::Error + Send + Sync> = "uh-oh".to_owned().into();
148        let input = futures::stream::iter(vec![Err(e)]);
149        assert_eq!(
150            collect(serialize(input, "b"))
151                .await
152                .unwrap_err()
153                .to_string(),
154            "uh-oh"
155        );
156    }
157}