rs-zero 0.2.6

Rust-first microservice framework inspired by go-zero engineering practices
Documentation
use std::time::{Duration, SystemTime, UNIX_EPOCH};

use crate::cache_redis::{
    RedisCacheConfig, RedisCacheError, RedisCacheResult, RedisClientFactory, RedisLuaScript,
};

const PERIOD_LIMIT_SCRIPT: &str = r#"
local key = KEYS[1]
local quota = tonumber(ARGV[1])
local ttl = tonumber(ARGV[2])
local current = redis.call("INCR", key)
if current == 1 then
  redis.call("PEXPIRE", key, ttl)
end
if current > quota then
  return 0
end
return 1
"#;

/// Fixed-window alignment behavior for [`RedisPeriodLimiter`].
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum PeriodAlignment {
    /// Keep a fixed TTL equal to the configured period.
    #[default]
    None,
    /// Expire the current window at the next natural period boundary.
    Natural,
    /// Expire at the next natural period boundary after applying this UTC offset in seconds.
    UtcOffset { seconds: i32 },
}

/// Redis-backed fixed-window period limiter configuration.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct RedisPeriodLimiterConfig {
    /// Redis connection configuration.
    pub redis: RedisCacheConfig,
    /// Namespace prefix for limiter keys.
    pub key_prefix: String,
    /// Window length.
    pub period: Duration,
    /// Maximum allowed requests per period.
    pub quota: u32,
    /// Whether Redis errors should allow traffic instead of failing closed.
    pub fail_open: bool,
    /// Whether the window TTL should align to a natural period boundary.
    pub alignment: PeriodAlignment,
}

impl Default for RedisPeriodLimiterConfig {
    fn default() -> Self {
        Self {
            redis: RedisCacheConfig::default(),
            key_prefix: "rs-zero:limit:period".to_string(),
            period: Duration::from_secs(60),
            quota: 100,
            fail_open: false,
            alignment: PeriodAlignment::default(),
        }
    }
}

impl RedisPeriodLimiterConfig {
    /// Returns a copy of this config aligned to natural period boundaries.
    pub fn aligned(mut self) -> Self {
        self.alignment = PeriodAlignment::Natural;
        self
    }

    /// Returns a copy of this config aligned using a UTC offset in seconds.
    pub fn aligned_to_utc_offset(mut self, seconds: i32) -> Self {
        self.alignment = PeriodAlignment::UtcOffset { seconds };
        self
    }
}

/// Redis-backed fixed-window period limiter.
#[derive(Debug, Clone)]
pub struct RedisPeriodLimiter {
    config: RedisPeriodLimiterConfig,
    client: redis::Client,
    script: RedisLuaScript,
}

impl RedisPeriodLimiter {
    /// Creates a period limiter and validates local configuration.
    pub fn new(config: RedisPeriodLimiterConfig) -> RedisCacheResult<Self> {
        validate_config(&config)?;
        let client = RedisClientFactory::new(config.redis.clone())?.create_client()?;
        Ok(Self {
            config,
            client,
            script: RedisLuaScript::new(PERIOD_LIMIT_SCRIPT),
        })
    }

    /// Returns whether one request is allowed for the given key.
    pub async fn allow(&self, key: &str) -> RedisCacheResult<bool> {
        let result = self.try_allow(key).await;
        match result {
            Ok(allowed) => Ok(allowed),
            Err(_) if self.config.fail_open => Ok(true),
            Err(error) => Err(error),
        }
    }

    /// Checks one request without applying fail-open behavior.
    pub async fn try_allow(&self, key: &str) -> RedisCacheResult<bool> {
        let mut connection = tokio::time::timeout(
            self.config.redis.connect_timeout,
            self.client.get_multiplexed_async_connection(),
        )
        .await
        .map_err(|_| RedisCacheError::Timeout("connect".to_string()))?
        .map_err(|error| RedisCacheError::Connection(error.to_string()))?;
        let key = self.window_key(key);
        let allowed = self
            .script
            .invoke_async::<i64>(
                &mut connection,
                &[key],
                &[
                    self.config.quota.to_string(),
                    self.ttl_millis(now_millis()).to_string(),
                ],
                self.config.redis.command_timeout,
            )
            .await?;
        Ok(allowed == 1)
    }

    fn window_key(&self, key: &str) -> String {
        let period_ms = self.config.period.as_millis().max(1);
        let window = aligned_epoch_millis(now_millis(), self.config.alignment) as u128 / period_ms;
        format!("{}:{key}:{window}", self.config.key_prefix)
    }

    fn ttl_millis(&self, now_millis: u64) -> u64 {
        ttl_millis(self.config.period, self.config.alignment, now_millis)
    }
}

fn validate_config(config: &RedisPeriodLimiterConfig) -> RedisCacheResult<()> {
    config.redis.validate()?;
    if config.key_prefix.trim().is_empty() {
        return Err(RedisCacheError::InvalidConfig(
            "period limiter key_prefix is required".to_string(),
        ));
    }
    if config.quota == 0 || config.period.is_zero() {
        return Err(RedisCacheError::InvalidConfig(
            "period limiter quota and period must be greater than zero".to_string(),
        ));
    }
    Ok(())
}

fn ttl_millis(period: Duration, alignment: PeriodAlignment, now_millis: u64) -> u64 {
    let period_millis = period.as_millis().max(1) as u64;
    match alignment {
        PeriodAlignment::None => period_millis,
        PeriodAlignment::Natural | PeriodAlignment::UtcOffset { .. } => {
            let aligned = aligned_epoch_millis(now_millis, alignment);
            let elapsed = aligned % period_millis;
            if elapsed == 0 {
                period_millis
            } else {
                period_millis - elapsed
            }
        }
    }
}

fn aligned_epoch_millis(now_millis: u64, alignment: PeriodAlignment) -> u64 {
    match alignment {
        PeriodAlignment::None | PeriodAlignment::Natural => now_millis,
        PeriodAlignment::UtcOffset { seconds } => {
            let offset_millis = i128::from(seconds) * 1000;
            let shifted = i128::from(now_millis) + offset_millis;
            shifted.max(0) as u64
        }
    }
}

fn now_millis() -> u64 {
    SystemTime::now()
        .duration_since(UNIX_EPOCH)
        .unwrap_or_default()
        .as_millis() as u64
}

#[cfg(test)]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
struct LocalPeriodWindow {
    count: u32,
    quota: u32,
}

#[cfg(test)]
impl LocalPeriodWindow {
    fn allow(&mut self) -> bool {
        self.count = self.count.saturating_add(1);
        self.count <= self.quota
    }
}

#[cfg(test)]
mod tests {
    use super::{LocalPeriodWindow, PeriodAlignment, RedisPeriodLimiterConfig, ttl_millis};
    use std::time::Duration;

    #[test]
    fn local_period_window_rejects_after_quota() {
        let mut window = LocalPeriodWindow { count: 0, quota: 2 };

        assert!(window.allow());
        assert!(window.allow());
        assert!(!window.allow());
    }

    #[test]
    fn default_ttl_is_fixed_period() {
        assert_eq!(
            ttl_millis(Duration::from_secs(60), PeriodAlignment::None, 12_345),
            60_000
        );
    }

    #[test]
    fn natural_alignment_returns_remaining_period() {
        assert_eq!(
            ttl_millis(Duration::from_secs(60), PeriodAlignment::Natural, 61_000),
            59_000
        );
    }

    #[test]
    fn utc_offset_alignment_shifts_boundary() {
        assert_eq!(
            ttl_millis(
                Duration::from_secs(60),
                PeriodAlignment::UtcOffset { seconds: 30 },
                20_000,
            ),
            10_000
        );
    }

    #[test]
    fn aligned_builder_preserves_other_fields() {
        let config = RedisPeriodLimiterConfig {
            fail_open: true,
            ..RedisPeriodLimiterConfig::default()
        }
        .aligned();

        assert!(config.fail_open);
        assert_eq!(config.alignment, PeriodAlignment::Natural);
    }
}