async_stream_lite/
try.rs

1use alloc::sync::Arc;
2use core::{
3	future::Future,
4	marker::PhantomData,
5	pin::Pin,
6	task::{Context, Poll}
7};
8
9use futures_core::stream::{FusedStream, Stream};
10
11use crate::{SharedStore, Yielder, enter};
12
13pin_project_lite::pin_project! {
14	/// A [`Stream`] created from a fallible, asynchronous generator-like function.
15	///
16	/// To create a [`TryAsyncStream`], use the [`try_async_stream`] function. See also [`crate::AsyncStream`].
17	pub struct TryAsyncStream<T, E, U> {
18		store: Arc<SharedStore<T>>,
19		done: bool,
20		#[pin]
21		generator: U,
22		_p: PhantomData<E>
23	}
24}
25
26impl<T, E, U> FusedStream for TryAsyncStream<T, E, U>
27where
28	U: Future<Output = Result<(), E>>
29{
30	fn is_terminated(&self) -> bool {
31		self.done
32	}
33}
34
35impl<T, E, U> Stream for TryAsyncStream<T, E, U>
36where
37	U: Future<Output = Result<(), E>>
38{
39	type Item = Result<T, E>;
40
41	fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
42		let me = self.project();
43		if *me.done {
44			return Poll::Ready(None);
45		}
46
47		let res = {
48			let _enter = enter(&me.store);
49			me.generator.poll(cx)
50		};
51
52		*me.done = res.is_ready();
53
54		if let Poll::Ready(Err(e)) = res {
55			return Poll::Ready(Some(Err(e)));
56		} else if me.store.has_value() {
57			return Poll::Ready(me.store.cell.take().map(Ok));
58		}
59
60		if *me.done { Poll::Ready(None) } else { Poll::Pending }
61	}
62
63	fn size_hint(&self) -> (usize, Option<usize>) {
64		if self.done { (0, Some(0)) } else { (0, None) }
65	}
66}
67
68/// Create an asynchronous [`Stream`] from a fallible asynchronous generator function.
69///
70/// Usage is similar to [`crate::async_stream`], however the future returned by `generator` is assumed to return
71/// `Result<(), E>` instead of `()`.
72///
73/// ```
74/// use std::{io, net::SocketAddr};
75///
76/// use async_stream_lite::try_async_stream;
77/// use futures::stream::Stream;
78/// use tokio::net::{TcpListener, TcpStream};
79///
80/// fn bind_and_accept(addr: SocketAddr) -> impl Stream<Item = io::Result<TcpStream>> {
81/// 	try_async_stream(|yielder| async move {
82/// 		let mut listener = TcpListener::bind(addr).await?;
83/// 		loop {
84/// 			let (stream, addr) = listener.accept().await?;
85/// 			println!("received on {addr:?}");
86/// 			yielder.r#yield(stream).await;
87/// 		}
88/// 	})
89/// }
90/// ```
91///
92/// The resulting stream yields `Result<T, E>`. The yielder function will cause the stream to yield `Ok(T)`. When an
93/// error is encountered, the stream yields `Err(E)` and is subsequently terminated.
94pub fn try_async_stream<T, E, F, U>(generator: F) -> TryAsyncStream<T, E, U>
95where
96	F: FnOnce(Yielder<T>) -> U,
97	U: Future<Output = Result<(), E>>
98{
99	let store = Arc::new(SharedStore::default());
100	let generator = generator(Yielder { store: Arc::downgrade(&store) });
101	TryAsyncStream {
102		store,
103		done: false,
104		generator,
105		_p: PhantomData
106	}
107}
108
109#[cfg(test)]
110mod tests {
111	use futures::{Stream, StreamExt};
112
113	use super::try_async_stream;
114
115	#[tokio::test]
116	async fn single_err() {
117		let s = try_async_stream(|yielder| async move {
118			if true {
119				Err("hello")?;
120			} else {
121				yielder.r#yield("world").await;
122			}
123			Ok(())
124		});
125
126		let values: Vec<_> = s.collect().await;
127		assert_eq!(1, values.len());
128		assert_eq!(Err("hello"), values[0]);
129	}
130
131	#[tokio::test]
132	async fn yield_then_err() {
133		let s = try_async_stream(|yielder| async move {
134			yielder.r#yield("hello").await;
135			Err("world")?;
136			unreachable!();
137		});
138
139		let values: Vec<_> = s.collect().await;
140		assert_eq!(2, values.len());
141		assert_eq!(Ok("hello"), values[0]);
142		assert_eq!(Err("world"), values[1]);
143	}
144
145	#[tokio::test]
146	async fn convert_err() {
147		struct ErrorA(u8);
148		#[derive(PartialEq, Debug)]
149		struct ErrorB(u8);
150		impl From<ErrorA> for ErrorB {
151			fn from(a: ErrorA) -> Self {
152				ErrorB(a.0)
153			}
154		}
155
156		fn test() -> impl Stream<Item = Result<&'static str, ErrorB>> {
157			try_async_stream(|yielder| async move {
158				if true {
159					Err(ErrorA(1))?;
160				} else {
161					Err(ErrorB(2))?;
162				}
163				yielder.r#yield("unreachable").await;
164				Ok(())
165			})
166		}
167
168		let values: Vec<_> = test().collect().await;
169		assert_eq!(1, values.len());
170		assert_eq!(Err(ErrorB(1)), values[0]);
171	}
172}