simple_agents_router/
retry.rs1use simple_agent_type::prelude::{ProviderError, SimpleAgentsError};
6use std::future::Future;
7use std::time::Duration;
8
9#[derive(Debug, Clone, Copy)]
11pub struct RetryPolicy {
12 pub max_attempts: u32,
14 pub initial_backoff: Duration,
16 pub max_backoff: Duration,
18 pub backoff_multiplier: f32,
20 pub jitter: bool,
22}
23
24impl Default for RetryPolicy {
25 fn default() -> Self {
26 Self {
27 max_attempts: 3,
28 initial_backoff: Duration::from_millis(100),
29 max_backoff: Duration::from_secs(10),
30 backoff_multiplier: 2.0,
31 jitter: true,
32 }
33 }
34}
35
36impl RetryPolicy {
37 fn backoff(&self, attempt: u32) -> Duration {
38 let base =
39 self.initial_backoff.as_millis() as f32 * self.backoff_multiplier.powi(attempt as i32);
40 let capped = base.min(self.max_backoff.as_millis() as f32);
41
42 let duration_ms = if self.jitter {
43 let jitter_factor = 0.5 + (random_f32() * 0.5);
44 capped * jitter_factor
45 } else {
46 capped
47 };
48
49 Duration::from_millis(duration_ms as u64).min(self.max_backoff)
50 }
51}
52
53pub async fn execute_with_retry<F, Fut, T>(
57 policy: RetryPolicy,
58 operation: F,
59) -> Result<T, SimpleAgentsError>
60where
61 F: Fn() -> Fut,
62 Fut: Future<Output = Result<T, SimpleAgentsError>>,
63{
64 if policy.max_attempts == 0 {
65 return Err(SimpleAgentsError::Config(
66 "retry max_attempts must be >= 1".to_string(),
67 ));
68 }
69
70 let mut last_error: Option<SimpleAgentsError> = None;
71
72 for attempt in 0..policy.max_attempts {
73 match operation().await {
74 Ok(result) => return Ok(result),
75 Err(error) => {
76 if !is_retryable(&error) {
77 return Err(error);
78 }
79
80 if attempt >= policy.max_attempts - 1 {
81 last_error = Some(error);
82 break;
83 }
84
85 tokio::time::sleep(policy.backoff(attempt)).await;
86 last_error = Some(error);
87 }
88 }
89 }
90
91 Err(last_error.unwrap_or_else(|| {
92 SimpleAgentsError::Config("retry loop exhausted without attempts".to_string())
93 }))
94}
95
96fn is_retryable(error: &SimpleAgentsError) -> bool {
97 matches!(
98 error,
99 SimpleAgentsError::Provider(
100 ProviderError::RateLimit { .. }
101 | ProviderError::Timeout(_)
102 | ProviderError::ServerError(_)
103 ) | SimpleAgentsError::Network(_)
104 )
105}
106
107fn random_f32() -> f32 {
108 use rand::Rng;
109 rand::thread_rng().gen()
110}
111
112#[cfg(test)]
113mod tests {
114 use super::*;
115
116 #[tokio::test]
117 async fn succeeds_without_retry() {
118 let policy = RetryPolicy {
119 max_attempts: 3,
120 initial_backoff: Duration::from_millis(1),
121 max_backoff: Duration::from_millis(5),
122 backoff_multiplier: 2.0,
123 jitter: false,
124 };
125
126 let result =
127 execute_with_retry(policy, || async { Ok::<_, SimpleAgentsError>("ok") }).await;
128 assert_eq!(result.unwrap(), "ok");
129 }
130
131 #[tokio::test]
132 async fn retries_on_retryable_error() {
133 let policy = RetryPolicy {
134 max_attempts: 2,
135 initial_backoff: Duration::from_millis(1),
136 max_backoff: Duration::from_millis(5),
137 backoff_multiplier: 2.0,
138 jitter: false,
139 };
140
141 use std::sync::atomic::{AtomicUsize, Ordering};
142 use std::sync::Arc;
143
144 let attempts = Arc::new(AtomicUsize::new(0));
145 let attempts_clone = attempts.clone();
146
147 let result = execute_with_retry(policy, move || {
148 let attempts = attempts_clone.clone();
149 async move {
150 let current = attempts.fetch_add(1, Ordering::Relaxed);
151 if current == 0 {
152 Err(SimpleAgentsError::Provider(ProviderError::Timeout(
153 Duration::from_secs(1),
154 )))
155 } else {
156 Ok("ok")
157 }
158 }
159 })
160 .await;
161
162 assert_eq!(result.unwrap(), "ok");
163 assert_eq!(attempts.load(Ordering::Relaxed), 2);
164 }
165
166 #[tokio::test]
167 async fn fails_on_non_retryable_error() {
168 let policy = RetryPolicy {
169 max_attempts: 3,
170 initial_backoff: Duration::from_millis(1),
171 max_backoff: Duration::from_millis(5),
172 backoff_multiplier: 2.0,
173 jitter: false,
174 };
175
176 use std::sync::atomic::{AtomicUsize, Ordering};
177 use std::sync::Arc;
178
179 let attempts = Arc::new(AtomicUsize::new(0));
180 let attempts_clone = attempts.clone();
181
182 let result = execute_with_retry(policy, move || {
183 let attempts = attempts_clone.clone();
184 async move {
185 attempts.fetch_add(1, Ordering::Relaxed);
186 Err::<&str, _>(SimpleAgentsError::Provider(ProviderError::InvalidApiKey))
187 }
188 })
189 .await;
190
191 assert!(matches!(
192 result,
193 Err(SimpleAgentsError::Provider(ProviderError::InvalidApiKey))
194 ));
195 assert_eq!(attempts.load(Ordering::Relaxed), 1);
196 }
197
198 #[tokio::test]
199 async fn zero_attempts_returns_config_error() {
200 let policy = RetryPolicy {
201 max_attempts: 0,
202 initial_backoff: Duration::from_millis(1),
203 max_backoff: Duration::from_millis(5),
204 backoff_multiplier: 2.0,
205 jitter: false,
206 };
207
208 let result =
209 execute_with_retry(policy, || async { Ok::<_, SimpleAgentsError>("ok") }).await;
210 assert!(matches!(result, Err(SimpleAgentsError::Config(_))));
211 }
212}