1#![doc = include_str!("../README.md")]
2
3use std::{
4 fmt,
5 marker::PhantomData,
6 mem,
7 pin::Pin,
8 task::{Context, Poll},
9};
10
11use bytes::{Bytes, BytesMut};
12use futures_core::{FusedStream, Stream};
13use pin_project_lite::pin_project;
14
15pin_project! {
16 #[derive(Debug)]
17 pub struct BytesChunks<St: Stream, P> {
18 #[pin]
19 stream: St,
20 buffer: BytesMut,
21 capacity: usize,
22 marker: PhantomData<P>,
23 }
24}
25
26type TryBytesChunksResult<T, E> = Result<Bytes, TryBytesChunksError<T, E>>;
27type TryBytesChunks<St, T, E> = BytesChunks<St, TryBytesChunksResult<T, E>>;
28
29#[derive(PartialEq, Eq)]
30pub struct TryBytesChunksError<T, E>(pub T, pub E);
31
32impl<St: Stream, B> BytesChunks<St, B> {
33 pub fn with_capacity(capacity: usize, stream: St) -> Self {
34 Self {
35 stream,
36 buffer: BytesMut::with_capacity(capacity),
37 capacity,
38 marker: PhantomData,
39 }
40 }
41
42 pub fn buffer(&self) -> &[u8] {
43 self.buffer.as_ref()
44 }
45}
46
47impl<St: Stream> BytesChunks<St, Bytes> {
48 fn take(self: Pin<&mut Self>) -> Bytes {
49 let cap = self.capacity;
50 self.project().buffer.split_to(cap).freeze()
51 }
52}
53
54impl<St: Stream> BytesChunks<St, Vec<u8>> {
55 fn take(self: Pin<&mut Self>) -> Vec<u8> {
56 let cap = self.capacity;
57 Vec::from(&self.project().buffer.split_to(cap).freeze()[..])
58 }
59}
60
61impl<St: Stream, E> BytesChunks<St, TryBytesChunksResult<Bytes, E>> {
62 fn take(self: Pin<&mut Self>) -> Bytes {
63 let cap = self.capacity.clamp(0, self.buffer.len());
64 self.project().buffer.split_to(cap).freeze()
65 }
66}
67
68impl<St: Stream, E> BytesChunks<St, TryBytesChunksResult<Vec<u8>, E>> {
69 fn take(self: Pin<&mut Self>) -> Vec<u8> {
70 let cap = self.capacity.clamp(0, self.buffer.len());
71 Vec::from(&self.project().buffer.split_to(cap).freeze()[..])
72 }
73}
74
75impl<T, E: fmt::Debug> fmt::Debug for TryBytesChunksError<T, E> {
76 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
77 self.1.fmt(f)
78 }
79}
80
81impl<T, E: fmt::Display> fmt::Display for TryBytesChunksError<T, E> {
82 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
83 self.1.fmt(f)
84 }
85}
86
87impl<T, E: fmt::Debug + fmt::Display> std::error::Error for TryBytesChunksError<T, E> {}
88
89impl<T, E> TryBytesChunksError<T, E> {
90 pub fn into_inner(self) -> T {
118 self.0
119 }
120}
121
122pub trait BytesStream: Stream {
123 fn bytes_chunks<T>(self, capacity: usize) -> BytesChunks<Self, T>
149 where
150 Self: Sized,
151 {
152 BytesChunks::with_capacity(capacity, self)
153 }
154
155 fn try_bytes_chunks<T, E>(self, capacity: usize) -> TryBytesChunks<Self, T, E>
182 where
183 Self: Sized,
184 {
185 BytesChunks::with_capacity(capacity, self)
186 }
187}
188
189impl<T> BytesStream for T where T: Stream {}
190
191impl<E, St: Stream<Item = Result<Bytes, E>>> Stream for TryBytesChunks<St, Bytes, E> {
192 type Item = Result<Bytes, TryBytesChunksError<Bytes, E>>;
193
194 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
195 let mut this = self.as_mut().project();
196
197 if this.buffer.len() >= *this.capacity {
198 return Poll::Ready(Some(Ok(self.take())));
199 }
200
201 loop {
202 match this.stream.as_mut().poll_next(cx) {
203 Poll::Pending => return Poll::Pending,
204
205 Poll::Ready(Some(item)) => match item {
206 Ok(item) => {
207 this.buffer.extend_from_slice(&item[..]);
208
209 if this.buffer.len() >= *this.capacity {
210 return Poll::Ready(Some(Ok(self.take())));
211 }
212 }
213 Err(err) => {
214 let err = TryBytesChunksError(self.take(), err);
215 return Poll::Ready(Some(Err(err)));
216 }
217 },
218
219 Poll::Ready(None) => {
220 let last = if this.buffer.is_empty() {
221 None
222 } else {
223 Some(Ok(Bytes::from(mem::take(this.buffer))))
224 };
225
226 return Poll::Ready(last);
227 }
228 }
229 }
230 }
231
232 fn size_hint(&self) -> (usize, Option<usize>) {
233 let chunk_len = if self.buffer.is_empty() { 0 } else { 1 };
234 let (lower, upper) = self.stream.size_hint();
235 let lower = lower.saturating_add(chunk_len);
236 let upper = upper.and_then(|x| x.checked_add(chunk_len));
237 (lower, upper)
238 }
239}
240
241impl<St: Stream<Item = Bytes>> Stream for BytesChunks<St, Bytes> {
242 type Item = Bytes;
243
244 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
245 let mut this = self.as_mut().project();
246
247 if this.buffer.len() >= *this.capacity {
248 return Poll::Ready(Some(self.take()));
249 }
250
251 loop {
252 match this.stream.as_mut().poll_next(cx) {
253 Poll::Pending => return Poll::Pending,
254
255 Poll::Ready(Some(item)) => {
256 this.buffer.extend_from_slice(&item[..]);
257
258 if this.buffer.len() >= *this.capacity {
259 return Poll::Ready(Some(self.take()));
260 }
261 }
262
263 Poll::Ready(None) => {
264 let last = if this.buffer.is_empty() {
265 None
266 } else {
267 Some(Bytes::from(mem::take(this.buffer)))
268 };
269
270 return Poll::Ready(last);
271 }
272 }
273 }
274 }
275
276 fn size_hint(&self) -> (usize, Option<usize>) {
277 let chunk_len = if self.buffer.is_empty() { 0 } else { 1 };
278 let (lower, upper) = self.stream.size_hint();
279 let lower = lower.saturating_add(chunk_len);
280 let upper = upper.and_then(|x| x.checked_add(chunk_len));
281 (lower, upper)
282 }
283}
284
285impl<St: Stream<Item = Vec<u8>>> Stream for BytesChunks<St, Vec<u8>> {
286 type Item = Vec<u8>;
287
288 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
289 let mut this = self.as_mut().project();
290
291 if this.buffer.len() >= *this.capacity {
292 return Poll::Ready(Some(self.take()));
293 }
294
295 loop {
296 match this.stream.as_mut().poll_next(cx) {
297 Poll::Pending => return Poll::Pending,
298
299 Poll::Ready(Some(item)) => {
300 this.buffer.extend_from_slice(&item[..]);
301
302 if this.buffer.len() >= *this.capacity {
303 return Poll::Ready(Some(self.take()));
304 }
305 }
306
307 Poll::Ready(None) => {
308 let last = if this.buffer.is_empty() {
309 None
310 } else {
311 let buf = mem::take(this.buffer);
312 Some(Vec::from(&buf[..]))
313 };
314
315 return Poll::Ready(last);
316 }
317 }
318 }
319 }
320
321 fn size_hint(&self) -> (usize, Option<usize>) {
322 let chunk_len = if self.buffer.is_empty() { 0 } else { 1 };
323 let (lower, upper) = self.stream.size_hint();
324 let lower = lower.saturating_add(chunk_len);
325 let upper = upper.and_then(|x| x.checked_add(chunk_len));
326 (lower, upper)
327 }
328}
329
330impl<E, St: Stream<Item = Result<Vec<u8>, E>>> Stream for TryBytesChunks<St, Vec<u8>, E> {
331 type Item = Result<Vec<u8>, TryBytesChunksError<Vec<u8>, E>>;
332
333 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
334 let mut this = self.as_mut().project();
335
336 if this.buffer.len() >= *this.capacity {
337 return Poll::Ready(Some(Ok(self.take())));
338 }
339
340 loop {
341 match this.stream.as_mut().poll_next(cx) {
342 Poll::Pending => return Poll::Pending,
343
344 Poll::Ready(Some(item)) => match item {
345 Ok(item) => {
346 this.buffer.extend_from_slice(&item[..]);
347
348 if this.buffer.len() >= *this.capacity {
349 return Poll::Ready(Some(Ok(self.take())));
350 }
351 }
352 Err(err) => {
353 let err = TryBytesChunksError(self.take(), err);
354 return Poll::Ready(Some(Err(err)));
355 }
356 },
357
358 Poll::Ready(None) => {
359 let last = if this.buffer.is_empty() {
360 None
361 } else {
362 let buf = mem::take(this.buffer);
363 Some(Ok(Vec::from(&buf[..])))
364 };
365
366 return Poll::Ready(last);
367 }
368 }
369 }
370 }
371
372 fn size_hint(&self) -> (usize, Option<usize>) {
373 let chunk_len = if self.buffer.is_empty() { 0 } else { 1 };
374 let (lower, upper) = self.stream.size_hint();
375 let lower = lower.saturating_add(chunk_len);
376 let upper = upper.and_then(|x| x.checked_add(chunk_len));
377 (lower, upper)
378 }
379}
380
381impl<St: FusedStream<Item = Bytes>> FusedStream for BytesChunks<St, Bytes> {
382 fn is_terminated(&self) -> bool {
383 self.stream.is_terminated() && self.buffer.is_empty()
384 }
385}
386
387impl<E, St: FusedStream<Item = Result<Bytes, E>>> FusedStream for TryBytesChunks<St, Bytes, E> {
388 fn is_terminated(&self) -> bool {
389 self.stream.is_terminated() && self.buffer.is_empty()
390 }
391}
392
393#[cfg(test)]
394mod test {
395 use std::convert::Infallible;
396
397 use bytes::Bytes;
398 use futures::{
399 executor::block_on,
400 stream::{self, StreamExt},
401 };
402 use futures_test::{assert_stream_done, assert_stream_next};
403
404 use super::BytesStream;
405
406 #[test]
407 fn test_bytes_chunks_lengthen() {
408 block_on(async {
409 let stream = futures::stream::iter(vec![
410 Bytes::from_static(&[1, 2, 3]),
411 Bytes::from_static(&[4, 5, 6]),
412 Bytes::from_static(&[7, 8, 9]),
413 ]);
414
415 let mut stream = stream.bytes_chunks(4);
416
417 assert_stream_next!(stream, Bytes::from_static(&[1, 2, 3, 4]));
418 assert_stream_next!(stream, Bytes::from_static(&[5, 6, 7, 8]));
419 assert_stream_next!(stream, Bytes::from_static(&[9]));
420 assert_stream_done!(stream);
421 });
422 }
423
424 #[test]
425 fn test_bytes_chunks_shorten() {
426 block_on(async {
427 let stream = futures::stream::iter(vec![
428 Bytes::from_static(&[1, 2, 3]),
429 Bytes::from_static(&[4, 5, 6]),
430 Bytes::from_static(&[7, 8, 9]),
431 ]);
432
433 let mut stream = stream.bytes_chunks(2);
434
435 assert_stream_next!(stream, Bytes::from_static(&[1, 2]));
436 assert_stream_next!(stream, Bytes::from_static(&[3, 4]));
437 assert_stream_next!(stream, Bytes::from_static(&[5, 6]));
438 assert_stream_next!(stream, Bytes::from_static(&[7, 8]));
439 assert_stream_next!(stream, Bytes::from_static(&[9]));
440 assert_stream_done!(stream);
441 });
442 }
443
444 #[test]
445 fn test_vec_chunks_lengthen() {
446 block_on(async {
447 #[rustfmt::skip]
448 let stream = futures::stream::iter(vec![
449 vec![1, 2, 3],
450 vec![4, 5, 6],
451 vec![7, 8, 9],
452 ]);
453
454 let mut stream = stream.bytes_chunks(4);
455
456 assert_stream_next!(stream, vec![1, 2, 3, 4]);
457 assert_stream_next!(stream, vec![5, 6, 7, 8]);
458 assert_stream_next!(stream, vec![9]);
459 assert_stream_done!(stream);
460 });
461 }
462
463 #[test]
464 fn test_vec_chunks_shorten() {
465 block_on(async {
466 #[rustfmt::skip]
467 let stream = futures::stream::iter(vec![
468 vec![1, 2, 3],
469 vec![4, 5, 6],
470 vec![7, 8, 9],
471 ]);
472
473 let mut stream = stream.bytes_chunks(2);
474
475 assert_stream_next!(stream, vec![1, 2]);
476 assert_stream_next!(stream, vec![3, 4]);
477 assert_stream_next!(stream, vec![5, 6]);
478 assert_stream_next!(stream, vec![7, 8]);
479 assert_stream_next!(stream, vec![9]);
480 assert_stream_done!(stream);
481 });
482 }
483
484 #[test]
485 fn test_try_bytes_chunks_lengthen() {
486 block_on(async {
487 let stream: stream::Iter<std::vec::IntoIter<Result<Bytes, Infallible>>> =
488 stream::iter(vec![
489 Ok(Bytes::from_static(&[1, 2, 3])),
490 Ok(Bytes::from_static(&[4, 5, 6])),
491 Ok(Bytes::from_static(&[7, 8, 9])),
492 ]);
493
494 let mut stream = stream.try_bytes_chunks(4);
495
496 assert_stream_next!(stream, Ok(Bytes::from_static(&[1, 2, 3, 4])));
497 assert_stream_next!(stream, Ok(Bytes::from_static(&[5, 6, 7, 8])));
498 assert_stream_next!(stream, Ok(Bytes::from_static(&[9])));
499 assert_stream_done!(stream);
500 });
501 }
502
503 #[test]
504 fn test_try_bytes_chunks_shorten() {
505 block_on(async {
506 let stream: stream::Iter<std::vec::IntoIter<Result<Bytes, Infallible>>> =
507 stream::iter(vec![
508 Ok(Bytes::from_static(&[1, 2, 3])),
509 Ok(Bytes::from_static(&[4, 5, 6])),
510 Ok(Bytes::from_static(&[7, 8, 9])),
511 ]);
512
513 let mut stream = stream.try_bytes_chunks(2);
514
515 assert_stream_next!(stream, Ok(Bytes::from_static(&[1, 2])));
516 assert_stream_next!(stream, Ok(Bytes::from_static(&[3, 4])));
517 assert_stream_next!(stream, Ok(Bytes::from_static(&[5, 6])));
518 assert_stream_next!(stream, Ok(Bytes::from_static(&[7, 8])));
519 assert_stream_next!(stream, Ok(Bytes::from_static(&[9])));
520 assert_stream_done!(stream);
521 });
522 }
523
524 #[test]
525 fn test_try_bytes_chunks_leftovers() {
526 block_on(async {
527 let stream: stream::Iter<std::vec::IntoIter<Result<Bytes, &'static str>>> =
528 stream::iter(vec![
529 Ok(Bytes::from_static(&[1, 2, 3])),
530 Ok(Bytes::from_static(&[4, 5, 6])),
531 Err("error"),
532 ]);
533
534 let mut stream = stream.try_bytes_chunks(4);
535
536 assert_stream_next!(stream, Ok(Bytes::from_static(&[1, 2, 3, 4])));
537
538 let err = stream.next().await.unwrap();
539 assert!(err.is_err());
540 let err = err.err().unwrap();
541 assert_eq!(err.into_inner(), Bytes::from_static(&[5, 6]));
542 });
543 }
544}