Skip to main content

simple_agents_router/
retry.rs

1//! Retry helper for routing operations.
2//!
3//! Provides exponential backoff with jitter for retryable errors.
4
5use simple_agent_type::prelude::{ProviderError, SimpleAgentsError};
6use std::future::Future;
7use std::time::Duration;
8
9/// Retry policy configuration.
10#[derive(Debug, Clone, Copy)]
11pub struct RetryPolicy {
12    /// Maximum number of retry attempts.
13    pub max_attempts: u32,
14    /// Initial backoff duration.
15    pub initial_backoff: Duration,
16    /// Maximum backoff duration.
17    pub max_backoff: Duration,
18    /// Exponential backoff multiplier.
19    pub backoff_multiplier: f32,
20    /// Add jitter to backoff.
21    pub jitter: bool,
22}
23
24impl Default for RetryPolicy {
25    fn default() -> Self {
26        Self {
27            max_attempts: 3,
28            initial_backoff: Duration::from_millis(100),
29            max_backoff: Duration::from_secs(10),
30            backoff_multiplier: 2.0,
31            jitter: true,
32        }
33    }
34}
35
36impl RetryPolicy {
37    fn backoff(&self, attempt: u32) -> Duration {
38        let base =
39            self.initial_backoff.as_millis() as f32 * self.backoff_multiplier.powi(attempt as i32);
40        let capped = base.min(self.max_backoff.as_millis() as f32);
41
42        let duration_ms = if self.jitter {
43            let jitter_factor = 0.5 + (random_f32() * 0.5);
44            capped * jitter_factor
45        } else {
46            capped
47        };
48
49        Duration::from_millis(duration_ms as u64).min(self.max_backoff)
50    }
51}
52
53/// Execute an async operation with retry logic.
54///
55/// Retries only on retryable provider or network errors.
56pub async fn execute_with_retry<F, Fut, T>(
57    policy: RetryPolicy,
58    operation: F,
59) -> Result<T, SimpleAgentsError>
60where
61    F: Fn() -> Fut,
62    Fut: Future<Output = Result<T, SimpleAgentsError>>,
63{
64    if policy.max_attempts == 0 {
65        return Err(SimpleAgentsError::Config(
66            "retry max_attempts must be >= 1".to_string(),
67        ));
68    }
69
70    let mut last_error: Option<SimpleAgentsError> = None;
71
72    for attempt in 0..policy.max_attempts {
73        match operation().await {
74            Ok(result) => return Ok(result),
75            Err(error) => {
76                if !is_retryable(&error) {
77                    return Err(error);
78                }
79
80                if attempt >= policy.max_attempts - 1 {
81                    last_error = Some(error);
82                    break;
83                }
84
85                tokio::time::sleep(policy.backoff(attempt)).await;
86                last_error = Some(error);
87            }
88        }
89    }
90
91    Err(last_error.unwrap_or_else(|| {
92        SimpleAgentsError::Config("retry loop exhausted without attempts".to_string())
93    }))
94}
95
96fn is_retryable(error: &SimpleAgentsError) -> bool {
97    matches!(
98        error,
99        SimpleAgentsError::Provider(
100            ProviderError::RateLimit { .. }
101                | ProviderError::Timeout(_)
102                | ProviderError::ServerError(_)
103        ) | SimpleAgentsError::Network(_)
104    )
105}
106
107fn random_f32() -> f32 {
108    use rand::Rng;
109    rand::thread_rng().gen()
110}
111
112#[cfg(test)]
113mod tests {
114    use super::*;
115
116    #[tokio::test]
117    async fn succeeds_without_retry() {
118        let policy = RetryPolicy {
119            max_attempts: 3,
120            initial_backoff: Duration::from_millis(1),
121            max_backoff: Duration::from_millis(5),
122            backoff_multiplier: 2.0,
123            jitter: false,
124        };
125
126        let result =
127            execute_with_retry(policy, || async { Ok::<_, SimpleAgentsError>("ok") }).await;
128        assert_eq!(result.unwrap(), "ok");
129    }
130
131    #[tokio::test]
132    async fn retries_on_retryable_error() {
133        let policy = RetryPolicy {
134            max_attempts: 2,
135            initial_backoff: Duration::from_millis(1),
136            max_backoff: Duration::from_millis(5),
137            backoff_multiplier: 2.0,
138            jitter: false,
139        };
140
141        use std::sync::atomic::{AtomicUsize, Ordering};
142        use std::sync::Arc;
143
144        let attempts = Arc::new(AtomicUsize::new(0));
145        let attempts_clone = attempts.clone();
146
147        let result = execute_with_retry(policy, move || {
148            let attempts = attempts_clone.clone();
149            async move {
150                let current = attempts.fetch_add(1, Ordering::Relaxed);
151                if current == 0 {
152                    Err(SimpleAgentsError::Provider(ProviderError::Timeout(
153                        Duration::from_secs(1),
154                    )))
155                } else {
156                    Ok("ok")
157                }
158            }
159        })
160        .await;
161
162        assert_eq!(result.unwrap(), "ok");
163        assert_eq!(attempts.load(Ordering::Relaxed), 2);
164    }
165
166    #[tokio::test]
167    async fn fails_on_non_retryable_error() {
168        let policy = RetryPolicy {
169            max_attempts: 3,
170            initial_backoff: Duration::from_millis(1),
171            max_backoff: Duration::from_millis(5),
172            backoff_multiplier: 2.0,
173            jitter: false,
174        };
175
176        use std::sync::atomic::{AtomicUsize, Ordering};
177        use std::sync::Arc;
178
179        let attempts = Arc::new(AtomicUsize::new(0));
180        let attempts_clone = attempts.clone();
181
182        let result = execute_with_retry(policy, move || {
183            let attempts = attempts_clone.clone();
184            async move {
185                attempts.fetch_add(1, Ordering::Relaxed);
186                Err::<&str, _>(SimpleAgentsError::Provider(ProviderError::InvalidApiKey))
187            }
188        })
189        .await;
190
191        assert!(matches!(
192            result,
193            Err(SimpleAgentsError::Provider(ProviderError::InvalidApiKey))
194        ));
195        assert_eq!(attempts.load(Ordering::Relaxed), 1);
196    }
197
198    #[tokio::test]
199    async fn zero_attempts_returns_config_error() {
200        let policy = RetryPolicy {
201            max_attempts: 0,
202            initial_backoff: Duration::from_millis(1),
203            max_backoff: Duration::from_millis(5),
204            backoff_multiplier: 2.0,
205            jitter: false,
206        };
207
208        let result =
209            execute_with_retry(policy, || async { Ok::<_, SimpleAgentsError>("ok") }).await;
210        assert!(matches!(result, Err(SimpleAgentsError::Config(_))));
211    }
212}