gcloud_gax/
retry.rs

1use std::future::Future;
2use std::iter::Take;
3use std::time::Duration;
4
5pub use tokio_retry2::strategy::ExponentialBackoff;
6use tokio_retry2::{Action, RetryIf};
7pub use tokio_retry2::{Condition, MapErr};
8
9use crate::grpc::{Code, Status};
10
11pub trait TryAs<T> {
12    fn try_as(&self) -> Option<&T>;
13}
14
15impl TryAs<Status> for Status {
16    fn try_as(&self) -> Option<&Status> {
17        Some(self)
18    }
19}
20
21pub trait Retry<E: TryAs<Status>, T: Condition<E>> {
22    fn strategy(&self) -> Take<ExponentialBackoff>;
23    fn condition(&self) -> T;
24    fn notify(error: &E, duration: Duration);
25}
26
27pub struct CodeCondition {
28    codes: Vec<Code>,
29}
30
31impl CodeCondition {
32    pub fn new(codes: Vec<Code>) -> Self {
33        Self { codes }
34    }
35}
36
37impl<E> Condition<E> for CodeCondition
38where
39    E: TryAs<Status>,
40{
41    fn should_retry(&mut self, error: &E) -> bool {
42        if let Some(status) = error.try_as() {
43            for code in &self.codes {
44                if *code == status.code() {
45                    return true;
46                }
47            }
48        }
49        false
50    }
51}
52
53#[derive(Clone, Debug)]
54pub struct RetrySetting {
55    pub from_millis: u64,
56    pub max_delay: Option<Duration>,
57    pub factor: u64,
58    pub take: usize,
59    pub codes: Vec<Code>,
60}
61
62impl Retry<Status, CodeCondition> for RetrySetting {
63    fn strategy(&self) -> Take<ExponentialBackoff> {
64        let mut st = ExponentialBackoff::from_millis(self.from_millis);
65        if let Some(max_delay) = self.max_delay {
66            st = st.max_delay(max_delay);
67        }
68        st.take(self.take)
69    }
70
71    fn condition(&self) -> CodeCondition {
72        CodeCondition::new(self.codes.clone())
73    }
74
75    fn notify(_error: &Status, _duration: Duration) {
76        tracing::trace!("retry fn");
77    }
78}
79
80impl Default for RetrySetting {
81    fn default() -> Self {
82        Self {
83            from_millis: 10,
84            max_delay: Some(Duration::from_secs(1)),
85            factor: 1u64,
86            take: 5,
87            codes: vec![Code::Unavailable, Code::Unknown, Code::Aborted],
88        }
89    }
90}
91
92pub async fn invoke<A, R, RT, C, E>(retry: Option<RT>, action: A) -> Result<R, E>
93where
94    E: TryAs<Status> + From<Status>,
95    A: Action<Item = R, Error = E>,
96    C: Condition<E>,
97    RT: Retry<E, C> + Default,
98{
99    let retry = retry.unwrap_or_default();
100    RetryIf::spawn(retry.strategy(), action, retry.condition(), RT::notify).await
101}
102/// Repeats retries when the specified error is detected.
103/// The argument specified by 'v' can be reused for each retry.
104pub async fn invoke_fn<R, V, A, RT, C, E>(retry: Option<RT>, mut f: impl FnMut(V) -> A, mut v: V) -> Result<R, E>
105where
106    E: TryAs<Status> + From<Status>,
107    A: Future<Output = Result<R, (E, V)>>,
108    C: Condition<E>,
109    RT: Retry<E, C> + Default,
110{
111    let retry = retry.unwrap_or_default();
112    let mut strategy = retry.strategy();
113    loop {
114        let result = f(v).await;
115        let status = match result {
116            Ok(s) => return Ok(s),
117            Err(e) => {
118                v = e.1;
119                e.0
120            }
121        };
122        if retry.condition().should_retry(&status) {
123            let duration = strategy.next().ok_or(status)?;
124            tokio::time::sleep(duration).await;
125        } else {
126            return Err(status);
127        }
128    }
129}
130
131#[cfg(test)]
132mod tests {
133    use std::sync::{Arc, Mutex};
134
135    use tokio_retry2::MapErr;
136    use tonic::{Code, Status};
137
138    use crate::retry::{invoke, RetrySetting};
139
140    #[tokio::test]
141    async fn test_retry() {
142        let retry = RetrySetting::default();
143        let counter = Arc::new(Mutex::new(0));
144        let action = || async {
145            let mut lock = counter.lock().unwrap();
146            *lock += 1;
147            let result: Result<i32, Status> = Err(Status::new(Code::Aborted, "error"));
148            result.map_transient_err()
149        };
150        let actual = invoke(Some(retry), action).await.unwrap_err();
151        let expected = Status::new(Code::Aborted, "error");
152        assert_eq!(actual.code(), expected.code());
153        assert_eq!(*counter.lock().unwrap(), 6);
154    }
155}