1use std::time::Duration;
2use tokio::time::timeout;
3
4pub struct RetryResult<T, E> {
5 pub success: Option<T>,
6 pub errors: Vec<E>,
7 pub timeout_count: u64,
8}
9
10pub async fn execute_retry<T, E, Fut>(
11 max_try_count: u64,
12 retry_duration: Duration,
13 timeout_duration: Duration,
14 inner: impl Fn(u64) -> Fut,
15) -> RetryResult<T, E>
16where
17 Fut: std::future::Future<Output = Result<T, E>>,
18{
19 execute_retry_with_exponential_backoff(
20 max_try_count,
21 retry_duration,
22 timeout_duration,
23 inner,
24 false,
25 )
26 .await
27}
28
29pub async fn execute_retry_with_exponential_backoff<T, E, Fut>(
30 max_try_count: u64,
31 retry_duration: Duration,
32 timeout_duration: Duration,
33 inner: impl Fn(u64) -> Fut,
34 exponential_backoff: bool,
35) -> RetryResult<T, E>
36where
37 Fut: std::future::Future<Output = Result<T, E>>,
38{
39 let mut try_count = 0;
40 let mut timeout_count = 0;
41 let mut errors = vec![];
42 loop {
43 try_count += 1;
44 if timeout_duration.is_zero() {
45 match inner(try_count).await {
46 Ok(res) => {
47 return RetryResult {
48 success: Some(res),
49 errors,
50 timeout_count,
51 }
52 }
53 Err(err) => {
54 errors.push(err);
55 }
56 }
57 } else {
58 match timeout(timeout_duration, inner(try_count)).await {
59 Ok(res) => match res {
60 Ok(res) => {
61 return RetryResult {
62 success: Some(res),
63 errors,
64 timeout_count,
65 }
66 }
67 Err(err) => {
68 errors.push(err);
69 }
70 },
71 Err(_) => {
72 timeout_count += 1;
73 }
74 }
75 }
76 if try_count >= max_try_count {
77 return RetryResult {
78 success: None,
79 errors,
80 timeout_count,
81 };
82 }
83 if !retry_duration.is_zero() {
84 let duration = if exponential_backoff {
85 retry_duration.mul_f64(2_i32.pow(try_count as u32) as f64)
86 } else {
87 retry_duration
88 };
89 tokio::time::sleep(duration).await;
90 }
91 }
92}
93
94#[cfg(test)]
95mod tests {
96 use std::vec;
97
98 use tokio::time::sleep;
99
100 use super::*;
101 async fn inner_success() -> Result<usize, String> {
104 Ok(1)
105 }
106
107 async fn inner_fail(n: u64) -> Result<usize, String> {
108 println!("inner_fail {}", n);
109 Err("error".to_string())
110 }
111
112 async fn inner_later() -> Result<usize, String> {
113 sleep(Duration::from_millis(100)).await;
114 Ok(1)
115 }
116
117 async fn inner_complex(n: u64) -> Result<usize, String> {
118 if n == 3 {
119 Ok(1)
120 } else {
121 Err("error".to_string())
122 }
123 }
124
125 #[tokio::test]
126 async fn test_retry() -> anyhow::Result<()> {
127 let res = execute_retry(
129 3,
130 Duration::from_secs(0),
131 Duration::from_secs(0),
132 |_n| async { inner_success().await },
133 )
134 .await;
135 assert_eq!(res.success, Some(1));
136 assert_eq!(res.errors.len(), 0);
137 assert_eq!(res.timeout_count, 0);
138
139 let res = execute_retry_with_exponential_backoff(
141 3,
142 Duration::from_secs(1),
143 Duration::from_secs(0),
144 |n| async move { inner_fail(n).await },
145 true,
146 )
147 .await;
148 assert_eq!(res.success, None);
149 assert_eq!(
150 res.errors,
151 vec!["error".to_owned(), "error".to_owned(), "error".to_owned(),]
152 );
153 assert_eq!(res.timeout_count, 0);
154
155 let res = execute_retry(
157 3,
158 Duration::from_secs(0),
159 Duration::from_millis(10),
160 |_n| async { inner_later().await },
161 )
162 .await;
163 assert_eq!(res.success, None);
164 assert_eq!(res.errors.len(), 0);
165 assert_eq!(res.timeout_count, 3);
166
167 let res = execute_retry(
169 3,
170 Duration::from_secs(0),
171 Duration::from_secs(0),
172 |n| async move { inner_complex(n).await },
173 )
174 .await;
175 assert_eq!(res.success, Some(1));
176 assert_eq!(res.errors, vec!["error".to_owned(), "error".to_owned()]);
177 assert_eq!(res.timeout_count, 0);
178
179 Ok(())
180 }
181}