1use crate::error::{Result, StreamingError};
4use chrono::{DateTime, Utc};
5use serde::{Deserialize, Serialize};
6use std::collections::VecDeque;
7use std::sync::Arc;
8use std::time::Duration;
9use tokio::sync::RwLock;
10use tokio::time::sleep;
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
14pub enum RecoveryStrategy {
15 Fail,
17
18 ExponentialBackoff,
20
21 FixedDelay,
23
24 Skip,
26
27 DeadLetter,
29}
30
31#[derive(Debug, Clone, Serialize, Deserialize)]
33pub struct RecoveryConfig {
34 pub strategy: RecoveryStrategy,
36
37 pub max_retries: usize,
39
40 pub initial_delay: Duration,
42
43 pub max_delay: Duration,
45
46 pub backoff_multiplier: f64,
48
49 pub track_failures: bool,
51
52 pub max_failure_history: usize,
54}
55
56impl Default for RecoveryConfig {
57 fn default() -> Self {
58 Self {
59 strategy: RecoveryStrategy::ExponentialBackoff,
60 max_retries: 3,
61 initial_delay: Duration::from_millis(100),
62 max_delay: Duration::from_secs(60),
63 backoff_multiplier: 2.0,
64 track_failures: true,
65 max_failure_history: 1000,
66 }
67 }
68}
69
70#[derive(Debug, Clone, Serialize, Deserialize)]
72pub struct FailureRecord {
73 pub timestamp: DateTime<Utc>,
75
76 pub error: String,
78
79 pub attempt: usize,
81
82 pub element_id: Option<String>,
84
85 pub action: String,
87}
88
89impl FailureRecord {
90 pub fn new(error: String, attempt: usize) -> Self {
92 Self {
93 timestamp: Utc::now(),
94 error,
95 attempt,
96 element_id: None,
97 action: "pending".to_string(),
98 }
99 }
100
101 pub fn with_element_id(mut self, id: String) -> Self {
103 self.element_id = Some(id);
104 self
105 }
106
107 pub fn with_action(mut self, action: String) -> Self {
109 self.action = action;
110 self
111 }
112}
113
114pub struct RecoveryManager {
116 config: RecoveryConfig,
117 failure_history: Arc<RwLock<VecDeque<FailureRecord>>>,
118 total_failures: Arc<RwLock<u64>>,
119 total_retries: Arc<RwLock<u64>>,
120 successful_recoveries: Arc<RwLock<u64>>,
121}
122
123impl RecoveryManager {
124 pub fn new(config: RecoveryConfig) -> Self {
126 Self {
127 config,
128 failure_history: Arc::new(RwLock::new(VecDeque::new())),
129 total_failures: Arc::new(RwLock::new(0)),
130 total_retries: Arc::new(RwLock::new(0)),
131 successful_recoveries: Arc::new(RwLock::new(0)),
132 }
133 }
134
135 pub async fn execute_with_recovery<F, Fut, T>(&self, mut operation: F) -> Result<T>
137 where
138 F: FnMut() -> Fut,
139 Fut: std::future::Future<Output = Result<T>>,
140 {
141 let mut attempt = 0;
142 let mut last_error_msg: Option<String> = None;
143
144 while attempt <= self.config.max_retries {
145 match operation().await {
146 Ok(result) => {
147 if attempt > 0 {
148 let mut recoveries = self.successful_recoveries.write().await;
149 *recoveries += 1;
150 }
151 return Ok(result);
152 }
153 Err(e) => {
154 let error_msg = e.to_string();
155 last_error_msg = Some(error_msg.clone());
156 attempt += 1;
157
158 if attempt > self.config.max_retries {
159 break;
160 }
161
162 let delay = self.calculate_delay(attempt);
163
164 if self.config.track_failures {
165 let record = FailureRecord::new(error_msg, attempt)
166 .with_action(format!("retry after {:?}", delay));
167 self.record_failure(record).await;
168 }
169
170 let mut retries = self.total_retries.write().await;
171 *retries += 1;
172
173 sleep(delay).await;
174 }
175 }
176 }
177
178 let mut failures = self.total_failures.write().await;
179 *failures += 1;
180
181 if let Some(error_msg) = last_error_msg {
182 if self.config.track_failures {
183 let record = FailureRecord::new(error_msg.clone(), attempt)
184 .with_action("max retries exceeded".to_string());
185 self.record_failure(record).await;
186 }
187
188 match self.config.strategy {
189 RecoveryStrategy::Fail => Err(StreamingError::Other(error_msg)),
190 RecoveryStrategy::Skip => {
191 tracing::warn!("Skipping failed operation after {} attempts", attempt);
192 Err(StreamingError::Other(error_msg))
193 }
194 RecoveryStrategy::DeadLetter => {
195 tracing::warn!("Moving to dead letter queue after {} attempts", attempt);
196 Err(StreamingError::Other(error_msg))
197 }
198 _ => Err(StreamingError::Other(error_msg)),
199 }
200 } else {
201 Err(StreamingError::Other("Unknown error".to_string()))
202 }
203 }
204
205 fn calculate_delay(&self, attempt: usize) -> Duration {
207 match self.config.strategy {
208 RecoveryStrategy::FixedDelay => self.config.initial_delay,
209 RecoveryStrategy::ExponentialBackoff => {
210 let multiplier = self.config.backoff_multiplier.powi(attempt as i32 - 1);
211 let delay_ms = self.config.initial_delay.as_millis() as f64 * multiplier;
212 let delay = Duration::from_millis(delay_ms as u64);
213 delay.min(self.config.max_delay)
214 }
215 _ => Duration::ZERO,
216 }
217 }
218
219 async fn record_failure(&self, record: FailureRecord) {
221 let mut history = self.failure_history.write().await;
222
223 history.push_back(record);
224
225 while history.len() > self.config.max_failure_history {
226 history.pop_front();
227 }
228 }
229
230 pub async fn get_failure_history(&self) -> Vec<FailureRecord> {
232 self.failure_history.read().await.iter().cloned().collect()
233 }
234
235 pub async fn total_failures(&self) -> u64 {
237 *self.total_failures.read().await
238 }
239
240 pub async fn total_retries(&self) -> u64 {
242 *self.total_retries.read().await
243 }
244
245 pub async fn successful_recoveries(&self) -> u64 {
247 *self.successful_recoveries.read().await
248 }
249
250 pub async fn success_rate(&self) -> f64 {
252 let failures = *self.total_failures.read().await;
253 let recoveries = *self.successful_recoveries.read().await;
254
255 if failures + recoveries == 0 {
256 1.0
257 } else {
258 recoveries as f64 / (failures + recoveries) as f64
259 }
260 }
261
262 pub async fn clear_history(&self) {
264 let mut history = self.failure_history.write().await;
265 history.clear();
266
267 *self.total_failures.write().await = 0;
268 *self.total_retries.write().await = 0;
269 *self.successful_recoveries.write().await = 0;
270 }
271
272 pub async fn recent_failures(&self, count: usize) -> Vec<FailureRecord> {
274 let history = self.failure_history.read().await;
275 history.iter().rev().take(count).cloned().collect()
276 }
277}
278
279#[cfg(test)]
280mod tests {
281 use super::*;
282 use std::sync::atomic::{AtomicU32, Ordering};
283
284 #[tokio::test]
285 async fn test_recovery_manager_success() {
286 let config = RecoveryConfig::default();
287 let manager = RecoveryManager::new(config);
288
289 let result = manager
290 .execute_with_recovery(|| async { Ok::<_, StreamingError>(42) })
291 .await
292 .expect("recovery execution should succeed");
293
294 assert_eq!(result, 42);
295 assert_eq!(manager.total_failures().await, 0);
296 }
297
298 #[tokio::test]
299 async fn test_recovery_manager_retry_success() {
300 let config = RecoveryConfig::default();
301 let manager = RecoveryManager::new(config);
302 let counter = Arc::new(AtomicU32::new(0));
303
304 let result = manager
305 .execute_with_recovery(|| {
306 let c = counter.clone();
307 async move {
308 let count = c.fetch_add(1, Ordering::Relaxed);
309 if count < 2 {
310 Err(StreamingError::Other("temporary error".to_string()))
311 } else {
312 Ok(42)
313 }
314 }
315 })
316 .await
317 .expect("retry should eventually succeed");
318
319 assert_eq!(result, 42);
320 assert_eq!(manager.total_retries().await, 2);
321 assert_eq!(manager.successful_recoveries().await, 1);
322 }
323
324 #[tokio::test]
325 async fn test_recovery_manager_max_retries() {
326 let config = RecoveryConfig {
327 max_retries: 2,
328 initial_delay: Duration::from_millis(10),
329 ..Default::default()
330 };
331
332 let manager = RecoveryManager::new(config);
333
334 let result = manager
335 .execute_with_recovery(|| async {
336 Err::<i32, _>(StreamingError::Other("persistent error".to_string()))
337 })
338 .await;
339
340 assert!(result.is_err());
341 assert_eq!(manager.total_failures().await, 1);
342 assert_eq!(manager.total_retries().await, 2);
343 }
344
345 #[tokio::test]
346 async fn test_exponential_backoff() {
347 let config = RecoveryConfig {
348 strategy: RecoveryStrategy::ExponentialBackoff,
349 initial_delay: Duration::from_millis(100),
350 backoff_multiplier: 2.0,
351 max_delay: Duration::from_secs(1),
352 ..Default::default()
353 };
354
355 let manager = RecoveryManager::new(config);
356
357 let delay1 = manager.calculate_delay(1);
358 let delay2 = manager.calculate_delay(2);
359 let delay3 = manager.calculate_delay(3);
360
361 assert_eq!(delay1, Duration::from_millis(100));
362 assert_eq!(delay2, Duration::from_millis(200));
363 assert_eq!(delay3, Duration::from_millis(400));
364 }
365
366 #[tokio::test]
367 async fn test_failure_history() {
368 let config = RecoveryConfig::default();
369 let manager = RecoveryManager::new(config);
370
371 let record = FailureRecord::new("test error".to_string(), 1);
372 manager.record_failure(record).await;
373
374 let history = manager.get_failure_history().await;
375 assert_eq!(history.len(), 1);
376 assert_eq!(history[0].error, "test error");
377 }
378
379 #[tokio::test]
380 async fn test_success_rate() {
381 let config = RecoveryConfig::default();
382 let manager = RecoveryManager::new(config);
383
384 *manager.total_failures.write().await = 2;
385 *manager.successful_recoveries.write().await = 8;
386
387 let rate = manager.success_rate().await;
388 assert!((rate - 0.8).abs() < 0.01);
389 }
390}