forge_runtime/rate_limit/
limiter.rs1use std::time::Duration;
2
3use chrono::{DateTime, Utc};
4use dashmap::DashMap;
5use sqlx::PgPool;
6
7use forge_core::rate_limit::{RateLimitConfig, RateLimitKey, RateLimitResult};
8use forge_core::{AuthContext, ForgeError, RequestMetadata, Result};
9
10pub struct RateLimiter {
14 pool: PgPool,
15}
16
17impl RateLimiter {
18 pub fn new(pool: PgPool) -> Self {
20 Self { pool }
21 }
22
23 pub async fn check(
25 &self,
26 bucket_key: &str,
27 config: &RateLimitConfig,
28 ) -> Result<RateLimitResult> {
29 let max_tokens = config.requests as f64;
30 let refill_rate = config.refill_rate();
31
32 let result = sqlx::query!(
34 r#"
35 INSERT INTO forge_rate_limits (bucket_key, tokens, last_refill, max_tokens, refill_rate)
36 VALUES ($1, $2 - 1, NOW(), $2, $3)
37 ON CONFLICT (bucket_key) DO UPDATE SET
38 tokens = LEAST(
39 forge_rate_limits.max_tokens::double precision,
40 forge_rate_limits.tokens +
41 (EXTRACT(EPOCH FROM (NOW() - forge_rate_limits.last_refill)) * forge_rate_limits.refill_rate)
42 ) - 1,
43 last_refill = NOW()
44 RETURNING tokens, max_tokens, last_refill, (tokens >= 0) as "allowed!"
45 "#,
46 bucket_key,
47 max_tokens as i32,
48 refill_rate
49 )
50 .fetch_one(&self.pool)
51 .await
52 .map_err(|e| ForgeError::Database(e.to_string()))?;
53
54 let tokens = result.tokens;
55 let last_refill = result.last_refill;
56 let allowed = result.allowed;
57
58 let remaining = tokens.max(0.0) as u32;
59 let reset_at =
60 last_refill + chrono::Duration::seconds(((max_tokens - tokens) / refill_rate) as i64);
61
62 if allowed {
63 Ok(RateLimitResult::allowed(remaining, reset_at))
64 } else {
65 let retry_after = Duration::from_secs_f64((1.0 - tokens) / refill_rate);
66 Ok(RateLimitResult::denied(remaining, reset_at, retry_after))
67 }
68 }
69
70 pub fn build_key(
72 &self,
73 key_type: RateLimitKey,
74 action_name: &str,
75 auth: &AuthContext,
76 request: &RequestMetadata,
77 ) -> String {
78 match key_type {
79 RateLimitKey::User => {
80 let user_id = auth
81 .user_id()
82 .map(|u| u.to_string())
83 .unwrap_or_else(|| "anonymous".to_string());
84 format!("user:{}:{}", user_id, action_name)
85 }
86 RateLimitKey::Ip => {
87 let ip = request.client_ip.as_deref().unwrap_or("unknown");
88 format!("ip:{}:{}", ip, action_name)
89 }
90 RateLimitKey::Tenant => {
91 let tenant_id = auth
92 .claim("tenant_id")
93 .and_then(|v| v.as_str())
94 .unwrap_or("none");
95 format!("tenant:{}:{}", tenant_id, action_name)
96 }
97 RateLimitKey::UserAction => {
98 let user_id = auth
99 .user_id()
100 .map(|u| u.to_string())
101 .unwrap_or_else(|| "anonymous".to_string());
102 format!("user_action:{}:{}", user_id, action_name)
103 }
104 RateLimitKey::Global => {
105 format!("global:{}", action_name)
106 }
107 }
108 }
109
110 pub async fn enforce(
112 &self,
113 bucket_key: &str,
114 config: &RateLimitConfig,
115 ) -> Result<RateLimitResult> {
116 let result = self.check(bucket_key, config).await?;
117 if !result.allowed {
118 return Err(ForgeError::RateLimitExceeded {
119 retry_after: result.retry_after.unwrap_or(Duration::from_secs(1)),
120 limit: config.requests,
121 remaining: result.remaining,
122 });
123 }
124 Ok(result)
125 }
126
127 pub async fn reset(&self, bucket_key: &str) -> Result<()> {
129 sqlx::query!(
130 "DELETE FROM forge_rate_limits WHERE bucket_key = $1",
131 bucket_key
132 )
133 .execute(&self.pool)
134 .await
135 .map_err(|e| ForgeError::Database(e.to_string()))?;
136 Ok(())
137 }
138
139 pub async fn cleanup(&self, older_than: DateTime<Utc>) -> Result<u64> {
141 let result = sqlx::query!(
142 r#"
143 DELETE FROM forge_rate_limits
144 WHERE created_at < $1
145 "#,
146 older_than,
147 )
148 .execute(&self.pool)
149 .await
150 .map_err(|e| ForgeError::Database(e.to_string()))?;
151
152 Ok(result.rows_affected())
153 }
154}
155
156struct LocalBucket {
157 tokens: f64,
158 max_tokens: f64,
159 refill_rate: f64,
160 last_refill: std::time::Instant,
161}
162
163impl LocalBucket {
164 fn new(max_tokens: f64, refill_rate: f64) -> Self {
165 Self {
166 tokens: max_tokens,
167 max_tokens,
168 refill_rate,
169 last_refill: std::time::Instant::now(),
170 }
171 }
172
173 fn try_consume(&mut self) -> bool {
174 let now = std::time::Instant::now();
175 let elapsed = now.duration_since(self.last_refill).as_secs_f64();
176 self.tokens = (self.tokens + elapsed * self.refill_rate).min(self.max_tokens);
177 self.last_refill = now;
178
179 if self.tokens >= 1.0 {
180 self.tokens -= 1.0;
181 true
182 } else {
183 false
184 }
185 }
186
187 fn remaining(&self) -> u32 {
188 self.tokens.max(0.0) as u32
189 }
190
191 fn time_until_token(&self) -> Duration {
192 if self.tokens >= 1.0 {
193 Duration::ZERO
194 } else {
195 Duration::from_secs_f64((1.0 - self.tokens) / self.refill_rate)
196 }
197 }
198}
199
200const MAX_LOCAL_BUCKETS: usize = 100_000;
203
204pub struct HybridRateLimiter {
209 local: DashMap<String, LocalBucket>,
210 db_limiter: RateLimiter,
211}
212
213impl HybridRateLimiter {
214 pub fn new(pool: PgPool) -> Self {
215 Self {
216 local: DashMap::new(),
217 db_limiter: RateLimiter::new(pool),
218 }
219 }
220
221 pub async fn check(
224 &self,
225 bucket_key: &str,
226 config: &RateLimitConfig,
227 ) -> Result<RateLimitResult> {
228 if config.key == RateLimitKey::Global {
229 return self.db_limiter.check(bucket_key, config).await;
230 }
231
232 let max_tokens = config.requests as f64;
233 let refill_rate = config.refill_rate();
234
235 if self.local.len() > MAX_LOCAL_BUCKETS {
237 self.cleanup_local(Duration::from_secs(300)); }
239
240 let mut bucket = self
241 .local
242 .entry(bucket_key.to_string())
243 .or_insert_with(|| LocalBucket::new(max_tokens, refill_rate));
244
245 let allowed = bucket.try_consume();
246 let remaining = bucket.remaining();
247 let reset_at = Utc::now()
248 + chrono::Duration::seconds(((max_tokens - bucket.tokens) / refill_rate) as i64);
249
250 if allowed {
251 Ok(RateLimitResult::allowed(remaining, reset_at))
252 } else {
253 let retry_after = bucket.time_until_token();
254 Ok(RateLimitResult::denied(remaining, reset_at, retry_after))
255 }
256 }
257
258 pub fn build_key(
259 &self,
260 key_type: RateLimitKey,
261 action_name: &str,
262 auth: &AuthContext,
263 request: &RequestMetadata,
264 ) -> String {
265 self.db_limiter
266 .build_key(key_type, action_name, auth, request)
267 }
268
269 pub async fn enforce(
270 &self,
271 bucket_key: &str,
272 config: &RateLimitConfig,
273 ) -> Result<RateLimitResult> {
274 let result = self.check(bucket_key, config).await?;
275 if !result.allowed {
276 return Err(ForgeError::RateLimitExceeded {
277 retry_after: result.retry_after.unwrap_or(Duration::from_secs(1)),
278 limit: config.requests,
279 remaining: result.remaining,
280 });
281 }
282 Ok(result)
283 }
284
285 pub fn cleanup_local(&self, max_idle: Duration) {
287 let cutoff = std::time::Instant::now() - max_idle;
288 self.local.retain(|_, bucket| bucket.last_refill > cutoff);
289 }
290}
291
292#[cfg(test)]
293mod tests {
294 use super::*;
295
296 #[tokio::test]
297 async fn test_rate_limiter_creation() {
298 let pool = sqlx::postgres::PgPoolOptions::new()
299 .max_connections(1)
300 .connect_lazy("postgres://localhost/test")
301 .expect("Failed to create mock pool");
302
303 let _limiter = RateLimiter::new(pool);
304 }
305
306 #[tokio::test]
307 async fn test_build_key() {
308 let pool = sqlx::postgres::PgPoolOptions::new()
309 .max_connections(1)
310 .connect_lazy("postgres://localhost/test")
311 .expect("Failed to create mock pool");
312
313 let limiter = RateLimiter::new(pool);
314 let auth = AuthContext::unauthenticated();
315 let request = RequestMetadata::default();
316
317 let key = limiter.build_key(RateLimitKey::Global, "test_action", &auth, &request);
318 assert_eq!(key, "global:test_action");
319
320 let key = limiter.build_key(RateLimitKey::User, "test_action", &auth, &request);
321 assert!(key.starts_with("user:"));
322 }
323}