amaters_core/error/
recovery.rs1use super::{AmateRSError, ErrorContext, Result};
9use std::time::Duration;
10
11#[derive(Debug, Clone)]
13pub enum RecoveryStrategy {
14 FailFast,
16 ExponentialBackoff {
18 initial_delay: Duration,
19 max_delay: Duration,
20 max_attempts: usize,
21 multiplier: f64,
22 },
23 LinearBackoff {
25 delay: Duration,
26 max_attempts: usize,
27 },
28 CircuitBreaker {
30 failure_threshold: usize,
31 timeout: Duration,
32 },
33}
34
35impl RecoveryStrategy {
36 pub fn default_exponential() -> Self {
38 Self::ExponentialBackoff {
39 initial_delay: Duration::from_millis(100),
40 max_delay: Duration::from_secs(30),
41 max_attempts: 5,
42 multiplier: 2.0,
43 }
44 }
45
46 pub fn default_linear() -> Self {
48 Self::LinearBackoff {
49 delay: Duration::from_secs(1),
50 max_attempts: 3,
51 }
52 }
53
54 pub fn default_circuit_breaker() -> Self {
56 Self::CircuitBreaker {
57 failure_threshold: 5,
58 timeout: Duration::from_secs(60),
59 }
60 }
61}
62
63#[derive(Debug, Clone)]
65pub struct RecoverableError {
66 pub error: AmateRSError,
67 pub strategy: RecoveryStrategy,
68 pub recovery_hint: Option<String>,
69}
70
71impl RecoverableError {
72 pub fn new(error: AmateRSError, strategy: RecoveryStrategy) -> Self {
74 Self {
75 error,
76 strategy,
77 recovery_hint: None,
78 }
79 }
80
81 pub fn with_hint(mut self, hint: impl Into<String>) -> Self {
83 self.recovery_hint = Some(hint.into());
84 self
85 }
86
87 pub fn is_retryable(&self) -> bool {
89 !matches!(self.strategy, RecoveryStrategy::FailFast)
90 }
91}
92
93pub struct RetryExecutor {
95 strategy: RecoveryStrategy,
96 attempt: usize,
97}
98
99impl RetryExecutor {
100 pub fn new(strategy: RecoveryStrategy) -> Self {
102 Self {
103 strategy,
104 attempt: 0,
105 }
106 }
107
108 pub fn current_delay(&self) -> Option<Duration> {
110 match &self.strategy {
111 RecoveryStrategy::FailFast => None,
112 RecoveryStrategy::ExponentialBackoff {
113 initial_delay,
114 max_delay,
115 max_attempts,
116 multiplier,
117 } => {
118 if self.attempt >= *max_attempts {
119 return None;
120 }
121 let delay = initial_delay.as_secs_f64() * multiplier.powi(self.attempt as i32);
122 let delay = Duration::from_secs_f64(delay.min(max_delay.as_secs_f64()));
123 Some(delay)
124 }
125 RecoveryStrategy::LinearBackoff {
126 delay,
127 max_attempts,
128 } => {
129 if self.attempt >= *max_attempts {
130 None
131 } else {
132 Some(*delay)
133 }
134 }
135 RecoveryStrategy::CircuitBreaker { .. } => {
136 if self.attempt == 0 {
139 Some(Duration::from_millis(100))
140 } else {
141 None
142 }
143 }
144 }
145 }
146
147 pub fn increment(&mut self) {
149 self.attempt += 1;
150 }
151
152 pub fn should_retry(&self) -> bool {
154 self.current_delay().is_some()
155 }
156}
157
158#[derive(Debug, Clone, Copy, PartialEq, Eq)]
160pub enum CircuitState {
161 Closed, Open, HalfOpen, }
165
166pub struct CircuitBreaker {
168 state: CircuitState,
169 failure_count: usize,
170 failure_threshold: usize,
171 last_failure_time: Option<std::time::Instant>,
172 timeout: Duration,
173}
174
175impl CircuitBreaker {
176 pub fn new(failure_threshold: usize, timeout: Duration) -> Self {
178 Self {
179 state: CircuitState::Closed,
180 failure_count: 0,
181 failure_threshold,
182 last_failure_time: None,
183 timeout,
184 }
185 }
186
187 pub fn record_success(&mut self) {
189 self.failure_count = 0;
190 self.state = CircuitState::Closed;
191 }
192
193 pub fn record_failure(&mut self) {
195 self.failure_count += 1;
196 self.last_failure_time = Some(std::time::Instant::now());
197
198 if self.failure_count >= self.failure_threshold {
199 self.state = CircuitState::Open;
200 }
201 }
202
203 pub fn is_allowed(&mut self) -> bool {
205 match self.state {
206 CircuitState::Closed => true,
207 CircuitState::Open => {
208 if let Some(last_failure) = self.last_failure_time {
209 if last_failure.elapsed() > self.timeout {
210 self.state = CircuitState::HalfOpen;
211 true
212 } else {
213 false
214 }
215 } else {
216 false
217 }
218 }
219 CircuitState::HalfOpen => true,
220 }
221 }
222
223 pub fn state(&self) -> CircuitState {
225 self.state
226 }
227}
228
229pub fn suggest_recovery_strategy(error: &AmateRSError) -> RecoveryStrategy {
231 match error {
232 AmateRSError::NetworkError(_) => RecoveryStrategy::default_exponential(),
233 AmateRSError::FheComputation(_) => RecoveryStrategy::default_linear(),
234 AmateRSError::IoError(_) => RecoveryStrategy::default_exponential(),
235 AmateRSError::ResourceExhausted(_) => RecoveryStrategy::default_circuit_breaker(),
236 AmateRSError::StorageIntegrity(_) => RecoveryStrategy::FailFast,
237 AmateRSError::SystemInvariantBroken(_) => RecoveryStrategy::FailFast,
238 _ => RecoveryStrategy::default_linear(),
239 }
240}
241
242#[cfg(test)]
243mod tests {
244 use super::*;
245
246 #[test]
247 fn test_retry_executor_exponential() -> Result<()> {
248 let mut executor = RetryExecutor::new(RecoveryStrategy::default_exponential());
249
250 assert!(executor.should_retry());
251 let delay1 = executor.current_delay().expect("Should have delay");
252 executor.increment();
253
254 let delay2 = executor.current_delay().expect("Should have delay");
255 assert!(delay2 > delay1, "Exponential backoff should increase delay");
256
257 Ok(())
258 }
259
260 #[test]
261 fn test_circuit_breaker() -> Result<()> {
262 let mut cb = CircuitBreaker::new(3, Duration::from_millis(100));
263
264 assert_eq!(cb.state(), CircuitState::Closed);
265 assert!(cb.is_allowed());
266
267 cb.record_failure();
269 cb.record_failure();
270 assert_eq!(cb.state(), CircuitState::Closed);
271
272 cb.record_failure();
273 assert_eq!(cb.state(), CircuitState::Open);
274 assert!(!cb.is_allowed());
275
276 std::thread::sleep(Duration::from_millis(150));
278 assert!(cb.is_allowed()); assert_eq!(cb.state(), CircuitState::HalfOpen);
280
281 cb.record_success();
283 assert_eq!(cb.state(), CircuitState::Closed);
284
285 Ok(())
286 }
287
288 #[test]
289 fn test_recoverable_error() -> Result<()> {
290 let error = AmateRSError::NetworkError(ErrorContext::new("connection timeout"));
291 let recoverable = RecoverableError::new(error, RecoveryStrategy::default_exponential())
292 .with_hint("Check network connectivity");
293
294 assert!(recoverable.is_retryable());
295 assert!(recoverable.recovery_hint.is_some());
296
297 Ok(())
298 }
299}