1use super::errors::ProviderError;
2use crate::providers::base::Provider;
3use async_trait::async_trait;
4use std::future::Future;
5use std::time::Duration;
6use tokio::time::sleep;
7
8pub const DEFAULT_MAX_RETRIES: usize = 3;
9pub const DEFAULT_INITIAL_RETRY_INTERVAL_MS: u64 = 1000;
10pub const DEFAULT_BACKOFF_MULTIPLIER: f64 = 2.0;
11pub const DEFAULT_MAX_RETRY_INTERVAL_MS: u64 = 30_000;
12
13#[derive(Debug, Clone)]
14pub struct RetryConfig {
15 pub(crate) max_retries: usize,
17 pub(crate) initial_interval_ms: u64,
19 pub(crate) backoff_multiplier: f64,
21 pub(crate) max_interval_ms: u64,
23}
24
25impl Default for RetryConfig {
26 fn default() -> Self {
27 Self {
28 max_retries: DEFAULT_MAX_RETRIES,
29 initial_interval_ms: DEFAULT_INITIAL_RETRY_INTERVAL_MS,
30 backoff_multiplier: DEFAULT_BACKOFF_MULTIPLIER,
31 max_interval_ms: DEFAULT_MAX_RETRY_INTERVAL_MS,
32 }
33 }
34}
35
36impl RetryConfig {
37 pub fn new(
38 max_retries: usize,
39 initial_interval_ms: u64,
40 backoff_multiplier: f64,
41 max_interval_ms: u64,
42 ) -> Self {
43 Self {
44 max_retries,
45 initial_interval_ms,
46 backoff_multiplier,
47 max_interval_ms,
48 }
49 }
50
51 pub fn max_retries(&self) -> usize {
52 self.max_retries
53 }
54
55 pub fn delay_for_attempt(&self, attempt: usize) -> Duration {
56 if attempt == 0 {
57 return Duration::from_millis(0);
58 }
59
60 let exponent = (attempt - 1) as u32;
61 let base_delay_ms = (self.initial_interval_ms as f64
62 * self.backoff_multiplier.powi(exponent as i32)) as u64;
63
64 let capped_delay_ms = std::cmp::min(base_delay_ms, self.max_interval_ms);
65
66 let jitter_factor_to_avoid_thundering_herd = 0.8 + (rand::random::<f64>() * 0.4);
67 let jitter_delay_ms =
68 (capped_delay_ms as f64 * jitter_factor_to_avoid_thundering_herd) as u64;
69
70 Duration::from_millis(jitter_delay_ms)
71 }
72}
73
74pub fn should_retry(error: &ProviderError) -> bool {
75 matches!(
76 error,
77 ProviderError::RateLimitExceeded { .. }
78 | ProviderError::ServerError(_)
79 | ProviderError::RequestFailed(_)
80 )
81}
82
83pub async fn retry_operation<F, Fut, T>(
84 config: &RetryConfig,
85 operation: F,
86) -> Result<T, ProviderError>
87where
88 F: Fn() -> Fut + Send,
89 Fut: Future<Output = Result<T, ProviderError>> + Send,
90 T: Send,
91{
92 let mut attempts = 0;
93
94 loop {
95 match operation().await {
96 Ok(result) => return Ok(result),
97 Err(error) => {
98 if should_retry(&error) && attempts < config.max_retries {
99 attempts += 1;
100 tracing::warn!(
101 "Request failed, retrying ({}/{}): {:?}",
102 attempts,
103 config.max_retries,
104 error
105 );
106
107 let delay = match &error {
108 ProviderError::RateLimitExceeded {
109 retry_delay: Some(d),
110 ..
111 } => *d,
112 _ => config.delay_for_attempt(attempts),
113 };
114
115 sleep(delay).await;
116 continue;
117 }
118 return Err(error);
119 }
120 }
121 }
122}
123
124#[async_trait]
126pub trait ProviderRetry {
127 fn retry_config(&self) -> RetryConfig {
128 RetryConfig::default()
129 }
130
131 async fn with_retry<F, Fut, T>(&self, operation: F) -> Result<T, ProviderError>
132 where
133 F: Fn() -> Fut + Send,
134 Fut: Future<Output = Result<T, ProviderError>> + Send,
135 T: Send,
136 {
137 let mut attempts = 0;
138 let config = self.retry_config();
139
140 loop {
141 return match operation().await {
142 Ok(result) => Ok(result),
143 Err(error) => {
144 if should_retry(&error) && attempts < config.max_retries {
145 attempts += 1;
146 tracing::warn!(
147 "Request failed, retrying ({}/{}): {:?}",
148 attempts,
149 config.max_retries,
150 error
151 );
152
153 let delay = match &error {
154 ProviderError::RateLimitExceeded {
155 retry_delay: Some(provider_delay),
156 ..
157 } => *provider_delay,
158 _ => config.delay_for_attempt(attempts),
159 };
160
161 let skip_backoff = std::env::var("ASTER_PROVIDER_SKIP_BACKOFF")
162 .unwrap_or_default()
163 .parse::<bool>()
164 .unwrap_or(false);
165
166 if skip_backoff {
167 tracing::info!("Skipping backoff due to ASTER_PROVIDER_SKIP_BACKOFF");
168 } else {
169 tracing::info!("Backing off for {:?} before retry", delay);
170 sleep(delay).await;
171 }
172 continue;
173 }
174
175 Err(error)
176 }
177 };
178 }
179 }
180}
181
182impl<P: Provider> ProviderRetry for P {
183 fn retry_config(&self) -> RetryConfig {
184 Provider::retry_config(self)
185 }
186}