1#![doc = include_str!("../README.md")]
2
3use std::{
4 cell::{Cell, UnsafeCell},
5 panic::{RefUnwindSafe, UnwindSafe},
6 pin::Pin,
7 sync::Arc,
8 task::{Poll, Waker},
9};
10
11use futures_util::{Future, Stream};
12use pin_project_lite::pin_project;
13use smallvec::SmallVec;
14
15pub struct StreamEmitter<T> {
17 inner: Arc<UnsafeCell<Inner<T>>>,
18}
19
20pub struct TryStreamEmitter<T, E> {
22 inner: Arc<UnsafeCell<Inner<Result<T, E>>>>,
23}
24
25thread_local! {
26 static ACTIVE_STREAM_INNER: Cell<*const ()> = const { Cell::new(std::ptr::null()) };
28}
29
30struct ActiveStreamPointerGuard {
32 old_ptr: *const (),
33}
34
35impl ActiveStreamPointerGuard {
36 fn set_active_ptr(ptr: *const ()) -> Self {
37 let old_ptr = ACTIVE_STREAM_INNER.with(|thread_ptr| thread_ptr.replace(ptr));
38 Self { old_ptr }
39 }
40}
41
42impl Drop for ActiveStreamPointerGuard {
43 fn drop(&mut self) {
44 ACTIVE_STREAM_INNER.with(|thread_ptr| thread_ptr.set(self.old_ptr));
45 }
46}
47
48struct Inner<T> {
56 stream_waker: Option<Waker>,
59 pending_values: SmallVec<[T; 1]>,
62 pending_wakers: SmallVec<[Waker; 1]>,
63}
64
65unsafe impl<T: Send, Fut: Future<Output = ()> + Send> Send for FnStream<T, Fut> {}
67unsafe impl<T: Send, Fut: Future<Output = ()> + Sync> Sync for FnStream<T, Fut> {}
69impl<T: UnwindSafe, Fut: Future<Output = ()> + UnwindSafe> UnwindSafe for FnStream<T, Fut> {}
70impl<T: RefUnwindSafe, Fut: Future<Output = ()> + RefUnwindSafe> RefUnwindSafe
71 for FnStream<T, Fut>
72{
73}
74unsafe impl<T: Send, E: Send, Fut: Future<Output = Result<(), E>> + Send> Send
76 for TryFnStream<T, E, Fut>
77{
78}
79unsafe impl<T: Send, E: Send, Fut: Future<Output = Result<(), E>> + Sync> Sync
81 for TryFnStream<T, E, Fut>
82{
83}
84impl<T: UnwindSafe, E: UnwindSafe, Fut: Future<Output = Result<(), E>> + UnwindSafe> UnwindSafe
85 for TryFnStream<T, E, Fut>
86{
87}
88impl<T: RefUnwindSafe, E: RefUnwindSafe, Fut: Future<Output = Result<(), E>> + RefUnwindSafe>
89 RefUnwindSafe for TryFnStream<T, E, Fut>
90{
91}
92unsafe impl<T: Send> Send for StreamEmitter<T> {}
94unsafe impl<T: Send> Sync for StreamEmitter<T> {}
96impl<T: UnwindSafe> UnwindSafe for StreamEmitter<T> {}
97impl<T: RefUnwindSafe> RefUnwindSafe for StreamEmitter<T> {}
98unsafe impl<T: Send, E: Send> Send for TryStreamEmitter<T, E> {}
100unsafe impl<T: Send, E: Send> Sync for TryStreamEmitter<T, E> {}
102impl<T: UnwindSafe, E: UnwindSafe> UnwindSafe for TryStreamEmitter<T, E> {}
103impl<T: RefUnwindSafe, E: RefUnwindSafe> RefUnwindSafe for TryStreamEmitter<T, E> {}
104unsafe impl<T: Send> Send for EmitFuture<'_, T> {}
106unsafe impl<T: Send> Sync for EmitFuture<'_, T> {}
108impl<T: UnwindSafe> UnwindSafe for EmitFuture<'_, T> {}
109impl<T: RefUnwindSafe> RefUnwindSafe for EmitFuture<'_, T> {}
110
111pin_project! {
112 pub struct FnStream<T, Fut: Future<Output = ()>> {
114 #[pin]
115 fut: Fut,
116 inner: Arc<UnsafeCell<Inner<T>>>,
117 }
118}
119
120pub fn fn_stream<T, Fut: Future<Output = ()>>(
140 func: impl FnOnce(StreamEmitter<T>) -> Fut,
141) -> FnStream<T, Fut> {
142 FnStream::new(func)
143}
144
145impl<T, Fut: Future<Output = ()>> FnStream<T, Fut> {
146 fn new<F: FnOnce(StreamEmitter<T>) -> Fut>(func: F) -> Self {
147 let inner = Arc::new(UnsafeCell::new(Inner {
148 stream_waker: None,
149 pending_values: SmallVec::new(),
150 pending_wakers: SmallVec::new(),
151 }));
152 let emitter = StreamEmitter {
153 inner: inner.clone(),
154 };
155 let fut = func(emitter);
156 Self { fut, inner }
157 }
158}
159
160impl<T, Fut: Future<Output = ()>> Stream for FnStream<T, Fut> {
161 type Item = T;
162
163 fn poll_next(
164 self: Pin<&mut Self>,
165 cx: &mut std::task::Context<'_>,
166 ) -> Poll<Option<Self::Item>> {
167 let this = self.project();
168
169 let inner = unsafe { &mut *this.inner.get() };
174 if let Some(value) = inner.pending_values.pop() {
175 return Poll::Ready(Some(value));
176 }
177 if !inner.pending_wakers.is_empty() {
178 for waker in inner.pending_wakers.drain(..) {
179 if !waker.will_wake(cx.waker()) {
180 waker.wake();
181 }
182 }
183 }
184 if let Some(stream_waker) = inner.stream_waker.as_mut() {
185 stream_waker.clone_from(cx.waker());
186 } else {
187 inner.stream_waker = Some(cx.waker().clone());
188 }
189
190 _ = inner;
192
193 let polling_ptr_guard =
197 ActiveStreamPointerGuard::set_active_ptr(Arc::as_ptr(&*this.inner).cast());
198 let r = this.fut.poll(cx);
199 drop(polling_ptr_guard);
200
201 let inner = unsafe { &mut *this.inner.get() };
208
209 match r {
210 std::task::Poll::Ready(()) => Poll::Ready(None),
211 std::task::Poll::Pending => {
212 if let Some(value) = inner.pending_values.pop() {
213 Poll::Ready(Some(value))
214 } else {
215 Poll::Pending
216 }
217 }
218 }
219 }
220}
221
222pub fn try_fn_stream<T, E, Fut: Future<Output = Result<(), E>>>(
251 func: impl FnOnce(TryStreamEmitter<T, E>) -> Fut,
252) -> TryFnStream<T, E, Fut> {
253 TryFnStream::new(func)
254}
255
256pin_project! {
257 pub struct TryFnStream<T, E, Fut: Future<Output = Result<(), E>>> {
259 is_err: bool,
260 #[pin]
261 fut: Fut,
262 inner: Arc<UnsafeCell<Inner<Result<T, E>>>>,
263 }
264}
265
266impl<T, E, Fut: Future<Output = Result<(), E>>> TryFnStream<T, E, Fut> {
267 fn new<F: FnOnce(TryStreamEmitter<T, E>) -> Fut>(func: F) -> Self {
268 let inner = Arc::new(UnsafeCell::new(Inner {
269 stream_waker: None,
270 pending_values: SmallVec::new(),
271 pending_wakers: SmallVec::new(),
272 }));
273 let emitter = TryStreamEmitter {
274 inner: inner.clone(),
275 };
276 let fut = func(emitter);
277 Self {
278 is_err: false,
279 fut,
280 inner,
281 }
282 }
283}
284
285impl<T, E, Fut: Future<Output = Result<(), E>>> Stream for TryFnStream<T, E, Fut> {
286 type Item = Result<T, E>;
287
288 fn poll_next(
289 self: Pin<&mut Self>,
290 cx: &mut std::task::Context<'_>,
291 ) -> Poll<Option<Self::Item>> {
292 if self.is_err {
294 return Poll::Ready(None);
295 }
296 let this = self.project();
297 let inner = unsafe { &mut *this.inner.get() };
302 if let Some(value) = inner.pending_values.pop() {
303 return Poll::Ready(Some(value));
304 }
305 if !inner.pending_wakers.is_empty() {
306 for waker in inner.pending_wakers.drain(..) {
307 if !waker.will_wake(cx.waker()) {
308 waker.wake();
309 }
310 }
311 }
312 if let Some(stream_waker) = inner.stream_waker.as_mut() {
313 stream_waker.clone_from(cx.waker());
314 } else {
315 inner.stream_waker = Some(cx.waker().clone());
316 }
317
318 _ = inner;
320
321 let polling_ptr_guard =
325 ActiveStreamPointerGuard::set_active_ptr(Arc::as_ptr(&*this.inner).cast());
326 let r = this.fut.poll(cx);
327 drop(polling_ptr_guard);
328
329 let inner = unsafe { &mut *this.inner.get() };
336 match r {
337 std::task::Poll::Ready(Ok(())) => Poll::Ready(None),
338 std::task::Poll::Ready(Err(e)) => {
339 *this.is_err = true;
340 Poll::Ready(Some(Err(e)))
341 }
342 std::task::Poll::Pending => {
343 if let Some(value) = inner.pending_values.pop() {
344 Poll::Ready(Some(value))
345 } else {
346 Poll::Pending
347 }
348 }
349 }
350 }
351}
352
353impl<T> StreamEmitter<T> {
354 #[must_use = "Ensure that emit() is awaited"]
361 pub fn emit(&'_ self, value: T) -> EmitFuture<'_, T> {
362 EmitFuture::new(&self.inner, value)
363 }
364}
365
366impl<T, E> TryStreamEmitter<T, E> {
367 #[must_use = "Ensure that emit() is awaited"]
374 pub fn emit(&'_ self, value: T) -> EmitFuture<'_, Result<T, E>> {
375 EmitFuture::new(&self.inner, Ok(value))
376 }
377
378 #[must_use = "Ensure that emit_err() is awaited"]
385 pub fn emit_err(&'_ self, err: E) -> EmitFuture<'_, Result<T, E>> {
386 EmitFuture::new(&self.inner, Err(err))
387 }
388}
389
390pin_project! {
391 pub struct EmitFuture<'a, T> {
393 inner: &'a UnsafeCell<Inner<T>>,
394 value: Option<T>,
395 }
396}
397
398impl<'a, T> EmitFuture<'a, T> {
399 fn new(inner: &'a UnsafeCell<Inner<T>>, value: T) -> Self {
400 Self {
401 inner,
402 value: Some(value),
403 }
404 }
405}
406
407impl<T> Future for EmitFuture<'_, T> {
408 type Output = ();
409
410 fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
411 let this = self.project();
412 assert!(
413 ACTIVE_STREAM_INNER.get() == std::ptr::from_ref(*this.inner).cast::<()>(),
414 "StreamEmitter::emit().await should only be called in the context of the corresponding `fn_stream()`/`try_fn_stream()`"
415 );
416 let inner = unsafe { &mut *this.inner.get() };
422
423 if let Some(value) = this.value.take() {
424 inner.pending_values.push(value);
425 let is_same_waker = if let Some(stream_waker) = inner.stream_waker.as_ref() {
426 stream_waker.will_wake(cx.waker())
427 } else {
428 false
429 };
430 if !is_same_waker {
431 inner.pending_wakers.push(cx.waker().clone());
432 }
433 Poll::Pending
434 } else if inner.pending_values.is_empty() {
435 Poll::Ready(())
439 } else {
440 Poll::Pending
441 }
442 }
443}
444
445#[cfg(test)]
446mod tests {
447 use std::{io::ErrorKind, pin::pin};
448
449 use futures_util::{pin_mut, stream::FuturesUnordered, StreamExt};
450
451 use super::*;
452
453 #[test]
454 fn infallible_works() {
455 futures_executor::block_on(async {
456 let stream = fn_stream(|emitter| async move {
457 eprintln!("stream 1");
458 emitter.emit(1).await;
459 eprintln!("stream 2");
460 emitter.emit(2).await;
461 eprintln!("stream 3");
462 });
463 pin_mut!(stream);
464 assert_eq!(Some(1), stream.next().await);
465 assert_eq!(Some(2), stream.next().await);
466 assert_eq!(None, stream.next().await);
467 });
468 }
469
470 #[test]
471 fn infallible_lifetime() {
472 let a = 1;
473 futures_executor::block_on(async {
474 let b = 2;
475 let a = &a;
476 let b = &b;
477 let stream = fn_stream(|emitter| async move {
478 eprintln!("stream 1");
479 emitter.emit(a).await;
480 eprintln!("stream 2");
481 emitter.emit(b).await;
482 eprintln!("stream 3");
483 });
484 pin_mut!(stream);
485 assert_eq!(Some(a), stream.next().await);
486 assert_eq!(Some(b), stream.next().await);
487 assert_eq!(None, stream.next().await);
488 });
489 }
490
491 #[test]
492 fn infallible_unawaited_emit_is_ignored() {
493 futures_executor::block_on(async {
494 #[expect(
495 unused_must_use,
496 reason = "this code intentionally does not await emitter.emit()"
497 )]
498 let stream = fn_stream(|emitter| async move {
499 emitter.emit(1);
500 emitter.emit(2);
501 emitter.emit(3).await;
502 });
503 pin_mut!(stream);
504 assert_eq!(Some(3), stream.next().await);
505 assert_eq!(None, stream.next().await);
506 });
507 }
508
509 #[test]
510 fn fallible_works() {
511 futures_executor::block_on(async {
512 let stream = try_fn_stream(|emitter| async move {
513 eprintln!("try stream 1");
514 emitter.emit(1).await;
515 eprintln!("try stream 2");
516 emitter.emit(2).await;
517 eprintln!("try stream 3");
518 Err(std::io::Error::from(ErrorKind::Other))
519 });
520 pin_mut!(stream);
521 assert_eq!(1, stream.next().await.unwrap().unwrap());
522 assert_eq!(2, stream.next().await.unwrap().unwrap());
523 assert!(stream.next().await.unwrap().is_err());
524 assert!(stream.next().await.is_none());
525 });
526 }
527
528 #[test]
529 fn fallible_emit_err_works() {
530 futures_executor::block_on(async {
531 let stream = try_fn_stream(|emitter| async move {
532 eprintln!("try stream 1");
533 emitter.emit(1).await;
534 eprintln!("try stream 2");
535 emitter.emit(2).await;
536 eprintln!("try stream 3");
537 emitter
538 .emit_err(std::io::Error::from(ErrorKind::Other))
539 .await;
540 eprintln!("try stream 4");
541 Err(std::io::Error::from(ErrorKind::Other))
542 });
543 pin_mut!(stream);
544 assert_eq!(1, stream.next().await.unwrap().unwrap());
545 assert_eq!(2, stream.next().await.unwrap().unwrap());
546 assert!(stream.next().await.unwrap().is_err());
547 assert!(stream.next().await.unwrap().is_err());
548 assert!(stream.next().await.is_none());
549 });
550 }
551
552 #[test]
553 fn method_async() {
554 struct St {
555 a: String,
556 }
557
558 impl St {
559 async fn f1(&self) -> impl Stream<Item = &str> {
560 self.f2().await
561 }
562
563 #[allow(clippy::unused_async)]
564 async fn f2(&self) -> impl Stream<Item = &str> {
565 fn_stream(|emitter| async move {
566 emitter.emit(self.a.as_str()).await;
567 emitter.emit(self.a.as_str()).await;
568 emitter.emit(self.a.as_str()).await;
569 })
570 }
571 }
572
573 futures_executor::block_on(async {
574 let l = St {
575 a: "qwe".to_owned(),
576 };
577 let s = l.f1().await;
578 let z: Vec<&str> = s.collect().await;
579 assert_eq!(z, ["qwe", "qwe", "qwe"]);
580 });
581 }
582
583 #[test]
584 fn tokio_join_one_works() {
585 futures_executor::block_on(async {
586 let stream = fn_stream(|emitter| async move {
587 tokio::join!(async { emitter.emit(1).await },);
588 emitter.emit(2).await;
589 });
590 pin_mut!(stream);
591 assert_eq!(Some(1), stream.next().await);
592 assert_eq!(Some(2), stream.next().await);
593 assert_eq!(None, stream.next().await);
594 });
595 }
596
597 #[test]
598 fn tokio_join_many_works() {
599 futures_executor::block_on(async {
600 let stream = fn_stream(|emitter| async move {
601 eprintln!("try stream 1");
602 tokio::join!(
603 async { emitter.emit(1).await },
604 async { emitter.emit(2).await },
605 async { emitter.emit(3).await },
606 );
607 emitter.emit(4).await;
608 });
609 pin_mut!(stream);
610 for _ in 0..3 {
611 let item = stream.next().await;
612 assert!(matches!(item, Some(1..=3)));
613 }
614 assert_eq!(Some(4), stream.next().await);
615 assert_eq!(None, stream.next().await);
616 });
617 }
618
619 #[test]
620 fn tokio_futures_unordered_one_works() {
621 futures_executor::block_on(async {
622 let stream = fn_stream(|emitter| async move {
623 let mut futs: FuturesUnordered<_> = (1..=1)
624 .map(|i| {
625 let emitter = &emitter;
626 async move { emitter.emit(i).await }
627 })
628 .collect();
629 while futs.next().await.is_some() {}
630 emitter.emit(2).await;
631 });
632 pin_mut!(stream);
633 assert_eq!(Some(1), stream.next().await);
634 assert_eq!(Some(2), stream.next().await);
635 assert_eq!(None, stream.next().await);
636 });
637 }
638
639 #[test]
640 fn tokio_futures_unordered_many_works() {
641 futures_executor::block_on(async {
642 let stream = fn_stream(|emitter| async move {
643 let mut futs: FuturesUnordered<_> = (1..=3)
644 .map(|i| {
645 let emitter = &emitter;
646 async move { emitter.emit(i).await }
647 })
648 .collect();
649 while futs.next().await.is_some() {}
650 emitter.emit(4).await;
651 });
652 pin_mut!(stream);
653 for _ in 1..=3 {
654 let item = stream.next().await;
655 assert!(matches!(item, Some(1..=3)));
656 }
657 assert_eq!(Some(4), stream.next().await);
658 assert_eq!(None, stream.next().await);
659 });
660 }
661
662 #[test]
663 fn infallible_nested_streams_work() {
664 futures_executor::block_on(async {
665 let mut stream = pin!(fn_stream(|emitter| async move {
666 for i in 0..3 {
667 let mut stream_2 = pin!(fn_stream(|emitter| async move {
668 for j in 0..3 {
669 emitter.emit(j).await;
670 }
671 }));
672 while let Some(item) = stream_2.next().await {
673 emitter.emit(3 * i + item).await;
674 }
675 }
676 }));
677 let mut sum = 0;
678 while let Some(item) = stream.next().await {
679 sum += item;
680 }
681 assert_eq!(sum, 36);
682 });
683 }
684
685 #[test]
686 fn fallible_nested_streams_work() {
687 futures_executor::block_on(async {
688 let mut stream = pin!(try_fn_stream(|emitter| async move {
689 for i in 0..3 {
690 let mut stream_2 = pin!(try_fn_stream(|emitter| async move {
691 for j in 0..3 {
692 emitter.emit(j).await;
693 }
694 Ok::<_, ()>(())
695 }));
696 while let Some(Ok(item)) = stream_2.next().await {
697 emitter.emit(3 * i + item).await;
698 }
699 }
700 Ok::<_, ()>(())
701 }));
702 let mut sum = 0;
703 while let Some(Ok(item)) = stream.next().await {
704 sum += item;
705 }
706 assert_eq!(sum, 36);
707 });
708 }
709
710 #[test]
711 #[should_panic(
712 expected = "StreamEmitter::emit().await should only be called in the context of the corresponding `fn_stream()`/`try_fn_stream()`"
713 )]
714 fn infallible_bad_nested_emit_detected() {
715 futures_executor::block_on(async {
716 let mut stream = pin!(fn_stream(|emitter| async move {
717 for i in 0..3 {
718 let emitter_ref = &emitter;
719 let mut stream_2 = pin!(fn_stream(|emitter_2| async move {
720 emitter_2.emit(0).await;
721 for j in 0..3 {
722 emitter_ref.emit(j).await;
723 }
724 }));
725 while let Some(item) = stream_2.next().await {
726 emitter.emit(3 * i + item).await;
727 }
728 }
729 }));
730
731 let mut sum = 0;
732 while let Some(item) = stream.next().await {
733 sum += item;
734 }
735 assert_eq!(sum, 36);
736 });
737 }
738
739 #[test]
740 #[should_panic(
741 expected = "StreamEmitter::emit().await should only be called in the context of the corresponding `fn_stream()`/`try_fn_stream()`"
742 )]
743 fn fallible_bad_nested_emit_detected() {
744 futures_executor::block_on(async {
745 let mut stream = pin!(try_fn_stream(|emitter| async move {
746 for i in 0..3 {
747 let emitter_ref = &emitter;
748 let mut stream_2 = pin!(try_fn_stream(|emitter_2| async move {
749 emitter_2.emit(0).await;
750 for j in 0..3 {
751 emitter_ref.emit(j).await;
752 }
753 Ok::<_, ()>(())
754 }));
755 while let Some(Ok(item)) = stream_2.next().await {
756 emitter.emit(3 * i + item).await;
757 }
758 }
759 Ok::<_, ()>(())
760 }));
761
762 let mut sum = 0;
763 while let Some(Ok(item)) = stream.next().await {
764 sum += item;
765 }
766 assert_eq!(sum, 36);
767 });
768 }
769}