1use std::fmt::Debug;
4use std::future::Future;
5use std::time::Duration;
6
7use tokio::time;
8use tracing::warn;
9
10#[derive(Copy, Clone, Debug)]
17pub struct Retry {
18 pub name: &'static str,
20
21 pub attempts: u32,
23
24 pub base_delay: Duration,
26
27 pub delay_factor: f64,
29
30 pub enable_jitter: bool,
32}
33
34impl Retry {
35 pub const fn new(name: &'static str) -> Self {
37 Self {
38 name,
39 attempts: 3,
40 base_delay: Duration::ZERO,
41 delay_factor: 1.0,
42 enable_jitter: false,
43 }
44 }
45
46 pub const fn attempts(mut self, attempts: u32) -> Self {
48 self.attempts = attempts;
49 self
50 }
51
52 pub const fn base_delay(mut self, base_delay: Duration) -> Self {
54 self.base_delay = base_delay;
55 self
56 }
57
58 pub const fn delay_factor(mut self, delay_factor: f64) -> Self {
60 self.delay_factor = delay_factor;
61 self
62 }
63
64 pub const fn jitter(mut self, enabled: bool) -> Self {
66 self.enable_jitter = enabled;
67 self
68 }
69
70 fn apply_jitter(&self, delay: Duration) -> Duration {
71 if self.enable_jitter {
72 delay.mul_f64(0.5 + fastrand::f64() / 2.0)
74 } else {
75 delay
76 }
77 }
78
79 pub async fn run<T, E: Debug, Fut>(self, mut func: impl FnMut() -> Fut) -> Result<T, E>
84 where
85 Fut: Future<Output = Result<T, E>>,
86 {
87 assert!(self.attempts > 0, "attempts must be greater than 0");
88 assert!(
89 self.base_delay >= Duration::ZERO && self.delay_factor >= 0.0,
90 "retry delay cannot be negative"
91 );
92 let mut delay = self.base_delay;
93 for i in 0..self.attempts {
94 match func().await {
95 Ok(value) => return Ok(value),
96 Err(err) if i == self.attempts - 1 => return Err(err),
97 Err(err) => {
98 warn!(?err, "failed retryable operation {}, retrying", self.name);
99 time::sleep(self.apply_jitter(delay)).await;
100 delay = delay.mul_f64(self.delay_factor);
101 }
102 }
103 }
104 unreachable!();
105 }
106}
107
108#[cfg(test)]
109mod tests {
110 use std::time::Duration;
111
112 use tokio::time::Instant;
113
114 use super::Retry;
115
116 #[tokio::test]
117 #[should_panic]
118 async fn zero_retry_attempts() {
119 let _ = Retry::new("test")
120 .attempts(0)
121 .run(|| async { Ok::<_, std::io::Error>(()) })
122 .await;
123 }
124
125 #[tokio::test]
126 async fn successful_retry() {
127 let mut count = 0;
128 let task = Retry::new("test").run(|| {
129 count += 1;
130 async { Ok::<_, std::io::Error>(()) }
131 });
132 let result = task.await;
133 assert_eq!(count, 1);
134 assert!(result.is_ok());
135 }
136
137 #[tokio::test]
138 async fn failed_retry() {
139 let mut count = 0;
140 let retry = Retry::new("test");
141 let task = retry.run(|| {
142 count += 1;
143 async { Err::<(), ()>(()) }
144 });
145 let result = task.await;
146 assert_eq!(count, retry.attempts);
147 assert!(result.is_err());
148 }
149
150 #[tokio::test(start_paused = true)]
151 async fn delayed_retry() {
152 let start = Instant::now();
153
154 let mut count = 0;
155 let task = Retry::new("test")
157 .attempts(5)
158 .base_delay(Duration::from_secs(1))
159 .delay_factor(2.0)
160 .run(|| {
161 count += 1;
162 async {
163 println!("elapsed = {:?}", start.elapsed());
164 if start.elapsed() < Duration::from_secs(5) {
165 Err::<(), ()>(())
166 } else {
167 Ok(())
168 }
169 }
170 });
171 let result = task.await;
172 assert_eq!(count, 4);
173 assert!(result.is_ok());
174 }
175
176 #[tokio::test(start_paused = true)]
177 async fn delayed_retry_with_jitter() {
178 let start = Instant::now();
179
180 let mut count = 0;
181 let task = Retry::new("test_jitter")
183 .attempts(4)
184 .base_delay(Duration::from_millis(100))
185 .delay_factor(10.0)
186 .jitter(true)
187 .run(|| {
188 count += 1;
189 async {
190 println!("elapsed = {:?}", start.elapsed());
191 if start.elapsed() < Duration::from_millis(500) {
192 Err::<(), ()>(())
193 } else {
194 Ok(())
195 }
196 }
197 });
198 let result = task.await;
199 assert_eq!(count, 3);
200 assert!(result.is_ok());
201 }
202}