rs-zero 0.2.6

Rust-first microservice framework inspired by go-zero engineering practices
Documentation
use std::{future::Future, sync::Arc, time::Duration};

use thiserror::Error;
use tokio::sync::Mutex;

use crate::resil::{WindowConfig, WindowSnapshot, breaker_state::CircuitBreakerState};

/// Circuit breaker state.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum BreakerState {
    /// Calls are allowed and failures are counted.
    Closed,
    /// Calls are rejected until the reset timeout elapses.
    Open,
    /// A limited number of trial calls may decide whether to close again.
    HalfOpen,
}

/// Circuit breaker configuration.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct BreakerConfig {
    /// Number of consecutive failures that opens the breaker.
    pub failure_threshold: u32,
    /// Time before an open breaker allows a half-open trial call.
    pub reset_timeout: Duration,
}

impl Default for BreakerConfig {
    fn default() -> Self {
        Self {
            failure_threshold: 5,
            reset_timeout: Duration::from_secs(30),
        }
    }
}

/// Advanced breaker policy used by production-oriented breakers.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct BreakerPolicyConfig {
    /// Rolling window configuration used for aggregate decisions.
    pub window: WindowConfig,
    /// Minimum request count before aggregate failure ratio can open or drop.
    pub min_request_count: u64,
    /// Failure ratio percentage that opens the breaker when enough samples exist.
    pub failure_ratio_percent: u8,
    /// Deterministic drop percentage while the rolling window is unhealthy.
    pub drop_ratio_percent: u8,
    /// Maximum concurrent trial calls in half-open state.
    pub half_open_max_calls: u32,
    /// Minimum interval between forced trial calls while open.
    pub force_pass_interval: Duration,
    /// Enables Google SRE style client-side throttling while closed.
    pub sre_rejection_enabled: bool,
    /// SRE throttling multiplier in millis. `1500` means `k = 1.5`.
    pub sre_k_millis: u32,
    /// Minimum total samples before SRE throttling can reject requests.
    pub sre_protection: u64,
}

impl Default for BreakerPolicyConfig {
    fn default() -> Self {
        Self {
            window: WindowConfig::default(),
            min_request_count: 20,
            failure_ratio_percent: 50,
            drop_ratio_percent: 20,
            half_open_max_calls: 1,
            force_pass_interval: Duration::from_secs(5),
            sre_rejection_enabled: false,
            sre_k_millis: 1500,
            sre_protection: 5,
        }
    }
}

impl BreakerPolicyConfig {
    /// Returns a policy using Google SRE style adaptive rejection.
    pub fn google_sre() -> Self {
        Self {
            sre_rejection_enabled: true,
            drop_ratio_percent: 0,
            failure_ratio_percent: 100,
            ..Self::default()
        }
    }
}

/// Snapshot of breaker state and rolling statistics.
#[derive(Debug, Clone, PartialEq)]
pub struct BreakerSnapshot {
    /// Current breaker state.
    pub state: BreakerState,
    /// Consecutive backend failures.
    pub consecutive_failures: u32,
    /// Current half-open trial calls.
    pub half_open_in_flight: u32,
    /// Rolling window statistics.
    pub window: WindowSnapshot,
}

/// Error returned when a breaker rejects a call.
#[derive(Debug, Error, Clone, Copy, PartialEq, Eq)]
pub enum BreakerError {
    /// The breaker is open.
    #[error("circuit breaker is open")]
    Open,
    /// The breaker probabilistically dropped the call while unhealthy.
    #[error("circuit breaker dropped request")]
    Dropped,
}

/// Error returned by protected breaker calls.
#[derive(Debug, Error, PartialEq, Eq)]
pub enum BreakerCallError<E> {
    /// The breaker rejected the call before the operation ran.
    #[error(transparent)]
    Rejected(#[from] BreakerError),
    /// The protected operation returned an error.
    #[error("protected call failed: {0}")]
    Inner(E),
}

/// Small circuit breaker suitable for local protection and tests.
#[derive(Debug)]
pub struct CircuitBreaker {
    state: CircuitBreakerState,
}

impl CircuitBreaker {
    /// Creates a breaker that opens after `failure_threshold` failures.
    pub fn new(failure_threshold: u32, reset_timeout: Duration) -> Self {
        Self {
            state: CircuitBreakerState::new(
                BreakerConfig {
                    failure_threshold,
                    reset_timeout,
                },
                BreakerPolicyConfig::default(),
            ),
        }
    }

    /// Returns the current state, applying reset timeout transition if needed.
    pub fn state(&mut self) -> BreakerState {
        self.state.state()
    }

    /// Returns whether the next call may proceed.
    pub fn allow(&mut self) -> bool {
        self.state.allow().is_ok()
    }

    /// Records a successful call.
    pub fn record_success(&mut self) {
        self.state.record_success();
    }

    /// Records a failed call.
    pub fn record_failure(&mut self) {
        self.state.record_failure();
    }
}

/// Thread-safe circuit breaker handle for async services.
#[derive(Debug, Clone)]
pub struct SharedCircuitBreaker {
    state: Arc<Mutex<CircuitBreakerState>>,
}

impl SharedCircuitBreaker {
    /// Creates a shared circuit breaker from configuration.
    pub fn new(config: BreakerConfig) -> Self {
        Self::with_policy(config, BreakerPolicyConfig::default())
    }

    /// Creates a shared circuit breaker with advanced rolling-window policy.
    pub fn with_policy(config: BreakerConfig, policy: BreakerPolicyConfig) -> Self {
        Self {
            state: Arc::new(Mutex::new(CircuitBreakerState::new(config, policy))),
        }
    }

    /// Attempts to enter the protected section.
    pub async fn allow(&self) -> Result<BreakerGuard, BreakerError> {
        self.state.lock().await.allow()?;
        Ok(BreakerGuard {
            breaker: self.clone(),
            completed: false,
        })
    }

    /// Runs a protected async operation and records success or failure.
    pub async fn do_request<F, Fut, T, E>(&self, request: F) -> Result<T, BreakerCallError<E>>
    where
        F: FnOnce() -> Fut,
        Fut: Future<Output = Result<T, E>>,
    {
        self.do_with_acceptable(request, |_| false).await
    }

    /// Runs a protected operation with a fallback used only for breaker rejection.
    pub async fn do_with_fallback<F, Fut, Fb, FbFut, T, E>(
        &self,
        request: F,
        fallback: Fb,
    ) -> Result<T, E>
    where
        F: FnOnce() -> Fut,
        Fut: Future<Output = Result<T, E>>,
        Fb: FnOnce(BreakerError) -> FbFut,
        FbFut: Future<Output = Result<T, E>>,
    {
        let guard = match self.allow().await {
            Ok(guard) => guard,
            Err(error) => return fallback(error).await,
        };

        match request().await {
            Ok(value) => {
                guard.record_success().await;
                Ok(value)
            }
            Err(error) => {
                guard.record_failure().await;
                Err(error)
            }
        }
    }

    /// Runs a protected operation and lets callers mark some errors as acceptable.
    pub async fn do_with_acceptable<F, Fut, T, E, A>(
        &self,
        request: F,
        acceptable: A,
    ) -> Result<T, BreakerCallError<E>>
    where
        F: FnOnce() -> Fut,
        Fut: Future<Output = Result<T, E>>,
        A: Fn(&E) -> bool,
    {
        let guard = self.allow().await?;
        match request().await {
            Ok(value) => {
                guard.record_success().await;
                Ok(value)
            }
            Err(error) if acceptable(&error) => {
                guard.record_success().await;
                Err(BreakerCallError::Inner(error))
            }
            Err(error) => {
                guard.record_failure().await;
                Err(BreakerCallError::Inner(error))
            }
        }
    }

    /// Returns the current state.
    pub async fn state(&self) -> BreakerState {
        self.state.lock().await.state()
    }

    /// Returns a snapshot with state and rolling statistics.
    pub async fn snapshot(&self) -> BreakerSnapshot {
        self.state.lock().await.snapshot()
    }

    async fn record_success(&self) {
        self.state.lock().await.record_success();
    }

    async fn record_failure(&self) {
        self.state.lock().await.record_failure();
    }
}

/// Guard returned by [`SharedCircuitBreaker::allow`].
#[derive(Debug)]
pub struct BreakerGuard {
    breaker: SharedCircuitBreaker,
    completed: bool,
}

impl BreakerGuard {
    /// Marks the protected operation as successful.
    pub async fn record_success(mut self) {
        self.breaker.record_success().await;
        self.completed = true;
    }

    /// Marks the protected operation as failed.
    pub async fn record_failure(mut self) {
        self.breaker.record_failure().await;
        self.completed = true;
    }
}

impl Drop for BreakerGuard {
    fn drop(&mut self) {
        if !self.completed {
            let breaker = self.breaker.clone();
            if let Ok(handle) = tokio::runtime::Handle::try_current() {
                handle.spawn(async move {
                    breaker.record_failure().await;
                });
            }
        }
    }
}