1use crate::error::OstiumError;
2use std::sync::atomic::{AtomicU64, Ordering};
3use std::sync::Arc;
4use std::time::Duration;
5use tokio::time::{sleep, Instant};
6use tracing::{debug, warn};
7
8#[derive(Debug, Clone)]
10pub struct RetryConfig {
11 pub max_attempts: u32,
13 pub initial_delay: Duration,
15 pub max_delay: Duration,
17 pub backoff_multiplier: f64,
19 pub jitter_factor: f64,
21 pub operation_timeout: Duration,
23}
24
25impl Default for RetryConfig {
26 fn default() -> Self {
27 Self {
28 max_attempts: 3,
29 initial_delay: Duration::from_millis(100),
30 max_delay: Duration::from_secs(30),
31 backoff_multiplier: 2.0,
32 jitter_factor: 0.1,
33 operation_timeout: Duration::from_secs(30),
34 }
35 }
36}
37
38impl RetryConfig {
39 pub fn network() -> Self {
41 Self {
42 max_attempts: 5,
43 initial_delay: Duration::from_millis(200),
44 max_delay: Duration::from_secs(10),
45 backoff_multiplier: 1.5,
46 jitter_factor: 0.2,
47 operation_timeout: Duration::from_secs(30),
48 }
49 }
50
51 pub fn contract() -> Self {
53 Self {
54 max_attempts: 3,
55 initial_delay: Duration::from_millis(500),
56 max_delay: Duration::from_secs(20),
57 backoff_multiplier: 2.0,
58 jitter_factor: 0.1,
59 operation_timeout: Duration::from_secs(60),
60 }
61 }
62
63 pub fn graphql() -> Self {
65 Self {
66 max_attempts: 4,
67 initial_delay: Duration::from_millis(100),
68 max_delay: Duration::from_secs(5),
69 backoff_multiplier: 1.8,
70 jitter_factor: 0.15,
71 operation_timeout: Duration::from_secs(15),
72 }
73 }
74}
75
76#[derive(Debug, Clone, Copy, PartialEq)]
78pub enum CircuitState {
79 Closed,
81 Open,
83 HalfOpen,
85}
86
87#[derive(Debug)]
89pub struct CircuitBreaker {
90 state: Arc<AtomicU64>, failure_threshold: u32,
92 recovery_timeout: Duration,
93 success_threshold: u32,
94}
95
96impl CircuitBreaker {
97 pub fn new(failure_threshold: u32, recovery_timeout: Duration) -> Self {
99 Self {
100 state: Arc::new(AtomicU64::new(0)), failure_threshold,
102 recovery_timeout,
103 success_threshold: 3,
104 }
105 }
106
107 pub fn call<F, Fut, T>(
109 &self,
110 operation: F,
111 ) -> impl std::future::Future<Output = Result<T, OstiumError>>
112 where
113 F: FnOnce() -> Fut,
114 Fut: std::future::Future<Output = Result<T, OstiumError>>,
115 {
116 let state = self.state.clone();
117 let failure_threshold = self.failure_threshold;
118 let recovery_timeout = self.recovery_timeout;
119 let _success_threshold = self.success_threshold;
120
121 async move {
122 let current_state = Self::decode_state(state.load(Ordering::Acquire));
123
124 match current_state.0 {
125 CircuitState::Open => {
126 let time_since_failure = Instant::now().duration_since(
127 Instant::now() - Duration::from_secs(current_state.2 as u64),
128 );
129
130 if time_since_failure >= recovery_timeout {
131 let new_packed = Self::encode_state(CircuitState::HalfOpen, 0, 0);
133 state.store(new_packed, Ordering::Release);
134 debug!("Circuit breaker transitioning to half-open");
135 } else {
136 return Err(OstiumError::Network(
137 "Circuit breaker is open - too many recent failures".to_string(),
138 ));
139 }
140 }
141 CircuitState::HalfOpen => {
142 }
144 CircuitState::Closed => {
145 }
147 }
148
149 match operation().await {
150 Ok(result) => {
151 match current_state.0 {
153 CircuitState::HalfOpen => {
154 let new_packed = Self::encode_state(CircuitState::Closed, 0, 0);
155 state.store(new_packed, Ordering::Release);
156 debug!("Circuit breaker closed after successful recovery");
157 }
158 _ => {
159 let new_packed = Self::encode_state(CircuitState::Closed, 0, 0);
160 state.store(new_packed, Ordering::Release);
161 }
162 }
163 Ok(result)
164 }
165 Err(error) => {
166 let new_failure_count = current_state.1 + 1;
168 let current_time = Instant::now().elapsed().as_secs() as u32;
169
170 if new_failure_count >= failure_threshold {
171 let new_packed =
172 Self::encode_state(CircuitState::Open, new_failure_count, current_time);
173 state.store(new_packed, Ordering::Release);
174 warn!(
175 "Circuit breaker opened after {} failures",
176 new_failure_count
177 );
178 } else {
179 let new_packed =
180 Self::encode_state(current_state.0, new_failure_count, current_time);
181 state.store(new_packed, Ordering::Release);
182 }
183
184 Err(error)
185 }
186 }
187 }
188 }
189
190 fn encode_state(state: CircuitState, failure_count: u32, last_failure_time: u32) -> u64 {
191 let state_bits = match state {
192 CircuitState::Closed => 0u64,
193 CircuitState::Open => 1u64,
194 CircuitState::HalfOpen => 2u64,
195 };
196
197 (state_bits << 56) | ((failure_count as u64 & 0xFFFFFF) << 32) | (last_failure_time as u64)
198 }
199
200 fn decode_state(packed: u64) -> (CircuitState, u32, u32) {
201 let state = match (packed >> 56) & 0xFF {
202 0 => CircuitState::Closed,
203 1 => CircuitState::Open,
204 2 => CircuitState::HalfOpen,
205 _ => CircuitState::Closed,
206 };
207 let failure_count = ((packed >> 32) & 0xFFFFFF) as u32;
208 let last_failure_time = (packed & 0xFFFFFFFF) as u32;
209
210 (state, failure_count, last_failure_time)
211 }
212
213 pub fn state(&self) -> CircuitState {
215 Self::decode_state(self.state.load(Ordering::Acquire)).0
216 }
217}
218
219pub struct RetryExecutor {
221 config: RetryConfig,
222 circuit_breaker: Option<CircuitBreaker>,
223}
224
225impl RetryExecutor {
226 pub fn new(config: RetryConfig) -> Self {
228 Self {
229 config,
230 circuit_breaker: None,
231 }
232 }
233
234 pub fn with_circuit_breaker(
236 mut self,
237 failure_threshold: u32,
238 recovery_timeout: Duration,
239 ) -> Self {
240 self.circuit_breaker = Some(CircuitBreaker::new(failure_threshold, recovery_timeout));
241 self
242 }
243
244 pub async fn execute<F, Fut, T>(&self, operation: F) -> Result<T, OstiumError>
246 where
247 F: Fn() -> Fut,
248 Fut: std::future::Future<Output = Result<T, OstiumError>>,
249 {
250 let mut attempt = 0;
251 let mut delay = self.config.initial_delay;
252
253 loop {
254 attempt += 1;
255
256 debug!(
257 "Executing operation attempt {}/{}",
258 attempt, self.config.max_attempts
259 );
260
261 let result = if let Some(ref circuit_breaker) = self.circuit_breaker {
263 circuit_breaker.call(&operation).await
264 } else {
265 operation().await
266 };
267
268 match result {
269 Ok(value) => {
270 if attempt > 1 {
271 debug!("Operation succeeded after {} attempts", attempt);
272 }
273 return Ok(value);
274 }
275 Err(error) => {
276 if !self.should_retry(&error) || attempt >= self.config.max_attempts {
277 warn!("Operation failed after {} attempts: {}", attempt, error);
278 return Err(error);
279 }
280
281 debug!(
282 "Operation failed on attempt {}, retrying after {:?}: {}",
283 attempt, delay, error
284 );
285
286 let jittered_delay = self.add_jitter(delay);
288 sleep(jittered_delay).await;
289
290 delay = std::cmp::min(
292 Duration::from_millis(
293 (delay.as_millis() as f64 * self.config.backoff_multiplier) as u64,
294 ),
295 self.config.max_delay,
296 );
297 }
298 }
299 }
300 }
301
302 fn should_retry(&self, error: &OstiumError) -> bool {
304 match error {
305 OstiumError::Network(_) => true,
307
308 OstiumError::Http(e) => e.is_timeout() || e.is_connect() || e.is_request(),
310
311 OstiumError::Contract(msg) => {
313 msg.contains("timeout")
314 || msg.contains("connection")
315 || msg.contains("temporarily unavailable")
316 || msg.contains("rate limit")
317 }
318
319 OstiumError::GraphQL(msg) => {
321 msg.contains("timeout")
322 || msg.contains("server error")
323 || msg.contains("503")
324 || msg.contains("502")
325 || msg.contains("504")
326 }
327
328 OstiumError::Provider(msg) => {
330 msg.contains("timeout") || msg.contains("connection") || msg.contains("rate limit")
331 }
332
333 OstiumError::Validation(_) => false,
335 OstiumError::Wallet(_) => false,
336 OstiumError::Config(_) => false,
337 OstiumError::Json(_) => false,
338 OstiumError::Decimal(_) => false,
339 OstiumError::Other(_) => false,
340 }
341 }
342
343 fn add_jitter(&self, delay: Duration) -> Duration {
345 if self.config.jitter_factor <= 0.0 {
346 return delay;
347 }
348
349 let jitter_range = (delay.as_millis() as f64 * self.config.jitter_factor) as u64;
350 let jitter = fastrand::u64(0..=jitter_range);
351
352 Duration::from_millis(delay.as_millis() as u64 + jitter)
353 }
354}
355
356#[macro_export]
359macro_rules! retry_network {
360 ($operation:expr) => {
361 $crate::retry::RetryExecutor::new($crate::retry::RetryConfig::network())
362 .execute(|| async { $operation })
363 .await
364 };
365}
366
367#[macro_export]
369macro_rules! retry_contract {
370 ($operation:expr) => {
371 $crate::retry::RetryExecutor::new($crate::retry::RetryConfig::contract())
372 .with_circuit_breaker(5, std::time::Duration::from_secs(60))
373 .execute(|| async { $operation })
374 .await
375 };
376}
377
378#[macro_export]
380macro_rules! retry_graphql {
381 ($operation:expr) => {
382 $crate::retry::RetryExecutor::new($crate::retry::RetryConfig::graphql())
383 .execute(|| async { $operation })
384 .await
385 };
386}
387
388pub use retry_contract;
390pub use retry_graphql;
391pub use retry_network;
392
393#[cfg(test)]
394mod tests {
395 use super::*;
396 use std::sync::atomic::{AtomicU32, Ordering};
397 use std::sync::Arc;
398
399 #[tokio::test]
400 async fn test_retry_success_after_failures() {
401 let counter = Arc::new(AtomicU32::new(0));
402 let counter_clone = counter.clone();
403
404 let config = RetryConfig {
405 max_attempts: 3,
406 initial_delay: Duration::from_millis(10),
407 ..Default::default()
408 };
409
410 let executor = RetryExecutor::new(config);
411
412 let result = executor
413 .execute(|| {
414 let counter = counter_clone.clone();
415 async move {
416 let count = counter.fetch_add(1, Ordering::SeqCst);
417 if count < 2 {
418 Err(OstiumError::Network("Temporary failure".to_string()))
419 } else {
420 Ok("Success".to_string())
421 }
422 }
423 })
424 .await;
425
426 assert!(result.is_ok());
427 assert_eq!(result.unwrap(), "Success");
428 assert_eq!(counter.load(Ordering::SeqCst), 3);
429 }
430
431 #[tokio::test]
432 async fn test_retry_exhaustion() {
433 let counter = Arc::new(AtomicU32::new(0));
434 let counter_clone = counter.clone();
435
436 let config = RetryConfig {
437 max_attempts: 2,
438 initial_delay: Duration::from_millis(10),
439 ..Default::default()
440 };
441
442 let executor = RetryExecutor::new(config);
443
444 let result: Result<String, OstiumError> = executor
445 .execute(|| {
446 let counter = counter_clone.clone();
447 async move {
448 counter.fetch_add(1, Ordering::SeqCst);
449 Err(OstiumError::Network("Permanent failure".to_string()))
450 }
451 })
452 .await;
453
454 assert!(result.is_err());
455 assert_eq!(counter.load(Ordering::SeqCst), 2);
456 }
457
458 #[tokio::test]
459 async fn test_circuit_breaker() {
460 let circuit_breaker = CircuitBreaker::new(2, Duration::from_millis(100));
461
462 let result1: Result<String, OstiumError> = circuit_breaker
464 .call(|| async { Err(OstiumError::Network("Failure".to_string())) })
465 .await;
466 assert!(result1.is_err());
467 assert_eq!(circuit_breaker.state(), CircuitState::Closed);
468
469 let result2: Result<String, OstiumError> = circuit_breaker
471 .call(|| async { Err(OstiumError::Network("Failure".to_string())) })
472 .await;
473 assert!(result2.is_err());
474 assert_eq!(circuit_breaker.state(), CircuitState::Open);
475
476 let result3 = circuit_breaker
478 .call(|| async { Ok("Should not execute".to_string()) })
479 .await;
480 assert!(result3.is_err());
481 assert!(result3
482 .unwrap_err()
483 .to_string()
484 .contains("Circuit breaker is open"));
485 }
486}