google_cloud_gax/
retry.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
use std::future::Future;
use std::iter::Take;
use std::time::Duration;

pub use tokio_retry2::strategy::ExponentialBackoff;
use tokio_retry2::{Action, RetryIf};
pub use tokio_retry2::{Condition, MapErr};

use crate::grpc::{Code, Status};

pub trait TryAs<T> {
    fn try_as(&self) -> Option<&T>;
}

impl TryAs<Status> for Status {
    fn try_as(&self) -> Option<&Status> {
        Some(self)
    }
}

pub trait Retry<E: TryAs<Status>, T: Condition<E>> {
    fn strategy(&self) -> Take<ExponentialBackoff>;
    fn condition(&self) -> T;
    fn notify(error: &E, duration: Duration);
}

pub struct CodeCondition {
    codes: Vec<Code>,
}

impl CodeCondition {
    pub fn new(codes: Vec<Code>) -> Self {
        Self { codes }
    }
}

impl<E> Condition<E> for CodeCondition
where
    E: TryAs<Status>,
{
    fn should_retry(&mut self, error: &E) -> bool {
        if let Some(status) = error.try_as() {
            for code in &self.codes {
                if *code == status.code() {
                    return true;
                }
            }
        }
        false
    }
}

#[derive(Clone, Debug)]
pub struct RetrySetting {
    pub from_millis: u64,
    pub max_delay: Option<Duration>,
    pub factor: u64,
    pub take: usize,
    pub codes: Vec<Code>,
}

impl Retry<Status, CodeCondition> for RetrySetting {
    fn strategy(&self) -> Take<ExponentialBackoff> {
        let mut st = ExponentialBackoff::from_millis(self.from_millis);
        if let Some(max_delay) = self.max_delay {
            st = st.max_delay(max_delay);
        }
        st.take(self.take)
    }

    fn condition(&self) -> CodeCondition {
        CodeCondition::new(self.codes.clone())
    }

    fn notify(_error: &Status, _duration: Duration) {
        tracing::trace!("retry fn");
    }
}

impl Default for RetrySetting {
    fn default() -> Self {
        Self {
            from_millis: 10,
            max_delay: Some(Duration::from_secs(1)),
            factor: 1u64,
            take: 5,
            codes: vec![Code::Unavailable, Code::Unknown, Code::Aborted],
        }
    }
}

pub async fn invoke<A, R, RT, C, E>(retry: Option<RT>, action: A) -> Result<R, E>
where
    E: TryAs<Status> + From<Status>,
    A: Action<Item = R, Error = E>,
    C: Condition<E>,
    RT: Retry<E, C> + Default,
{
    let retry = retry.unwrap_or_default();
    RetryIf::spawn(retry.strategy(), action, retry.condition(), RT::notify).await
}
/// Repeats retries when the specified error is detected.
/// The argument specified by 'v' can be reused for each retry.
pub async fn invoke_fn<R, V, A, RT, C, E>(retry: Option<RT>, mut f: impl FnMut(V) -> A, mut v: V) -> Result<R, E>
where
    E: TryAs<Status> + From<Status>,
    A: Future<Output = Result<R, (E, V)>>,
    C: Condition<E>,
    RT: Retry<E, C> + Default,
{
    let retry = retry.unwrap_or_default();
    let mut strategy = retry.strategy();
    loop {
        let result = f(v).await;
        let status = match result {
            Ok(s) => return Ok(s),
            Err(e) => {
                v = e.1;
                e.0
            }
        };
        if retry.condition().should_retry(&status) {
            let duration = strategy.next().ok_or(status)?;
            tokio::time::sleep(duration).await;
        } else {
            return Err(status);
        }
    }
}

#[cfg(test)]
mod tests {
    use std::sync::{Arc, Mutex};

    use tokio_retry2::MapErr;
    use tonic::{Code, Status};

    use crate::retry::{invoke, RetrySetting};

    #[tokio::test]
    async fn test_retry() {
        let retry = RetrySetting::default();
        let counter = Arc::new(Mutex::new(0));
        let action = || async {
            let mut lock = counter.lock().unwrap();
            *lock += 1;
            let result: Result<i32, Status> = Err(Status::new(Code::Aborted, "error"));
            result.map_transient_err()
        };
        let actual = invoke(Some(retry), action).await.unwrap_err();
        let expected = Status::new(Code::Aborted, "error");
        assert_eq!(actual.code(), expected.code());
        assert_eq!(*counter.lock().unwrap(), 6);
    }
}