async_fn_stream/
lib.rs

1#![doc = include_str!("../README.md")]
2
3use std::{
4    pin::Pin,
5    sync::{atomic::AtomicBool, Arc, Mutex},
6    task::{Poll, Waker},
7};
8
9use futures_util::{Future, Stream};
10use pin_project_lite::pin_project;
11use smallvec::SmallVec;
12
13/// An intermediary that transfers values from stream to its consumer
14pub struct StreamEmitter<T> {
15    inner: Arc<Mutex<Inner<T>>>,
16}
17
18/// An intermediary that transfers values from stream to its consumer
19pub struct TryStreamEmitter<T, E> {
20    inner: Arc<Mutex<Inner<Result<T, E>>>>,
21}
22
23struct Inner<T> {
24    // polling is `true` only for duration of a single `FnStream::poll()` call, which helps detecting invalid usage (cross-thread or cross-task)
25    polling: AtomicBool,
26    // `stream_waker` is used to compare the waker of the stream future with the waker of the `Emit` future.
27    // If the stream implementation does not have sub-executors, we don't have to push wakers to `pending_wakers`, which avoids cloning it.
28    stream_waker: Option<Waker>,
29    // Due to internal concurrency, a single call to stream's future may yield multiple elements.
30    // All elements are stored here and yielded from the stream before polling the future again.
31    pending_values: SmallVec<[T; 1]>,
32    pending_wakers: SmallVec<[Waker; 1]>,
33}
34
35pin_project! {
36    /// Implementation of [`Stream`] trait created by [`fn_stream`].
37    pub struct FnStream<T, Fut: Future<Output = ()>> {
38        #[pin]
39        fut: Fut,
40        inner: Arc<Mutex<Inner<T>>>,
41    }
42}
43
44/// Create a new infallible stream which is implemented by `func`.
45///
46/// Caller should pass an async function which will return successive stream elements via [`StreamEmitter::emit`].
47///
48/// # Example
49///
50/// ```rust
51/// use async_fn_stream::fn_stream;
52/// use futures_util::Stream;
53///
54/// fn build_stream() -> impl Stream<Item = i32> {
55///     fn_stream(|emitter| async move {
56///         for i in 0..3 {
57///             // yield elements from stream via `emitter`
58///             emitter.emit(i).await;
59///         }
60///     })
61/// }
62/// ```
63pub fn fn_stream<T, Fut: Future<Output = ()>>(
64    func: impl FnOnce(StreamEmitter<T>) -> Fut,
65) -> FnStream<T, Fut> {
66    FnStream::new(func)
67}
68
69impl<T, Fut: Future<Output = ()>> FnStream<T, Fut> {
70    fn new<F: FnOnce(StreamEmitter<T>) -> Fut>(func: F) -> Self {
71        let inner = Arc::new(Mutex::new(Inner {
72            polling: AtomicBool::new(false),
73            stream_waker: None,
74            pending_values: SmallVec::new(),
75            pending_wakers: SmallVec::new(),
76        }));
77        let emitter = StreamEmitter {
78            inner: inner.clone(),
79        };
80        let fut = func(emitter);
81        Self { fut, inner }
82    }
83}
84
85impl<T, Fut: Future<Output = ()>> Stream for FnStream<T, Fut> {
86    type Item = T;
87
88    fn poll_next(
89        self: Pin<&mut Self>,
90        cx: &mut std::task::Context<'_>,
91    ) -> Poll<Option<Self::Item>> {
92        let this = self.project();
93
94        let mut inner_guard = this.inner.lock().expect("mutex was poisoned");
95        if let Some(value) = inner_guard.pending_values.pop() {
96            return Poll::Ready(Some(value));
97        }
98        if !inner_guard.pending_wakers.is_empty() {
99            for waker in inner_guard.pending_wakers.drain(..) {
100                if !waker.will_wake(cx.waker()) {
101                    waker.wake();
102                }
103            }
104        }
105        if let Some(stream_waker) = inner_guard.stream_waker.as_mut() {
106            stream_waker.clone_from(cx.waker());
107        } else {
108            inner_guard.stream_waker = Some(cx.waker().clone());
109        }
110
111        let old_polling = inner_guard
112            .polling
113            .swap(true, std::sync::atomic::Ordering::Relaxed);
114        drop(inner_guard);
115        assert!(
116            !old_polling,
117            "async-fn-stream invariant violation: polling must be false before entering poll"
118        );
119        let r = this.fut.poll(cx);
120        let mut inner_guard = this.inner.lock().expect("mutex was poisoned");
121        inner_guard
122            .polling
123            .store(false, std::sync::atomic::Ordering::Relaxed);
124        match r {
125            std::task::Poll::Ready(()) => Poll::Ready(None),
126            std::task::Poll::Pending => {
127                if let Some(value) = inner_guard.pending_values.pop() {
128                    Poll::Ready(Some(value))
129                } else {
130                    Poll::Pending
131                }
132            }
133        }
134    }
135}
136
137/// Create a new fallible stream which is implemented by `func`.
138///
139/// Caller should pass an async function which can:
140///
141/// - return successive stream elements via [`StreamEmitter::emit`]
142/// - return transient errors via [`StreamEmitter::emit_err`]
143/// - return fatal errors as [`Result::Err`]
144///
145/// # Example
146/// ```rust
147/// use async_fn_stream::try_fn_stream;
148/// use futures_util::Stream;
149///
150/// fn build_stream() -> impl Stream<Item = Result<i32, anyhow::Error>> {
151///     try_fn_stream(|emitter| async move {
152///         for i in 0..3 {
153///             // yield elements from stream via `emitter`
154///             emitter.emit(i).await;
155///         }
156///
157///         // return errors view emitter without ending the stream
158///         emitter.emit_err(anyhow::anyhow!("An error happened"));
159///
160///         // return errors from stream, ending the stream
161///         Err(anyhow::anyhow!("An error happened"))
162///     })
163/// }
164/// ```
165pub fn try_fn_stream<T, E, Fut: Future<Output = Result<(), E>>>(
166    func: impl FnOnce(TryStreamEmitter<T, E>) -> Fut,
167) -> TryFnStream<T, E, Fut> {
168    TryFnStream::new(func)
169}
170
171pin_project! {
172    /// Implementation of [`Stream`] trait created by [`try_fn_stream`].
173    pub struct TryFnStream<T, E, Fut: Future<Output = Result<(), E>>> {
174        is_err: bool,
175        #[pin]
176        fut: Fut,
177        inner: Arc<Mutex<Inner<Result<T, E>>>>,
178    }
179}
180
181impl<T, E, Fut: Future<Output = Result<(), E>>> TryFnStream<T, E, Fut> {
182    fn new<F: FnOnce(TryStreamEmitter<T, E>) -> Fut>(func: F) -> Self {
183        let inner = Arc::new(Mutex::new(Inner {
184            polling: AtomicBool::new(false),
185            stream_waker: None,
186            pending_values: SmallVec::new(),
187            pending_wakers: SmallVec::new(),
188        }));
189        let emitter = TryStreamEmitter {
190            inner: inner.clone(),
191        };
192        let fut = func(emitter);
193        Self {
194            is_err: false,
195            fut,
196            inner,
197        }
198    }
199}
200
201impl<T, E, Fut: Future<Output = Result<(), E>>> Stream for TryFnStream<T, E, Fut> {
202    type Item = Result<T, E>;
203
204    fn poll_next(
205        self: Pin<&mut Self>,
206        cx: &mut std::task::Context<'_>,
207    ) -> Poll<Option<Self::Item>> {
208        // TODO: merge the implementation with `FnStream`
209        if self.is_err {
210            return Poll::Ready(None);
211        }
212        let this = self.project();
213        let mut inner_guard = this.inner.lock().expect("mutex was poisoned");
214        if let Some(value) = inner_guard.pending_values.pop() {
215            return Poll::Ready(Some(value));
216        }
217        if !inner_guard.pending_wakers.is_empty() {
218            for waker in inner_guard.pending_wakers.drain(..) {
219                if !waker.will_wake(cx.waker()) {
220                    waker.wake();
221                }
222            }
223        }
224        if let Some(stream_waker) = inner_guard.stream_waker.as_mut() {
225            stream_waker.clone_from(cx.waker());
226        } else {
227            inner_guard.stream_waker = Some(cx.waker().clone());
228        }
229
230        let old_polling = inner_guard
231            .polling
232            .swap(true, std::sync::atomic::Ordering::Relaxed);
233        drop(inner_guard);
234        assert!(
235            !old_polling,
236            "async-fn-stream invariant violation: polling must be false before entering poll"
237        );
238        let r = this.fut.poll(cx);
239        let mut inner_guard = this.inner.lock().expect("mutex was poisoned");
240        inner_guard
241            .polling
242            .store(false, std::sync::atomic::Ordering::Relaxed);
243        match r {
244            std::task::Poll::Ready(Ok(())) => Poll::Ready(None),
245            std::task::Poll::Ready(Err(e)) => {
246                *this.is_err = true;
247                Poll::Ready(Some(Err(e)))
248            }
249            std::task::Poll::Pending => {
250                if let Some(value) = inner_guard.pending_values.pop() {
251                    Poll::Ready(Some(value))
252                } else {
253                    Poll::Pending
254                }
255            }
256        }
257    }
258}
259
260impl<T> StreamEmitter<T> {
261    /// Emit value from a stream and wait until stream consumer calls [`futures_util::StreamExt::next`] again.
262    ///
263    /// # Panics
264    /// Will panic if:
265    /// * `emit` is called twice without awaiting result of first call
266    /// * `emit` is called not in context of polling the stream
267    #[must_use = "Ensure that emit() is awaited"]
268    pub fn emit(&'_ self, value: T) -> EmitFuture<'_, T> {
269        EmitFuture::new(&self.inner, value)
270    }
271}
272
273impl<T, E> TryStreamEmitter<T, E> {
274    /// Emit value from a stream and wait until stream consumer calls [`futures_util::StreamExt::next`] again.
275    ///
276    /// # Panics
277    /// Will panic if:
278    /// * `emit`/`emit_err` is called twice without awaiting result of the first call
279    /// * `emit` is called not in context of polling the stream
280    #[must_use = "Ensure that emit() is awaited"]
281    pub fn emit(&'_ self, value: T) -> EmitFuture<'_, Result<T, E>> {
282        EmitFuture::new(&self.inner, Ok(value))
283    }
284
285    /// Emit error from a stream and wait until stream consumer calls [`futures_util::StreamExt::next`] again.
286    ///
287    /// # Panics
288    /// Will panic if:
289    /// * `emit`/`emit_err` is called twice without awaiting result of the first call
290    /// * `emit_err` is called not in context of polling the stream
291    #[must_use = "Ensure that emit_err() is awaited"]
292    pub fn emit_err(&'_ self, err: E) -> EmitFuture<'_, Result<T, E>> {
293        EmitFuture::new(&self.inner, Err(err))
294    }
295}
296
297pin_project! {
298    /// Future returned from [`StreamEmitter::emit`].
299    pub struct EmitFuture<'a, T> {
300        inner: &'a Mutex<Inner<T>>,
301        value: Option<T>,
302    }
303}
304
305impl<'a, T> EmitFuture<'a, T> {
306    fn new(inner: &'a Mutex<Inner<T>>, value: T) -> Self {
307        Self {
308            inner,
309            value: Some(value),
310        }
311    }
312}
313
314impl<T> Future for EmitFuture<'_, T> {
315    type Output = ();
316
317    fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
318        let this = self.project();
319        let mut inner_guard = this.inner.lock().expect("Mutex was poisoned");
320        let inner = &mut *inner_guard;
321        assert!(
322            inner.polling.load(std::sync::atomic::Ordering::Relaxed),
323            "StreamEmitter::emit().await should only be called in context of `fn_stream()`/`try_fn_stream()`"
324        );
325
326        if let Some(value) = this.value.take() {
327            inner.pending_values.push(value);
328            let is_same_waker = if let Some(stream_waker) = inner.stream_waker.as_ref() {
329                stream_waker.will_wake(cx.waker())
330            } else {
331                false
332            };
333            if !is_same_waker {
334                inner.pending_wakers.push(cx.waker().clone());
335            }
336            Poll::Pending
337        } else if inner.pending_values.is_empty() {
338            // stream only polls the future after draining `inner.pending_values`, so this check should not be necessary in theory;
339            // this is just a safeguard against misuses; e.g. if a future calls `.emit().poll()` in a loop without yielding on `Poll::Pending`,
340            // this would lead to overflow of `inner.pending_values`
341            Poll::Ready(())
342        } else {
343            Poll::Pending
344        }
345    }
346}
347
348#[cfg(test)]
349mod tests {
350    use std::io::ErrorKind;
351
352    use futures_util::{pin_mut, stream::FuturesUnordered, StreamExt};
353
354    use super::*;
355
356    #[test]
357    fn infallible_works() {
358        futures_executor::block_on(async {
359            let stream = fn_stream(|emitter| async move {
360                eprintln!("stream 1");
361                emitter.emit(1).await;
362                eprintln!("stream 2");
363                emitter.emit(2).await;
364                eprintln!("stream 3");
365            });
366            pin_mut!(stream);
367            assert_eq!(Some(1), stream.next().await);
368            assert_eq!(Some(2), stream.next().await);
369            assert_eq!(None, stream.next().await);
370        });
371    }
372
373    #[test]
374    fn infallible_lifetime() {
375        let a = 1;
376        futures_executor::block_on(async {
377            let b = 2;
378            let a = &a;
379            let b = &b;
380            let stream = fn_stream(|emitter| async move {
381                eprintln!("stream 1");
382                emitter.emit(a).await;
383                eprintln!("stream 2");
384                emitter.emit(b).await;
385                eprintln!("stream 3");
386            });
387            pin_mut!(stream);
388            assert_eq!(Some(a), stream.next().await);
389            assert_eq!(Some(b), stream.next().await);
390            assert_eq!(None, stream.next().await);
391        });
392    }
393
394    #[test]
395    fn infallible_unawaited_emit_is_ignored() {
396        futures_executor::block_on(async {
397            #[expect(
398                unused_must_use,
399                reason = "this code intentionally does not await emitter.emit()"
400            )]
401            let stream = fn_stream(|emitter| async move {
402                emitter.emit(1)/* .await */;
403                emitter.emit(2)/* .await */;
404                emitter.emit(3).await;
405            });
406            pin_mut!(stream);
407            assert_eq!(Some(3), stream.next().await);
408            assert_eq!(None, stream.next().await);
409        });
410    }
411
412    #[test]
413    fn fallible_works() {
414        futures_executor::block_on(async {
415            let stream = try_fn_stream(|emitter| async move {
416                eprintln!("try stream 1");
417                emitter.emit(1).await;
418                eprintln!("try stream 2");
419                emitter.emit(2).await;
420                eprintln!("try stream 3");
421                Err(std::io::Error::from(ErrorKind::Other))
422            });
423            pin_mut!(stream);
424            assert_eq!(1, stream.next().await.unwrap().unwrap());
425            assert_eq!(2, stream.next().await.unwrap().unwrap());
426            assert!(stream.next().await.unwrap().is_err());
427            assert!(stream.next().await.is_none());
428        });
429    }
430
431    #[test]
432    fn fallible_emit_err_works() {
433        futures_executor::block_on(async {
434            let stream = try_fn_stream(|emitter| async move {
435                eprintln!("try stream 1");
436                emitter.emit(1).await;
437                eprintln!("try stream 2");
438                emitter.emit(2).await;
439                eprintln!("try stream 3");
440                emitter
441                    .emit_err(std::io::Error::from(ErrorKind::Other))
442                    .await;
443                eprintln!("try stream 4");
444                Err(std::io::Error::from(ErrorKind::Other))
445            });
446            pin_mut!(stream);
447            assert_eq!(1, stream.next().await.unwrap().unwrap());
448            assert_eq!(2, stream.next().await.unwrap().unwrap());
449            assert!(stream.next().await.unwrap().is_err());
450            assert!(stream.next().await.unwrap().is_err());
451            assert!(stream.next().await.is_none());
452        });
453    }
454
455    #[test]
456    fn method_async() {
457        struct St {
458            a: String,
459        }
460
461        impl St {
462            async fn f1(&self) -> impl Stream<Item = &str> {
463                self.f2().await
464            }
465
466            #[allow(clippy::unused_async)]
467            async fn f2(&self) -> impl Stream<Item = &str> {
468                fn_stream(|emitter| async move {
469                    emitter.emit(self.a.as_str()).await;
470                    emitter.emit(self.a.as_str()).await;
471                    emitter.emit(self.a.as_str()).await;
472                })
473            }
474        }
475
476        futures_executor::block_on(async {
477            let l = St {
478                a: "qwe".to_owned(),
479            };
480            let s = l.f1().await;
481            let z: Vec<&str> = s.collect().await;
482            assert_eq!(z, ["qwe", "qwe", "qwe"]);
483        });
484    }
485
486    #[test]
487    fn tokio_join_one_works() {
488        futures_executor::block_on(async {
489            let stream = fn_stream(|emitter| async move {
490                tokio::join!(async { emitter.emit(1).await },);
491                emitter.emit(2).await;
492            });
493            pin_mut!(stream);
494            assert_eq!(Some(1), stream.next().await);
495            assert_eq!(Some(2), stream.next().await);
496            assert_eq!(None, stream.next().await);
497        });
498    }
499
500    #[test]
501    fn tokio_join_many_works() {
502        futures_executor::block_on(async {
503            let stream = fn_stream(|emitter| async move {
504                eprintln!("try stream 1");
505                tokio::join!(
506                    async { emitter.emit(1).await },
507                    async { emitter.emit(2).await },
508                    async { emitter.emit(3).await },
509                );
510                emitter.emit(4).await;
511            });
512            pin_mut!(stream);
513            for _ in 0..3 {
514                let item = stream.next().await;
515                assert!(matches!(item, Some(1..=3)));
516            }
517            assert_eq!(Some(4), stream.next().await);
518            assert_eq!(None, stream.next().await);
519        });
520    }
521
522    #[test]
523    fn tokio_futures_unordered_one_works() {
524        futures_executor::block_on(async {
525            let stream = fn_stream(|emitter| async move {
526                let mut futs: FuturesUnordered<_> = (1..=1)
527                    .map(|i| {
528                        let emitter = &emitter;
529                        async move { emitter.emit(i).await }
530                    })
531                    .collect();
532                while futs.next().await.is_some() {}
533                emitter.emit(2).await;
534            });
535            pin_mut!(stream);
536            assert_eq!(Some(1), stream.next().await);
537            assert_eq!(Some(2), stream.next().await);
538            assert_eq!(None, stream.next().await);
539        });
540    }
541
542    #[test]
543    fn tokio_futures_unordered_many_works() {
544        futures_executor::block_on(async {
545            let stream = fn_stream(|emitter| async move {
546                let mut futs: FuturesUnordered<_> = (1..=3)
547                    .map(|i| {
548                        let emitter = &emitter;
549                        async move { emitter.emit(i).await }
550                    })
551                    .collect();
552                while futs.next().await.is_some() {}
553                emitter.emit(4).await;
554            });
555            pin_mut!(stream);
556            for _ in 1..=3 {
557                let item = stream.next().await;
558                assert!(matches!(item, Some(1..=3)));
559            }
560            assert_eq!(Some(4), stream.next().await);
561            assert_eq!(None, stream.next().await);
562        });
563    }
564}