rs-zero 0.2.6

Rust-first microservice framework inspired by go-zero engineering practices
Documentation
use std::sync::Arc;

use crate::{
    resil::{RedisPeriodLimiter, RedisTokenLimiter},
    rpc::RpcRateLimiterConfig,
};

#[derive(Debug, Clone)]
pub(crate) struct RpcRateLimiter {
    mode: Arc<RpcRateLimiterMode>,
}

#[derive(Debug)]
enum RpcRateLimiterMode {
    Disabled,
    Token {
        limiter: Result<RedisTokenLimiter, String>,
        fail_open: bool,
    },
    Period {
        limiter: Result<RedisPeriodLimiter, String>,
        fail_open: bool,
    },
}

impl RpcRateLimiter {
    pub(crate) fn new(config: RpcRateLimiterConfig) -> Self {
        let mode = match config {
            RpcRateLimiterConfig::Disabled => RpcRateLimiterMode::Disabled,
            RpcRateLimiterConfig::RedisToken(config) => {
                let fail_open = config.fail_open;
                RpcRateLimiterMode::Token {
                    limiter: RedisTokenLimiter::new(config).map_err(|error| error.to_string()),
                    fail_open,
                }
            }
            RpcRateLimiterConfig::RedisPeriod(config) => {
                let fail_open = config.fail_open;
                RpcRateLimiterMode::Period {
                    limiter: RedisPeriodLimiter::new(config).map_err(|error| error.to_string()),
                    fail_open,
                }
            }
        };

        Self {
            mode: Arc::new(mode),
        }
    }

    pub(crate) async fn allow(&self, key: &str) -> RpcRateLimitOutcome {
        match self.mode.as_ref() {
            RpcRateLimiterMode::Disabled => RpcRateLimitOutcome::Allowed,
            RpcRateLimiterMode::Token { limiter, fail_open } => {
                allow_token(limiter, *fail_open, key).await
            }
            RpcRateLimiterMode::Period { limiter, fail_open } => {
                allow_period(limiter, *fail_open, key).await
            }
        }
    }
}

#[derive(Debug, Clone, PartialEq, Eq)]
pub(crate) enum RpcRateLimitOutcome {
    Allowed,
    Rejected,
    ErrorOpen,
    ErrorClosed(String),
}

async fn allow_token(
    limiter: &Result<RedisTokenLimiter, String>,
    fail_open: bool,
    key: &str,
) -> RpcRateLimitOutcome {
    match limiter {
        Ok(limiter) => match limiter.try_allow_n(key, 1).await {
            Ok(true) => RpcRateLimitOutcome::Allowed,
            Ok(false) => RpcRateLimitOutcome::Rejected,
            Err(_) if fail_open => RpcRateLimitOutcome::ErrorOpen,
            Err(error) => RpcRateLimitOutcome::ErrorClosed(error.to_string()),
        },
        Err(_) if fail_open => RpcRateLimitOutcome::ErrorOpen,
        Err(error) => RpcRateLimitOutcome::ErrorClosed(error.clone()),
    }
}

async fn allow_period(
    limiter: &Result<RedisPeriodLimiter, String>,
    fail_open: bool,
    key: &str,
) -> RpcRateLimitOutcome {
    match limiter {
        Ok(limiter) => match limiter.try_allow(key).await {
            Ok(true) => RpcRateLimitOutcome::Allowed,
            Ok(false) => RpcRateLimitOutcome::Rejected,
            Err(_) if fail_open => RpcRateLimitOutcome::ErrorOpen,
            Err(error) => RpcRateLimitOutcome::ErrorClosed(error.to_string()),
        },
        Err(_) if fail_open => RpcRateLimitOutcome::ErrorOpen,
        Err(error) => RpcRateLimitOutcome::ErrorClosed(error.clone()),
    }
}