resident_utils/
retry.rs

1use std::time::Duration;
2use tokio::time::timeout;
3
4pub struct RetryResult<T, E> {
5    pub success: Option<T>,
6    pub errors: Vec<E>,
7    pub timeout_count: u64,
8}
9
10pub async fn execute_retry<T, E, Fut>(
11    max_try_count: u64,
12    retry_duration: Duration,
13    timeout_duration: Duration,
14    inner: impl Fn(u64) -> Fut,
15) -> RetryResult<T, E>
16where
17    Fut: std::future::Future<Output = Result<T, E>>,
18{
19    execute_retry_with_exponential_backoff(
20        max_try_count,
21        retry_duration,
22        timeout_duration,
23        inner,
24        false,
25    )
26    .await
27}
28
29pub async fn execute_retry_with_exponential_backoff<T, E, Fut>(
30    max_try_count: u64,
31    retry_duration: Duration,
32    timeout_duration: Duration,
33    inner: impl Fn(u64) -> Fut,
34    exponential_backoff: bool,
35) -> RetryResult<T, E>
36where
37    Fut: std::future::Future<Output = Result<T, E>>,
38{
39    let mut try_count = 0;
40    let mut timeout_count = 0;
41    let mut errors = vec![];
42    loop {
43        try_count += 1;
44        if timeout_duration.is_zero() {
45            match inner(try_count).await {
46                Ok(res) => {
47                    return RetryResult {
48                        success: Some(res),
49                        errors,
50                        timeout_count,
51                    }
52                }
53                Err(err) => {
54                    errors.push(err);
55                }
56            }
57        } else {
58            match timeout(timeout_duration, inner(try_count)).await {
59                Ok(res) => match res {
60                    Ok(res) => {
61                        return RetryResult {
62                            success: Some(res),
63                            errors,
64                            timeout_count,
65                        }
66                    }
67                    Err(err) => {
68                        errors.push(err);
69                    }
70                },
71                Err(_) => {
72                    timeout_count += 1;
73                }
74            }
75        }
76        if try_count >= max_try_count {
77            return RetryResult {
78                success: None,
79                errors,
80                timeout_count,
81            };
82        }
83        if !retry_duration.is_zero() {
84            let duration = if exponential_backoff {
85                retry_duration.mul_f64(2_i32.pow(try_count as u32) as f64)
86            } else {
87                retry_duration
88            };
89            tokio::time::sleep(duration).await;
90        }
91    }
92}
93
94#[cfg(test)]
95mod tests {
96    use std::vec;
97
98    use tokio::time::sleep;
99
100    use super::*;
101    // REALM_CODE=test cargo test test_retry -- --nocapture --test-threads=1
102
103    async fn inner_success() -> Result<usize, String> {
104        Ok(1)
105    }
106
107    async fn inner_fail(n: u64) -> Result<usize, String> {
108        println!("inner_fail {}", n);
109        Err("error".to_string())
110    }
111
112    async fn inner_later() -> Result<usize, String> {
113        sleep(Duration::from_millis(100)).await;
114        Ok(1)
115    }
116
117    async fn inner_complex(n: u64) -> Result<usize, String> {
118        if n == 3 {
119            Ok(1)
120        } else {
121            Err("error".to_string())
122        }
123    }
124
125    #[tokio::test]
126    async fn test_retry() -> anyhow::Result<()> {
127        // Success
128        let res = execute_retry(
129            3,
130            Duration::from_secs(0),
131            Duration::from_secs(0),
132            |_n| async { inner_success().await },
133        )
134        .await;
135        assert_eq!(res.success, Some(1));
136        assert_eq!(res.errors.len(), 0);
137        assert_eq!(res.timeout_count, 0);
138
139        // Failure
140        let res = execute_retry_with_exponential_backoff(
141            3,
142            Duration::from_secs(1),
143            Duration::from_secs(0),
144            |n| async move { inner_fail(n).await },
145            true,
146        )
147        .await;
148        assert_eq!(res.success, None);
149        assert_eq!(
150            res.errors,
151            vec!["error".to_owned(), "error".to_owned(), "error".to_owned(),]
152        );
153        assert_eq!(res.timeout_count, 0);
154
155        // Timeout
156        let res = execute_retry(
157            3,
158            Duration::from_secs(0),
159            Duration::from_millis(10),
160            |_n| async { inner_later().await },
161        )
162        .await;
163        assert_eq!(res.success, None);
164        assert_eq!(res.errors.len(), 0);
165        assert_eq!(res.timeout_count, 3);
166
167        // Complex
168        let res = execute_retry(
169            3,
170            Duration::from_secs(0),
171            Duration::from_secs(0),
172            |n| async move { inner_complex(n).await },
173        )
174        .await;
175        assert_eq!(res.success, Some(1));
176        assert_eq!(res.errors, vec!["error".to_owned(), "error".to_owned()]);
177        assert_eq!(res.timeout_count, 0);
178
179        Ok(())
180    }
181}