Skip to main content

forge_core/rate_limit/
backend.rs

1use std::future::Future;
2use std::pin::Pin;
3
4use crate::function::{AuthContext, RequestMetadata};
5use crate::rate_limit::{RateLimitConfig, RateLimitKey, RateLimitResult};
6use crate::{ForgeError, Result};
7
8/// Pluggable rate-limiter implementation.
9///
10/// The runtime exposes two implementations:
11/// - `HybridRateLimiter`: per-node DashMap fast path with PG fallback for
12///   `Global` keys. Approximate under multi-node deployments — user/IP limits
13///   multiply by the node count. Right for DDoS protection.
14/// - `StrictRateLimiter`: every check round-trips to PostgreSQL. Cluster-wide
15///   correct. Right for billing-grade or quota enforcement.
16///
17/// Both ship with the framework. Users implement this trait themselves only
18/// when their backing store sits outside the runtime's PG-only contract.
19pub trait RateLimiterBackend: Send + Sync + 'static {
20    /// Check whether a single token is available for the given bucket.
21    fn check<'a>(
22        &'a self,
23        bucket_key: &'a str,
24        config: &'a RateLimitConfig,
25    ) -> Pin<Box<dyn Future<Output = Result<RateLimitResult>> + Send + 'a>>;
26
27    /// Build the bucket key string for a (key kind, action, auth, request) tuple.
28    fn build_key(
29        &self,
30        key_type: RateLimitKey,
31        action_name: &str,
32        auth: &AuthContext,
33        request: &RequestMetadata,
34    ) -> String;
35
36    /// Check and convert a denial into a [`ForgeError::RateLimitExceeded`].
37    fn enforce<'a>(
38        &'a self,
39        bucket_key: &'a str,
40        config: &'a RateLimitConfig,
41    ) -> Pin<Box<dyn Future<Output = Result<RateLimitResult>> + Send + 'a>> {
42        Box::pin(async move {
43            let result = self.check(bucket_key, config).await?;
44            if !result.allowed {
45                return Err(ForgeError::RateLimitExceeded {
46                    retry_after: result
47                        .retry_after
48                        .unwrap_or(std::time::Duration::from_secs(1)),
49                    limit: config.requests,
50                    remaining: result.remaining,
51                });
52            }
53            Ok(result)
54        })
55    }
56}