1use alloc::sync::Arc;
2use core::{
3 future::Future,
4 marker::PhantomData,
5 pin::Pin,
6 task::{Context, Poll}
7};
8
9use futures_core::stream::{FusedStream, Stream};
10
11use crate::{SharedStore, Yielder, enter};
12
13pin_project_lite::pin_project! {
14 pub struct TryAsyncStream<T, E, U> {
18 store: Arc<SharedStore<T>>,
19 done: bool,
20 #[pin]
21 generator: U,
22 _p: PhantomData<E>
23 }
24}
25
26impl<T, E, U> FusedStream for TryAsyncStream<T, E, U>
27where
28 U: Future<Output = Result<(), E>>
29{
30 #[inline(always)]
31 fn is_terminated(&self) -> bool {
32 self.done
33 }
34}
35
36impl<T, E, U> Stream for TryAsyncStream<T, E, U>
37where
38 U: Future<Output = Result<(), E>>
39{
40 type Item = Result<T, E>;
41
42 #[inline]
43 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
44 let me = self.project();
45 if *me.done {
46 return Poll::Ready(None);
47 }
48
49 let res = {
50 let _enter = enter(&me.store);
51 me.generator.poll(cx)
52 };
53
54 *me.done = res.is_ready();
55
56 if let Poll::Ready(Err(e)) = res {
57 return Poll::Ready(Some(Err(e)));
58 } else if me.store.has_value() {
59 return Poll::Ready(me.store.cell.take().map(Ok));
60 }
61
62 if *me.done { Poll::Ready(None) } else { Poll::Pending }
63 }
64
65 #[inline]
66 fn size_hint(&self) -> (usize, Option<usize>) {
67 if self.done { (0, Some(0)) } else { (0, None) }
68 }
69}
70
71#[inline]
98pub fn try_async_stream<T, E, F, U>(generator: F) -> TryAsyncStream<T, E, U>
99where
100 F: FnOnce(Yielder<T>) -> U,
101 U: Future<Output = Result<(), E>>
102{
103 let store = Arc::new(SharedStore::default());
104 let generator = generator(Yielder { store: Arc::downgrade(&store) });
105 TryAsyncStream {
106 store,
107 done: false,
108 generator,
109 _p: PhantomData
110 }
111}
112
113#[cfg(test)]
114mod tests {
115 use futures::{Stream, StreamExt};
116
117 use super::try_async_stream;
118 use crate::tests::run_test;
119
120 #[test]
121 fn single_err() {
122 run_test(async {
123 let s = try_async_stream(|yielder| async move {
124 if true {
125 Err("hello")?;
126 } else {
127 yielder.y("world").await;
128 }
129 Ok(())
130 });
131
132 let values: Vec<_> = s.collect().await;
133 assert_eq!(1, values.len());
134 assert_eq!(Err("hello"), values[0]);
135 })
136 }
137
138 #[test]
139 fn yield_then_err() {
140 run_test(async {
141 let s = try_async_stream(|yielder| async move {
142 yielder.y("hello").await;
143 Err("world")?;
144 unreachable!();
145 });
146
147 let values: Vec<_> = s.collect().await;
148 assert_eq!(2, values.len());
149 assert_eq!(Ok("hello"), values[0]);
150 assert_eq!(Err("world"), values[1]);
151 })
152 }
153
154 #[test]
155 fn convert_err() {
156 run_test(async {
157 struct ErrorA(u8);
158 #[derive(PartialEq, Debug)]
159 struct ErrorB(u8);
160 impl From<ErrorA> for ErrorB {
161 fn from(a: ErrorA) -> Self {
162 ErrorB(a.0)
163 }
164 }
165
166 fn test() -> impl Stream<Item = Result<&'static str, ErrorB>> {
167 try_async_stream(|yielder| async move {
168 if true {
169 Err(ErrorA(1))?;
170 } else {
171 Err(ErrorB(2))?;
172 }
173 yielder.y("unreachable").await;
174 Ok(())
175 })
176 }
177
178 let values: Vec<_> = test().collect().await;
179 assert_eq!(1, values.len());
180 assert_eq!(Err(ErrorB(1)), values[0]);
181 })
182 }
183}