use std::{
sync::{
Arc,
atomic::{AtomicBool, Ordering},
},
time::{Duration, SystemTime, UNIX_EPOCH},
};
use tokio::sync::Mutex;
use crate::cache_redis::{
RedisCacheConfig, RedisCacheError, RedisCacheResult, RedisClientFactory, RedisLuaScript,
};
const TOKEN_BUCKET_SCRIPT: &str = r#"
local key = KEYS[1]
local now = tonumber(ARGV[1])
local rate = tonumber(ARGV[2])
local burst = tonumber(ARGV[3])
local requested = tonumber(ARGV[4])
local ttl = tonumber(ARGV[5])
local values = redis.call("HMGET", key, "tokens", "ts")
local tokens = tonumber(values[1]) or burst
local ts = tonumber(values[2]) or now
local elapsed = math.max(0, now - ts) / 1000
tokens = math.min(burst, tokens + elapsed * rate)
local allowed = 0
if tokens >= requested then
tokens = tokens - requested
allowed = 1
end
redis.call("HMSET", key, "tokens", tokens, "ts", now)
redis.call("PEXPIRE", key, ttl)
return allowed
"#;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct TokenLimiterRescueConfig {
pub enabled: bool,
pub ping_interval: Duration,
}
impl Default for TokenLimiterRescueConfig {
fn default() -> Self {
Self {
enabled: false,
ping_interval: Duration::from_millis(100),
}
}
}
impl TokenLimiterRescueConfig {
pub fn go_zero_defaults() -> Self {
Self {
enabled: true,
..Self::default()
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct RedisTokenLimiterConfig {
pub redis: RedisCacheConfig,
pub key_prefix: String,
pub rate_per_second: u32,
pub burst: u32,
pub fail_open: bool,
pub rescue: TokenLimiterRescueConfig,
pub ttl: Duration,
}
impl Default for RedisTokenLimiterConfig {
fn default() -> Self {
Self {
redis: RedisCacheConfig::default(),
key_prefix: "rs-zero:limit:token".to_string(),
rate_per_second: 100,
burst: 100,
fail_open: false,
rescue: TokenLimiterRescueConfig::default(),
ttl: Duration::from_secs(60),
}
}
}
impl RedisTokenLimiterConfig {
pub fn go_zero_defaults() -> Self {
Self {
rescue: TokenLimiterRescueConfig::go_zero_defaults(),
..Self::default()
}
}
}
#[derive(Debug, Clone)]
pub struct RedisTokenLimiter {
config: RedisTokenLimiterConfig,
client: redis::Client,
script: RedisLuaScript,
rescue: Arc<Mutex<LocalTokenBucket>>,
redis_alive: Arc<AtomicBool>,
monitor_started: Arc<AtomicBool>,
}
impl RedisTokenLimiter {
pub fn new(config: RedisTokenLimiterConfig) -> RedisCacheResult<Self> {
validate_config(&config)?;
let client = RedisClientFactory::new(config.redis.clone())?.create_client()?;
Ok(Self {
rescue: Arc::new(Mutex::new(LocalTokenBucket::new(
config.burst,
now_millis(),
))),
redis_alive: Arc::new(AtomicBool::new(true)),
monitor_started: Arc::new(AtomicBool::new(false)),
config,
client,
script: RedisLuaScript::new(TOKEN_BUCKET_SCRIPT),
})
}
pub fn redis_alive(&self) -> bool {
self.redis_alive.load(Ordering::Acquire)
}
pub async fn allow(&self, key: &str) -> RedisCacheResult<bool> {
self.allow_n(key, 1).await
}
pub async fn allow_n(&self, key: &str, requested: u32) -> RedisCacheResult<bool> {
if requested == 0 {
return Ok(true);
}
if self.config.rescue.enabled && !self.redis_alive() {
return self.allow_rescue(requested).await;
}
let result = self.try_allow_n(key, requested).await;
match result {
Ok(allowed) => Ok(allowed),
Err(_) if self.config.rescue.enabled => {
self.start_monitor();
self.allow_rescue(requested).await
}
Err(_) if self.config.fail_open => Ok(true),
Err(error) => Err(error),
}
}
pub async fn try_allow_n(&self, key: &str, requested: u32) -> RedisCacheResult<bool> {
let mut connection = tokio::time::timeout(
self.config.redis.connect_timeout,
self.client.get_multiplexed_async_connection(),
)
.await
.map_err(|_| RedisCacheError::Timeout("connect".to_string()))?
.map_err(|error| RedisCacheError::Connection(error.to_string()))?;
let key = format!("{}:{key}", self.config.key_prefix);
let allowed = self
.script
.invoke_async::<i64>(
&mut connection,
&[key],
&[
now_millis().to_string(),
self.config.rate_per_second.to_string(),
self.config.burst.to_string(),
requested.to_string(),
(self.config.ttl.as_millis().max(1) as u64).to_string(),
],
self.config.redis.command_timeout,
)
.await?;
Ok(allowed == 1)
}
async fn allow_rescue(&self, requested: u32) -> RedisCacheResult<bool> {
let mut bucket = self.rescue.lock().await;
Ok(bucket.allow(
now_millis(),
self.config.rate_per_second,
self.config.burst,
requested,
))
}
fn start_monitor(&self) {
self.redis_alive.store(false, Ordering::Release);
if self
.monitor_started
.compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire)
.is_err()
{
return;
}
let client = self.client.clone();
let connect_timeout = self.config.redis.connect_timeout;
let command_timeout = self.config.redis.command_timeout;
let ping_interval = self.config.rescue.ping_interval;
let redis_alive = self.redis_alive.clone();
let monitor_started = self.monitor_started.clone();
tokio::spawn(async move {
let mut interval = tokio::time::interval(ping_interval);
loop {
interval.tick().await;
if redis_ping(&client, connect_timeout, command_timeout).await {
redis_alive.store(true, Ordering::Release);
monitor_started.store(false, Ordering::Release);
return;
}
}
});
}
}
async fn redis_ping(
client: &redis::Client,
connect_timeout: Duration,
command_timeout: Duration,
) -> bool {
let connection =
tokio::time::timeout(connect_timeout, client.get_multiplexed_async_connection()).await;
let Ok(Ok(mut connection)) = connection else {
return false;
};
let response = tokio::time::timeout(
command_timeout,
redis::cmd("PING").query_async::<String>(&mut connection),
)
.await;
matches!(response, Ok(Ok(value)) if value == "PONG")
}
fn validate_config(config: &RedisTokenLimiterConfig) -> RedisCacheResult<()> {
config.redis.validate()?;
if config.key_prefix.trim().is_empty() {
return Err(RedisCacheError::InvalidConfig(
"token limiter key_prefix is required".to_string(),
));
}
if config.rate_per_second == 0 || config.burst == 0 {
return Err(RedisCacheError::InvalidConfig(
"token limiter rate_per_second and burst must be greater than zero".to_string(),
));
}
if config.rescue.enabled && config.rescue.ping_interval.is_zero() {
return Err(RedisCacheError::InvalidConfig(
"token limiter rescue ping_interval must be greater than zero".to_string(),
));
}
Ok(())
}
fn now_millis() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_millis() as u64
}
#[derive(Debug, Clone, Copy, PartialEq)]
struct LocalTokenBucket {
tokens: f64,
last_millis: u64,
}
impl LocalTokenBucket {
fn new(burst: u32, now_millis: u64) -> Self {
Self {
tokens: burst as f64,
last_millis: now_millis,
}
}
fn allow(&mut self, now_millis: u64, rate_per_second: u32, burst: u32, requested: u32) -> bool {
let elapsed = now_millis.saturating_sub(self.last_millis) as f64 / 1000.0;
self.tokens = (self.tokens + elapsed * rate_per_second as f64).min(burst as f64);
self.last_millis = now_millis;
if self.tokens >= requested as f64 {
self.tokens -= requested as f64;
true
} else {
false
}
}
}
#[cfg(test)]
mod tests {
use super::{
LocalTokenBucket, RedisTokenLimiterConfig, TokenLimiterRescueConfig, validate_config,
};
use crate::cache_redis::RedisCacheError;
#[test]
fn local_token_bucket_refills_over_time() {
let mut bucket = LocalTokenBucket {
tokens: 2.0,
last_millis: 0,
};
assert!(bucket.allow(0, 2, 2, 2));
assert!(!bucket.allow(0, 2, 2, 1));
assert!(bucket.allow(500, 2, 2, 1));
}
#[test]
fn go_zero_defaults_enable_rescue() {
let config = RedisTokenLimiterConfig::go_zero_defaults();
assert!(config.rescue.enabled);
assert!(!config.fail_open);
}
#[test]
fn rescue_ping_interval_must_be_positive() {
let config = RedisTokenLimiterConfig {
rescue: TokenLimiterRescueConfig {
enabled: true,
ping_interval: std::time::Duration::ZERO,
},
..RedisTokenLimiterConfig::default()
};
let error = validate_config(&config).expect_err("invalid config");
assert!(
matches!(error, RedisCacheError::InvalidConfig(message) if message.contains("ping_interval"))
);
}
}