1use std::future::Future;
2use std::pin::Pin;
3use std::time::Duration;
4
5use chrono::{DateTime, Utc};
6use dashmap::DashMap;
7use sqlx::PgPool;
8
9use forge_core::rate_limit::{RateLimitConfig, RateLimitKey, RateLimitResult, RateLimiterBackend};
10use forge_core::{AuthContext, ForgeError, RequestMetadata, Result};
11
12pub struct StrictRateLimiter {
20 pool: PgPool,
21}
22
23impl StrictRateLimiter {
24 pub fn new(pool: PgPool) -> Self {
25 Self { pool }
26 }
27
28 pub async fn check(
29 &self,
30 bucket_key: &str,
31 config: &RateLimitConfig,
32 ) -> Result<RateLimitResult> {
33 let max_tokens = config.requests as f64;
34 let refill_rate = config.refill_rate();
35
36 let result = sqlx::query!(
41 r#"
42 INSERT INTO forge_rate_limits (bucket_key, tokens, last_refill, max_tokens, refill_rate)
43 VALUES ($1, $2 - 1, NOW(), $2, $3)
44 ON CONFLICT (bucket_key) DO UPDATE SET
45 tokens = GREATEST(
46 LEAST(
47 forge_rate_limits.max_tokens::double precision,
48 forge_rate_limits.tokens +
49 (EXTRACT(EPOCH FROM (NOW() - forge_rate_limits.last_refill)) * forge_rate_limits.refill_rate)
50 ) - 1,
51 -1.0
52 ),
53 last_refill = NOW()
54 RETURNING tokens, max_tokens, last_refill, (tokens >= 0) as "allowed!"
55 "#,
56 bucket_key,
57 max_tokens as i32,
58 refill_rate
59 )
60 .fetch_one(&self.pool)
61 .await
62 .map_err(ForgeError::Database)?;
63
64 let tokens = result.tokens;
65 let last_refill = result.last_refill;
66 let allowed = result.allowed;
67
68 let remaining = tokens.max(0.0) as u32;
69 let reset_at =
70 last_refill + chrono::Duration::seconds(((max_tokens - tokens) / refill_rate) as i64);
71
72 if allowed {
73 Ok(RateLimitResult::allowed(remaining, reset_at))
74 } else {
75 let retry_after = Duration::from_secs_f64((1.0 - tokens) / refill_rate);
79 Ok(RateLimitResult::denied(remaining, reset_at, retry_after))
80 }
81 }
82
83 pub fn build_key(
84 &self,
85 key_type: RateLimitKey,
86 action_name: &str,
87 auth: &AuthContext,
88 request: &RequestMetadata,
89 ) -> String {
90 match key_type {
91 RateLimitKey::User => {
92 let user_id = auth.user_id().map(|u| u.to_string()).unwrap_or_else(|| {
93 let ip = request.client_ip().unwrap_or("unknown");
94 format!("anon-{ip}")
95 });
96 format!("user:{}:{}", user_id, action_name)
97 }
98 RateLimitKey::Ip => {
99 let ip = request.client_ip().unwrap_or("unknown");
100 format!("ip:{}:{}", ip, action_name)
101 }
102 RateLimitKey::Tenant => {
103 let tenant_id = auth
104 .claim("tenant_id")
105 .and_then(|v| v.as_str())
106 .unwrap_or("none");
107 format!("tenant:{}:{}", tenant_id, action_name)
108 }
109 RateLimitKey::UserAction => {
110 let user_id = auth
111 .user_id()
112 .map(|u| u.to_string())
113 .unwrap_or_else(|| "anonymous".to_string());
114 format!("user_action:{}:{}", user_id, action_name)
115 }
116 RateLimitKey::Global => {
117 format!("global:{}", action_name)
118 }
119 RateLimitKey::Custom(claim_name) => {
120 let value = auth
121 .claim(&claim_name)
122 .and_then(|v| v.as_str())
123 .unwrap_or("unknown");
124 format!("custom:{}:{}:{}", claim_name, value, action_name)
125 }
126 _ => format!("global:{}", action_name),
129 }
130 }
131
132 pub async fn enforce(
133 &self,
134 bucket_key: &str,
135 config: &RateLimitConfig,
136 ) -> Result<RateLimitResult> {
137 let result = self.check(bucket_key, config).await?;
138 if !result.allowed {
139 #[cfg(feature = "gateway")]
140 crate::signals::emit_diagnostic(
141 "rate_limit.exceeded",
142 serde_json::json!({
143 "bucket": bucket_key,
144 "limit": config.requests,
145 "remaining": result.remaining,
146 "retry_after_ms": result
147 .retry_after
148 .unwrap_or(Duration::from_secs(1))
149 .as_millis() as u64,
150 }),
151 None,
152 None,
153 None,
154 None,
155 false,
156 );
157 return Err(ForgeError::RateLimitExceeded {
158 retry_after: result.retry_after.unwrap_or(Duration::from_secs(1)),
159 limit: config.requests,
160 remaining: result.remaining,
161 });
162 }
163 Ok(result)
164 }
165
166 pub async fn cleanup(&self, older_than: DateTime<Utc>) -> Result<u64> {
167 let result = sqlx::query!(
168 r#"
169 DELETE FROM forge_rate_limits
170 WHERE created_at < $1
171 "#,
172 older_than,
173 )
174 .execute(&self.pool)
175 .await
176 .map_err(ForgeError::Database)?;
177
178 Ok(result.rows_affected())
179 }
180}
181
182struct LocalBucket {
183 tokens: f64,
184 max_tokens: f64,
185 refill_rate: f64,
186 last_refill: std::time::Instant,
187}
188
189impl LocalBucket {
190 fn new(max_tokens: f64, refill_rate: f64) -> Self {
191 Self {
192 tokens: max_tokens,
193 max_tokens,
194 refill_rate,
195 last_refill: std::time::Instant::now(),
196 }
197 }
198
199 fn try_consume(&mut self) -> bool {
200 let now = std::time::Instant::now();
201 let elapsed = now.duration_since(self.last_refill).as_secs_f64();
202 self.tokens = (self.tokens + elapsed * self.refill_rate).min(self.max_tokens);
203 self.last_refill = now;
204
205 if self.tokens >= 1.0 {
206 self.tokens -= 1.0;
207 true
208 } else {
209 false
210 }
211 }
212
213 fn remaining(&self) -> u32 {
214 self.tokens.max(0.0) as u32
215 }
216
217 fn time_until_token(&self) -> Duration {
218 if self.tokens >= 1.0 {
219 Duration::ZERO
220 } else {
221 Duration::from_secs_f64((1.0 - self.tokens) / self.refill_rate)
222 }
223 }
224}
225
226pub struct HybridRateLimiter {
239 local: DashMap<String, LocalBucket>,
240 db_limiter: StrictRateLimiter,
241 max_local_buckets: usize,
242}
243
244impl HybridRateLimiter {
245 pub fn new(pool: PgPool) -> Self {
246 Self::with_max_buckets(pool, 100_000)
247 }
248
249 pub fn with_max_buckets(pool: PgPool, max_local_buckets: usize) -> Self {
251 Self {
252 local: DashMap::new(),
253 db_limiter: StrictRateLimiter::new(pool),
254 max_local_buckets,
255 }
256 }
257
258 pub async fn check(
259 &self,
260 bucket_key: &str,
261 config: &RateLimitConfig,
262 ) -> Result<RateLimitResult> {
263 if config.key == RateLimitKey::Global {
264 return self.db_limiter.check(bucket_key, config).await;
265 }
266
267 let max_tokens = config.requests as f64;
268 let refill_rate = config.refill_rate();
269
270 if self.local.len() > self.max_local_buckets {
271 self.cleanup_local(Duration::from_secs(300)); }
273
274 let mut bucket = self
275 .local
276 .entry(bucket_key.to_string())
277 .or_insert_with(|| LocalBucket::new(max_tokens, refill_rate));
278
279 let allowed = bucket.try_consume();
280 let remaining = bucket.remaining();
281 let reset_at = Utc::now()
282 + chrono::Duration::seconds(((max_tokens - bucket.tokens) / refill_rate) as i64);
283
284 if allowed {
285 Ok(RateLimitResult::allowed(remaining, reset_at))
286 } else {
287 let retry_after = bucket.time_until_token();
288 Ok(RateLimitResult::denied(remaining, reset_at, retry_after))
289 }
290 }
291
292 pub fn build_key(
293 &self,
294 key_type: RateLimitKey,
295 action_name: &str,
296 auth: &AuthContext,
297 request: &RequestMetadata,
298 ) -> String {
299 self.db_limiter
300 .build_key(key_type, action_name, auth, request)
301 }
302
303 pub async fn enforce(
304 &self,
305 bucket_key: &str,
306 config: &RateLimitConfig,
307 ) -> Result<RateLimitResult> {
308 let result = self.check(bucket_key, config).await?;
309 if !result.allowed {
310 return Err(ForgeError::RateLimitExceeded {
311 retry_after: result.retry_after.unwrap_or(Duration::from_secs(1)),
312 limit: config.requests,
313 remaining: result.remaining,
314 });
315 }
316 Ok(result)
317 }
318
319 pub fn cleanup_local(&self, max_idle: Duration) {
321 let cutoff = std::time::Instant::now()
322 .checked_sub(max_idle)
323 .unwrap_or(std::time::Instant::now());
324 self.local.retain(|_, bucket| bucket.last_refill > cutoff);
325 }
326}
327
328impl RateLimiterBackend for StrictRateLimiter {
329 fn check<'a>(
330 &'a self,
331 bucket_key: &'a str,
332 config: &'a RateLimitConfig,
333 ) -> Pin<Box<dyn Future<Output = Result<RateLimitResult>> + Send + 'a>> {
334 Box::pin(StrictRateLimiter::check(self, bucket_key, config))
335 }
336
337 fn build_key(
338 &self,
339 key_type: RateLimitKey,
340 action_name: &str,
341 auth: &AuthContext,
342 request: &RequestMetadata,
343 ) -> String {
344 StrictRateLimiter::build_key(self, key_type, action_name, auth, request)
345 }
346
347 fn enforce<'a>(
348 &'a self,
349 bucket_key: &'a str,
350 config: &'a RateLimitConfig,
351 ) -> Pin<Box<dyn Future<Output = Result<RateLimitResult>> + Send + 'a>> {
352 Box::pin(StrictRateLimiter::enforce(self, bucket_key, config))
353 }
354}
355
356impl RateLimiterBackend for HybridRateLimiter {
357 fn check<'a>(
358 &'a self,
359 bucket_key: &'a str,
360 config: &'a RateLimitConfig,
361 ) -> Pin<Box<dyn Future<Output = Result<RateLimitResult>> + Send + 'a>> {
362 Box::pin(HybridRateLimiter::check(self, bucket_key, config))
363 }
364
365 fn build_key(
366 &self,
367 key_type: RateLimitKey,
368 action_name: &str,
369 auth: &AuthContext,
370 request: &RequestMetadata,
371 ) -> String {
372 HybridRateLimiter::build_key(self, key_type, action_name, auth, request)
373 }
374
375 fn enforce<'a>(
376 &'a self,
377 bucket_key: &'a str,
378 config: &'a RateLimitConfig,
379 ) -> Pin<Box<dyn Future<Output = Result<RateLimitResult>> + Send + 'a>> {
380 Box::pin(HybridRateLimiter::enforce(self, bucket_key, config))
381 }
382}
383
384#[cfg(test)]
385#[allow(
386 clippy::unwrap_used,
387 clippy::indexing_slicing,
388 clippy::panic,
389 clippy::disallowed_methods
390)]
391mod tests {
392 use super::*;
393 use std::sync::Arc;
394
395 fn lazy_pool() -> PgPool {
396 sqlx::postgres::PgPoolOptions::new()
399 .max_connections(1)
400 .connect_lazy("postgres://localhost/test")
401 .expect("connect_lazy never fails for a syntactically valid URL")
402 }
403
404 fn cfg(requests: u32, window_ms: u64) -> RateLimitConfig {
405 RateLimitConfig::new(requests, Duration::from_millis(window_ms))
406 }
407
408 #[test]
409 fn local_bucket_consumes_then_denies() {
410 let mut bucket = LocalBucket::new(3.0, 1.0);
411 assert!(bucket.try_consume());
412 assert!(bucket.try_consume());
413 assert!(bucket.try_consume());
414 assert!(!bucket.try_consume());
416 assert_eq!(bucket.remaining(), 0);
417 }
418
419 #[test]
420 fn local_bucket_refill_does_not_exceed_max() {
421 let mut bucket = LocalBucket::new(5.0, 1000.0);
422 for _ in 0..5 {
424 bucket.try_consume();
425 }
426 bucket.last_refill = std::time::Instant::now() - Duration::from_secs(10);
428 assert!(bucket.try_consume());
430 assert_eq!(bucket.remaining(), 4);
431 }
432
433 #[test]
434 fn local_bucket_time_until_token_is_zero_when_available() {
435 let bucket = LocalBucket::new(5.0, 1.0);
436 assert_eq!(bucket.time_until_token(), Duration::ZERO);
437 }
438
439 #[test]
440 fn local_bucket_time_until_token_reflects_refill_rate() {
441 let mut bucket = LocalBucket::new(1.0, 1.0);
443 bucket.tokens = 0.5;
444 let wait = bucket.time_until_token();
445 assert!(
447 wait.as_secs_f64() > 0.45 && wait.as_secs_f64() < 0.55,
448 "expected ~0.5s, got {wait:?}",
449 );
450 }
451
452 #[tokio::test]
453 async fn hybrid_denies_after_quota_exhausted() {
454 let limiter = HybridRateLimiter::new(lazy_pool());
455 let config = cfg(3, 60_000);
456
457 for i in 0..3 {
458 let r = limiter.check("user:alice:hit", &config).await.unwrap();
459 assert!(r.allowed, "request {i} should be allowed within quota");
460 }
461
462 let denied = limiter.check("user:alice:hit", &config).await.unwrap();
463 assert!(!denied.allowed, "4th request should be denied");
464 assert!(denied.retry_after.is_some());
465 }
466
467 #[tokio::test]
468 async fn hybrid_isolates_keys() {
469 let limiter = HybridRateLimiter::new(lazy_pool());
470 let config = cfg(2, 60_000);
471
472 assert!(limiter.check("alice", &config).await.unwrap().allowed);
474 assert!(limiter.check("alice", &config).await.unwrap().allowed);
475 assert!(!limiter.check("alice", &config).await.unwrap().allowed);
476
477 assert!(limiter.check("bob", &config).await.unwrap().allowed);
479 }
480
481 #[tokio::test]
482 async fn hybrid_concurrent_consumers_respect_quota() {
483 let limiter = Arc::new(HybridRateLimiter::new(lazy_pool()));
484 let config = Arc::new(cfg(10, 60_000));
485
486 let mut joins = Vec::new();
487 for _ in 0..50 {
488 let l = limiter.clone();
489 let c = config.clone();
490 joins.push(tokio::spawn(async move {
491 l.check("user:shared", &c).await.unwrap().allowed
492 }));
493 }
494
495 let mut allowed = 0;
496 for j in joins {
497 if j.await.unwrap() {
498 allowed += 1;
499 }
500 }
501 assert_eq!(
504 allowed, 10,
505 "exactly quota worth of requests should pass under contention"
506 );
507 }
508
509 #[tokio::test]
510 async fn hybrid_enforce_returns_typed_error() {
511 let limiter = HybridRateLimiter::new(lazy_pool());
512 let config = cfg(1, 60_000);
513 assert!(limiter.enforce("k", &config).await.is_ok());
514 match limiter.enforce("k", &config).await {
515 Err(ForgeError::RateLimitExceeded {
516 retry_after,
517 limit,
518 remaining: _,
519 }) => {
520 assert_eq!(limit, 1);
521 assert!(retry_after > Duration::ZERO);
522 }
523 other => panic!("expected RateLimitExceeded, got {other:?}"),
524 }
525 }
526
527 #[tokio::test]
528 async fn hybrid_cleanup_evicts_idle_buckets() {
529 let limiter = HybridRateLimiter::new(lazy_pool());
530 let now = std::time::Instant::now();
532 limiter.local.insert(
533 "fresh".to_string(),
534 LocalBucket {
535 tokens: 1.0,
536 max_tokens: 1.0,
537 refill_rate: 1.0,
538 last_refill: now,
539 },
540 );
541 limiter.local.insert(
542 "stale".to_string(),
543 LocalBucket {
544 tokens: 1.0,
545 max_tokens: 1.0,
546 refill_rate: 1.0,
547 last_refill: now - Duration::from_secs(600),
548 },
549 );
550
551 limiter.cleanup_local(Duration::from_secs(300));
552
553 assert!(limiter.local.contains_key("fresh"));
554 assert!(!limiter.local.contains_key("stale"));
555 }
556
557 #[tokio::test]
558 async fn build_key_covers_all_variants() {
559 let limiter = StrictRateLimiter::new(lazy_pool());
560 let anon = AuthContext::unauthenticated();
561 let req = RequestMetadata::default();
562
563 assert_eq!(
564 limiter.build_key(RateLimitKey::Global, "act", &anon, &req),
565 "global:act"
566 );
567 let ip_key = limiter.build_key(RateLimitKey::Ip, "act", &anon, &req);
568 assert!(ip_key.starts_with("ip:"));
569 assert!(ip_key.ends_with(":act"));
570
571 let user_key = limiter.build_key(RateLimitKey::User, "act", &anon, &req);
573 assert!(user_key.starts_with("user:anon-"));
574 assert_eq!(
575 limiter.build_key(RateLimitKey::Tenant, "act", &anon, &req),
576 "tenant:none:act"
577 );
578
579 let custom = limiter.build_key(RateLimitKey::Custom("org".to_string()), "act", &anon, &req);
580 assert_eq!(custom, "custom:org:unknown:act");
581 }
582}