use crate::retry_result::RetryResult;
use rand::RngExt;
use std::sync::{Arc, Mutex};
#[derive(thiserror::Error, Debug)]
#[non_exhaustive]
pub enum Error {
#[error("the scaling factor ({0}) must be greater or equal than 0.0")]
ScalingOutOfRange(f64),
#[error(
"the minimum tokens ({min}) must be less than or equal to the initial token ({initial}) count"
)]
TooFewMinTokens { min: u64, initial: u64 },
}
pub trait RetryThrottler: Send + Sync + std::fmt::Debug {
#[cfg_attr(not(feature = "_internal-semver"), doc(hidden))]
fn throttle_retry_attempt(&self) -> bool;
#[cfg_attr(not(feature = "_internal-semver"), doc(hidden))]
fn on_retry_failure(&mut self, flow: &RetryResult);
#[cfg_attr(not(feature = "_internal-semver"), doc(hidden))]
fn on_success(&mut self);
}
pub type SharedRetryThrottler = Arc<Mutex<dyn RetryThrottler>>;
#[derive(Clone, Debug)]
pub struct RetryThrottlerArg(SharedRetryThrottler);
impl<T: RetryThrottler + 'static> From<T> for RetryThrottlerArg {
fn from(value: T) -> Self {
Self(Arc::new(Mutex::new(value)))
}
}
impl From<SharedRetryThrottler> for RetryThrottlerArg {
fn from(value: SharedRetryThrottler) -> Self {
Self(value)
}
}
impl From<RetryThrottlerArg> for SharedRetryThrottler {
fn from(value: RetryThrottlerArg) -> SharedRetryThrottler {
value.0
}
}
#[derive(Clone, Debug)]
pub struct AdaptiveThrottler {
accept_count: f64,
request_count: f64,
factor: f64,
}
impl AdaptiveThrottler {
pub fn new(factor: f64) -> Result<Self, Error> {
if factor < 0.0 {
return Err(Error::ScalingOutOfRange(factor));
}
let factor = if factor < 0.0 { 0.0 } else { factor };
Ok(Self::clamp(factor))
}
pub fn clamp(factor: f64) -> Self {
let factor = if factor < 0.0 { 0.0 } else { factor };
Self {
accept_count: 0.0,
request_count: 0.0,
factor,
}
}
fn throttle<R: rand::Rng>(&self, rng: &mut R) -> bool {
let reject_probability =
(self.request_count - self.factor * self.accept_count) / (self.request_count + 1.0);
let reject_probability = if reject_probability < 0.0 {
0_f64
} else {
reject_probability
};
rng.random_range(0.0..=1.0) <= reject_probability
}
}
impl std::default::Default for AdaptiveThrottler {
fn default() -> Self {
Self::clamp(2.0)
}
}
impl RetryThrottler for AdaptiveThrottler {
fn throttle_retry_attempt(&self) -> bool {
self.throttle(&mut rand::rng())
}
fn on_retry_failure(&mut self, flow: &RetryResult) {
self.request_count += 1.0;
match flow {
RetryResult::Continue(_) | RetryResult::Exhausted(_) => {}
RetryResult::Permanent(_) => {
self.accept_count += 1.0;
}
};
}
fn on_success(&mut self) {
self.request_count += 1.0;
self.accept_count += 1.0;
}
}
#[derive(Clone, Debug)]
pub struct CircuitBreaker {
max_tokens: u64,
min_tokens: u64,
cur_tokens: u64,
error_cost: u64,
}
impl CircuitBreaker {
pub fn new(tokens: u64, min_tokens: u64, error_cost: u64) -> Result<Self, Error> {
if min_tokens > tokens {
return Err(Error::TooFewMinTokens {
min: min_tokens,
initial: tokens,
});
}
Ok(Self {
max_tokens: tokens,
min_tokens,
cur_tokens: tokens,
error_cost,
})
}
pub fn clamp(tokens: u64, min_tokens: u64, error_cost: u64) -> Self {
Self {
max_tokens: tokens,
min_tokens: std::cmp::min(min_tokens, tokens),
cur_tokens: tokens,
error_cost,
}
}
}
impl std::default::Default for CircuitBreaker {
fn default() -> Self {
CircuitBreaker::clamp(100, 50, 10)
}
}
impl RetryThrottler for CircuitBreaker {
fn throttle_retry_attempt(&self) -> bool {
self.cur_tokens <= self.min_tokens
}
fn on_retry_failure(&mut self, flow: &RetryResult) {
match flow {
RetryResult::Continue(_) | RetryResult::Exhausted(_) => {
self.cur_tokens = self.cur_tokens.saturating_sub(self.error_cost);
}
RetryResult::Permanent(_) => {
self.on_success();
}
};
}
fn on_success(&mut self) {
self.cur_tokens = std::cmp::min(self.max_tokens, self.cur_tokens.saturating_add(1));
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::mock_rng::MockRng;
type TestResult = std::result::Result<(), Box<dyn std::error::Error>>;
#[test]
fn retry_throttler_arg() {
let throttler = AdaptiveThrottler::default();
let _ = RetryThrottlerArg::from(throttler);
let throttler: Arc<Mutex<dyn RetryThrottler>> =
Arc::new(Mutex::new(CircuitBreaker::default()));
let _ = RetryThrottlerArg::from(throttler);
}
#[test]
fn adaptive_construction() {
let throttler = AdaptiveThrottler::new(-2.0);
assert!(
matches!(throttler, Err(Error::ScalingOutOfRange { .. })),
"{throttler:?}"
);
let throttler = AdaptiveThrottler::new(0.0);
assert!(throttler.is_ok(), "{throttler:?}");
}
fn test_error() -> crate::error::Error {
use crate::error::{
Error,
rpc::{Code, Status},
};
Error::service(Status::default().set_code(Code::Aborted))
}
#[test]
fn adaptive() -> TestResult {
let mut throttler = AdaptiveThrottler::default();
assert_eq!(throttler.request_count, 0.0);
assert_eq!(throttler.accept_count, 0.0);
assert_eq!(throttler.factor, 2.0);
assert!(!throttler.throttle_retry_attempt(), "{throttler:?}");
throttler.on_retry_failure(&RetryResult::Continue(test_error()));
assert_eq!(throttler.request_count, 1.0);
assert_eq!(throttler.accept_count, 0.0);
throttler.on_retry_failure(&RetryResult::Continue(test_error()));
assert_eq!(throttler.request_count, 2.0);
assert_eq!(throttler.accept_count, 0.0);
throttler.on_success();
assert_eq!(throttler.request_count, 3.0);
assert_eq!(throttler.accept_count, 1.0);
throttler.on_retry_failure(&RetryResult::Permanent(test_error()));
assert_eq!(throttler.request_count, 4.0);
assert_eq!(throttler.accept_count, 2.0);
let mut throttler = AdaptiveThrottler::default();
throttler.on_retry_failure(&RetryResult::Continue(test_error()));
let mut rng = MockRng::new(0);
assert_eq!(rng.random_range(0.0..=1.0), 0.0);
assert!(throttler.throttle(&mut rng), "{throttler:?}");
let mut rng = MockRng::new(u64::MAX - u64::MAX / 4);
assert!(
rng.random_range(0.0..=1.0) > 0.5,
"{}",
rng.random_range(0.0..=1.0)
);
assert!(!throttler.throttle(&mut rng), "{throttler:?}");
let mut throttler = AdaptiveThrottler::new(100.0)?;
throttler.on_success();
assert!(!throttler.throttle_retry_attempt(), "{throttler:?}");
Ok(())
}
#[test]
fn circuit_breaker_validation() {
let throttler = CircuitBreaker::new(100, 200, 1);
assert!(
matches!(throttler, Err(Error::TooFewMinTokens { .. })),
"{throttler:?}"
);
}
#[test]
fn circuit_breaker() {
let mut throttler = CircuitBreaker::default();
assert!(!throttler.throttle_retry_attempt(), "{throttler:?}");
for _ in 0..4 {
throttler.on_retry_failure(&RetryResult::Continue(test_error()));
assert!(!throttler.throttle_retry_attempt(), "{throttler:?}");
}
throttler.on_retry_failure(&RetryResult::Continue(test_error()));
throttler.on_retry_failure(&RetryResult::Continue(test_error()));
assert!(throttler.throttle_retry_attempt(), "{throttler:?}");
for _ in 0..10 {
throttler.on_success();
assert!(throttler.throttle_retry_attempt(), "{throttler:?}");
}
throttler.on_success();
assert!(!throttler.throttle_retry_attempt(), "{throttler:?}");
throttler.on_retry_failure(&RetryResult::Continue(test_error()));
for _ in 0..9 {
throttler.on_retry_failure(&RetryResult::Permanent(test_error()));
assert!(throttler.throttle_retry_attempt(), "{throttler:?}");
}
throttler.on_retry_failure(&RetryResult::Permanent(test_error()));
assert!(!throttler.throttle_retry_attempt(), "{throttler:?}");
}
}