async_fn_stream/
lib.rs

1#![doc = include_str!("../README.md")]
2
3use std::{
4    cell::{Cell, UnsafeCell},
5    panic::{RefUnwindSafe, UnwindSafe},
6    pin::Pin,
7    sync::Arc,
8    task::{Poll, Waker},
9};
10
11use futures_util::{Future, Stream};
12use pin_project_lite::pin_project;
13use smallvec::SmallVec;
14
15/// An intermediary that transfers values from stream to its consumer
16pub struct StreamEmitter<T> {
17    inner: Arc<UnsafeCell<Inner<T>>>,
18}
19
20/// An intermediary that transfers values from stream to its consumer
21pub struct TryStreamEmitter<T, E> {
22    inner: Arc<UnsafeCell<Inner<Result<T, E>>>>,
23}
24
25thread_local! {
26    /// A type-erased pointer to the `Inner<_>`, for which a call to `Stream::poll_next` is active.
27    static ACTIVE_STREAM_INNER: Cell<*const ()> = const { Cell::new(std::ptr::null()) };
28}
29
30/// A guard that ensures that `ACTIVE_STREAM_INNER` is returned to its previous value even in case of a panic.
31struct ActiveStreamPointerGuard {
32    old_ptr: *const (),
33}
34
35impl ActiveStreamPointerGuard {
36    fn set_active_ptr(ptr: *const ()) -> Self {
37        let old_ptr = ACTIVE_STREAM_INNER.with(|thread_ptr| thread_ptr.replace(ptr));
38        Self { old_ptr }
39    }
40}
41
42impl Drop for ActiveStreamPointerGuard {
43    fn drop(&mut self) {
44        ACTIVE_STREAM_INNER.with(|thread_ptr| thread_ptr.set(self.old_ptr));
45    }
46}
47
48/// SAFETY:
49/// `Inner<T>` stores the state shared between `StreamEmitter`/`TryStreamEmitter` and `FnStream`/`TryFnStream`.
50/// It is stored within `Arc<UnsafeCell<_>>`. Exclusive access to it is controlled using:
51/// 1) exclusive reference `Pin<&mut FnStream>` which is provided to `FnStream::poll_next`
52/// 2) `ACTIVE_STREAM_INNER`, which is itself managed by `FnStream::poll_next`.
53///    - `ACTIVE_STREAM_INNER` is only set when the corresponding `Pin<&mut FnStream>` is on the stack
54///    - `EmitFuture` can safely access its inner reference if it equals to `ACTIVE_STREAM_INNER` (all other accesses are invalid and lead to panics)
55struct Inner<T> {
56    // `stream_waker` is used to compare the waker of the stream future with the waker of the `Emit` future.
57    // If the stream implementation does not have sub-executors, we don't have to push wakers to `pending_wakers`, which avoids cloning it.
58    stream_waker: Option<Waker>,
59    // Due to internal concurrency, a single call to stream's future may yield multiple elements.
60    // All elements are stored here and yielded from the stream before polling the future again.
61    pending_values: SmallVec<[T; 1]>,
62    pending_wakers: SmallVec<[Waker; 1]>,
63}
64
65// SAFETY: `FnStream` implements own synchronization
66unsafe impl<T: Send, Fut: Future<Output = ()> + Send> Send for FnStream<T, Fut> {}
67// SAFETY: `FnStream` implements own synchronization
68unsafe impl<T: Send, Fut: Future<Output = ()> + Sync> Sync for FnStream<T, Fut> {}
69impl<T: UnwindSafe, Fut: Future<Output = ()> + UnwindSafe> UnwindSafe for FnStream<T, Fut> {}
70impl<T: RefUnwindSafe, Fut: Future<Output = ()> + RefUnwindSafe> RefUnwindSafe
71    for FnStream<T, Fut>
72{
73}
74// SAFETY: `TryFnStream` implements own synchronization
75unsafe impl<T: Send, E: Send, Fut: Future<Output = Result<(), E>> + Send> Send
76    for TryFnStream<T, E, Fut>
77{
78}
79// SAFETY: `TryFnStream` implements own synchronization
80unsafe impl<T: Send, E: Send, Fut: Future<Output = Result<(), E>> + Sync> Sync
81    for TryFnStream<T, E, Fut>
82{
83}
84impl<T: UnwindSafe, E: UnwindSafe, Fut: Future<Output = Result<(), E>> + UnwindSafe> UnwindSafe
85    for TryFnStream<T, E, Fut>
86{
87}
88impl<T: RefUnwindSafe, E: RefUnwindSafe, Fut: Future<Output = Result<(), E>> + RefUnwindSafe>
89    RefUnwindSafe for TryFnStream<T, E, Fut>
90{
91}
92// SAFETY: `StreamEmitter` implements own synchronization
93unsafe impl<T: Send> Send for StreamEmitter<T> {}
94// SAFETY: `StreamEmitter` implements own synchronization
95unsafe impl<T: Send> Sync for StreamEmitter<T> {}
96impl<T: UnwindSafe> UnwindSafe for StreamEmitter<T> {}
97impl<T: RefUnwindSafe> RefUnwindSafe for StreamEmitter<T> {}
98// SAFETY: `TryStreamEmitter` implements own synchronization
99unsafe impl<T: Send, E: Send> Send for TryStreamEmitter<T, E> {}
100// SAFETY: `TryStreamEmitter` implements own synchronization
101unsafe impl<T: Send, E: Send> Sync for TryStreamEmitter<T, E> {}
102impl<T: UnwindSafe, E: UnwindSafe> UnwindSafe for TryStreamEmitter<T, E> {}
103impl<T: RefUnwindSafe, E: RefUnwindSafe> RefUnwindSafe for TryStreamEmitter<T, E> {}
104// SAFETY: `EmitFuture` implements own synchronization.
105unsafe impl<T: Send> Send for EmitFuture<'_, T> {}
106// SAFETY: `EmitFuture` implements own synchronization
107unsafe impl<T: Send> Sync for EmitFuture<'_, T> {}
108impl<T: UnwindSafe> UnwindSafe for EmitFuture<'_, T> {}
109impl<T: RefUnwindSafe> RefUnwindSafe for EmitFuture<'_, T> {}
110
111pin_project! {
112    /// Implementation of [`Stream`] trait created by [`fn_stream`].
113    pub struct FnStream<T, Fut: Future<Output = ()>> {
114        #[pin]
115        fut: Fut,
116        inner: Arc<UnsafeCell<Inner<T>>>,
117    }
118}
119
120/// Create a new infallible stream which is implemented by `func`.
121///
122/// Caller should pass an async function which will return successive stream elements via [`StreamEmitter::emit`].
123///
124/// # Example
125///
126/// ```rust
127/// use async_fn_stream::fn_stream;
128/// use futures_util::Stream;
129///
130/// fn build_stream() -> impl Stream<Item = i32> {
131///     fn_stream(|emitter| async move {
132///         for i in 0..3 {
133///             // yield elements from stream via `emitter`
134///             emitter.emit(i).await;
135///         }
136///     })
137/// }
138/// ```
139pub fn fn_stream<T, Fut: Future<Output = ()>>(
140    func: impl FnOnce(StreamEmitter<T>) -> Fut,
141) -> FnStream<T, Fut> {
142    FnStream::new(func)
143}
144
145impl<T, Fut: Future<Output = ()>> FnStream<T, Fut> {
146    fn new<F: FnOnce(StreamEmitter<T>) -> Fut>(func: F) -> Self {
147        let inner = Arc::new(UnsafeCell::new(Inner {
148            stream_waker: None,
149            pending_values: SmallVec::new(),
150            pending_wakers: SmallVec::new(),
151        }));
152        let emitter = StreamEmitter {
153            inner: inner.clone(),
154        };
155        let fut = func(emitter);
156        Self { fut, inner }
157    }
158}
159
160impl<T, Fut: Future<Output = ()>> Stream for FnStream<T, Fut> {
161    type Item = T;
162
163    fn poll_next(
164        self: Pin<&mut Self>,
165        cx: &mut std::task::Context<'_>,
166    ) -> Poll<Option<Self::Item>> {
167        let this = self.project();
168
169        // SAFETY:
170        // (see safety comment for `Inner<T>`)
171        // 1) we have no aliasing, since we're holding unique reference to `Self`, and
172        // 2) `this.inner` is not deallocated for the duration of this method
173        let inner = unsafe { &mut *this.inner.get() };
174        if let Some(value) = inner.pending_values.pop() {
175            return Poll::Ready(Some(value));
176        }
177        if !inner.pending_wakers.is_empty() {
178            for waker in inner.pending_wakers.drain(..) {
179                if !waker.will_wake(cx.waker()) {
180                    waker.wake();
181                }
182            }
183        }
184        if let Some(stream_waker) = inner.stream_waker.as_mut() {
185            stream_waker.clone_from(cx.waker());
186        } else {
187            inner.stream_waker = Some(cx.waker().clone());
188        }
189
190        // SAFETY: ensure that we're not holding a reference to this.inner
191        _ = inner;
192
193        // SAFETY for Inner<T>:
194        // - `ACTIVE_STREAM_INNER` now contains valid pointer to `this.inner` during the call to `fut.poll`
195        // - `ACTIVE_STREAM_INNER` is restored after the call due to the use of guard
196        let polling_ptr_guard =
197            ActiveStreamPointerGuard::set_active_ptr(Arc::as_ptr(&*this.inner).cast());
198        let r = this.fut.poll(cx);
199        drop(polling_ptr_guard);
200
201        // SAFETY:
202        // (see safety comment for `Inner<T>`)
203        // 1) we have no aliasing, since:
204        //    - we're holding unique reference to `Self`, and
205        //    - we removed the pointer from `ACTIVE_STREAM_INNER`
206        // 2) `this.inner` is not deallocated for the duration of this method
207        let inner = unsafe { &mut *this.inner.get() };
208
209        match r {
210            std::task::Poll::Ready(()) => Poll::Ready(None),
211            std::task::Poll::Pending => {
212                if let Some(value) = inner.pending_values.pop() {
213                    Poll::Ready(Some(value))
214                } else {
215                    Poll::Pending
216                }
217            }
218        }
219    }
220}
221
222/// Create a new fallible stream which is implemented by `func`.
223///
224/// Caller should pass an async function which can:
225///
226/// - return successive stream elements via [`TryStreamEmitter::emit`]
227/// - return transient errors via [`TryStreamEmitter::emit_err`]
228/// - return fatal errors as [`Result::Err`]
229///
230/// # Example
231/// ```rust
232/// use async_fn_stream::try_fn_stream;
233/// use futures_util::Stream;
234///
235/// fn build_stream() -> impl Stream<Item = Result<i32, anyhow::Error>> {
236///     try_fn_stream(|emitter| async move {
237///         for i in 0..3 {
238///             // yield elements from stream via `emitter`
239///             emitter.emit(i).await;
240///         }
241///
242///         // return errors view emitter without ending the stream
243///         emitter.emit_err(anyhow::anyhow!("An error happened"));
244///
245///         // return errors from stream, ending the stream
246///         Err(anyhow::anyhow!("An error happened"))
247///     })
248/// }
249/// ```
250pub fn try_fn_stream<T, E, Fut: Future<Output = Result<(), E>>>(
251    func: impl FnOnce(TryStreamEmitter<T, E>) -> Fut,
252) -> TryFnStream<T, E, Fut> {
253    TryFnStream::new(func)
254}
255
256pin_project! {
257    /// Implementation of [`Stream`] trait created by [`try_fn_stream`].
258    pub struct TryFnStream<T, E, Fut: Future<Output = Result<(), E>>> {
259        is_err: bool,
260        #[pin]
261        fut: Fut,
262        inner: Arc<UnsafeCell<Inner<Result<T, E>>>>,
263    }
264}
265
266impl<T, E, Fut: Future<Output = Result<(), E>>> TryFnStream<T, E, Fut> {
267    fn new<F: FnOnce(TryStreamEmitter<T, E>) -> Fut>(func: F) -> Self {
268        let inner = Arc::new(UnsafeCell::new(Inner {
269            stream_waker: None,
270            pending_values: SmallVec::new(),
271            pending_wakers: SmallVec::new(),
272        }));
273        let emitter = TryStreamEmitter {
274            inner: inner.clone(),
275        };
276        let fut = func(emitter);
277        Self {
278            is_err: false,
279            fut,
280            inner,
281        }
282    }
283}
284
285impl<T, E, Fut: Future<Output = Result<(), E>>> Stream for TryFnStream<T, E, Fut> {
286    type Item = Result<T, E>;
287
288    fn poll_next(
289        self: Pin<&mut Self>,
290        cx: &mut std::task::Context<'_>,
291    ) -> Poll<Option<Self::Item>> {
292        // TODO: merge the implementation with `FnStream`
293        if self.is_err {
294            return Poll::Ready(None);
295        }
296        let this = self.project();
297        // SAFETY:
298        // (see safety comment for `Inner<T>`)
299        // 1) we have no aliasing, since we're holding unique reference to `Self`, and
300        // 2) `this.inner` is not deallocated for the duration of this method
301        let inner = unsafe { &mut *this.inner.get() };
302        if let Some(value) = inner.pending_values.pop() {
303            return Poll::Ready(Some(value));
304        }
305        if !inner.pending_wakers.is_empty() {
306            for waker in inner.pending_wakers.drain(..) {
307                if !waker.will_wake(cx.waker()) {
308                    waker.wake();
309                }
310            }
311        }
312        if let Some(stream_waker) = inner.stream_waker.as_mut() {
313            stream_waker.clone_from(cx.waker());
314        } else {
315            inner.stream_waker = Some(cx.waker().clone());
316        }
317
318        // SAFETY: ensure that we're not holding a reference to this.inner
319        _ = inner;
320
321        // SAFETY for Inner<T>:
322        // - `ACTIVE_STREAM_INNER` now contains valid pointer to `this.inner` during the call to `fut.poll`
323        // - `ACTIVE_STREAM_INNER` is restored after the call due to the use of guard
324        let polling_ptr_guard =
325            ActiveStreamPointerGuard::set_active_ptr(Arc::as_ptr(&*this.inner).cast());
326        let r = this.fut.poll(cx);
327        drop(polling_ptr_guard);
328
329        // SAFETY:
330        // (see safety comment for `Inner<T>`)
331        // 1) we have no aliasing, since:
332        //    - we're holding unique reference to `Self`, and
333        //    - we removed the pointer from `ACTIVE_STREAM_INNER`
334        // 2) `this.inner` is not deallocated for the duration of this method
335        let inner = unsafe { &mut *this.inner.get() };
336        match r {
337            std::task::Poll::Ready(Ok(())) => Poll::Ready(None),
338            std::task::Poll::Ready(Err(e)) => {
339                *this.is_err = true;
340                Poll::Ready(Some(Err(e)))
341            }
342            std::task::Poll::Pending => {
343                if let Some(value) = inner.pending_values.pop() {
344                    Poll::Ready(Some(value))
345                } else {
346                    Poll::Pending
347                }
348            }
349        }
350    }
351}
352
353impl<T> StreamEmitter<T> {
354    /// Emit value from a stream and wait until stream consumer calls [`futures_util::StreamExt::next`] again.
355    ///
356    /// # Panics
357    /// Will panic if:
358    /// * `emit` is called not in context of polling the stream
359    #[must_use = "Ensure that emit() is awaited"]
360    pub fn emit(&'_ self, value: T) -> EmitFuture<'_, T> {
361        EmitFuture::new(&self.inner, value)
362    }
363}
364
365impl<T, E> TryStreamEmitter<T, E> {
366    /// Emit value from a stream and wait until stream consumer calls [`futures_util::StreamExt::next`] again.
367    ///
368    /// # Panics
369    /// Will panic if:
370    /// * `emit` is called not in context of polling the stream
371    #[must_use = "Ensure that emit() is awaited"]
372    pub fn emit(&'_ self, value: T) -> EmitFuture<'_, Result<T, E>> {
373        EmitFuture::new(&self.inner, Ok(value))
374    }
375
376    /// Emit value from a stream and wait until stream consumer calls [`futures_util::StreamExt::next`] again.
377    ///
378    /// # Panics
379    /// Will panic if:
380    /// * `emit_result` is called not in context of polling the stream
381    #[must_use = "Ensure that emit() is awaited"]
382    pub fn emit_result(&'_ self, value: Result<T, E>) -> EmitFuture<'_, Result<T, E>> {
383        EmitFuture::new(&self.inner, value)
384    }
385
386    /// Emit error from a stream and wait until stream consumer calls [`futures_util::StreamExt::next`] again.
387    ///
388    /// # Panics
389    /// Will panic if:
390    /// * `emit_err` is called not in context of polling the stream
391    #[must_use = "Ensure that emit_err() is awaited"]
392    pub fn emit_err(&'_ self, err: E) -> EmitFuture<'_, Result<T, E>> {
393        EmitFuture::new(&self.inner, Err(err))
394    }
395}
396
397pin_project! {
398    /// Future returned from [`StreamEmitter::emit`].
399    pub struct EmitFuture<'a, T> {
400        inner: &'a UnsafeCell<Inner<T>>,
401        value: Option<T>,
402    }
403}
404
405impl<'a, T> EmitFuture<'a, T> {
406    fn new(inner: &'a UnsafeCell<Inner<T>>, value: T) -> Self {
407        Self {
408            inner,
409            value: Some(value),
410        }
411    }
412}
413
414impl<T> Future for EmitFuture<'_, T> {
415    type Output = ();
416
417    fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
418        let this = self.project();
419        assert!(
420            ACTIVE_STREAM_INNER.get() == std::ptr::from_ref(*this.inner).cast::<()>(),
421            "StreamEmitter::emit().await should only be called in the context of the corresponding `fn_stream()`/`try_fn_stream()`"
422        );
423        // SAFETY:
424        // 1) we hold a unique reference to `this.inner` since we verified that ACTIVE_STREAM_INNER == self.inner
425        //    - we're calling `{Try,}FnStream::poll_next` in this thread, which holds the only other reference to the same instance of `Inner<T>`
426        //    - `{Try,}FnStream::poll_next` is not holding a reference to `self.inner` during the call to `fut.poll`
427        // 2) `this.inner` is not deallocated for the duration of this method
428        let inner = unsafe { &mut *this.inner.get() };
429
430        if let Some(value) = this.value.take() {
431            inner.pending_values.push(value);
432            let is_same_waker = if let Some(stream_waker) = inner.stream_waker.as_ref() {
433                stream_waker.will_wake(cx.waker())
434            } else {
435                false
436            };
437            if !is_same_waker {
438                inner.pending_wakers.push(cx.waker().clone());
439            }
440            Poll::Pending
441        } else if inner.pending_values.is_empty() {
442            // stream only polls the future after draining `inner.pending_values`, so this check should not be necessary in theory;
443            // this is just a safeguard against misuses; e.g. if a future calls `.emit().poll()` in a loop without yielding on `Poll::Pending`,
444            // this would lead to overflow of `inner.pending_values`
445            Poll::Ready(())
446        } else {
447            Poll::Pending
448        }
449    }
450}
451
452#[cfg(test)]
453mod tests {
454    use std::{io::ErrorKind, pin::pin};
455
456    use futures_util::{pin_mut, stream::FuturesUnordered, StreamExt};
457
458    use super::*;
459
460    #[test]
461    fn infallible_works() {
462        futures_executor::block_on(async {
463            let stream = fn_stream(|emitter| async move {
464                eprintln!("stream 1");
465                emitter.emit(1).await;
466                eprintln!("stream 2");
467                emitter.emit(2).await;
468                eprintln!("stream 3");
469            });
470            pin_mut!(stream);
471            assert_eq!(Some(1), stream.next().await);
472            assert_eq!(Some(2), stream.next().await);
473            assert_eq!(None, stream.next().await);
474        });
475    }
476
477    #[test]
478    fn infallible_lifetime() {
479        let a = 1;
480        futures_executor::block_on(async {
481            let b = 2;
482            let a = &a;
483            let b = &b;
484            let stream = fn_stream(|emitter| async move {
485                eprintln!("stream 1");
486                emitter.emit(a).await;
487                eprintln!("stream 2");
488                emitter.emit(b).await;
489                eprintln!("stream 3");
490            });
491            pin_mut!(stream);
492            assert_eq!(Some(a), stream.next().await);
493            assert_eq!(Some(b), stream.next().await);
494            assert_eq!(None, stream.next().await);
495        });
496    }
497
498    #[test]
499    fn infallible_unawaited_emit_is_ignored() {
500        futures_executor::block_on(async {
501            #[expect(
502                unused_must_use,
503                reason = "this code intentionally does not await emitter.emit()"
504            )]
505            let stream = fn_stream(|emitter| async move {
506                emitter.emit(1)/* .await */;
507                emitter.emit(2)/* .await */;
508                emitter.emit(3).await;
509            });
510            pin_mut!(stream);
511            assert_eq!(Some(3), stream.next().await);
512            assert_eq!(None, stream.next().await);
513        });
514    }
515
516    #[test]
517    fn fallible_works() {
518        futures_executor::block_on(async {
519            let stream = try_fn_stream(|emitter| async move {
520                eprintln!("try stream 1");
521                emitter.emit(1).await;
522                eprintln!("try stream 2");
523                emitter.emit(2).await;
524                eprintln!("try stream 3");
525                Err(std::io::Error::from(ErrorKind::Other))
526            });
527            pin_mut!(stream);
528            assert_eq!(1, stream.next().await.unwrap().unwrap());
529            assert_eq!(2, stream.next().await.unwrap().unwrap());
530            assert!(stream.next().await.unwrap().is_err());
531            assert!(stream.next().await.is_none());
532        });
533    }
534
535    #[test]
536    fn fallible_emit_err_works() {
537        futures_executor::block_on(async {
538            let stream = try_fn_stream(|emitter| async move {
539                eprintln!("try stream 1");
540                emitter.emit(1).await;
541                eprintln!("try stream 2");
542                emitter.emit_result(Ok(2)).await;
543                eprintln!("try stream 3");
544                emitter
545                    .emit_err(std::io::Error::from(ErrorKind::Other))
546                    .await;
547                eprintln!("try stream 4");
548                emitter
549                    .emit_result(Err(std::io::Error::from(ErrorKind::Other)))
550                    .await;
551                eprintln!("try stream 5");
552                Err(std::io::Error::from(ErrorKind::Other))
553            });
554            pin_mut!(stream);
555            assert_eq!(1, stream.next().await.unwrap().unwrap());
556            assert_eq!(2, stream.next().await.unwrap().unwrap());
557            assert!(stream.next().await.unwrap().is_err());
558            assert!(stream.next().await.unwrap().is_err());
559            assert!(stream.next().await.unwrap().is_err());
560            assert!(stream.next().await.is_none());
561        });
562    }
563
564    #[test]
565    fn method_async() {
566        struct St {
567            a: String,
568        }
569
570        impl St {
571            async fn f1(&self) -> impl Stream<Item = &str> {
572                self.f2().await
573            }
574
575            #[allow(clippy::unused_async)]
576            async fn f2(&self) -> impl Stream<Item = &str> {
577                fn_stream(|emitter| async move {
578                    emitter.emit(self.a.as_str()).await;
579                    emitter.emit(self.a.as_str()).await;
580                    emitter.emit(self.a.as_str()).await;
581                })
582            }
583        }
584
585        futures_executor::block_on(async {
586            let l = St {
587                a: "qwe".to_owned(),
588            };
589            let s = l.f1().await;
590            let z: Vec<&str> = s.collect().await;
591            assert_eq!(z, ["qwe", "qwe", "qwe"]);
592        });
593    }
594
595    #[test]
596    fn tokio_join_one_works() {
597        futures_executor::block_on(async {
598            let stream = fn_stream(|emitter| async move {
599                tokio::join!(async { emitter.emit(1).await },);
600                emitter.emit(2).await;
601            });
602            pin_mut!(stream);
603            assert_eq!(Some(1), stream.next().await);
604            assert_eq!(Some(2), stream.next().await);
605            assert_eq!(None, stream.next().await);
606        });
607    }
608
609    #[test]
610    fn tokio_join_many_works() {
611        futures_executor::block_on(async {
612            let stream = fn_stream(|emitter| async move {
613                eprintln!("try stream 1");
614                tokio::join!(
615                    async { emitter.emit(1).await },
616                    async { emitter.emit(2).await },
617                    async { emitter.emit(3).await },
618                );
619                emitter.emit(4).await;
620            });
621            pin_mut!(stream);
622            for _ in 0..3 {
623                let item = stream.next().await;
624                assert!(matches!(item, Some(1..=3)));
625            }
626            assert_eq!(Some(4), stream.next().await);
627            assert_eq!(None, stream.next().await);
628        });
629    }
630
631    #[test]
632    fn tokio_futures_unordered_one_works() {
633        futures_executor::block_on(async {
634            let stream = fn_stream(|emitter| async move {
635                let mut futs: FuturesUnordered<_> = (1..=1)
636                    .map(|i| {
637                        let emitter = &emitter;
638                        async move { emitter.emit(i).await }
639                    })
640                    .collect();
641                while futs.next().await.is_some() {}
642                emitter.emit(2).await;
643            });
644            pin_mut!(stream);
645            assert_eq!(Some(1), stream.next().await);
646            assert_eq!(Some(2), stream.next().await);
647            assert_eq!(None, stream.next().await);
648        });
649    }
650
651    #[test]
652    fn tokio_futures_unordered_many_works() {
653        futures_executor::block_on(async {
654            let stream = fn_stream(|emitter| async move {
655                let mut futs: FuturesUnordered<_> = (1..=3)
656                    .map(|i| {
657                        let emitter = &emitter;
658                        async move { emitter.emit(i).await }
659                    })
660                    .collect();
661                while futs.next().await.is_some() {}
662                emitter.emit(4).await;
663            });
664            pin_mut!(stream);
665            for _ in 1..=3 {
666                let item = stream.next().await;
667                assert!(matches!(item, Some(1..=3)));
668            }
669            assert_eq!(Some(4), stream.next().await);
670            assert_eq!(None, stream.next().await);
671        });
672    }
673
674    #[test]
675    fn infallible_nested_streams_work() {
676        futures_executor::block_on(async {
677            let mut stream = pin!(fn_stream(|emitter| async move {
678                for i in 0..3 {
679                    let mut stream_2 = pin!(fn_stream(|emitter| async move {
680                        for j in 0..3 {
681                            emitter.emit(j).await;
682                        }
683                    }));
684                    while let Some(item) = stream_2.next().await {
685                        emitter.emit(3 * i + item).await;
686                    }
687                }
688            }));
689            let mut sum = 0;
690            while let Some(item) = stream.next().await {
691                sum += item;
692            }
693            assert_eq!(sum, 36);
694        });
695    }
696
697    #[test]
698    fn fallible_nested_streams_work() {
699        futures_executor::block_on(async {
700            let mut stream = pin!(try_fn_stream(|emitter| async move {
701                for i in 0..3 {
702                    let mut stream_2 = pin!(try_fn_stream(|emitter| async move {
703                        for j in 0..3 {
704                            emitter.emit(j).await;
705                        }
706                        Ok::<_, ()>(())
707                    }));
708                    while let Some(Ok(item)) = stream_2.next().await {
709                        emitter.emit(3 * i + item).await;
710                    }
711                }
712                Ok::<_, ()>(())
713            }));
714            let mut sum = 0;
715            while let Some(Ok(item)) = stream.next().await {
716                sum += item;
717            }
718            assert_eq!(sum, 36);
719        });
720    }
721
722    #[test]
723    #[should_panic(
724        expected = "StreamEmitter::emit().await should only be called in the context of the corresponding `fn_stream()`/`try_fn_stream()`"
725    )]
726    fn infallible_bad_nested_emit_detected() {
727        futures_executor::block_on(async {
728            let mut stream = pin!(fn_stream(|emitter| async move {
729                for i in 0..3 {
730                    let emitter_ref = &emitter;
731                    let mut stream_2 = pin!(fn_stream(|emitter_2| async move {
732                        emitter_2.emit(0).await;
733                        for j in 0..3 {
734                            emitter_ref.emit(j).await;
735                        }
736                    }));
737                    while let Some(item) = stream_2.next().await {
738                        emitter.emit(3 * i + item).await;
739                    }
740                }
741            }));
742
743            let mut sum = 0;
744            while let Some(item) = stream.next().await {
745                sum += item;
746            }
747            assert_eq!(sum, 36);
748        });
749    }
750
751    #[test]
752    #[should_panic(
753        expected = "StreamEmitter::emit().await should only be called in the context of the corresponding `fn_stream()`/`try_fn_stream()`"
754    )]
755    fn fallible_bad_nested_emit_detected() {
756        futures_executor::block_on(async {
757            let mut stream = pin!(try_fn_stream(|emitter| async move {
758                for i in 0..3 {
759                    let emitter_ref = &emitter;
760                    let mut stream_2 = pin!(try_fn_stream(|emitter_2| async move {
761                        emitter_2.emit(0).await;
762                        for j in 0..3 {
763                            emitter_ref.emit(j).await;
764                        }
765                        Ok::<_, ()>(())
766                    }));
767                    while let Some(Ok(item)) = stream_2.next().await {
768                        emitter.emit(3 * i + item).await;
769                    }
770                }
771                Ok::<_, ()>(())
772            }));
773
774            let mut sum = 0;
775            while let Some(Ok(item)) = stream.next().await {
776                sum += item;
777            }
778            assert_eq!(sum, 36);
779        });
780    }
781}