use std::{future::Future, sync::Arc};
use tokio::{sync::Semaphore, time::timeout};
use tonic::{Code, Status};
use crate::{
resil::{
AdaptiveShedderConfig, BreakerConfig, BreakerGuard, BreakerPolicyConfig, BreakerRegistry,
ShedderGuard, ShedderRegistry, WindowConfig,
},
rpc::{RpcResilienceConfig, rpc_resilience_key},
};
#[cfg(feature = "cache-redis")]
use crate::rpc::limiter::{RpcRateLimitOutcome, RpcRateLimiter};
#[cfg(feature = "observability")]
type RpcMetrics = Option<crate::observability::MetricsRegistry>;
#[cfg(not(feature = "observability"))]
type RpcMetrics = ();
#[derive(Debug, Clone)]
pub struct RpcResilienceLayer {
service: String,
config: RpcResilienceConfig,
breakers: BreakerRegistry,
shedders: ShedderRegistry,
semaphore: Option<Arc<Semaphore>>,
#[cfg(feature = "cache-redis")]
limiter: RpcRateLimiter,
metrics: RpcMetrics,
}
impl RpcResilienceLayer {
pub fn new(service: impl Into<String>, config: RpcResilienceConfig) -> Self {
let semaphore = config
.max_concurrency
.map(|max| Arc::new(Semaphore::new(max)));
#[cfg(feature = "cache-redis")]
let limiter = RpcRateLimiter::new(config.rate_limiter.clone());
Self {
service: service.into(),
config,
breakers: BreakerRegistry::new(),
shedders: ShedderRegistry::new(),
semaphore,
#[cfg(feature = "cache-redis")]
limiter,
metrics: default_metrics(),
}
}
#[cfg(feature = "observability")]
pub fn with_metrics(mut self, metrics: crate::observability::MetricsRegistry) -> Self {
self.metrics = Some(metrics);
self
}
pub async fn run_unary<F, Fut, T>(&self, method: &str, call: F) -> Result<T, Status>
where
F: FnOnce() -> Fut,
Fut: Future<Output = Result<T, Status>>,
{
#[cfg(feature = "observability")]
{
return crate::observability::observe_rpc_unary(
self.metrics.as_ref(),
&self.service,
method,
None,
self.run_unary_inner(method, call),
)
.await;
}
#[cfg(not(feature = "observability"))]
{
self.run_unary_inner(method, call).await
}
}
#[cfg(feature = "observability")]
pub async fn run_unary_with_metadata<F, Fut, T>(
&self,
method: &str,
metadata: &tonic::metadata::MetadataMap,
call: F,
) -> Result<T, Status>
where
F: FnOnce() -> Fut,
Fut: Future<Output = Result<T, Status>>,
{
crate::observability::observe_rpc_unary_with_metadata(
self.metrics.as_ref(),
&self.service,
method,
metadata,
self.run_unary_inner(method, call),
)
.await
}
#[doc(hidden)]
pub async fn run_unary_inner_public<F, Fut, T>(
&self,
method: &str,
call: F,
) -> Result<T, Status>
where
F: FnOnce() -> Fut,
Fut: Future<Output = Result<T, Status>>,
{
self.run_unary_inner(method, call).await
}
async fn run_unary_inner<F, Fut, T>(&self, method: &str, call: F) -> Result<T, Status>
where
F: FnOnce() -> Fut,
Fut: Future<Output = Result<T, Status>>,
{
let _permit = self.acquire_concurrency().await?;
let key = rpc_resilience_key(&self.service, method);
self.acquire_limiter(&key).await?;
let shedder = self.acquire_shedder(&key).await?;
let breaker = self.acquire_breaker(&key).await?;
let result = timeout(self.config.request_timeout, call()).await;
match result {
Ok(Ok(value)) => {
record_success(breaker, shedder).await;
self.record("rpc", "success");
Ok(value)
}
Ok(Err(status)) if status_counts_as_failure(&status) => {
record_failure(breaker, shedder).await;
self.record("rpc", "failure");
Err(status)
}
Ok(Err(status)) => {
record_success(breaker, shedder).await;
self.record("rpc", "acceptable_error");
Err(status)
}
Err(_) => {
record_failure(breaker, shedder).await;
self.record("rpc", "timeout");
Err(Status::deadline_exceeded("rpc request timed out"))
}
}
}
async fn acquire_concurrency(
&self,
) -> Result<Option<tokio::sync::OwnedSemaphorePermit>, Status> {
let Some(semaphore) = &self.semaphore else {
return Ok(None);
};
match semaphore.clone().try_acquire_owned() {
Ok(permit) => {
self.record("concurrency", "allowed");
Ok(Some(permit))
}
Err(_) => {
self.record("concurrency", "rejected");
Err(Status::resource_exhausted("rpc concurrency limit reached"))
}
}
}
async fn acquire_shedder(&self, key: &str) -> Result<Option<ShedderGuard>, Status> {
if !self.config.shedding_enabled {
return Ok(None);
}
let shedder = self
.shedders
.get_or_insert(key, shedder_config(&self.config))
.await;
match shedder.allow().await {
Ok(guard) => {
self.record("shedder", "pass");
Ok(Some(guard))
}
Err(_) => {
self.record("shedder", "drop");
Err(Status::resource_exhausted("rpc service overloaded"))
}
}
}
#[cfg(feature = "cache-redis")]
async fn acquire_limiter(&self, key: &str) -> Result<(), Status> {
match self.limiter.allow(key).await {
RpcRateLimitOutcome::Allowed => {
self.record("limiter", "allowed");
Ok(())
}
RpcRateLimitOutcome::Rejected => {
self.record("limiter", "rejected");
Err(Status::resource_exhausted("rpc rate limit exceeded"))
}
RpcRateLimitOutcome::ErrorOpen => {
self.record("limiter", "error_open");
Ok(())
}
RpcRateLimitOutcome::ErrorClosed(error) => {
self.record("limiter", "error_closed");
Err(Status::unavailable(format!(
"rpc rate limiter unavailable: {error}"
)))
}
}
}
#[cfg(not(feature = "cache-redis"))]
async fn acquire_limiter(&self, _key: &str) -> Result<(), Status> {
Ok(())
}
async fn acquire_breaker(&self, key: &str) -> Result<Option<BreakerGuard>, Status> {
if !self.config.breaker_enabled {
return Ok(None);
}
let breaker = self
.breakers
.get_or_insert_with_policy(
key,
BreakerConfig {
failure_threshold: self.config.breaker_failure_threshold,
reset_timeout: self.config.breaker_reset_timeout,
},
breaker_policy(&self.config),
)
.await;
match breaker.allow().await {
Ok(guard) => {
self.record("breaker", "allowed");
Ok(Some(guard))
}
Err(error) => {
self.record("breaker", "dropped");
Err(Status::unavailable(error.to_string()))
}
}
}
fn record(&self, component: &str, outcome: &str) {
record_metric(&self.metrics, component, outcome);
}
}
async fn record_success(breaker: Option<BreakerGuard>, shedder: Option<ShedderGuard>) {
if let Some(guard) = breaker {
guard.record_success().await;
}
if let Some(guard) = shedder {
guard.record_success().await;
}
}
async fn record_failure(breaker: Option<BreakerGuard>, shedder: Option<ShedderGuard>) {
if let Some(guard) = breaker {
guard.record_failure().await;
}
if let Some(guard) = shedder {
guard.record_failure().await;
}
}
fn shedder_config(config: &RpcResilienceConfig) -> AdaptiveShedderConfig {
AdaptiveShedderConfig {
max_in_flight: config
.shedding_max_in_flight
.or(config.max_concurrency)
.unwrap_or(1024),
min_request_count: config.shedding_min_request_count,
max_latency: config.shedding_max_latency,
cpu_threshold_millis: config.shedding_cpu_threshold_millis,
cool_off: config.shedding_cool_off,
window: WindowConfig {
buckets: config.shedding_window_buckets,
bucket_duration: config.shedding_window_bucket_duration,
},
..AdaptiveShedderConfig::default()
}
}
fn breaker_policy(config: &RpcResilienceConfig) -> BreakerPolicyConfig {
let mut policy = if config.breaker_sre_enabled {
BreakerPolicyConfig::google_sre()
} else {
BreakerPolicyConfig::default()
};
policy.sre_k_millis = config.breaker_sre_k_millis;
policy.sre_protection = config.breaker_sre_protection;
policy
}
pub fn status_counts_as_failure(status: &Status) -> bool {
matches!(
status.code(),
Code::Unknown
| Code::DeadlineExceeded
| Code::ResourceExhausted
| Code::Internal
| Code::Unavailable
| Code::DataLoss
)
}
#[cfg(feature = "observability")]
fn default_metrics() -> RpcMetrics {
None
}
#[cfg(not(feature = "observability"))]
fn default_metrics() -> RpcMetrics {}
#[cfg(feature = "observability")]
fn record_metric(metrics: &RpcMetrics, component: &str, outcome: &str) {
crate::observability::record_resilience_decision(metrics.as_ref(), "grpc", component, outcome);
}
#[cfg(not(feature = "observability"))]
fn record_metric(_metrics: &RpcMetrics, _component: &str, _outcome: &str) {}
#[cfg(test)]
mod tests {
use std::time::Duration;
use super::{RpcResilienceLayer, status_counts_as_failure};
use crate::rpc::RpcResilienceConfig;
#[tokio::test]
async fn rpc_layer_maps_concurrency_rejection() {
let layer = RpcResilienceLayer::new(
"hello",
RpcResilienceConfig {
max_concurrency: Some(1),
..RpcResilienceConfig::default()
},
);
let held = layer.clone();
let handle = tokio::spawn(async move {
held.run_unary("Say", || async {
tokio::time::sleep(Duration::from_millis(40)).await;
Ok::<_, tonic::Status>("ok")
})
.await
});
tokio::time::sleep(Duration::from_millis(5)).await;
let rejected = layer.run_unary("Say", || async { Ok("ok") }).await;
assert_eq!(
rejected.expect_err("rejected").code(),
tonic::Code::ResourceExhausted
);
assert_eq!(handle.await.expect("join").expect("first"), "ok");
}
#[tokio::test]
async fn rpc_layer_opens_breaker_after_failure() {
let layer = RpcResilienceLayer::new(
"hello",
RpcResilienceConfig {
breaker_enabled: true,
breaker_failure_threshold: 1,
breaker_reset_timeout: Duration::from_secs(60),
..RpcResilienceConfig::default()
},
);
let _ = layer
.run_unary("Say", || async {
Err::<(), _>(tonic::Status::internal("down"))
})
.await;
let rejected = layer.run_unary("Say", || async { Ok(()) }).await;
assert_eq!(rejected.expect_err("open").code(), tonic::Code::Unavailable);
}
#[test]
fn invalid_argument_is_acceptable_for_breaker() {
assert!(!status_counts_as_failure(&tonic::Status::invalid_argument(
"bad"
)));
assert!(status_counts_as_failure(&tonic::Status::unavailable(
"down"
)));
}
}