1use std::collections::HashMap;
4use chrono::Utc;
5
6use crate::types::{CircuitBreaker, CircuitState, FailureClass, RateLimitWindow, RetryStrategy};
7
8pub struct RetryEngine {
10 circuit_breakers: HashMap<String, CircuitBreaker>,
11 rate_limits: HashMap<String, RateLimitWindow>,
12 failure_history: Vec<FailureRecord>,
13}
14
15#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
17pub struct FailureRecord {
18 pub endpoint: String,
19 pub failure_class: FailureClass,
20 pub timestamp: chrono::DateTime<Utc>,
21 pub http_status: Option<u16>,
22 pub message: String,
23}
24
25impl RetryEngine {
26 pub fn new() -> Self {
27 Self {
28 circuit_breakers: HashMap::new(),
29 rate_limits: HashMap::new(),
30 failure_history: Vec::new(),
31 }
32 }
33
34 pub fn classify_http_status(status: u16) -> FailureClass {
36 match status {
37 429 => FailureClass::RateLimit,
38 401 | 403 => FailureClass::AuthFailure,
39 404 | 400 | 405 | 410 | 422 => FailureClass::Permanent,
40 500..=599 => FailureClass::ServerError,
41 _ => FailureClass::Transient,
42 }
43 }
44
45 pub fn classify_error(error: &str) -> FailureClass {
47 let lower = error.to_lowercase();
48 if lower.contains("timeout") || lower.contains("timed out") {
49 FailureClass::Transient
50 } else if lower.contains("refused") || lower.contains("dns") || lower.contains("resolve") {
51 FailureClass::NetworkError
52 } else if lower.contains("unauthorized") || lower.contains("forbidden") || lower.contains("auth") {
53 FailureClass::AuthFailure
54 } else if lower.contains("rate") || lower.contains("throttl") || lower.contains("429") {
55 FailureClass::RateLimit
56 } else if lower.contains("not found") || lower.contains("404") {
57 FailureClass::Permanent
58 } else {
59 FailureClass::Transient
60 }
61 }
62
63 pub fn strategy_for(class: FailureClass) -> RetryStrategy {
65 match class {
66 FailureClass::Transient | FailureClass::ServerError => RetryStrategy::ExponentialBackoff {
67 base_ms: 1000,
68 max_ms: 30_000,
69 max_attempts: 3,
70 },
71 FailureClass::RateLimit => RetryStrategy::WaitRetryAfter,
72 FailureClass::AuthFailure => RetryStrategy::RefreshAndRetry,
73 FailureClass::Permanent => RetryStrategy::FailFast,
74 FailureClass::NetworkError => RetryStrategy::ExponentialBackoff {
75 base_ms: 2000,
76 max_ms: 60_000,
77 max_attempts: 5,
78 },
79 }
80 }
81
82 pub fn get_circuit(&mut self, endpoint: &str) -> &CircuitBreaker {
84 self.circuit_breakers
85 .entry(endpoint.to_string())
86 .or_insert_with(|| CircuitBreaker::new(endpoint, 5, 60))
87 }
88
89 pub fn record_failure(&mut self, endpoint: &str, class: FailureClass, message: &str, status: Option<u16>) {
91 let cb = self.circuit_breakers
93 .entry(endpoint.to_string())
94 .or_insert_with(|| CircuitBreaker::new(endpoint, 5, 60));
95 cb.record_failure();
96
97 self.failure_history.push(FailureRecord {
99 endpoint: endpoint.to_string(),
100 failure_class: class,
101 timestamp: Utc::now(),
102 http_status: status,
103 message: message.to_string(),
104 });
105
106 if self.failure_history.len() > 500 {
108 self.failure_history.drain(..self.failure_history.len() - 500);
109 }
110 }
111
112 pub fn record_success(&mut self, endpoint: &str) {
114 if let Some(cb) = self.circuit_breakers.get_mut(endpoint) {
115 cb.record_success();
116 }
117 }
118
119 pub fn should_allow(&self, endpoint: &str) -> bool {
121 match self.circuit_breakers.get(endpoint) {
122 Some(cb) => cb.should_allow(),
123 None => true,
124 }
125 }
126
127 pub fn update_rate_limit(&mut self, endpoint: &str, limit: u32, remaining: u32, reset_epoch: i64) {
129 let resets_at = chrono::DateTime::from_timestamp(reset_epoch, 0)
130 .map(|dt| dt.with_timezone(&Utc))
131 .unwrap_or_else(Utc::now);
132 self.rate_limits.insert(endpoint.to_string(), RateLimitWindow {
133 endpoint: endpoint.to_string(),
134 limit,
135 remaining,
136 resets_at,
137 window_secs: 60,
138 });
139 }
140
141 pub fn rate_limit_status(&self, endpoint: &str) -> Option<&RateLimitWindow> {
143 self.rate_limits.get(endpoint)
144 }
145
146 pub fn all_circuits(&self) -> &HashMap<String, CircuitBreaker> {
148 &self.circuit_breakers
149 }
150
151 pub fn reset_circuit(&mut self, endpoint: &str) -> bool {
153 if let Some(cb) = self.circuit_breakers.get_mut(endpoint) {
154 cb.record_success();
155 true
156 } else {
157 false
158 }
159 }
160
161 pub fn failure_patterns(&self, endpoint: &str) -> Vec<&FailureRecord> {
163 self.failure_history.iter().filter(|f| f.endpoint == endpoint).collect()
164 }
165
166 pub fn recent_failures(&self, limit: usize) -> Vec<&FailureRecord> {
168 self.failure_history.iter().rev().take(limit).collect()
169 }
170}
171
172impl Default for RetryEngine {
173 fn default() -> Self { Self::new() }
174}
175
176#[cfg(test)]
177mod tests {
178 use super::*;
179
180 #[test]
181 fn test_classify_http_status() {
182 assert_eq!(RetryEngine::classify_http_status(429), FailureClass::RateLimit);
183 assert_eq!(RetryEngine::classify_http_status(401), FailureClass::AuthFailure);
184 assert_eq!(RetryEngine::classify_http_status(404), FailureClass::Permanent);
185 assert_eq!(RetryEngine::classify_http_status(503), FailureClass::ServerError);
186 }
187
188 #[test]
189 fn test_circuit_breaker_opens() {
190 let mut engine = RetryEngine::new();
191 let ep = "https://api.example.com";
192 for _ in 0..5 {
193 engine.record_failure(ep, FailureClass::Transient, "timeout", None);
194 }
195 assert!(!engine.should_allow(ep));
196 }
197
198 #[test]
199 fn test_circuit_breaker_resets() {
200 let mut engine = RetryEngine::new();
201 let ep = "https://api.example.com";
202 for _ in 0..5 {
203 engine.record_failure(ep, FailureClass::Transient, "timeout", None);
204 }
205 assert!(!engine.should_allow(ep));
206 engine.reset_circuit(ep);
207 assert!(engine.should_allow(ep));
208 }
209
210 #[test]
211 fn test_classify_error_string() {
212 assert_eq!(RetryEngine::classify_error("connection timed out"), FailureClass::Transient);
213 assert_eq!(RetryEngine::classify_error("connection refused"), FailureClass::NetworkError);
214 assert_eq!(RetryEngine::classify_error("rate limit exceeded"), FailureClass::RateLimit);
215 assert_eq!(RetryEngine::classify_error("unauthorized"), FailureClass::AuthFailure);
216 }
217
218 #[test]
219 fn test_failure_history_capped() {
220 let mut engine = RetryEngine::new();
221 for i in 0..600 {
222 engine.record_failure(&format!("ep{}", i), FailureClass::Transient, "err", None);
223 }
224 assert!(engine.failure_history.len() <= 500);
225 }
226}