1use std::future::Future;
2use std::pin::Pin;
3use std::task::{Context, Poll};
4
5use futures::{ready, TryFuture};
6use pin_project::pin_project;
7use tokio::time::sleep;
8
9use crate::error::RetryError;
10use crate::retry_strategy::RetryStrategy;
11use crate::RetryPolicy;
12
13pub trait FutureFactory<E> {
14 type Future: TryFuture<Error = RetryPolicy<E>>;
15
16 fn spawn(&mut self) -> Self::Future;
17}
18
19impl<T, Fut, E> FutureFactory<E> for T
20where
21 T: Unpin + FnMut() -> Fut,
22 Fut: TryFuture<Error = RetryPolicy<E>>,
23{
24 type Future = Fut;
25
26 fn spawn(&mut self) -> Fut {
27 self()
28 }
29}
30
31#[pin_project(project = FutureStateProj)]
32enum FutureState<Fut> {
33 WaitingForFuture {
34 #[pin]
35 future: Fut,
36 },
37 TimerActive {
38 #[pin]
39 delay: tokio::time::Sleep,
40 },
41}
42
43#[pin_project]
51pub struct AsyncRetry<F, E, RS>
52where
53 F: FutureFactory<E>
54{
55 factory: F,
56 retry_strategy: RS,
57 attempts_before: usize,
58 #[pin]
59 state: FutureState<F::Future>,
60 errors: Vec<RetryPolicy<E>>,
61}
62
63impl<F, E, RS> AsyncRetry<F, E, RS>
64where
65 F: FutureFactory<E>,
66{
67 pub fn new(mut factory: F, retry_strategy: RS) -> Self {
73 let future = factory.spawn();
74 Self {
75 factory,
76 retry_strategy,
77 state: FutureState::WaitingForFuture { future },
78 attempts_before: 0,
79 errors: Vec::new(),
80 }
81 }
82}
83
84impl<F, E, RS> Future for AsyncRetry<F, E, RS>
85where
86 F: FutureFactory<E>,
87 RS: RetryStrategy,
88{
89 type Output = Result<<<F as FutureFactory<E>>::Future as TryFuture>::Ok, RetryError<E>>;
90
91 fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
92 loop {
93 let async_retry = self.as_mut().project();
94 let new_state = match async_retry.state.project() {
95 FutureStateProj::WaitingForFuture { future } => match ready!(future.try_poll(cx)) {
96 Ok(t) => {
97 *async_retry.attempts_before = 0;
98 return Poll::Ready(Ok(t));
99 }
100 Err(err) => {
101 async_retry.errors.push(err);
102 let err = async_retry.errors.last().unwrap(); let new_state = match err {
104 RetryPolicy::Repeat(_) => {
105 let check_attempt_result = async_retry
106 .retry_strategy
107 .check_attempt(*async_retry.attempts_before);
108 match check_attempt_result {
109 Ok(duration) => {
110 FutureState::TimerActive { delay: sleep(duration) }
111 }
112 Err(_) => {
113 let errors =
114 std::mem::take(async_retry.errors);
115 return Poll::Ready(Err(RetryError { errors }));
116 }
117 }
118 }
119 RetryPolicy::Fail(_) => {
120 let errors = std::mem::take(async_retry.errors);
121 return Poll::Ready(Err(RetryError { errors }));
122 }
123 };
124 *async_retry.attempts_before += 1;
125 new_state
126 }
127 },
128 FutureStateProj::TimerActive { delay } => {
129 ready!(delay.poll(cx));
130 FutureState::WaitingForFuture { future: async_retry.factory.spawn() }
131 }
132 };
133
134 self.as_mut().project().state.set(new_state);
135 }
136 }
137}