and_then_concurrent/
lib.rs

1//! Use on `impl Stream`s via the [`TryStreamAndThenExt`] trait.
2//!
3//! Why is this necessary? Consider the example below. We have a `Stream` from `try_unfold`, but
4//! this stream is splitting some larger stream into sub-streams, with each sub-stream
5//! represented by a channel. If we simply called `and_then`, that function's implementation,
6//! as an optimization, only keeps *one* "pending" future in its state. This means that it
7//! cannot poll the backing stream, because that might produce another future which it has no
8//! space for. So, it *must* run the pending future to completion before polling the stream
9//! again.
10//!
11//! Unfortunately, in this case, the backing stream has to be polled for our future to resolve!
12//! So using `and_then` will deadlock. Instead, this crate makes a tradeoff: it will hold a
13//! list of pending futures in a `FuturesUnordered`, so it is safe to
14//! poll the backing stream. This means that if the resulting futures don't resolve, we could
15//! have a large list of futures.
16//!
17//! ```rust
18//! # use and_then_concurrent::TryStreamAndThenExt;
19//! # use futures_util::stream::TryStreamExt;
20//! # use std::collections::HashMap;
21//! # use std::time::Duration;
22//! # use tokio::{sync::mpsc, time::sleep};
23//! # #[tokio::main]
24//! # async fn main() {
25//! let c = futures_util::stream::try_unfold(
26//!     (
27//!         0,
28//!         HashMap::<usize, mpsc::UnboundedSender<(usize, usize)>>::default(),
29//!     ),
30//!     move |(mut i, mut map)| async move {
31//!         loop {
32//!             sleep(Duration::from_millis(10)).await;
33//!             let (substream, message) = (i % 3, i);
34//!             i += 1;
35//!             if i > 25 {
36//!                 return Ok(None);
37//!             }
38//!
39//!             let mut new = None;
40//!             if map
41//!                 .entry(substream)
42//!                 .or_insert_with(|| {
43//!                     let (sub_s, sub_r) = mpsc::unbounded_channel();
44//!                     new = Some(sub_r);
45//!                     sub_s
46//!                 })
47//!                 .send((substream, message))
48//!                 .is_err()
49//!             {
50//!                 map.remove(&substream);
51//!             }
52//!
53//!             if let Some(new_sub_r) = new {
54//!                 return Ok::<_, String>(Some((new_sub_r, (i, map))));
55//!             }
56//!         }
57//!     },
58//! )
59//! // .and_then(...) would deadlock!
60//! .and_then_concurrent(|mut res| async move {
61//!     loop {
62//!         let (stream, val): (usize, usize) = match res.recv().await {
63//!             None => return Ok(()),
64//!             Some(s) => s,
65//!         };
66//!         println!("got {:?} on stream {:?}", val, stream);
67//!     }
68//! })
69//! .try_collect::<Vec<_>>();
70//! c.await.unwrap();
71//! # }
72//! ```
73
74use futures_util::{
75    future::TryFuture,
76    stream::{FuturesUnordered, Stream, TryStream},
77};
78use pin_project::pin_project;
79use std::future::Future;
80use std::pin::Pin;
81use std::task::{Context, Poll};
82
83/// Extension to [`futures_util::stream::TryStreamExt`]
84pub trait TryStreamAndThenExt: TryStream {
85    /// Chain a computation when a stream value is ready, passing `Ok` values to the closure `f`.
86    ///
87    /// This function is similar to [`futures_util::stream::TryStreamExt::and_then`], but the
88    /// stream is polled concurrently with the futures returned by `f`. An unbounded number of
89    /// futures corresponding to past stream values is kept via `FuturesUnordered`.
90    ///
91    /// See [crate-level docs](`crate`) for an explanation and usage example.
92    fn and_then_concurrent<Fut, F>(self, f: F) -> AndThenConcurrent<Self, Fut, F>
93    where
94        Self: Sized,
95        Fut: TryFuture<Error = Self::Error>,
96        F: FnMut(Self::Ok) -> Fut;
97}
98
99impl<S: TryStream> TryStreamAndThenExt for S {
100    fn and_then_concurrent<Fut, F>(self, f: F) -> AndThenConcurrent<Self, Fut, F>
101    where
102        Self: Sized,
103        Fut: TryFuture<Error = Self::Error>,
104        F: FnMut(Self::Ok) -> Fut,
105    {
106        AndThenConcurrent {
107            stream: self,
108            futs: FuturesUnordered::new(),
109            fun: f,
110        }
111    }
112}
113
114/// Stream type for [`TryStreamAndThenExt::and_then_concurrent`].
115#[pin_project(project = AndThenConcurrentProj)]
116pub struct AndThenConcurrent<St, Fut: TryFuture, F> {
117    #[pin]
118    stream: St,
119    #[pin]
120    futs: FuturesUnordered<Fut>,
121    fun: F,
122}
123
124impl<St, Fut, F, T> Stream for AndThenConcurrent<St, Fut, F>
125where
126    St: TryStream,
127    Fut: Future<Output = Result<T, St::Error>>,
128    F: FnMut(St::Ok) -> Fut,
129{
130    type Item = Result<T, St::Error>;
131
132    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
133        let AndThenConcurrentProj {
134            mut stream,
135            mut futs,
136            fun,
137        } = self.project();
138
139        match stream.as_mut().try_poll_next(cx) {
140            Poll::Ready(Some(Ok(n))) => {
141                futs.push(fun(n));
142            }
143            Poll::Ready(Some(Err(e))) => return Poll::Ready(Some(Err(e))),
144            Poll::Pending => {
145                if futs.is_empty() {
146                    return Poll::Pending;
147                }
148            }
149            _ => (),
150        }
151
152        let x = futs.as_mut().poll_next(cx);
153        if let Poll::Pending = x {
154            // check stream once more
155            match stream.as_mut().try_poll_next(cx) {
156                Poll::Ready(Some(Ok(n))) => {
157                    futs.push(fun(n));
158                }
159                _ => (),
160            }
161        }
162        x
163    }
164}