1use std::str::FromStr;
2use std::time::Duration;
3
4use chrono::{DateTime, Utc};
5
6use crate::ForgeError;
7use crate::util::parse_duration;
8
9mod backend;
10
11pub use backend::RateLimiterBackend;
12
13#[derive(Debug, Clone, Default, PartialEq, Eq)]
15#[non_exhaustive]
16pub enum RateLimitKey {
17 #[default]
19 User,
20 Ip,
22 Tenant,
24 UserAction,
26 Global,
28 Custom(String),
34}
35
36impl RateLimitKey {
37 pub fn as_str(&self) -> &str {
42 match self {
43 Self::User => "user",
44 Self::Ip => "ip",
45 Self::Tenant => "tenant",
46 Self::UserAction => "user_action",
47 Self::Global => "global",
48 Self::Custom(_) => "custom",
49 }
50 }
51
52 pub fn custom_name(&self) -> Option<&str> {
55 match self {
56 Self::Custom(name) => Some(name.as_str()),
57 _ => None,
58 }
59 }
60}
61
62#[derive(Debug, Clone)]
64pub struct ParseRateLimitKeyError(pub String);
65
66impl std::fmt::Display for ParseRateLimitKeyError {
67 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
68 write!(
69 f,
70 "unknown rate limit key \"{}\". Expected one of: user, ip, tenant, user_action, global, custom:<name>",
71 self.0
72 )
73 }
74}
75
76impl std::error::Error for ParseRateLimitKeyError {}
77
78impl FromStr for RateLimitKey {
79 type Err = ParseRateLimitKeyError;
80
81 fn from_str(s: &str) -> Result<Self, Self::Err> {
82 match s {
83 "user" => Ok(Self::User),
84 "ip" => Ok(Self::Ip),
85 "tenant" => Ok(Self::Tenant),
86 "user_action" => Ok(Self::UserAction),
87 "global" => Ok(Self::Global),
88 _ if s.starts_with("custom:") => {
89 Ok(Self::Custom(s.trim_start_matches("custom:").to_string()))
90 }
91 _ => Err(ParseRateLimitKeyError(s.to_string())),
92 }
93 }
94}
95
96#[derive(Debug, Clone)]
98#[non_exhaustive]
99pub struct RateLimitConfig {
100 pub requests: u32,
102 pub per: Duration,
104 pub key: RateLimitKey,
106}
107
108impl RateLimitConfig {
109 pub fn new(requests: u32, per: Duration) -> Self {
111 Self {
112 requests,
113 per,
114 key: RateLimitKey::default(),
115 }
116 }
117
118 pub fn with_key(mut self, key: RateLimitKey) -> Self {
120 self.key = key;
121 self
122 }
123
124 pub fn refill_rate(&self) -> f64 {
126 self.requests as f64 / self.per.as_secs_f64()
127 }
128
129 pub fn parse_duration(s: &str) -> Option<Duration> {
132 parse_duration(s)
133 }
134}
135
136impl Default for RateLimitConfig {
137 fn default() -> Self {
138 Self {
139 requests: 100,
140 per: Duration::from_secs(60),
141 key: RateLimitKey::User,
142 }
143 }
144}
145
146#[derive(Debug, Clone)]
148pub struct RateLimitResult {
149 pub allowed: bool,
151 pub remaining: u32,
153 pub reset_at: DateTime<Utc>,
155 pub retry_after: Option<Duration>,
157}
158
159impl RateLimitResult {
160 pub fn allowed(remaining: u32, reset_at: DateTime<Utc>) -> Self {
162 Self {
163 allowed: true,
164 remaining,
165 reset_at,
166 retry_after: None,
167 }
168 }
169
170 pub fn denied(remaining: u32, reset_at: DateTime<Utc>, retry_after: Duration) -> Self {
172 Self {
173 allowed: false,
174 remaining,
175 reset_at,
176 retry_after: Some(retry_after),
177 }
178 }
179
180 pub fn to_error(&self, limit: u32) -> Option<ForgeError> {
182 if self.allowed {
183 None
184 } else {
185 Some(ForgeError::RateLimitExceeded {
186 retry_after: self.retry_after.unwrap_or(Duration::from_secs(1)),
187 limit,
188 remaining: self.remaining,
189 })
190 }
191 }
192}
193
194#[derive(Debug, Clone)]
196pub struct RateLimitHeaders {
197 pub limit: u32,
199 pub remaining: u32,
201 pub reset: i64,
203 pub retry_after: Option<u64>,
205}
206
207impl RateLimitHeaders {
208 pub fn from_result(result: &RateLimitResult, limit: u32) -> Self {
210 Self {
211 limit,
212 remaining: result.remaining,
213 reset: result.reset_at.timestamp(),
214 retry_after: result.retry_after.map(|d| d.as_secs()),
215 }
216 }
217}
218
219#[cfg(test)]
220#[allow(clippy::unwrap_used, clippy::indexing_slicing, clippy::panic)]
221mod tests {
222 use super::*;
223
224 #[test]
225 fn test_rate_limit_key() {
226 assert_eq!(RateLimitKey::User.as_str(), "user");
227 assert_eq!(RateLimitKey::Ip.as_str(), "ip");
228 assert_eq!(RateLimitKey::Global.as_str(), "global");
229 assert_eq!(
230 RateLimitKey::Custom("tenant_id".to_string()).as_str(),
231 "custom"
232 );
233 assert_eq!(
234 RateLimitKey::Custom("tenant_id".to_string()).custom_name(),
235 Some("tenant_id")
236 );
237
238 assert_eq!("user".parse::<RateLimitKey>().unwrap(), RateLimitKey::User);
239 assert_eq!("ip".parse::<RateLimitKey>().unwrap(), RateLimitKey::Ip);
240 assert_eq!(
241 "custom:tenant_id".parse::<RateLimitKey>().unwrap(),
242 RateLimitKey::Custom("tenant_id".to_string())
243 );
244 }
245
246 #[test]
247 fn test_rate_limit_config() {
248 let config = RateLimitConfig::new(100, Duration::from_secs(60));
249 assert_eq!(config.requests, 100);
250 assert_eq!(config.per, Duration::from_secs(60));
251 assert!((config.refill_rate() - 1.666666).abs() < 0.01);
252 }
253
254 #[test]
255 fn test_parse_duration() {
256 assert_eq!(
257 RateLimitConfig::parse_duration("1s"),
258 Some(Duration::from_secs(1))
259 );
260 assert_eq!(
261 RateLimitConfig::parse_duration("1m"),
262 Some(Duration::from_secs(60))
263 );
264 assert_eq!(
265 RateLimitConfig::parse_duration("1h"),
266 Some(Duration::from_secs(3600))
267 );
268 assert_eq!(
269 RateLimitConfig::parse_duration("1d"),
270 Some(Duration::from_secs(86400))
271 );
272 assert_eq!(RateLimitConfig::parse_duration("invalid"), None);
273 }
274
275 #[test]
276 fn test_rate_limit_result_allowed() {
277 let result = RateLimitResult::allowed(99, Utc::now());
278 assert!(result.allowed);
279 assert!(result.retry_after.is_none());
280 assert!(result.to_error(100).is_none());
281 }
282
283 #[test]
284 fn test_rate_limit_result_denied() {
285 let result = RateLimitResult::denied(0, Utc::now(), Duration::from_secs(30));
286 assert!(!result.allowed);
287 assert!(result.retry_after.is_some());
288 assert!(result.to_error(100).is_some());
289 }
290
291 #[test]
292 fn rate_limit_key_default_is_user() {
293 assert_eq!(RateLimitKey::default(), RateLimitKey::User);
297 }
298
299 #[test]
300 fn rate_limit_key_as_str_covers_all_standard_variants() {
301 assert_eq!(RateLimitKey::Tenant.as_str(), "tenant");
302 assert_eq!(RateLimitKey::UserAction.as_str(), "user_action");
303 }
304
305 #[test]
306 fn rate_limit_key_custom_name_is_none_for_standard_variants() {
307 for variant in [
308 RateLimitKey::User,
309 RateLimitKey::Ip,
310 RateLimitKey::Tenant,
311 RateLimitKey::UserAction,
312 RateLimitKey::Global,
313 ] {
314 assert_eq!(variant.custom_name(), None);
315 }
316 }
317
318 #[test]
319 fn rate_limit_key_parse_covers_all_named_variants() {
320 assert_eq!(
321 "tenant".parse::<RateLimitKey>().unwrap(),
322 RateLimitKey::Tenant
323 );
324 assert_eq!(
325 "user_action".parse::<RateLimitKey>().unwrap(),
326 RateLimitKey::UserAction
327 );
328 assert_eq!(
329 "global".parse::<RateLimitKey>().unwrap(),
330 RateLimitKey::Global
331 );
332 }
333
334 #[test]
335 fn rate_limit_key_parse_custom_extracts_inner_name() {
336 let parsed = "custom:org_id".parse::<RateLimitKey>().unwrap();
337 assert_eq!(parsed, RateLimitKey::Custom("org_id".to_string()));
338 assert_eq!(parsed.custom_name(), Some("org_id"));
339
340 let empty = "custom:".parse::<RateLimitKey>().unwrap();
343 assert_eq!(empty, RateLimitKey::Custom(String::new()));
344 }
345
346 #[test]
347 fn rate_limit_key_parse_unknown_returns_descriptive_error() {
348 let err = "bogus".parse::<RateLimitKey>().unwrap_err();
349 let msg = err.to_string();
350 assert!(msg.contains("bogus"), "error should echo input: {msg}");
351 assert!(
352 msg.contains("user, ip, tenant, user_action, global, custom:<name>"),
353 "error should list valid keys: {msg}"
354 );
355 }
356
357 #[test]
358 fn rate_limit_config_default_matches_documented_values() {
359 let cfg = RateLimitConfig::default();
360 assert_eq!(cfg.requests, 100);
361 assert_eq!(cfg.per, Duration::from_secs(60));
362 assert_eq!(cfg.key, RateLimitKey::User);
363 }
364
365 #[test]
366 fn rate_limit_config_with_key_overrides_default() {
367 let cfg = RateLimitConfig::new(10, Duration::from_secs(1)).with_key(RateLimitKey::Ip);
368 assert_eq!(cfg.key, RateLimitKey::Ip);
369 assert_eq!(cfg.requests, 10);
370 assert_eq!(cfg.per, Duration::from_secs(1));
371 }
372
373 #[test]
374 fn rate_limit_config_refill_rate_handles_burst_window() {
375 let cfg = RateLimitConfig::new(60, Duration::from_secs(30));
377 assert!((cfg.refill_rate() - 2.0).abs() < 1e-9);
378 }
379
380 #[test]
381 fn to_error_carries_retry_after_and_limit_metadata() {
382 let result = RateLimitResult::denied(3, Utc::now(), Duration::from_secs(42));
383 let err = result.to_error(100).expect("denied result yields error");
384 match err {
385 ForgeError::RateLimitExceeded {
386 retry_after,
387 limit,
388 remaining,
389 } => {
390 assert_eq!(retry_after, Duration::from_secs(42));
391 assert_eq!(limit, 100);
392 assert_eq!(remaining, 3);
393 }
394 other => panic!("expected RateLimitExceeded, got {other:?}"),
395 }
396 }
397
398 #[test]
399 fn to_error_falls_back_to_1s_when_retry_after_missing() {
400 let result = RateLimitResult {
403 allowed: false,
404 remaining: 0,
405 reset_at: Utc::now(),
406 retry_after: None,
407 };
408 match result.to_error(5).expect("denied yields error") {
409 ForgeError::RateLimitExceeded { retry_after, .. } => {
410 assert_eq!(retry_after, Duration::from_secs(1));
411 }
412 other => panic!("expected RateLimitExceeded, got {other:?}"),
413 }
414 }
415
416 #[test]
417 fn rate_limit_headers_from_allowed_result_omits_retry_after() {
418 let reset = Utc::now();
419 let result = RateLimitResult::allowed(7, reset);
420 let headers = RateLimitHeaders::from_result(&result, 10);
421 assert_eq!(headers.limit, 10);
422 assert_eq!(headers.remaining, 7);
423 assert_eq!(headers.reset, reset.timestamp());
424 assert_eq!(headers.retry_after, None);
425 }
426
427 #[test]
428 fn rate_limit_headers_from_denied_result_carries_retry_after_seconds() {
429 let reset = Utc::now();
430 let result = RateLimitResult::denied(0, reset, Duration::from_secs(15));
431 let headers = RateLimitHeaders::from_result(&result, 10);
432 assert_eq!(headers.retry_after, Some(15));
433 assert_eq!(headers.remaining, 0);
434 }
435}