async_repeat/
future.rs

1use std::future::Future;
2use std::pin::Pin;
3use std::task::{Context, Poll};
4
5use futures::{ready, TryFuture};
6use pin_project::pin_project;
7use tokio::time::sleep;
8
9use crate::error::RetryError;
10use crate::retry_strategy::RetryStrategy;
11use crate::RetryPolicy;
12
13pub trait FutureFactory<E> {
14    type Future: TryFuture<Error = RetryPolicy<E>>;
15
16    fn spawn(&mut self) -> Self::Future;
17}
18
19impl<T, Fut, E> FutureFactory<E> for T
20where
21    T: Unpin + FnMut() -> Fut,
22    Fut: TryFuture<Error = RetryPolicy<E>>,
23{
24    type Future = Fut;
25
26    fn spawn(&mut self) -> Fut {
27        self()
28    }
29}
30
31#[pin_project(project = FutureStateProj)]
32enum FutureState<Fut> {
33    WaitingForFuture {
34        #[pin]
35        future: Fut,
36    },
37    TimerActive {
38        #[pin]
39        delay: tokio::time::Sleep,
40    },
41}
42
43/// A future which is trying to resolve inner future
44/// until it exits successfully or return an [error](crate::error::RetryError).
45///
46/// The main point is that you handle all the logic **inside** your future
47/// and construct a helper type or use one of existing which implements
48/// [RetryStrategy](crate::retry_strategy::RetryStrategy) trait
49/// which is responsible for configuring retry mechanism
50#[pin_project]
51pub struct AsyncRetry<F, E, RS>
52where
53    F: FutureFactory<E>
54{
55    factory: F,
56    retry_strategy: RS,
57    attempts_before: usize,
58    #[pin]
59    state: FutureState<F::Future>,
60    errors: Vec<RetryPolicy<E>>,
61}
62
63impl<F, E, RS> AsyncRetry<F, E, RS>
64where
65    F: FutureFactory<E>,
66{
67    /// [FutureFactory](FutureFactory) has a blanket implementation
68    /// for FnMut closures. This means that you can pass a closure instead
69    /// of implementing [FutureFactory](FutureFactory) for your type.
70    ///
71    /// See examples to understand how to use this.
72    pub fn new(mut factory: F, retry_strategy: RS) -> Self {
73        let future = factory.spawn();
74        Self {
75            factory,
76            retry_strategy,
77            state: FutureState::WaitingForFuture { future },
78            attempts_before: 0,
79            errors: Vec::new(),
80        }
81    }
82}
83
84impl<F, E, RS> Future for AsyncRetry<F, E, RS>
85where
86    F: FutureFactory<E>,
87    RS: RetryStrategy,
88{
89    type Output = Result<<<F as FutureFactory<E>>::Future as TryFuture>::Ok, RetryError<E>>;
90
91    fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
92        loop {
93            let async_retry = self.as_mut().project();
94            let new_state = match async_retry.state.project() {
95                FutureStateProj::WaitingForFuture { future } => match ready!(future.try_poll(cx)) {
96                    Ok(t) => {
97                        *async_retry.attempts_before = 0;
98                        return Poll::Ready(Ok(t));
99                    }
100                    Err(err) => {
101                        async_retry.errors.push(err);
102                        let err = async_retry.errors.last().unwrap(); // cannot panic as we just pushed to vec
103                        let new_state = match err {
104                            RetryPolicy::Repeat(_) => {
105                                let check_attempt_result = async_retry
106                                    .retry_strategy
107                                    .check_attempt(*async_retry.attempts_before);
108                                match check_attempt_result {
109                                    Ok(duration) => {
110                                        FutureState::TimerActive { delay: sleep(duration) }
111                                    }
112                                    Err(_) => {
113                                        let errors =
114                                            std::mem::take(async_retry.errors);
115                                        return Poll::Ready(Err(RetryError { errors }));
116                                    }
117                                }
118                            }
119                            RetryPolicy::Fail(_) => {
120                                let errors = std::mem::take(async_retry.errors);
121                                return Poll::Ready(Err(RetryError { errors }));
122                            }
123                        };
124                        *async_retry.attempts_before += 1;
125                        new_state
126                    }
127                },
128                FutureStateProj::TimerActive { delay } => {
129                    ready!(delay.poll(cx));
130                    FutureState::WaitingForFuture { future: async_retry.factory.spawn() }
131                }
132            };
133
134            self.as_mut().project().state.set(new_state);
135        }
136    }
137}