bytes_stream/
lib.rs

1#![doc = include_str!("../README.md")]
2
3use std::{
4    fmt,
5    marker::PhantomData,
6    mem,
7    pin::Pin,
8    task::{Context, Poll},
9};
10
11use bytes::{Bytes, BytesMut};
12use futures_core::{FusedStream, Stream};
13use pin_project_lite::pin_project;
14
15pin_project! {
16    #[derive(Debug)]
17    pub struct BytesChunks<St: Stream, P> {
18        #[pin]
19        stream: St,
20        buffer: BytesMut,
21        capacity: usize,
22        marker: PhantomData<P>,
23    }
24}
25
26type TryBytesChunksResult<T, E> = Result<Bytes, TryBytesChunksError<T, E>>;
27type TryBytesChunks<St, T, E> = BytesChunks<St, TryBytesChunksResult<T, E>>;
28
29#[derive(PartialEq, Eq)]
30pub struct TryBytesChunksError<T, E>(pub T, pub E);
31
32impl<St: Stream, B> BytesChunks<St, B> {
33    pub fn with_capacity(capacity: usize, stream: St) -> Self {
34        Self {
35            stream,
36            buffer: BytesMut::with_capacity(capacity),
37            capacity,
38            marker: PhantomData,
39        }
40    }
41
42    pub fn buffer(&self) -> &[u8] {
43        self.buffer.as_ref()
44    }
45}
46
47impl<St: Stream> BytesChunks<St, Bytes> {
48    fn take(self: Pin<&mut Self>) -> Bytes {
49        let cap = self.capacity;
50        self.project().buffer.split_to(cap).freeze()
51    }
52}
53
54impl<St: Stream> BytesChunks<St, Vec<u8>> {
55    fn take(self: Pin<&mut Self>) -> Vec<u8> {
56        let cap = self.capacity;
57        Vec::from(&self.project().buffer.split_to(cap).freeze()[..])
58    }
59}
60
61impl<St: Stream, E> BytesChunks<St, TryBytesChunksResult<Bytes, E>> {
62    fn take(self: Pin<&mut Self>) -> Bytes {
63        let cap = self.capacity.clamp(0, self.buffer.len());
64        self.project().buffer.split_to(cap).freeze()
65    }
66}
67
68impl<St: Stream, E> BytesChunks<St, TryBytesChunksResult<Vec<u8>, E>> {
69    fn take(self: Pin<&mut Self>) -> Vec<u8> {
70        let cap = self.capacity.clamp(0, self.buffer.len());
71        Vec::from(&self.project().buffer.split_to(cap).freeze()[..])
72    }
73}
74
75impl<T, E: fmt::Debug> fmt::Debug for TryBytesChunksError<T, E> {
76    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
77        self.1.fmt(f)
78    }
79}
80
81impl<T, E: fmt::Display> fmt::Display for TryBytesChunksError<T, E> {
82    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
83        self.1.fmt(f)
84    }
85}
86
87impl<T, E: fmt::Debug + fmt::Display> std::error::Error for TryBytesChunksError<T, E> {}
88
89impl<T, E> TryBytesChunksError<T, E> {
90    /// Returns the buffered data stored before the error.
91    /// ```
92    /// # use std::{convert::Infallible, vec::IntoIter};
93    /// # use bytes::Bytes;
94    /// # use futures::{
95    /// #     executor::block_on,
96    /// #     stream::{self, StreamExt},
97    /// # };
98    /// # use bytes_stream::BytesStream;
99    /// # fn main() {
100    /// # block_on(async {
101    /// let stream: stream::Iter<IntoIter<Result<Bytes, &'static str>>> =
102    ///     stream::iter(vec![
103    ///         Ok(Bytes::from_static(&[1, 2, 3])),
104    ///         Ok(Bytes::from_static(&[4, 5, 6])),
105    ///         Err("failure"),
106    ///     ]);
107    ///
108    /// let mut stream = stream.try_bytes_chunks(4);
109    ///
110    /// assert_eq!(stream.next().await, Some(Ok(Bytes::from_static(&[1, 2, 3, 4]))));
111    ///
112    /// let err = stream.next().await.unwrap().err().unwrap();
113    /// assert_eq!(err.into_inner(), Bytes::from_static(&[5, 6]));
114    /// # });
115    /// # }
116    /// ```
117    pub fn into_inner(self) -> T {
118        self.0
119    }
120}
121
122pub trait BytesStream: Stream {
123    /// Group bytes in chunks of `capacity`.
124    /// ```
125    /// # use bytes::Bytes;
126    /// # use futures::{
127    /// #     executor::block_on,
128    /// #     stream::{self, StreamExt},
129    /// # };
130    /// # use bytes_stream::BytesStream;
131    /// # fn main() {
132    /// # block_on(async {
133    /// let stream = futures::stream::iter(vec![
134    ///     Bytes::from_static(&[1, 2, 3]),
135    ///     Bytes::from_static(&[4, 5, 6]),
136    ///     Bytes::from_static(&[7, 8, 9]),
137    /// ]);
138    ///
139    /// let mut stream = stream.bytes_chunks(4);
140    ///
141    /// assert_eq!(stream.next().await, Some(Bytes::from_static(&[1, 2, 3, 4])));
142    /// assert_eq!(stream.next().await, Some(Bytes::from_static(&[5, 6, 7, 8])));
143    /// assert_eq!(stream.next().await, Some(Bytes::from_static(&[9])));
144    /// assert_eq!(stream.next().await, None);
145    /// # });
146    /// # }
147    /// ```
148    fn bytes_chunks<T>(self, capacity: usize) -> BytesChunks<Self, T>
149    where
150        Self: Sized,
151    {
152        BytesChunks::with_capacity(capacity, self)
153    }
154
155    /// Group result of bytes in chunks of `capacity`.
156    /// ```
157    /// # use std::convert::Infallible;
158    /// # use bytes::Bytes;
159    /// # use futures::{
160    /// #     executor::block_on,
161    /// #     stream::{self, StreamExt},
162    /// # };
163    /// # use bytes_stream::BytesStream;
164    /// # fn main() {
165    /// # block_on(async {
166    /// let stream = futures::stream::iter(vec![
167    ///     Ok::<_, Infallible>(Bytes::from_static(&[1, 2, 3])),
168    ///     Ok::<_, Infallible>(Bytes::from_static(&[4, 5, 6])),
169    ///     Ok::<_, Infallible>(Bytes::from_static(&[7, 8, 9])),
170    /// ]);
171    ///
172    /// let mut stream = stream.try_bytes_chunks(4);
173    ///
174    /// assert_eq!(stream.next().await, Some(Ok(Bytes::from_static(&[1, 2, 3, 4]))));
175    /// assert_eq!(stream.next().await, Some(Ok(Bytes::from_static(&[5, 6, 7, 8]))));
176    /// assert_eq!(stream.next().await, Some(Ok(Bytes::from_static(&[9]))));
177    /// assert_eq!(stream.next().await, None);
178    /// # });
179    /// # }
180    /// ```
181    fn try_bytes_chunks<T, E>(self, capacity: usize) -> TryBytesChunks<Self, T, E>
182    where
183        Self: Sized,
184    {
185        BytesChunks::with_capacity(capacity, self)
186    }
187}
188
189impl<T> BytesStream for T where T: Stream {}
190
191impl<E, St: Stream<Item = Result<Bytes, E>>> Stream for TryBytesChunks<St, Bytes, E> {
192    type Item = Result<Bytes, TryBytesChunksError<Bytes, E>>;
193
194    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
195        let mut this = self.as_mut().project();
196
197        if this.buffer.len() >= *this.capacity {
198            return Poll::Ready(Some(Ok(self.take())));
199        }
200
201        loop {
202            match this.stream.as_mut().poll_next(cx) {
203                Poll::Pending => return Poll::Pending,
204
205                Poll::Ready(Some(item)) => match item {
206                    Ok(item) => {
207                        this.buffer.extend_from_slice(&item[..]);
208
209                        if this.buffer.len() >= *this.capacity {
210                            return Poll::Ready(Some(Ok(self.take())));
211                        }
212                    }
213                    Err(err) => {
214                        let err = TryBytesChunksError(self.take(), err);
215                        return Poll::Ready(Some(Err(err)));
216                    }
217                },
218
219                Poll::Ready(None) => {
220                    let last = if this.buffer.is_empty() {
221                        None
222                    } else {
223                        Some(Ok(Bytes::from(mem::take(this.buffer))))
224                    };
225
226                    return Poll::Ready(last);
227                }
228            }
229        }
230    }
231
232    fn size_hint(&self) -> (usize, Option<usize>) {
233        let chunk_len = if self.buffer.is_empty() { 0 } else { 1 };
234        let (lower, upper) = self.stream.size_hint();
235        let lower = lower.saturating_add(chunk_len);
236        let upper = upper.and_then(|x| x.checked_add(chunk_len));
237        (lower, upper)
238    }
239}
240
241impl<St: Stream<Item = Bytes>> Stream for BytesChunks<St, Bytes> {
242    type Item = Bytes;
243
244    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
245        let mut this = self.as_mut().project();
246
247        if this.buffer.len() >= *this.capacity {
248            return Poll::Ready(Some(self.take()));
249        }
250
251        loop {
252            match this.stream.as_mut().poll_next(cx) {
253                Poll::Pending => return Poll::Pending,
254
255                Poll::Ready(Some(item)) => {
256                    this.buffer.extend_from_slice(&item[..]);
257
258                    if this.buffer.len() >= *this.capacity {
259                        return Poll::Ready(Some(self.take()));
260                    }
261                }
262
263                Poll::Ready(None) => {
264                    let last = if this.buffer.is_empty() {
265                        None
266                    } else {
267                        Some(Bytes::from(mem::take(this.buffer)))
268                    };
269
270                    return Poll::Ready(last);
271                }
272            }
273        }
274    }
275
276    fn size_hint(&self) -> (usize, Option<usize>) {
277        let chunk_len = if self.buffer.is_empty() { 0 } else { 1 };
278        let (lower, upper) = self.stream.size_hint();
279        let lower = lower.saturating_add(chunk_len);
280        let upper = upper.and_then(|x| x.checked_add(chunk_len));
281        (lower, upper)
282    }
283}
284
285impl<St: Stream<Item = Vec<u8>>> Stream for BytesChunks<St, Vec<u8>> {
286    type Item = Vec<u8>;
287
288    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
289        let mut this = self.as_mut().project();
290
291        if this.buffer.len() >= *this.capacity {
292            return Poll::Ready(Some(self.take()));
293        }
294
295        loop {
296            match this.stream.as_mut().poll_next(cx) {
297                Poll::Pending => return Poll::Pending,
298
299                Poll::Ready(Some(item)) => {
300                    this.buffer.extend_from_slice(&item[..]);
301
302                    if this.buffer.len() >= *this.capacity {
303                        return Poll::Ready(Some(self.take()));
304                    }
305                }
306
307                Poll::Ready(None) => {
308                    let last = if this.buffer.is_empty() {
309                        None
310                    } else {
311                        let buf = mem::take(this.buffer);
312                        Some(Vec::from(&buf[..]))
313                    };
314
315                    return Poll::Ready(last);
316                }
317            }
318        }
319    }
320
321    fn size_hint(&self) -> (usize, Option<usize>) {
322        let chunk_len = if self.buffer.is_empty() { 0 } else { 1 };
323        let (lower, upper) = self.stream.size_hint();
324        let lower = lower.saturating_add(chunk_len);
325        let upper = upper.and_then(|x| x.checked_add(chunk_len));
326        (lower, upper)
327    }
328}
329
330impl<E, St: Stream<Item = Result<Vec<u8>, E>>> Stream for TryBytesChunks<St, Vec<u8>, E> {
331    type Item = Result<Vec<u8>, TryBytesChunksError<Vec<u8>, E>>;
332
333    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
334        let mut this = self.as_mut().project();
335
336        if this.buffer.len() >= *this.capacity {
337            return Poll::Ready(Some(Ok(self.take())));
338        }
339
340        loop {
341            match this.stream.as_mut().poll_next(cx) {
342                Poll::Pending => return Poll::Pending,
343
344                Poll::Ready(Some(item)) => match item {
345                    Ok(item) => {
346                        this.buffer.extend_from_slice(&item[..]);
347
348                        if this.buffer.len() >= *this.capacity {
349                            return Poll::Ready(Some(Ok(self.take())));
350                        }
351                    }
352                    Err(err) => {
353                        let err = TryBytesChunksError(self.take(), err);
354                        return Poll::Ready(Some(Err(err)));
355                    }
356                },
357
358                Poll::Ready(None) => {
359                    let last = if this.buffer.is_empty() {
360                        None
361                    } else {
362                        let buf = mem::take(this.buffer);
363                        Some(Ok(Vec::from(&buf[..])))
364                    };
365
366                    return Poll::Ready(last);
367                }
368            }
369        }
370    }
371
372    fn size_hint(&self) -> (usize, Option<usize>) {
373        let chunk_len = if self.buffer.is_empty() { 0 } else { 1 };
374        let (lower, upper) = self.stream.size_hint();
375        let lower = lower.saturating_add(chunk_len);
376        let upper = upper.and_then(|x| x.checked_add(chunk_len));
377        (lower, upper)
378    }
379}
380
381impl<St: FusedStream<Item = Bytes>> FusedStream for BytesChunks<St, Bytes> {
382    fn is_terminated(&self) -> bool {
383        self.stream.is_terminated() && self.buffer.is_empty()
384    }
385}
386
387impl<E, St: FusedStream<Item = Result<Bytes, E>>> FusedStream for TryBytesChunks<St, Bytes, E> {
388    fn is_terminated(&self) -> bool {
389        self.stream.is_terminated() && self.buffer.is_empty()
390    }
391}
392
393#[cfg(test)]
394mod test {
395    use std::convert::Infallible;
396
397    use bytes::Bytes;
398    use futures::{
399        executor::block_on,
400        stream::{self, StreamExt},
401    };
402    use futures_test::{assert_stream_done, assert_stream_next};
403
404    use super::BytesStream;
405
406    #[test]
407    fn test_bytes_chunks_lengthen() {
408        block_on(async {
409            let stream = futures::stream::iter(vec![
410                Bytes::from_static(&[1, 2, 3]),
411                Bytes::from_static(&[4, 5, 6]),
412                Bytes::from_static(&[7, 8, 9]),
413            ]);
414
415            let mut stream = stream.bytes_chunks(4);
416
417            assert_stream_next!(stream, Bytes::from_static(&[1, 2, 3, 4]));
418            assert_stream_next!(stream, Bytes::from_static(&[5, 6, 7, 8]));
419            assert_stream_next!(stream, Bytes::from_static(&[9]));
420            assert_stream_done!(stream);
421        });
422    }
423
424    #[test]
425    fn test_bytes_chunks_shorten() {
426        block_on(async {
427            let stream = futures::stream::iter(vec![
428                Bytes::from_static(&[1, 2, 3]),
429                Bytes::from_static(&[4, 5, 6]),
430                Bytes::from_static(&[7, 8, 9]),
431            ]);
432
433            let mut stream = stream.bytes_chunks(2);
434
435            assert_stream_next!(stream, Bytes::from_static(&[1, 2]));
436            assert_stream_next!(stream, Bytes::from_static(&[3, 4]));
437            assert_stream_next!(stream, Bytes::from_static(&[5, 6]));
438            assert_stream_next!(stream, Bytes::from_static(&[7, 8]));
439            assert_stream_next!(stream, Bytes::from_static(&[9]));
440            assert_stream_done!(stream);
441        });
442    }
443
444    #[test]
445    fn test_vec_chunks_lengthen() {
446        block_on(async {
447            #[rustfmt::skip]
448            let stream = futures::stream::iter(vec![
449                vec![1, 2, 3],
450                vec![4, 5, 6],
451                vec![7, 8, 9],
452            ]);
453
454            let mut stream = stream.bytes_chunks(4);
455
456            assert_stream_next!(stream, vec![1, 2, 3, 4]);
457            assert_stream_next!(stream, vec![5, 6, 7, 8]);
458            assert_stream_next!(stream, vec![9]);
459            assert_stream_done!(stream);
460        });
461    }
462
463    #[test]
464    fn test_vec_chunks_shorten() {
465        block_on(async {
466            #[rustfmt::skip]
467            let stream = futures::stream::iter(vec![
468                vec![1, 2, 3],
469                vec![4, 5, 6],
470                vec![7, 8, 9],
471            ]);
472
473            let mut stream = stream.bytes_chunks(2);
474
475            assert_stream_next!(stream, vec![1, 2]);
476            assert_stream_next!(stream, vec![3, 4]);
477            assert_stream_next!(stream, vec![5, 6]);
478            assert_stream_next!(stream, vec![7, 8]);
479            assert_stream_next!(stream, vec![9]);
480            assert_stream_done!(stream);
481        });
482    }
483
484    #[test]
485    fn test_try_bytes_chunks_lengthen() {
486        block_on(async {
487            let stream: stream::Iter<std::vec::IntoIter<Result<Bytes, Infallible>>> =
488                stream::iter(vec![
489                    Ok(Bytes::from_static(&[1, 2, 3])),
490                    Ok(Bytes::from_static(&[4, 5, 6])),
491                    Ok(Bytes::from_static(&[7, 8, 9])),
492                ]);
493
494            let mut stream = stream.try_bytes_chunks(4);
495
496            assert_stream_next!(stream, Ok(Bytes::from_static(&[1, 2, 3, 4])));
497            assert_stream_next!(stream, Ok(Bytes::from_static(&[5, 6, 7, 8])));
498            assert_stream_next!(stream, Ok(Bytes::from_static(&[9])));
499            assert_stream_done!(stream);
500        });
501    }
502
503    #[test]
504    fn test_try_bytes_chunks_shorten() {
505        block_on(async {
506            let stream: stream::Iter<std::vec::IntoIter<Result<Bytes, Infallible>>> =
507                stream::iter(vec![
508                    Ok(Bytes::from_static(&[1, 2, 3])),
509                    Ok(Bytes::from_static(&[4, 5, 6])),
510                    Ok(Bytes::from_static(&[7, 8, 9])),
511                ]);
512
513            let mut stream = stream.try_bytes_chunks(2);
514
515            assert_stream_next!(stream, Ok(Bytes::from_static(&[1, 2])));
516            assert_stream_next!(stream, Ok(Bytes::from_static(&[3, 4])));
517            assert_stream_next!(stream, Ok(Bytes::from_static(&[5, 6])));
518            assert_stream_next!(stream, Ok(Bytes::from_static(&[7, 8])));
519            assert_stream_next!(stream, Ok(Bytes::from_static(&[9])));
520            assert_stream_done!(stream);
521        });
522    }
523
524    #[test]
525    fn test_try_bytes_chunks_leftovers() {
526        block_on(async {
527            let stream: stream::Iter<std::vec::IntoIter<Result<Bytes, &'static str>>> =
528                stream::iter(vec![
529                    Ok(Bytes::from_static(&[1, 2, 3])),
530                    Ok(Bytes::from_static(&[4, 5, 6])),
531                    Err("error"),
532                ]);
533
534            let mut stream = stream.try_bytes_chunks(4);
535
536            assert_stream_next!(stream, Ok(Bytes::from_static(&[1, 2, 3, 4])));
537
538            let err = stream.next().await.unwrap();
539            assert!(err.is_err());
540            let err = err.err().unwrap();
541            assert_eq!(err.into_inner(), Bytes::from_static(&[5, 6]));
542        });
543    }
544}