1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
use crate::error::{Error, ErrorKind, HttpError};
use crate::policies::{Policy, PolicyResult, Request};
use crate::sleep::sleep;
use crate::{Context, StatusCode};

use async_trait::async_trait;
use time::OffsetDateTime;

use std::sync::Arc;
use std::time::Duration;

/// A retry policy.
///
/// In the simple form, the policies need only differ in how
/// they determine if the retry has expired and for how long they should
/// sleep between retries.
///
/// `wait` can be implemented in more complex cases where a simple test of time
/// is not enough.
#[async_trait]
pub trait RetryPolicy: std::fmt::Debug + Send + Sync {
    /// Determine if no more retries should be performed.
    ///
    /// Must return true if no more retries should be attempted.
    fn is_expired(&self, duration_since_start: Duration, retry_count: u32) -> bool;
    /// Determine how long before the next retry should be attempted.
    fn sleep_duration(&self, retry_count: u32) -> Duration;
    /// A Future that will wait until the request can be retried.
    /// `error` is the [`Error`] value the led to a retry attempt.
    async fn wait(&self, _error: &Error, retry_count: u32) {
        sleep(self.sleep_duration(retry_count)).await;
    }
}

/// The status codes where a retry should be attempted.
///
/// On all other 4xx and 5xx status codes no retry is attempted.
const RETRY_STATUSES: &[StatusCode] = &[
    StatusCode::RequestTimeout,
    StatusCode::TooManyRequests,
    StatusCode::InternalServerError,
    StatusCode::BadGateway,
    StatusCode::ServiceUnavailable,
    StatusCode::GatewayTimeout,
];

#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
impl<T> Policy for T
where
    T: RetryPolicy,
{
    async fn send(
        &self,
        ctx: &Context,
        request: &mut Request,
        next: &[Arc<dyn Policy>],
    ) -> PolicyResult {
        let mut retry_count = 0;
        let mut start = None;

        loop {
            let result = next[0].send(ctx, request, &next[1..]).await;
            // only start keeping track of time after the first request is made
            let start = start.get_or_insert_with(OffsetDateTime::now_utc);
            let last_error = match result {
                Ok(response) if response.status().is_success() => {
                    log::trace!(
                        "Successful response. Request={:?} response={:?}",
                        request,
                        response
                    );
                    return Ok(response);
                }
                Ok(response) => {
                    // Error status code
                    let status = response.status();
                    let http_error = HttpError::new(response).await;

                    let error_kind = ErrorKind::http_response(
                        status,
                        http_error.error_code().map(std::borrow::ToOwned::to_owned),
                    );

                    if !RETRY_STATUSES.contains(&status) {
                        log::debug!(
                            "server returned error status which will not be retried: {}",
                            status
                        );
                        // Server didn't return a status we retry on so return early
                        let error = Error::full(
                            error_kind,
                            http_error,
                            format!(
                                "server returned error status which will not be retried: {status}"
                            ),
                        );
                        return Err(error);
                    }
                    log::debug!(
                        "server returned error status which requires retry: {}",
                        status
                    );
                    Error::new(error_kind, http_error)
                }
                Err(error) => {
                    if error.kind() == &ErrorKind::Io {
                        log::debug!(
                            "io error occurred when making request which will be retried: {}",
                            error
                        );
                        error
                    } else {
                        return Err(
                            error.context("non-io error occurred which will not be retried")
                        );
                    }
                }
            };

            let time_since_start = (OffsetDateTime::now_utc() - *start)
                .try_into()
                .unwrap_or_default();
            if self.is_expired(time_since_start, retry_count) {
                return Err(last_error
                    .context("retry policy expired and the request will no longer be retried"));
            }
            retry_count += 1;

            self.wait(&last_error, retry_count).await;
        }
    }
}