Skip to main content

forge_runtime/rate_limit/
limiter.rs

1use std::time::Duration;
2
3use chrono::{DateTime, Utc};
4use sqlx::PgPool;
5
6use forge_core::rate_limit::{RateLimitConfig, RateLimitKey, RateLimitResult};
7use forge_core::{AuthContext, ForgeError, RequestMetadata, Result};
8
9/// Rate limiter using PostgreSQL for state storage.
10///
11/// Implements token bucket algorithm with atomic updates.
12pub struct RateLimiter {
13    pool: PgPool,
14}
15
16impl RateLimiter {
17    /// Create a new rate limiter.
18    pub fn new(pool: PgPool) -> Self {
19        Self { pool }
20    }
21
22    /// Check rate limit for a bucket key.
23    pub async fn check(
24        &self,
25        bucket_key: &str,
26        config: &RateLimitConfig,
27    ) -> Result<RateLimitResult> {
28        let max_tokens = config.requests as f64;
29        let refill_rate = config.refill_rate();
30
31        // Atomic upsert with token bucket logic
32        let result: (f64, i32, DateTime<Utc>, bool) = sqlx::query_as(
33            r#"
34            INSERT INTO forge_rate_limits (bucket_key, tokens, last_refill, max_tokens, refill_rate)
35            VALUES ($1, $2 - 1, NOW(), $2, $3)
36            ON CONFLICT (bucket_key) DO UPDATE SET
37                tokens = LEAST(
38                    forge_rate_limits.max_tokens::double precision,
39                    forge_rate_limits.tokens +
40                        (EXTRACT(EPOCH FROM (NOW() - forge_rate_limits.last_refill)) * forge_rate_limits.refill_rate)
41                ) - 1,
42                last_refill = NOW()
43            RETURNING tokens, max_tokens, last_refill, (tokens >= 0) as allowed
44            "#,
45        )
46        .bind(bucket_key)
47        .bind(max_tokens as i32)
48        .bind(refill_rate)
49        .fetch_one(&self.pool)
50        .await
51        .map_err(|e| ForgeError::Database(e.to_string()))?;
52
53        let (tokens, _max, last_refill, allowed) = result;
54
55        let remaining = tokens.max(0.0) as u32;
56        let reset_at =
57            last_refill + chrono::Duration::seconds(((max_tokens - tokens) / refill_rate) as i64);
58
59        if allowed {
60            Ok(RateLimitResult::allowed(remaining, reset_at))
61        } else {
62            let retry_after = Duration::from_secs_f64((1.0 - tokens) / refill_rate);
63            Ok(RateLimitResult::denied(remaining, reset_at, retry_after))
64        }
65    }
66
67    /// Build a bucket key for the given parameters.
68    pub fn build_key(
69        &self,
70        key_type: RateLimitKey,
71        action_name: &str,
72        auth: &AuthContext,
73        request: &RequestMetadata,
74    ) -> String {
75        match key_type {
76            RateLimitKey::User => {
77                let user_id = auth
78                    .user_id()
79                    .map(|u| u.to_string())
80                    .unwrap_or_else(|| "anonymous".to_string());
81                format!("user:{}:{}", user_id, action_name)
82            }
83            RateLimitKey::Ip => {
84                let ip = request.client_ip.as_deref().unwrap_or("unknown");
85                format!("ip:{}:{}", ip, action_name)
86            }
87            RateLimitKey::Tenant => {
88                let tenant_id = auth
89                    .claim("tenant_id")
90                    .and_then(|v| v.as_str())
91                    .unwrap_or("none");
92                format!("tenant:{}:{}", tenant_id, action_name)
93            }
94            RateLimitKey::UserAction => {
95                let user_id = auth
96                    .user_id()
97                    .map(|u| u.to_string())
98                    .unwrap_or_else(|| "anonymous".to_string());
99                format!("user_action:{}:{}", user_id, action_name)
100            }
101            RateLimitKey::Global => {
102                format!("global:{}", action_name)
103            }
104        }
105    }
106
107    /// Check rate limit and return an error if exceeded.
108    pub async fn enforce(
109        &self,
110        bucket_key: &str,
111        config: &RateLimitConfig,
112    ) -> Result<RateLimitResult> {
113        let result = self.check(bucket_key, config).await?;
114        if !result.allowed {
115            return Err(ForgeError::RateLimitExceeded {
116                retry_after: result.retry_after.unwrap_or(Duration::from_secs(1)),
117                limit: config.requests,
118                remaining: result.remaining,
119            });
120        }
121        Ok(result)
122    }
123
124    /// Reset a rate limit bucket.
125    pub async fn reset(&self, bucket_key: &str) -> Result<()> {
126        sqlx::query("DELETE FROM forge_rate_limits WHERE bucket_key = $1")
127            .bind(bucket_key)
128            .execute(&self.pool)
129            .await
130            .map_err(|e| ForgeError::Database(e.to_string()))?;
131        Ok(())
132    }
133
134    /// Clean up old rate limit entries.
135    pub async fn cleanup(&self, older_than: DateTime<Utc>) -> Result<u64> {
136        let result = sqlx::query(
137            r#"
138            DELETE FROM forge_rate_limits
139            WHERE created_at < $1
140            "#,
141        )
142        .bind(older_than)
143        .execute(&self.pool)
144        .await
145        .map_err(|e| ForgeError::Database(e.to_string()))?;
146
147        Ok(result.rows_affected())
148    }
149}
150
151#[cfg(test)]
152mod tests {
153    use super::*;
154
155    #[tokio::test]
156    async fn test_rate_limiter_creation() {
157        let pool = sqlx::postgres::PgPoolOptions::new()
158            .max_connections(1)
159            .connect_lazy("postgres://localhost/test")
160            .expect("Failed to create mock pool");
161
162        let _limiter = RateLimiter::new(pool);
163    }
164
165    #[tokio::test]
166    async fn test_build_key() {
167        let pool = sqlx::postgres::PgPoolOptions::new()
168            .max_connections(1)
169            .connect_lazy("postgres://localhost/test")
170            .expect("Failed to create mock pool");
171
172        let limiter = RateLimiter::new(pool);
173        let auth = AuthContext::unauthenticated();
174        let request = RequestMetadata::default();
175
176        let key = limiter.build_key(RateLimitKey::Global, "test_action", &auth, &request);
177        assert_eq!(key, "global:test_action");
178
179        let key = limiter.build_key(RateLimitKey::User, "test_action", &auth, &request);
180        assert!(key.starts_with("user:"));
181    }
182}