1use alloc::{boxed::Box, vec, vec::Vec};
2use core::{
3 fmt,
4 future::Future,
5 marker::PhantomData,
6 ops::ControlFlow,
7 pin::Pin,
8 task::{Context, Poll},
9};
10
11use async_sleep::{sleep, Sleepble};
12use futures_util::{future::FusedFuture, FutureExt as _};
13use pin_project_lite::pin_project;
14use retry_policy::RetryPolicy;
15
16use crate::error::Error;
17
18type RetryFutureRepeater<T, E> =
20 Box<dyn FnMut() -> Pin<Box<dyn Future<Output = Result<T, E>> + Send>> + Send>;
21
22pin_project! {
24 pub struct Retry<SLEEP, POL, T, E> {
25 policy: POL,
26 future_repeater: RetryFutureRepeater<T, E>,
27 state: State<T, E>,
29 attempts: usize,
30 errors: Option<Vec<E>>,
31 phantom: PhantomData<(SLEEP, T, E)>,
33 }
34}
35
36impl<SLEEP, POL, T, E> fmt::Debug for Retry<SLEEP, POL, T, E>
37where
38 POL: fmt::Debug,
39{
40 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
41 f.debug_struct("Retry")
42 .field("policy", &self.policy)
43 .field("future_repeater", &"")
44 .finish()
45 }
46}
47
48impl<SLEEP, POL, T, E> Retry<SLEEP, POL, T, E> {
49 pub(crate) fn new(policy: POL, future_repeater: RetryFutureRepeater<T, E>) -> Self {
50 Self {
51 policy,
52 future_repeater,
53 state: State::Pending,
55 attempts: 0,
56 errors: Some(vec![]),
57 phantom: PhantomData,
59 }
60 }
61}
62
63enum State<T, E> {
65 Pending,
66 Fut(Pin<Box<dyn Future<Output = Result<T, E>> + Send>>),
67 Sleep(Pin<Box<dyn Future<Output = ()> + Send>>),
68 Done,
69}
70impl<T, E> fmt::Debug for State<T, E> {
71 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
72 match self {
73 State::Pending => write!(f, "Pending"),
74 State::Fut(_) => write!(f, "Fut"),
75 State::Sleep(_) => write!(f, "Sleep"),
76 State::Done => write!(f, "Done"),
77 }
78 }
79}
80
81pub fn retry<SLEEP, POL, F, Fut, T, E>(policy: POL, future_repeater: F) -> Retry<SLEEP, POL, T, E>
83where
84 SLEEP: Sleepble + 'static,
85 POL: RetryPolicy<E>,
86 F: Fn() -> Fut + Send + 'static,
87 Fut: Future<Output = Result<T, E>> + Send + 'static,
88{
89 Retry::new(policy, Box::new(move || Box::pin(future_repeater())))
90}
91
92impl<SLEEP, POL, T, E> FusedFuture for Retry<SLEEP, POL, T, E>
94where
95 SLEEP: Sleepble + 'static,
96 POL: RetryPolicy<E>,
97{
98 fn is_terminated(&self) -> bool {
99 matches!(self.state, State::Done)
100 }
101}
102
103impl<SLEEP, POL, T, E> Future for Retry<SLEEP, POL, T, E>
105where
106 SLEEP: Sleepble + 'static,
107 POL: RetryPolicy<E>,
108{
109 type Output = Result<T, Error<E>>;
110
111 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
112 let this = self.project();
113
114 loop {
115 match this.state {
116 State::Pending => {
117 let future = (this.future_repeater)();
118
119 *this.state = State::Fut(future);
121
122 continue;
123 }
124 State::Fut(future) => {
125 match future.poll_unpin(cx) {
126 Poll::Ready(Ok(x)) => {
127 *this.state = State::Done;
129 *this.attempts = 0;
130 *this.errors = Some(Vec::new());
131
132 break Poll::Ready(Ok(x));
133 }
134 Poll::Ready(Err(err)) => {
135 *this.attempts += 1;
137
138 let ret = this.policy.next_step(&err, *this.attempts);
140
141 if let Some(errors) = this.errors.as_mut() {
143 errors.push(err)
144 } else {
145 unreachable!()
146 }
147
148 match ret {
149 ControlFlow::Continue(dur) => {
150 *this.state = State::Sleep(Box::pin(sleep::<SLEEP>(dur)));
152
153 continue;
154 }
155 ControlFlow::Break(stop_reason) => {
156 let errors = this.errors.take().expect("unreachable!()");
157
158 *this.state = State::Done;
160 *this.attempts = 0;
161 *this.errors = Some(Vec::new());
162
163 break Poll::Ready(Err(Error::new(stop_reason, errors)));
164 }
165 }
166 }
167 Poll::Pending => {
168 break Poll::Pending;
169 }
170 }
171 }
172 State::Sleep(future) => match future.poll_unpin(cx) {
173 Poll::Ready(_) => {
174 *this.state = State::Pending;
176
177 continue;
178 }
179 Poll::Pending => {
180 break Poll::Pending;
181 }
182 },
183 State::Done => panic!("cannot poll Retry twice"),
184 }
185 }
186 }
187}
188
189#[cfg(feature = "std")]
190#[cfg(test)]
191mod tests {
192 use super::*;
193
194 use core::{
195 sync::atomic::{AtomicUsize, Ordering},
196 time::Duration,
197 };
198
199 use async_sleep::impl_tokio::Sleep;
200 use once_cell::sync::Lazy;
201 use retry_policy::{
202 policies::SimplePolicy,
203 retry_backoff::backoffs::FnBackoff,
204 retry_predicate::predicates::{AlwaysPredicate, FnPredicate},
205 StopReason,
206 };
207
208 #[tokio::test]
209 async fn test_retry_with_max_retries_reached() {
210 #[derive(Debug, PartialEq)]
211 struct FError(usize);
212 async fn f(n: usize) -> Result<(), FError> {
213 Err(FError(n))
214 }
215
216 let policy = SimplePolicy::new(
218 AlwaysPredicate,
219 3,
220 FnBackoff::from(|_| Duration::from_millis(100)),
221 );
222
223 let now = std::time::Instant::now();
225
226 match retry::<Sleep, _, _, _, _, _>(policy, || f(0)).await {
227 Ok(_) => panic!(""),
228 Err(err) => {
229 assert_eq!(&err.stop_reason, &StopReason::MaxRetriesReached);
230 assert_eq!(err.errors(), &[FError(0), FError(0), FError(0), FError(0)]);
231 }
232 }
233
234 let elapsed_dur = now.elapsed();
235 assert!(elapsed_dur.as_millis() >= 300 && elapsed_dur.as_millis() <= 305);
236 }
237
238 #[tokio::test]
239 async fn test_retry_with_max_retries_reached_for_tokio_spawn() {
240 #[derive(Debug, PartialEq)]
241 struct FError(usize);
242 async fn f(n: usize) -> Result<(), FError> {
243 Err(FError(n))
244 }
245
246 let policy = SimplePolicy::new(
248 AlwaysPredicate,
249 3,
250 FnBackoff::from(|_| Duration::from_millis(100)),
251 );
252
253 tokio::spawn(async move {
255 let now = std::time::Instant::now();
256
257 match retry::<Sleep, _, _, _, _, _>(policy, || f(0)).await {
258 Ok(_) => panic!(""),
259 Err(err) => {
260 assert_eq!(&err.stop_reason, &StopReason::MaxRetriesReached);
261 assert_eq!(err.errors(), &[FError(0), FError(0), FError(0), FError(0)]);
262 }
263 }
264
265 let elapsed_dur = now.elapsed();
266 assert!(elapsed_dur.as_millis() >= 300 && elapsed_dur.as_millis() <= 305);
267 });
268 }
269
270 #[tokio::test]
271 async fn test_retry_with_predicate_failed() {
272 #[derive(Debug, PartialEq)]
273 struct FError(usize);
274 async fn f(n: usize) -> Result<(), FError> {
275 Err(FError(n))
276 }
277
278 static N: Lazy<AtomicUsize> = Lazy::new(|| AtomicUsize::new(0));
280
281 let policy = SimplePolicy::new(
282 FnPredicate::from(|FError(n): &FError| [0, 1].contains(n)),
283 3,
284 FnBackoff::from(|_| Duration::from_millis(100)),
285 );
286
287 let now = std::time::Instant::now();
289
290 match retry::<Sleep, _, _, _, _, _>(policy, || f(N.fetch_add(1, Ordering::SeqCst))).await {
291 Ok(_) => panic!(""),
292 Err(err) => {
293 assert_eq!(&err.stop_reason, &StopReason::PredicateFailed);
294 assert_eq!(err.errors(), &[FError(0), FError(1), FError(2)]);
295 }
296 }
297
298 let elapsed_dur = now.elapsed();
299 assert!(elapsed_dur.as_millis() >= 200 && elapsed_dur.as_millis() <= 205);
300 }
301}