Skip to main content

forge_core/rate_limit/
mod.rs

1use std::str::FromStr;
2use std::time::Duration;
3
4use chrono::{DateTime, Utc};
5
6use crate::ForgeError;
7use crate::util::parse_duration;
8
9mod backend;
10
11pub use backend::RateLimiterBackend;
12
13/// Rate limit key type for bucketing.
14#[derive(Debug, Clone, Default, PartialEq, Eq)]
15#[non_exhaustive]
16pub enum RateLimitKey {
17    /// Per-user rate limiting.
18    #[default]
19    User,
20    /// Per-IP rate limiting.
21    Ip,
22    /// Per-tenant rate limiting.
23    Tenant,
24    /// Per-user-action rate limiting (combines user and action).
25    UserAction,
26    /// Global rate limiting (single bucket for all requests).
27    Global,
28    /// Custom claim-based bucketing. The inner string is the JWT claim name
29    /// whose value is used as the bucket discriminator. Backends receive this
30    /// variant via `build_key` and may further customise the resolution logic.
31    ///
32    /// **Macro syntax**: `#[query(rate_limit(requests = 100, per = "1m", key = "custom:claim_name"))]`
33    Custom(String),
34}
35
36impl RateLimitKey {
37    /// Convert to a string representation.
38    ///
39    /// For [`Self::Custom`] this returns `"custom"` — use [`Self::custom_name`]
40    /// to retrieve the inner claim name.
41    pub fn as_str(&self) -> &str {
42        match self {
43            Self::User => "user",
44            Self::Ip => "ip",
45            Self::Tenant => "tenant",
46            Self::UserAction => "user_action",
47            Self::Global => "global",
48            Self::Custom(_) => "custom",
49        }
50    }
51
52    /// Return the inner claim name for [`Self::Custom`], or `None` for the
53    /// standard variants.
54    pub fn custom_name(&self) -> Option<&str> {
55        match self {
56            Self::Custom(name) => Some(name.as_str()),
57            _ => None,
58        }
59    }
60}
61
62/// Error returned when parsing an unknown rate limit key string.
63#[derive(Debug, Clone)]
64pub struct ParseRateLimitKeyError(pub String);
65
66impl std::fmt::Display for ParseRateLimitKeyError {
67    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
68        write!(
69            f,
70            "unknown rate limit key \"{}\". Expected one of: user, ip, tenant, user_action, global, custom:<name>",
71            self.0
72        )
73    }
74}
75
76impl std::error::Error for ParseRateLimitKeyError {}
77
78impl FromStr for RateLimitKey {
79    type Err = ParseRateLimitKeyError;
80
81    fn from_str(s: &str) -> Result<Self, Self::Err> {
82        match s {
83            "user" => Ok(Self::User),
84            "ip" => Ok(Self::Ip),
85            "tenant" => Ok(Self::Tenant),
86            "user_action" => Ok(Self::UserAction),
87            "global" => Ok(Self::Global),
88            _ if s.starts_with("custom:") => {
89                Ok(Self::Custom(s.trim_start_matches("custom:").to_string()))
90            }
91            _ => Err(ParseRateLimitKeyError(s.to_string())),
92        }
93    }
94}
95
96/// Rate limit configuration.
97#[derive(Debug, Clone)]
98#[non_exhaustive]
99pub struct RateLimitConfig {
100    /// Maximum requests allowed.
101    pub requests: u32,
102    /// Time window for the limit.
103    pub per: Duration,
104    /// Key type for bucketing.
105    pub key: RateLimitKey,
106}
107
108impl RateLimitConfig {
109    /// Create a new rate limit config.
110    pub fn new(requests: u32, per: Duration) -> Self {
111        Self {
112            requests,
113            per,
114            key: RateLimitKey::default(),
115        }
116    }
117
118    /// Create with a specific key type.
119    pub fn with_key(mut self, key: RateLimitKey) -> Self {
120        self.key = key;
121        self
122    }
123
124    /// Calculate the refill rate (tokens per second).
125    pub fn refill_rate(&self) -> f64 {
126        self.requests as f64 / self.per.as_secs_f64()
127    }
128
129    /// Parse duration from string like "1m", "1h", "1d", "100ms".
130    /// Delegates to [`crate::util::parse_duration`].
131    pub fn parse_duration(s: &str) -> Option<Duration> {
132        parse_duration(s)
133    }
134}
135
136impl Default for RateLimitConfig {
137    fn default() -> Self {
138        Self {
139            requests: 100,
140            per: Duration::from_secs(60),
141            key: RateLimitKey::User,
142        }
143    }
144}
145
146/// Result of a rate limit check.
147#[derive(Debug, Clone)]
148pub struct RateLimitResult {
149    /// Whether the request is allowed.
150    pub allowed: bool,
151    /// Remaining requests in the current window.
152    pub remaining: u32,
153    /// When the limit resets.
154    pub reset_at: DateTime<Utc>,
155    /// Time to wait before retrying (if not allowed).
156    pub retry_after: Option<Duration>,
157}
158
159impl RateLimitResult {
160    /// Create a result for an allowed request.
161    pub fn allowed(remaining: u32, reset_at: DateTime<Utc>) -> Self {
162        Self {
163            allowed: true,
164            remaining,
165            reset_at,
166            retry_after: None,
167        }
168    }
169
170    /// Create a result for a denied request.
171    pub fn denied(remaining: u32, reset_at: DateTime<Utc>, retry_after: Duration) -> Self {
172        Self {
173            allowed: false,
174            remaining,
175            reset_at,
176            retry_after: Some(retry_after),
177        }
178    }
179
180    /// Convert to a ForgeError if rate limited.
181    pub fn to_error(&self, limit: u32) -> Option<ForgeError> {
182        if self.allowed {
183            None
184        } else {
185            Some(ForgeError::RateLimitExceeded {
186                retry_after: self.retry_after.unwrap_or(Duration::from_secs(1)),
187                limit,
188                remaining: self.remaining,
189            })
190        }
191    }
192}
193
194/// HTTP headers for rate limiting.
195#[derive(Debug, Clone)]
196pub struct RateLimitHeaders {
197    /// X-RateLimit-Limit header value.
198    pub limit: u32,
199    /// X-RateLimit-Remaining header value.
200    pub remaining: u32,
201    /// X-RateLimit-Reset header value (Unix timestamp).
202    pub reset: i64,
203    /// Retry-After header value (seconds).
204    pub retry_after: Option<u64>,
205}
206
207impl RateLimitHeaders {
208    /// Create headers from a rate limit result.
209    pub fn from_result(result: &RateLimitResult, limit: u32) -> Self {
210        Self {
211            limit,
212            remaining: result.remaining,
213            reset: result.reset_at.timestamp(),
214            retry_after: result.retry_after.map(|d| d.as_secs()),
215        }
216    }
217}
218
219#[cfg(test)]
220#[allow(clippy::unwrap_used, clippy::indexing_slicing, clippy::panic)]
221mod tests {
222    use super::*;
223
224    #[test]
225    fn test_rate_limit_key() {
226        assert_eq!(RateLimitKey::User.as_str(), "user");
227        assert_eq!(RateLimitKey::Ip.as_str(), "ip");
228        assert_eq!(RateLimitKey::Global.as_str(), "global");
229        assert_eq!(
230            RateLimitKey::Custom("tenant_id".to_string()).as_str(),
231            "custom"
232        );
233        assert_eq!(
234            RateLimitKey::Custom("tenant_id".to_string()).custom_name(),
235            Some("tenant_id")
236        );
237
238        assert_eq!("user".parse::<RateLimitKey>().unwrap(), RateLimitKey::User);
239        assert_eq!("ip".parse::<RateLimitKey>().unwrap(), RateLimitKey::Ip);
240        assert_eq!(
241            "custom:tenant_id".parse::<RateLimitKey>().unwrap(),
242            RateLimitKey::Custom("tenant_id".to_string())
243        );
244    }
245
246    #[test]
247    fn test_rate_limit_config() {
248        let config = RateLimitConfig::new(100, Duration::from_secs(60));
249        assert_eq!(config.requests, 100);
250        assert_eq!(config.per, Duration::from_secs(60));
251        assert!((config.refill_rate() - 1.666666).abs() < 0.01);
252    }
253
254    #[test]
255    fn test_parse_duration() {
256        assert_eq!(
257            RateLimitConfig::parse_duration("1s"),
258            Some(Duration::from_secs(1))
259        );
260        assert_eq!(
261            RateLimitConfig::parse_duration("1m"),
262            Some(Duration::from_secs(60))
263        );
264        assert_eq!(
265            RateLimitConfig::parse_duration("1h"),
266            Some(Duration::from_secs(3600))
267        );
268        assert_eq!(
269            RateLimitConfig::parse_duration("1d"),
270            Some(Duration::from_secs(86400))
271        );
272        assert_eq!(RateLimitConfig::parse_duration("invalid"), None);
273    }
274
275    #[test]
276    fn test_rate_limit_result_allowed() {
277        let result = RateLimitResult::allowed(99, Utc::now());
278        assert!(result.allowed);
279        assert!(result.retry_after.is_none());
280        assert!(result.to_error(100).is_none());
281    }
282
283    #[test]
284    fn test_rate_limit_result_denied() {
285        let result = RateLimitResult::denied(0, Utc::now(), Duration::from_secs(30));
286        assert!(!result.allowed);
287        assert!(result.retry_after.is_some());
288        assert!(result.to_error(100).is_some());
289    }
290
291    #[test]
292    fn rate_limit_key_default_is_user() {
293        // Defaulting to User means handlers without an explicit key are scoped
294        // per-principal — flipping this to Global would silently make limits
295        // shared across all callers.
296        assert_eq!(RateLimitKey::default(), RateLimitKey::User);
297    }
298
299    #[test]
300    fn rate_limit_key_as_str_covers_all_standard_variants() {
301        assert_eq!(RateLimitKey::Tenant.as_str(), "tenant");
302        assert_eq!(RateLimitKey::UserAction.as_str(), "user_action");
303    }
304
305    #[test]
306    fn rate_limit_key_custom_name_is_none_for_standard_variants() {
307        for variant in [
308            RateLimitKey::User,
309            RateLimitKey::Ip,
310            RateLimitKey::Tenant,
311            RateLimitKey::UserAction,
312            RateLimitKey::Global,
313        ] {
314            assert_eq!(variant.custom_name(), None);
315        }
316    }
317
318    #[test]
319    fn rate_limit_key_parse_covers_all_named_variants() {
320        assert_eq!(
321            "tenant".parse::<RateLimitKey>().unwrap(),
322            RateLimitKey::Tenant
323        );
324        assert_eq!(
325            "user_action".parse::<RateLimitKey>().unwrap(),
326            RateLimitKey::UserAction
327        );
328        assert_eq!(
329            "global".parse::<RateLimitKey>().unwrap(),
330            RateLimitKey::Global
331        );
332    }
333
334    #[test]
335    fn rate_limit_key_parse_custom_extracts_inner_name() {
336        let parsed = "custom:org_id".parse::<RateLimitKey>().unwrap();
337        assert_eq!(parsed, RateLimitKey::Custom("org_id".to_string()));
338        assert_eq!(parsed.custom_name(), Some("org_id"));
339
340        // Empty inner name is still accepted at parse time; the backend decides
341        // what to do with it.
342        let empty = "custom:".parse::<RateLimitKey>().unwrap();
343        assert_eq!(empty, RateLimitKey::Custom(String::new()));
344    }
345
346    #[test]
347    fn rate_limit_key_parse_unknown_returns_descriptive_error() {
348        let err = "bogus".parse::<RateLimitKey>().unwrap_err();
349        let msg = err.to_string();
350        assert!(msg.contains("bogus"), "error should echo input: {msg}");
351        assert!(
352            msg.contains("user, ip, tenant, user_action, global, custom:<name>"),
353            "error should list valid keys: {msg}"
354        );
355    }
356
357    #[test]
358    fn rate_limit_config_default_matches_documented_values() {
359        let cfg = RateLimitConfig::default();
360        assert_eq!(cfg.requests, 100);
361        assert_eq!(cfg.per, Duration::from_secs(60));
362        assert_eq!(cfg.key, RateLimitKey::User);
363    }
364
365    #[test]
366    fn rate_limit_config_with_key_overrides_default() {
367        let cfg = RateLimitConfig::new(10, Duration::from_secs(1)).with_key(RateLimitKey::Ip);
368        assert_eq!(cfg.key, RateLimitKey::Ip);
369        assert_eq!(cfg.requests, 10);
370        assert_eq!(cfg.per, Duration::from_secs(1));
371    }
372
373    #[test]
374    fn rate_limit_config_refill_rate_handles_burst_window() {
375        // 60 requests/30s = 2 tokens/sec
376        let cfg = RateLimitConfig::new(60, Duration::from_secs(30));
377        assert!((cfg.refill_rate() - 2.0).abs() < 1e-9);
378    }
379
380    #[test]
381    fn to_error_carries_retry_after_and_limit_metadata() {
382        let result = RateLimitResult::denied(3, Utc::now(), Duration::from_secs(42));
383        let err = result.to_error(100).expect("denied result yields error");
384        match err {
385            ForgeError::RateLimitExceeded {
386                retry_after,
387                limit,
388                remaining,
389            } => {
390                assert_eq!(retry_after, Duration::from_secs(42));
391                assert_eq!(limit, 100);
392                assert_eq!(remaining, 3);
393            }
394            other => panic!("expected RateLimitExceeded, got {other:?}"),
395        }
396    }
397
398    #[test]
399    fn to_error_falls_back_to_1s_when_retry_after_missing() {
400        // Construct a denied-shaped result with no retry_after to confirm the
401        // documented 1-second fallback in to_error.
402        let result = RateLimitResult {
403            allowed: false,
404            remaining: 0,
405            reset_at: Utc::now(),
406            retry_after: None,
407        };
408        match result.to_error(5).expect("denied yields error") {
409            ForgeError::RateLimitExceeded { retry_after, .. } => {
410                assert_eq!(retry_after, Duration::from_secs(1));
411            }
412            other => panic!("expected RateLimitExceeded, got {other:?}"),
413        }
414    }
415
416    #[test]
417    fn rate_limit_headers_from_allowed_result_omits_retry_after() {
418        let reset = Utc::now();
419        let result = RateLimitResult::allowed(7, reset);
420        let headers = RateLimitHeaders::from_result(&result, 10);
421        assert_eq!(headers.limit, 10);
422        assert_eq!(headers.remaining, 7);
423        assert_eq!(headers.reset, reset.timestamp());
424        assert_eq!(headers.retry_after, None);
425    }
426
427    #[test]
428    fn rate_limit_headers_from_denied_result_carries_retry_after_seconds() {
429        let reset = Utc::now();
430        let result = RateLimitResult::denied(0, reset, Duration::from_secs(15));
431        let headers = RateLimitHeaders::from_result(&result, 10);
432        assert_eq!(headers.retry_after, Some(15));
433        assert_eq!(headers.remaining, 0);
434    }
435}