use std::time::{Duration, SystemTime, UNIX_EPOCH};
use crate::cache_redis::{
RedisCacheConfig, RedisCacheError, RedisCacheResult, RedisClientFactory, RedisLuaScript,
};
const PERIOD_LIMIT_SCRIPT: &str = r#"
local key = KEYS[1]
local quota = tonumber(ARGV[1])
local ttl = tonumber(ARGV[2])
local current = redis.call("INCR", key)
if current == 1 then
redis.call("PEXPIRE", key, ttl)
end
if current > quota then
return 0
end
return 1
"#;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum PeriodAlignment {
#[default]
None,
Natural,
UtcOffset { seconds: i32 },
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct RedisPeriodLimiterConfig {
pub redis: RedisCacheConfig,
pub key_prefix: String,
pub period: Duration,
pub quota: u32,
pub fail_open: bool,
pub alignment: PeriodAlignment,
}
impl Default for RedisPeriodLimiterConfig {
fn default() -> Self {
Self {
redis: RedisCacheConfig::default(),
key_prefix: "rs-zero:limit:period".to_string(),
period: Duration::from_secs(60),
quota: 100,
fail_open: false,
alignment: PeriodAlignment::default(),
}
}
}
impl RedisPeriodLimiterConfig {
pub fn aligned(mut self) -> Self {
self.alignment = PeriodAlignment::Natural;
self
}
pub fn aligned_to_utc_offset(mut self, seconds: i32) -> Self {
self.alignment = PeriodAlignment::UtcOffset { seconds };
self
}
}
#[derive(Debug, Clone)]
pub struct RedisPeriodLimiter {
config: RedisPeriodLimiterConfig,
client: redis::Client,
script: RedisLuaScript,
}
impl RedisPeriodLimiter {
pub fn new(config: RedisPeriodLimiterConfig) -> RedisCacheResult<Self> {
validate_config(&config)?;
let client = RedisClientFactory::new(config.redis.clone())?.create_client()?;
Ok(Self {
config,
client,
script: RedisLuaScript::new(PERIOD_LIMIT_SCRIPT),
})
}
pub async fn allow(&self, key: &str) -> RedisCacheResult<bool> {
let result = self.try_allow(key).await;
match result {
Ok(allowed) => Ok(allowed),
Err(_) if self.config.fail_open => Ok(true),
Err(error) => Err(error),
}
}
pub async fn try_allow(&self, key: &str) -> RedisCacheResult<bool> {
let mut connection = tokio::time::timeout(
self.config.redis.connect_timeout,
self.client.get_multiplexed_async_connection(),
)
.await
.map_err(|_| RedisCacheError::Timeout("connect".to_string()))?
.map_err(|error| RedisCacheError::Connection(error.to_string()))?;
let key = self.window_key(key);
let allowed = self
.script
.invoke_async::<i64>(
&mut connection,
&[key],
&[
self.config.quota.to_string(),
self.ttl_millis(now_millis()).to_string(),
],
self.config.redis.command_timeout,
)
.await?;
Ok(allowed == 1)
}
fn window_key(&self, key: &str) -> String {
let period_ms = self.config.period.as_millis().max(1);
let window = aligned_epoch_millis(now_millis(), self.config.alignment) as u128 / period_ms;
format!("{}:{key}:{window}", self.config.key_prefix)
}
fn ttl_millis(&self, now_millis: u64) -> u64 {
ttl_millis(self.config.period, self.config.alignment, now_millis)
}
}
fn validate_config(config: &RedisPeriodLimiterConfig) -> RedisCacheResult<()> {
config.redis.validate()?;
if config.key_prefix.trim().is_empty() {
return Err(RedisCacheError::InvalidConfig(
"period limiter key_prefix is required".to_string(),
));
}
if config.quota == 0 || config.period.is_zero() {
return Err(RedisCacheError::InvalidConfig(
"period limiter quota and period must be greater than zero".to_string(),
));
}
Ok(())
}
fn ttl_millis(period: Duration, alignment: PeriodAlignment, now_millis: u64) -> u64 {
let period_millis = period.as_millis().max(1) as u64;
match alignment {
PeriodAlignment::None => period_millis,
PeriodAlignment::Natural | PeriodAlignment::UtcOffset { .. } => {
let aligned = aligned_epoch_millis(now_millis, alignment);
let elapsed = aligned % period_millis;
if elapsed == 0 {
period_millis
} else {
period_millis - elapsed
}
}
}
}
fn aligned_epoch_millis(now_millis: u64, alignment: PeriodAlignment) -> u64 {
match alignment {
PeriodAlignment::None | PeriodAlignment::Natural => now_millis,
PeriodAlignment::UtcOffset { seconds } => {
let offset_millis = i128::from(seconds) * 1000;
let shifted = i128::from(now_millis) + offset_millis;
shifted.max(0) as u64
}
}
}
fn now_millis() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_millis() as u64
}
#[cfg(test)]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
struct LocalPeriodWindow {
count: u32,
quota: u32,
}
#[cfg(test)]
impl LocalPeriodWindow {
fn allow(&mut self) -> bool {
self.count = self.count.saturating_add(1);
self.count <= self.quota
}
}
#[cfg(test)]
mod tests {
use super::{LocalPeriodWindow, PeriodAlignment, RedisPeriodLimiterConfig, ttl_millis};
use std::time::Duration;
#[test]
fn local_period_window_rejects_after_quota() {
let mut window = LocalPeriodWindow { count: 0, quota: 2 };
assert!(window.allow());
assert!(window.allow());
assert!(!window.allow());
}
#[test]
fn default_ttl_is_fixed_period() {
assert_eq!(
ttl_millis(Duration::from_secs(60), PeriodAlignment::None, 12_345),
60_000
);
}
#[test]
fn natural_alignment_returns_remaining_period() {
assert_eq!(
ttl_millis(Duration::from_secs(60), PeriodAlignment::Natural, 61_000),
59_000
);
}
#[test]
fn utc_offset_alignment_shifts_boundary() {
assert_eq!(
ttl_millis(
Duration::from_secs(60),
PeriodAlignment::UtcOffset { seconds: 30 },
20_000,
),
10_000
);
}
#[test]
fn aligned_builder_preserves_other_fields() {
let config = RedisPeriodLimiterConfig {
fail_open: true,
..RedisPeriodLimiterConfig::default()
}
.aligned();
assert!(config.fail_open);
assert_eq!(config.alignment, PeriodAlignment::Natural);
}
}