1#![doc = include_str!("../README.md")]
2
3use std::{
4 cell::{Cell, UnsafeCell},
5 panic::{RefUnwindSafe, UnwindSafe},
6 pin::Pin,
7 sync::Arc,
8 task::{Poll, Waker},
9};
10
11use futures_util::{Future, Stream};
12use pin_project_lite::pin_project;
13use smallvec::SmallVec;
14
15pub struct StreamEmitter<T> {
17 inner: Arc<UnsafeCell<Inner<T>>>,
18}
19
20pub struct TryStreamEmitter<T, E> {
22 inner: Arc<UnsafeCell<Inner<Result<T, E>>>>,
23}
24
25thread_local! {
26 static ACTIVE_STREAM_INNER: Cell<*const ()> = const { Cell::new(std::ptr::null()) };
28}
29
30struct ActiveStreamPointerGuard {
32 old_ptr: *const (),
33}
34
35impl ActiveStreamPointerGuard {
36 fn set_active_ptr(ptr: *const ()) -> Self {
37 let old_ptr = ACTIVE_STREAM_INNER.with(|thread_ptr| thread_ptr.replace(ptr));
38 Self { old_ptr }
39 }
40}
41
42impl Drop for ActiveStreamPointerGuard {
43 fn drop(&mut self) {
44 ACTIVE_STREAM_INNER.with(|thread_ptr| thread_ptr.set(self.old_ptr));
45 }
46}
47
48struct Inner<T> {
56 stream_waker: Option<Waker>,
59 pending_values: SmallVec<[T; 1]>,
62 pending_wakers: SmallVec<[Waker; 1]>,
63}
64
65unsafe impl<T: Send, Fut: Future<Output = ()> + Send> Send for FnStream<T, Fut> {}
67unsafe impl<T: Send, Fut: Future<Output = ()> + Sync> Sync for FnStream<T, Fut> {}
69impl<T: UnwindSafe, Fut: Future<Output = ()> + UnwindSafe> UnwindSafe for FnStream<T, Fut> {}
70impl<T: RefUnwindSafe, Fut: Future<Output = ()> + RefUnwindSafe> RefUnwindSafe
71 for FnStream<T, Fut>
72{
73}
74unsafe impl<T: Send, E: Send, Fut: Future<Output = Result<(), E>> + Send> Send
76 for TryFnStream<T, E, Fut>
77{
78}
79unsafe impl<T: Send, E: Send, Fut: Future<Output = Result<(), E>> + Sync> Sync
81 for TryFnStream<T, E, Fut>
82{
83}
84impl<T: UnwindSafe, E: UnwindSafe, Fut: Future<Output = Result<(), E>> + UnwindSafe> UnwindSafe
85 for TryFnStream<T, E, Fut>
86{
87}
88impl<T: RefUnwindSafe, E: RefUnwindSafe, Fut: Future<Output = Result<(), E>> + RefUnwindSafe>
89 RefUnwindSafe for TryFnStream<T, E, Fut>
90{
91}
92unsafe impl<T: Send> Send for StreamEmitter<T> {}
94unsafe impl<T: Send> Sync for StreamEmitter<T> {}
96impl<T: UnwindSafe> UnwindSafe for StreamEmitter<T> {}
97impl<T: RefUnwindSafe> RefUnwindSafe for StreamEmitter<T> {}
98unsafe impl<T: Send, E: Send> Send for TryStreamEmitter<T, E> {}
100unsafe impl<T: Send, E: Send> Sync for TryStreamEmitter<T, E> {}
102impl<T: UnwindSafe, E: UnwindSafe> UnwindSafe for TryStreamEmitter<T, E> {}
103impl<T: RefUnwindSafe, E: RefUnwindSafe> RefUnwindSafe for TryStreamEmitter<T, E> {}
104unsafe impl<T: Send> Send for EmitFuture<'_, T> {}
106unsafe impl<T: Send> Sync for EmitFuture<'_, T> {}
108impl<T: UnwindSafe> UnwindSafe for EmitFuture<'_, T> {}
109impl<T: RefUnwindSafe> RefUnwindSafe for EmitFuture<'_, T> {}
110
111pin_project! {
112 pub struct FnStream<T, Fut: Future<Output = ()>> {
114 #[pin]
115 fut: Fut,
116 inner: Arc<UnsafeCell<Inner<T>>>,
117 }
118}
119
120pub fn fn_stream<T, Fut: Future<Output = ()>>(
140 func: impl FnOnce(StreamEmitter<T>) -> Fut,
141) -> FnStream<T, Fut> {
142 FnStream::new(func)
143}
144
145impl<T, Fut: Future<Output = ()>> FnStream<T, Fut> {
146 fn new<F: FnOnce(StreamEmitter<T>) -> Fut>(func: F) -> Self {
147 let inner = Arc::new(UnsafeCell::new(Inner {
148 stream_waker: None,
149 pending_values: SmallVec::new(),
150 pending_wakers: SmallVec::new(),
151 }));
152 let emitter = StreamEmitter {
153 inner: inner.clone(),
154 };
155 let fut = func(emitter);
156 Self { fut, inner }
157 }
158}
159
160impl<T, Fut: Future<Output = ()>> Stream for FnStream<T, Fut> {
161 type Item = T;
162
163 fn poll_next(
164 self: Pin<&mut Self>,
165 cx: &mut std::task::Context<'_>,
166 ) -> Poll<Option<Self::Item>> {
167 let this = self.project();
168
169 let inner = unsafe { &mut *this.inner.get() };
174 if let Some(value) = inner.pending_values.pop() {
175 return Poll::Ready(Some(value));
176 }
177 if !inner.pending_wakers.is_empty() {
178 for waker in inner.pending_wakers.drain(..) {
179 if !waker.will_wake(cx.waker()) {
180 waker.wake();
181 }
182 }
183 }
184 if let Some(stream_waker) = inner.stream_waker.as_mut() {
185 stream_waker.clone_from(cx.waker());
186 } else {
187 inner.stream_waker = Some(cx.waker().clone());
188 }
189
190 _ = inner;
192
193 let polling_ptr_guard =
197 ActiveStreamPointerGuard::set_active_ptr(Arc::as_ptr(&*this.inner).cast());
198 let r = this.fut.poll(cx);
199 drop(polling_ptr_guard);
200
201 let inner = unsafe { &mut *this.inner.get() };
208
209 match r {
210 std::task::Poll::Ready(()) => Poll::Ready(None),
211 std::task::Poll::Pending => {
212 if let Some(value) = inner.pending_values.pop() {
213 Poll::Ready(Some(value))
214 } else {
215 Poll::Pending
216 }
217 }
218 }
219 }
220}
221
222pub fn try_fn_stream<T, E, Fut: Future<Output = Result<(), E>>>(
251 func: impl FnOnce(TryStreamEmitter<T, E>) -> Fut,
252) -> TryFnStream<T, E, Fut> {
253 TryFnStream::new(func)
254}
255
256pin_project! {
257 pub struct TryFnStream<T, E, Fut: Future<Output = Result<(), E>>> {
259 is_err: bool,
260 #[pin]
261 fut: Fut,
262 inner: Arc<UnsafeCell<Inner<Result<T, E>>>>,
263 }
264}
265
266impl<T, E, Fut: Future<Output = Result<(), E>>> TryFnStream<T, E, Fut> {
267 fn new<F: FnOnce(TryStreamEmitter<T, E>) -> Fut>(func: F) -> Self {
268 let inner = Arc::new(UnsafeCell::new(Inner {
269 stream_waker: None,
270 pending_values: SmallVec::new(),
271 pending_wakers: SmallVec::new(),
272 }));
273 let emitter = TryStreamEmitter {
274 inner: inner.clone(),
275 };
276 let fut = func(emitter);
277 Self {
278 is_err: false,
279 fut,
280 inner,
281 }
282 }
283}
284
285impl<T, E, Fut: Future<Output = Result<(), E>>> Stream for TryFnStream<T, E, Fut> {
286 type Item = Result<T, E>;
287
288 fn poll_next(
289 self: Pin<&mut Self>,
290 cx: &mut std::task::Context<'_>,
291 ) -> Poll<Option<Self::Item>> {
292 if self.is_err {
294 return Poll::Ready(None);
295 }
296 let this = self.project();
297 let inner = unsafe { &mut *this.inner.get() };
302 if let Some(value) = inner.pending_values.pop() {
303 return Poll::Ready(Some(value));
304 }
305 if !inner.pending_wakers.is_empty() {
306 for waker in inner.pending_wakers.drain(..) {
307 if !waker.will_wake(cx.waker()) {
308 waker.wake();
309 }
310 }
311 }
312 if let Some(stream_waker) = inner.stream_waker.as_mut() {
313 stream_waker.clone_from(cx.waker());
314 } else {
315 inner.stream_waker = Some(cx.waker().clone());
316 }
317
318 _ = inner;
320
321 let polling_ptr_guard =
325 ActiveStreamPointerGuard::set_active_ptr(Arc::as_ptr(&*this.inner).cast());
326 let r = this.fut.poll(cx);
327 drop(polling_ptr_guard);
328
329 let inner = unsafe { &mut *this.inner.get() };
336 match r {
337 std::task::Poll::Ready(Ok(())) => Poll::Ready(None),
338 std::task::Poll::Ready(Err(e)) => {
339 *this.is_err = true;
340 Poll::Ready(Some(Err(e)))
341 }
342 std::task::Poll::Pending => {
343 if let Some(value) = inner.pending_values.pop() {
344 Poll::Ready(Some(value))
345 } else {
346 Poll::Pending
347 }
348 }
349 }
350 }
351}
352
353impl<T> StreamEmitter<T> {
354 #[must_use = "Ensure that emit() is awaited"]
360 pub fn emit(&'_ self, value: T) -> EmitFuture<'_, T> {
361 EmitFuture::new(&self.inner, value)
362 }
363}
364
365impl<T, E> TryStreamEmitter<T, E> {
366 #[must_use = "Ensure that emit() is awaited"]
372 pub fn emit(&'_ self, value: T) -> EmitFuture<'_, Result<T, E>> {
373 EmitFuture::new(&self.inner, Ok(value))
374 }
375
376 #[must_use = "Ensure that emit() is awaited"]
382 pub fn emit_result(&'_ self, value: Result<T, E>) -> EmitFuture<'_, Result<T, E>> {
383 EmitFuture::new(&self.inner, value)
384 }
385
386 #[must_use = "Ensure that emit_err() is awaited"]
392 pub fn emit_err(&'_ self, err: E) -> EmitFuture<'_, Result<T, E>> {
393 EmitFuture::new(&self.inner, Err(err))
394 }
395}
396
397pin_project! {
398 pub struct EmitFuture<'a, T> {
400 inner: &'a UnsafeCell<Inner<T>>,
401 value: Option<T>,
402 }
403}
404
405impl<'a, T> EmitFuture<'a, T> {
406 fn new(inner: &'a UnsafeCell<Inner<T>>, value: T) -> Self {
407 Self {
408 inner,
409 value: Some(value),
410 }
411 }
412}
413
414impl<T> Future for EmitFuture<'_, T> {
415 type Output = ();
416
417 fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
418 let this = self.project();
419 assert!(
420 ACTIVE_STREAM_INNER.get() == std::ptr::from_ref(*this.inner).cast::<()>(),
421 "StreamEmitter::emit().await should only be called in the context of the corresponding `fn_stream()`/`try_fn_stream()`"
422 );
423 let inner = unsafe { &mut *this.inner.get() };
429
430 if let Some(value) = this.value.take() {
431 inner.pending_values.push(value);
432 let is_same_waker = if let Some(stream_waker) = inner.stream_waker.as_ref() {
433 stream_waker.will_wake(cx.waker())
434 } else {
435 false
436 };
437 if !is_same_waker {
438 inner.pending_wakers.push(cx.waker().clone());
439 }
440 Poll::Pending
441 } else if inner.pending_values.is_empty() {
442 Poll::Ready(())
446 } else {
447 Poll::Pending
448 }
449 }
450}
451
452#[cfg(test)]
453mod tests {
454 use std::{io::ErrorKind, pin::pin};
455
456 use futures_util::{pin_mut, stream::FuturesUnordered, StreamExt};
457
458 use super::*;
459
460 #[test]
461 fn infallible_works() {
462 futures_executor::block_on(async {
463 let stream = fn_stream(|emitter| async move {
464 eprintln!("stream 1");
465 emitter.emit(1).await;
466 eprintln!("stream 2");
467 emitter.emit(2).await;
468 eprintln!("stream 3");
469 });
470 pin_mut!(stream);
471 assert_eq!(Some(1), stream.next().await);
472 assert_eq!(Some(2), stream.next().await);
473 assert_eq!(None, stream.next().await);
474 });
475 }
476
477 #[test]
478 fn infallible_lifetime() {
479 let a = 1;
480 futures_executor::block_on(async {
481 let b = 2;
482 let a = &a;
483 let b = &b;
484 let stream = fn_stream(|emitter| async move {
485 eprintln!("stream 1");
486 emitter.emit(a).await;
487 eprintln!("stream 2");
488 emitter.emit(b).await;
489 eprintln!("stream 3");
490 });
491 pin_mut!(stream);
492 assert_eq!(Some(a), stream.next().await);
493 assert_eq!(Some(b), stream.next().await);
494 assert_eq!(None, stream.next().await);
495 });
496 }
497
498 #[test]
499 fn infallible_unawaited_emit_is_ignored() {
500 futures_executor::block_on(async {
501 #[expect(
502 unused_must_use,
503 reason = "this code intentionally does not await emitter.emit()"
504 )]
505 let stream = fn_stream(|emitter| async move {
506 emitter.emit(1);
507 emitter.emit(2);
508 emitter.emit(3).await;
509 });
510 pin_mut!(stream);
511 assert_eq!(Some(3), stream.next().await);
512 assert_eq!(None, stream.next().await);
513 });
514 }
515
516 #[test]
517 fn fallible_works() {
518 futures_executor::block_on(async {
519 let stream = try_fn_stream(|emitter| async move {
520 eprintln!("try stream 1");
521 emitter.emit(1).await;
522 eprintln!("try stream 2");
523 emitter.emit(2).await;
524 eprintln!("try stream 3");
525 Err(std::io::Error::from(ErrorKind::Other))
526 });
527 pin_mut!(stream);
528 assert_eq!(1, stream.next().await.unwrap().unwrap());
529 assert_eq!(2, stream.next().await.unwrap().unwrap());
530 assert!(stream.next().await.unwrap().is_err());
531 assert!(stream.next().await.is_none());
532 });
533 }
534
535 #[test]
536 fn fallible_emit_err_works() {
537 futures_executor::block_on(async {
538 let stream = try_fn_stream(|emitter| async move {
539 eprintln!("try stream 1");
540 emitter.emit(1).await;
541 eprintln!("try stream 2");
542 emitter.emit_result(Ok(2)).await;
543 eprintln!("try stream 3");
544 emitter
545 .emit_err(std::io::Error::from(ErrorKind::Other))
546 .await;
547 eprintln!("try stream 4");
548 emitter
549 .emit_result(Err(std::io::Error::from(ErrorKind::Other)))
550 .await;
551 eprintln!("try stream 5");
552 Err(std::io::Error::from(ErrorKind::Other))
553 });
554 pin_mut!(stream);
555 assert_eq!(1, stream.next().await.unwrap().unwrap());
556 assert_eq!(2, stream.next().await.unwrap().unwrap());
557 assert!(stream.next().await.unwrap().is_err());
558 assert!(stream.next().await.unwrap().is_err());
559 assert!(stream.next().await.unwrap().is_err());
560 assert!(stream.next().await.is_none());
561 });
562 }
563
564 #[test]
565 fn method_async() {
566 struct St {
567 a: String,
568 }
569
570 impl St {
571 async fn f1(&self) -> impl Stream<Item = &str> {
572 self.f2().await
573 }
574
575 #[allow(clippy::unused_async)]
576 async fn f2(&self) -> impl Stream<Item = &str> {
577 fn_stream(|emitter| async move {
578 emitter.emit(self.a.as_str()).await;
579 emitter.emit(self.a.as_str()).await;
580 emitter.emit(self.a.as_str()).await;
581 })
582 }
583 }
584
585 futures_executor::block_on(async {
586 let l = St {
587 a: "qwe".to_owned(),
588 };
589 let s = l.f1().await;
590 let z: Vec<&str> = s.collect().await;
591 assert_eq!(z, ["qwe", "qwe", "qwe"]);
592 });
593 }
594
595 #[test]
596 fn tokio_join_one_works() {
597 futures_executor::block_on(async {
598 let stream = fn_stream(|emitter| async move {
599 tokio::join!(async { emitter.emit(1).await },);
600 emitter.emit(2).await;
601 });
602 pin_mut!(stream);
603 assert_eq!(Some(1), stream.next().await);
604 assert_eq!(Some(2), stream.next().await);
605 assert_eq!(None, stream.next().await);
606 });
607 }
608
609 #[test]
610 fn tokio_join_many_works() {
611 futures_executor::block_on(async {
612 let stream = fn_stream(|emitter| async move {
613 eprintln!("try stream 1");
614 tokio::join!(
615 async { emitter.emit(1).await },
616 async { emitter.emit(2).await },
617 async { emitter.emit(3).await },
618 );
619 emitter.emit(4).await;
620 });
621 pin_mut!(stream);
622 for _ in 0..3 {
623 let item = stream.next().await;
624 assert!(matches!(item, Some(1..=3)));
625 }
626 assert_eq!(Some(4), stream.next().await);
627 assert_eq!(None, stream.next().await);
628 });
629 }
630
631 #[test]
632 fn tokio_futures_unordered_one_works() {
633 futures_executor::block_on(async {
634 let stream = fn_stream(|emitter| async move {
635 let mut futs: FuturesUnordered<_> = (1..=1)
636 .map(|i| {
637 let emitter = &emitter;
638 async move { emitter.emit(i).await }
639 })
640 .collect();
641 while futs.next().await.is_some() {}
642 emitter.emit(2).await;
643 });
644 pin_mut!(stream);
645 assert_eq!(Some(1), stream.next().await);
646 assert_eq!(Some(2), stream.next().await);
647 assert_eq!(None, stream.next().await);
648 });
649 }
650
651 #[test]
652 fn tokio_futures_unordered_many_works() {
653 futures_executor::block_on(async {
654 let stream = fn_stream(|emitter| async move {
655 let mut futs: FuturesUnordered<_> = (1..=3)
656 .map(|i| {
657 let emitter = &emitter;
658 async move { emitter.emit(i).await }
659 })
660 .collect();
661 while futs.next().await.is_some() {}
662 emitter.emit(4).await;
663 });
664 pin_mut!(stream);
665 for _ in 1..=3 {
666 let item = stream.next().await;
667 assert!(matches!(item, Some(1..=3)));
668 }
669 assert_eq!(Some(4), stream.next().await);
670 assert_eq!(None, stream.next().await);
671 });
672 }
673
674 #[test]
675 fn infallible_nested_streams_work() {
676 futures_executor::block_on(async {
677 let mut stream = pin!(fn_stream(|emitter| async move {
678 for i in 0..3 {
679 let mut stream_2 = pin!(fn_stream(|emitter| async move {
680 for j in 0..3 {
681 emitter.emit(j).await;
682 }
683 }));
684 while let Some(item) = stream_2.next().await {
685 emitter.emit(3 * i + item).await;
686 }
687 }
688 }));
689 let mut sum = 0;
690 while let Some(item) = stream.next().await {
691 sum += item;
692 }
693 assert_eq!(sum, 36);
694 });
695 }
696
697 #[test]
698 fn fallible_nested_streams_work() {
699 futures_executor::block_on(async {
700 let mut stream = pin!(try_fn_stream(|emitter| async move {
701 for i in 0..3 {
702 let mut stream_2 = pin!(try_fn_stream(|emitter| async move {
703 for j in 0..3 {
704 emitter.emit(j).await;
705 }
706 Ok::<_, ()>(())
707 }));
708 while let Some(Ok(item)) = stream_2.next().await {
709 emitter.emit(3 * i + item).await;
710 }
711 }
712 Ok::<_, ()>(())
713 }));
714 let mut sum = 0;
715 while let Some(Ok(item)) = stream.next().await {
716 sum += item;
717 }
718 assert_eq!(sum, 36);
719 });
720 }
721
722 #[test]
723 #[should_panic(
724 expected = "StreamEmitter::emit().await should only be called in the context of the corresponding `fn_stream()`/`try_fn_stream()`"
725 )]
726 fn infallible_bad_nested_emit_detected() {
727 futures_executor::block_on(async {
728 let mut stream = pin!(fn_stream(|emitter| async move {
729 for i in 0..3 {
730 let emitter_ref = &emitter;
731 let mut stream_2 = pin!(fn_stream(|emitter_2| async move {
732 emitter_2.emit(0).await;
733 for j in 0..3 {
734 emitter_ref.emit(j).await;
735 }
736 }));
737 while let Some(item) = stream_2.next().await {
738 emitter.emit(3 * i + item).await;
739 }
740 }
741 }));
742
743 let mut sum = 0;
744 while let Some(item) = stream.next().await {
745 sum += item;
746 }
747 assert_eq!(sum, 36);
748 });
749 }
750
751 #[test]
752 #[should_panic(
753 expected = "StreamEmitter::emit().await should only be called in the context of the corresponding `fn_stream()`/`try_fn_stream()`"
754 )]
755 fn fallible_bad_nested_emit_detected() {
756 futures_executor::block_on(async {
757 let mut stream = pin!(try_fn_stream(|emitter| async move {
758 for i in 0..3 {
759 let emitter_ref = &emitter;
760 let mut stream_2 = pin!(try_fn_stream(|emitter_2| async move {
761 emitter_2.emit(0).await;
762 for j in 0..3 {
763 emitter_ref.emit(j).await;
764 }
765 Ok::<_, ()>(())
766 }));
767 while let Some(Ok(item)) = stream_2.next().await {
768 emitter.emit(3 * i + item).await;
769 }
770 }
771 Ok::<_, ()>(())
772 }));
773
774 let mut sum = 0;
775 while let Some(Ok(item)) = stream.next().await {
776 sum += item;
777 }
778 assert_eq!(sum, 36);
779 });
780 }
781}