use std::future::Future;
use std::time::Duration;
use rskit_errors::{AppError, AppResult};
use crate::{
Bulkhead, BulkheadConfig, CbConfig, CircuitBreaker, RateLimiter, RateLimiterConfig, RetryPolicy,
};
#[derive(Clone, Default)]
pub struct Policy {
retry: Option<RetryPolicy>,
circuit_breaker: Option<CircuitBreaker>,
bulkhead: Option<Bulkhead>,
rate_limiter: Option<RateLimiter>,
timeout: Option<Duration>,
}
impl Policy {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn with_retry(mut self, config: RetryPolicy) -> Self {
self.retry = Some(config);
self
}
#[must_use = "builder methods return an updated policy; use the returned value"]
pub fn with_circuit_breaker(self, config: CbConfig) -> AppResult<Self> {
self.try_with_circuit_breaker(config)
}
#[must_use = "builder methods return an updated policy; use the returned value"]
pub fn try_with_circuit_breaker(mut self, config: CbConfig) -> AppResult<Self> {
self.circuit_breaker = Some(CircuitBreaker::new(config)?);
Ok(self)
}
#[must_use = "builder methods return an updated policy; use the returned value"]
pub fn with_bulkhead(self, config: BulkheadConfig) -> AppResult<Self> {
self.try_with_bulkhead(config)
}
#[must_use = "builder methods return an updated policy; use the returned value"]
pub fn try_with_bulkhead(mut self, config: BulkheadConfig) -> AppResult<Self> {
self.bulkhead = Some(Bulkhead::new(config)?);
Ok(self)
}
#[must_use]
pub fn with_rate_limiter(mut self, rate_limiter: RateLimiter) -> Self {
self.rate_limiter = Some(rate_limiter);
self
}
#[must_use = "builder methods return an updated policy; use the returned value"]
pub fn try_with_rate_limiter_config(mut self, config: RateLimiterConfig) -> AppResult<Self> {
self.rate_limiter = Some(RateLimiter::from_config(config)?);
Ok(self)
}
#[must_use]
pub fn with_timeout(mut self, timeout: Duration) -> Self {
self.timeout = Some(timeout);
self
}
pub async fn execute<F, Fut, T, E>(&self, mut f: F) -> Result<T, E>
where
F: FnMut() -> Fut,
Fut: Future<Output = Result<T, E>>,
E: From<AppError> + Into<AppError>,
{
if let Some(rate_limiter) = &self.rate_limiter {
rate_limiter.check().map_err(E::from)?;
}
if let Some(bulkhead) = &self.bulkhead {
bulkhead
.execute(|| async {
execute_circuit_breaker(
self.circuit_breaker.as_ref(),
self.timeout,
self.retry.as_ref(),
&mut f,
)
.await
.map_err(Into::into)
})
.await
.map_err(E::from)
} else {
execute_circuit_breaker(
self.circuit_breaker.as_ref(),
self.timeout,
self.retry.as_ref(),
&mut f,
)
.await
}
}
}
async fn execute_circuit_breaker<F, Fut, T, E>(
circuit_breaker: Option<&CircuitBreaker>,
timeout: Option<Duration>,
retry: Option<&RetryPolicy>,
f: &mut F,
) -> Result<T, E>
where
F: FnMut() -> Fut,
Fut: Future<Output = Result<T, E>>,
E: From<AppError> + Into<AppError>,
{
if let Some(circuit_breaker) = circuit_breaker {
circuit_breaker
.execute(|| async { execute_timeout(timeout, retry, f).await.map_err(Into::into) })
.await
.map_err(E::from)
} else {
execute_timeout(timeout, retry, f).await
}
}
async fn execute_timeout<F, Fut, T, E>(
timeout: Option<Duration>,
retry: Option<&RetryPolicy>,
f: &mut F,
) -> Result<T, E>
where
F: FnMut() -> Fut,
Fut: Future<Output = Result<T, E>>,
E: From<AppError> + Into<AppError>,
{
if let Some(timeout) = timeout {
tokio::time::timeout(timeout, execute_retry(retry, f))
.await
.map_err(|_| E::from(AppError::timeout("resilience policy")))?
} else {
execute_retry(retry, f).await
}
}
async fn execute_retry<F, Fut, T, E>(retry: Option<&RetryPolicy>, f: &mut F) -> Result<T, E>
where
F: FnMut() -> Fut,
Fut: Future<Output = Result<T, E>>,
E: From<AppError> + Into<AppError>,
{
if let Some(retry) = retry {
retry.validate().map_err(E::from)?;
let mut attempt = 0usize;
loop {
attempt += 1;
match f().await {
Ok(value) => return Ok(value),
Err(err) => {
let retryable_error: AppError = err.into();
let should_retry = retry
.retry_if
.as_ref()
.map(|predicate| predicate(&retryable_error))
.unwrap_or_else(|| retryable_error.is_retryable());
if attempt >= retry.max_attempts || !should_retry {
return Err(E::from(retryable_error));
}
if let Some(callback) = &retry.on_retry {
callback(attempt as u32, &retryable_error);
}
tokio::time::sleep(retry.backoff(attempt)).await;
}
}
}
} else {
f().await
}
}
#[cfg(test)]
mod tests {
use std::sync::{
Arc,
atomic::{AtomicUsize, Ordering},
};
use rskit_errors::{AppError, ErrorCode};
use super::*;
use crate::{BulkheadConfig, CbConfig, ConstantBackoff, LinearBackoff, RateLimiterConfig};
#[tokio::test]
async fn policy_retries_until_success() {
let mut attempts = 0usize;
let policy = Policy::new().with_retry(
RetryPolicy::new()
.with_max_attempts(3)
.with_constant_backoff(ConstantBackoff::new(Duration::from_millis(1)))
.with_jitter(false),
);
let result = policy
.execute(|| {
attempts += 1;
let attempt = attempts;
async move {
if attempt == 1 {
Err::<u32, AppError>(AppError::connection_failed("upstream"))
} else {
Ok(7)
}
}
})
.await;
assert_eq!(result.unwrap(), 7);
assert_eq!(attempts, 2);
}
#[tokio::test]
async fn policy_circuit_breaker_state_is_shared_across_calls() {
let attempts = Arc::new(AtomicUsize::new(0));
let policy = Policy::new()
.with_circuit_breaker(CbConfig::new("shared").with_max_failures(1))
.unwrap();
let first = policy
.execute(|| async { Err::<(), AppError>(AppError::new(ErrorCode::Internal, "boom")) })
.await;
assert!(first.is_err());
let second = policy
.execute(|| {
let attempts = attempts.clone();
async move {
attempts.fetch_add(1, Ordering::SeqCst);
Ok::<(), AppError>(())
}
})
.await;
assert!(second.is_err());
assert_eq!(attempts.load(Ordering::SeqCst), 0);
}
#[tokio::test]
async fn policy_timeout_wraps_retry_block() {
let attempts = Arc::new(AtomicUsize::new(0));
let policy = Policy::new()
.with_retry(
RetryPolicy::new()
.with_max_attempts(5)
.with_constant_backoff(ConstantBackoff::new(Duration::from_millis(5)))
.with_jitter(false),
)
.with_timeout(Duration::from_millis(20));
let result = policy
.execute(|| {
let attempts = attempts.clone();
async move {
attempts.fetch_add(1, Ordering::SeqCst);
tokio::time::sleep(Duration::from_millis(15)).await;
Err::<(), AppError>(AppError::connection_failed("slow"))
}
})
.await;
assert!(result.is_err());
assert_eq!(result.unwrap_err().code(), ErrorCode::Timeout);
assert!(attempts.load(Ordering::SeqCst) < 5);
}
#[tokio::test]
async fn policy_combines_outer_guards() {
let policy = Policy::new()
.try_with_rate_limiter_config(RateLimiterConfig::new("policy", 10, 2))
.unwrap()
.with_bulkhead(BulkheadConfig::new("policy", 1))
.unwrap()
.with_circuit_breaker(CbConfig::new("policy"))
.unwrap()
.with_retry(RetryPolicy::new().with_linear_backoff(LinearBackoff::new(
Duration::from_millis(1),
Duration::from_millis(1),
Duration::from_millis(5),
)));
let result = policy.execute(|| async { Ok::<_, AppError>(42) }).await;
assert_eq!(result.unwrap(), 42);
}
#[tokio::test]
async fn policy_rejects_invalid_builder_configs() {
assert!(
Policy::new()
.with_circuit_breaker(CbConfig::new("bad-cb").with_max_failures(0))
.is_err()
);
assert!(
Policy::new()
.with_bulkhead(BulkheadConfig::new("bad-bulkhead", 0))
.is_err()
);
assert!(
Policy::new()
.try_with_rate_limiter_config(RateLimiterConfig::new("bad-rate", 0, 1))
.is_err()
);
}
#[tokio::test]
async fn policy_without_retry_calls_operation_once_and_propagates_error() {
let attempts = Arc::new(AtomicUsize::new(0));
let policy = Policy::new();
let result = policy
.execute(|| {
let attempts = Arc::clone(&attempts);
async move {
attempts.fetch_add(1, Ordering::SeqCst);
Err::<(), AppError>(AppError::connection_failed("upstream"))
}
})
.await;
assert_eq!(attempts.load(Ordering::SeqCst), 1);
assert_eq!(result.unwrap_err().code(), ErrorCode::ConnectionFailed);
}
}