1use crate::error::{AptosError, AptosResult};
26use std::collections::HashSet;
27use std::future::Future;
28use std::time::Duration;
29use tokio::time::sleep;
30
31#[derive(Debug, Clone)]
33pub struct RetryConfig {
34 pub max_retries: u32,
36 pub initial_delay_ms: u64,
38 pub max_delay_ms: u64,
40 pub exponential_base: f64,
42 pub jitter: bool,
44 pub jitter_factor: f64,
46 pub retryable_status_codes: HashSet<u16>,
49}
50
51impl Default for RetryConfig {
52 fn default() -> Self {
53 Self {
54 max_retries: 3,
55 initial_delay_ms: 100,
56 max_delay_ms: 10_000,
57 exponential_base: 2.0,
58 jitter: true,
59 jitter_factor: 0.5,
60 retryable_status_codes: [
61 408, 429, 500, 502, 503, 504, ]
68 .into_iter()
69 .collect(),
70 }
71 }
72}
73
74impl RetryConfig {
75 pub fn builder() -> RetryConfigBuilder {
77 RetryConfigBuilder::default()
78 }
79
80 pub fn no_retry() -> Self {
82 Self {
83 max_retries: 0,
84 ..Default::default()
85 }
86 }
87
88 pub fn aggressive() -> Self {
90 Self {
91 max_retries: 5,
92 initial_delay_ms: 50,
93 max_delay_ms: 5_000,
94 exponential_base: 1.5,
95 jitter: true,
96 jitter_factor: 0.3,
97 ..Default::default()
98 }
99 }
100
101 pub fn conservative() -> Self {
103 Self {
104 max_retries: 3,
105 initial_delay_ms: 500,
106 max_delay_ms: 30_000,
107 exponential_base: 2.0,
108 jitter: true,
109 jitter_factor: 0.5,
110 ..Default::default()
111 }
112 }
113
114 #[allow(clippy::cast_possible_truncation)] pub fn delay_for_attempt(&self, attempt: u32) -> Duration {
117 if attempt == 0 {
118 return Duration::from_millis(0);
119 }
120
121 #[allow(clippy::cast_precision_loss)] let base_delay = self.initial_delay_ms as f64
124 * self
125 .exponential_base
126 .powi(attempt.saturating_sub(1).cast_signed());
127
128 #[allow(clippy::cast_precision_loss)] let capped_delay = base_delay.min(self.max_delay_ms as f64);
131
132 let final_delay = if self.jitter {
134 let jitter_range = capped_delay * self.jitter_factor;
135 let jitter = rand::random::<f64>() * jitter_range * 2.0 - jitter_range;
136 (capped_delay + jitter).max(0.0)
137 } else {
138 capped_delay
139 };
140
141 #[allow(clippy::cast_sign_loss)] Duration::from_millis(final_delay as u64)
143 }
144
145 #[inline]
147 pub fn is_retryable_status(&self, status_code: u16) -> bool {
148 self.retryable_status_codes.contains(&status_code)
149 }
150
151 #[inline]
153 pub fn is_retryable_error(&self, error: &AptosError) -> bool {
154 match error {
155 AptosError::Http(_) | AptosError::RateLimited { .. } => true,
157 AptosError::Api { status_code, .. } => self.is_retryable_status(*status_code),
159 _ => false,
161 }
162 }
163}
164
165#[derive(Debug, Clone, Default)]
167pub struct RetryConfigBuilder {
168 max_retries: Option<u32>,
169 initial_delay_ms: Option<u64>,
170 max_delay_ms: Option<u64>,
171 exponential_base: Option<f64>,
172 jitter: Option<bool>,
173 jitter_factor: Option<f64>,
174 retryable_status_codes: Option<HashSet<u16>>,
175}
176
177impl RetryConfigBuilder {
178 #[must_use]
180 pub fn max_retries(mut self, max_retries: u32) -> Self {
181 self.max_retries = Some(max_retries);
182 self
183 }
184
185 #[must_use]
187 pub fn initial_delay_ms(mut self, initial_delay_ms: u64) -> Self {
188 self.initial_delay_ms = Some(initial_delay_ms);
189 self
190 }
191
192 #[must_use]
194 pub fn max_delay_ms(mut self, max_delay_ms: u64) -> Self {
195 self.max_delay_ms = Some(max_delay_ms);
196 self
197 }
198
199 #[must_use]
201 pub fn exponential_base(mut self, base: f64) -> Self {
202 self.exponential_base = Some(base);
203 self
204 }
205
206 #[must_use]
208 pub fn jitter(mut self, jitter: bool) -> Self {
209 self.jitter = Some(jitter);
210 self
211 }
212
213 #[must_use]
215 pub fn jitter_factor(mut self, factor: f64) -> Self {
216 self.jitter_factor = Some(factor.clamp(0.0, 1.0));
217 self
218 }
219
220 #[must_use]
222 pub fn retryable_status_codes(mut self, codes: impl IntoIterator<Item = u16>) -> Self {
223 self.retryable_status_codes = Some(codes.into_iter().collect());
224 self
225 }
226
227 #[must_use]
229 pub fn add_retryable_status_code(mut self, code: u16) -> Self {
230 let mut codes = self.retryable_status_codes.unwrap_or_default();
231 codes.insert(code);
232 self.retryable_status_codes = Some(codes);
233 self
234 }
235
236 #[must_use]
238 pub fn build(self) -> RetryConfig {
239 let default = RetryConfig::default();
240 RetryConfig {
241 max_retries: self.max_retries.unwrap_or(default.max_retries),
242 initial_delay_ms: self.initial_delay_ms.unwrap_or(default.initial_delay_ms),
243 max_delay_ms: self.max_delay_ms.unwrap_or(default.max_delay_ms),
244 exponential_base: self.exponential_base.unwrap_or(default.exponential_base),
245 jitter: self.jitter.unwrap_or(default.jitter),
246 jitter_factor: self.jitter_factor.unwrap_or(default.jitter_factor),
247 retryable_status_codes: self
248 .retryable_status_codes
249 .unwrap_or(default.retryable_status_codes),
250 }
251 }
252}
253
254#[derive(Debug, Clone)]
256pub struct RetryExecutor {
257 config: RetryConfig,
258}
259
260impl RetryExecutor {
261 pub fn new(config: RetryConfig) -> Self {
263 Self { config }
264 }
265
266 pub fn with_defaults() -> Self {
268 Self::new(RetryConfig::default())
269 }
270
271 pub async fn execute<F, Fut, T>(&self, operation: F) -> AptosResult<T>
282 where
283 F: Fn() -> Fut,
284 Fut: Future<Output = AptosResult<T>>,
285 {
286 let mut attempt = 0;
287
288 loop {
289 match operation().await {
290 Ok(result) => return Ok(result),
291 Err(error) => {
292 if attempt >= self.config.max_retries || !self.config.is_retryable_error(&error)
294 {
295 return Err(error);
296 }
297
298 attempt += 1;
299
300 let delay = if let AptosError::RateLimited {
303 retry_after_secs: Some(secs),
304 } = &error
305 {
306 let capped_secs = (*secs).min(300); Duration::from_secs(capped_secs)
309 } else {
310 self.config.delay_for_attempt(attempt)
311 };
312
313 if !delay.is_zero() {
314 sleep(delay).await;
315 }
316 }
317 }
318 }
319 }
320
321 pub async fn execute_with_predicate<F, Fut, T, P>(
328 &self,
329 operation: F,
330 should_retry: P,
331 ) -> AptosResult<T>
332 where
333 F: Fn() -> Fut,
334 Fut: Future<Output = AptosResult<T>>,
335 P: Fn(&AptosError) -> bool,
336 {
337 let mut attempt = 0;
338
339 loop {
340 match operation().await {
341 Ok(result) => return Ok(result),
342 Err(error) => {
343 if attempt >= self.config.max_retries || !should_retry(&error) {
344 return Err(error);
345 }
346
347 attempt += 1;
348
349 let delay = if let AptosError::RateLimited {
351 retry_after_secs: Some(secs),
352 } = &error
353 {
354 let capped_secs = (*secs).min(300);
355 Duration::from_secs(capped_secs)
356 } else {
357 self.config.delay_for_attempt(attempt)
358 };
359
360 if !delay.is_zero() {
361 sleep(delay).await;
362 }
363 }
364 }
365 }
366 }
367}
368
369pub trait RetryExt<T> {
371 fn with_retry(self, config: &RetryConfig) -> impl Future<Output = AptosResult<T>>;
373}
374
375pub async fn retry<F, Fut, T>(operation: F) -> AptosResult<T>
383where
384 F: Fn() -> Fut,
385 Fut: Future<Output = AptosResult<T>>,
386{
387 RetryExecutor::with_defaults().execute(operation).await
388}
389
390pub async fn retry_with_config<F, Fut, T>(config: &RetryConfig, operation: F) -> AptosResult<T>
398where
399 F: Fn() -> Fut,
400 Fut: Future<Output = AptosResult<T>>,
401{
402 RetryExecutor::new(config.clone()).execute(operation).await
403}
404
405#[cfg(test)]
406mod tests {
407 use super::*;
408 use std::sync::Arc;
409 use std::sync::atomic::{AtomicU32, Ordering};
410
411 #[test]
412 fn test_default_config() {
413 let config = RetryConfig::default();
414 assert_eq!(config.max_retries, 3);
415 assert_eq!(config.initial_delay_ms, 100);
416 assert!(config.jitter);
417 }
418
419 #[test]
420 fn test_no_retry_config() {
421 let config = RetryConfig::no_retry();
422 assert_eq!(config.max_retries, 0);
423 }
424
425 #[test]
426 fn test_builder() {
427 let config = RetryConfig::builder()
428 .max_retries(5)
429 .initial_delay_ms(200)
430 .max_delay_ms(5000)
431 .exponential_base(1.5)
432 .jitter(false)
433 .build();
434
435 assert_eq!(config.max_retries, 5);
436 assert_eq!(config.initial_delay_ms, 200);
437 assert_eq!(config.max_delay_ms, 5000);
438 assert!((config.exponential_base - 1.5).abs() < f64::EPSILON);
439 assert!(!config.jitter);
440 }
441
442 #[test]
443 fn test_delay_calculation_no_jitter() {
444 let config = RetryConfig::builder()
445 .initial_delay_ms(100)
446 .exponential_base(2.0)
447 .jitter(false)
448 .build();
449
450 assert_eq!(config.delay_for_attempt(0), Duration::from_millis(0));
452
453 assert_eq!(config.delay_for_attempt(1), Duration::from_millis(100));
455
456 assert_eq!(config.delay_for_attempt(2), Duration::from_millis(200));
458
459 assert_eq!(config.delay_for_attempt(3), Duration::from_millis(400));
461 }
462
463 #[test]
464 fn test_delay_capped_at_max() {
465 let config = RetryConfig::builder()
466 .initial_delay_ms(1000)
467 .max_delay_ms(2000)
468 .exponential_base(2.0)
469 .jitter(false)
470 .build();
471
472 assert_eq!(config.delay_for_attempt(3), Duration::from_millis(2000));
474 }
475
476 #[test]
477 fn test_retryable_status_codes() {
478 let config = RetryConfig::default();
479
480 assert!(config.is_retryable_status(429)); assert!(config.is_retryable_status(503)); assert!(!config.is_retryable_status(400)); assert!(!config.is_retryable_status(404)); }
485
486 #[test]
487 fn test_retryable_errors() {
488 let config = RetryConfig::default();
489
490 let api_error = AptosError::Api {
492 status_code: 503,
493 message: "Service Unavailable".to_string(),
494 error_code: None,
495 vm_error_code: None,
496 };
497 assert!(config.is_retryable_error(&api_error));
498
499 let rate_limited = AptosError::RateLimited {
501 retry_after_secs: Some(30),
502 };
503 assert!(config.is_retryable_error(&rate_limited));
504
505 let api_error_400 = AptosError::Api {
507 status_code: 400,
508 message: "Bad Request".to_string(),
509 error_code: None,
510 vm_error_code: None,
511 };
512 assert!(!config.is_retryable_error(&api_error_400));
513
514 let not_found = AptosError::NotFound("resource".to_string());
516 assert!(!config.is_retryable_error(¬_found));
517 }
518
519 #[tokio::test]
520 async fn test_retry_succeeds_on_first_try() {
521 let executor = RetryExecutor::with_defaults();
522 let counter = Arc::new(AtomicU32::new(0));
523 let counter_clone = counter.clone();
524
525 let result = executor
526 .execute(|| {
527 let counter = counter_clone.clone();
528 async move {
529 counter.fetch_add(1, Ordering::SeqCst);
530 Ok::<_, AptosError>(42)
531 }
532 })
533 .await;
534
535 assert_eq!(result.unwrap(), 42);
536 assert_eq!(counter.load(Ordering::SeqCst), 1);
537 }
538
539 #[tokio::test]
540 async fn test_retry_succeeds_after_failures() {
541 let config = RetryConfig::builder()
542 .max_retries(3)
543 .initial_delay_ms(1) .jitter(false)
545 .build();
546 let executor = RetryExecutor::new(config);
547 let counter = Arc::new(AtomicU32::new(0));
548 let counter_clone = counter.clone();
549
550 let result = executor
551 .execute(|| {
552 let counter = counter_clone.clone();
553 async move {
554 let count = counter.fetch_add(1, Ordering::SeqCst);
555 if count < 2 {
556 Err(AptosError::Api {
557 status_code: 503,
558 message: "Service Unavailable".to_string(),
559 error_code: None,
560 vm_error_code: None,
561 })
562 } else {
563 Ok(42)
564 }
565 }
566 })
567 .await;
568
569 assert_eq!(result.unwrap(), 42);
570 assert_eq!(counter.load(Ordering::SeqCst), 3); }
572
573 #[tokio::test]
574 async fn test_retry_exhausted() {
575 let config = RetryConfig::builder()
576 .max_retries(2)
577 .initial_delay_ms(1)
578 .jitter(false)
579 .build();
580 let executor = RetryExecutor::new(config);
581 let counter = Arc::new(AtomicU32::new(0));
582 let counter_clone = counter.clone();
583
584 let result = executor
585 .execute(|| {
586 let counter = counter_clone.clone();
587 async move {
588 counter.fetch_add(1, Ordering::SeqCst);
589 Err::<i32, _>(AptosError::Api {
590 status_code: 503,
591 message: "Always fails".to_string(),
592 error_code: None,
593 vm_error_code: None,
594 })
595 }
596 })
597 .await;
598
599 assert!(result.is_err());
600 assert_eq!(counter.load(Ordering::SeqCst), 3); }
602
603 #[tokio::test]
604 async fn test_no_retry_on_non_retryable_error() {
605 let config = RetryConfig::builder()
606 .max_retries(3)
607 .initial_delay_ms(1)
608 .build();
609 let executor = RetryExecutor::new(config);
610 let counter = Arc::new(AtomicU32::new(0));
611 let counter_clone = counter.clone();
612
613 let result = executor
614 .execute(|| {
615 let counter = counter_clone.clone();
616 async move {
617 counter.fetch_add(1, Ordering::SeqCst);
618 Err::<i32, _>(AptosError::Api {
619 status_code: 400, message: "Bad Request".to_string(),
621 error_code: None,
622 vm_error_code: None,
623 })
624 }
625 })
626 .await;
627
628 assert!(result.is_err());
629 assert_eq!(counter.load(Ordering::SeqCst), 1); }
631
632 #[test]
633 fn test_aggressive_config() {
634 let config = RetryConfig::aggressive();
635 assert_eq!(config.max_retries, 5);
636 assert_eq!(config.initial_delay_ms, 50);
637 assert_eq!(config.max_delay_ms, 5_000);
638 assert!((config.exponential_base - 1.5).abs() < f64::EPSILON);
639 assert!(config.jitter);
640 }
641
642 #[test]
643 fn test_conservative_config() {
644 let config = RetryConfig::conservative();
645 assert_eq!(config.max_retries, 3);
646 assert_eq!(config.initial_delay_ms, 500);
647 assert_eq!(config.max_delay_ms, 30_000);
648 assert!((config.exponential_base - 2.0).abs() < f64::EPSILON);
649 assert!(config.jitter);
650 }
651
652 #[test]
653 fn test_builder_jitter_factor() {
654 let config = RetryConfig::builder().jitter_factor(0.25).build();
655
656 assert!((config.jitter_factor - 0.25).abs() < f64::EPSILON);
657 }
658
659 #[test]
660 fn test_builder_retryable_status_codes() {
661 let config = RetryConfig::builder()
662 .retryable_status_codes([500, 502])
663 .build();
664
665 assert!(config.is_retryable_status(500));
666 assert!(config.is_retryable_status(502));
667 assert!(!config.is_retryable_status(503)); }
669
670 #[test]
671 fn test_delay_with_jitter() {
672 let config = RetryConfig::builder()
673 .initial_delay_ms(100)
674 .jitter(true)
675 .jitter_factor(0.5)
676 .build();
677
678 let delay1 = config.delay_for_attempt(1);
680 assert!(delay1 >= Duration::from_millis(50));
682 assert!(delay1 <= Duration::from_millis(150));
683 }
684
685 #[test]
686 fn test_delay_zero_for_first_attempt() {
687 let config = RetryConfig::default();
688 assert_eq!(config.delay_for_attempt(0), Duration::from_millis(0));
689 }
690
691 #[test]
692 fn test_retryable_error_transaction_error() {
693 let config = RetryConfig::default();
694
695 let txn_error = AptosError::Transaction("failed".to_string());
697 assert!(!config.is_retryable_error(&txn_error));
698 }
699
700 #[test]
701 fn test_retryable_error_invalid_address() {
702 let config = RetryConfig::default();
703
704 let addr_error = AptosError::InvalidAddress("bad".to_string());
706 assert!(!config.is_retryable_error(&addr_error));
707 }
708
709 #[tokio::test]
710 async fn test_retry_with_no_retry_config() {
711 let config = RetryConfig::no_retry();
712 let executor = RetryExecutor::new(config);
713 let counter = Arc::new(AtomicU32::new(0));
714 let counter_clone = counter.clone();
715
716 let result = executor
717 .execute(|| {
718 let counter = counter_clone.clone();
719 async move {
720 counter.fetch_add(1, Ordering::SeqCst);
721 Err::<i32, _>(AptosError::Api {
722 status_code: 503,
723 message: "Service Unavailable".to_string(),
724 error_code: None,
725 vm_error_code: None,
726 })
727 }
728 })
729 .await;
730
731 assert!(result.is_err());
732 assert_eq!(counter.load(Ordering::SeqCst), 1); }
734
735 #[test]
736 fn test_retry_config_clone() {
737 let config = RetryConfig::builder()
738 .max_retries(5)
739 .initial_delay_ms(200)
740 .build();
741
742 let cloned = config.clone();
743 assert_eq!(config.max_retries, cloned.max_retries);
744 assert_eq!(config.initial_delay_ms, cloned.initial_delay_ms);
745 }
746
747 #[test]
748 fn test_retry_config_debug() {
749 let config = RetryConfig::default();
750 let debug = format!("{config:?}");
751 assert!(debug.contains("RetryConfig"));
752 assert!(debug.contains("max_retries"));
753 }
754
755 #[test]
756 fn test_builder_add_retryable_status_code() {
757 let config = RetryConfig::builder()
758 .add_retryable_status_code(599)
759 .build();
760
761 assert!(config.is_retryable_status(599));
762 }
763
764 #[test]
765 fn test_builder_add_duplicate_status_code() {
766 let config = RetryConfig::builder()
767 .add_retryable_status_code(500)
768 .add_retryable_status_code(500) .build();
770
771 assert!(config.is_retryable_status(500));
773 assert_eq!(config.retryable_status_codes.len(), 1);
775 }
776
777 #[test]
778 fn test_builder_jitter_factor_clamped() {
779 let config = RetryConfig::builder()
780 .jitter_factor(2.0) .build();
782
783 assert!((config.jitter_factor - 1.0).abs() < f64::EPSILON);
784
785 let config2 = RetryConfig::builder()
786 .jitter_factor(-1.0) .build();
788
789 assert!(config2.jitter_factor.abs() < f64::EPSILON);
790 }
791
792 #[test]
793 fn test_retry_executor_new() {
794 let config = RetryConfig::default();
795 let executor = RetryExecutor::new(config.clone());
796
797 let debug = format!("{executor:?}");
798 assert!(debug.contains("RetryExecutor"));
799 }
800
801 #[tokio::test]
802 async fn test_retry_with_custom_predicate() {
803 let config = RetryConfig::builder()
804 .max_retries(3)
805 .initial_delay_ms(1)
806 .jitter(false)
807 .build();
808 let executor = RetryExecutor::new(config);
809 let counter = Arc::new(AtomicU32::new(0));
810 let counter_clone = counter.clone();
811
812 let result = executor
814 .execute_with_predicate(
815 || {
816 let counter = counter_clone.clone();
817 async move {
818 let count = counter.fetch_add(1, Ordering::SeqCst);
819 if count < 2 {
820 Err(AptosError::NotFound("test".to_string()))
821 } else {
822 Ok(42)
823 }
824 }
825 },
826 |_| true, )
828 .await;
829
830 assert_eq!(result.unwrap(), 42);
831 assert_eq!(counter.load(Ordering::SeqCst), 3);
832 }
833
834 #[tokio::test]
835 async fn test_retry_with_predicate_no_retry() {
836 let config = RetryConfig::builder()
837 .max_retries(3)
838 .initial_delay_ms(1)
839 .build();
840 let executor = RetryExecutor::new(config);
841 let counter = Arc::new(AtomicU32::new(0));
842 let counter_clone = counter.clone();
843
844 let result = executor
846 .execute_with_predicate(
847 || {
848 let counter = counter_clone.clone();
849 async move {
850 counter.fetch_add(1, Ordering::SeqCst);
851 Err::<i32, _>(AptosError::Api {
852 status_code: 503,
853 message: "Fail".to_string(),
854 error_code: None,
855 vm_error_code: None,
856 })
857 }
858 },
859 |_| false, )
861 .await;
862
863 assert!(result.is_err());
864 assert_eq!(counter.load(Ordering::SeqCst), 1); }
866
867 #[tokio::test]
868 async fn test_retry_convenience_function() {
869 let counter = Arc::new(AtomicU32::new(0));
870 let counter_clone = counter.clone();
871
872 let result = retry(|| {
873 let counter = counter_clone.clone();
874 async move {
875 counter.fetch_add(1, Ordering::SeqCst);
876 Ok::<_, AptosError>(42)
877 }
878 })
879 .await;
880
881 assert_eq!(result.unwrap(), 42);
882 assert_eq!(counter.load(Ordering::SeqCst), 1);
883 }
884
885 #[tokio::test]
886 async fn test_retry_with_config_convenience_function() {
887 let config = RetryConfig::builder()
888 .max_retries(1)
889 .initial_delay_ms(1)
890 .jitter(false)
891 .build();
892 let counter = Arc::new(AtomicU32::new(0));
893 let counter_clone = counter.clone();
894
895 let result = retry_with_config(&config, || {
896 let counter = counter_clone.clone();
897 async move {
898 let count = counter.fetch_add(1, Ordering::SeqCst);
899 if count < 1 {
900 Err(AptosError::Api {
902 status_code: 503,
903 message: "Service Unavailable".to_string(),
904 error_code: None,
905 vm_error_code: None,
906 })
907 } else {
908 Ok(42)
909 }
910 }
911 })
912 .await;
913
914 assert_eq!(result.unwrap(), 42);
915 assert_eq!(counter.load(Ordering::SeqCst), 2);
916 }
917
918 #[test]
919 fn test_retryable_rate_limited_error() {
920 let config = RetryConfig::default();
921
922 let rate_limited = AptosError::RateLimited {
924 retry_after_secs: Some(5),
925 };
926 assert!(config.is_retryable_error(&rate_limited));
927 }
928
929 #[test]
930 fn test_builder_default_debug() {
931 let builder = RetryConfigBuilder::default();
932 let debug = format!("{builder:?}");
933 assert!(debug.contains("RetryConfigBuilder"));
934 }
935}