rs-zero 0.2.8

Rust-first microservice framework inspired by go-zero engineering practices
Documentation
use std::sync::Arc;

use axum::{
    body::Body,
    extract::{MatchedPath, Request},
    middleware::Next,
    response::{IntoResponse, Response},
};

use crate::{
    resil::{RedisPeriodLimiter, RedisTokenLimiter},
    rest::{
        RestError, RestRateLimiterConfig,
        middleware::{MiddlewareMetrics, record_resilience_event},
    },
};

/// Redis-backed REST rate limiter used by the default middleware stack.
#[derive(Debug, Clone)]
pub struct RestRateLimiter {
    mode: Arc<RestRateLimiterMode>,
}

#[derive(Debug)]
enum RestRateLimiterMode {
    Disabled,
    Token {
        limiter: Result<RedisTokenLimiter, String>,
        fail_open: bool,
    },
    Period {
        limiter: Result<RedisPeriodLimiter, String>,
        fail_open: bool,
    },
}

impl RestRateLimiter {
    /// Creates a REST limiter from configuration.
    pub fn new(config: RestRateLimiterConfig) -> Self {
        let mode = match config {
            RestRateLimiterConfig::Disabled => RestRateLimiterMode::Disabled,
            RestRateLimiterConfig::RedisToken(config) => {
                let fail_open = config.fail_open;
                RestRateLimiterMode::Token {
                    limiter: RedisTokenLimiter::new(config).map_err(|error| error.to_string()),
                    fail_open,
                }
            }
            RestRateLimiterConfig::RedisPeriod(config) => {
                let fail_open = config.fail_open;
                RestRateLimiterMode::Period {
                    limiter: RedisPeriodLimiter::new(config).map_err(|error| error.to_string()),
                    fail_open,
                }
            }
        };

        Self {
            mode: Arc::new(mode),
        }
    }

    /// Returns whether this limiter is disabled.
    pub fn is_disabled(&self) -> bool {
        matches!(self.mode.as_ref(), RestRateLimiterMode::Disabled)
    }

    async fn allow(&self, key: &str) -> RateLimitOutcome {
        match self.mode.as_ref() {
            RestRateLimiterMode::Disabled => RateLimitOutcome::Allowed,
            RestRateLimiterMode::Token { limiter, fail_open } => {
                allow_token(limiter, *fail_open, key).await
            }
            RestRateLimiterMode::Period { limiter, fail_open } => {
                allow_period(limiter, *fail_open, key).await
            }
        }
    }
}

#[derive(Debug, Clone, PartialEq, Eq)]
enum RateLimitOutcome {
    Allowed,
    Rejected,
    ErrorOpen,
    ErrorClosed(String),
}

async fn allow_token(
    limiter: &Result<RedisTokenLimiter, String>,
    fail_open: bool,
    key: &str,
) -> RateLimitOutcome {
    match limiter {
        Ok(limiter) => match limiter.try_allow_n(key, 1).await {
            Ok(true) => RateLimitOutcome::Allowed,
            Ok(false) => RateLimitOutcome::Rejected,
            Err(error) if fail_open => {
                let _ = error;
                RateLimitOutcome::ErrorOpen
            }
            Err(error) => RateLimitOutcome::ErrorClosed(error.to_string()),
        },
        Err(error) if fail_open => {
            let _ = error;
            RateLimitOutcome::ErrorOpen
        }
        Err(error) => RateLimitOutcome::ErrorClosed(error.clone()),
    }
}

async fn allow_period(
    limiter: &Result<RedisPeriodLimiter, String>,
    fail_open: bool,
    key: &str,
) -> RateLimitOutcome {
    match limiter {
        Ok(limiter) => match limiter.try_allow(key).await {
            Ok(true) => RateLimitOutcome::Allowed,
            Ok(false) => RateLimitOutcome::Rejected,
            Err(error) if fail_open => {
                let _ = error;
                RateLimitOutcome::ErrorOpen
            }
            Err(error) => RateLimitOutcome::ErrorClosed(error.to_string()),
        },
        Err(error) if fail_open => {
            let _ = error;
            RateLimitOutcome::ErrorOpen
        }
        Err(error) => RateLimitOutcome::ErrorClosed(error.clone()),
    }
}

/// Applies REST Redis rate limiting to one request.
pub async fn rate_limiter_middleware(
    limiter: RestRateLimiter,
    metrics: MiddlewareMetrics,
    request: Request<Body>,
    next: Next,
) -> Response {
    let key = limiter_key(&request);
    match limiter.allow(&key).await {
        RateLimitOutcome::Allowed => {
            record_resilience_event(&metrics, "limiter", "allowed");
            next.run(request).await
        }
        RateLimitOutcome::Rejected => {
            record_resilience_event(&metrics, "limiter", "rejected");
            RestError::RateLimited.into_response()
        }
        RateLimitOutcome::ErrorOpen => {
            record_resilience_event(&metrics, "limiter", "error_open");
            next.run(request).await
        }
        RateLimitOutcome::ErrorClosed(error) => {
            record_resilience_event(&metrics, "limiter", "error_closed");
            RestError::ServiceUnavailable(format!("rate limiter unavailable: {error}"))
                .into_response()
        }
    }
}

fn limiter_key(request: &Request<Body>) -> String {
    let method = request.method().as_str();
    let route = request
        .extensions()
        .get::<MatchedPath>()
        .map(|path| path.as_str())
        .unwrap_or("unknown");
    format!("{method}:{route}")
}