1use core::future::Future;
2use core::pin::Pin;
3use core::task::{Context, Poll};
4use pin_project::pin_project;
5
6pub trait AsyncMapExt<T, E> {
7 fn async_map<TFn, TFuture>(self, f: TFn) -> AsyncMap<T, E, TFn, TFuture>
32 where
33 TFn: FnOnce(T) -> TFuture,
34 TFuture: Future;
35
36 fn async_and_then<U, TFn, TFuture>(self, f: TFn) -> AsyncAndThen<T, E, TFn, TFuture>
73 where
74 TFn: FnOnce(T) -> TFuture,
75 TFuture: Future<Output = Result<U, E>>;
76}
77
78#[doc(hidden)]
79#[pin_project(project = AsyncMapProj)]
80pub enum AsyncMap<T, E, TFn, TFuture> {
81 Err(Option<E>),
82 Pending(Option<(T, TFn)>),
83 Polling(#[pin] TFuture),
84}
85
86impl<T, U, E, TFn, TFuture> Future for AsyncMap<T, E, TFn, TFuture>
87where
88 TFn: FnOnce(T) -> TFuture,
89 TFuture: Future<Output = U>,
90{
91 type Output = Result<U, E>;
92
93 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
94 use AsyncMapProj::*;
95
96 match self.as_mut().project() {
97 Err(e) => Poll::Ready(Result::Err(e.take().expect("AsyncMap::Err polled twice"))),
98
99 Pending(payload) => {
100 let (x, f) = payload.take().expect("AsyncMap::Pending polled twice");
101 let future = f(x);
102 self.set(AsyncMap::Polling(future));
103 self.poll(cx)
104 }
105
106 Polling(future) => future.poll(cx).map(Ok),
107 }
108 }
109}
110
111#[doc(hidden)]
112#[pin_project(project = AsyncAndThenProj)]
113pub enum AsyncAndThen<T, E, TFn, TFuture> {
114 Err(Option<E>),
115 Pending(Option<(T, TFn)>),
116 Polling(#[pin] TFuture),
117}
118
119impl<T, U, E, TFn, TFuture> Future for AsyncAndThen<T, E, TFn, TFuture>
120where
121 TFn: FnOnce(T) -> TFuture,
122 TFuture: Future<Output = Result<U, E>>,
123{
124 type Output = Result<U, E>;
125
126 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
127 use AsyncAndThenProj::*;
128
129 match self.as_mut().project() {
130 Err(e) => Poll::Ready(Result::Err(
131 e.take().expect("AsyncAndThen::Err polled twice"),
132 )),
133
134 Pending(payload) => {
135 let (x, f) = payload.take().expect("AsyncAndThen::Pending polled twice");
136 let future = f(x);
137 self.set(AsyncAndThen::Polling(future));
138 self.poll(cx)
139 }
140
141 Polling(future) => match future.poll(cx) {
142 Poll::Ready(Result::Ok(v)) => Poll::Ready(Result::Ok(v)),
143 Poll::Ready(Result::Err(e)) => Poll::Ready(Result::Err(e)),
144 Poll::Pending => Poll::Pending,
145 },
146 }
147 }
148}
149
150impl<T, E> AsyncMapExt<T, E> for Result<T, E> {
151 fn async_map<TFn, TFuture>(self, f: TFn) -> AsyncMap<T, E, TFn, TFuture>
152 where
153 TFn: FnOnce(T) -> TFuture,
154 TFuture: Future,
155 {
156 match self {
157 Ok(v) => AsyncMap::Pending(Some((v, f))),
158 Err(e) => AsyncMap::Err(Some(e)),
159 }
160 }
161
162 fn async_and_then<U, TFn, TFuture>(self, f: TFn) -> AsyncAndThen<T, E, TFn, TFuture>
163 where
164 TFn: FnOnce(T) -> TFuture,
165 TFuture: Future<Output = Result<U, E>>,
166 {
167 match self {
168 Ok(v) => AsyncAndThen::Pending(Some((v, f))),
169 Err(e) => AsyncAndThen::Err(Some(e)),
170 }
171 }
172}
173
174#[cfg(test)]
175mod test {
176 use super::AsyncMapExt;
177
178 type Result = core::result::Result<i32, i32>;
179
180 #[tokio::test]
181 async fn map() {
182 assert_eq!(
183 Result::Ok(1).async_map(|x: i32| async move { x + 1 }).await,
184 Result::Ok(2),
185 );
186
187 assert_eq!(
188 Result::Err(4)
189 .async_map(|x: i32| async move { x + 1 })
190 .await,
191 Result::Err(4),
192 );
193 }
194
195 #[tokio::test]
196 async fn and_then() {
197 assert_eq!(
198 Result::Ok(1)
199 .async_and_then(|x: i32| async move { Ok(x + 1) })
200 .await,
201 Result::Ok(2),
202 );
203
204 assert_eq!(
205 Result::Ok(1)
206 .async_and_then(|x: i32| async move { Err(x + 1) })
207 .await,
208 Result::Err(2),
209 );
210
211 assert_eq!(
212 Result::Err(4)
213 .async_and_then(|x: i32| async move { Ok(x + 1) })
214 .await,
215 Result::Err(4),
216 );
217 }
218}