Skip to main content

aster/network/
retry.rs

1//! 网络请求重试策略
2//!
3//! 支持指数退避和抖动
4
5use rand::Rng;
6use serde::{Deserialize, Serialize};
7use std::future::Future;
8use std::time::Duration;
9use tokio::time::sleep;
10
11/// 重试配置
12#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct RetryConfig {
14    /// 最大重试次数
15    #[serde(default = "default_max_retries")]
16    pub max_retries: u32,
17    /// 基础延迟(毫秒)
18    #[serde(default = "default_base_delay")]
19    pub base_delay: u64,
20    /// 最大延迟(毫秒)
21    #[serde(default = "default_max_delay")]
22    pub max_delay: u64,
23    /// 是否使用指数退避
24    #[serde(default = "default_exponential_backoff")]
25    pub exponential_backoff: bool,
26    /// 抖动因子 (0.0-1.0)
27    #[serde(default = "default_jitter")]
28    pub jitter: f64,
29    /// 可重试的错误类型
30    #[serde(default = "default_retryable_errors")]
31    pub retryable_errors: Vec<String>,
32    /// 可重试的状态码
33    #[serde(default = "default_retryable_status_codes")]
34    pub retryable_status_codes: Vec<u16>,
35}
36
37fn default_max_retries() -> u32 {
38    4
39}
40fn default_base_delay() -> u64 {
41    1000
42}
43fn default_max_delay() -> u64 {
44    30000
45}
46fn default_exponential_backoff() -> bool {
47    true
48}
49fn default_jitter() -> f64 {
50    0.1
51}
52
53fn default_retryable_errors() -> Vec<String> {
54    vec![
55        "ECONNRESET".to_string(),
56        "ETIMEDOUT".to_string(),
57        "ENOTFOUND".to_string(),
58        "ECONNREFUSED".to_string(),
59        "ENETUNREACH".to_string(),
60        "overloaded_error".to_string(),
61        "rate_limit_error".to_string(),
62        "api_error".to_string(),
63        "timeout".to_string(),
64    ]
65}
66
67fn default_retryable_status_codes() -> Vec<u16> {
68    vec![408, 429, 500, 502, 503, 504]
69}
70
71impl Default for RetryConfig {
72    fn default() -> Self {
73        DEFAULT_RETRY_CONFIG.clone()
74    }
75}
76
77/// 默认重试配置
78pub const DEFAULT_RETRY_CONFIG: RetryConfig = RetryConfig {
79    max_retries: 4,
80    base_delay: 1000,
81    max_delay: 30000,
82    exponential_backoff: true,
83    jitter: 0.1,
84    retryable_errors: Vec::new(), // 使用 default_retryable_errors()
85    retryable_status_codes: Vec::new(), // 使用 default_retryable_status_codes()
86};
87
88/// 计算重试延迟
89pub fn calculate_retry_delay(attempt: u32, config: &RetryConfig) -> u64 {
90    let mut delay = config.base_delay;
91
92    if config.exponential_backoff {
93        delay = config.base_delay * 2u64.pow(attempt);
94    }
95
96    // 应用抖动(避免惊群效应)
97    if config.jitter > 0.0 {
98        let jitter_amount = (delay as f64 * config.jitter) as i64;
99        let random_jitter = rand::thread_rng().gen_range(-jitter_amount..=jitter_amount);
100        delay = (delay as i64 + random_jitter).max(0) as u64;
101    }
102
103    // 限制最大延迟
104    delay.min(config.max_delay)
105}
106
107/// 判断错误是否可重试
108pub fn is_retryable_error(error: &str, status_code: Option<u16>, config: &RetryConfig) -> bool {
109    let retryable_errors = if config.retryable_errors.is_empty() {
110        default_retryable_errors()
111    } else {
112        config.retryable_errors.clone()
113    };
114
115    let retryable_status_codes = if config.retryable_status_codes.is_empty() {
116        default_retryable_status_codes()
117    } else {
118        config.retryable_status_codes.clone()
119    };
120
121    // 检查错误消息
122    for code in &retryable_errors {
123        if error.contains(code) {
124            return true;
125        }
126    }
127
128    // 检查 HTTP 状态码
129    if let Some(status) = status_code {
130        if retryable_status_codes.contains(&status) {
131            return true;
132        }
133    }
134
135    false
136}
137
138/// 重试错误信息
139#[derive(Debug, Clone)]
140pub struct RetryError<E> {
141    /// 最后一次错误
142    pub last_error: E,
143    /// 重试次数
144    pub attempts: u32,
145}
146
147impl<E: std::fmt::Display> std::fmt::Display for RetryError<E> {
148    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
149        write!(
150            f,
151            "Failed after {} attempts: {}",
152            self.attempts, self.last_error
153        )
154    }
155}
156
157impl<E: std::error::Error + 'static> std::error::Error for RetryError<E> {
158    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
159        Some(&self.last_error)
160    }
161}
162
163/// 执行带重试的操作
164pub async fn with_retry<T, E, F, Fut>(
165    operation: F,
166    config: &RetryConfig,
167    is_retryable: impl Fn(&E) -> bool,
168    on_retry: Option<impl Fn(u32, &E, u64)>,
169) -> Result<T, RetryError<E>>
170where
171    F: Fn() -> Fut,
172    Fut: Future<Output = Result<T, E>>,
173{
174    let mut last_error: Option<E> = None;
175
176    for attempt in 0..=config.max_retries {
177        match operation().await {
178            Ok(result) => return Ok(result),
179            Err(error) => {
180                // 最后一次尝试失败
181                if attempt == config.max_retries {
182                    return Err(RetryError {
183                        last_error: error,
184                        attempts: attempt + 1,
185                    });
186                }
187
188                // 检查是否可重试
189                if !is_retryable(&error) {
190                    return Err(RetryError {
191                        last_error: error,
192                        attempts: attempt + 1,
193                    });
194                }
195
196                // 计算延迟
197                let delay = calculate_retry_delay(attempt, config);
198
199                // 调用回调
200                if let Some(ref callback) = on_retry {
201                    callback(attempt + 1, &error, delay);
202                }
203
204                last_error = Some(error);
205
206                // 等待后重试
207                sleep(Duration::from_millis(delay)).await;
208            }
209        }
210    }
211
212    Err(RetryError {
213        last_error: last_error.unwrap(),
214        attempts: config.max_retries + 1,
215    })
216}
217
218/// 简化的重试函数
219pub async fn retry<T, E, F, Fut>(operation: F, config: &RetryConfig) -> Result<T, RetryError<E>>
220where
221    F: Fn() -> Fut,
222    Fut: Future<Output = Result<T, E>>,
223    E: std::fmt::Display,
224{
225    with_retry(
226        operation,
227        config,
228        |e| is_retryable_error(&e.to_string(), None, config),
229        None::<fn(u32, &E, u64)>,
230    )
231    .await
232}