1use alloc::sync::Arc;
2use core::{
3 future::Future,
4 marker::PhantomData,
5 pin::Pin,
6 task::{Context, Poll}
7};
8
9use futures_core::stream::{FusedStream, Stream};
10
11use crate::{SharedStore, Yielder, enter};
12
13pin_project_lite::pin_project! {
14 pub struct TryAsyncStream<T, E, U> {
18 store: Arc<SharedStore<T>>,
19 done: bool,
20 #[pin]
21 generator: U,
22 _p: PhantomData<E>
23 }
24}
25
26impl<T, E, U> FusedStream for TryAsyncStream<T, E, U>
27where
28 U: Future<Output = Result<(), E>>
29{
30 fn is_terminated(&self) -> bool {
31 self.done
32 }
33}
34
35impl<T, E, U> Stream for TryAsyncStream<T, E, U>
36where
37 U: Future<Output = Result<(), E>>
38{
39 type Item = Result<T, E>;
40
41 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
42 let me = self.project();
43 if *me.done {
44 return Poll::Ready(None);
45 }
46
47 let res = {
48 let _enter = enter(&me.store);
49 me.generator.poll(cx)
50 };
51
52 *me.done = res.is_ready();
53
54 if let Poll::Ready(Err(e)) = res {
55 return Poll::Ready(Some(Err(e)));
56 } else if me.store.has_value() {
57 return Poll::Ready(me.store.cell.take().map(Ok));
58 }
59
60 if *me.done { Poll::Ready(None) } else { Poll::Pending }
61 }
62
63 fn size_hint(&self) -> (usize, Option<usize>) {
64 if self.done { (0, Some(0)) } else { (0, None) }
65 }
66}
67
68pub fn try_async_stream<T, E, F, U>(generator: F) -> TryAsyncStream<T, E, U>
95where
96 F: FnOnce(Yielder<T>) -> U,
97 U: Future<Output = Result<(), E>>
98{
99 let store = Arc::new(SharedStore::default());
100 let generator = generator(Yielder { store: Arc::downgrade(&store) });
101 TryAsyncStream {
102 store,
103 done: false,
104 generator,
105 _p: PhantomData
106 }
107}
108
109#[cfg(test)]
110mod tests {
111 use futures::{Stream, StreamExt};
112
113 use super::try_async_stream;
114
115 #[tokio::test]
116 async fn single_err() {
117 let s = try_async_stream(|yielder| async move {
118 if true {
119 Err("hello")?;
120 } else {
121 yielder.r#yield("world").await;
122 }
123 Ok(())
124 });
125
126 let values: Vec<_> = s.collect().await;
127 assert_eq!(1, values.len());
128 assert_eq!(Err("hello"), values[0]);
129 }
130
131 #[tokio::test]
132 async fn yield_then_err() {
133 let s = try_async_stream(|yielder| async move {
134 yielder.r#yield("hello").await;
135 Err("world")?;
136 unreachable!();
137 });
138
139 let values: Vec<_> = s.collect().await;
140 assert_eq!(2, values.len());
141 assert_eq!(Ok("hello"), values[0]);
142 assert_eq!(Err("world"), values[1]);
143 }
144
145 #[tokio::test]
146 async fn convert_err() {
147 struct ErrorA(u8);
148 #[derive(PartialEq, Debug)]
149 struct ErrorB(u8);
150 impl From<ErrorA> for ErrorB {
151 fn from(a: ErrorA) -> Self {
152 ErrorB(a.0)
153 }
154 }
155
156 fn test() -> impl Stream<Item = Result<&'static str, ErrorB>> {
157 try_async_stream(|yielder| async move {
158 if true {
159 Err(ErrorA(1))?;
160 } else {
161 Err(ErrorB(2))?;
162 }
163 yielder.r#yield("unreachable").await;
164 Ok(())
165 })
166 }
167
168 let values: Vec<_> = test().collect().await;
169 assert_eq!(1, values.len());
170 assert_eq!(Err(ErrorB(1)), values[0]);
171 }
172}