rok-rate-limit 0.3.0

Rate limiting Tower middleware and programmatic Limiter API for the rok ecosystem
Documentation
use std::{sync::Arc, time::Duration};

use crate::backend::MemoryBackend;

// ── LimitResult ───────────────────────────────────────────────────────────────

/// Result of a single rate-limit check.
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum LimitResult {
    /// Request is within the limit.
    Allowed {
        /// Remaining requests in the current window.
        remaining: u64,
        /// Unix timestamp (seconds) when the window resets.
        reset_epoch: u64,
    },
    /// Limit exceeded — caller should return 429.
    Exceeded {
        /// Seconds until the current window resets.
        retry_after_secs: u64,
    },
}

impl LimitResult {
    pub fn is_allowed(&self) -> bool {
        matches!(self, Self::Allowed { .. })
    }

    pub fn is_exceeded(&self) -> bool {
        matches!(self, Self::Exceeded { .. })
    }
}

// ── Limiter ───────────────────────────────────────────────────────────────────

/// Shared rate-limiter handle. Cheap to clone — all state lives behind an
/// `Arc`.
///
/// # Example
///
/// ```rust,ignore
/// let limiter = Limiter::memory();
///
/// let result = limiter.for_key("user:123")
///     .requests(10)
///     .per(Duration::from_secs(60))
///     .check();
///
/// match result {
///     LimitResult::Allowed { remaining, .. } => { /* proceed */ }
///     LimitResult::Exceeded { retry_after_secs } => { /* 429 */ }
/// }
/// ```
#[derive(Clone)]
pub struct Limiter {
    backend: Arc<MemoryBackend>,
    /// Secondary in-memory guard used by `block_in_memory_after`.  Kept
    /// separate so that the primary backend can be swapped for Redis in future
    /// while the block guard always stays in-process.
    block_backend: Arc<MemoryBackend>,
}

impl Default for Limiter {
    fn default() -> Self {
        Self::memory()
    }
}

impl Limiter {
    /// Create a new limiter backed by an in-memory fixed-window counter.
    pub fn memory() -> Self {
        Self {
            backend: Arc::new(MemoryBackend::new()),
            block_backend: Arc::new(MemoryBackend::new()),
        }
    }

    /// Start building a rate-limit check for `key`.
    pub fn for_key(&self, key: impl Into<String>) -> LimitBuilder {
        LimitBuilder {
            limiter: self.clone(),
            key: key.into(),
            requests: 60,
            window_secs: 60,
            block_after: None,
        }
    }
}

// ── LimitBuilder ──────────────────────────────────────────────────────────────

/// Fluent builder for a single rate-limit check produced by [`Limiter::for_key`].
pub struct LimitBuilder {
    limiter: Limiter,
    key: String,
    requests: u64,
    window_secs: u64,
    block_after: Option<u64>,
}

impl LimitBuilder {
    /// Maximum number of requests allowed in the window.
    pub fn requests(mut self, n: u64) -> Self {
        self.requests = n;
        self
    }

    /// Window duration.
    pub fn per(mut self, d: Duration) -> Self {
        self.window_secs = d.as_secs().max(1);
        self
    }

    /// Short-circuit and return `Exceeded` immediately once the in-process
    /// counter hits `n`, without consulting any external storage.  Useful for
    /// shielding Redis from clearly abusive traffic when Redis is the primary
    /// backend.
    pub fn block_in_memory_after(mut self, n: u64) -> Self {
        self.block_after = Some(n);
        self
    }

    /// Run the check and return a [`LimitResult`].  Synchronous — the
    /// in-memory backend requires no I/O.
    pub fn check(self) -> LimitResult {
        let now_epoch = epoch_secs();

        // Fast-path: secondary in-memory block guard
        if let Some(block_limit) = self.block_after {
            let (block_count, reset_epoch) = self
                .limiter
                .block_backend
                .increment(&format!("block:{}", self.key), self.window_secs);

            if block_count > block_limit {
                let retry_after_secs = reset_epoch.saturating_sub(now_epoch).max(1);
                return LimitResult::Exceeded { retry_after_secs };
            }
        }

        let (count, reset_epoch) = self.limiter.backend.increment(&self.key, self.window_secs);

        if count > self.requests {
            let retry_after_secs = reset_epoch.saturating_sub(now_epoch).max(1);
            LimitResult::Exceeded { retry_after_secs }
        } else {
            LimitResult::Allowed {
                remaining: self.requests.saturating_sub(count),
                reset_epoch,
            }
        }
    }
}

fn epoch_secs() -> u64 {
    use std::time::{SystemTime, UNIX_EPOCH};
    SystemTime::now()
        .duration_since(UNIX_EPOCH)
        .unwrap_or_default()
        .as_secs()
}