1use crate::*;
2use duration_str::*;
3use serde::Deserialize;
4use std::time::Duration;
5
6#[derive(Debug, Clone, Copy, Deserialize, PartialEq)]
7#[serde(tag = "strategy")]
8pub enum BackoffConfig {
10 Constant {
12 #[serde(default = "defaults::delay", deserialize_with = "deserialize_duration")]
16 delay: Duration,
17
18 #[serde(default = "defaults::max_retries")]
22 max_retries: usize,
23
24 #[serde(default = "defaults::jitter_enabled")]
28 jitter_enabled: bool,
29
30 #[serde(default = "defaults::jitter_seed")]
34 jitter_seed: Option<u64>,
35 },
36
37 Exponential {
39 #[serde(default = "defaults::delay", deserialize_with = "deserialize_duration")]
43 initial_delay: Duration,
44
45 #[serde(default = "defaults::factor")]
49 factor: f32,
50
51 #[serde(
55 default = "defaults::max_delay",
56 deserialize_with = "deserialize_duration"
57 )]
58 max_delay: Duration,
59
60 #[serde(default = "defaults::max_retries")]
64 max_retries: usize,
65
66 #[serde(
70 default = "defaults::max_total_delay",
71 deserialize_with = "deserialize_duration"
72 )]
73 max_total_delay: Duration,
74
75 #[serde(default = "defaults::jitter_enabled")]
79 jitter_enabled: bool,
80
81 #[serde(default = "defaults::jitter_seed")]
85 jitter_seed: Option<u64>,
86 },
87
88 Fibonacci {
90 #[serde(default = "defaults::delay", deserialize_with = "deserialize_duration")]
94 initial_delay: Duration,
95
96 #[serde(
100 default = "defaults::max_delay",
101 deserialize_with = "deserialize_duration"
102 )]
103 max_delay: Duration,
104
105 #[serde(default = "defaults::max_retries")]
109 max_retries: usize,
110
111 #[serde(default = "defaults::jitter_enabled")]
115 jitter_enabled: bool,
116
117 #[serde(default = "defaults::jitter_seed")]
121 jitter_seed: Option<u64>,
122 },
123}
124
125impl backon::BackoffBuilder for BackoffConfig {
126 type Backoff = Backoff;
127
128 fn build(self) -> Backoff {
129 match self {
130 BackoffConfig::Constant {
131 delay,
132 max_retries,
133 jitter_enabled,
134 jitter_seed,
135 } => {
136 let mut builder = backon::ConstantBuilder::new()
137 .with_delay(delay)
138 .with_max_times(max_retries);
139
140 if jitter_enabled {
141 builder = builder.with_jitter();
142 }
143
144 if let Some(jitter_seed) = jitter_seed {
145 builder = builder.with_jitter_seed(jitter_seed);
146 }
147
148 Backoff::Constant(builder.build())
149 }
150
151 BackoffConfig::Exponential {
152 initial_delay,
153 factor,
154 max_delay,
155 max_retries,
156 max_total_delay,
157 jitter_enabled,
158 jitter_seed,
159 } => {
160 let mut builder = backon::ExponentialBuilder::new()
161 .with_min_delay(initial_delay)
162 .with_factor(factor)
163 .with_max_delay(max_delay)
164 .with_max_times(max_retries)
165 .with_total_delay(Some(max_total_delay));
166
167 if jitter_enabled {
168 builder = builder.with_jitter();
169 }
170
171 if let Some(jitter_seed) = jitter_seed {
172 builder = builder.with_jitter_seed(jitter_seed);
173 }
174
175 Backoff::Exponential(builder.build())
176 }
177
178 BackoffConfig::Fibonacci {
179 initial_delay,
180 max_delay,
181 max_retries,
182 jitter_enabled,
183 jitter_seed,
184 } => {
185 let mut builder = backon::FibonacciBuilder::new()
186 .with_min_delay(initial_delay)
187 .with_max_delay(max_delay)
188 .with_max_times(max_retries);
189
190 if jitter_enabled {
191 builder = builder.with_jitter();
192 }
193
194 if let Some(jitter_seed) = jitter_seed {
195 builder = builder.with_jitter_seed(jitter_seed);
196 }
197
198 Backoff::Fibonacci(builder.build())
199 }
200 }
201 }
202}
203
204pub mod defaults {
206 use std::time::Duration;
207
208 pub const fn delay() -> Duration {
210 Duration::from_millis(500)
211 }
212
213 pub const fn max_retries() -> usize {
215 4
216 }
217
218 pub const fn jitter_enabled() -> bool {
220 true
221 }
222
223 pub const fn jitter_seed() -> Option<u64> {
225 None
226 }
227
228 pub const fn factor() -> f32 {
230 2.0
231 }
232
233 pub const fn max_delay() -> Duration {
235 Duration::from_secs(30)
236 }
237
238 pub const fn max_total_delay() -> Duration {
240 Duration::from_secs(60)
241 }
242}
243
244#[cfg(test)]
245mod tests {
246 use super::*;
247 use backon::BackoffBuilder;
248 use std::time::Duration;
249
250 #[test]
251 fn constant_backoff_config_to_backoff() {
252 let config = BackoffConfig::Constant {
253 delay: Duration::from_secs(1),
254 max_retries: 3,
255 jitter_enabled: false,
256 jitter_seed: None,
257 };
258
259 let backoff = config.build();
260 assert!(matches!(backoff, Backoff::Constant(_)));
261
262 assert_eq!(
263 backoff
264 .take(100)
265 .map(|duration| duration.as_millis())
266 .collect::<Vec<_>>(),
267 vec![1000; 3]
268 );
269 }
270
271 #[test]
272 fn exponential_backoff_config_to_backoff() {
273 let config = BackoffConfig::Exponential {
274 initial_delay: Duration::from_millis(100),
275 factor: 2_f32,
276 max_delay: Duration::from_millis(800),
277 max_retries: 5,
278 max_total_delay: Duration::from_secs(1000),
279 jitter_enabled: false,
280 jitter_seed: None,
281 };
282
283 let backoff = config.build();
284 assert!(matches!(backoff, Backoff::Exponential(_)));
285
286 assert_eq!(
287 backoff
288 .take(100)
289 .map(|duration| duration.as_millis())
290 .collect::<Vec<_>>(),
291 vec![100, 200, 400, 800, 800]
292 );
293 }
294
295 #[test]
296 fn exponential_backoff_config_to_backoff_with_max_total_delay() {
297 let config = BackoffConfig::Exponential {
298 initial_delay: Duration::from_millis(100),
299 factor: 2_f32,
300 max_delay: Duration::from_millis(800),
301 max_retries: 5,
302 max_total_delay: Duration::from_millis(1500 + 1),
303 jitter_enabled: false,
304 jitter_seed: None,
305 };
306
307 let backoff = config.build();
308 assert!(matches!(backoff, Backoff::Exponential(_)));
309
310 assert_eq!(
311 backoff
312 .take(100)
313 .map(|duration| duration.as_millis())
314 .collect::<Vec<_>>(),
315 vec![100, 200, 400, 800]
316 );
317 }
318
319 #[test]
320 fn fibonacci_backoff_config_to_backoff() {
321 let config = BackoffConfig::Fibonacci {
322 initial_delay: Duration::from_millis(100),
323 max_delay: Duration::from_millis(800),
324 max_retries: 5,
325 jitter_enabled: false,
326 jitter_seed: None,
327 };
328
329 let backoff = config.build();
330 assert!(matches!(backoff, Backoff::Fibonacci(_)));
331
332 assert_eq!(
333 backoff
334 .take(usize::MAX)
335 .map(|duration| duration.as_millis())
336 .collect::<Vec<_>>(),
337 vec![100, 100, 200, 300, 500]
338 );
339 }
340}