mx-core 0.1.0

Core utilities for MultiversX Rust services.
Documentation
//! Circuit breaker trait and state types for fault tolerance.

use std::time::Duration;

/// Circuit breaker states.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(u8)]
pub enum CircuitState {
    /// Circuit is closed, requests flow normally.
    Closed = 0,
    /// Circuit is open, requests are rejected immediately.
    Open = 1,
    /// Circuit is testing, allowing a single request through.
    HalfOpen = 2,
}

impl From<u8> for CircuitState {
    fn from(value: u8) -> Self {
        match value {
            0 => Self::Closed,
            1 => Self::Open,
            2 => Self::HalfOpen,
            _ => Self::Closed,
        }
    }
}

/// Configuration for the circuit breaker.
#[derive(Debug, Clone)]
pub struct CircuitBreakerConfig {
    /// Number of consecutive failures before opening the circuit.
    pub failure_threshold: u32,
    /// Duration to wait before transitioning from Open to HalfOpen.
    pub recovery_timeout: Duration,
    /// Number of successes required in half-open state to close (default: 1).
    pub success_threshold: u32,
}

impl Default for CircuitBreakerConfig {
    fn default() -> Self {
        Self {
            failure_threshold: 5,
            recovery_timeout: Duration::from_secs(30),
            success_threshold: 1,
        }
    }
}

impl CircuitBreakerConfig {
    /// Creates a new config with the specified threshold and timeout.
    pub fn new(failure_threshold: u32, recovery_timeout: Duration) -> Self {
        Self {
            failure_threshold,
            recovery_timeout,
            success_threshold: 1,
        }
    }

    /// Sets the number of successes needed in half-open state to close.
    pub fn with_success_threshold(mut self, threshold: u32) -> Self {
        self.success_threshold = threshold;
        self
    }
}

/// Trait for circuit breaker pattern implementation.
///
/// The circuit breaker prevents cascading failures by:
/// 1. Tracking failure rates
/// 2. Opening the circuit when failures exceed threshold
/// 3. Periodically testing if the service has recovered
/// 4. Closing the circuit when recovery is confirmed
pub trait CircuitBreaker: Send + Sync {
    /// Returns the current circuit state.
    fn state(&self) -> CircuitState;

    /// Returns whether the circuit is currently open (rejecting requests).
    fn is_open(&self) -> bool {
        self.state() == CircuitState::Open
    }

    /// Returns whether the circuit allows requests through.
    fn allows_request(&self) -> bool {
        self.state() != CircuitState::Open
    }

    /// Records a successful operation.
    fn record_success(&self);

    /// Records a failed operation.
    fn record_failure(&self);

    /// Manually opens the circuit.
    fn trip(&self);

    /// Manually closes the circuit.
    fn reset(&self);

    /// Returns the time until the circuit will transition to half-open.
    /// Returns `None` if the circuit is not open.
    fn time_until_half_open(&self) -> Option<Duration>;

    /// Returns the current failure count.
    fn failure_count(&self) -> u32;
}

/// Thread-safe circuit breaker using atomic operations.
pub struct AtomicCircuitBreaker {
    pub state_val: std::sync::atomic::AtomicU8,
    failure_count_val: std::sync::atomic::AtomicU32,
    success_count_val: std::sync::atomic::AtomicU32,
    last_failure: parking_lot::Mutex<Option<std::time::Instant>>,
    config: CircuitBreakerConfig,
}

impl AtomicCircuitBreaker {
    /// Creates a new circuit breaker with the given configuration.
    pub fn new(config: CircuitBreakerConfig) -> Self {
        Self {
            state_val: std::sync::atomic::AtomicU8::new(CircuitState::Closed as u8),
            failure_count_val: std::sync::atomic::AtomicU32::new(0),
            success_count_val: std::sync::atomic::AtomicU32::new(0),
            last_failure: parking_lot::Mutex::new(None),
            config,
        }
    }

    /// Creates a circuit breaker with default configuration.
    pub fn with_defaults() -> Self {
        Self::new(CircuitBreakerConfig::default())
    }
}

impl CircuitBreaker for AtomicCircuitBreaker {
    fn state(&self) -> CircuitState {
        use std::sync::atomic::Ordering;

        let state = CircuitState::from(self.state_val.load(Ordering::Acquire));

        // Check if we should transition from Open to HalfOpen
        if state == CircuitState::Open
            && let Some(last) = *self.last_failure.lock()
            && last.elapsed() >= self.config.recovery_timeout
            && self
                .state_val
                .compare_exchange(
                    CircuitState::Open as u8,
                    CircuitState::HalfOpen as u8,
                    Ordering::AcqRel,
                    Ordering::Acquire,
                )
                .is_ok()
        {
            self.success_count_val.store(0, Ordering::Release);
            return CircuitState::HalfOpen;
        }

        state
    }

    fn record_success(&self) {
        use std::sync::atomic::Ordering;

        let state = CircuitState::from(self.state_val.load(Ordering::Acquire));

        match state {
            CircuitState::Closed => {
                self.failure_count_val.store(0, Ordering::Release);
            }
            CircuitState::HalfOpen => {
                let successes = self.success_count_val.fetch_add(1, Ordering::AcqRel) + 1;
                if successes >= self.config.success_threshold {
                    self.state_val
                        .store(CircuitState::Closed as u8, Ordering::Release);
                    self.failure_count_val.store(0, Ordering::Release);
                    self.success_count_val.store(0, Ordering::Release);
                    *self.last_failure.lock() = None;
                }
            }
            CircuitState::Open => {}
        }
    }

    fn record_failure(&self) {
        use std::sync::atomic::Ordering;

        *self.last_failure.lock() = Some(std::time::Instant::now());

        let state = CircuitState::from(self.state_val.load(Ordering::Acquire));

        match state {
            CircuitState::Closed => {
                let count = self.failure_count_val.fetch_add(1, Ordering::AcqRel) + 1;
                if count >= self.config.failure_threshold {
                    self.state_val
                        .store(CircuitState::Open as u8, Ordering::Release);
                }
            }
            CircuitState::HalfOpen => {
                self.state_val
                    .store(CircuitState::Open as u8, Ordering::Release);
                self.success_count_val.store(0, Ordering::Release);
            }
            CircuitState::Open => {}
        }
    }

    fn trip(&self) {
        use std::sync::atomic::Ordering;
        self.state_val
            .store(CircuitState::Open as u8, Ordering::Release);
        *self.last_failure.lock() = Some(std::time::Instant::now());
    }

    fn reset(&self) {
        use std::sync::atomic::Ordering;
        self.state_val
            .store(CircuitState::Closed as u8, Ordering::Release);
        self.failure_count_val.store(0, Ordering::Release);
        self.success_count_val.store(0, Ordering::Release);
        *self.last_failure.lock() = None;
    }

    fn time_until_half_open(&self) -> Option<Duration> {
        use std::sync::atomic::Ordering;

        let state = CircuitState::from(self.state_val.load(Ordering::Acquire));
        if state != CircuitState::Open {
            return None;
        }

        let last = (*self.last_failure.lock())?;
        let elapsed = last.elapsed();

        if elapsed >= self.config.recovery_timeout {
            Some(Duration::ZERO)
        } else {
            Some(self.config.recovery_timeout - elapsed)
        }
    }

    fn failure_count(&self) -> u32 {
        self.failure_count_val
            .load(std::sync::atomic::Ordering::Acquire)
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use std::thread;

    #[test]
    fn test_starts_closed() {
        let cb = AtomicCircuitBreaker::with_defaults();
        assert_eq!(cb.state(), CircuitState::Closed);
        assert!(cb.allows_request());
    }

    #[test]
    fn test_opens_after_threshold() {
        let config = CircuitBreakerConfig::new(3, Duration::from_secs(30));
        let cb = AtomicCircuitBreaker::new(config);

        cb.record_failure();
        cb.record_failure();
        assert_eq!(cb.state(), CircuitState::Closed);

        cb.record_failure();
        assert_eq!(cb.state(), CircuitState::Open);
        assert!(!cb.allows_request());
    }

    #[test]
    fn test_transitions_to_half_open() {
        let config = CircuitBreakerConfig::new(1, Duration::from_millis(10));
        let cb = AtomicCircuitBreaker::new(config);

        cb.record_failure();
        assert_eq!(cb.state(), CircuitState::Open);

        thread::sleep(Duration::from_millis(15));
        assert_eq!(cb.state(), CircuitState::HalfOpen);
    }

    #[test]
    fn test_closes_on_success_in_half_open() {
        let config = CircuitBreakerConfig::new(1, Duration::from_millis(10));
        let cb = AtomicCircuitBreaker::new(config);

        cb.record_failure();
        thread::sleep(Duration::from_millis(15));
        assert_eq!(cb.state(), CircuitState::HalfOpen);

        cb.record_success();
        assert_eq!(cb.state(), CircuitState::Closed);
    }

    #[test]
    fn test_success_threshold() {
        let config =
            CircuitBreakerConfig::new(1, Duration::from_millis(10)).with_success_threshold(3);
        let cb = AtomicCircuitBreaker::new(config);

        cb.record_failure();
        thread::sleep(Duration::from_millis(15));
        assert_eq!(cb.state(), CircuitState::HalfOpen);

        cb.record_success();
        assert_eq!(cb.state(), CircuitState::HalfOpen); // Need 3 successes

        cb.record_success();
        assert_eq!(cb.state(), CircuitState::HalfOpen); // Need 1 more

        cb.record_success();
        assert_eq!(cb.state(), CircuitState::Closed); // Done
    }

    #[test]
    fn test_trip_and_reset() {
        let cb = AtomicCircuitBreaker::with_defaults();
        cb.trip();
        assert_eq!(cb.state(), CircuitState::Open);

        cb.reset();
        assert_eq!(cb.state(), CircuitState::Closed);
    }

    #[test]
    fn test_thread_safety() {
        use std::sync::Arc;

        let config = CircuitBreakerConfig::new(100, Duration::from_secs(30));
        let cb = Arc::new(AtomicCircuitBreaker::new(config));

        let handles: Vec<_> = (0..10)
            .map(|_| {
                let cb = Arc::clone(&cb);
                thread::spawn(move || {
                    for _ in 0..10 {
                        cb.record_failure();
                    }
                })
            })
            .collect();

        for h in handles {
            h.join().unwrap();
        }

        assert_eq!(cb.state(), CircuitState::Open);
    }
}