ai_sdk_core/
retry.rs

1use backoff::{backoff::Backoff, ExponentialBackoff};
2use std::time::Duration;
3
4/// Retry policy for API calls
5#[derive(Clone)]
6pub struct RetryPolicy {
7    /// Maximum number of retry attempts
8    pub max_retries: u32,
9    /// Initial delay before first retry
10    pub initial_delay: Duration,
11    /// Maximum delay between retries
12    pub max_delay: Duration,
13}
14
15impl Default for RetryPolicy {
16    fn default() -> Self {
17        Self {
18            max_retries: 2,
19            initial_delay: Duration::from_millis(500),
20            max_delay: Duration::from_secs(32),
21        }
22    }
23}
24
25impl RetryPolicy {
26    /// Creates a new RetryPolicy with default settings
27    pub fn new() -> Self {
28        Self::default()
29    }
30
31    /// Sets the maximum number of retry attempts
32    pub fn with_max_retries(mut self, max_retries: u32) -> Self {
33        self.max_retries = max_retries;
34        self
35    }
36
37    /// Sets the initial delay before first retry
38    pub fn with_initial_delay(mut self, delay: Duration) -> Self {
39        self.initial_delay = delay;
40        self
41    }
42
43    /// Sets the maximum delay between retries
44    pub fn with_max_delay(mut self, delay: Duration) -> Self {
45        self.max_delay = delay;
46        self
47    }
48
49    /// Executes the given function with retry logic using exponential backoff
50    pub async fn retry<F, Fut, T, E>(&self, mut f: F) -> Result<T, E>
51    where
52        F: FnMut() -> Fut,
53        Fut: std::future::Future<Output = Result<T, E>>,
54        E: std::fmt::Debug,
55    {
56        let mut backoff = ExponentialBackoff {
57            initial_interval: self.initial_delay,
58            max_interval: self.max_delay,
59            max_elapsed_time: None,
60            ..Default::default()
61        };
62
63        let mut attempt = 0;
64        loop {
65            match f().await {
66                Ok(result) => return Ok(result),
67                Err(err) => {
68                    attempt += 1;
69                    if attempt > self.max_retries {
70                        return Err(err);
71                    }
72
73                    if let Some(delay) = backoff.next_backoff() {
74                        tracing::debug!(
75                            "Retrying after {:?}, attempt {}/{}",
76                            delay,
77                            attempt,
78                            self.max_retries
79                        );
80                        tokio::time::sleep(delay).await;
81                    } else {
82                        return Err(err);
83                    }
84                }
85            }
86        }
87    }
88}
89
90#[cfg(test)]
91mod tests {
92    use super::*;
93
94    #[tokio::test]
95    async fn test_retry_succeeds_on_first_attempt() {
96        let policy = RetryPolicy::default();
97        let result = policy
98            .retry(|| async { Ok::<i32, String>(42) })
99            .await
100            .unwrap();
101        assert_eq!(result, 42);
102    }
103
104    #[tokio::test]
105    async fn test_retry_succeeds_after_failures() {
106        use std::sync::atomic::{AtomicU32, Ordering};
107        use std::sync::Arc;
108
109        let policy = RetryPolicy::default();
110        let attempts = Arc::new(AtomicU32::new(0));
111        let attempts_clone = attempts.clone();
112
113        let result = policy
114            .retry(move || {
115                let attempts = attempts_clone.clone();
116                async move {
117                    let count = attempts.fetch_add(1, Ordering::SeqCst) + 1;
118                    if count < 2 {
119                        Err("transient error")
120                    } else {
121                        Ok(42)
122                    }
123                }
124            })
125            .await
126            .unwrap();
127
128        assert_eq!(result, 42);
129        assert_eq!(attempts.load(Ordering::SeqCst), 2);
130    }
131
132    #[tokio::test]
133    async fn test_retry_fails_after_max_retries() {
134        use std::sync::atomic::{AtomicU32, Ordering};
135        use std::sync::Arc;
136
137        let policy = RetryPolicy::default().with_max_retries(1);
138        let attempts = Arc::new(AtomicU32::new(0));
139        let attempts_clone = attempts.clone();
140
141        let result = policy
142            .retry(move || {
143                let attempts = attempts_clone.clone();
144                async move {
145                    attempts.fetch_add(1, Ordering::SeqCst);
146                    Err::<i32, _>("permanent error")
147                }
148            })
149            .await;
150
151        assert!(result.is_err());
152        assert_eq!(attempts.load(Ordering::SeqCst), 2); // Initial attempt + 1 retry
153    }
154}