Skip to main content

aster/providers/
retry.rs

1use super::errors::ProviderError;
2use crate::providers::base::Provider;
3use async_trait::async_trait;
4use std::future::Future;
5use std::time::Duration;
6use tokio::time::sleep;
7
8pub const DEFAULT_MAX_RETRIES: usize = 3;
9pub const DEFAULT_INITIAL_RETRY_INTERVAL_MS: u64 = 1000;
10pub const DEFAULT_BACKOFF_MULTIPLIER: f64 = 2.0;
11pub const DEFAULT_MAX_RETRY_INTERVAL_MS: u64 = 30_000;
12
13#[derive(Debug, Clone)]
14pub struct RetryConfig {
15    /// Maximum number of retry attempts
16    pub(crate) max_retries: usize,
17    /// Initial interval between retries in milliseconds
18    pub(crate) initial_interval_ms: u64,
19    /// Multiplier for backoff (exponential)
20    pub(crate) backoff_multiplier: f64,
21    /// Maximum interval between retries in milliseconds
22    pub(crate) max_interval_ms: u64,
23}
24
25impl Default for RetryConfig {
26    fn default() -> Self {
27        Self {
28            max_retries: DEFAULT_MAX_RETRIES,
29            initial_interval_ms: DEFAULT_INITIAL_RETRY_INTERVAL_MS,
30            backoff_multiplier: DEFAULT_BACKOFF_MULTIPLIER,
31            max_interval_ms: DEFAULT_MAX_RETRY_INTERVAL_MS,
32        }
33    }
34}
35
36impl RetryConfig {
37    pub fn new(
38        max_retries: usize,
39        initial_interval_ms: u64,
40        backoff_multiplier: f64,
41        max_interval_ms: u64,
42    ) -> Self {
43        Self {
44            max_retries,
45            initial_interval_ms,
46            backoff_multiplier,
47            max_interval_ms,
48        }
49    }
50
51    pub fn max_retries(&self) -> usize {
52        self.max_retries
53    }
54
55    pub fn delay_for_attempt(&self, attempt: usize) -> Duration {
56        if attempt == 0 {
57            return Duration::from_millis(0);
58        }
59
60        let exponent = (attempt - 1) as u32;
61        let base_delay_ms = (self.initial_interval_ms as f64
62            * self.backoff_multiplier.powi(exponent as i32)) as u64;
63
64        let capped_delay_ms = std::cmp::min(base_delay_ms, self.max_interval_ms);
65
66        let jitter_factor_to_avoid_thundering_herd = 0.8 + (rand::random::<f64>() * 0.4);
67        let jitter_delay_ms =
68            (capped_delay_ms as f64 * jitter_factor_to_avoid_thundering_herd) as u64;
69
70        Duration::from_millis(jitter_delay_ms)
71    }
72}
73
74pub fn should_retry(error: &ProviderError) -> bool {
75    matches!(
76        error,
77        ProviderError::RateLimitExceeded { .. }
78            | ProviderError::ServerError(_)
79            | ProviderError::RequestFailed(_)
80    )
81}
82
83pub async fn retry_operation<F, Fut, T>(
84    config: &RetryConfig,
85    operation: F,
86) -> Result<T, ProviderError>
87where
88    F: Fn() -> Fut + Send,
89    Fut: Future<Output = Result<T, ProviderError>> + Send,
90    T: Send,
91{
92    let mut attempts = 0;
93
94    loop {
95        match operation().await {
96            Ok(result) => return Ok(result),
97            Err(error) => {
98                if should_retry(&error) && attempts < config.max_retries {
99                    attempts += 1;
100                    tracing::warn!(
101                        "Request failed, retrying ({}/{}): {:?}",
102                        attempts,
103                        config.max_retries,
104                        error
105                    );
106
107                    let delay = match &error {
108                        ProviderError::RateLimitExceeded {
109                            retry_delay: Some(d),
110                            ..
111                        } => *d,
112                        _ => config.delay_for_attempt(attempts),
113                    };
114
115                    sleep(delay).await;
116                    continue;
117                }
118                return Err(error);
119            }
120        }
121    }
122}
123
124/// Trait for retry functionality to keep Provider dyn-compatible
125#[async_trait]
126pub trait ProviderRetry {
127    fn retry_config(&self) -> RetryConfig {
128        RetryConfig::default()
129    }
130
131    async fn with_retry<F, Fut, T>(&self, operation: F) -> Result<T, ProviderError>
132    where
133        F: Fn() -> Fut + Send,
134        Fut: Future<Output = Result<T, ProviderError>> + Send,
135        T: Send,
136    {
137        let mut attempts = 0;
138        let config = self.retry_config();
139
140        loop {
141            return match operation().await {
142                Ok(result) => Ok(result),
143                Err(error) => {
144                    if should_retry(&error) && attempts < config.max_retries {
145                        attempts += 1;
146                        tracing::warn!(
147                            "Request failed, retrying ({}/{}): {:?}",
148                            attempts,
149                            config.max_retries,
150                            error
151                        );
152
153                        let delay = match &error {
154                            ProviderError::RateLimitExceeded {
155                                retry_delay: Some(provider_delay),
156                                ..
157                            } => *provider_delay,
158                            _ => config.delay_for_attempt(attempts),
159                        };
160
161                        let skip_backoff = std::env::var("ASTER_PROVIDER_SKIP_BACKOFF")
162                            .unwrap_or_default()
163                            .parse::<bool>()
164                            .unwrap_or(false);
165
166                        if skip_backoff {
167                            tracing::info!("Skipping backoff due to ASTER_PROVIDER_SKIP_BACKOFF");
168                        } else {
169                            tracing::info!("Backing off for {:?} before retry", delay);
170                            sleep(delay).await;
171                        }
172                        continue;
173                    }
174
175                    Err(error)
176                }
177            };
178        }
179    }
180}
181
182impl<P: Provider> ProviderRetry for P {
183    fn retry_config(&self) -> RetryConfig {
184        Provider::retry_config(self)
185    }
186}