1use futures_channel::mpsc;
10use futures_core::Stream;
11use std::io::{self, Write};
12use std::mem;
13
14pub(crate) struct BodyWriter<D, E>
27where
28 D: From<Vec<u8>> + Send + 'static,
29 E: Send + 'static,
30{
31 sender: mpsc::UnboundedSender<Result<D, E>>,
32
33 buf: Vec<u8>,
35}
36
37impl<D, E> BodyWriter<D, E>
38where
39 D: From<Vec<u8>> + Send + 'static,
40 E: Send + 'static,
41{
42 pub(crate) fn with_chunk_size(
43 cap: usize,
44 ) -> (Self, Box<dyn Stream<Item = Result<D, E>> + Send>) {
45 assert!(cap > 0);
46 let (snd, rcv) = mpsc::unbounded();
47 let body = Box::new(rcv);
48 (
49 BodyWriter {
50 sender: snd,
51 buf: Vec::with_capacity(cap),
52 },
53 body,
54 )
55 }
56
57 pub(crate) fn abort(&mut self, error: E) {
59 let _ = self.sender.unbounded_send(Err(error));
61 }
62
63 #[cfg(test)]
65 fn truncate(&mut self) {
66 self.buf.clear()
67 }
68}
69
70impl<D, E> Write for BodyWriter<D, E>
71where
72 D: From<Vec<u8>> + Send + 'static,
73 E: Send + 'static,
74{
75 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
76 let remaining = self.buf.capacity() - self.buf.len();
77 let full = remaining <= buf.len();
78 let bytes = if full { remaining } else { buf.len() };
79 self.buf.extend_from_slice(&buf[0..bytes]);
80 if full {
81 self.flush()?;
82 }
83 Ok(bytes)
84 }
85
86 fn flush(&mut self) -> io::Result<()> {
87 if !self.buf.is_empty() {
88 let cap = self.buf.capacity();
89 let full_buf = mem::replace(&mut self.buf, Vec::with_capacity(cap));
90 if self.sender.unbounded_send(Ok(full_buf.into())).is_err() {
91 return Err(io::Error::new(
96 io::ErrorKind::BrokenPipe,
97 "receiver was dropped",
98 ));
99 }
100 }
101 Ok(())
102 }
103}
104
105impl<D, E> Drop for BodyWriter<D, E>
106where
107 D: From<Vec<u8>> + Send + 'static,
108 E: Send + 'static,
109{
110 fn drop(&mut self) {
111 let _ = self.flush();
112 }
113}
114
115#[cfg(test)]
116mod tests {
117 use super::BodyWriter;
118 use futures_core::Stream;
119 use futures_util::{stream::StreamExt, stream::TryStreamExt};
120 use std::io::Write;
121 use std::pin::Pin;
122
123 type BoxedError = Box<dyn std::error::Error + 'static + Send + Sync>;
124 type BodyStream = Box<dyn Stream<Item = Result<Vec<u8>, BoxedError>> + Send>;
125
126 async fn to_vec(s: BodyStream) -> Vec<u8> {
127 Pin::from(s).try_concat().await.unwrap()
128 }
129
130 #[tokio::test]
133 async fn small_no_flush() {
134 let (mut w, body): (_, BodyStream) = BodyWriter::with_chunk_size(4);
135 assert_eq!(w.write(b"1").unwrap(), 1);
136 w.truncate();
137 drop(w);
138 assert_eq!(b"", &to_vec(body).await[..]);
139 }
140
141 #[tokio::test]
143 async fn small_flush() {
144 let (mut w, body): (_, BodyStream) = BodyWriter::with_chunk_size(4);
145 assert_eq!(w.write(b"1").unwrap(), 1);
146 w.flush().unwrap();
147 drop(w);
148 assert_eq!(b"1", &to_vec(body).await[..]);
149 }
150
151 #[tokio::test]
153 async fn chunk_write() {
154 let (mut w, body): (_, BodyStream) = BodyWriter::with_chunk_size(4);
155 assert_eq!(w.write(b"1234").unwrap(), 4);
156 w.flush().unwrap();
157 drop(w);
158 assert_eq!(b"1234", &to_vec(body).await[..]);
159 }
160
161 #[tokio::test]
163 async fn chunk_double_write() {
164 let (mut w, body): (_, BodyStream) = BodyWriter::with_chunk_size(4);
165 assert_eq!(w.write(b"1234").unwrap(), 4);
166 assert_eq!(w.write(b"5678").unwrap(), 4);
167 w.flush().unwrap();
168 drop(w);
169 assert_eq!(b"12345678", &to_vec(body).await[..]);
170 }
171
172 #[tokio::test]
174 async fn large_write() {
175 let (mut w, body): (_, BodyStream) = BodyWriter::with_chunk_size(4);
176 assert_eq!(w.write(b"123456").unwrap(), 4);
177 drop(w);
178 assert_eq!(b"1234", &to_vec(body).await[..]);
179 }
180
181 #[tokio::test]
183 async fn small_large_write() {
184 let (mut w, body): (_, BodyStream) = BodyWriter::with_chunk_size(4);
185 assert_eq!(w.write(b"1").unwrap(), 1);
186 assert_eq!(w.write(b"2345").unwrap(), 3);
187 drop(w);
188 assert_eq!(b"1234", &to_vec(body).await[..]);
189 }
190
191 #[tokio::test]
193 async fn abort() {
194 let (mut w, body): (_, BodyStream) = BodyWriter::with_chunk_size(4);
195 w.write_all(b"12345").unwrap();
196 w.truncate();
197 w.abort(Box::new(std::io::Error::new(
198 std::io::ErrorKind::Other,
199 "asdf",
200 )));
201 drop(w);
202 let items = Pin::<_>::from(body)
203 .collect::<Vec<Result<Vec<u8>, BoxedError>>>()
204 .await;
205 assert_eq!(items.len(), 2);
206 assert_eq!(b"1234", &items[0].as_ref().unwrap()[..]);
207 items[1].as_ref().unwrap_err();
208 }
209}