1#![doc = include_str!("../README.md")]
2
3use std::{
4 pin::Pin,
5 sync::{atomic::AtomicBool, Arc, Mutex},
6 task::{Poll, Waker},
7};
8
9use futures_util::{Future, Stream};
10use pin_project_lite::pin_project;
11use smallvec::SmallVec;
12
13pub struct StreamEmitter<T> {
15 inner: Arc<Mutex<Inner<T>>>,
16}
17
18pub struct TryStreamEmitter<T, E> {
20 inner: Arc<Mutex<Inner<Result<T, E>>>>,
21}
22
23struct Inner<T> {
24 polling: AtomicBool,
26 stream_waker: Option<Waker>,
29 pending_values: SmallVec<[T; 1]>,
32 pending_wakers: SmallVec<[Waker; 1]>,
33}
34
35pin_project! {
36 pub struct FnStream<T, Fut: Future<Output = ()>> {
38 #[pin]
39 fut: Fut,
40 inner: Arc<Mutex<Inner<T>>>,
41 }
42}
43
44pub fn fn_stream<T, Fut: Future<Output = ()>>(
64 func: impl FnOnce(StreamEmitter<T>) -> Fut,
65) -> FnStream<T, Fut> {
66 FnStream::new(func)
67}
68
69impl<T, Fut: Future<Output = ()>> FnStream<T, Fut> {
70 fn new<F: FnOnce(StreamEmitter<T>) -> Fut>(func: F) -> Self {
71 let inner = Arc::new(Mutex::new(Inner {
72 polling: AtomicBool::new(false),
73 stream_waker: None,
74 pending_values: SmallVec::new(),
75 pending_wakers: SmallVec::new(),
76 }));
77 let emitter = StreamEmitter {
78 inner: inner.clone(),
79 };
80 let fut = func(emitter);
81 Self { fut, inner }
82 }
83}
84
85impl<T, Fut: Future<Output = ()>> Stream for FnStream<T, Fut> {
86 type Item = T;
87
88 fn poll_next(
89 self: Pin<&mut Self>,
90 cx: &mut std::task::Context<'_>,
91 ) -> Poll<Option<Self::Item>> {
92 let this = self.project();
93
94 let mut inner_guard = this.inner.lock().expect("mutex was poisoned");
95 if let Some(value) = inner_guard.pending_values.pop() {
96 return Poll::Ready(Some(value));
97 }
98 if !inner_guard.pending_wakers.is_empty() {
99 for waker in inner_guard.pending_wakers.drain(..) {
100 if !waker.will_wake(cx.waker()) {
101 waker.wake();
102 }
103 }
104 }
105 if let Some(stream_waker) = inner_guard.stream_waker.as_mut() {
106 stream_waker.clone_from(cx.waker());
107 } else {
108 inner_guard.stream_waker = Some(cx.waker().clone());
109 }
110
111 let old_polling = inner_guard
112 .polling
113 .swap(true, std::sync::atomic::Ordering::Relaxed);
114 drop(inner_guard);
115 assert!(
116 !old_polling,
117 "async-fn-stream invariant violation: polling must be false before entering poll"
118 );
119 let r = this.fut.poll(cx);
120 let mut inner_guard = this.inner.lock().expect("mutex was poisoned");
121 inner_guard
122 .polling
123 .store(false, std::sync::atomic::Ordering::Relaxed);
124 match r {
125 std::task::Poll::Ready(()) => Poll::Ready(None),
126 std::task::Poll::Pending => {
127 if let Some(value) = inner_guard.pending_values.pop() {
128 Poll::Ready(Some(value))
129 } else {
130 Poll::Pending
131 }
132 }
133 }
134 }
135}
136
137pub fn try_fn_stream<T, E, Fut: Future<Output = Result<(), E>>>(
166 func: impl FnOnce(TryStreamEmitter<T, E>) -> Fut,
167) -> TryFnStream<T, E, Fut> {
168 TryFnStream::new(func)
169}
170
171pin_project! {
172 pub struct TryFnStream<T, E, Fut: Future<Output = Result<(), E>>> {
174 is_err: bool,
175 #[pin]
176 fut: Fut,
177 inner: Arc<Mutex<Inner<Result<T, E>>>>,
178 }
179}
180
181impl<T, E, Fut: Future<Output = Result<(), E>>> TryFnStream<T, E, Fut> {
182 fn new<F: FnOnce(TryStreamEmitter<T, E>) -> Fut>(func: F) -> Self {
183 let inner = Arc::new(Mutex::new(Inner {
184 polling: AtomicBool::new(false),
185 stream_waker: None,
186 pending_values: SmallVec::new(),
187 pending_wakers: SmallVec::new(),
188 }));
189 let emitter = TryStreamEmitter {
190 inner: inner.clone(),
191 };
192 let fut = func(emitter);
193 Self {
194 is_err: false,
195 fut,
196 inner,
197 }
198 }
199}
200
201impl<T, E, Fut: Future<Output = Result<(), E>>> Stream for TryFnStream<T, E, Fut> {
202 type Item = Result<T, E>;
203
204 fn poll_next(
205 self: Pin<&mut Self>,
206 cx: &mut std::task::Context<'_>,
207 ) -> Poll<Option<Self::Item>> {
208 if self.is_err {
210 return Poll::Ready(None);
211 }
212 let this = self.project();
213 let mut inner_guard = this.inner.lock().expect("mutex was poisoned");
214 if let Some(value) = inner_guard.pending_values.pop() {
215 return Poll::Ready(Some(value));
216 }
217 if !inner_guard.pending_wakers.is_empty() {
218 for waker in inner_guard.pending_wakers.drain(..) {
219 if !waker.will_wake(cx.waker()) {
220 waker.wake();
221 }
222 }
223 }
224 if let Some(stream_waker) = inner_guard.stream_waker.as_mut() {
225 stream_waker.clone_from(cx.waker());
226 } else {
227 inner_guard.stream_waker = Some(cx.waker().clone());
228 }
229
230 let old_polling = inner_guard
231 .polling
232 .swap(true, std::sync::atomic::Ordering::Relaxed);
233 drop(inner_guard);
234 assert!(
235 !old_polling,
236 "async-fn-stream invariant violation: polling must be false before entering poll"
237 );
238 let r = this.fut.poll(cx);
239 let mut inner_guard = this.inner.lock().expect("mutex was poisoned");
240 inner_guard
241 .polling
242 .store(false, std::sync::atomic::Ordering::Relaxed);
243 match r {
244 std::task::Poll::Ready(Ok(())) => Poll::Ready(None),
245 std::task::Poll::Ready(Err(e)) => {
246 *this.is_err = true;
247 Poll::Ready(Some(Err(e)))
248 }
249 std::task::Poll::Pending => {
250 if let Some(value) = inner_guard.pending_values.pop() {
251 Poll::Ready(Some(value))
252 } else {
253 Poll::Pending
254 }
255 }
256 }
257 }
258}
259
260impl<T> StreamEmitter<T> {
261 #[must_use = "Ensure that emit() is awaited"]
268 pub fn emit(&'_ self, value: T) -> EmitFuture<'_, T> {
269 EmitFuture::new(&self.inner, value)
270 }
271}
272
273impl<T, E> TryStreamEmitter<T, E> {
274 #[must_use = "Ensure that emit() is awaited"]
281 pub fn emit(&'_ self, value: T) -> EmitFuture<'_, Result<T, E>> {
282 EmitFuture::new(&self.inner, Ok(value))
283 }
284
285 #[must_use = "Ensure that emit_err() is awaited"]
292 pub fn emit_err(&'_ self, err: E) -> EmitFuture<'_, Result<T, E>> {
293 EmitFuture::new(&self.inner, Err(err))
294 }
295}
296
297pin_project! {
298 pub struct EmitFuture<'a, T> {
300 inner: &'a Mutex<Inner<T>>,
301 value: Option<T>,
302 }
303}
304
305impl<'a, T> EmitFuture<'a, T> {
306 fn new(inner: &'a Mutex<Inner<T>>, value: T) -> Self {
307 Self {
308 inner,
309 value: Some(value),
310 }
311 }
312}
313
314impl<T> Future for EmitFuture<'_, T> {
315 type Output = ();
316
317 fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
318 let this = self.project();
319 let mut inner_guard = this.inner.lock().expect("Mutex was poisoned");
320 let inner = &mut *inner_guard;
321 assert!(
322 inner.polling.load(std::sync::atomic::Ordering::Relaxed),
323 "StreamEmitter::emit().await should only be called in context of `fn_stream()`/`try_fn_stream()`"
324 );
325
326 if let Some(value) = this.value.take() {
327 inner.pending_values.push(value);
328 let is_same_waker = if let Some(stream_waker) = inner.stream_waker.as_ref() {
329 stream_waker.will_wake(cx.waker())
330 } else {
331 false
332 };
333 if !is_same_waker {
334 inner.pending_wakers.push(cx.waker().clone());
335 }
336 Poll::Pending
337 } else if inner.pending_values.is_empty() {
338 Poll::Ready(())
342 } else {
343 Poll::Pending
344 }
345 }
346}
347
348#[cfg(test)]
349mod tests {
350 use std::io::ErrorKind;
351
352 use futures_util::{pin_mut, stream::FuturesUnordered, StreamExt};
353
354 use super::*;
355
356 #[test]
357 fn infallible_works() {
358 futures_executor::block_on(async {
359 let stream = fn_stream(|emitter| async move {
360 eprintln!("stream 1");
361 emitter.emit(1).await;
362 eprintln!("stream 2");
363 emitter.emit(2).await;
364 eprintln!("stream 3");
365 });
366 pin_mut!(stream);
367 assert_eq!(Some(1), stream.next().await);
368 assert_eq!(Some(2), stream.next().await);
369 assert_eq!(None, stream.next().await);
370 });
371 }
372
373 #[test]
374 fn infallible_lifetime() {
375 let a = 1;
376 futures_executor::block_on(async {
377 let b = 2;
378 let a = &a;
379 let b = &b;
380 let stream = fn_stream(|emitter| async move {
381 eprintln!("stream 1");
382 emitter.emit(a).await;
383 eprintln!("stream 2");
384 emitter.emit(b).await;
385 eprintln!("stream 3");
386 });
387 pin_mut!(stream);
388 assert_eq!(Some(a), stream.next().await);
389 assert_eq!(Some(b), stream.next().await);
390 assert_eq!(None, stream.next().await);
391 });
392 }
393
394 #[test]
395 fn infallible_unawaited_emit_is_ignored() {
396 futures_executor::block_on(async {
397 #[expect(
398 unused_must_use,
399 reason = "this code intentionally does not await emitter.emit()"
400 )]
401 let stream = fn_stream(|emitter| async move {
402 emitter.emit(1);
403 emitter.emit(2);
404 emitter.emit(3).await;
405 });
406 pin_mut!(stream);
407 assert_eq!(Some(3), stream.next().await);
408 assert_eq!(None, stream.next().await);
409 });
410 }
411
412 #[test]
413 fn fallible_works() {
414 futures_executor::block_on(async {
415 let stream = try_fn_stream(|emitter| async move {
416 eprintln!("try stream 1");
417 emitter.emit(1).await;
418 eprintln!("try stream 2");
419 emitter.emit(2).await;
420 eprintln!("try stream 3");
421 Err(std::io::Error::from(ErrorKind::Other))
422 });
423 pin_mut!(stream);
424 assert_eq!(1, stream.next().await.unwrap().unwrap());
425 assert_eq!(2, stream.next().await.unwrap().unwrap());
426 assert!(stream.next().await.unwrap().is_err());
427 assert!(stream.next().await.is_none());
428 });
429 }
430
431 #[test]
432 fn fallible_emit_err_works() {
433 futures_executor::block_on(async {
434 let stream = try_fn_stream(|emitter| async move {
435 eprintln!("try stream 1");
436 emitter.emit(1).await;
437 eprintln!("try stream 2");
438 emitter.emit(2).await;
439 eprintln!("try stream 3");
440 emitter
441 .emit_err(std::io::Error::from(ErrorKind::Other))
442 .await;
443 eprintln!("try stream 4");
444 Err(std::io::Error::from(ErrorKind::Other))
445 });
446 pin_mut!(stream);
447 assert_eq!(1, stream.next().await.unwrap().unwrap());
448 assert_eq!(2, stream.next().await.unwrap().unwrap());
449 assert!(stream.next().await.unwrap().is_err());
450 assert!(stream.next().await.unwrap().is_err());
451 assert!(stream.next().await.is_none());
452 });
453 }
454
455 #[test]
456 fn method_async() {
457 struct St {
458 a: String,
459 }
460
461 impl St {
462 async fn f1(&self) -> impl Stream<Item = &str> {
463 self.f2().await
464 }
465
466 #[allow(clippy::unused_async)]
467 async fn f2(&self) -> impl Stream<Item = &str> {
468 fn_stream(|emitter| async move {
469 emitter.emit(self.a.as_str()).await;
470 emitter.emit(self.a.as_str()).await;
471 emitter.emit(self.a.as_str()).await;
472 })
473 }
474 }
475
476 futures_executor::block_on(async {
477 let l = St {
478 a: "qwe".to_owned(),
479 };
480 let s = l.f1().await;
481 let z: Vec<&str> = s.collect().await;
482 assert_eq!(z, ["qwe", "qwe", "qwe"]);
483 });
484 }
485
486 #[test]
487 fn tokio_join_one_works() {
488 futures_executor::block_on(async {
489 let stream = fn_stream(|emitter| async move {
490 tokio::join!(async { emitter.emit(1).await },);
491 emitter.emit(2).await;
492 });
493 pin_mut!(stream);
494 assert_eq!(Some(1), stream.next().await);
495 assert_eq!(Some(2), stream.next().await);
496 assert_eq!(None, stream.next().await);
497 });
498 }
499
500 #[test]
501 fn tokio_join_many_works() {
502 futures_executor::block_on(async {
503 let stream = fn_stream(|emitter| async move {
504 eprintln!("try stream 1");
505 tokio::join!(
506 async { emitter.emit(1).await },
507 async { emitter.emit(2).await },
508 async { emitter.emit(3).await },
509 );
510 emitter.emit(4).await;
511 });
512 pin_mut!(stream);
513 for _ in 0..3 {
514 let item = stream.next().await;
515 assert!(matches!(item, Some(1..=3)));
516 }
517 assert_eq!(Some(4), stream.next().await);
518 assert_eq!(None, stream.next().await);
519 });
520 }
521
522 #[test]
523 fn tokio_futures_unordered_one_works() {
524 futures_executor::block_on(async {
525 let stream = fn_stream(|emitter| async move {
526 let mut futs: FuturesUnordered<_> = (1..=1)
527 .map(|i| {
528 let emitter = &emitter;
529 async move { emitter.emit(i).await }
530 })
531 .collect();
532 while futs.next().await.is_some() {}
533 emitter.emit(2).await;
534 });
535 pin_mut!(stream);
536 assert_eq!(Some(1), stream.next().await);
537 assert_eq!(Some(2), stream.next().await);
538 assert_eq!(None, stream.next().await);
539 });
540 }
541
542 #[test]
543 fn tokio_futures_unordered_many_works() {
544 futures_executor::block_on(async {
545 let stream = fn_stream(|emitter| async move {
546 let mut futs: FuturesUnordered<_> = (1..=3)
547 .map(|i| {
548 let emitter = &emitter;
549 async move { emitter.emit(i).await }
550 })
551 .collect();
552 while futs.next().await.is_some() {}
553 emitter.emit(4).await;
554 });
555 pin_mut!(stream);
556 for _ in 1..=3 {
557 let item = stream.next().await;
558 assert!(matches!(item, Some(1..=3)));
559 }
560 assert_eq!(Some(4), stream.next().await);
561 assert_eq!(None, stream.next().await);
562 });
563 }
564}