async_fn_stream/
lib.rs

1#![doc = include_str!("../README.md")]
2
3use std::{
4    pin::Pin,
5    sync::{Arc, Mutex},
6    task::{Poll, Waker},
7};
8
9use futures_util::{Future, FutureExt, Stream};
10use pin_project_lite::pin_project;
11
12/// An intermediary that transfers values from stream to its consumer
13pub struct StreamEmitter<T> {
14    inner: Arc<Mutex<Inner<T>>>,
15}
16
17/// An intermediary that transfers values from stream to its consumer
18pub struct TryStreamEmitter<T, E> {
19    inner: Arc<Mutex<Inner<Result<T, E>>>>,
20}
21
22struct Inner<T> {
23    value: Option<T>,
24    waker: Option<Waker>,
25}
26
27pin_project! {
28    /// Implementation of [`Stream`] trait created by [`fn_stream`].
29    pub struct FnStream<T, Fut: Future<Output = ()>> {
30        #[pin]
31        fut: Fut,
32        inner: Arc<Mutex<Inner<T>>>,
33    }
34}
35
36/// Create a new infallible stream which is implemented by `func`.
37///
38/// Caller should pass an async function which will return successive stream elements via [`StreamEmitter::emit`].
39///
40/// # Example
41///
42/// ```rust
43/// use async_fn_stream::fn_stream;
44/// use futures_util::Stream;
45///
46/// fn build_stream() -> impl Stream<Item = i32> {
47///     fn_stream(|emitter| async move {
48///         for i in 0..3 {
49///             // yield elements from stream via `emitter`
50///             emitter.emit(i).await;
51///         }
52///     })
53/// }
54/// ```
55pub fn fn_stream<T, Fut: Future<Output = ()>>(
56    func: impl FnOnce(StreamEmitter<T>) -> Fut,
57) -> FnStream<T, Fut> {
58    FnStream::new(func)
59}
60
61impl<T, Fut: Future<Output = ()>> FnStream<T, Fut> {
62    fn new<F: FnOnce(StreamEmitter<T>) -> Fut>(func: F) -> Self {
63        let inner = Arc::new(Mutex::new(Inner {
64            value: None,
65            waker: None,
66        }));
67        let emitter = StreamEmitter {
68            inner: inner.clone(),
69        };
70        let fut = func(emitter);
71        Self { fut, inner }
72    }
73}
74
75impl<T, Fut: Future<Output = ()>> Stream for FnStream<T, Fut> {
76    type Item = T;
77
78    fn poll_next(
79        self: Pin<&mut Self>,
80        cx: &mut std::task::Context<'_>,
81    ) -> Poll<Option<Self::Item>> {
82        let mut this = self.project();
83
84        this.inner.lock().expect("Mutex was poisoned").waker = Some(cx.waker().clone());
85        let r = this.fut.poll_unpin(cx);
86        match r {
87            std::task::Poll::Ready(()) => Poll::Ready(None),
88            std::task::Poll::Pending => {
89                let value = this.inner.lock().expect("Mutex was poisoned").value.take();
90                match value {
91                    None => Poll::Pending,
92                    Some(value) => Poll::Ready(Some(value)),
93                }
94            }
95        }
96    }
97}
98
99/// Create a new fallible stream which is implemented by `func`.
100///
101/// Caller should pass an async function which can:
102///
103/// - return successive stream elements via [`StreamEmitter::emit`]
104/// - return transient errors via [`StreamEmitter::emit_err`]
105/// - return fatal errors as [`Result::Err`]
106///
107/// # Example
108/// ```rust
109/// use async_fn_stream::try_fn_stream;
110/// use futures_util::Stream;
111///
112/// fn build_stream() -> impl Stream<Item = Result<i32, anyhow::Error>> {
113///     try_fn_stream(|emitter| async move {
114///         for i in 0..3 {
115///             // yield elements from stream via `emitter`
116///             emitter.emit(i).await;
117///         }
118///
119///         // return errors view emitter without ending the stream
120///         emitter.emit_err(anyhow::anyhow!("An error happened"));
121///
122///         // return errors from stream, ending the stream
123///         Err(anyhow::anyhow!("An error happened"))
124///     })
125/// }
126/// ```
127pub fn try_fn_stream<T, E, Fut: Future<Output = Result<(), E>>>(
128    func: impl FnOnce(TryStreamEmitter<T, E>) -> Fut,
129) -> TryFnStream<T, E, Fut> {
130    TryFnStream::new(func)
131}
132
133pin_project! {
134    /// Implementation of [`Stream`] trait created by [`try_fn_stream`].
135    pub struct TryFnStream<T, E, Fut: Future<Output = Result<(), E>>> {
136        is_err: bool,
137        #[pin]
138        fut: Fut,
139        inner: Arc<Mutex<Inner<Result<T, E>>>>,
140    }
141}
142
143impl<T, E, Fut: Future<Output = Result<(), E>>> TryFnStream<T, E, Fut> {
144    fn new<F: FnOnce(TryStreamEmitter<T, E>) -> Fut>(func: F) -> Self {
145        let inner = Arc::new(Mutex::new(Inner {
146            value: None,
147            waker: None,
148        }));
149        let emitter = TryStreamEmitter {
150            inner: inner.clone(),
151        };
152        let fut = func(emitter);
153        Self {
154            is_err: false,
155            fut,
156            inner,
157        }
158    }
159}
160
161impl<T, E, Fut: Future<Output = Result<(), E>>> Stream for TryFnStream<T, E, Fut> {
162    type Item = Result<T, E>;
163
164    fn poll_next(
165        self: Pin<&mut Self>,
166        cx: &mut std::task::Context<'_>,
167    ) -> Poll<Option<Self::Item>> {
168        if self.is_err {
169            return Poll::Ready(None);
170        }
171        let mut this = self.project();
172        this.inner.lock().expect("Mutex was poisoned").waker = Some(cx.waker().clone());
173        let r = this.fut.poll_unpin(cx);
174        match r {
175            std::task::Poll::Ready(Ok(())) => Poll::Ready(None),
176            std::task::Poll::Ready(Err(e)) => {
177                *this.is_err = true;
178                Poll::Ready(Some(Err(e)))
179            }
180            std::task::Poll::Pending => {
181                let value = this.inner.lock().expect("Mutex was poisoned").value.take();
182                match value {
183                    None => Poll::Pending,
184                    Some(value) => Poll::Ready(Some(value)),
185                }
186            }
187        }
188    }
189}
190
191impl<T> StreamEmitter<T> {
192    /// Emit value from a stream and wait until stream consumer calls [`futures_util::StreamExt::next`] again.
193    ///
194    /// # Panics
195    /// Will panic if:
196    /// * `emit` is called twice without awaiting result of first call
197    /// * `emit` is called not in context of polling the stream
198    #[must_use = "Ensure that emit() is awaited"]
199    pub fn emit(&self, value: T) -> CollectFuture {
200        let mut inner = self.inner.lock().expect("Mutex was poisoned");
201        let inner = &mut *inner;
202        if inner.value.is_some() {
203            panic!("StreamEmitter::emit() was called without `.await`'ing result of previous emit")
204        }
205        inner.value = Some(value);
206        inner
207            .waker
208            .take()
209            .expect("StreamEmitter::emit() should only be called in context of Future::poll()")
210            .wake();
211        CollectFuture { polled: false }
212    }
213}
214
215impl<T, E> TryStreamEmitter<T, E> {
216    fn internal_emit(&self, res: Result<T, E>) -> CollectFuture {
217        let mut inner = self.inner.lock().expect("Mutex was poisoned");
218        let inner = &mut *inner;
219        if inner.value.is_some() {
220            panic!(
221                "TreStreamEmitter::emit/emit_err() was called without `.await`'ing result of previous collect"
222            )
223        }
224        inner.value = Some(res);
225        inner
226            .waker
227            .take()
228            .expect("TreStreamEmitter::emit/emit_err() should only be called in context of Future::poll()")
229            .wake();
230        CollectFuture { polled: false }
231    }
232
233    /// Emit value from a stream and wait until stream consumer calls [`futures_util::StreamExt::next`] again.
234    ///
235    /// # Panics
236    /// Will panic if:
237    /// * `emit`/`emit_err` is called twice without awaiting result of the first call
238    /// * `emit` is called not in context of polling the stream
239    #[must_use = "Ensure that emit() is awaited"]
240    pub fn emit(&self, value: T) -> CollectFuture {
241        self.internal_emit(Ok(value))
242    }
243
244    /// Emit error from a stream and wait until stream consumer calls [`futures_util::StreamExt::next`] again.
245    ///
246    /// # Panics
247    /// Will panic if:
248    /// * `emit`/`emit_err` is called twice without awaiting result of the first call
249    /// * `emit_err` is called not in context of polling the stream
250    #[must_use = "Ensure that emit_err() is awaited"]
251    pub fn emit_err(&self, err: E) -> CollectFuture {
252        self.internal_emit(Err(err))
253    }
254}
255
256/// Future returned from [`StreamEmitter::emit`].
257pub struct CollectFuture {
258    polled: bool,
259}
260
261impl Future for CollectFuture {
262    type Output = ();
263
264    fn poll(self: Pin<&mut Self>, _cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
265        if self.polled {
266            Poll::Ready(())
267        } else {
268            self.get_mut().polled = true;
269            Poll::Pending
270        }
271    }
272}
273
274#[cfg(test)]
275mod tests {
276    use std::io::ErrorKind;
277
278    use futures_util::{pin_mut, StreamExt};
279
280    use super::*;
281
282    #[test]
283    fn infallible_works() {
284        futures_executor::block_on(async {
285            let stream = fn_stream(|collector| async move {
286                eprintln!("stream 1");
287                collector.emit(1).await;
288                eprintln!("stream 2");
289                collector.emit(2).await;
290                eprintln!("stream 3");
291            });
292            pin_mut!(stream);
293            assert_eq!(Some(1), stream.next().await);
294            assert_eq!(Some(2), stream.next().await);
295            assert_eq!(None, stream.next().await);
296        });
297    }
298
299    #[test]
300    fn infallible_lifetime() {
301        let a = 1;
302        futures_executor::block_on(async {
303            let b = 2;
304            let a = &a;
305            let b = &b;
306            let stream = fn_stream(|collector| async move {
307                eprintln!("stream 1");
308                collector.emit(a).await;
309                eprintln!("stream 2");
310                collector.emit(b).await;
311                eprintln!("stream 3");
312            });
313            pin_mut!(stream);
314            assert_eq!(Some(a), stream.next().await);
315            assert_eq!(Some(b), stream.next().await);
316            assert_eq!(None, stream.next().await);
317        });
318    }
319
320    #[test]
321    #[should_panic]
322    fn infallible_panics_on_multiple_collects() {
323        futures_executor::block_on(async {
324            #[allow(unused_must_use)]
325            let stream = fn_stream(|collector| async move {
326                eprintln!("stream 1");
327                collector.emit(1);
328                collector.emit(2);
329                eprintln!("stream 3");
330            });
331            pin_mut!(stream);
332            assert_eq!(Some(1), stream.next().await);
333            assert_eq!(Some(2), stream.next().await);
334            assert_eq!(None, stream.next().await);
335        });
336    }
337
338    #[test]
339    fn fallible_works() {
340        futures_executor::block_on(async {
341            let stream = try_fn_stream(|collector| async move {
342                eprintln!("try stream 1");
343                collector.emit(1).await;
344                eprintln!("try stream 2");
345                collector.emit(2).await;
346                eprintln!("try stream 3");
347                Err(std::io::Error::from(ErrorKind::Other))
348            });
349            pin_mut!(stream);
350            assert_eq!(1, stream.next().await.unwrap().unwrap());
351            assert_eq!(2, stream.next().await.unwrap().unwrap());
352            assert!(stream.next().await.unwrap().is_err());
353            assert!(stream.next().await.is_none());
354        });
355    }
356
357    #[test]
358    fn fallible_emit_err_works() {
359        futures_executor::block_on(async {
360            let stream = try_fn_stream(|collector| async move {
361                eprintln!("try stream 1");
362                collector.emit(1).await;
363                eprintln!("try stream 2");
364                collector.emit(2).await;
365                eprintln!("try stream 3");
366                collector
367                    .emit_err(std::io::Error::from(ErrorKind::Other))
368                    .await;
369                eprintln!("try stream 4");
370                Err(std::io::Error::from(ErrorKind::Other))
371            });
372            pin_mut!(stream);
373            assert_eq!(1, stream.next().await.unwrap().unwrap());
374            assert_eq!(2, stream.next().await.unwrap().unwrap());
375            assert!(stream.next().await.unwrap().is_err());
376            assert!(stream.next().await.unwrap().is_err());
377            assert!(stream.next().await.is_none());
378        });
379    }
380
381    #[test]
382    fn method_async() {
383        struct St {
384            a: String,
385        }
386
387        impl St {
388            async fn f1(&self) -> impl Stream<Item = &str> {
389                self.f2().await
390            }
391
392            async fn f2(&self) -> impl Stream<Item = &str> {
393                fn_stream(|collector| async move {
394                    collector.emit(self.a.as_str()).await;
395                    collector.emit(self.a.as_str()).await;
396                    collector.emit(self.a.as_str()).await;
397                })
398            }
399        }
400
401        futures_executor::block_on(async {
402            let l = St {
403                a: "qwe".to_owned(),
404            };
405            let s = l.f1().await;
406            let z: Vec<&str> = s.collect().await;
407            assert_eq!(z, ["qwe", "qwe", "qwe"]);
408        })
409    }
410}