Skip to main content

aster/ratelimit/
retry.rs

1//! 重试策略
2//!
3//! 指数退避重试和错误判断
4
5use std::future::Future;
6use std::time::Duration;
7
8/// 重试策略配置
9#[derive(Debug, Clone)]
10pub struct RetryPolicy {
11    /// 最大重试次数
12    pub max_retries: u32,
13    /// 基础延迟(毫秒)
14    pub base_delay_ms: u64,
15    /// 最大延迟(毫秒)
16    pub max_delay_ms: u64,
17    /// 指数基数
18    pub exponential_base: f64,
19    /// 是否添加抖动
20    pub jitter: bool,
21}
22
23impl Default for RetryPolicy {
24    fn default() -> Self {
25        Self {
26            max_retries: 3,
27            base_delay_ms: 1000,
28            max_delay_ms: 60_000,
29            exponential_base: 2.0,
30            jitter: true,
31        }
32    }
33}
34
35/// 计算重试延迟
36fn calculate_delay(policy: &RetryPolicy, attempt: u32) -> Duration {
37    let mut delay = policy.base_delay_ms as f64 * policy.exponential_base.powi(attempt as i32);
38
39    // 添加抖动
40    if policy.jitter {
41        use rand::Rng;
42        let mut rng = rand::thread_rng();
43        delay *= 0.5 + rng.gen::<f64>();
44    }
45
46    // 限制最大延迟
47    let delay_ms = (delay as u64).min(policy.max_delay_ms);
48    Duration::from_millis(delay_ms)
49}
50
51/// 带指数退避的重试
52pub async fn retry_with_backoff<T, E, F, Fut>(mut f: F, policy: RetryPolicy) -> Result<T, E>
53where
54    F: FnMut() -> Fut,
55    Fut: Future<Output = Result<T, E>>,
56    E: std::fmt::Debug,
57{
58    let mut last_error: Option<E> = None;
59
60    for attempt in 0..=policy.max_retries {
61        match f().await {
62            Ok(result) => return Ok(result),
63            Err(err) => {
64                last_error = Some(err);
65
66                if attempt < policy.max_retries {
67                    let delay = calculate_delay(&policy, attempt);
68                    tokio::time::sleep(delay).await;
69                }
70            }
71        }
72    }
73
74    Err(last_error.unwrap())
75}
76
77/// 默认可重试状态码
78const DEFAULT_RETRYABLE_STATUS_CODES: &[u16] = &[429, 500, 502, 503, 504];
79
80/// 检查错误是否可重试
81pub fn is_retryable_error(error: &str, status_codes: Option<&[u16]>) -> bool {
82    let codes = status_codes.unwrap_or(DEFAULT_RETRYABLE_STATUS_CODES);
83
84    // 检查网络错误
85    let network_errors = [
86        "ECONNREFUSED",
87        "ETIMEDOUT",
88        "ENOTFOUND",
89        "connection refused",
90        "timeout",
91        "network error",
92    ];
93
94    for net_err in network_errors {
95        if error.to_lowercase().contains(&net_err.to_lowercase()) {
96            return true;
97        }
98    }
99
100    // 检查限流
101    if error.contains("rate limit") || error.contains("429") {
102        return true;
103    }
104
105    // 检查状态码
106    for code in codes {
107        if error.contains(&code.to_string()) {
108            return true;
109        }
110    }
111
112    false
113}
114
115/// 解析 Retry-After 头
116pub fn parse_retry_after(header: &str) -> Option<u64> {
117    // 尝试解析为秒数
118    if let Ok(seconds) = header.parse::<u64>() {
119        return Some(seconds);
120    }
121
122    // 尝试解析为 HTTP 日期
123    if let Ok(date) = chrono::DateTime::parse_from_rfc2822(header) {
124        let now = chrono::Utc::now();
125        let diff = date.signed_duration_since(now);
126        if diff.num_seconds() > 0 {
127            return Some(diff.num_seconds() as u64);
128        }
129    }
130
131    None
132}
133
134#[cfg(test)]
135mod tests {
136    use super::*;
137
138    #[test]
139    fn test_is_retryable_error() {
140        assert!(is_retryable_error("rate limit exceeded", None));
141        assert!(is_retryable_error("status code 429", None));
142        assert!(is_retryable_error("connection refused", None));
143        assert!(is_retryable_error("ETIMEDOUT", None));
144        assert!(!is_retryable_error("invalid request", None));
145    }
146
147    #[test]
148    fn test_parse_retry_after_seconds() {
149        assert_eq!(parse_retry_after("60"), Some(60));
150        assert_eq!(parse_retry_after("0"), Some(0));
151    }
152
153    #[test]
154    fn test_calculate_delay() {
155        let policy = RetryPolicy {
156            jitter: false,
157            ..Default::default()
158        };
159
160        let delay0 = calculate_delay(&policy, 0);
161        let delay1 = calculate_delay(&policy, 1);
162        let delay2 = calculate_delay(&policy, 2);
163
164        assert_eq!(delay0.as_millis(), 1000);
165        assert_eq!(delay1.as_millis(), 2000);
166        assert_eq!(delay2.as_millis(), 4000);
167    }
168}