Skip to main content

forge_core/rate_limit/
mod.rs

1use std::str::FromStr;
2use std::time::Duration;
3
4use chrono::{DateTime, Utc};
5
6use crate::ForgeError;
7
8/// Rate limit key type for bucketing.
9#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
10pub enum RateLimitKey {
11    /// Per-user rate limiting.
12    #[default]
13    User,
14    /// Per-IP rate limiting.
15    Ip,
16    /// Per-tenant rate limiting.
17    Tenant,
18    /// Per-user-action rate limiting (combines user and action).
19    UserAction,
20    /// Global rate limiting (single bucket for all requests).
21    Global,
22}
23
24impl RateLimitKey {
25    /// Convert to string.
26    pub fn as_str(&self) -> &'static str {
27        match self {
28            Self::User => "user",
29            Self::Ip => "ip",
30            Self::Tenant => "tenant",
31            Self::UserAction => "user_action",
32            Self::Global => "global",
33        }
34    }
35}
36
37impl FromStr for RateLimitKey {
38    type Err = std::convert::Infallible;
39
40    fn from_str(s: &str) -> Result<Self, Self::Err> {
41        Ok(match s {
42            "user" => Self::User,
43            "ip" => Self::Ip,
44            "tenant" => Self::Tenant,
45            "user_action" => Self::UserAction,
46            "global" => Self::Global,
47            _ => Self::User,
48        })
49    }
50}
51
52/// Rate limit configuration.
53#[derive(Debug, Clone)]
54pub struct RateLimitConfig {
55    /// Maximum requests allowed.
56    pub requests: u32,
57    /// Time window for the limit.
58    pub per: Duration,
59    /// Key type for bucketing.
60    pub key: RateLimitKey,
61}
62
63impl RateLimitConfig {
64    /// Create a new rate limit config.
65    pub fn new(requests: u32, per: Duration) -> Self {
66        Self {
67            requests,
68            per,
69            key: RateLimitKey::default(),
70        }
71    }
72
73    /// Create with a specific key type.
74    pub fn with_key(mut self, key: RateLimitKey) -> Self {
75        self.key = key;
76        self
77    }
78
79    /// Calculate the refill rate (tokens per second).
80    pub fn refill_rate(&self) -> f64 {
81        self.requests as f64 / self.per.as_secs_f64()
82    }
83
84    /// Parse duration from string like "1m", "1h", "1d".
85    pub fn parse_duration(s: &str) -> Option<Duration> {
86        let s = s.trim();
87        if s.is_empty() {
88            return None;
89        }
90
91        let (num_str, unit) = s.split_at(s.len() - 1);
92        let num: u64 = num_str.parse().ok()?;
93
94        match unit {
95            "s" => Some(Duration::from_secs(num)),
96            "m" => Some(Duration::from_secs(num * 60)),
97            "h" => Some(Duration::from_secs(num * 3600)),
98            "d" => Some(Duration::from_secs(num * 86400)),
99            _ => None,
100        }
101    }
102}
103
104impl Default for RateLimitConfig {
105    fn default() -> Self {
106        Self {
107            requests: 100,
108            per: Duration::from_secs(60),
109            key: RateLimitKey::User,
110        }
111    }
112}
113
114/// Result of a rate limit check.
115#[derive(Debug, Clone)]
116pub struct RateLimitResult {
117    /// Whether the request is allowed.
118    pub allowed: bool,
119    /// Remaining requests in the current window.
120    pub remaining: u32,
121    /// When the limit resets.
122    pub reset_at: DateTime<Utc>,
123    /// Time to wait before retrying (if not allowed).
124    pub retry_after: Option<Duration>,
125}
126
127impl RateLimitResult {
128    /// Create a result for an allowed request.
129    pub fn allowed(remaining: u32, reset_at: DateTime<Utc>) -> Self {
130        Self {
131            allowed: true,
132            remaining,
133            reset_at,
134            retry_after: None,
135        }
136    }
137
138    /// Create a result for a denied request.
139    pub fn denied(remaining: u32, reset_at: DateTime<Utc>, retry_after: Duration) -> Self {
140        Self {
141            allowed: false,
142            remaining,
143            reset_at,
144            retry_after: Some(retry_after),
145        }
146    }
147
148    /// Convert to a ForgeError if rate limited.
149    pub fn to_error(&self, limit: u32) -> Option<ForgeError> {
150        if self.allowed {
151            None
152        } else {
153            Some(ForgeError::RateLimitExceeded {
154                retry_after: self.retry_after.unwrap_or(Duration::from_secs(1)),
155                limit,
156                remaining: self.remaining,
157            })
158        }
159    }
160}
161
162/// HTTP headers for rate limiting.
163#[derive(Debug, Clone)]
164pub struct RateLimitHeaders {
165    /// X-RateLimit-Limit header value.
166    pub limit: u32,
167    /// X-RateLimit-Remaining header value.
168    pub remaining: u32,
169    /// X-RateLimit-Reset header value (Unix timestamp).
170    pub reset: i64,
171    /// Retry-After header value (seconds).
172    pub retry_after: Option<u64>,
173}
174
175impl RateLimitHeaders {
176    /// Create headers from a rate limit result.
177    pub fn from_result(result: &RateLimitResult, limit: u32) -> Self {
178        Self {
179            limit,
180            remaining: result.remaining,
181            reset: result.reset_at.timestamp(),
182            retry_after: result.retry_after.map(|d| d.as_secs()),
183        }
184    }
185}
186
187#[cfg(test)]
188#[allow(clippy::unwrap_used, clippy::indexing_slicing)]
189mod tests {
190    use super::*;
191
192    #[test]
193    fn test_rate_limit_key() {
194        assert_eq!(RateLimitKey::User.as_str(), "user");
195        assert_eq!(RateLimitKey::Ip.as_str(), "ip");
196        assert_eq!(RateLimitKey::Global.as_str(), "global");
197
198        assert_eq!("user".parse::<RateLimitKey>().unwrap(), RateLimitKey::User);
199        assert_eq!("ip".parse::<RateLimitKey>().unwrap(), RateLimitKey::Ip);
200    }
201
202    #[test]
203    fn test_rate_limit_config() {
204        let config = RateLimitConfig::new(100, Duration::from_secs(60));
205        assert_eq!(config.requests, 100);
206        assert_eq!(config.per, Duration::from_secs(60));
207        assert!((config.refill_rate() - 1.666666).abs() < 0.01);
208    }
209
210    #[test]
211    fn test_parse_duration() {
212        assert_eq!(
213            RateLimitConfig::parse_duration("1s"),
214            Some(Duration::from_secs(1))
215        );
216        assert_eq!(
217            RateLimitConfig::parse_duration("1m"),
218            Some(Duration::from_secs(60))
219        );
220        assert_eq!(
221            RateLimitConfig::parse_duration("1h"),
222            Some(Duration::from_secs(3600))
223        );
224        assert_eq!(
225            RateLimitConfig::parse_duration("1d"),
226            Some(Duration::from_secs(86400))
227        );
228        assert_eq!(RateLimitConfig::parse_duration("invalid"), None);
229    }
230
231    #[test]
232    fn test_rate_limit_result_allowed() {
233        let result = RateLimitResult::allowed(99, Utc::now());
234        assert!(result.allowed);
235        assert!(result.retry_after.is_none());
236        assert!(result.to_error(100).is_none());
237    }
238
239    #[test]
240    fn test_rate_limit_result_denied() {
241        let result = RateLimitResult::denied(0, Utc::now(), Duration::from_secs(30));
242        assert!(!result.allowed);
243        assert!(result.retry_after.is_some());
244        assert!(result.to_error(100).is_some());
245    }
246}