rs-zero 0.2.6

Rust-first microservice framework inspired by go-zero engineering practices
Documentation
use std::{future::Future, sync::Arc};

use tokio::{sync::Semaphore, time::timeout};
use tonic::{Code, Status};

use crate::{
    resil::{
        AdaptiveShedderConfig, BreakerConfig, BreakerGuard, BreakerPolicyConfig, BreakerRegistry,
        ShedderGuard, ShedderRegistry, WindowConfig,
    },
    rpc::{RpcResilienceConfig, rpc_resilience_key},
};

#[cfg(feature = "cache-redis")]
use crate::rpc::limiter::{RpcRateLimitOutcome, RpcRateLimiter};

#[cfg(feature = "observability")]
type RpcMetrics = Option<crate::observability::MetricsRegistry>;
#[cfg(not(feature = "observability"))]
type RpcMetrics = ();

/// Reusable RPC resilience helper for unary tonic calls.
#[derive(Debug, Clone)]
pub struct RpcResilienceLayer {
    service: String,
    config: RpcResilienceConfig,
    breakers: BreakerRegistry,
    shedders: ShedderRegistry,
    semaphore: Option<Arc<Semaphore>>,
    #[cfg(feature = "cache-redis")]
    limiter: RpcRateLimiter,
    metrics: RpcMetrics,
}

impl RpcResilienceLayer {
    /// Creates a helper with shared breaker, shedder and concurrency state.
    pub fn new(service: impl Into<String>, config: RpcResilienceConfig) -> Self {
        let semaphore = config
            .max_concurrency
            .map(|max| Arc::new(Semaphore::new(max)));
        #[cfg(feature = "cache-redis")]
        let limiter = RpcRateLimiter::new(config.rate_limiter.clone());
        Self {
            service: service.into(),
            config,
            breakers: BreakerRegistry::new(),
            shedders: ShedderRegistry::new(),
            semaphore,
            #[cfg(feature = "cache-redis")]
            limiter,
            metrics: default_metrics(),
        }
    }

    /// Attaches a metrics registry to resilience decisions.
    #[cfg(feature = "observability")]
    pub fn with_metrics(mut self, metrics: crate::observability::MetricsRegistry) -> Self {
        self.metrics = Some(metrics);
        self
    }

    /// Runs one unary call through timeout, concurrency, shedding and breaker checks.
    pub async fn run_unary<F, Fut, T>(&self, method: &str, call: F) -> Result<T, Status>
    where
        F: FnOnce() -> Fut,
        Fut: Future<Output = Result<T, Status>>,
    {
        #[cfg(feature = "observability")]
        {
            return crate::observability::observe_rpc_unary(
                self.metrics.as_ref(),
                &self.service,
                method,
                None,
                self.run_unary_inner(method, call),
            )
            .await;
        }

        #[cfg(not(feature = "observability"))]
        {
            self.run_unary_inner(method, call).await
        }
    }

    /// Runs one unary call with request metadata available to observability.
    #[cfg(feature = "observability")]
    pub async fn run_unary_with_metadata<F, Fut, T>(
        &self,
        method: &str,
        metadata: &tonic::metadata::MetadataMap,
        call: F,
    ) -> Result<T, Status>
    where
        F: FnOnce() -> Fut,
        Fut: Future<Output = Result<T, Status>>,
    {
        crate::observability::observe_rpc_unary_with_metadata(
            self.metrics.as_ref(),
            &self.service,
            method,
            metadata,
            self.run_unary_inner(method, call),
        )
        .await
    }

    #[doc(hidden)]
    pub async fn run_unary_inner_public<F, Fut, T>(
        &self,
        method: &str,
        call: F,
    ) -> Result<T, Status>
    where
        F: FnOnce() -> Fut,
        Fut: Future<Output = Result<T, Status>>,
    {
        self.run_unary_inner(method, call).await
    }

    async fn run_unary_inner<F, Fut, T>(&self, method: &str, call: F) -> Result<T, Status>
    where
        F: FnOnce() -> Fut,
        Fut: Future<Output = Result<T, Status>>,
    {
        let _permit = self.acquire_concurrency().await?;
        let key = rpc_resilience_key(&self.service, method);
        self.acquire_limiter(&key).await?;
        let shedder = self.acquire_shedder(&key).await?;
        let breaker = self.acquire_breaker(&key).await?;

        let result = timeout(self.config.request_timeout, call()).await;
        match result {
            Ok(Ok(value)) => {
                record_success(breaker, shedder).await;
                self.record("rpc", "success");
                Ok(value)
            }
            Ok(Err(status)) if status_counts_as_failure(&status) => {
                record_failure(breaker, shedder).await;
                self.record("rpc", "failure");
                Err(status)
            }
            Ok(Err(status)) => {
                record_success(breaker, shedder).await;
                self.record("rpc", "acceptable_error");
                Err(status)
            }
            Err(_) => {
                record_failure(breaker, shedder).await;
                self.record("rpc", "timeout");
                Err(Status::deadline_exceeded("rpc request timed out"))
            }
        }
    }

    async fn acquire_concurrency(
        &self,
    ) -> Result<Option<tokio::sync::OwnedSemaphorePermit>, Status> {
        let Some(semaphore) = &self.semaphore else {
            return Ok(None);
        };

        match semaphore.clone().try_acquire_owned() {
            Ok(permit) => {
                self.record("concurrency", "allowed");
                Ok(Some(permit))
            }
            Err(_) => {
                self.record("concurrency", "rejected");
                Err(Status::resource_exhausted("rpc concurrency limit reached"))
            }
        }
    }

    async fn acquire_shedder(&self, key: &str) -> Result<Option<ShedderGuard>, Status> {
        if !self.config.shedding_enabled {
            return Ok(None);
        }

        let shedder = self
            .shedders
            .get_or_insert(key, shedder_config(&self.config))
            .await;
        match shedder.allow().await {
            Ok(guard) => {
                self.record("shedder", "pass");
                Ok(Some(guard))
            }
            Err(_) => {
                self.record("shedder", "drop");
                Err(Status::resource_exhausted("rpc service overloaded"))
            }
        }
    }

    #[cfg(feature = "cache-redis")]
    async fn acquire_limiter(&self, key: &str) -> Result<(), Status> {
        match self.limiter.allow(key).await {
            RpcRateLimitOutcome::Allowed => {
                self.record("limiter", "allowed");
                Ok(())
            }
            RpcRateLimitOutcome::Rejected => {
                self.record("limiter", "rejected");
                Err(Status::resource_exhausted("rpc rate limit exceeded"))
            }
            RpcRateLimitOutcome::ErrorOpen => {
                self.record("limiter", "error_open");
                Ok(())
            }
            RpcRateLimitOutcome::ErrorClosed(error) => {
                self.record("limiter", "error_closed");
                Err(Status::unavailable(format!(
                    "rpc rate limiter unavailable: {error}"
                )))
            }
        }
    }

    #[cfg(not(feature = "cache-redis"))]
    async fn acquire_limiter(&self, _key: &str) -> Result<(), Status> {
        Ok(())
    }

    async fn acquire_breaker(&self, key: &str) -> Result<Option<BreakerGuard>, Status> {
        if !self.config.breaker_enabled {
            return Ok(None);
        }

        let breaker = self
            .breakers
            .get_or_insert_with_policy(
                key,
                BreakerConfig {
                    failure_threshold: self.config.breaker_failure_threshold,
                    reset_timeout: self.config.breaker_reset_timeout,
                },
                breaker_policy(&self.config),
            )
            .await;
        match breaker.allow().await {
            Ok(guard) => {
                self.record("breaker", "allowed");
                Ok(Some(guard))
            }
            Err(error) => {
                self.record("breaker", "dropped");
                Err(Status::unavailable(error.to_string()))
            }
        }
    }

    fn record(&self, component: &str, outcome: &str) {
        record_metric(&self.metrics, component, outcome);
    }
}

async fn record_success(breaker: Option<BreakerGuard>, shedder: Option<ShedderGuard>) {
    if let Some(guard) = breaker {
        guard.record_success().await;
    }
    if let Some(guard) = shedder {
        guard.record_success().await;
    }
}

async fn record_failure(breaker: Option<BreakerGuard>, shedder: Option<ShedderGuard>) {
    if let Some(guard) = breaker {
        guard.record_failure().await;
    }
    if let Some(guard) = shedder {
        guard.record_failure().await;
    }
}

fn shedder_config(config: &RpcResilienceConfig) -> AdaptiveShedderConfig {
    AdaptiveShedderConfig {
        max_in_flight: config
            .shedding_max_in_flight
            .or(config.max_concurrency)
            .unwrap_or(1024),
        min_request_count: config.shedding_min_request_count,
        max_latency: config.shedding_max_latency,
        cpu_threshold_millis: config.shedding_cpu_threshold_millis,
        cool_off: config.shedding_cool_off,
        window: WindowConfig {
            buckets: config.shedding_window_buckets,
            bucket_duration: config.shedding_window_bucket_duration,
        },
        ..AdaptiveShedderConfig::default()
    }
}

fn breaker_policy(config: &RpcResilienceConfig) -> BreakerPolicyConfig {
    let mut policy = if config.breaker_sre_enabled {
        BreakerPolicyConfig::google_sre()
    } else {
        BreakerPolicyConfig::default()
    };
    policy.sre_k_millis = config.breaker_sre_k_millis;
    policy.sre_protection = config.breaker_sre_protection;
    policy
}

/// Returns whether a tonic status should count as a resilience failure.
pub fn status_counts_as_failure(status: &Status) -> bool {
    matches!(
        status.code(),
        Code::Unknown
            | Code::DeadlineExceeded
            | Code::ResourceExhausted
            | Code::Internal
            | Code::Unavailable
            | Code::DataLoss
    )
}

#[cfg(feature = "observability")]
fn default_metrics() -> RpcMetrics {
    None
}

#[cfg(not(feature = "observability"))]
fn default_metrics() -> RpcMetrics {}

#[cfg(feature = "observability")]
fn record_metric(metrics: &RpcMetrics, component: &str, outcome: &str) {
    crate::observability::record_resilience_decision(metrics.as_ref(), "grpc", component, outcome);
}

#[cfg(not(feature = "observability"))]
fn record_metric(_metrics: &RpcMetrics, _component: &str, _outcome: &str) {}

#[cfg(test)]
mod tests {
    use std::time::Duration;

    use super::{RpcResilienceLayer, status_counts_as_failure};
    use crate::rpc::RpcResilienceConfig;

    #[tokio::test]
    async fn rpc_layer_maps_concurrency_rejection() {
        let layer = RpcResilienceLayer::new(
            "hello",
            RpcResilienceConfig {
                max_concurrency: Some(1),
                ..RpcResilienceConfig::default()
            },
        );
        let held = layer.clone();
        let handle = tokio::spawn(async move {
            held.run_unary("Say", || async {
                tokio::time::sleep(Duration::from_millis(40)).await;
                Ok::<_, tonic::Status>("ok")
            })
            .await
        });
        tokio::time::sleep(Duration::from_millis(5)).await;

        let rejected = layer.run_unary("Say", || async { Ok("ok") }).await;

        assert_eq!(
            rejected.expect_err("rejected").code(),
            tonic::Code::ResourceExhausted
        );
        assert_eq!(handle.await.expect("join").expect("first"), "ok");
    }

    #[tokio::test]
    async fn rpc_layer_opens_breaker_after_failure() {
        let layer = RpcResilienceLayer::new(
            "hello",
            RpcResilienceConfig {
                breaker_enabled: true,
                breaker_failure_threshold: 1,
                breaker_reset_timeout: Duration::from_secs(60),
                ..RpcResilienceConfig::default()
            },
        );
        let _ = layer
            .run_unary("Say", || async {
                Err::<(), _>(tonic::Status::internal("down"))
            })
            .await;

        let rejected = layer.run_unary("Say", || async { Ok(()) }).await;

        assert_eq!(rejected.expect_err("open").code(), tonic::Code::Unavailable);
    }

    #[test]
    fn invalid_argument_is_acceptable_for_breaker() {
        assert!(!status_counts_as_failure(&tonic::Status::invalid_argument(
            "bad"
        )));
        assert!(status_counts_as_failure(&tonic::Status::unavailable(
            "down"
        )));
    }
}