Skip to main content

async_stream_lite/
try.rs

1use alloc::sync::Arc;
2use core::{
3	future::Future,
4	marker::PhantomData,
5	pin::Pin,
6	task::{Context, Poll}
7};
8
9use futures_core::stream::{FusedStream, Stream};
10
11use crate::{SharedStore, Yielder, enter};
12
13pin_project_lite::pin_project! {
14	/// A [`Stream`] created from a fallible, asynchronous generator-like function.
15	///
16	/// To create a [`TryAsyncStream`], use the [`try_async_stream`] function. See also [`crate::AsyncStream`].
17	pub struct TryAsyncStream<T, E, U> {
18		store: Arc<SharedStore<T>>,
19		done: bool,
20		#[pin]
21		generator: U,
22		_p: PhantomData<E>
23	}
24}
25
26impl<T, E, U> FusedStream for TryAsyncStream<T, E, U>
27where
28	U: Future<Output = Result<(), E>>
29{
30	#[inline(always)]
31	fn is_terminated(&self) -> bool {
32		self.done
33	}
34}
35
36impl<T, E, U> Stream for TryAsyncStream<T, E, U>
37where
38	U: Future<Output = Result<(), E>>
39{
40	type Item = Result<T, E>;
41
42	#[inline]
43	fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
44		let me = self.project();
45		if *me.done {
46			return Poll::Ready(None);
47		}
48
49		let res = {
50			let _enter = enter(&me.store);
51			me.generator.poll(cx)
52		};
53
54		*me.done = res.is_ready();
55
56		if let Poll::Ready(Err(e)) = res {
57			return Poll::Ready(Some(Err(e)));
58		} else if me.store.has_value() {
59			return Poll::Ready(me.store.cell.take().map(Ok));
60		}
61
62		if *me.done { Poll::Ready(None) } else { Poll::Pending }
63	}
64
65	#[inline]
66	fn size_hint(&self) -> (usize, Option<usize>) {
67		if self.done { (0, Some(0)) } else { (0, None) }
68	}
69}
70
71/// Create an asynchronous [`Stream`] from a fallible asynchronous generator function.
72///
73/// Usage is similar to [`crate::async_stream`], however the future returned by `generator` is assumed to return
74/// `Result<(), E>` instead of `()`.
75///
76/// ```
77/// use std::{io, net::SocketAddr};
78///
79/// use async_stream_lite::try_async_stream;
80/// use futures::stream::Stream;
81/// use tokio::net::{TcpListener, TcpStream};
82///
83/// fn bind_and_accept(addr: SocketAddr) -> impl Stream<Item = io::Result<TcpStream>> {
84/// 	try_async_stream(|yielder| async move {
85/// 		let mut listener = TcpListener::bind(addr).await?;
86/// 		loop {
87/// 			let (stream, addr) = listener.accept().await?;
88/// 			println!("received on {addr:?}");
89/// 			yielder.y(stream).await;
90/// 		}
91/// 	})
92/// }
93/// ```
94///
95/// The resulting stream yields `Result<T, E>`. The yielder function will cause the stream to yield `Ok(T)`. When an
96/// error is encountered, the stream yields `Err(E)` and is subsequently terminated.
97#[inline]
98pub fn try_async_stream<T, E, F, U>(generator: F) -> TryAsyncStream<T, E, U>
99where
100	F: FnOnce(Yielder<T>) -> U,
101	U: Future<Output = Result<(), E>>
102{
103	let store = Arc::new(SharedStore::default());
104	let generator = generator(Yielder { store: Arc::downgrade(&store) });
105	TryAsyncStream {
106		store,
107		done: false,
108		generator,
109		_p: PhantomData
110	}
111}
112
113#[cfg(test)]
114mod tests {
115	use futures::{Stream, StreamExt};
116
117	use super::try_async_stream;
118	use crate::tests::run_test;
119
120	#[test]
121	fn single_err() {
122		run_test(async {
123			let s = try_async_stream(|yielder| async move {
124				if true {
125					Err("hello")?;
126				} else {
127					yielder.y("world").await;
128				}
129				Ok(())
130			});
131
132			let values: Vec<_> = s.collect().await;
133			assert_eq!(1, values.len());
134			assert_eq!(Err("hello"), values[0]);
135		})
136	}
137
138	#[test]
139	fn yield_then_err() {
140		run_test(async {
141			let s = try_async_stream(|yielder| async move {
142				yielder.y("hello").await;
143				Err("world")?;
144				unreachable!();
145			});
146
147			let values: Vec<_> = s.collect().await;
148			assert_eq!(2, values.len());
149			assert_eq!(Ok("hello"), values[0]);
150			assert_eq!(Err("world"), values[1]);
151		})
152	}
153
154	#[test]
155	fn convert_err() {
156		run_test(async {
157			struct ErrorA(u8);
158			#[derive(PartialEq, Debug)]
159			struct ErrorB(u8);
160			impl From<ErrorA> for ErrorB {
161				fn from(a: ErrorA) -> Self {
162					ErrorB(a.0)
163				}
164			}
165
166			fn test() -> impl Stream<Item = Result<&'static str, ErrorB>> {
167				try_async_stream(|yielder| async move {
168					if true {
169						Err(ErrorA(1))?;
170					} else {
171						Err(ErrorB(2))?;
172					}
173					yielder.y("unreachable").await;
174					Ok(())
175				})
176			}
177
178			let values: Vec<_> = test().collect().await;
179			assert_eq!(1, values.len());
180			assert_eq!(Err(ErrorB(1)), values[0]);
181		})
182	}
183}