1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
pub use tokio_retry::strategy::ExponentialBackoff;
pub use tokio_retry::Condition;

use crate::cancel::CancellationToken;
use crate::grpc::{Code, Status};
use std::future::Future;
use std::iter::Take;

use std::time::Duration;
use tokio::select;

use tokio_retry::Action;
use tokio_retry::RetryIf;

pub trait TryAs<T> {
    fn try_as(&self) -> Option<&T>;
}

impl TryAs<Status> for Status {
    fn try_as(&self) -> Option<&Status> {
        Some(self)
    }
}

pub trait Retry<E: TryAs<Status>, T: Condition<E>> {
    fn strategy(&self) -> Take<ExponentialBackoff>;
    fn condition(&self) -> T;
}

pub struct CodeCondition {
    codes: Vec<Code>,
}

impl CodeCondition {
    pub fn new(codes: Vec<Code>) -> Self {
        Self { codes }
    }
}

impl<E> Condition<E> for CodeCondition
where
    E: TryAs<Status>,
{
    fn should_retry(&mut self, error: &E) -> bool {
        if let Some(status) = error.try_as() {
            for code in &self.codes {
                if *code == status.code() {
                    return true;
                }
            }
        }
        false
    }
}

#[derive(Clone, Debug)]
pub struct RetrySetting {
    pub from_millis: u64,
    pub max_delay: Option<Duration>,
    pub factor: u64,
    pub take: usize,
    pub codes: Vec<Code>,
}

impl Retry<Status, CodeCondition> for RetrySetting {
    fn strategy(&self) -> Take<ExponentialBackoff> {
        let mut st = tokio_retry::strategy::ExponentialBackoff::from_millis(self.from_millis);
        if let Some(max_delay) = self.max_delay {
            st = st.max_delay(max_delay);
        }
        st.take(self.take)
    }

    fn condition(&self) -> CodeCondition {
        CodeCondition::new(self.codes.clone())
    }
}

impl Default for RetrySetting {
    fn default() -> Self {
        Self {
            from_millis: 10,
            max_delay: Some(Duration::from_secs(1)),
            factor: 1u64,
            take: 5,
            codes: vec![Code::Unavailable, Code::Unknown, Code::Aborted],
        }
    }
}

pub async fn invoke<A, R, RT, C, E>(cancel: Option<CancellationToken>, retry: Option<RT>, action: A) -> Result<R, E>
where
    E: TryAs<Status> + From<Status>,
    A: Action<Item = R, Error = E>,
    C: Condition<E>,
    RT: Retry<E, C> + Default,
{
    let retry = retry.unwrap_or_default();
    match cancel {
        Some(cancel) => {
            select! {
                _ = cancel.cancelled() => Err(Status::cancelled("client cancel").into()),
                v = RetryIf::spawn(retry.strategy(), action, retry.condition()) => v
            }
        }
        None => RetryIf::spawn(retry.strategy(), action, retry.condition()).await,
    }
}
/// Repeats retries when the specified error is detected.
/// The argument specified by 'v' can be reused for each retry.
pub async fn invoke_fn<R, V, A, RT, C, E>(
    cancel: Option<CancellationToken>,
    retry: Option<RT>,
    mut f: impl FnMut(V) -> A,
    mut v: V,
) -> Result<R, E>
where
    E: TryAs<Status> + From<Status>,
    A: Future<Output = Result<R, (E, V)>>,
    C: Condition<E>,
    RT: Retry<E, C> + Default,
{
    let fn_loop = async {
        let retry = retry.unwrap_or_default();
        let mut strategy = retry.strategy();
        loop {
            let result = f(v).await;
            let status = match result {
                Ok(s) => return Ok(s),
                Err(e) => {
                    v = e.1;
                    e.0
                }
            };
            if retry.condition().should_retry(&status) {
                let duration = match strategy.next() {
                    None => return Err(status),
                    Some(s) => s,
                };
                tokio::time::sleep(duration).await;
                tracing::trace!("retry fn");
            } else {
                return Err(status);
            }
        }
    };
    match cancel {
        Some(cancel) => {
            select! {
                _ = cancel.cancelled() => Err(Status::cancelled("client cancel").into()),
                v = fn_loop => v
            }
        }
        None => fn_loop.await,
    }
}