async_fn_stream/
lib.rs

1#![doc = include_str!("../README.md")]
2
3use std::{
4    cell::{Cell, UnsafeCell},
5    panic::{RefUnwindSafe, UnwindSafe},
6    pin::Pin,
7    sync::Arc,
8    task::{Poll, Waker},
9};
10
11use futures_util::{Future, Stream};
12use pin_project_lite::pin_project;
13use smallvec::SmallVec;
14
15/// An intermediary that transfers values from stream to its consumer
16pub struct StreamEmitter<T> {
17    inner: Arc<UnsafeCell<Inner<T>>>,
18}
19
20/// An intermediary that transfers values from stream to its consumer
21pub struct TryStreamEmitter<T, E> {
22    inner: Arc<UnsafeCell<Inner<Result<T, E>>>>,
23}
24
25thread_local! {
26    /// A type-erased pointer to the `Inner<_>`, for which a call to `Stream::poll_next` is active.
27    static ACTIVE_STREAM_INNER: Cell<*const ()> = const { Cell::new(std::ptr::null()) };
28}
29
30/// A guard that ensures that `ACTIVE_STREAM_INNER` is returned to its previous value even in case of a panic.
31struct ActiveStreamPointerGuard {
32    old_ptr: *const (),
33}
34
35impl ActiveStreamPointerGuard {
36    fn set_active_ptr(ptr: *const ()) -> Self {
37        let old_ptr = ACTIVE_STREAM_INNER.with(|thread_ptr| thread_ptr.replace(ptr));
38        Self { old_ptr }
39    }
40}
41
42impl Drop for ActiveStreamPointerGuard {
43    fn drop(&mut self) {
44        ACTIVE_STREAM_INNER.with(|thread_ptr| thread_ptr.set(self.old_ptr));
45    }
46}
47
48/// SAFETY:
49/// `Inner<T>` stores the state shared between `StreamEmitter`/`TryStreamEmitter` and `FnStream`/`TryFnStream`.
50/// It is stored within `Arc<UnsafeCell<_>>`. Exclusive access to it is controlled using:
51/// 1) exclusive reference `Pin<&mut FnStream>` which is provided to `FnStream::poll_next`
52/// 2) `ACTIVE_STREAM_INNER`, which is itself managed by `FnStream::poll_next`.
53///    - `ACTIVE_STREAM_INNER` is only set when the corresponding `Pin<&mut FnStream>` is on the stack
54///    - `EmitFuture` can safely access its inner reference if it equals to `ACTIVE_STREAM_INNER` (all other accesses are invalid and lead to panics)
55struct Inner<T> {
56    // `stream_waker` is used to compare the waker of the stream future with the waker of the `Emit` future.
57    // If the stream implementation does not have sub-executors, we don't have to push wakers to `pending_wakers`, which avoids cloning it.
58    stream_waker: Option<Waker>,
59    // Due to internal concurrency, a single call to stream's future may yield multiple elements.
60    // All elements are stored here and yielded from the stream before polling the future again.
61    pending_values: SmallVec<[T; 1]>,
62    pending_wakers: SmallVec<[Waker; 1]>,
63}
64
65// SAFETY: `FnStream` implements own synchronization
66unsafe impl<T: Send, Fut: Future<Output = ()> + Send> Send for FnStream<T, Fut> {}
67// SAFETY: `FnStream` implements own synchronization
68unsafe impl<T: Send, Fut: Future<Output = ()> + Sync> Sync for FnStream<T, Fut> {}
69impl<T: UnwindSafe, Fut: Future<Output = ()> + UnwindSafe> UnwindSafe for FnStream<T, Fut> {}
70impl<T: RefUnwindSafe, Fut: Future<Output = ()> + RefUnwindSafe> RefUnwindSafe
71    for FnStream<T, Fut>
72{
73}
74// SAFETY: `TryFnStream` implements own synchronization
75unsafe impl<T: Send, E: Send, Fut: Future<Output = Result<(), E>> + Send> Send
76    for TryFnStream<T, E, Fut>
77{
78}
79// SAFETY: `TryFnStream` implements own synchronization
80unsafe impl<T: Send, E: Send, Fut: Future<Output = Result<(), E>> + Sync> Sync
81    for TryFnStream<T, E, Fut>
82{
83}
84impl<T: UnwindSafe, E: UnwindSafe, Fut: Future<Output = Result<(), E>> + UnwindSafe> UnwindSafe
85    for TryFnStream<T, E, Fut>
86{
87}
88impl<T: RefUnwindSafe, E: RefUnwindSafe, Fut: Future<Output = Result<(), E>> + RefUnwindSafe>
89    RefUnwindSafe for TryFnStream<T, E, Fut>
90{
91}
92// SAFETY: `StreamEmitter` implements own synchronization
93unsafe impl<T: Send> Send for StreamEmitter<T> {}
94// SAFETY: `StreamEmitter` implements own synchronization
95unsafe impl<T: Send> Sync for StreamEmitter<T> {}
96impl<T: UnwindSafe> UnwindSafe for StreamEmitter<T> {}
97impl<T: RefUnwindSafe> RefUnwindSafe for StreamEmitter<T> {}
98// SAFETY: `TryStreamEmitter` implements own synchronization
99unsafe impl<T: Send, E: Send> Send for TryStreamEmitter<T, E> {}
100// SAFETY: `TryStreamEmitter` implements own synchronization
101unsafe impl<T: Send, E: Send> Sync for TryStreamEmitter<T, E> {}
102impl<T: UnwindSafe, E: UnwindSafe> UnwindSafe for TryStreamEmitter<T, E> {}
103impl<T: RefUnwindSafe, E: RefUnwindSafe> RefUnwindSafe for TryStreamEmitter<T, E> {}
104// SAFETY: `EmitFuture` implements own synchronization.
105unsafe impl<T: Send> Send for EmitFuture<'_, T> {}
106// SAFETY: `EmitFuture` implements own synchronization
107unsafe impl<T: Send> Sync for EmitFuture<'_, T> {}
108impl<T: UnwindSafe> UnwindSafe for EmitFuture<'_, T> {}
109impl<T: RefUnwindSafe> RefUnwindSafe for EmitFuture<'_, T> {}
110
111pin_project! {
112    /// Implementation of [`Stream`] trait created by [`fn_stream`].
113    pub struct FnStream<T, Fut: Future<Output = ()>> {
114        #[pin]
115        fut: Fut,
116        inner: Arc<UnsafeCell<Inner<T>>>,
117    }
118}
119
120/// Create a new infallible stream which is implemented by `func`.
121///
122/// Caller should pass an async function which will return successive stream elements via [`StreamEmitter::emit`].
123///
124/// # Example
125///
126/// ```rust
127/// use async_fn_stream::fn_stream;
128/// use futures_util::Stream;
129///
130/// fn build_stream() -> impl Stream<Item = i32> {
131///     fn_stream(|emitter| async move {
132///         for i in 0..3 {
133///             // yield elements from stream via `emitter`
134///             emitter.emit(i).await;
135///         }
136///     })
137/// }
138/// ```
139pub fn fn_stream<T, Fut: Future<Output = ()>>(
140    func: impl FnOnce(StreamEmitter<T>) -> Fut,
141) -> FnStream<T, Fut> {
142    FnStream::new(func)
143}
144
145impl<T, Fut: Future<Output = ()>> FnStream<T, Fut> {
146    fn new<F: FnOnce(StreamEmitter<T>) -> Fut>(func: F) -> Self {
147        let inner = Arc::new(UnsafeCell::new(Inner {
148            stream_waker: None,
149            pending_values: SmallVec::new(),
150            pending_wakers: SmallVec::new(),
151        }));
152        let emitter = StreamEmitter {
153            inner: inner.clone(),
154        };
155        let fut = func(emitter);
156        Self { fut, inner }
157    }
158}
159
160impl<T, Fut: Future<Output = ()>> Stream for FnStream<T, Fut> {
161    type Item = T;
162
163    fn poll_next(
164        self: Pin<&mut Self>,
165        cx: &mut std::task::Context<'_>,
166    ) -> Poll<Option<Self::Item>> {
167        let this = self.project();
168
169        // SAFETY:
170        // (see safety comment for `Inner<T>`)
171        // 1) we have no aliasing, since we're holding unique reference to `Self`, and
172        // 2) `this.inner` is not deallocated for the duration of this method
173        let inner = unsafe { &mut *this.inner.get() };
174        if let Some(value) = inner.pending_values.pop() {
175            return Poll::Ready(Some(value));
176        }
177        if !inner.pending_wakers.is_empty() {
178            for waker in inner.pending_wakers.drain(..) {
179                if !waker.will_wake(cx.waker()) {
180                    waker.wake();
181                }
182            }
183        }
184        if let Some(stream_waker) = inner.stream_waker.as_mut() {
185            stream_waker.clone_from(cx.waker());
186        } else {
187            inner.stream_waker = Some(cx.waker().clone());
188        }
189
190        // SAFETY: ensure that we're not holding a reference to this.inner
191        _ = inner;
192
193        // SAFETY for Inner<T>:
194        // - `ACTIVE_STREAM_INNER` now contains valid pointer to `this.inner` during the call to `fut.poll`
195        // - `ACTIVE_STREAM_INNER` is restored after the call due to the use of guard
196        let polling_ptr_guard =
197            ActiveStreamPointerGuard::set_active_ptr(Arc::as_ptr(&*this.inner).cast());
198        let r = this.fut.poll(cx);
199        drop(polling_ptr_guard);
200
201        // SAFETY:
202        // (see safety comment for `Inner<T>`)
203        // 1) we have no aliasing, since:
204        //    - we're holding unique reference to `Self`, and
205        //    - we removed the pointer from `ACTIVE_STREAM_INNER`
206        // 2) `this.inner` is not deallocated for the duration of this method
207        let inner = unsafe { &mut *this.inner.get() };
208
209        match r {
210            std::task::Poll::Ready(()) => Poll::Ready(None),
211            std::task::Poll::Pending => {
212                if let Some(value) = inner.pending_values.pop() {
213                    Poll::Ready(Some(value))
214                } else {
215                    Poll::Pending
216                }
217            }
218        }
219    }
220}
221
222/// Create a new fallible stream which is implemented by `func`.
223///
224/// Caller should pass an async function which can:
225///
226/// - return successive stream elements via [`TryStreamEmitter::emit`]
227/// - return transient errors via [`TryStreamEmitter::emit_err`]
228/// - return fatal errors as [`Result::Err`]
229///
230/// # Example
231/// ```rust
232/// use async_fn_stream::try_fn_stream;
233/// use futures_util::Stream;
234///
235/// fn build_stream() -> impl Stream<Item = Result<i32, anyhow::Error>> {
236///     try_fn_stream(|emitter| async move {
237///         for i in 0..3 {
238///             // yield elements from stream via `emitter`
239///             emitter.emit(i).await;
240///         }
241///
242///         // return errors view emitter without ending the stream
243///         emitter.emit_err(anyhow::anyhow!("An error happened"));
244///
245///         // return errors from stream, ending the stream
246///         Err(anyhow::anyhow!("An error happened"))
247///     })
248/// }
249/// ```
250pub fn try_fn_stream<T, E, Fut: Future<Output = Result<(), E>>>(
251    func: impl FnOnce(TryStreamEmitter<T, E>) -> Fut,
252) -> TryFnStream<T, E, Fut> {
253    TryFnStream::new(func)
254}
255
256pin_project! {
257    /// Implementation of [`Stream`] trait created by [`try_fn_stream`].
258    pub struct TryFnStream<T, E, Fut: Future<Output = Result<(), E>>> {
259        is_err: bool,
260        #[pin]
261        fut: Fut,
262        inner: Arc<UnsafeCell<Inner<Result<T, E>>>>,
263    }
264}
265
266impl<T, E, Fut: Future<Output = Result<(), E>>> TryFnStream<T, E, Fut> {
267    fn new<F: FnOnce(TryStreamEmitter<T, E>) -> Fut>(func: F) -> Self {
268        let inner = Arc::new(UnsafeCell::new(Inner {
269            stream_waker: None,
270            pending_values: SmallVec::new(),
271            pending_wakers: SmallVec::new(),
272        }));
273        let emitter = TryStreamEmitter {
274            inner: inner.clone(),
275        };
276        let fut = func(emitter);
277        Self {
278            is_err: false,
279            fut,
280            inner,
281        }
282    }
283}
284
285impl<T, E, Fut: Future<Output = Result<(), E>>> Stream for TryFnStream<T, E, Fut> {
286    type Item = Result<T, E>;
287
288    fn poll_next(
289        self: Pin<&mut Self>,
290        cx: &mut std::task::Context<'_>,
291    ) -> Poll<Option<Self::Item>> {
292        // TODO: merge the implementation with `FnStream`
293        if self.is_err {
294            return Poll::Ready(None);
295        }
296        let this = self.project();
297        // SAFETY:
298        // (see safety comment for `Inner<T>`)
299        // 1) we have no aliasing, since we're holding unique reference to `Self`, and
300        // 2) `this.inner` is not deallocated for the duration of this method
301        let inner = unsafe { &mut *this.inner.get() };
302        if let Some(value) = inner.pending_values.pop() {
303            return Poll::Ready(Some(value));
304        }
305        if !inner.pending_wakers.is_empty() {
306            for waker in inner.pending_wakers.drain(..) {
307                if !waker.will_wake(cx.waker()) {
308                    waker.wake();
309                }
310            }
311        }
312        if let Some(stream_waker) = inner.stream_waker.as_mut() {
313            stream_waker.clone_from(cx.waker());
314        } else {
315            inner.stream_waker = Some(cx.waker().clone());
316        }
317
318        // SAFETY: ensure that we're not holding a reference to this.inner
319        _ = inner;
320
321        // SAFETY for Inner<T>:
322        // - `ACTIVE_STREAM_INNER` now contains valid pointer to `this.inner` during the call to `fut.poll`
323        // - `ACTIVE_STREAM_INNER` is restored after the call due to the use of guard
324        let polling_ptr_guard =
325            ActiveStreamPointerGuard::set_active_ptr(Arc::as_ptr(&*this.inner).cast());
326        let r = this.fut.poll(cx);
327        drop(polling_ptr_guard);
328
329        // SAFETY:
330        // (see safety comment for `Inner<T>`)
331        // 1) we have no aliasing, since:
332        //    - we're holding unique reference to `Self`, and
333        //    - we removed the pointer from `ACTIVE_STREAM_INNER`
334        // 2) `this.inner` is not deallocated for the duration of this method
335        let inner = unsafe { &mut *this.inner.get() };
336        match r {
337            std::task::Poll::Ready(Ok(())) => Poll::Ready(None),
338            std::task::Poll::Ready(Err(e)) => {
339                *this.is_err = true;
340                Poll::Ready(Some(Err(e)))
341            }
342            std::task::Poll::Pending => {
343                if let Some(value) = inner.pending_values.pop() {
344                    Poll::Ready(Some(value))
345                } else {
346                    Poll::Pending
347                }
348            }
349        }
350    }
351}
352
353impl<T> StreamEmitter<T> {
354    /// Emit value from a stream and wait until stream consumer calls [`futures_util::StreamExt::next`] again.
355    ///
356    /// # Panics
357    /// Will panic if:
358    /// * `emit` is called twice without awaiting result of first call
359    /// * `emit` is called not in context of polling the stream
360    #[must_use = "Ensure that emit() is awaited"]
361    pub fn emit(&'_ self, value: T) -> EmitFuture<'_, T> {
362        EmitFuture::new(&self.inner, value)
363    }
364}
365
366impl<T, E> TryStreamEmitter<T, E> {
367    /// Emit value from a stream and wait until stream consumer calls [`futures_util::StreamExt::next`] again.
368    ///
369    /// # Panics
370    /// Will panic if:
371    /// * `emit`/`emit_err` is called twice without awaiting result of the first call
372    /// * `emit` is called not in context of polling the stream
373    #[must_use = "Ensure that emit() is awaited"]
374    pub fn emit(&'_ self, value: T) -> EmitFuture<'_, Result<T, E>> {
375        EmitFuture::new(&self.inner, Ok(value))
376    }
377
378    /// Emit error from a stream and wait until stream consumer calls [`futures_util::StreamExt::next`] again.
379    ///
380    /// # Panics
381    /// Will panic if:
382    /// * `emit`/`emit_err` is called twice without awaiting result of the first call
383    /// * `emit_err` is called not in context of polling the stream
384    #[must_use = "Ensure that emit_err() is awaited"]
385    pub fn emit_err(&'_ self, err: E) -> EmitFuture<'_, Result<T, E>> {
386        EmitFuture::new(&self.inner, Err(err))
387    }
388}
389
390pin_project! {
391    /// Future returned from [`StreamEmitter::emit`].
392    pub struct EmitFuture<'a, T> {
393        inner: &'a UnsafeCell<Inner<T>>,
394        value: Option<T>,
395    }
396}
397
398impl<'a, T> EmitFuture<'a, T> {
399    fn new(inner: &'a UnsafeCell<Inner<T>>, value: T) -> Self {
400        Self {
401            inner,
402            value: Some(value),
403        }
404    }
405}
406
407impl<T> Future for EmitFuture<'_, T> {
408    type Output = ();
409
410    fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
411        let this = self.project();
412        assert!(
413            ACTIVE_STREAM_INNER.get() == std::ptr::from_ref(*this.inner).cast::<()>(),
414            "StreamEmitter::emit().await should only be called in the context of the corresponding `fn_stream()`/`try_fn_stream()`"
415        );
416        // SAFETY:
417        // 1) we hold a unique reference to `this.inner` since we verified that ACTIVE_STREAM_INNER == self.inner
418        //    - we're calling `{Try,}FnStream::poll_next` in this thread, which holds the only other reference to the same instance of `Inner<T>`
419        //    - `{Try,}FnStream::poll_next` is not holding a reference to `self.inner` during the call to `fut.poll`
420        // 2) `this.inner` is not deallocated for the duration of this method
421        let inner = unsafe { &mut *this.inner.get() };
422
423        if let Some(value) = this.value.take() {
424            inner.pending_values.push(value);
425            let is_same_waker = if let Some(stream_waker) = inner.stream_waker.as_ref() {
426                stream_waker.will_wake(cx.waker())
427            } else {
428                false
429            };
430            if !is_same_waker {
431                inner.pending_wakers.push(cx.waker().clone());
432            }
433            Poll::Pending
434        } else if inner.pending_values.is_empty() {
435            // stream only polls the future after draining `inner.pending_values`, so this check should not be necessary in theory;
436            // this is just a safeguard against misuses; e.g. if a future calls `.emit().poll()` in a loop without yielding on `Poll::Pending`,
437            // this would lead to overflow of `inner.pending_values`
438            Poll::Ready(())
439        } else {
440            Poll::Pending
441        }
442    }
443}
444
445#[cfg(test)]
446mod tests {
447    use std::{io::ErrorKind, pin::pin};
448
449    use futures_util::{pin_mut, stream::FuturesUnordered, StreamExt};
450
451    use super::*;
452
453    #[test]
454    fn infallible_works() {
455        futures_executor::block_on(async {
456            let stream = fn_stream(|emitter| async move {
457                eprintln!("stream 1");
458                emitter.emit(1).await;
459                eprintln!("stream 2");
460                emitter.emit(2).await;
461                eprintln!("stream 3");
462            });
463            pin_mut!(stream);
464            assert_eq!(Some(1), stream.next().await);
465            assert_eq!(Some(2), stream.next().await);
466            assert_eq!(None, stream.next().await);
467        });
468    }
469
470    #[test]
471    fn infallible_lifetime() {
472        let a = 1;
473        futures_executor::block_on(async {
474            let b = 2;
475            let a = &a;
476            let b = &b;
477            let stream = fn_stream(|emitter| async move {
478                eprintln!("stream 1");
479                emitter.emit(a).await;
480                eprintln!("stream 2");
481                emitter.emit(b).await;
482                eprintln!("stream 3");
483            });
484            pin_mut!(stream);
485            assert_eq!(Some(a), stream.next().await);
486            assert_eq!(Some(b), stream.next().await);
487            assert_eq!(None, stream.next().await);
488        });
489    }
490
491    #[test]
492    fn infallible_unawaited_emit_is_ignored() {
493        futures_executor::block_on(async {
494            #[expect(
495                unused_must_use,
496                reason = "this code intentionally does not await emitter.emit()"
497            )]
498            let stream = fn_stream(|emitter| async move {
499                emitter.emit(1)/* .await */;
500                emitter.emit(2)/* .await */;
501                emitter.emit(3).await;
502            });
503            pin_mut!(stream);
504            assert_eq!(Some(3), stream.next().await);
505            assert_eq!(None, stream.next().await);
506        });
507    }
508
509    #[test]
510    fn fallible_works() {
511        futures_executor::block_on(async {
512            let stream = try_fn_stream(|emitter| async move {
513                eprintln!("try stream 1");
514                emitter.emit(1).await;
515                eprintln!("try stream 2");
516                emitter.emit(2).await;
517                eprintln!("try stream 3");
518                Err(std::io::Error::from(ErrorKind::Other))
519            });
520            pin_mut!(stream);
521            assert_eq!(1, stream.next().await.unwrap().unwrap());
522            assert_eq!(2, stream.next().await.unwrap().unwrap());
523            assert!(stream.next().await.unwrap().is_err());
524            assert!(stream.next().await.is_none());
525        });
526    }
527
528    #[test]
529    fn fallible_emit_err_works() {
530        futures_executor::block_on(async {
531            let stream = try_fn_stream(|emitter| async move {
532                eprintln!("try stream 1");
533                emitter.emit(1).await;
534                eprintln!("try stream 2");
535                emitter.emit(2).await;
536                eprintln!("try stream 3");
537                emitter
538                    .emit_err(std::io::Error::from(ErrorKind::Other))
539                    .await;
540                eprintln!("try stream 4");
541                Err(std::io::Error::from(ErrorKind::Other))
542            });
543            pin_mut!(stream);
544            assert_eq!(1, stream.next().await.unwrap().unwrap());
545            assert_eq!(2, stream.next().await.unwrap().unwrap());
546            assert!(stream.next().await.unwrap().is_err());
547            assert!(stream.next().await.unwrap().is_err());
548            assert!(stream.next().await.is_none());
549        });
550    }
551
552    #[test]
553    fn method_async() {
554        struct St {
555            a: String,
556        }
557
558        impl St {
559            async fn f1(&self) -> impl Stream<Item = &str> {
560                self.f2().await
561            }
562
563            #[allow(clippy::unused_async)]
564            async fn f2(&self) -> impl Stream<Item = &str> {
565                fn_stream(|emitter| async move {
566                    emitter.emit(self.a.as_str()).await;
567                    emitter.emit(self.a.as_str()).await;
568                    emitter.emit(self.a.as_str()).await;
569                })
570            }
571        }
572
573        futures_executor::block_on(async {
574            let l = St {
575                a: "qwe".to_owned(),
576            };
577            let s = l.f1().await;
578            let z: Vec<&str> = s.collect().await;
579            assert_eq!(z, ["qwe", "qwe", "qwe"]);
580        });
581    }
582
583    #[test]
584    fn tokio_join_one_works() {
585        futures_executor::block_on(async {
586            let stream = fn_stream(|emitter| async move {
587                tokio::join!(async { emitter.emit(1).await },);
588                emitter.emit(2).await;
589            });
590            pin_mut!(stream);
591            assert_eq!(Some(1), stream.next().await);
592            assert_eq!(Some(2), stream.next().await);
593            assert_eq!(None, stream.next().await);
594        });
595    }
596
597    #[test]
598    fn tokio_join_many_works() {
599        futures_executor::block_on(async {
600            let stream = fn_stream(|emitter| async move {
601                eprintln!("try stream 1");
602                tokio::join!(
603                    async { emitter.emit(1).await },
604                    async { emitter.emit(2).await },
605                    async { emitter.emit(3).await },
606                );
607                emitter.emit(4).await;
608            });
609            pin_mut!(stream);
610            for _ in 0..3 {
611                let item = stream.next().await;
612                assert!(matches!(item, Some(1..=3)));
613            }
614            assert_eq!(Some(4), stream.next().await);
615            assert_eq!(None, stream.next().await);
616        });
617    }
618
619    #[test]
620    fn tokio_futures_unordered_one_works() {
621        futures_executor::block_on(async {
622            let stream = fn_stream(|emitter| async move {
623                let mut futs: FuturesUnordered<_> = (1..=1)
624                    .map(|i| {
625                        let emitter = &emitter;
626                        async move { emitter.emit(i).await }
627                    })
628                    .collect();
629                while futs.next().await.is_some() {}
630                emitter.emit(2).await;
631            });
632            pin_mut!(stream);
633            assert_eq!(Some(1), stream.next().await);
634            assert_eq!(Some(2), stream.next().await);
635            assert_eq!(None, stream.next().await);
636        });
637    }
638
639    #[test]
640    fn tokio_futures_unordered_many_works() {
641        futures_executor::block_on(async {
642            let stream = fn_stream(|emitter| async move {
643                let mut futs: FuturesUnordered<_> = (1..=3)
644                    .map(|i| {
645                        let emitter = &emitter;
646                        async move { emitter.emit(i).await }
647                    })
648                    .collect();
649                while futs.next().await.is_some() {}
650                emitter.emit(4).await;
651            });
652            pin_mut!(stream);
653            for _ in 1..=3 {
654                let item = stream.next().await;
655                assert!(matches!(item, Some(1..=3)));
656            }
657            assert_eq!(Some(4), stream.next().await);
658            assert_eq!(None, stream.next().await);
659        });
660    }
661
662    #[test]
663    fn infallible_nested_streams_work() {
664        futures_executor::block_on(async {
665            let mut stream = pin!(fn_stream(|emitter| async move {
666                for i in 0..3 {
667                    let mut stream_2 = pin!(fn_stream(|emitter| async move {
668                        for j in 0..3 {
669                            emitter.emit(j).await;
670                        }
671                    }));
672                    while let Some(item) = stream_2.next().await {
673                        emitter.emit(3 * i + item).await;
674                    }
675                }
676            }));
677            let mut sum = 0;
678            while let Some(item) = stream.next().await {
679                sum += item;
680            }
681            assert_eq!(sum, 36);
682        });
683    }
684
685    #[test]
686    fn fallible_nested_streams_work() {
687        futures_executor::block_on(async {
688            let mut stream = pin!(try_fn_stream(|emitter| async move {
689                for i in 0..3 {
690                    let mut stream_2 = pin!(try_fn_stream(|emitter| async move {
691                        for j in 0..3 {
692                            emitter.emit(j).await;
693                        }
694                        Ok::<_, ()>(())
695                    }));
696                    while let Some(Ok(item)) = stream_2.next().await {
697                        emitter.emit(3 * i + item).await;
698                    }
699                }
700                Ok::<_, ()>(())
701            }));
702            let mut sum = 0;
703            while let Some(Ok(item)) = stream.next().await {
704                sum += item;
705            }
706            assert_eq!(sum, 36);
707        });
708    }
709
710    #[test]
711    #[should_panic(
712        expected = "StreamEmitter::emit().await should only be called in the context of the corresponding `fn_stream()`/`try_fn_stream()`"
713    )]
714    fn infallible_bad_nested_emit_detected() {
715        futures_executor::block_on(async {
716            let mut stream = pin!(fn_stream(|emitter| async move {
717                for i in 0..3 {
718                    let emitter_ref = &emitter;
719                    let mut stream_2 = pin!(fn_stream(|emitter_2| async move {
720                        emitter_2.emit(0).await;
721                        for j in 0..3 {
722                            emitter_ref.emit(j).await;
723                        }
724                    }));
725                    while let Some(item) = stream_2.next().await {
726                        emitter.emit(3 * i + item).await;
727                    }
728                }
729            }));
730
731            let mut sum = 0;
732            while let Some(item) = stream.next().await {
733                sum += item;
734            }
735            assert_eq!(sum, 36);
736        });
737    }
738
739    #[test]
740    #[should_panic(
741        expected = "StreamEmitter::emit().await should only be called in the context of the corresponding `fn_stream()`/`try_fn_stream()`"
742    )]
743    fn fallible_bad_nested_emit_detected() {
744        futures_executor::block_on(async {
745            let mut stream = pin!(try_fn_stream(|emitter| async move {
746                for i in 0..3 {
747                    let emitter_ref = &emitter;
748                    let mut stream_2 = pin!(try_fn_stream(|emitter_2| async move {
749                        emitter_2.emit(0).await;
750                        for j in 0..3 {
751                            emitter_ref.emit(j).await;
752                        }
753                        Ok::<_, ()>(())
754                    }));
755                    while let Some(Ok(item)) = stream_2.next().await {
756                        emitter.emit(3 * i + item).await;
757                    }
758                }
759                Ok::<_, ()>(())
760            }));
761
762            let mut sum = 0;
763            while let Some(Ok(item)) = stream.next().await {
764                sum += item;
765            }
766            assert_eq!(sum, 36);
767        });
768    }
769}