Skip to main content

agentic_connect/engine/
retry_engine.rs

1//! Retry engine — manages circuit breakers, rate limits, and failure classification.
2
3use std::collections::HashMap;
4use chrono::Utc;
5
6use crate::types::{CircuitBreaker, CircuitState, FailureClass, RateLimitWindow, RetryStrategy};
7
8/// Runtime retry state for all connections.
9pub struct RetryEngine {
10    circuit_breakers: HashMap<String, CircuitBreaker>,
11    rate_limits: HashMap<String, RateLimitWindow>,
12    failure_history: Vec<FailureRecord>,
13}
14
15/// Recorded failure for pattern learning.
16#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
17pub struct FailureRecord {
18    pub endpoint: String,
19    pub failure_class: FailureClass,
20    pub timestamp: chrono::DateTime<Utc>,
21    pub http_status: Option<u16>,
22    pub message: String,
23}
24
25impl RetryEngine {
26    pub fn new() -> Self {
27        Self {
28            circuit_breakers: HashMap::new(),
29            rate_limits: HashMap::new(),
30            failure_history: Vec::new(),
31        }
32    }
33
34    /// Classify an HTTP status code into a failure class.
35    pub fn classify_http_status(status: u16) -> FailureClass {
36        match status {
37            429 => FailureClass::RateLimit,
38            401 | 403 => FailureClass::AuthFailure,
39            404 | 400 | 405 | 410 | 422 => FailureClass::Permanent,
40            500..=599 => FailureClass::ServerError,
41            _ => FailureClass::Transient,
42        }
43    }
44
45    /// Classify a connection error string.
46    pub fn classify_error(error: &str) -> FailureClass {
47        let lower = error.to_lowercase();
48        if lower.contains("timeout") || lower.contains("timed out") {
49            FailureClass::Transient
50        } else if lower.contains("refused") || lower.contains("dns") || lower.contains("resolve") {
51            FailureClass::NetworkError
52        } else if lower.contains("unauthorized") || lower.contains("forbidden") || lower.contains("auth") {
53            FailureClass::AuthFailure
54        } else if lower.contains("rate") || lower.contains("throttl") || lower.contains("429") {
55            FailureClass::RateLimit
56        } else if lower.contains("not found") || lower.contains("404") {
57            FailureClass::Permanent
58        } else {
59            FailureClass::Transient
60        }
61    }
62
63    /// Get the retry strategy for a failure class.
64    pub fn strategy_for(class: FailureClass) -> RetryStrategy {
65        match class {
66            FailureClass::Transient | FailureClass::ServerError => RetryStrategy::ExponentialBackoff {
67                base_ms: 1000,
68                max_ms: 30_000,
69                max_attempts: 3,
70            },
71            FailureClass::RateLimit => RetryStrategy::WaitRetryAfter,
72            FailureClass::AuthFailure => RetryStrategy::RefreshAndRetry,
73            FailureClass::Permanent => RetryStrategy::FailFast,
74            FailureClass::NetworkError => RetryStrategy::ExponentialBackoff {
75                base_ms: 2000,
76                max_ms: 60_000,
77                max_attempts: 5,
78            },
79        }
80    }
81
82    /// Get or create a circuit breaker for an endpoint.
83    pub fn get_circuit(&mut self, endpoint: &str) -> &CircuitBreaker {
84        self.circuit_breakers
85            .entry(endpoint.to_string())
86            .or_insert_with(|| CircuitBreaker::new(endpoint, 5, 60))
87    }
88
89    /// Record a failure for an endpoint.
90    pub fn record_failure(&mut self, endpoint: &str, class: FailureClass, message: &str, status: Option<u16>) {
91        // Update circuit breaker
92        let cb = self.circuit_breakers
93            .entry(endpoint.to_string())
94            .or_insert_with(|| CircuitBreaker::new(endpoint, 5, 60));
95        cb.record_failure();
96
97        // Record for pattern learning
98        self.failure_history.push(FailureRecord {
99            endpoint: endpoint.to_string(),
100            failure_class: class,
101            timestamp: Utc::now(),
102            http_status: status,
103            message: message.to_string(),
104        });
105
106        // Keep last 500 failures
107        if self.failure_history.len() > 500 {
108            self.failure_history.drain(..self.failure_history.len() - 500);
109        }
110    }
111
112    /// Record a success for an endpoint.
113    pub fn record_success(&mut self, endpoint: &str) {
114        if let Some(cb) = self.circuit_breakers.get_mut(endpoint) {
115            cb.record_success();
116        }
117    }
118
119    /// Check if a request to this endpoint should be allowed.
120    pub fn should_allow(&self, endpoint: &str) -> bool {
121        match self.circuit_breakers.get(endpoint) {
122            Some(cb) => cb.should_allow(),
123            None => true,
124        }
125    }
126
127    /// Update rate limit tracking from response headers.
128    pub fn update_rate_limit(&mut self, endpoint: &str, limit: u32, remaining: u32, reset_epoch: i64) {
129        let resets_at = chrono::DateTime::from_timestamp(reset_epoch, 0)
130            .map(|dt| dt.with_timezone(&Utc))
131            .unwrap_or_else(Utc::now);
132        self.rate_limits.insert(endpoint.to_string(), RateLimitWindow {
133            endpoint: endpoint.to_string(),
134            limit,
135            remaining,
136            resets_at,
137            window_secs: 60,
138        });
139    }
140
141    /// Get rate limit status for an endpoint.
142    pub fn rate_limit_status(&self, endpoint: &str) -> Option<&RateLimitWindow> {
143        self.rate_limits.get(endpoint)
144    }
145
146    /// Get all circuit breaker states.
147    pub fn all_circuits(&self) -> &HashMap<String, CircuitBreaker> {
148        &self.circuit_breakers
149    }
150
151    /// Reset a circuit breaker.
152    pub fn reset_circuit(&mut self, endpoint: &str) -> bool {
153        if let Some(cb) = self.circuit_breakers.get_mut(endpoint) {
154            cb.record_success();
155            true
156        } else {
157            false
158        }
159    }
160
161    /// Get failure patterns for an endpoint.
162    pub fn failure_patterns(&self, endpoint: &str) -> Vec<&FailureRecord> {
163        self.failure_history.iter().filter(|f| f.endpoint == endpoint).collect()
164    }
165
166    /// Get all recent failures.
167    pub fn recent_failures(&self, limit: usize) -> Vec<&FailureRecord> {
168        self.failure_history.iter().rev().take(limit).collect()
169    }
170}
171
172impl Default for RetryEngine {
173    fn default() -> Self { Self::new() }
174}
175
176#[cfg(test)]
177mod tests {
178    use super::*;
179
180    #[test]
181    fn test_classify_http_status() {
182        assert_eq!(RetryEngine::classify_http_status(429), FailureClass::RateLimit);
183        assert_eq!(RetryEngine::classify_http_status(401), FailureClass::AuthFailure);
184        assert_eq!(RetryEngine::classify_http_status(404), FailureClass::Permanent);
185        assert_eq!(RetryEngine::classify_http_status(503), FailureClass::ServerError);
186    }
187
188    #[test]
189    fn test_circuit_breaker_opens() {
190        let mut engine = RetryEngine::new();
191        let ep = "https://api.example.com";
192        for _ in 0..5 {
193            engine.record_failure(ep, FailureClass::Transient, "timeout", None);
194        }
195        assert!(!engine.should_allow(ep));
196    }
197
198    #[test]
199    fn test_circuit_breaker_resets() {
200        let mut engine = RetryEngine::new();
201        let ep = "https://api.example.com";
202        for _ in 0..5 {
203            engine.record_failure(ep, FailureClass::Transient, "timeout", None);
204        }
205        assert!(!engine.should_allow(ep));
206        engine.reset_circuit(ep);
207        assert!(engine.should_allow(ep));
208    }
209
210    #[test]
211    fn test_classify_error_string() {
212        assert_eq!(RetryEngine::classify_error("connection timed out"), FailureClass::Transient);
213        assert_eq!(RetryEngine::classify_error("connection refused"), FailureClass::NetworkError);
214        assert_eq!(RetryEngine::classify_error("rate limit exceeded"), FailureClass::RateLimit);
215        assert_eq!(RetryEngine::classify_error("unauthorized"), FailureClass::AuthFailure);
216    }
217
218    #[test]
219    fn test_failure_history_capped() {
220        let mut engine = RetryEngine::new();
221        for i in 0..600 {
222            engine.record_failure(&format!("ep{}", i), FailureClass::Transient, "err", None);
223        }
224        assert!(engine.failure_history.len() <= 500);
225    }
226}