1use std::future::Future;
13use std::time::Duration;
14
15use osproxy_spi::SpiError;
16
17#[derive(Clone, Copy, PartialEq, Eq, Debug)]
19pub struct RetryPolicy {
20 pub max_attempts: u32,
22 pub base_backoff: Duration,
24 pub max_backoff: Duration,
26}
27
28impl Default for RetryPolicy {
29 fn default() -> Self {
30 Self {
31 max_attempts: 3,
32 base_backoff: Duration::from_millis(50),
33 max_backoff: Duration::from_secs(1),
34 }
35 }
36}
37
38impl RetryPolicy {
39 fn backoff(self, attempt: u32) -> Duration {
42 let factor = 1u32.checked_shl(attempt).unwrap_or(u32::MAX);
43 self.base_backoff
44 .saturating_mul(factor)
45 .min(self.max_backoff)
46 }
47}
48
49pub(crate) async fn with_retry<T, F, Fut>(policy: RetryPolicy, mut op: F) -> Result<T, SpiError>
53where
54 F: FnMut() -> Fut,
55 Fut: Future<Output = Result<T, SpiError>>,
56{
57 let mut attempt = 0;
58 loop {
59 match op().await {
60 Ok(value) => return Ok(value),
61 Err(err) if err.retryable() && attempt + 1 < policy.max_attempts => {
62 let backoff = policy.backoff(attempt);
63 if !backoff.is_zero() {
64 tokio::time::sleep(backoff).await;
65 }
66 attempt += 1;
67 }
68 Err(err) => return Err(err),
69 }
70 }
71}
72
73#[cfg(test)]
74mod tests {
75 use super::*;
76 use std::cell::Cell;
77
78 fn backend_unavailable() -> SpiError {
79 SpiError::PlacementBackend { retryable: true }
80 }
81
82 #[tokio::test]
83 async fn retries_a_transient_backend_then_succeeds() {
84 let policy = RetryPolicy {
86 max_attempts: 3,
87 base_backoff: Duration::ZERO,
88 max_backoff: Duration::ZERO,
89 };
90 let calls = Cell::new(0);
91 let out: Result<u8, SpiError> = with_retry(policy, || {
92 let n = calls.get() + 1;
93 calls.set(n);
94 async move {
95 if n < 3 {
96 Err(backend_unavailable())
97 } else {
98 Ok(7)
99 }
100 }
101 })
102 .await;
103 assert_eq!(out.unwrap(), 7);
104 assert_eq!(calls.get(), 3);
105 }
106
107 #[tokio::test]
108 async fn gives_up_after_max_attempts_with_the_retryable_error() {
109 let policy = RetryPolicy {
110 max_attempts: 2,
111 base_backoff: Duration::ZERO,
112 max_backoff: Duration::ZERO,
113 };
114 let calls = Cell::new(0);
115 let out: Result<u8, SpiError> = with_retry(policy, || {
116 calls.set(calls.get() + 1);
117 async { Err(backend_unavailable()) }
118 })
119 .await;
120 assert!(out.is_err());
121 assert_eq!(calls.get(), 2, "exactly max_attempts tries");
122 }
123
124 #[tokio::test]
125 async fn does_not_retry_a_non_retryable_error() {
126 let policy = RetryPolicy::default();
127 let calls = Cell::new(0);
128 let out: Result<u8, SpiError> = with_retry(policy, || {
129 calls.set(calls.get() + 1);
130 async {
131 Err(SpiError::PlacementMissing {
132 partition: osproxy_core::PartitionId::from("p"),
133 })
134 }
135 })
136 .await;
137 assert!(out.is_err());
138 assert_eq!(calls.get(), 1, "a definitive error is not retried");
139 }
140}