Skip to main content

forge_runtime/rate_limit/
limiter.rs

1use std::future::Future;
2use std::pin::Pin;
3use std::time::Duration;
4
5use chrono::{DateTime, Utc};
6use dashmap::DashMap;
7use sqlx::PgPool;
8
9use forge_core::rate_limit::{RateLimitConfig, RateLimitKey, RateLimitResult, RateLimiterBackend};
10use forge_core::{AuthContext, ForgeError, RequestMetadata, Result};
11
12/// Strict rate limiter backed entirely by PostgreSQL.
13///
14/// Every check round-trips to PG, so limits are cluster-wide correct at the
15/// cost of one query per rate-limited request. Right for billing-grade
16/// quotas; for DDoS protection prefer [`HybridRateLimiter`].
17///
18/// Implements token bucket algorithm with atomic updates.
19pub struct StrictRateLimiter {
20    pool: PgPool,
21}
22
23impl StrictRateLimiter {
24    pub fn new(pool: PgPool) -> Self {
25        Self { pool }
26    }
27
28    pub async fn check(
29        &self,
30        bucket_key: &str,
31        config: &RateLimitConfig,
32    ) -> Result<RateLimitResult> {
33        let max_tokens = config.requests as f64;
34        let refill_rate = config.refill_rate();
35
36        // Unconditional subtraction drove `tokens` arbitrarily negative under
37        // sustained overload, inflating `retry_after` into multi-minute waits.
38        // GREATEST(refilled - 1, 0) keeps `retry_after` proportional to the
39        // actual single-token refill time.
40        let result = sqlx::query!(
41            r#"
42            INSERT INTO forge_rate_limits (bucket_key, tokens, last_refill, max_tokens, refill_rate)
43            VALUES ($1, $2 - 1, NOW(), $2, $3)
44            ON CONFLICT (bucket_key) DO UPDATE SET
45                tokens = GREATEST(
46                    LEAST(
47                        forge_rate_limits.max_tokens::double precision,
48                        forge_rate_limits.tokens +
49                            (EXTRACT(EPOCH FROM (NOW() - forge_rate_limits.last_refill)) * forge_rate_limits.refill_rate)
50                    ) - 1,
51                    -1.0
52                ),
53                last_refill = NOW()
54            RETURNING tokens, max_tokens, last_refill, (tokens >= 0) as "allowed!"
55            "#,
56            bucket_key,
57            max_tokens as i32,
58            refill_rate
59        )
60        .fetch_one(&self.pool)
61        .await
62        .map_err(ForgeError::Database)?;
63
64        let tokens = result.tokens;
65        let last_refill = result.last_refill;
66        let allowed = result.allowed;
67
68        let remaining = tokens.max(0.0) as u32;
69        let reset_at =
70            last_refill + chrono::Duration::seconds(((max_tokens - tokens) / refill_rate) as i64);
71
72        if allowed {
73            Ok(RateLimitResult::allowed(remaining, reset_at))
74        } else {
75            // tokens is clamped to >= -1, so retry_after is bounded by
76            // (1 - (-1)) / refill_rate = 2 / refill_rate — proportional to
77            // one refill interval rather than runaway.
78            let retry_after = Duration::from_secs_f64((1.0 - tokens) / refill_rate);
79            Ok(RateLimitResult::denied(remaining, reset_at, retry_after))
80        }
81    }
82
83    pub fn build_key(
84        &self,
85        key_type: RateLimitKey,
86        action_name: &str,
87        auth: &AuthContext,
88        request: &RequestMetadata,
89    ) -> String {
90        match key_type {
91            RateLimitKey::User => {
92                let user_id = auth.user_id().map(|u| u.to_string()).unwrap_or_else(|| {
93                    let ip = request.client_ip().unwrap_or("unknown");
94                    format!("anon-{ip}")
95                });
96                format!("user:{}:{}", user_id, action_name)
97            }
98            RateLimitKey::Ip => {
99                let ip = request.client_ip().unwrap_or("unknown");
100                format!("ip:{}:{}", ip, action_name)
101            }
102            RateLimitKey::Tenant => {
103                let tenant_id = auth
104                    .claim("tenant_id")
105                    .and_then(|v| v.as_str())
106                    .unwrap_or("none");
107                format!("tenant:{}:{}", tenant_id, action_name)
108            }
109            RateLimitKey::UserAction => {
110                let user_id = auth
111                    .user_id()
112                    .map(|u| u.to_string())
113                    .unwrap_or_else(|| "anonymous".to_string());
114                format!("user_action:{}:{}", user_id, action_name)
115            }
116            RateLimitKey::Global => {
117                format!("global:{}", action_name)
118            }
119            RateLimitKey::Custom(claim_name) => {
120                let value = auth
121                    .claim(&claim_name)
122                    .and_then(|v| v.as_str())
123                    .unwrap_or("unknown");
124                format!("custom:{}:{}:{}", claim_name, value, action_name)
125            }
126            // RateLimitKey is #[non_exhaustive]; future keys collapse to a
127            // global bucket until the runtime adds an explicit handler.
128            _ => format!("global:{}", action_name),
129        }
130    }
131
132    pub async fn enforce(
133        &self,
134        bucket_key: &str,
135        config: &RateLimitConfig,
136    ) -> Result<RateLimitResult> {
137        let result = self.check(bucket_key, config).await?;
138        if !result.allowed {
139            #[cfg(feature = "gateway")]
140            crate::signals::emit_diagnostic(
141                "rate_limit.exceeded",
142                serde_json::json!({
143                    "bucket": bucket_key,
144                    "limit": config.requests,
145                    "remaining": result.remaining,
146                    "retry_after_ms": result
147                        .retry_after
148                        .unwrap_or(Duration::from_secs(1))
149                        .as_millis() as u64,
150                }),
151                None,
152                None,
153                None,
154                None,
155                false,
156            );
157            return Err(ForgeError::RateLimitExceeded {
158                retry_after: result.retry_after.unwrap_or(Duration::from_secs(1)),
159                limit: config.requests,
160                remaining: result.remaining,
161            });
162        }
163        Ok(result)
164    }
165
166    pub async fn cleanup(&self, older_than: DateTime<Utc>) -> Result<u64> {
167        let result = sqlx::query!(
168            r#"
169            DELETE FROM forge_rate_limits
170            WHERE created_at < $1
171            "#,
172            older_than,
173        )
174        .execute(&self.pool)
175        .await
176        .map_err(ForgeError::Database)?;
177
178        Ok(result.rows_affected())
179    }
180}
181
182struct LocalBucket {
183    tokens: f64,
184    max_tokens: f64,
185    refill_rate: f64,
186    last_refill: std::time::Instant,
187}
188
189impl LocalBucket {
190    fn new(max_tokens: f64, refill_rate: f64) -> Self {
191        Self {
192            tokens: max_tokens,
193            max_tokens,
194            refill_rate,
195            last_refill: std::time::Instant::now(),
196        }
197    }
198
199    fn try_consume(&mut self) -> bool {
200        let now = std::time::Instant::now();
201        let elapsed = now.duration_since(self.last_refill).as_secs_f64();
202        self.tokens = (self.tokens + elapsed * self.refill_rate).min(self.max_tokens);
203        self.last_refill = now;
204
205        if self.tokens >= 1.0 {
206            self.tokens -= 1.0;
207            true
208        } else {
209            false
210        }
211    }
212
213    fn remaining(&self) -> u32 {
214        self.tokens.max(0.0) as u32
215    }
216
217    fn time_until_token(&self) -> Duration {
218        if self.tokens >= 1.0 {
219            Duration::ZERO
220        } else {
221            Duration::from_secs_f64((1.0 - self.tokens) / self.refill_rate)
222        }
223    }
224}
225
226/// Hybrid rate limiter with in-memory fast path and periodic DB sync.
227///
228/// Per-user/per-IP checks use a local DashMap for sub-microsecond decisions,
229/// so a `100 req/min` limit becomes `100 × N` across an N-node cluster. Right
230/// for DDoS protection where the threshold is approximate. For cluster-wide
231/// correctness (e.g. billing quotas) use [`StrictRateLimiter`].
232///
233/// `Global` keys always hit the database for cross-node consistency.
234///
235/// DESIGN: Per-node rate limiting. Cluster-wide consistency trades latency
236/// for accuracy. With N nodes, effective limit is N× per-key. Keep per-node
237/// budgets low.
238pub struct HybridRateLimiter {
239    local: DashMap<String, LocalBucket>,
240    db_limiter: StrictRateLimiter,
241    max_local_buckets: usize,
242}
243
244impl HybridRateLimiter {
245    pub fn new(pool: PgPool) -> Self {
246        Self::with_max_buckets(pool, 100_000)
247    }
248
249    /// Create a hybrid rate limiter with a custom local bucket limit.
250    pub fn with_max_buckets(pool: PgPool, max_local_buckets: usize) -> Self {
251        Self {
252            local: DashMap::new(),
253            db_limiter: StrictRateLimiter::new(pool),
254            max_local_buckets,
255        }
256    }
257
258    pub async fn check(
259        &self,
260        bucket_key: &str,
261        config: &RateLimitConfig,
262    ) -> Result<RateLimitResult> {
263        if config.key == RateLimitKey::Global {
264            return self.db_limiter.check(bucket_key, config).await;
265        }
266
267        let max_tokens = config.requests as f64;
268        let refill_rate = config.refill_rate();
269
270        if self.local.len() > self.max_local_buckets {
271            self.cleanup_local(Duration::from_secs(300)); // evict entries idle > 5 min
272        }
273
274        let mut bucket = self
275            .local
276            .entry(bucket_key.to_string())
277            .or_insert_with(|| LocalBucket::new(max_tokens, refill_rate));
278
279        let allowed = bucket.try_consume();
280        let remaining = bucket.remaining();
281        let reset_at = Utc::now()
282            + chrono::Duration::seconds(((max_tokens - bucket.tokens) / refill_rate) as i64);
283
284        if allowed {
285            Ok(RateLimitResult::allowed(remaining, reset_at))
286        } else {
287            let retry_after = bucket.time_until_token();
288            Ok(RateLimitResult::denied(remaining, reset_at, retry_after))
289        }
290    }
291
292    pub fn build_key(
293        &self,
294        key_type: RateLimitKey,
295        action_name: &str,
296        auth: &AuthContext,
297        request: &RequestMetadata,
298    ) -> String {
299        self.db_limiter
300            .build_key(key_type, action_name, auth, request)
301    }
302
303    pub async fn enforce(
304        &self,
305        bucket_key: &str,
306        config: &RateLimitConfig,
307    ) -> Result<RateLimitResult> {
308        let result = self.check(bucket_key, config).await?;
309        if !result.allowed {
310            return Err(ForgeError::RateLimitExceeded {
311                retry_after: result.retry_after.unwrap_or(Duration::from_secs(1)),
312                limit: config.requests,
313                remaining: result.remaining,
314            });
315        }
316        Ok(result)
317    }
318
319    /// Clean up expired local buckets (call periodically).
320    pub fn cleanup_local(&self, max_idle: Duration) {
321        let cutoff = std::time::Instant::now()
322            .checked_sub(max_idle)
323            .unwrap_or(std::time::Instant::now());
324        self.local.retain(|_, bucket| bucket.last_refill > cutoff);
325    }
326}
327
328impl RateLimiterBackend for StrictRateLimiter {
329    fn check<'a>(
330        &'a self,
331        bucket_key: &'a str,
332        config: &'a RateLimitConfig,
333    ) -> Pin<Box<dyn Future<Output = Result<RateLimitResult>> + Send + 'a>> {
334        Box::pin(StrictRateLimiter::check(self, bucket_key, config))
335    }
336
337    fn build_key(
338        &self,
339        key_type: RateLimitKey,
340        action_name: &str,
341        auth: &AuthContext,
342        request: &RequestMetadata,
343    ) -> String {
344        StrictRateLimiter::build_key(self, key_type, action_name, auth, request)
345    }
346
347    fn enforce<'a>(
348        &'a self,
349        bucket_key: &'a str,
350        config: &'a RateLimitConfig,
351    ) -> Pin<Box<dyn Future<Output = Result<RateLimitResult>> + Send + 'a>> {
352        Box::pin(StrictRateLimiter::enforce(self, bucket_key, config))
353    }
354}
355
356impl RateLimiterBackend for HybridRateLimiter {
357    fn check<'a>(
358        &'a self,
359        bucket_key: &'a str,
360        config: &'a RateLimitConfig,
361    ) -> Pin<Box<dyn Future<Output = Result<RateLimitResult>> + Send + 'a>> {
362        Box::pin(HybridRateLimiter::check(self, bucket_key, config))
363    }
364
365    fn build_key(
366        &self,
367        key_type: RateLimitKey,
368        action_name: &str,
369        auth: &AuthContext,
370        request: &RequestMetadata,
371    ) -> String {
372        HybridRateLimiter::build_key(self, key_type, action_name, auth, request)
373    }
374
375    fn enforce<'a>(
376        &'a self,
377        bucket_key: &'a str,
378        config: &'a RateLimitConfig,
379    ) -> Pin<Box<dyn Future<Output = Result<RateLimitResult>> + Send + 'a>> {
380        Box::pin(HybridRateLimiter::enforce(self, bucket_key, config))
381    }
382}
383
384#[cfg(test)]
385#[allow(
386    clippy::unwrap_used,
387    clippy::indexing_slicing,
388    clippy::panic,
389    clippy::disallowed_methods
390)]
391mod tests {
392    use super::*;
393    use std::sync::Arc;
394
395    fn lazy_pool() -> PgPool {
396        // Pool is never actually queried in these tests — the local DashMap
397        // path short-circuits before any DB call for non-Global keys.
398        sqlx::postgres::PgPoolOptions::new()
399            .max_connections(1)
400            .connect_lazy("postgres://localhost/test")
401            .expect("connect_lazy never fails for a syntactically valid URL")
402    }
403
404    fn cfg(requests: u32, window_ms: u64) -> RateLimitConfig {
405        RateLimitConfig::new(requests, Duration::from_millis(window_ms))
406    }
407
408    #[test]
409    fn local_bucket_consumes_then_denies() {
410        let mut bucket = LocalBucket::new(3.0, 1.0);
411        assert!(bucket.try_consume());
412        assert!(bucket.try_consume());
413        assert!(bucket.try_consume());
414        // 4th request must be denied — bucket is now <1 token.
415        assert!(!bucket.try_consume());
416        assert_eq!(bucket.remaining(), 0);
417    }
418
419    #[test]
420    fn local_bucket_refill_does_not_exceed_max() {
421        let mut bucket = LocalBucket::new(5.0, 1000.0);
422        // Drain to empty
423        for _ in 0..5 {
424            bucket.try_consume();
425        }
426        // Backdate last_refill to force a huge refill
427        bucket.last_refill = std::time::Instant::now() - Duration::from_secs(10);
428        // try_consume refills then consumes one token; should land at max - 1.
429        assert!(bucket.try_consume());
430        assert_eq!(bucket.remaining(), 4);
431    }
432
433    #[test]
434    fn local_bucket_time_until_token_is_zero_when_available() {
435        let bucket = LocalBucket::new(5.0, 1.0);
436        assert_eq!(bucket.time_until_token(), Duration::ZERO);
437    }
438
439    #[test]
440    fn local_bucket_time_until_token_reflects_refill_rate() {
441        // 0.5 tokens at 1.0/s → need 0.5s for next token.
442        let mut bucket = LocalBucket::new(1.0, 1.0);
443        bucket.tokens = 0.5;
444        let wait = bucket.time_until_token();
445        // Allow small float slack.
446        assert!(
447            wait.as_secs_f64() > 0.45 && wait.as_secs_f64() < 0.55,
448            "expected ~0.5s, got {wait:?}",
449        );
450    }
451
452    #[tokio::test]
453    async fn hybrid_denies_after_quota_exhausted() {
454        let limiter = HybridRateLimiter::new(lazy_pool());
455        let config = cfg(3, 60_000);
456
457        for i in 0..3 {
458            let r = limiter.check("user:alice:hit", &config).await.unwrap();
459            assert!(r.allowed, "request {i} should be allowed within quota");
460        }
461
462        let denied = limiter.check("user:alice:hit", &config).await.unwrap();
463        assert!(!denied.allowed, "4th request should be denied");
464        assert!(denied.retry_after.is_some());
465    }
466
467    #[tokio::test]
468    async fn hybrid_isolates_keys() {
469        let limiter = HybridRateLimiter::new(lazy_pool());
470        let config = cfg(2, 60_000);
471
472        // Drain alice
473        assert!(limiter.check("alice", &config).await.unwrap().allowed);
474        assert!(limiter.check("alice", &config).await.unwrap().allowed);
475        assert!(!limiter.check("alice", &config).await.unwrap().allowed);
476
477        // bob's bucket must be untouched
478        assert!(limiter.check("bob", &config).await.unwrap().allowed);
479    }
480
481    #[tokio::test]
482    async fn hybrid_concurrent_consumers_respect_quota() {
483        let limiter = Arc::new(HybridRateLimiter::new(lazy_pool()));
484        let config = Arc::new(cfg(10, 60_000));
485
486        let mut joins = Vec::new();
487        for _ in 0..50 {
488            let l = limiter.clone();
489            let c = config.clone();
490            joins.push(tokio::spawn(async move {
491                l.check("user:shared", &c).await.unwrap().allowed
492            }));
493        }
494
495        let mut allowed = 0;
496        for j in joins {
497            if j.await.unwrap() {
498                allowed += 1;
499            }
500        }
501        // With 10 tokens and 50 concurrent requests, allowed must be exactly 10
502        // (DashMap entry-or-insert serializes per-key under contention).
503        assert_eq!(
504            allowed, 10,
505            "exactly quota worth of requests should pass under contention"
506        );
507    }
508
509    #[tokio::test]
510    async fn hybrid_enforce_returns_typed_error() {
511        let limiter = HybridRateLimiter::new(lazy_pool());
512        let config = cfg(1, 60_000);
513        assert!(limiter.enforce("k", &config).await.is_ok());
514        match limiter.enforce("k", &config).await {
515            Err(ForgeError::RateLimitExceeded {
516                retry_after,
517                limit,
518                remaining: _,
519            }) => {
520                assert_eq!(limit, 1);
521                assert!(retry_after > Duration::ZERO);
522            }
523            other => panic!("expected RateLimitExceeded, got {other:?}"),
524        }
525    }
526
527    #[tokio::test]
528    async fn hybrid_cleanup_evicts_idle_buckets() {
529        let limiter = HybridRateLimiter::new(lazy_pool());
530        // Seed two buckets directly with different last_refill times.
531        let now = std::time::Instant::now();
532        limiter.local.insert(
533            "fresh".to_string(),
534            LocalBucket {
535                tokens: 1.0,
536                max_tokens: 1.0,
537                refill_rate: 1.0,
538                last_refill: now,
539            },
540        );
541        limiter.local.insert(
542            "stale".to_string(),
543            LocalBucket {
544                tokens: 1.0,
545                max_tokens: 1.0,
546                refill_rate: 1.0,
547                last_refill: now - Duration::from_secs(600),
548            },
549        );
550
551        limiter.cleanup_local(Duration::from_secs(300));
552
553        assert!(limiter.local.contains_key("fresh"));
554        assert!(!limiter.local.contains_key("stale"));
555    }
556
557    #[tokio::test]
558    async fn build_key_covers_all_variants() {
559        let limiter = StrictRateLimiter::new(lazy_pool());
560        let anon = AuthContext::unauthenticated();
561        let req = RequestMetadata::default();
562
563        assert_eq!(
564            limiter.build_key(RateLimitKey::Global, "act", &anon, &req),
565            "global:act"
566        );
567        let ip_key = limiter.build_key(RateLimitKey::Ip, "act", &anon, &req);
568        assert!(ip_key.starts_with("ip:"));
569        assert!(ip_key.ends_with(":act"));
570
571        // Unauthenticated user collapses to anon-<ip>; tenant lookup misses.
572        let user_key = limiter.build_key(RateLimitKey::User, "act", &anon, &req);
573        assert!(user_key.starts_with("user:anon-"));
574        assert_eq!(
575            limiter.build_key(RateLimitKey::Tenant, "act", &anon, &req),
576            "tenant:none:act"
577        );
578
579        let custom = limiter.build_key(RateLimitKey::Custom("org".to_string()), "act", &anon, &req);
580        assert_eq!(custom, "custom:org:unknown:act");
581    }
582}