Skip to main content

simple_agents_router/
retry.rs

1//! Retry helper for routing operations.
2//!
3//! Provides exponential backoff with jitter for retryable errors.
4
5use simple_agent_type::prelude::{ProviderError, SimpleAgentsError};
6use std::future::Future;
7use std::time::Duration;
8
9/// Retry policy configuration.
10#[derive(Debug, Clone, Copy)]
11pub struct RetryPolicy {
12    /// Maximum number of retry attempts.
13    pub max_attempts: u32,
14    /// Initial backoff duration.
15    pub initial_backoff: Duration,
16    /// Maximum backoff duration.
17    pub max_backoff: Duration,
18    /// Exponential backoff multiplier.
19    pub backoff_multiplier: f32,
20    /// Add jitter to backoff.
21    pub jitter: bool,
22}
23
24impl Default for RetryPolicy {
25    fn default() -> Self {
26        Self {
27            max_attempts: 3,
28            initial_backoff: Duration::from_millis(100),
29            max_backoff: Duration::from_secs(10),
30            backoff_multiplier: 2.0,
31            jitter: true,
32        }
33    }
34}
35
36impl RetryPolicy {
37    fn backoff(&self, attempt: u32) -> Duration {
38        let base =
39            self.initial_backoff.as_millis() as f32 * self.backoff_multiplier.powi(attempt as i32);
40        let capped = base.min(self.max_backoff.as_millis() as f32);
41
42        let duration_ms = if self.jitter {
43            let jitter_factor = 0.5 + (random_f32() * 0.5);
44            capped * jitter_factor
45        } else {
46            capped
47        };
48
49        Duration::from_millis(duration_ms as u64).min(self.max_backoff)
50    }
51}
52
53/// Execute an async operation with retry logic.
54///
55/// Retries only on retryable provider or network errors.
56pub async fn execute_with_retry<F, Fut, T>(
57    policy: RetryPolicy,
58    operation: F,
59) -> Result<T, SimpleAgentsError>
60where
61    F: Fn() -> Fut,
62    Fut: Future<Output = Result<T, SimpleAgentsError>>,
63{
64    let mut last_error: Option<SimpleAgentsError> = None;
65
66    for attempt in 0..policy.max_attempts {
67        match operation().await {
68            Ok(result) => return Ok(result),
69            Err(error) => {
70                if !is_retryable(&error) {
71                    return Err(error);
72                }
73
74                if attempt >= policy.max_attempts - 1 {
75                    last_error = Some(error);
76                    break;
77                }
78
79                tokio::time::sleep(policy.backoff(attempt)).await;
80                last_error = Some(error);
81            }
82        }
83    }
84
85    Err(last_error.unwrap())
86}
87
88fn is_retryable(error: &SimpleAgentsError) -> bool {
89    matches!(
90        error,
91        SimpleAgentsError::Provider(
92            ProviderError::RateLimit { .. }
93                | ProviderError::Timeout(_)
94                | ProviderError::ServerError(_)
95        ) | SimpleAgentsError::Network(_)
96    )
97}
98
99fn random_f32() -> f32 {
100    use rand::Rng;
101    rand::thread_rng().gen()
102}
103
104#[cfg(test)]
105mod tests {
106    use super::*;
107
108    #[tokio::test]
109    async fn succeeds_without_retry() {
110        let policy = RetryPolicy {
111            max_attempts: 3,
112            initial_backoff: Duration::from_millis(1),
113            max_backoff: Duration::from_millis(5),
114            backoff_multiplier: 2.0,
115            jitter: false,
116        };
117
118        let result =
119            execute_with_retry(policy, || async { Ok::<_, SimpleAgentsError>("ok") }).await;
120        assert_eq!(result.unwrap(), "ok");
121    }
122
123    #[tokio::test]
124    async fn retries_on_retryable_error() {
125        let policy = RetryPolicy {
126            max_attempts: 2,
127            initial_backoff: Duration::from_millis(1),
128            max_backoff: Duration::from_millis(5),
129            backoff_multiplier: 2.0,
130            jitter: false,
131        };
132
133        use std::sync::atomic::{AtomicUsize, Ordering};
134        use std::sync::Arc;
135
136        let attempts = Arc::new(AtomicUsize::new(0));
137        let attempts_clone = attempts.clone();
138
139        let result = execute_with_retry(policy, move || {
140            let attempts = attempts_clone.clone();
141            async move {
142                let current = attempts.fetch_add(1, Ordering::Relaxed);
143                if current == 0 {
144                    Err(SimpleAgentsError::Provider(ProviderError::Timeout(
145                        Duration::from_secs(1),
146                    )))
147                } else {
148                    Ok("ok")
149                }
150            }
151        })
152        .await;
153
154        assert_eq!(result.unwrap(), "ok");
155        assert_eq!(attempts.load(Ordering::Relaxed), 2);
156    }
157
158    #[tokio::test]
159    async fn fails_on_non_retryable_error() {
160        let policy = RetryPolicy {
161            max_attempts: 3,
162            initial_backoff: Duration::from_millis(1),
163            max_backoff: Duration::from_millis(5),
164            backoff_multiplier: 2.0,
165            jitter: false,
166        };
167
168        use std::sync::atomic::{AtomicUsize, Ordering};
169        use std::sync::Arc;
170
171        let attempts = Arc::new(AtomicUsize::new(0));
172        let attempts_clone = attempts.clone();
173
174        let result = execute_with_retry(policy, move || {
175            let attempts = attempts_clone.clone();
176            async move {
177                attempts.fetch_add(1, Ordering::Relaxed);
178                Err::<&str, _>(SimpleAgentsError::Provider(ProviderError::InvalidApiKey))
179            }
180        })
181        .await;
182
183        assert!(matches!(
184            result,
185            Err(SimpleAgentsError::Provider(ProviderError::InvalidApiKey))
186        ));
187        assert_eq!(attempts.load(Ordering::Relaxed), 1);
188    }
189}