1use std::fmt::Debug;
4use std::time::Duration;
5
6use tokio::time;
7use tracing::warn;
8
9#[derive(Copy, Clone, Debug)]
16pub struct Retry {
17 pub name: &'static str,
19
20 pub attempts: u32,
22
23 pub base_delay: Duration,
25
26 pub delay_factor: f64,
28
29 pub enable_jitter: bool,
31}
32
33impl Retry {
34 pub const fn new(name: &'static str) -> Self {
36 Self {
37 name,
38 attempts: 3,
39 base_delay: Duration::ZERO,
40 delay_factor: 1.0,
41 enable_jitter: false,
42 }
43 }
44
45 pub const fn attempts(mut self, attempts: u32) -> Self {
47 self.attempts = attempts;
48 self
49 }
50
51 pub const fn base_delay(mut self, base_delay: Duration) -> Self {
53 self.base_delay = base_delay;
54 self
55 }
56
57 pub const fn delay_factor(mut self, delay_factor: f64) -> Self {
59 self.delay_factor = delay_factor;
60 self
61 }
62
63 pub const fn jitter(mut self, enabled: bool) -> Self {
65 self.enable_jitter = enabled;
66 self
67 }
68
69 fn apply_jitter(&self, delay: Duration) -> Duration {
70 if self.enable_jitter {
71 delay.mul_f64(0.5 + fastrand::f64() / 2.0)
73 } else {
74 delay
75 }
76 }
77
78 pub async fn run<T, E: Debug>(
83 self,
84 mut func: impl AsyncFnMut() -> Result<T, E>,
85 ) -> Result<T, E> {
86 assert!(self.attempts > 0, "attempts must be greater than 0");
87 assert!(
88 self.base_delay >= Duration::ZERO && self.delay_factor >= 0.0,
89 "retry delay cannot be negative"
90 );
91 let mut delay = self.base_delay;
92 for i in 0..self.attempts {
93 match func().await {
94 Ok(value) => return Ok(value),
95 Err(err) if i == self.attempts - 1 => return Err(err),
96 Err(err) => {
97 warn!(?err, "failed retryable operation {}, retrying", self.name);
98 time::sleep(self.apply_jitter(delay)).await;
99 delay = delay.mul_f64(self.delay_factor);
100 }
101 }
102 }
103 unreachable!();
104 }
105}
106
107#[cfg(test)]
108mod tests {
109 use std::time::Duration;
110
111 use tokio::time::Instant;
112
113 use super::Retry;
114
115 #[tokio::test]
116 #[should_panic]
117 async fn zero_retry_attempts() {
118 let _ = Retry::new("test")
119 .attempts(0)
120 .run(async || Ok::<_, std::io::Error>(()))
121 .await;
122 }
123
124 #[tokio::test]
125 async fn successful_retry() {
126 let mut count = 0;
127 let task = Retry::new("test").run(async || {
128 count += 1;
129 Ok::<_, std::io::Error>(())
130 });
131 let result = task.await;
132 assert_eq!(count, 1);
133 assert!(result.is_ok());
134 }
135
136 #[tokio::test]
137 async fn failed_retry() {
138 let mut count = 0;
139 let retry = Retry::new("test");
140 let task = retry.run(async || {
141 count += 1;
142 Err::<(), ()>(())
143 });
144 let result = task.await;
145 assert_eq!(count, retry.attempts);
146 assert!(result.is_err());
147 }
148
149 #[tokio::test(start_paused = true)]
150 async fn delayed_retry() {
151 let start = Instant::now();
152
153 let mut count = 0;
154 let task = Retry::new("test")
156 .attempts(5)
157 .base_delay(Duration::from_secs(1))
158 .delay_factor(2.0)
159 .run(async || {
160 count += 1;
161 println!("elapsed = {:?}", start.elapsed());
162 if start.elapsed() < Duration::from_secs(5) {
163 Err::<(), ()>(())
164 } else {
165 Ok(())
166 }
167 });
168 let result = task.await;
169 assert_eq!(count, 4);
170 assert!(result.is_ok());
171 }
172
173 #[tokio::test(start_paused = true)]
174 async fn delayed_retry_with_jitter() {
175 let start = Instant::now();
176
177 let mut count = 0;
178 let task = Retry::new("test_jitter")
180 .attempts(4)
181 .base_delay(Duration::from_millis(100))
182 .delay_factor(10.0)
183 .jitter(true)
184 .run(async || {
185 count += 1;
186 println!("elapsed = {:?}", start.elapsed());
187 if start.elapsed() < Duration::from_millis(500) {
188 Err::<(), ()>(())
189 } else {
190 Ok(())
191 }
192 });
193 let result = task.await;
194 assert_eq!(count, 3);
195 assert!(result.is_ok());
196 }
197}