1use std::future::Future;
2use std::iter::Take;
3use std::time::Duration;
4
5pub use tokio_retry2::strategy::ExponentialBackoff;
6use tokio_retry2::{Action, RetryIf};
7pub use tokio_retry2::{Condition, MapErr};
8
9use crate::grpc::{Code, Status};
10
11pub trait TryAs<T> {
12 fn try_as(&self) -> Option<&T>;
13}
14
15impl TryAs<Status> for Status {
16 fn try_as(&self) -> Option<&Status> {
17 Some(self)
18 }
19}
20
21pub trait Retry<E: TryAs<Status>, T: Condition<E>> {
22 fn strategy(&self) -> Take<ExponentialBackoff>;
23 fn condition(&self) -> T;
24 fn notify(error: &E, duration: Duration);
25}
26
27pub struct CodeCondition {
28 codes: Vec<Code>,
29}
30
31impl CodeCondition {
32 pub fn new(codes: Vec<Code>) -> Self {
33 Self { codes }
34 }
35}
36
37impl<E> Condition<E> for CodeCondition
38where
39 E: TryAs<Status>,
40{
41 fn should_retry(&mut self, error: &E) -> bool {
42 if let Some(status) = error.try_as() {
43 for code in &self.codes {
44 if *code == status.code() {
45 return true;
46 }
47 }
48 }
49 false
50 }
51}
52
53#[derive(Clone, Debug)]
54pub struct RetrySetting {
55 pub from_millis: u64,
56 pub max_delay: Option<Duration>,
57 pub factor: u64,
58 pub take: usize,
59 pub codes: Vec<Code>,
60}
61
62impl Retry<Status, CodeCondition> for RetrySetting {
63 fn strategy(&self) -> Take<ExponentialBackoff> {
64 let mut st = ExponentialBackoff::from_millis(self.from_millis);
65 if let Some(max_delay) = self.max_delay {
66 st = st.max_delay(max_delay);
67 }
68 st.take(self.take)
69 }
70
71 fn condition(&self) -> CodeCondition {
72 CodeCondition::new(self.codes.clone())
73 }
74
75 fn notify(_error: &Status, _duration: Duration) {
76 tracing::trace!("retry fn");
77 }
78}
79
80impl Default for RetrySetting {
81 fn default() -> Self {
82 Self {
83 from_millis: 10,
84 max_delay: Some(Duration::from_secs(1)),
85 factor: 1u64,
86 take: 5,
87 codes: vec![Code::Unavailable, Code::Unknown, Code::Aborted],
88 }
89 }
90}
91
92pub async fn invoke<A, R, RT, C, E>(retry: Option<RT>, action: A) -> Result<R, E>
93where
94 E: TryAs<Status> + From<Status>,
95 A: Action<Item = R, Error = E>,
96 C: Condition<E>,
97 RT: Retry<E, C> + Default,
98{
99 let retry = retry.unwrap_or_default();
100 RetryIf::spawn(retry.strategy(), action, retry.condition(), RT::notify).await
101}
102pub async fn invoke_fn<R, V, A, RT, C, E>(retry: Option<RT>, mut f: impl FnMut(V) -> A, mut v: V) -> Result<R, E>
105where
106 E: TryAs<Status> + From<Status>,
107 A: Future<Output = Result<R, (E, V)>>,
108 C: Condition<E>,
109 RT: Retry<E, C> + Default,
110{
111 let retry = retry.unwrap_or_default();
112 let mut strategy = retry.strategy();
113 loop {
114 let result = f(v).await;
115 let status = match result {
116 Ok(s) => return Ok(s),
117 Err(e) => {
118 v = e.1;
119 e.0
120 }
121 };
122 if retry.condition().should_retry(&status) {
123 let duration = strategy.next().ok_or(status)?;
124 tokio::time::sleep(duration).await;
125 } else {
126 return Err(status);
127 }
128 }
129}
130
131#[cfg(test)]
132mod tests {
133 use std::sync::{Arc, Mutex};
134
135 use tokio_retry2::MapErr;
136 use tonic::{Code, Status};
137
138 use crate::retry::{invoke, RetrySetting};
139
140 #[tokio::test]
141 async fn test_retry() {
142 let retry = RetrySetting::default();
143 let counter = Arc::new(Mutex::new(0));
144 let action = || async {
145 let mut lock = counter.lock().unwrap();
146 *lock += 1;
147 let result: Result<i32, Status> = Err(Status::new(Code::Aborted, "error"));
148 result.map_transient_err()
149 };
150 let actual = invoke(Some(retry), action).await.unwrap_err();
151 let expected = Status::new(Code::Aborted, "error");
152 assert_eq!(actual.code(), expected.code());
153 assert_eq!(*counter.lock().unwrap(), 6);
154 }
155}