1use std::str::FromStr;
2use std::time::Duration;
3
4use chrono::{DateTime, Utc};
5
6use crate::ForgeError;
7
8#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
10pub enum RateLimitKey {
11 #[default]
13 User,
14 Ip,
16 Tenant,
18 UserAction,
20 Global,
22}
23
24impl RateLimitKey {
25 pub fn as_str(&self) -> &'static str {
27 match self {
28 Self::User => "user",
29 Self::Ip => "ip",
30 Self::Tenant => "tenant",
31 Self::UserAction => "user_action",
32 Self::Global => "global",
33 }
34 }
35}
36
37impl FromStr for RateLimitKey {
38 type Err = std::convert::Infallible;
39
40 fn from_str(s: &str) -> Result<Self, Self::Err> {
41 Ok(match s {
42 "user" => Self::User,
43 "ip" => Self::Ip,
44 "tenant" => Self::Tenant,
45 "user_action" => Self::UserAction,
46 "global" => Self::Global,
47 _ => Self::User,
48 })
49 }
50}
51
52#[derive(Debug, Clone)]
54pub struct RateLimitConfig {
55 pub requests: u32,
57 pub per: Duration,
59 pub key: RateLimitKey,
61}
62
63impl RateLimitConfig {
64 pub fn new(requests: u32, per: Duration) -> Self {
66 Self {
67 requests,
68 per,
69 key: RateLimitKey::default(),
70 }
71 }
72
73 pub fn with_key(mut self, key: RateLimitKey) -> Self {
75 self.key = key;
76 self
77 }
78
79 pub fn refill_rate(&self) -> f64 {
81 self.requests as f64 / self.per.as_secs_f64()
82 }
83
84 pub fn parse_duration(s: &str) -> Option<Duration> {
86 let s = s.trim();
87 if s.is_empty() {
88 return None;
89 }
90
91 let (num_str, unit) = s.split_at(s.len() - 1);
92 let num: u64 = num_str.parse().ok()?;
93
94 match unit {
95 "s" => Some(Duration::from_secs(num)),
96 "m" => Some(Duration::from_secs(num * 60)),
97 "h" => Some(Duration::from_secs(num * 3600)),
98 "d" => Some(Duration::from_secs(num * 86400)),
99 _ => None,
100 }
101 }
102}
103
104impl Default for RateLimitConfig {
105 fn default() -> Self {
106 Self {
107 requests: 100,
108 per: Duration::from_secs(60),
109 key: RateLimitKey::User,
110 }
111 }
112}
113
114#[derive(Debug, Clone)]
116pub struct RateLimitResult {
117 pub allowed: bool,
119 pub remaining: u32,
121 pub reset_at: DateTime<Utc>,
123 pub retry_after: Option<Duration>,
125}
126
127impl RateLimitResult {
128 pub fn allowed(remaining: u32, reset_at: DateTime<Utc>) -> Self {
130 Self {
131 allowed: true,
132 remaining,
133 reset_at,
134 retry_after: None,
135 }
136 }
137
138 pub fn denied(remaining: u32, reset_at: DateTime<Utc>, retry_after: Duration) -> Self {
140 Self {
141 allowed: false,
142 remaining,
143 reset_at,
144 retry_after: Some(retry_after),
145 }
146 }
147
148 pub fn to_error(&self, limit: u32) -> Option<ForgeError> {
150 if self.allowed {
151 None
152 } else {
153 Some(ForgeError::RateLimitExceeded {
154 retry_after: self.retry_after.unwrap_or(Duration::from_secs(1)),
155 limit,
156 remaining: self.remaining,
157 })
158 }
159 }
160}
161
162#[derive(Debug, Clone)]
164pub struct RateLimitHeaders {
165 pub limit: u32,
167 pub remaining: u32,
169 pub reset: i64,
171 pub retry_after: Option<u64>,
173}
174
175impl RateLimitHeaders {
176 pub fn from_result(result: &RateLimitResult, limit: u32) -> Self {
178 Self {
179 limit,
180 remaining: result.remaining,
181 reset: result.reset_at.timestamp(),
182 retry_after: result.retry_after.map(|d| d.as_secs()),
183 }
184 }
185}
186
187#[cfg(test)]
188#[allow(clippy::unwrap_used, clippy::indexing_slicing)]
189mod tests {
190 use super::*;
191
192 #[test]
193 fn test_rate_limit_key() {
194 assert_eq!(RateLimitKey::User.as_str(), "user");
195 assert_eq!(RateLimitKey::Ip.as_str(), "ip");
196 assert_eq!(RateLimitKey::Global.as_str(), "global");
197
198 assert_eq!("user".parse::<RateLimitKey>().unwrap(), RateLimitKey::User);
199 assert_eq!("ip".parse::<RateLimitKey>().unwrap(), RateLimitKey::Ip);
200 }
201
202 #[test]
203 fn test_rate_limit_config() {
204 let config = RateLimitConfig::new(100, Duration::from_secs(60));
205 assert_eq!(config.requests, 100);
206 assert_eq!(config.per, Duration::from_secs(60));
207 assert!((config.refill_rate() - 1.666666).abs() < 0.01);
208 }
209
210 #[test]
211 fn test_parse_duration() {
212 assert_eq!(
213 RateLimitConfig::parse_duration("1s"),
214 Some(Duration::from_secs(1))
215 );
216 assert_eq!(
217 RateLimitConfig::parse_duration("1m"),
218 Some(Duration::from_secs(60))
219 );
220 assert_eq!(
221 RateLimitConfig::parse_duration("1h"),
222 Some(Duration::from_secs(3600))
223 );
224 assert_eq!(
225 RateLimitConfig::parse_duration("1d"),
226 Some(Duration::from_secs(86400))
227 );
228 assert_eq!(RateLimitConfig::parse_duration("invalid"), None);
229 }
230
231 #[test]
232 fn test_rate_limit_result_allowed() {
233 let result = RateLimitResult::allowed(99, Utc::now());
234 assert!(result.allowed);
235 assert!(result.retry_after.is_none());
236 assert!(result.to_error(100).is_none());
237 }
238
239 #[test]
240 fn test_rate_limit_result_denied() {
241 let result = RateLimitResult::denied(0, Utc::now(), Duration::from_secs(30));
242 assert!(!result.allowed);
243 assert!(result.retry_after.is_some());
244 assert!(result.to_error(100).is_some());
245 }
246}