futures_01_ext/
lib.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 *
4 * This source code is licensed under both the MIT license found in the
5 * LICENSE-MIT file in the root directory of this source tree and the Apache
6 * License, Version 2.0 found in the LICENSE-APACHE file in the root directory
7 * of this source tree.
8 */
9
10#![deny(warnings, missing_docs, clippy::all, rustdoc::broken_intra_doc_links)]
11#![feature(never_type)]
12
13//! Crate extending functionality of [`futures`] crate
14
15use std::fmt::Debug;
16use std::io as std_io;
17
18use bytes_old::Bytes;
19use futures::future;
20use futures::stream;
21use futures::sync::mpsc;
22use futures::sync::oneshot;
23use futures::try_ready;
24use futures::Async;
25use futures::AsyncSink;
26use futures::Future;
27use futures::Poll;
28use futures::Sink;
29use futures::Stream;
30use tokio_io::codec::Decoder;
31use tokio_io::codec::Encoder;
32use tokio_io::AsyncWrite;
33
34mod bytes_stream;
35pub mod decode;
36pub mod encode;
37mod futures_ordered;
38pub mod io;
39mod select_all;
40mod split_err;
41mod stream_wrappers;
42mod streamfork;
43
44// Re-exports. Those are used by the macros in this crate in order to reference a stable version of
45// what "futures" means.
46pub use futures as futures_reexport;
47
48pub use crate::bytes_stream::BytesStream;
49pub use crate::bytes_stream::BytesStreamFuture;
50pub use crate::futures_ordered::futures_ordered;
51pub use crate::futures_ordered::FuturesOrdered;
52pub use crate::select_all::select_all;
53pub use crate::select_all::SelectAll;
54pub use crate::split_err::split_err;
55pub use crate::stream_wrappers::CollectNoConsume;
56pub use crate::stream_wrappers::CollectTo;
57
58/// Map `Item` and `Error` to `()`
59///
60/// Adapt an existing `Future` to return unit `Item` and `Error`, while still
61/// waiting for the underlying `Future` to complete.
62#[must_use = "futures do nothing unless you `.await` or poll them"]
63pub struct Discard<F>(F);
64
65impl<F> Discard<F> {
66    /// Create instance wrapping `f`
67    pub fn new(f: F) -> Self {
68        Discard(f)
69    }
70}
71
72impl<F> Future for Discard<F>
73where
74    F: Future,
75{
76    type Item = ();
77    type Error = ();
78
79    fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
80        match self.0.poll() {
81            Err(_) => Err(()),
82            Ok(Async::NotReady) => Ok(Async::NotReady),
83            Ok(Async::Ready(_)) => Ok(Async::Ready(())),
84        }
85    }
86}
87
88/// Send an item over an mpsc channel, discarding both the sender and receiver-closed errors. This
89/// should be used when the receiver being closed makes sending values moot, since no one is
90/// interested in the results any more.
91///
92/// `E` is an arbitrary error type useful for getting types to match up, but it will never be
93/// produced by the returned future.
94#[inline]
95pub fn send_discard<T, E>(
96    sender: mpsc::Sender<T>,
97    value: T,
98) -> impl Future<Item = (), Error = E> + Send
99where
100    T: Send,
101    E: Send,
102{
103    sender.send(value).then(|_| Ok(()))
104}
105
106/// Replacement for BoxFuture, deprecated in upstream futures-rs.
107pub type BoxFuture<T, E> = Box<dyn Future<Item = T, Error = E> + Send>;
108/// Replacement for BoxFutureNonSend, deprecated in upstream futures-rs.
109pub type BoxFutureNonSend<T, E> = Box<dyn Future<Item = T, Error = E>>;
110/// Replacement for BoxStream, deprecated in upstream futures-rs.
111pub type BoxStream<T, E> = Box<dyn Stream<Item = T, Error = E> + Send>;
112/// Replacement for BoxStreamNonSend, deprecated in upstream futures-rs.
113pub type BoxStreamNonSend<T, E> = Box<dyn Stream<Item = T, Error = E>>;
114
115/// Do something with an error if the future failed.
116///
117/// This is created by the `FutureExt::inspect_err` method.
118#[derive(Debug)]
119#[must_use = "futures do nothing unless polled"]
120pub struct InspectErr<A, F>
121where
122    A: Future,
123{
124    future: A,
125    f: Option<F>,
126}
127
128impl<A, F> Future for InspectErr<A, F>
129where
130    A: Future,
131    F: FnOnce(&A::Error),
132{
133    type Item = A::Item;
134    type Error = A::Error;
135
136    fn poll(&mut self) -> Poll<A::Item, A::Error> {
137        match self.future.poll() {
138            Ok(Async::NotReady) => Ok(Async::NotReady),
139            Ok(Async::Ready(e)) => Ok(Async::Ready(e)),
140            Err(e) => {
141                self.f.take().map_or_else(
142                    // Act like a fused future
143                    || Ok(Async::NotReady),
144                    |func| {
145                        func(&e);
146                        Err(e)
147                    },
148                )
149            }
150        }
151    }
152}
153
154/// Inspect the Result returned by a future
155///
156/// This is created by the `FutureExt::inspect_result` method.
157#[derive(Debug)]
158#[must_use = "futures do nothing unless polled"]
159pub struct InspectResult<A, F>
160where
161    A: Future,
162{
163    future: A,
164    f: Option<F>,
165}
166
167impl<A, F> Future for InspectResult<A, F>
168where
169    A: Future,
170    F: FnOnce(Result<&A::Item, &A::Error>),
171{
172    type Item = A::Item;
173    type Error = A::Error;
174
175    fn poll(&mut self) -> Poll<A::Item, A::Error> {
176        match self.future.poll() {
177            Ok(Async::NotReady) => Ok(Async::NotReady),
178            Ok(Async::Ready(i)) => self.f.take().map_or_else(
179                // Act like a fused future
180                || Ok(Async::NotReady),
181                |func| {
182                    func(Ok(&i));
183                    Ok(Async::Ready(i))
184                },
185            ),
186
187            Err(e) => self.f.take().map_or_else(
188                // Act like a fused future
189                || Ok(Async::NotReady),
190                |func| {
191                    func(Err(&e));
192                    Err(e)
193                },
194            ),
195        }
196    }
197}
198
199/// A trait implemented by default for all Futures which extends the standard
200/// functionality.
201pub trait FutureExt: Future + Sized {
202    /// Map a `Future` to have `Item=()` and `Error=()`. This is
203    /// useful when a future is being used to drive a computation
204    /// but the actual results aren't interesting (such as when used
205    /// with `Handle::spawn()`).
206    fn discard(self) -> Discard<Self> {
207        Discard(self)
208    }
209
210    /// Create a `Send`able boxed version of this `Future`.
211    #[inline]
212    fn boxify(self) -> BoxFuture<Self::Item, Self::Error>
213    where
214        Self: 'static + Send,
215    {
216        // TODO: (rain1) T21801845 rename to 'boxed' once gone from upstream.
217        Box::new(self)
218    }
219
220    /// Create a non-`Send`able boxed version of this `Future`.
221    #[inline]
222    fn boxify_nonsend(self) -> BoxFutureNonSend<Self::Item, Self::Error>
223    where
224        Self: 'static,
225    {
226        Box::new(self)
227    }
228
229    /// Shorthand for returning [`future::Either::A`]
230    fn left_future<B>(self) -> future::Either<Self, B> {
231        future::Either::A(self)
232    }
233
234    /// Shorthand for returning [`future::Either::B`]
235    fn right_future<A>(self) -> future::Either<A, Self> {
236        future::Either::B(self)
237    }
238
239    /// Similar to [`future::Future::inspect`], but runs the function on error
240    fn inspect_err<F>(self, f: F) -> InspectErr<Self, F>
241    where
242        F: FnOnce(&Self::Error),
243        Self: Sized,
244    {
245        InspectErr {
246            future: self,
247            f: Some(f),
248        }
249    }
250
251    /// Similar to [`future::Future::inspect`], but runs the function on both
252    /// output or error of the Future treating it as a regular [`Result`]
253    fn inspect_result<F>(self, f: F) -> InspectResult<Self, F>
254    where
255        F: FnOnce(Result<&Self::Item, &Self::Error>),
256        Self: Sized,
257    {
258        InspectResult {
259            future: self,
260            f: Some(f),
261        }
262    }
263}
264
265impl<T> FutureExt for T where T: Future {}
266
267/// Params for [StreamExt::buffered_weight_limited] and [WeightLimitedBufferedStream]
268pub struct BufferedParams {
269    /// Limit for the sum of weights in the [WeightLimitedBufferedStream] stream
270    pub weight_limit: u64,
271    /// Limit for size of buffer in the [WeightLimitedBufferedStream] stream
272    pub buffer_size: usize,
273}
274
275/// A trait implemented by default for all Streams which extends the standard
276/// functionality.
277pub trait StreamExt: Stream {
278    /// Fork elements in a stream out to two sinks, depending on a predicate
279    ///
280    /// If the predicate returns false, send the item to `out1`, otherwise to
281    /// `out2`. `streamfork()` acts in a similar manner to `forward()` in that it
282    /// keeps operating until the input stream ends, and then returns everything
283    /// in the resulting Future.
284    ///
285    /// The predicate returns a `Result` so that it can fail (if there's a malformed
286    /// input that can't be assigned to either output).
287    fn streamfork<Out1, Out2, F, E>(
288        self,
289        out1: Out1,
290        out2: Out2,
291        pred: F,
292    ) -> streamfork::Forker<Self, Out1, Out2, F, E>
293    where
294        Self: Sized,
295        Out1: Sink<SinkItem = Self::Item>,
296        Out2: Sink<SinkItem = Self::Item, SinkError = Out1::SinkError>,
297        F: FnMut(&Self::Item) -> Result<bool, E>,
298        E: From<Self::Error> + From<Out1::SinkError> + From<Out2::SinkError>,
299    {
300        streamfork::streamfork(self, out1, out2, pred)
301    }
302
303    /// Returns a future that yields a `(Vec<<Self>::Item>, Self)`, where the
304    /// vector is a collections of all elements yielded by the Stream.
305    fn collect_no_consume(self) -> CollectNoConsume<Self>
306    where
307        Self: Sized,
308    {
309        stream_wrappers::collect_no_consume::new(self)
310    }
311
312    /// A shorthand for [encode::encode]
313    fn encode<Enc>(self, encoder: Enc) -> encode::LayeredEncoder<Self, Enc>
314    where
315        Self: Sized,
316        Enc: Encoder<Item = Self::Item>,
317    {
318        encode::encode(self, encoder)
319    }
320
321    /// Similar to [std::iter::Iterator::enumerate], returns a Stream that yields
322    /// `(usize, Self::Item)` where the first element of tuple is the iteration
323    /// count.
324    fn enumerate(self) -> Enumerate<Self>
325    where
326        Self: Sized,
327    {
328        Enumerate::new(self)
329    }
330
331    /// Creates a stream wrapper and a future. The future will resolve into the wrapped stream when
332    /// the stream wrapper returns None. It uses ConservativeReceiver to ensure that deadlocks are
333    /// easily caught when one tries to poll on the receiver before consuming the stream.
334    fn return_remainder(self) -> (ReturnRemainder<Self>, ConservativeReceiver<Self>)
335    where
336        Self: Sized,
337    {
338        ReturnRemainder::new(self)
339    }
340
341    /// Whether this stream is empty.
342    ///
343    /// This will consume one element from the stream if returned.
344    #[allow(clippy::wrong_self_convention)]
345    fn is_empty<'a>(self) -> Box<dyn Future<Item = bool, Error = Self::Error> + Send + 'a>
346    where
347        Self: 'a + Send + Sized,
348    {
349        Box::new(
350            self.into_future()
351                .map(|(first, _rest)| first.is_none())
352                .map_err(|(err, _rest)| err),
353        )
354    }
355
356    /// Whether this stream is not empty (has at least one element).
357    ///
358    /// This will consume one element from the stream if returned.
359    fn not_empty<'a>(self) -> Box<dyn Future<Item = bool, Error = Self::Error> + Send + 'a>
360    where
361        Self: 'a + Send + Sized,
362    {
363        Box::new(
364            self.into_future()
365                .map(|(first, _rest)| first.is_some())
366                .map_err(|(err, _rest)| err),
367        )
368    }
369
370    /// Create a `Send`able boxed version of this `Stream`.
371    #[inline]
372    fn boxify(self) -> BoxStream<Self::Item, Self::Error>
373    where
374        Self: 'static + Send + Sized,
375    {
376        // TODO: (rain1) T21801845 rename to 'boxed' once gone from upstream.
377        Box::new(self)
378    }
379
380    /// Create a non-`Send`able boxed version of this `Stream`.
381    #[inline]
382    fn boxify_nonsend(self) -> BoxStreamNonSend<Self::Item, Self::Error>
383    where
384        Self: 'static + Sized,
385    {
386        Box::new(self)
387    }
388
389    /// Shorthand for returning [`StreamEither::A`]
390    fn left_stream<B>(self) -> StreamEither<Self, B>
391    where
392        Self: Sized,
393    {
394        StreamEither::A(self)
395    }
396
397    /// Shorthand for returning [`StreamEither::B`]
398    fn right_stream<A>(self) -> StreamEither<A, Self>
399    where
400        Self: Sized,
401    {
402        StreamEither::B(self)
403    }
404
405    /// Similar to [Stream::chunks], but returns earlier if [futures::Async::NotReady]
406    /// was returned.
407    fn batch(self, limit: usize) -> BatchStream<Self>
408    where
409        Self: Sized,
410    {
411        BatchStream::new(self, limit)
412    }
413
414    /// Like [Stream::buffered] call, but can also limit number of futures in a buffer by "weight".
415    fn buffered_weight_limited<I, E, Fut>(
416        self,
417        params: BufferedParams,
418    ) -> WeightLimitedBufferedStream<Self, I, E>
419    where
420        Self: Sized + Send + 'static,
421        Self: Stream<Item = (Fut, u64), Error = E>,
422        Fut: Future<Item = I, Error = E>,
423    {
424        WeightLimitedBufferedStream::new(params, self)
425    }
426
427    /// Returns a Future that yields a collection `C` containing all `Self::Item`
428    /// yielded by the stream
429    fn collect_to<C: Default + Extend<Self::Item>>(self) -> CollectTo<Self, C>
430    where
431        Self: Sized,
432    {
433        CollectTo::new(self)
434    }
435}
436
437impl<T> StreamExt for T where T: Stream {}
438
439/// Like [stream::Buffered], but can also limit number of futures in a buffer by "weight".
440pub struct WeightLimitedBufferedStream<S, I, E> {
441    queue: stream::FuturesOrdered<BoxFuture<(I, u64), E>>,
442    current_weight: u64,
443    weight_limit: u64,
444    max_buffer_size: usize,
445    stream: stream::Fuse<S>,
446}
447
448impl<S, I, E> WeightLimitedBufferedStream<S, I, E>
449where
450    S: Stream,
451{
452    /// Create a new instance that will be configured using the `params` provided
453    pub fn new(params: BufferedParams, stream: S) -> Self {
454        Self {
455            queue: stream::FuturesOrdered::new(),
456            current_weight: 0,
457            weight_limit: params.weight_limit,
458            max_buffer_size: params.buffer_size,
459            stream: stream.fuse(),
460        }
461    }
462}
463
464impl<S, Fut, I: 'static, E: 'static> Stream for WeightLimitedBufferedStream<S, I, E>
465where
466    S: Stream<Item = (Fut, u64), Error = E>,
467    Fut: Future<Item = I, Error = E> + Send + 'static,
468{
469    type Item = I;
470    type Error = E;
471
472    fn poll(&mut self) -> Poll<Option<Self::Item>, E> {
473        // First up, try to spawn off as many futures as possible by filling up
474        // our slab of futures.
475        while self.queue.len() < self.max_buffer_size && self.current_weight < self.weight_limit {
476            let future = match self.stream.poll()? {
477                Async::Ready(Some((s, weight))) => {
478                    self.current_weight += weight;
479                    s.map(move |val| (val, weight)).boxify()
480                }
481                Async::Ready(None) | Async::NotReady => break,
482            };
483
484            self.queue.push(future);
485        }
486
487        // Try polling a new future
488        if let Some((val, weight)) = try_ready!(self.queue.poll()) {
489            self.current_weight -= weight;
490            return Ok(Async::Ready(Some(val)));
491        }
492
493        // If we've gotten this far, then there are no events for us to process
494        // and nothing was ready, so figure out if we're not done yet  or if
495        // we've reached the end.
496        if self.stream.is_done() {
497            Ok(Async::Ready(None))
498        } else {
499            Ok(Async::NotReady)
500        }
501    }
502}
503
504/// Trait that provides a function for making a decoding layer on top of Stream of Bytes
505pub trait StreamLayeredExt: Stream<Item = Bytes> {
506    /// Returnes a Stream that will yield decoded chunks of Bytes as they come
507    /// using provided [Decoder]
508    fn decode<Dec>(self, decoder: Dec) -> decode::LayeredDecode<Self, Dec>
509    where
510        Self: Sized,
511        Dec: Decoder;
512}
513
514impl<T> StreamLayeredExt for T
515where
516    T: Stream<Item = Bytes>,
517{
518    fn decode<Dec>(self, decoder: Dec) -> decode::LayeredDecode<Self, Dec>
519    where
520        Self: Sized,
521        Dec: Decoder,
522    {
523        decode::decode(self, decoder)
524    }
525}
526
527/// Like [std::iter::Enumerate], but for Stream
528pub struct Enumerate<In> {
529    inner: In,
530    count: usize,
531}
532
533impl<In> Enumerate<In> {
534    fn new(inner: In) -> Self {
535        Enumerate { inner, count: 0 }
536    }
537}
538
539impl<In: Stream> Stream for Enumerate<In> {
540    type Item = (usize, In::Item);
541    type Error = In::Error;
542
543    fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
544        match self.inner.poll() {
545            Err(err) => Err(err),
546            Ok(Async::NotReady) => Ok(Async::NotReady),
547            Ok(Async::Ready(None)) => Ok(Async::Ready(None)),
548            Ok(Async::Ready(Some(v))) => {
549                let c = self.count;
550                self.count += 1;
551                Ok(Async::Ready(Some((c, v))))
552            }
553        }
554    }
555}
556
557/// Like [future::Either], but for Stream
558pub enum StreamEither<A, B> {
559    /// First branch of the type
560    A(A),
561    /// Second branch of the type
562    B(B),
563}
564
565impl<A, B> Stream for StreamEither<A, B>
566where
567    A: Stream,
568    B: Stream<Item = A::Item, Error = A::Error>,
569{
570    type Item = A::Item;
571    type Error = A::Error;
572
573    fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
574        match self {
575            StreamEither::A(a) => a.poll(),
576            StreamEither::B(b) => b.poll(),
577        }
578    }
579}
580
581/// This is a wrapper around oneshot::Receiver that will return error when the receiver was polled
582/// and the result was not ready. This is a very strict way of preventing deadlocks in code when
583/// receiver is polled before the sender has send the result
584pub struct ConservativeReceiver<T>(oneshot::Receiver<T>);
585
586/// Error that can be returned by [ConservativeReceiver]
587#[derive(Clone, Copy, PartialEq, Eq, Debug)]
588pub enum ConservativeReceiverError {
589    /// The underlying [oneshot::Receiver] returned [oneshot::Canceled]
590    Canceled,
591    /// The underlying [oneshot::Receiver] returned [Async::NotReady], which means it was polled
592    /// before the [oneshot::Sender] send some data
593    ReceiveBeforeSend,
594}
595
596impl ::std::error::Error for ConservativeReceiverError {
597    fn description(&self) -> &str {
598        match self {
599            ConservativeReceiverError::Canceled => "oneshot canceled",
600            ConservativeReceiverError::ReceiveBeforeSend => "recv called on channel before send",
601        }
602    }
603}
604
605impl ::std::fmt::Display for ConservativeReceiverError {
606    fn fmt(&self, fmt: &mut ::std::fmt::Formatter) -> ::std::fmt::Result {
607        match self {
608            ConservativeReceiverError::Canceled => write!(fmt, "oneshot canceled"),
609            ConservativeReceiverError::ReceiveBeforeSend => {
610                write!(fmt, "recv called on channel before send")
611            }
612        }
613    }
614}
615
616impl ::std::convert::From<oneshot::Canceled> for ConservativeReceiverError {
617    fn from(_: oneshot::Canceled) -> ConservativeReceiverError {
618        ConservativeReceiverError::Canceled
619    }
620}
621
622impl<T> ConservativeReceiver<T> {
623    /// Return an instance of [ConservativeReceiver] wrapping the [oneshot::Receiver]
624    pub fn new(recv: oneshot::Receiver<T>) -> Self {
625        ConservativeReceiver(recv)
626    }
627}
628
629impl<T> Future for ConservativeReceiver<T> {
630    type Item = T;
631    type Error = ConservativeReceiverError;
632
633    fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
634        match self.0.poll()? {
635            Async::Ready(item) => Ok(Async::Ready(item)),
636            Async::NotReady => Err(ConservativeReceiverError::ReceiveBeforeSend),
637        }
638    }
639}
640
641/// A stream wrapper returned by [StreamExt::return_remainder]
642pub struct ReturnRemainder<In> {
643    inner: Option<In>,
644    send: Option<oneshot::Sender<In>>,
645}
646
647impl<In> ReturnRemainder<In> {
648    fn new(inner: In) -> (Self, ConservativeReceiver<In>) {
649        let (send, recv) = oneshot::channel();
650        (
651            Self {
652                inner: Some(inner),
653                send: Some(send),
654            },
655            ConservativeReceiver::new(recv),
656        )
657    }
658}
659
660impl<In: Stream> Stream for ReturnRemainder<In> {
661    type Item = In::Item;
662    type Error = In::Error;
663
664    fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
665        let maybe_item = match self.inner {
666            Some(ref mut inner) => try_ready!(inner.poll()),
667            None => return Ok(Async::Ready(None)),
668        };
669
670        if maybe_item.is_none() {
671            let inner = self
672                .inner
673                .take()
674                .expect("inner was just polled, should be some");
675            let send = self.send.take().expect("send is None iff inner is None");
676            // The Receiver will handle errors
677            let _ = send.send(inner);
678        }
679
680        Ok(Async::Ready(maybe_item))
681    }
682}
683
684/// A convenience macro for working with `io::Result<T>` from the `Read` and
685/// `Write` traits.
686///
687/// This macro takes `io::Result<T>` as input, and returns `Poll<T, io::Error>`
688/// as the output. If the input type is of the `Err` variant, then
689/// `Poll::NotReady` is returned if it indicates `WouldBlock` or otherwise `Err`
690/// is returned.
691#[macro_export]
692#[rustfmt::skip]
693macro_rules! handle_nb {
694    ($e:expr) => {
695        match $e {
696            Ok(t) => Ok(::futures::Async::Ready(t)),
697            Err(ref e) if e.kind() == ::std::io::ErrorKind::WouldBlock => {
698                Ok(::futures::Async::NotReady)
699            }
700            Err(e) => Err(e),
701        }
702    };
703}
704
705/// Macro that can be used like `?` operator, but in the context where the expected return type is
706/// BoxFuture. The result of it is either Ok part of Result or immediate returning the Err part
707/// converted into BoxFuture.
708#[macro_export]
709#[rustfmt::skip]
710macro_rules! try_boxfuture {
711    ($e:expr) => {
712        match $e {
713            Ok(t) => t,
714            Err(e) => return $crate::FutureExt::boxify($crate::futures_reexport::future::err(e.into())),
715        }
716    };
717}
718
719/// Macro that can be used like `?` operator, but in the context where the expected return type is
720/// BoxStream. The result of it is either Ok part of Result or immediate returning the Err part
721/// converted into BoxStream.
722#[macro_export]
723#[rustfmt::skip]
724macro_rules! try_boxstream {
725    ($e:expr) => {
726        match $e {
727            Ok(t) => t,
728            Err(e) => return $crate::StreamExt::boxify($crate::futures_reexport::stream::once(Err(e.into()))),
729        }
730    };
731}
732
733/// Macro that can be used like ensure! macro from failure crate, but in the context where the
734/// expected return type is BoxFuture. Exits a function early with an Error if the condition is not
735/// satisfied.
736#[macro_export]
737#[rustfmt::skip]
738macro_rules! ensure_boxfuture {
739    ($cond:expr, $e:expr) => {
740        if !($cond) {
741            return $crate::FutureExt::boxify(::futures::future::err($e.into()));
742        }
743    };
744}
745
746/// Macro that can be used like ensure! macro from failure crate, but in the context where the
747/// expected return type is BoxStream. Exits a function early with an Error if the condition is not
748/// satisfied.
749#[macro_export]
750#[rustfmt::skip]
751macro_rules! ensure_boxstream {
752    ($cond:expr, $e:expr) => {
753        if !($cond) {
754            return $crate::StreamExt::boxify(::futures::stream::once(Err($e.into())));
755        }
756    };
757}
758
759/// Macro that can be used like `?` operator, but in the context where the expected return type is
760///  a left future. The result of it is either Ok part of Result or immediate returning the Err
761//part / converted into a  a left future.
762#[macro_export]
763#[rustfmt::skip]
764macro_rules! try_left_future {
765    ($e:expr) => {
766        match $e {
767            Ok(t) => t,
768            Err(e) => return $crate::futures_reexport::future::err(e.into()).left_future(),
769        }
770    };
771}
772
773/// Simple adapter from `Sink` interface to `AsyncWrite` interface.
774/// It can be useful to convert from the interface that supports only AsyncWrite, and get
775/// Stream as a result.
776pub struct SinkToAsyncWrite<S> {
777    sink: S,
778}
779
780impl<S> SinkToAsyncWrite<S> {
781    /// Return an instance of [SinkToAsyncWrite] wrapping a Sink
782    pub fn new(sink: S) -> Self {
783        SinkToAsyncWrite { sink }
784    }
785}
786
787fn create_std_error<E: Debug>(err: E) -> std_io::Error {
788    std_io::Error::new(std_io::ErrorKind::Other, format!("{err:?}"))
789}
790
791impl<E, S> std_io::Write for SinkToAsyncWrite<S>
792where
793    S: Sink<SinkItem = Bytes, SinkError = E>,
794    E: Debug,
795{
796    fn write(&mut self, buf: &[u8]) -> ::std::io::Result<usize> {
797        let bytes = Bytes::from(buf);
798        match self.sink.start_send(bytes) {
799            Ok(AsyncSink::Ready) => Ok(buf.len()),
800            Ok(AsyncSink::NotReady(_)) => Err(std_io::Error::new(
801                std_io::ErrorKind::WouldBlock,
802                "channel is busy",
803            )),
804            Err(err) => Err(create_std_error(err)),
805        }
806    }
807
808    fn flush(&mut self) -> std_io::Result<()> {
809        match self.sink.poll_complete() {
810            Ok(Async::Ready(())) => Ok(()),
811            Ok(Async::NotReady) => Err(std_io::Error::new(
812                std_io::ErrorKind::WouldBlock,
813                "channel is busy",
814            )),
815            Err(err) => Err(create_std_error(err)),
816        }
817    }
818}
819
820impl<E, S> AsyncWrite for SinkToAsyncWrite<S>
821where
822    S: Sink<SinkItem = Bytes, SinkError = E>,
823    E: Debug,
824{
825    fn shutdown(&mut self) -> Poll<(), std_io::Error> {
826        match self.sink.close() {
827            Ok(res) => Ok(res),
828            Err(err) => Err(create_std_error(err)),
829        }
830    }
831}
832
833/// It's a combinator that converts `Stream<A>` into `Stream<Vec<A>>`.
834/// So interface is similar to `.chunks()` method, but there's an important difference:
835/// BatchStream won't wait until the whole batch fills up i.e. as soon as underlying stream
836/// return NotReady, then new batch is returned from BatchStream
837pub struct BatchStream<S>
838where
839    S: Stream,
840{
841    inner: stream::Fuse<S>,
842    err: Option<S::Error>,
843    limit: usize,
844}
845
846impl<S: Stream> BatchStream<S> {
847    /// Return an instance of [BatchStream] wrapping a Stream with the provided limit set
848    pub fn new(s: S, limit: usize) -> Self {
849        Self {
850            inner: s.fuse(),
851            err: None,
852            limit,
853        }
854    }
855}
856
857impl<S: Stream> Stream for BatchStream<S> {
858    type Item = Vec<S::Item>;
859    type Error = S::Error;
860
861    fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
862        let mut batch = vec![];
863
864        if let Some(err) = self.err.take() {
865            return Err(err);
866        }
867
868        while batch.len() < self.limit {
869            match self.inner.poll() {
870                Ok(Async::Ready(Some(v))) => batch.push(v),
871                Ok(Async::NotReady) | Ok(Async::Ready(None)) => break,
872                Err(err) => {
873                    self.err = Some(err);
874                    break;
875                }
876            }
877        }
878
879        if batch.is_empty() {
880            if let Some(err) = self.err.take() {
881                return Err(err);
882            }
883
884            if self.inner.is_done() {
885                Ok(Async::Ready(None))
886            } else {
887                Ok(Async::NotReady)
888            }
889        } else {
890            Ok(Async::Ready(Some(batch)))
891        }
892    }
893}
894
895#[cfg(test)]
896mod test {
897    use std::sync::atomic::AtomicUsize;
898    use std::sync::atomic::Ordering;
899    use std::sync::Arc;
900
901    use anyhow::Result;
902    use assert_matches::assert_matches;
903    use cloned::cloned;
904    use futures::future::err;
905    use futures::future::ok;
906    use futures::stream;
907    use futures::sync::mpsc;
908    use futures::IntoFuture;
909    use futures::Stream;
910    use futures03::compat::Future01CompatExt;
911    use tokio::runtime::Runtime;
912
913    use super::*;
914    #[derive(Debug)]
915    struct MyErr;
916
917    impl<T> From<mpsc::SendError<T>> for MyErr {
918        fn from(_: mpsc::SendError<T>) -> Self {
919            MyErr
920        }
921    }
922
923    #[test]
924    fn discard() {
925        use futures::sync::mpsc;
926
927        let runtime = Runtime::new().unwrap();
928
929        let (tx, rx) = mpsc::channel(1);
930
931        let xfer = stream::iter_ok::<_, MyErr>(vec![123]).forward(tx);
932
933        runtime.spawn(xfer.discard().compat());
934
935        match runtime.block_on(rx.collect().compat()) {
936            Ok(v) => assert_eq!(v, vec![123]),
937            bad => panic!("bad {bad:?}"),
938        }
939    }
940
941    #[test]
942    fn inspect_err() {
943        let count = Arc::new(AtomicUsize::new(0));
944        cloned!(count as count_cloned);
945        let runtime = Runtime::new().unwrap();
946        let work = err::<i32, i32>(42).inspect_err(move |e| {
947            assert_eq!(42, *e);
948            count_cloned.fetch_add(1, Ordering::SeqCst);
949        });
950        if runtime.block_on(work.compat()).is_ok() {
951            panic!("future is supposed to fail");
952        }
953        assert_eq!(1, count.load(Ordering::SeqCst));
954    }
955
956    #[test]
957    fn inspect_ok() {
958        let count = Arc::new(AtomicUsize::new(0));
959        cloned!(count as count_cloned);
960        let runtime = Runtime::new().unwrap();
961        let work = ok::<i32, i32>(42).inspect_err(move |_| {
962            count_cloned.fetch_add(1, Ordering::SeqCst);
963        });
964        if runtime.block_on(work.compat()).is_err() {
965            panic!("future is supposed to succeed");
966        }
967        assert_eq!(0, count.load(Ordering::SeqCst));
968    }
969
970    #[test]
971    fn inspect_result() {
972        let count = Arc::new(AtomicUsize::new(0));
973        cloned!(count as count_cloned);
974        let runtime = Runtime::new().unwrap();
975        let work = err::<i32, i32>(42).inspect_result(move |res| {
976            if let Err(e) = res {
977                assert_eq!(42, *e);
978                count_cloned.fetch_add(1, Ordering::SeqCst);
979            } else {
980                count_cloned.fetch_add(2, Ordering::SeqCst);
981            }
982        });
983        if runtime.block_on(work.compat()).is_ok() {
984            panic!("future is supposed to fail");
985        }
986        assert_eq!(1, count.load(Ordering::SeqCst));
987    }
988
989    #[test]
990    fn enumerate() {
991        let s = stream::iter_ok::<_, ()>(vec!["hello", "there", "world"]);
992        let es = Enumerate::new(s);
993        let v = es.collect().wait();
994
995        assert_eq!(v, Ok(vec![(0, "hello"), (1, "there"), (2, "world")]));
996    }
997
998    #[test]
999    fn empty() {
1000        let mut s = stream::empty::<(), ()>();
1001        // Ensure that the stream doesn't have to be consumed.
1002        assert!(s.by_ref().is_empty().wait().unwrap());
1003        assert!(!s.not_empty().wait().unwrap());
1004
1005        let mut s = stream::once::<_, ()>(Ok("foo"));
1006        assert!(!s.by_ref().is_empty().wait().unwrap());
1007        // The above is_empty would consume the first element, so the stream has to be
1008        // reinitialized.
1009        let s = stream::once::<_, ()>(Ok("foo"));
1010        assert!(s.not_empty().wait().unwrap());
1011    }
1012
1013    #[test]
1014    fn return_remainder() {
1015        use futures::future::poll_fn;
1016
1017        let s = stream::iter_ok::<_, ()>(vec!["hello", "there", "world"]).fuse();
1018        let (mut s, mut remainder) = s.return_remainder();
1019
1020        let runtime = Runtime::new().unwrap();
1021        let res: Result<(), ()> = runtime.block_on(
1022            poll_fn(move || {
1023                assert_matches!(
1024                    remainder.poll(),
1025                    Err(ConservativeReceiverError::ReceiveBeforeSend)
1026                );
1027
1028                assert_eq!(s.poll(), Ok(Async::Ready(Some("hello"))));
1029                assert_matches!(
1030                    remainder.poll(),
1031                    Err(ConservativeReceiverError::ReceiveBeforeSend)
1032                );
1033
1034                assert_eq!(s.poll(), Ok(Async::Ready(Some("there"))));
1035                assert_matches!(
1036                    remainder.poll(),
1037                    Err(ConservativeReceiverError::ReceiveBeforeSend)
1038                );
1039
1040                assert_eq!(s.poll(), Ok(Async::Ready(Some("world"))));
1041                assert_matches!(
1042                    remainder.poll(),
1043                    Err(ConservativeReceiverError::ReceiveBeforeSend)
1044                );
1045
1046                assert_eq!(s.poll(), Ok(Async::Ready(None)));
1047                match remainder.poll() {
1048                    Ok(Async::Ready(s)) => assert!(s.is_done()),
1049                    bad => panic!("unexpected result: {bad:?}"),
1050                }
1051
1052                Ok(Async::Ready(()))
1053            })
1054            .compat(),
1055        );
1056
1057        assert_matches!(res, Ok(()));
1058    }
1059
1060    fn assert_flush<E, S>(sink: &mut SinkToAsyncWrite<S>)
1061    where
1062        S: Sink<SinkItem = Bytes, SinkError = E>,
1063        E: Debug,
1064    {
1065        use std::io::Write;
1066        loop {
1067            let flush_res = sink.flush();
1068            if flush_res.is_ok() {
1069                break;
1070            }
1071            if let Err(ref e) = flush_res {
1072                assert_eq!(e.kind(), std_io::ErrorKind::WouldBlock);
1073            }
1074        }
1075    }
1076
1077    fn assert_shutdown<E, S>(sink: &mut SinkToAsyncWrite<S>)
1078    where
1079        S: Sink<SinkItem = Bytes, SinkError = E>,
1080        E: Debug,
1081    {
1082        loop {
1083            let shutdown_res = sink.shutdown();
1084            if shutdown_res.is_ok() {
1085                break;
1086            }
1087            if let Err(ref e) = shutdown_res {
1088                assert_eq!(e.kind(), std_io::ErrorKind::WouldBlock);
1089            }
1090        }
1091    }
1092
1093    #[test]
1094    fn sink_to_async_write() {
1095        use std::io::Write;
1096
1097        use futures::sync::mpsc;
1098        let rt = tokio::runtime::Runtime::new().unwrap();
1099
1100        let (tx, rx) = mpsc::channel::<Bytes>(1);
1101
1102        let messages_num = 10;
1103
1104        rt.spawn(
1105            Ok::<_, ()>(())
1106                .into_future()
1107                .map(move |()| {
1108                    let mut async_write = SinkToAsyncWrite::new(tx);
1109                    for i in 0..messages_num {
1110                        loop {
1111                            let res = async_write.write(format!("{i}").as_bytes());
1112                            if let Err(ref e) = res {
1113                                assert_eq!(e.kind(), std_io::ErrorKind::WouldBlock);
1114                                assert_flush(&mut async_write);
1115                            } else {
1116                                break;
1117                            }
1118                        }
1119                    }
1120
1121                    assert_flush(&mut async_write);
1122                    assert_shutdown(&mut async_write);
1123                })
1124                .compat(),
1125        );
1126
1127        let res = rt.block_on(rx.collect().compat()).unwrap();
1128        assert_eq!(res.len(), messages_num);
1129    }
1130
1131    #[test]
1132    fn test_buffered() {
1133        type TestStream = BoxStream<(BoxFuture<(), ()>, u64), ()>;
1134
1135        fn create_stream() -> (Arc<AtomicUsize>, TestStream) {
1136            let s: TestStream = stream::iter_ok(vec![
1137                (future::ok(()).boxify(), 100),
1138                (future::ok(()).boxify(), 2),
1139            ])
1140            .boxify();
1141
1142            let counter = Arc::new(AtomicUsize::new(0));
1143
1144            (
1145                counter.clone(),
1146                s.inspect({
1147                    move |_val| {
1148                        counter.fetch_add(1, Ordering::SeqCst);
1149                    }
1150                })
1151                .boxify(),
1152            )
1153        }
1154
1155        let runtime = tokio::runtime::Runtime::new().unwrap();
1156
1157        let (counter, s) = create_stream();
1158        let params = BufferedParams {
1159            weight_limit: 10,
1160            buffer_size: 10,
1161        };
1162        let s = s.buffered_weight_limited(params);
1163        if let Ok((Some(()), s)) = runtime.block_on(s.into_future().compat()) {
1164            assert_eq!(counter.load(Ordering::SeqCst), 1);
1165            assert_eq!(runtime.block_on(s.collect().compat()).unwrap().len(), 1);
1166            assert_eq!(counter.load(Ordering::SeqCst), 2);
1167        } else {
1168            panic!("failed to block on a stream");
1169        }
1170
1171        let (counter, s) = create_stream();
1172        let params = BufferedParams {
1173            weight_limit: 200,
1174            buffer_size: 10,
1175        };
1176        let s = s.buffered_weight_limited(params);
1177        if let Ok((Some(()), s)) = runtime.block_on(s.into_future().compat()) {
1178            assert_eq!(counter.load(Ordering::SeqCst), 2);
1179            assert_eq!(runtime.block_on(s.collect().compat()).unwrap().len(), 1);
1180            assert_eq!(counter.load(Ordering::SeqCst), 2);
1181        } else {
1182            panic!("failed to block on a stream");
1183        }
1184    }
1185
1186    use std::collections::HashSet;
1187
1188    fn assert_same_elements<I, T>(src: Vec<I>, iter: T)
1189    where
1190        I: Copy + Debug + Ord,
1191        T: IntoIterator<Item = I>,
1192    {
1193        let mut dst_sorted: Vec<I> = iter.into_iter().collect();
1194        dst_sorted.sort();
1195
1196        let mut src_sorted = src;
1197        src_sorted.sort();
1198
1199        assert_eq!(src_sorted, dst_sorted);
1200    }
1201
1202    #[test]
1203    fn collect_into_vec() {
1204        let items = vec![1, 2, 3];
1205        let future = futures::stream::iter_ok::<_, ()>(items.clone()).collect_to::<Vec<i32>>();
1206        let runtime = Runtime::new().unwrap();
1207        match runtime.block_on(future.compat()) {
1208            Ok(collections) => assert_same_elements(items, collections),
1209            Err(()) => panic!("future is supposed to succeed"),
1210        }
1211    }
1212
1213    #[test]
1214    fn collect_into_set() {
1215        let items = vec![1, 2, 3];
1216        let future = futures::stream::iter_ok::<_, ()>(items.clone()).collect_to::<HashSet<i32>>();
1217        let runtime = Runtime::new().unwrap();
1218        match runtime.block_on(future.compat()) {
1219            Ok(collections) => assert_same_elements(items, collections),
1220            Err(()) => panic!("future is supposed to succeed"),
1221        }
1222    }
1223}