1use std::sync::Arc;
10use std::time::Duration;
11use tokio::time::{sleep, Instant};
12
13use crate::core::{error::LarkAPIError, error_helper::RetryStrategy, SDKResult};
14
15#[derive(Clone)]
17pub struct RetryConfig {
18 pub default_strategy: RetryStrategy,
20 pub enabled: bool,
22 pub on_retry: Option<Arc<dyn Fn(&RetryAttempt) + Send + Sync>>,
24 pub retry_filter: Option<Arc<dyn Fn(&LarkAPIError) -> bool + Send + Sync>>,
26}
27
28impl std::fmt::Debug for RetryConfig {
29 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
30 f.debug_struct("RetryConfig")
31 .field("default_strategy", &self.default_strategy)
32 .field("enabled", &self.enabled)
33 .field(
34 "on_retry",
35 &self.on_retry.as_ref().map(|_| "Fn(&RetryAttempt)"),
36 )
37 .field(
38 "retry_filter",
39 &self
40 .retry_filter
41 .as_ref()
42 .map(|_| "Fn(&LarkAPIError) -> bool"),
43 )
44 .finish()
45 }
46}
47
48impl Default for RetryConfig {
49 fn default() -> Self {
50 Self {
51 default_strategy: RetryStrategy::default(),
52 enabled: true,
53 on_retry: None,
54 retry_filter: None,
55 }
56 }
57}
58
59impl RetryConfig {
60 pub fn new() -> Self {
62 Self::default()
63 }
64
65 pub fn enabled(mut self, enabled: bool) -> Self {
67 self.enabled = enabled;
68 self
69 }
70
71 pub fn default_strategy(mut self, strategy: RetryStrategy) -> Self {
73 self.default_strategy = strategy;
74 self
75 }
76
77 pub fn on_retry<F>(mut self, callback: F) -> Self
79 where
80 F: Fn(&RetryAttempt) + Send + Sync + 'static,
81 {
82 self.on_retry = Some(Arc::new(callback));
83 self
84 }
85
86 pub fn retry_filter<F>(mut self, filter: F) -> Self
88 where
89 F: Fn(&LarkAPIError) -> bool + Send + Sync + 'static,
90 {
91 self.retry_filter = Some(Arc::new(filter));
92 self
93 }
94
95 pub fn server_errors_only(mut self) -> Self {
97 self.retry_filter = Some(Arc::new(|error| match error {
98 LarkAPIError::ApiError { code, .. } => {
99 matches!(*code, 500..=599)
100 }
101 LarkAPIError::RequestError(req_err) => {
102 req_err.contains("timeout")
103 || req_err.contains("timed out")
104 || req_err.contains("connect")
105 || req_err.contains("connection")
106 }
107 _ => false,
108 }));
109 self
110 }
111
112 pub fn aggressive(mut self) -> Self {
114 self.default_strategy = RetryStrategy {
115 max_attempts: 5,
116 base_delay: Duration::from_millis(500),
117 use_exponential_backoff: true,
118 max_delay: Duration::from_secs(30),
119 };
120 self
121 }
122
123 pub fn conservative(mut self) -> Self {
125 self.default_strategy = RetryStrategy {
126 max_attempts: 2,
127 base_delay: Duration::from_secs(2),
128 use_exponential_backoff: false,
129 max_delay: Duration::from_secs(10),
130 };
131 self
132 }
133}
134
135#[derive(Debug, Clone)]
137pub struct RetryAttempt {
138 pub attempt: u32,
140 pub max_attempts: u32,
142 pub delay: Duration,
144 pub error: LarkAPIError,
146 pub started_at: Instant,
148 pub elapsed: Duration,
150}
151
152impl RetryAttempt {
153 pub fn is_final_attempt(&self) -> bool {
155 self.attempt >= self.max_attempts
156 }
157
158 pub fn remaining_attempts(&self) -> u32 {
160 self.max_attempts.saturating_sub(self.attempt)
161 }
162
163 pub fn print_info(&self) {
165 let percentage = (self.attempt as f32 / self.max_attempts as f32 * 100.0) as u32;
166 println!(
167 "🔄 重试 {}/{} ({}%) - 延迟 {:?} - 耗时 {:?}",
168 self.attempt, self.max_attempts, percentage, self.delay, self.elapsed
169 );
170 }
171}
172
173pub struct RetryMiddleware {
175 config: RetryConfig,
176}
177
178impl Default for RetryMiddleware {
179 fn default() -> Self {
180 Self::new(RetryConfig::default())
181 }
182}
183
184impl RetryMiddleware {
185 pub fn new(config: RetryConfig) -> Self {
187 Self { config }
188 }
189
190 pub async fn execute<F, T, Fut>(&self, operation: F) -> SDKResult<T>
192 where
193 F: Fn() -> Fut,
194 Fut: std::future::Future<Output = SDKResult<T>>,
195 {
196 if !self.config.enabled {
197 return operation().await;
198 }
199
200 let started_at = Instant::now();
201 let mut last_error = None;
202
203 for attempt in 1..=self.config.default_strategy.max_attempts {
204 let result = operation().await;
205
206 match result {
207 Ok(value) => return Ok(value),
208 Err(error) => {
209 last_error = Some(error.clone());
210
211 if !self.should_retry(&error, attempt) {
213 return Err(error);
214 }
215
216 let delay = self.calculate_delay(attempt - 1);
218 let elapsed = started_at.elapsed();
219
220 let retry_attempt = RetryAttempt {
222 attempt,
223 max_attempts: self.config.default_strategy.max_attempts,
224 delay,
225 error: error.clone(),
226 started_at,
227 elapsed,
228 };
229
230 if let Some(callback) = &self.config.on_retry {
232 callback(&retry_attempt);
233 }
234
235 if !retry_attempt.is_final_attempt() {
237 sleep(delay).await;
238 }
239 }
240 }
241 }
242
243 Err(last_error.unwrap())
245 }
246
247 fn should_retry(&self, error: &LarkAPIError, attempt: u32) -> bool {
249 if attempt >= self.config.default_strategy.max_attempts {
251 return false;
252 }
253
254 if let Some(filter) = &self.config.retry_filter {
256 return filter(error);
257 }
258
259 error.is_retryable()
261 }
262
263 fn calculate_delay(&self, attempt: u32) -> Duration {
265 self.config.default_strategy.calculate_delay(attempt)
266 }
267}
268
269#[derive(Debug, Default)]
271pub struct RetryStats {
272 pub total_attempts: u32,
274 pub successful_attempts: u32,
276 pub retry_count: u32,
278 pub total_duration: Duration,
280 pub average_delay: Duration,
282}
283
284impl RetryStats {
285 pub fn success_rate(&self) -> f32 {
287 if self.total_attempts == 0 {
288 0.0
289 } else {
290 self.successful_attempts as f32 / self.total_attempts as f32
291 }
292 }
293
294 pub fn print_summary(&self) {
296 println!("📊 重试统计:");
297 println!(" 总尝试次数: {}", self.total_attempts);
298 println!(" 成功次数: {}", self.successful_attempts);
299 println!(" 重试次数: {}", self.retry_count);
300 println!(" 成功率: {:.1}%", self.success_rate() * 100.0);
301 println!(" 总耗时: {:?}", self.total_duration);
302 println!(" 平均延迟: {:?}", self.average_delay);
303 }
304}
305
306pub struct RetryMiddlewareWithStats {
308 middleware: RetryMiddleware,
309 stats: Arc<std::sync::Mutex<RetryStats>>,
310}
311
312impl RetryMiddlewareWithStats {
313 pub fn new(config: RetryConfig) -> Self {
315 let stats = Arc::new(std::sync::Mutex::new(RetryStats::default()));
316 let stats_clone = Arc::clone(&stats);
317
318 let config_with_stats = config.on_retry(move |attempt| {
320 if let Ok(mut stats) = stats_clone.lock() {
321 stats.total_attempts += 1;
322 stats.retry_count += 1;
323 stats.total_duration += attempt.elapsed;
324 }
325 });
326
327 Self {
328 middleware: RetryMiddleware::new(config_with_stats),
329 stats,
330 }
331 }
332
333 pub async fn execute<F, T, Fut>(&self, operation: F) -> SDKResult<T>
335 where
336 F: Fn() -> Fut,
337 Fut: std::future::Future<Output = SDKResult<T>>,
338 {
339 let result = self.middleware.execute(operation).await;
340
341 if let Ok(mut stats) = self.stats.lock() {
343 if result.is_ok() {
344 stats.successful_attempts += 1;
345 }
346 }
347
348 result
349 }
350
351 pub fn get_stats(&self) -> RetryStats {
353 let stats = self.stats.lock().unwrap();
354 RetryStats {
355 total_attempts: stats.total_attempts,
356 successful_attempts: stats.successful_attempts,
357 retry_count: stats.retry_count,
358 total_duration: stats.total_duration,
359 average_delay: stats.average_delay,
360 }
361 }
362
363 pub fn reset_stats(&self) {
365 if let Ok(mut stats) = self.stats.lock() {
366 *stats = RetryStats::default();
367 }
368 }
369}
370
371pub struct RetryStrategyBuilder {
373 strategy: RetryStrategy,
374}
375
376impl RetryStrategyBuilder {
377 pub fn new() -> Self {
379 Self {
380 strategy: RetryStrategy::default(),
381 }
382 }
383
384 pub fn max_attempts(mut self, max_attempts: u32) -> Self {
386 self.strategy.max_attempts = max_attempts;
387 self
388 }
389
390 pub fn base_delay(mut self, delay: Duration) -> Self {
392 self.strategy.base_delay = delay;
393 self
394 }
395
396 pub fn max_delay(mut self, delay: Duration) -> Self {
398 self.strategy.max_delay = delay;
399 self
400 }
401
402 pub fn exponential_backoff(mut self, enabled: bool) -> Self {
404 self.strategy.use_exponential_backoff = enabled;
405 self
406 }
407
408 pub fn build(self) -> RetryStrategy {
410 self.strategy
411 }
412
413 pub fn linear(max_attempts: u32, delay: Duration) -> RetryStrategy {
415 Self::new()
416 .max_attempts(max_attempts)
417 .base_delay(delay)
418 .exponential_backoff(false)
419 .build()
420 }
421
422 pub fn exponential(
424 max_attempts: u32,
425 base_delay: Duration,
426 max_delay: Duration,
427 ) -> RetryStrategy {
428 Self::new()
429 .max_attempts(max_attempts)
430 .base_delay(base_delay)
431 .max_delay(max_delay)
432 .exponential_backoff(true)
433 .build()
434 }
435}
436
437impl Default for RetryStrategyBuilder {
438 fn default() -> Self {
439 Self::new()
440 }
441}
442
443#[cfg(test)]
444mod tests {
445 use super::*;
446
447 #[test]
448 fn test_retry_config_builder() {
449 let config = RetryConfig::new().enabled(true).aggressive();
450
451 assert!(config.enabled);
452 assert_eq!(config.default_strategy.max_attempts, 5);
453 }
454
455 #[test]
456 fn test_retry_strategy_builder() {
457 let strategy = RetryStrategyBuilder::new()
458 .max_attempts(3)
459 .base_delay(Duration::from_secs(1))
460 .exponential_backoff(true)
461 .build();
462
463 assert_eq!(strategy.max_attempts, 3);
464 assert_eq!(strategy.base_delay, Duration::from_secs(1));
465 assert!(strategy.use_exponential_backoff);
466 }
467
468 #[test]
469 fn test_linear_strategy() {
470 let strategy = RetryStrategyBuilder::linear(3, Duration::from_secs(2));
471
472 assert_eq!(strategy.max_attempts, 3);
473 assert_eq!(strategy.base_delay, Duration::from_secs(2));
474 assert!(!strategy.use_exponential_backoff);
475 }
476
477 #[test]
478 fn test_exponential_strategy() {
479 let strategy = RetryStrategyBuilder::exponential(
480 5,
481 Duration::from_millis(500),
482 Duration::from_secs(30),
483 );
484
485 assert_eq!(strategy.max_attempts, 5);
486 assert_eq!(strategy.base_delay, Duration::from_millis(500));
487 assert_eq!(strategy.max_delay, Duration::from_secs(30));
488 assert!(strategy.use_exponential_backoff);
489 }
490
491 #[test]
492 fn test_retry_attempt_info() {
493 let error = LarkAPIError::api_error(500, "Server Error", None);
494 let attempt = RetryAttempt {
495 attempt: 2,
496 max_attempts: 3,
497 delay: Duration::from_secs(2),
498 error,
499 started_at: Instant::now(),
500 elapsed: Duration::from_secs(5),
501 };
502
503 assert!(!attempt.is_final_attempt());
504 assert_eq!(attempt.remaining_attempts(), 1);
505 }
506
507 #[test]
508 fn test_retry_stats() {
509 let stats = RetryStats {
510 total_attempts: 10,
511 successful_attempts: 8,
512 retry_count: 5,
513 total_duration: Duration::from_secs(30),
514 average_delay: Duration::from_secs(2),
515 };
516
517 assert_eq!(stats.success_rate(), 0.8);
518 }
519
520 #[tokio::test]
521 async fn test_retry_middleware_success() {
522 use std::sync::{
523 atomic::{AtomicU32, Ordering},
524 Arc,
525 };
526
527 let middleware = RetryMiddleware::default();
528 let call_count = Arc::new(AtomicU32::new(0));
529
530 let call_count_clone = Arc::clone(&call_count);
531 let result: Result<&str, LarkAPIError> = middleware
532 .execute(move || {
533 let count = call_count_clone.fetch_add(1, Ordering::SeqCst) + 1;
534 async move {
535 if count == 1 {
536 Err(LarkAPIError::api_error(500, "Server Error", None))
537 } else {
538 Ok("Success")
539 }
540 }
541 })
542 .await;
543
544 assert!(result.is_ok());
545 assert_eq!(result.unwrap(), "Success");
546 assert_eq!(call_count.load(Ordering::SeqCst), 2);
547 }
548
549 #[tokio::test]
550 async fn test_retry_middleware_failure() {
551 use std::sync::{
552 atomic::{AtomicU32, Ordering},
553 Arc,
554 };
555
556 let config = RetryConfig::new()
557 .default_strategy(RetryStrategyBuilder::linear(2, Duration::from_millis(1)));
558
559 let middleware = RetryMiddleware::new(config);
560 let call_count = Arc::new(AtomicU32::new(0));
561
562 let call_count_clone = Arc::clone(&call_count);
563 let result: Result<&str, LarkAPIError> = middleware
564 .execute(move || {
565 call_count_clone.fetch_add(1, Ordering::SeqCst);
566 async move { Err(LarkAPIError::api_error(500, "Server Error", None)) }
567 })
568 .await;
569
570 assert!(result.is_err());
571 assert_eq!(call_count.load(Ordering::SeqCst), 2);
572 }
573}