1use std::str::FromStr;
2use std::time::Duration;
3
4use chrono::{DateTime, Utc};
5
6use crate::ForgeError;
7
8#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
10pub enum RateLimitKey {
11 #[default]
13 User,
14 Ip,
16 Tenant,
18 UserAction,
20 Global,
22}
23
24impl RateLimitKey {
25 pub fn as_str(&self) -> &'static str {
27 match self {
28 Self::User => "user",
29 Self::Ip => "ip",
30 Self::Tenant => "tenant",
31 Self::UserAction => "user_action",
32 Self::Global => "global",
33 }
34 }
35}
36
37impl FromStr for RateLimitKey {
38 type Err = std::convert::Infallible;
39
40 fn from_str(s: &str) -> Result<Self, Self::Err> {
41 Ok(match s {
42 "user" => Self::User,
43 "ip" => Self::Ip,
44 "tenant" => Self::Tenant,
45 "user_action" => Self::UserAction,
46 "global" => Self::Global,
47 _ => Self::User,
48 })
49 }
50}
51
52#[derive(Debug, Clone)]
54pub struct RateLimitConfig {
55 pub requests: u32,
57 pub per: Duration,
59 pub key: RateLimitKey,
61}
62
63impl RateLimitConfig {
64 pub fn new(requests: u32, per: Duration) -> Self {
66 Self {
67 requests,
68 per,
69 key: RateLimitKey::default(),
70 }
71 }
72
73 pub fn with_key(mut self, key: RateLimitKey) -> Self {
75 self.key = key;
76 self
77 }
78
79 pub fn refill_rate(&self) -> f64 {
81 self.requests as f64 / self.per.as_secs_f64()
82 }
83
84 pub fn parse_duration(s: &str) -> Option<Duration> {
86 let s = s.trim();
87 if s.is_empty() {
88 return None;
89 }
90
91 let (num_str, unit) = s.split_at(s.len() - 1);
92 let num: u64 = num_str.parse().ok()?;
93
94 match unit {
95 "s" => Some(Duration::from_secs(num)),
96 "m" => Some(Duration::from_secs(num * 60)),
97 "h" => Some(Duration::from_secs(num * 3600)),
98 "d" => Some(Duration::from_secs(num * 86400)),
99 _ => None,
100 }
101 }
102}
103
104impl Default for RateLimitConfig {
105 fn default() -> Self {
106 Self {
107 requests: 100,
108 per: Duration::from_secs(60),
109 key: RateLimitKey::User,
110 }
111 }
112}
113
114#[derive(Debug, Clone)]
116pub struct RateLimitResult {
117 pub allowed: bool,
119 pub remaining: u32,
121 pub reset_at: DateTime<Utc>,
123 pub retry_after: Option<Duration>,
125}
126
127impl RateLimitResult {
128 pub fn allowed(remaining: u32, reset_at: DateTime<Utc>) -> Self {
130 Self {
131 allowed: true,
132 remaining,
133 reset_at,
134 retry_after: None,
135 }
136 }
137
138 pub fn denied(remaining: u32, reset_at: DateTime<Utc>, retry_after: Duration) -> Self {
140 Self {
141 allowed: false,
142 remaining,
143 reset_at,
144 retry_after: Some(retry_after),
145 }
146 }
147
148 pub fn to_error(&self, limit: u32) -> Option<ForgeError> {
150 if self.allowed {
151 None
152 } else {
153 Some(ForgeError::RateLimitExceeded {
154 retry_after: self.retry_after.unwrap_or(Duration::from_secs(1)),
155 limit,
156 remaining: self.remaining,
157 })
158 }
159 }
160}
161
162#[derive(Debug, Clone)]
164pub struct RateLimitHeaders {
165 pub limit: u32,
167 pub remaining: u32,
169 pub reset: i64,
171 pub retry_after: Option<u64>,
173}
174
175impl RateLimitHeaders {
176 pub fn from_result(result: &RateLimitResult, limit: u32) -> Self {
178 Self {
179 limit,
180 remaining: result.remaining,
181 reset: result.reset_at.timestamp(),
182 retry_after: result.retry_after.map(|d| d.as_secs()),
183 }
184 }
185}
186
187#[cfg(test)]
188mod tests {
189 use super::*;
190
191 #[test]
192 fn test_rate_limit_key() {
193 assert_eq!(RateLimitKey::User.as_str(), "user");
194 assert_eq!(RateLimitKey::Ip.as_str(), "ip");
195 assert_eq!(RateLimitKey::Global.as_str(), "global");
196
197 assert_eq!("user".parse::<RateLimitKey>().unwrap(), RateLimitKey::User);
198 assert_eq!("ip".parse::<RateLimitKey>().unwrap(), RateLimitKey::Ip);
199 }
200
201 #[test]
202 fn test_rate_limit_config() {
203 let config = RateLimitConfig::new(100, Duration::from_secs(60));
204 assert_eq!(config.requests, 100);
205 assert_eq!(config.per, Duration::from_secs(60));
206 assert!((config.refill_rate() - 1.666666).abs() < 0.01);
207 }
208
209 #[test]
210 fn test_parse_duration() {
211 assert_eq!(
212 RateLimitConfig::parse_duration("1s"),
213 Some(Duration::from_secs(1))
214 );
215 assert_eq!(
216 RateLimitConfig::parse_duration("1m"),
217 Some(Duration::from_secs(60))
218 );
219 assert_eq!(
220 RateLimitConfig::parse_duration("1h"),
221 Some(Duration::from_secs(3600))
222 );
223 assert_eq!(
224 RateLimitConfig::parse_duration("1d"),
225 Some(Duration::from_secs(86400))
226 );
227 assert_eq!(RateLimitConfig::parse_duration("invalid"), None);
228 }
229
230 #[test]
231 fn test_rate_limit_result_allowed() {
232 let result = RateLimitResult::allowed(99, Utc::now());
233 assert!(result.allowed);
234 assert!(result.retry_after.is_none());
235 assert!(result.to_error(100).is_none());
236 }
237
238 #[test]
239 fn test_rate_limit_result_denied() {
240 let result = RateLimitResult::denied(0, Utc::now(), Duration::from_secs(30));
241 assert!(!result.allowed);
242 assert!(result.retry_after.is_some());
243 assert!(result.to_error(100).is_some());
244 }
245}