ostium_rust_sdk/
retry.rs

1use crate::error::OstiumError;
2use std::sync::atomic::{AtomicU64, Ordering};
3use std::sync::Arc;
4use std::time::Duration;
5use tokio::time::{sleep, Instant};
6use tracing::{debug, warn};
7
8/// Configuration for retry behavior
9#[derive(Debug, Clone)]
10pub struct RetryConfig {
11    /// Maximum number of retry attempts
12    pub max_attempts: u32,
13    /// Initial delay between retries
14    pub initial_delay: Duration,
15    /// Maximum delay between retries
16    pub max_delay: Duration,
17    /// Multiplier for exponential backoff
18    pub backoff_multiplier: f64,
19    /// Maximum jitter factor (0.0 to 1.0)
20    pub jitter_factor: f64,
21    /// Timeout for individual operations
22    pub operation_timeout: Duration,
23}
24
25impl Default for RetryConfig {
26    fn default() -> Self {
27        Self {
28            max_attempts: 3,
29            initial_delay: Duration::from_millis(100),
30            max_delay: Duration::from_secs(30),
31            backoff_multiplier: 2.0,
32            jitter_factor: 0.1,
33            operation_timeout: Duration::from_secs(30),
34        }
35    }
36}
37
38impl RetryConfig {
39    /// Create a config optimized for network operations
40    pub fn network() -> Self {
41        Self {
42            max_attempts: 5,
43            initial_delay: Duration::from_millis(200),
44            max_delay: Duration::from_secs(10),
45            backoff_multiplier: 1.5,
46            jitter_factor: 0.2,
47            operation_timeout: Duration::from_secs(30),
48        }
49    }
50
51    /// Create a config optimized for contract interactions
52    pub fn contract() -> Self {
53        Self {
54            max_attempts: 3,
55            initial_delay: Duration::from_millis(500),
56            max_delay: Duration::from_secs(20),
57            backoff_multiplier: 2.0,
58            jitter_factor: 0.1,
59            operation_timeout: Duration::from_secs(60),
60        }
61    }
62
63    /// Create a config optimized for GraphQL queries
64    pub fn graphql() -> Self {
65        Self {
66            max_attempts: 4,
67            initial_delay: Duration::from_millis(100),
68            max_delay: Duration::from_secs(5),
69            backoff_multiplier: 1.8,
70            jitter_factor: 0.15,
71            operation_timeout: Duration::from_secs(15),
72        }
73    }
74}
75
76/// Circuit breaker state
77#[derive(Debug, Clone, Copy, PartialEq)]
78pub enum CircuitState {
79    /// Circuit is closed - normal operation
80    Closed,
81    /// Circuit is open - blocking requests due to failures
82    Open,
83    /// Circuit is half-open - allowing limited requests to test recovery
84    HalfOpen,
85}
86
87/// Circuit breaker for preventing cascading failures
88#[derive(Debug)]
89pub struct CircuitBreaker {
90    state: Arc<AtomicU64>, // Packed: state (8 bits) + failure_count (24 bits) + last_failure_time (32 bits)
91    failure_threshold: u32,
92    recovery_timeout: Duration,
93    success_threshold: u32,
94}
95
96impl CircuitBreaker {
97    /// Create a new circuit breaker with the specified failure threshold and recovery timeout
98    pub fn new(failure_threshold: u32, recovery_timeout: Duration) -> Self {
99        Self {
100            state: Arc::new(AtomicU64::new(0)), // Initial state: Closed (0)
101            failure_threshold,
102            recovery_timeout,
103            success_threshold: 3,
104        }
105    }
106
107    /// Execute an operation through the circuit breaker
108    pub fn call<F, Fut, T>(
109        &self,
110        operation: F,
111    ) -> impl std::future::Future<Output = Result<T, OstiumError>>
112    where
113        F: FnOnce() -> Fut,
114        Fut: std::future::Future<Output = Result<T, OstiumError>>,
115    {
116        let state = self.state.clone();
117        let failure_threshold = self.failure_threshold;
118        let recovery_timeout = self.recovery_timeout;
119        let _success_threshold = self.success_threshold;
120
121        async move {
122            let current_state = Self::decode_state(state.load(Ordering::Acquire));
123
124            match current_state.0 {
125                CircuitState::Open => {
126                    let time_since_failure = Instant::now().duration_since(
127                        Instant::now() - Duration::from_secs(current_state.2 as u64),
128                    );
129
130                    if time_since_failure >= recovery_timeout {
131                        // Transition to half-open
132                        let new_packed = Self::encode_state(CircuitState::HalfOpen, 0, 0);
133                        state.store(new_packed, Ordering::Release);
134                        debug!("Circuit breaker transitioning to half-open");
135                    } else {
136                        return Err(OstiumError::Network(
137                            "Circuit breaker is open - too many recent failures".to_string(),
138                        ));
139                    }
140                }
141                CircuitState::HalfOpen => {
142                    // Allow limited requests through
143                }
144                CircuitState::Closed => {
145                    // Normal operation
146                }
147            }
148
149            match operation().await {
150                Ok(result) => {
151                    // Success - reset failure count or close circuit
152                    match current_state.0 {
153                        CircuitState::HalfOpen => {
154                            let new_packed = Self::encode_state(CircuitState::Closed, 0, 0);
155                            state.store(new_packed, Ordering::Release);
156                            debug!("Circuit breaker closed after successful recovery");
157                        }
158                        _ => {
159                            let new_packed = Self::encode_state(CircuitState::Closed, 0, 0);
160                            state.store(new_packed, Ordering::Release);
161                        }
162                    }
163                    Ok(result)
164                }
165                Err(error) => {
166                    // Failure - increment count and possibly open circuit
167                    let new_failure_count = current_state.1 + 1;
168                    let current_time = Instant::now().elapsed().as_secs() as u32;
169
170                    if new_failure_count >= failure_threshold {
171                        let new_packed =
172                            Self::encode_state(CircuitState::Open, new_failure_count, current_time);
173                        state.store(new_packed, Ordering::Release);
174                        warn!(
175                            "Circuit breaker opened after {} failures",
176                            new_failure_count
177                        );
178                    } else {
179                        let new_packed =
180                            Self::encode_state(current_state.0, new_failure_count, current_time);
181                        state.store(new_packed, Ordering::Release);
182                    }
183
184                    Err(error)
185                }
186            }
187        }
188    }
189
190    fn encode_state(state: CircuitState, failure_count: u32, last_failure_time: u32) -> u64 {
191        let state_bits = match state {
192            CircuitState::Closed => 0u64,
193            CircuitState::Open => 1u64,
194            CircuitState::HalfOpen => 2u64,
195        };
196
197        (state_bits << 56) | ((failure_count as u64 & 0xFFFFFF) << 32) | (last_failure_time as u64)
198    }
199
200    fn decode_state(packed: u64) -> (CircuitState, u32, u32) {
201        let state = match (packed >> 56) & 0xFF {
202            0 => CircuitState::Closed,
203            1 => CircuitState::Open,
204            2 => CircuitState::HalfOpen,
205            _ => CircuitState::Closed,
206        };
207        let failure_count = ((packed >> 32) & 0xFFFFFF) as u32;
208        let last_failure_time = (packed & 0xFFFFFFFF) as u32;
209
210        (state, failure_count, last_failure_time)
211    }
212
213    /// Get the current state of the circuit breaker
214    pub fn state(&self) -> CircuitState {
215        Self::decode_state(self.state.load(Ordering::Acquire)).0
216    }
217}
218
219/// Retry executor with exponential backoff and jitter
220pub struct RetryExecutor {
221    config: RetryConfig,
222    circuit_breaker: Option<CircuitBreaker>,
223}
224
225impl RetryExecutor {
226    /// Create a new retry executor with the specified configuration
227    pub fn new(config: RetryConfig) -> Self {
228        Self {
229            config,
230            circuit_breaker: None,
231        }
232    }
233
234    /// Add a circuit breaker to the retry executor
235    pub fn with_circuit_breaker(
236        mut self,
237        failure_threshold: u32,
238        recovery_timeout: Duration,
239    ) -> Self {
240        self.circuit_breaker = Some(CircuitBreaker::new(failure_threshold, recovery_timeout));
241        self
242    }
243
244    /// Execute an operation with retry logic
245    pub async fn execute<F, Fut, T>(&self, operation: F) -> Result<T, OstiumError>
246    where
247        F: Fn() -> Fut,
248        Fut: std::future::Future<Output = Result<T, OstiumError>>,
249    {
250        let mut attempt = 0;
251        let mut delay = self.config.initial_delay;
252
253        loop {
254            attempt += 1;
255
256            debug!(
257                "Executing operation attempt {}/{}",
258                attempt, self.config.max_attempts
259            );
260
261            // Use circuit breaker if configured
262            let result = if let Some(ref circuit_breaker) = self.circuit_breaker {
263                circuit_breaker.call(&operation).await
264            } else {
265                operation().await
266            };
267
268            match result {
269                Ok(value) => {
270                    if attempt > 1 {
271                        debug!("Operation succeeded after {} attempts", attempt);
272                    }
273                    return Ok(value);
274                }
275                Err(error) => {
276                    if !self.should_retry(&error) || attempt >= self.config.max_attempts {
277                        warn!("Operation failed after {} attempts: {}", attempt, error);
278                        return Err(error);
279                    }
280
281                    debug!(
282                        "Operation failed on attempt {}, retrying after {:?}: {}",
283                        attempt, delay, error
284                    );
285
286                    // Sleep with jitter
287                    let jittered_delay = self.add_jitter(delay);
288                    sleep(jittered_delay).await;
289
290                    // Calculate next delay with exponential backoff
291                    delay = std::cmp::min(
292                        Duration::from_millis(
293                            (delay.as_millis() as f64 * self.config.backoff_multiplier) as u64,
294                        ),
295                        self.config.max_delay,
296                    );
297                }
298            }
299        }
300    }
301
302    /// Determine if an error should trigger a retry
303    fn should_retry(&self, error: &OstiumError) -> bool {
304        match error {
305            // Always retry network errors
306            OstiumError::Network(_) => true,
307
308            // Retry HTTP errors that might be transient
309            OstiumError::Http(e) => e.is_timeout() || e.is_connect() || e.is_request(),
310
311            // Retry specific contract errors
312            OstiumError::Contract(msg) => {
313                msg.contains("timeout")
314                    || msg.contains("connection")
315                    || msg.contains("temporarily unavailable")
316                    || msg.contains("rate limit")
317            }
318
319            // Retry GraphQL errors that might be transient
320            OstiumError::GraphQL(msg) => {
321                msg.contains("timeout")
322                    || msg.contains("server error")
323                    || msg.contains("503")
324                    || msg.contains("502")
325                    || msg.contains("504")
326            }
327
328            // Retry provider errors that might be transient
329            OstiumError::Provider(msg) => {
330                msg.contains("timeout") || msg.contains("connection") || msg.contains("rate limit")
331            }
332
333            // Don't retry these errors as they're likely permanent
334            OstiumError::Validation(_) => false,
335            OstiumError::Wallet(_) => false,
336            OstiumError::Config(_) => false,
337            OstiumError::Json(_) => false,
338            OstiumError::Decimal(_) => false,
339            OstiumError::Other(_) => false,
340        }
341    }
342
343    /// Add jitter to delay to prevent thundering herd
344    fn add_jitter(&self, delay: Duration) -> Duration {
345        if self.config.jitter_factor <= 0.0 {
346            return delay;
347        }
348
349        let jitter_range = (delay.as_millis() as f64 * self.config.jitter_factor) as u64;
350        let jitter = fastrand::u64(0..=jitter_range);
351
352        Duration::from_millis(delay.as_millis() as u64 + jitter)
353    }
354}
355
356/// Convenience macros for different operation types
357/// Execute an operation with network-optimized retry settings
358#[macro_export]
359macro_rules! retry_network {
360    ($operation:expr) => {
361        $crate::retry::RetryExecutor::new($crate::retry::RetryConfig::network())
362            .execute(|| async { $operation })
363            .await
364    };
365}
366
367/// Execute an operation with contract-optimized retry settings and circuit breaker
368#[macro_export]
369macro_rules! retry_contract {
370    ($operation:expr) => {
371        $crate::retry::RetryExecutor::new($crate::retry::RetryConfig::contract())
372            .with_circuit_breaker(5, std::time::Duration::from_secs(60))
373            .execute(|| async { $operation })
374            .await
375    };
376}
377
378/// Execute an operation with GraphQL-optimized retry settings
379#[macro_export]
380macro_rules! retry_graphql {
381    ($operation:expr) => {
382        $crate::retry::RetryExecutor::new($crate::retry::RetryConfig::graphql())
383            .execute(|| async { $operation })
384            .await
385    };
386}
387
388// Re-export macros for convenience
389pub use retry_contract;
390pub use retry_graphql;
391pub use retry_network;
392
393#[cfg(test)]
394mod tests {
395    use super::*;
396    use std::sync::atomic::{AtomicU32, Ordering};
397    use std::sync::Arc;
398
399    #[tokio::test]
400    async fn test_retry_success_after_failures() {
401        let counter = Arc::new(AtomicU32::new(0));
402        let counter_clone = counter.clone();
403
404        let config = RetryConfig {
405            max_attempts: 3,
406            initial_delay: Duration::from_millis(10),
407            ..Default::default()
408        };
409
410        let executor = RetryExecutor::new(config);
411
412        let result = executor
413            .execute(|| {
414                let counter = counter_clone.clone();
415                async move {
416                    let count = counter.fetch_add(1, Ordering::SeqCst);
417                    if count < 2 {
418                        Err(OstiumError::Network("Temporary failure".to_string()))
419                    } else {
420                        Ok("Success".to_string())
421                    }
422                }
423            })
424            .await;
425
426        assert!(result.is_ok());
427        assert_eq!(result.unwrap(), "Success");
428        assert_eq!(counter.load(Ordering::SeqCst), 3);
429    }
430
431    #[tokio::test]
432    async fn test_retry_exhaustion() {
433        let counter = Arc::new(AtomicU32::new(0));
434        let counter_clone = counter.clone();
435
436        let config = RetryConfig {
437            max_attempts: 2,
438            initial_delay: Duration::from_millis(10),
439            ..Default::default()
440        };
441
442        let executor = RetryExecutor::new(config);
443
444        let result: Result<String, OstiumError> = executor
445            .execute(|| {
446                let counter = counter_clone.clone();
447                async move {
448                    counter.fetch_add(1, Ordering::SeqCst);
449                    Err(OstiumError::Network("Permanent failure".to_string()))
450                }
451            })
452            .await;
453
454        assert!(result.is_err());
455        assert_eq!(counter.load(Ordering::SeqCst), 2);
456    }
457
458    #[tokio::test]
459    async fn test_circuit_breaker() {
460        let circuit_breaker = CircuitBreaker::new(2, Duration::from_millis(100));
461
462        // First failure
463        let result1: Result<String, OstiumError> = circuit_breaker
464            .call(|| async { Err(OstiumError::Network("Failure".to_string())) })
465            .await;
466        assert!(result1.is_err());
467        assert_eq!(circuit_breaker.state(), CircuitState::Closed);
468
469        // Second failure - should open circuit
470        let result2: Result<String, OstiumError> = circuit_breaker
471            .call(|| async { Err(OstiumError::Network("Failure".to_string())) })
472            .await;
473        assert!(result2.is_err());
474        assert_eq!(circuit_breaker.state(), CircuitState::Open);
475
476        // Third call should be rejected immediately
477        let result3 = circuit_breaker
478            .call(|| async { Ok("Should not execute".to_string()) })
479            .await;
480        assert!(result3.is_err());
481        assert!(result3
482            .unwrap_err()
483            .to_string()
484            .contains("Circuit breaker is open"));
485    }
486}