use std::sync::Arc;
use crate::{
resil::{RedisPeriodLimiter, RedisTokenLimiter},
rpc::RpcRateLimiterConfig,
};
#[derive(Debug, Clone)]
pub(crate) struct RpcRateLimiter {
mode: Arc<RpcRateLimiterMode>,
}
#[derive(Debug)]
enum RpcRateLimiterMode {
Disabled,
Token {
limiter: Result<RedisTokenLimiter, String>,
fail_open: bool,
},
Period {
limiter: Result<RedisPeriodLimiter, String>,
fail_open: bool,
},
}
impl RpcRateLimiter {
pub(crate) fn new(config: RpcRateLimiterConfig) -> Self {
let mode = match config {
RpcRateLimiterConfig::Disabled => RpcRateLimiterMode::Disabled,
RpcRateLimiterConfig::RedisToken(config) => {
let fail_open = config.fail_open;
RpcRateLimiterMode::Token {
limiter: RedisTokenLimiter::new(config).map_err(|error| error.to_string()),
fail_open,
}
}
RpcRateLimiterConfig::RedisPeriod(config) => {
let fail_open = config.fail_open;
RpcRateLimiterMode::Period {
limiter: RedisPeriodLimiter::new(config).map_err(|error| error.to_string()),
fail_open,
}
}
};
Self {
mode: Arc::new(mode),
}
}
pub(crate) async fn allow(&self, key: &str) -> RpcRateLimitOutcome {
match self.mode.as_ref() {
RpcRateLimiterMode::Disabled => RpcRateLimitOutcome::Allowed,
RpcRateLimiterMode::Token { limiter, fail_open } => {
allow_token(limiter, *fail_open, key).await
}
RpcRateLimiterMode::Period { limiter, fail_open } => {
allow_period(limiter, *fail_open, key).await
}
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub(crate) enum RpcRateLimitOutcome {
Allowed,
Rejected,
ErrorOpen,
ErrorClosed(String),
}
async fn allow_token(
limiter: &Result<RedisTokenLimiter, String>,
fail_open: bool,
key: &str,
) -> RpcRateLimitOutcome {
match limiter {
Ok(limiter) => match limiter.try_allow_n(key, 1).await {
Ok(true) => RpcRateLimitOutcome::Allowed,
Ok(false) => RpcRateLimitOutcome::Rejected,
Err(_) if fail_open => RpcRateLimitOutcome::ErrorOpen,
Err(error) => RpcRateLimitOutcome::ErrorClosed(error.to_string()),
},
Err(_) if fail_open => RpcRateLimitOutcome::ErrorOpen,
Err(error) => RpcRateLimitOutcome::ErrorClosed(error.clone()),
}
}
async fn allow_period(
limiter: &Result<RedisPeriodLimiter, String>,
fail_open: bool,
key: &str,
) -> RpcRateLimitOutcome {
match limiter {
Ok(limiter) => match limiter.try_allow(key).await {
Ok(true) => RpcRateLimitOutcome::Allowed,
Ok(false) => RpcRateLimitOutcome::Rejected,
Err(_) if fail_open => RpcRateLimitOutcome::ErrorOpen,
Err(error) => RpcRateLimitOutcome::ErrorClosed(error.to_string()),
},
Err(_) if fail_open => RpcRateLimitOutcome::ErrorOpen,
Err(error) => RpcRateLimitOutcome::ErrorClosed(error.clone()),
}
}