1use crate::error::{CgCommonError, Result};
4use rand::Rng;
5use std::future::Future;
6use std::time::Duration;
7use tokio::time::sleep;
8use tracing::warn;
9
10#[derive(Debug, Clone)]
12pub struct RetryConfig {
13 pub max_retries: u32,
14 pub base_delay_ms: u64,
15 pub max_jitter_ms: u64,
16 pub retry_on_rate_limit: bool,
17 pub retry_on_server_error: bool,
18}
19
20impl Default for RetryConfig {
21 fn default() -> Self {
22 Self {
23 max_retries: 3,
24 base_delay_ms: 2000, max_jitter_ms: 1000,
26 retry_on_rate_limit: true,
27 retry_on_server_error: true,
28 }
29 }
30}
31
32impl RetryConfig {
33 pub fn new(max_retries: u32) -> Self {
34 Self {
35 max_retries,
36 ..Default::default()
37 }
38 }
39
40 pub fn no_retry() -> Self {
41 Self {
42 max_retries: 0,
43 ..Default::default()
44 }
45 }
46
47 pub fn with_base_delay(mut self, ms: u64) -> Self {
48 self.base_delay_ms = ms;
49 self
50 }
51
52 pub fn with_max_jitter(mut self, ms: u64) -> Self {
53 self.max_jitter_ms = ms;
54 self
55 }
56}
57
58#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
60pub enum BackoffStrategy {
61 #[default]
62 Exponential,
63 Linear,
64 Constant,
65}
66
67pub fn calculate_backoff(
69 attempt: u32,
70 config: &RetryConfig,
71 strategy: BackoffStrategy,
72) -> Duration {
73 let base = config.base_delay_ms;
74 let delay_ms = match strategy {
75 BackoffStrategy::Exponential => base * 2u64.pow(attempt),
76 BackoffStrategy::Linear => base * (attempt as u64 + 1),
77 BackoffStrategy::Constant => base,
78 };
79 let jitter = if config.max_jitter_ms > 0 {
80 rand::thread_rng().gen_range(0..=config.max_jitter_ms)
81 } else {
82 0
83 };
84 Duration::from_millis(delay_ms + jitter)
85}
86
87pub fn is_retryable(error: &CgCommonError, config: &RetryConfig) -> bool {
89 match error {
90 CgCommonError::RateLimitExceeded(_) => config.retry_on_rate_limit,
91 CgCommonError::ServerError(_, _) => config.retry_on_server_error,
92 CgCommonError::TimeoutError(_) => true,
93 CgCommonError::RequestError(e) => e.is_timeout() || e.is_connect(),
94 _ => false,
95 }
96}
97
98pub async fn retry_with_backoff<F, Fut, T>(operation: F, config: &RetryConfig) -> Result<T>
100where
101 F: Fn() -> Fut,
102 Fut: Future<Output = Result<T>>,
103{
104 let mut last_error = None;
105
106 for attempt in 0..=config.max_retries {
107 match operation().await {
108 Ok(result) => return Ok(result),
109 Err(e) => {
110 if attempt == config.max_retries || !is_retryable(&e, config) {
111 return Err(e);
112 }
113 let delay = calculate_backoff(attempt, config, BackoffStrategy::Exponential);
114 warn!(
115 "Attempt {} failed: {}. Retrying in {:?}",
116 attempt + 1,
117 e,
118 delay
119 );
120 sleep(delay).await;
121 last_error = Some(e);
122 }
123 }
124 }
125
126 Err(last_error.unwrap_or(CgCommonError::MaxRetriesExceeded(config.max_retries)))
127}
128
129#[cfg(test)]
130mod tests {
131 use super::*;
132
133 #[test]
134 fn test_default_config() {
135 let config = RetryConfig::default();
136 assert_eq!(config.max_retries, 3);
137 assert_eq!(config.base_delay_ms, 2000);
138 }
139
140 #[test]
141 fn test_exponential_backoff() {
142 let config = RetryConfig::default().with_max_jitter(0);
143 assert_eq!(
144 calculate_backoff(0, &config, BackoffStrategy::Exponential),
145 Duration::from_millis(2000)
146 );
147 assert_eq!(
148 calculate_backoff(1, &config, BackoffStrategy::Exponential),
149 Duration::from_millis(4000)
150 );
151 }
152
153 #[tokio::test]
154 async fn test_retry_immediate_success() {
155 let config = RetryConfig::default();
156 let result: Result<i32> = retry_with_backoff(|| async { Ok(42) }, &config).await;
157 assert_eq!(result.unwrap(), 42);
158 }
159}