rs-zero 0.2.6

Rust-first microservice framework inspired by go-zero engineering practices
Documentation
use std::{
    sync::{
        Arc,
        atomic::{AtomicBool, Ordering},
    },
    time::{Duration, SystemTime, UNIX_EPOCH},
};

use tokio::sync::Mutex;

use crate::cache_redis::{
    RedisCacheConfig, RedisCacheError, RedisCacheResult, RedisClientFactory, RedisLuaScript,
};

const TOKEN_BUCKET_SCRIPT: &str = r#"
local key = KEYS[1]
local now = tonumber(ARGV[1])
local rate = tonumber(ARGV[2])
local burst = tonumber(ARGV[3])
local requested = tonumber(ARGV[4])
local ttl = tonumber(ARGV[5])
local values = redis.call("HMGET", key, "tokens", "ts")
local tokens = tonumber(values[1]) or burst
local ts = tonumber(values[2]) or now
local elapsed = math.max(0, now - ts) / 1000
tokens = math.min(burst, tokens + elapsed * rate)
local allowed = 0
if tokens >= requested then
  tokens = tokens - requested
  allowed = 1
end
redis.call("HMSET", key, "tokens", tokens, "ts", now)
redis.call("PEXPIRE", key, ttl)
return allowed
"#;

/// Local rescue behavior for [`RedisTokenLimiter`].
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct TokenLimiterRescueConfig {
    /// Whether Redis errors should switch to the local token bucket.
    pub enabled: bool,
    /// How often the background monitor checks Redis recovery.
    pub ping_interval: Duration,
}

impl Default for TokenLimiterRescueConfig {
    fn default() -> Self {
        Self {
            enabled: false,
            ping_interval: Duration::from_millis(100),
        }
    }
}

impl TokenLimiterRescueConfig {
    /// Returns go-zero style rescue behavior.
    pub fn go_zero_defaults() -> Self {
        Self {
            enabled: true,
            ..Self::default()
        }
    }
}

/// Redis-backed token bucket limiter configuration.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct RedisTokenLimiterConfig {
    /// Redis connection configuration.
    pub redis: RedisCacheConfig,
    /// Namespace prefix for limiter keys.
    pub key_prefix: String,
    /// Refill rate in tokens per second.
    pub rate_per_second: u32,
    /// Maximum token bucket capacity.
    pub burst: u32,
    /// Whether Redis errors should allow traffic instead of failing closed.
    pub fail_open: bool,
    /// Local limiter used while Redis is unavailable.
    pub rescue: TokenLimiterRescueConfig,
    /// Redis key TTL.
    pub ttl: Duration,
}

impl Default for RedisTokenLimiterConfig {
    fn default() -> Self {
        Self {
            redis: RedisCacheConfig::default(),
            key_prefix: "rs-zero:limit:token".to_string(),
            rate_per_second: 100,
            burst: 100,
            fail_open: false,
            rescue: TokenLimiterRescueConfig::default(),
            ttl: Duration::from_secs(60),
        }
    }
}

impl RedisTokenLimiterConfig {
    /// Returns a go-zero style limiter profile with local rescue enabled.
    pub fn go_zero_defaults() -> Self {
        Self {
            rescue: TokenLimiterRescueConfig::go_zero_defaults(),
            ..Self::default()
        }
    }
}

/// Redis-backed token bucket limiter.
#[derive(Debug, Clone)]
pub struct RedisTokenLimiter {
    config: RedisTokenLimiterConfig,
    client: redis::Client,
    script: RedisLuaScript,
    rescue: Arc<Mutex<LocalTokenBucket>>,
    redis_alive: Arc<AtomicBool>,
    monitor_started: Arc<AtomicBool>,
}

impl RedisTokenLimiter {
    /// Creates a limiter and validates local configuration.
    pub fn new(config: RedisTokenLimiterConfig) -> RedisCacheResult<Self> {
        validate_config(&config)?;
        let client = RedisClientFactory::new(config.redis.clone())?.create_client()?;
        Ok(Self {
            rescue: Arc::new(Mutex::new(LocalTokenBucket::new(
                config.burst,
                now_millis(),
            ))),
            redis_alive: Arc::new(AtomicBool::new(true)),
            monitor_started: Arc::new(AtomicBool::new(false)),
            config,
            client,
            script: RedisLuaScript::new(TOKEN_BUCKET_SCRIPT),
        })
    }

    /// Returns whether the limiter currently believes Redis is usable.
    pub fn redis_alive(&self) -> bool {
        self.redis_alive.load(Ordering::Acquire)
    }

    /// Allows one token for the given key.
    pub async fn allow(&self, key: &str) -> RedisCacheResult<bool> {
        self.allow_n(key, 1).await
    }

    /// Allows `requested` tokens for the given key.
    pub async fn allow_n(&self, key: &str, requested: u32) -> RedisCacheResult<bool> {
        if requested == 0 {
            return Ok(true);
        }

        if self.config.rescue.enabled && !self.redis_alive() {
            return self.allow_rescue(requested).await;
        }

        let result = self.try_allow_n(key, requested).await;
        match result {
            Ok(allowed) => Ok(allowed),
            Err(_) if self.config.rescue.enabled => {
                self.start_monitor();
                self.allow_rescue(requested).await
            }
            Err(_) if self.config.fail_open => Ok(true),
            Err(error) => Err(error),
        }
    }

    /// Checks `requested` tokens without applying fail-open or rescue behavior.
    pub async fn try_allow_n(&self, key: &str, requested: u32) -> RedisCacheResult<bool> {
        let mut connection = tokio::time::timeout(
            self.config.redis.connect_timeout,
            self.client.get_multiplexed_async_connection(),
        )
        .await
        .map_err(|_| RedisCacheError::Timeout("connect".to_string()))?
        .map_err(|error| RedisCacheError::Connection(error.to_string()))?;
        let key = format!("{}:{key}", self.config.key_prefix);
        let allowed = self
            .script
            .invoke_async::<i64>(
                &mut connection,
                &[key],
                &[
                    now_millis().to_string(),
                    self.config.rate_per_second.to_string(),
                    self.config.burst.to_string(),
                    requested.to_string(),
                    (self.config.ttl.as_millis().max(1) as u64).to_string(),
                ],
                self.config.redis.command_timeout,
            )
            .await?;
        Ok(allowed == 1)
    }

    async fn allow_rescue(&self, requested: u32) -> RedisCacheResult<bool> {
        let mut bucket = self.rescue.lock().await;
        Ok(bucket.allow(
            now_millis(),
            self.config.rate_per_second,
            self.config.burst,
            requested,
        ))
    }

    fn start_monitor(&self) {
        self.redis_alive.store(false, Ordering::Release);
        if self
            .monitor_started
            .compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire)
            .is_err()
        {
            return;
        }

        let client = self.client.clone();
        let connect_timeout = self.config.redis.connect_timeout;
        let command_timeout = self.config.redis.command_timeout;
        let ping_interval = self.config.rescue.ping_interval;
        let redis_alive = self.redis_alive.clone();
        let monitor_started = self.monitor_started.clone();

        tokio::spawn(async move {
            let mut interval = tokio::time::interval(ping_interval);
            loop {
                interval.tick().await;
                if redis_ping(&client, connect_timeout, command_timeout).await {
                    redis_alive.store(true, Ordering::Release);
                    monitor_started.store(false, Ordering::Release);
                    return;
                }
            }
        });
    }
}

async fn redis_ping(
    client: &redis::Client,
    connect_timeout: Duration,
    command_timeout: Duration,
) -> bool {
    let connection =
        tokio::time::timeout(connect_timeout, client.get_multiplexed_async_connection()).await;
    let Ok(Ok(mut connection)) = connection else {
        return false;
    };

    let response = tokio::time::timeout(
        command_timeout,
        redis::cmd("PING").query_async::<String>(&mut connection),
    )
    .await;
    matches!(response, Ok(Ok(value)) if value == "PONG")
}

fn validate_config(config: &RedisTokenLimiterConfig) -> RedisCacheResult<()> {
    config.redis.validate()?;
    if config.key_prefix.trim().is_empty() {
        return Err(RedisCacheError::InvalidConfig(
            "token limiter key_prefix is required".to_string(),
        ));
    }
    if config.rate_per_second == 0 || config.burst == 0 {
        return Err(RedisCacheError::InvalidConfig(
            "token limiter rate_per_second and burst must be greater than zero".to_string(),
        ));
    }
    if config.rescue.enabled && config.rescue.ping_interval.is_zero() {
        return Err(RedisCacheError::InvalidConfig(
            "token limiter rescue ping_interval must be greater than zero".to_string(),
        ));
    }
    Ok(())
}

fn now_millis() -> u64 {
    SystemTime::now()
        .duration_since(UNIX_EPOCH)
        .unwrap_or_default()
        .as_millis() as u64
}

#[derive(Debug, Clone, Copy, PartialEq)]
struct LocalTokenBucket {
    tokens: f64,
    last_millis: u64,
}

impl LocalTokenBucket {
    fn new(burst: u32, now_millis: u64) -> Self {
        Self {
            tokens: burst as f64,
            last_millis: now_millis,
        }
    }

    fn allow(&mut self, now_millis: u64, rate_per_second: u32, burst: u32, requested: u32) -> bool {
        let elapsed = now_millis.saturating_sub(self.last_millis) as f64 / 1000.0;
        self.tokens = (self.tokens + elapsed * rate_per_second as f64).min(burst as f64);
        self.last_millis = now_millis;
        if self.tokens >= requested as f64 {
            self.tokens -= requested as f64;
            true
        } else {
            false
        }
    }
}

#[cfg(test)]
mod tests {
    use super::{
        LocalTokenBucket, RedisTokenLimiterConfig, TokenLimiterRescueConfig, validate_config,
    };
    use crate::cache_redis::RedisCacheError;

    #[test]
    fn local_token_bucket_refills_over_time() {
        let mut bucket = LocalTokenBucket {
            tokens: 2.0,
            last_millis: 0,
        };

        assert!(bucket.allow(0, 2, 2, 2));
        assert!(!bucket.allow(0, 2, 2, 1));
        assert!(bucket.allow(500, 2, 2, 1));
    }

    #[test]
    fn go_zero_defaults_enable_rescue() {
        let config = RedisTokenLimiterConfig::go_zero_defaults();

        assert!(config.rescue.enabled);
        assert!(!config.fail_open);
    }

    #[test]
    fn rescue_ping_interval_must_be_positive() {
        let config = RedisTokenLimiterConfig {
            rescue: TokenLimiterRescueConfig {
                enabled: true,
                ping_interval: std::time::Duration::ZERO,
            },
            ..RedisTokenLimiterConfig::default()
        };

        let error = validate_config(&config).expect_err("invalid config");
        assert!(
            matches!(error, RedisCacheError::InvalidConfig(message) if message.contains("ping_interval"))
        );
    }
}