Skip to main content

forge_runtime/rate_limit/
limiter.rs

1use std::time::Duration;
2
3use chrono::{DateTime, Utc};
4use dashmap::DashMap;
5use sqlx::PgPool;
6
7use forge_core::rate_limit::{RateLimitConfig, RateLimitKey, RateLimitResult};
8use forge_core::{AuthContext, ForgeError, RequestMetadata, Result};
9
10/// Rate limiter using PostgreSQL for state storage.
11///
12/// Implements token bucket algorithm with atomic updates.
13pub struct RateLimiter {
14    pool: PgPool,
15}
16
17impl RateLimiter {
18    /// Create a new rate limiter.
19    pub fn new(pool: PgPool) -> Self {
20        Self { pool }
21    }
22
23    /// Check rate limit for a bucket key.
24    pub async fn check(
25        &self,
26        bucket_key: &str,
27        config: &RateLimitConfig,
28    ) -> Result<RateLimitResult> {
29        let max_tokens = config.requests as f64;
30        let refill_rate = config.refill_rate();
31
32        // Atomic upsert with token bucket logic
33        let result = sqlx::query!(
34            r#"
35            INSERT INTO forge_rate_limits (bucket_key, tokens, last_refill, max_tokens, refill_rate)
36            VALUES ($1, $2 - 1, NOW(), $2, $3)
37            ON CONFLICT (bucket_key) DO UPDATE SET
38                tokens = LEAST(
39                    forge_rate_limits.max_tokens::double precision,
40                    forge_rate_limits.tokens +
41                        (EXTRACT(EPOCH FROM (NOW() - forge_rate_limits.last_refill)) * forge_rate_limits.refill_rate)
42                ) - 1,
43                last_refill = NOW()
44            RETURNING tokens, max_tokens, last_refill, (tokens >= 0) as "allowed!"
45            "#,
46            bucket_key,
47            max_tokens as i32,
48            refill_rate
49        )
50        .fetch_one(&self.pool)
51        .await
52        .map_err(|e| ForgeError::Database(e.to_string()))?;
53
54        let tokens = result.tokens;
55        let last_refill = result.last_refill;
56        let allowed = result.allowed;
57
58        let remaining = tokens.max(0.0) as u32;
59        let reset_at =
60            last_refill + chrono::Duration::seconds(((max_tokens - tokens) / refill_rate) as i64);
61
62        if allowed {
63            Ok(RateLimitResult::allowed(remaining, reset_at))
64        } else {
65            let retry_after = Duration::from_secs_f64((1.0 - tokens) / refill_rate);
66            Ok(RateLimitResult::denied(remaining, reset_at, retry_after))
67        }
68    }
69
70    /// Build a bucket key for the given parameters.
71    pub fn build_key(
72        &self,
73        key_type: RateLimitKey,
74        action_name: &str,
75        auth: &AuthContext,
76        request: &RequestMetadata,
77    ) -> String {
78        match key_type {
79            RateLimitKey::User => {
80                let user_id = auth
81                    .user_id()
82                    .map(|u| u.to_string())
83                    .unwrap_or_else(|| "anonymous".to_string());
84                format!("user:{}:{}", user_id, action_name)
85            }
86            RateLimitKey::Ip => {
87                let ip = request.client_ip.as_deref().unwrap_or("unknown");
88                format!("ip:{}:{}", ip, action_name)
89            }
90            RateLimitKey::Tenant => {
91                let tenant_id = auth
92                    .claim("tenant_id")
93                    .and_then(|v| v.as_str())
94                    .unwrap_or("none");
95                format!("tenant:{}:{}", tenant_id, action_name)
96            }
97            RateLimitKey::UserAction => {
98                let user_id = auth
99                    .user_id()
100                    .map(|u| u.to_string())
101                    .unwrap_or_else(|| "anonymous".to_string());
102                format!("user_action:{}:{}", user_id, action_name)
103            }
104            RateLimitKey::Global => {
105                format!("global:{}", action_name)
106            }
107        }
108    }
109
110    /// Check rate limit and return an error if exceeded.
111    pub async fn enforce(
112        &self,
113        bucket_key: &str,
114        config: &RateLimitConfig,
115    ) -> Result<RateLimitResult> {
116        let result = self.check(bucket_key, config).await?;
117        if !result.allowed {
118            return Err(ForgeError::RateLimitExceeded {
119                retry_after: result.retry_after.unwrap_or(Duration::from_secs(1)),
120                limit: config.requests,
121                remaining: result.remaining,
122            });
123        }
124        Ok(result)
125    }
126
127    /// Reset a rate limit bucket.
128    pub async fn reset(&self, bucket_key: &str) -> Result<()> {
129        sqlx::query("DELETE FROM forge_rate_limits WHERE bucket_key = $1")
130            .bind(bucket_key)
131            .execute(&self.pool)
132            .await
133            .map_err(|e| ForgeError::Database(e.to_string()))?;
134        Ok(())
135    }
136
137    /// Clean up old rate limit entries.
138    pub async fn cleanup(&self, older_than: DateTime<Utc>) -> Result<u64> {
139        let result = sqlx::query(
140            r#"
141            DELETE FROM forge_rate_limits
142            WHERE created_at < $1
143            "#,
144        )
145        .bind(older_than)
146        .execute(&self.pool)
147        .await
148        .map_err(|e| ForgeError::Database(e.to_string()))?;
149
150        Ok(result.rows_affected())
151    }
152}
153
154struct LocalBucket {
155    tokens: f64,
156    max_tokens: f64,
157    refill_rate: f64,
158    last_refill: std::time::Instant,
159}
160
161impl LocalBucket {
162    fn new(max_tokens: f64, refill_rate: f64) -> Self {
163        Self {
164            tokens: max_tokens,
165            max_tokens,
166            refill_rate,
167            last_refill: std::time::Instant::now(),
168        }
169    }
170
171    fn try_consume(&mut self) -> bool {
172        let now = std::time::Instant::now();
173        let elapsed = now.duration_since(self.last_refill).as_secs_f64();
174        self.tokens = (self.tokens + elapsed * self.refill_rate).min(self.max_tokens);
175        self.last_refill = now;
176
177        if self.tokens >= 1.0 {
178            self.tokens -= 1.0;
179            true
180        } else {
181            false
182        }
183    }
184
185    fn remaining(&self) -> u32 {
186        self.tokens.max(0.0) as u32
187    }
188
189    fn time_until_token(&self) -> Duration {
190        if self.tokens >= 1.0 {
191            Duration::ZERO
192        } else {
193            Duration::from_secs_f64((1.0 - self.tokens) / self.refill_rate)
194        }
195    }
196}
197
198/// Hybrid rate limiter with in-memory fast path and periodic DB sync.
199///
200/// Per-user/per-IP checks use a local DashMap for sub-microsecond decisions.
201/// Global keys always hit the database for strict cross-node consistency.
202pub struct HybridRateLimiter {
203    local: DashMap<String, LocalBucket>,
204    db_limiter: RateLimiter,
205}
206
207impl HybridRateLimiter {
208    pub fn new(pool: PgPool) -> Self {
209        Self {
210            local: DashMap::new(),
211            db_limiter: RateLimiter::new(pool),
212        }
213    }
214
215    /// Check rate limit. Uses local fast path for per-user/per-IP keys,
216    /// database for global keys.
217    pub async fn check(
218        &self,
219        bucket_key: &str,
220        config: &RateLimitConfig,
221    ) -> Result<RateLimitResult> {
222        if config.key == RateLimitKey::Global {
223            return self.db_limiter.check(bucket_key, config).await;
224        }
225
226        let max_tokens = config.requests as f64;
227        let refill_rate = config.refill_rate();
228
229        let mut bucket = self
230            .local
231            .entry(bucket_key.to_string())
232            .or_insert_with(|| LocalBucket::new(max_tokens, refill_rate));
233
234        let allowed = bucket.try_consume();
235        let remaining = bucket.remaining();
236        let reset_at = Utc::now()
237            + chrono::Duration::seconds(((max_tokens - bucket.tokens) / refill_rate) as i64);
238
239        if allowed {
240            Ok(RateLimitResult::allowed(remaining, reset_at))
241        } else {
242            let retry_after = bucket.time_until_token();
243            Ok(RateLimitResult::denied(remaining, reset_at, retry_after))
244        }
245    }
246
247    pub fn build_key(
248        &self,
249        key_type: RateLimitKey,
250        action_name: &str,
251        auth: &AuthContext,
252        request: &RequestMetadata,
253    ) -> String {
254        self.db_limiter
255            .build_key(key_type, action_name, auth, request)
256    }
257
258    pub async fn enforce(
259        &self,
260        bucket_key: &str,
261        config: &RateLimitConfig,
262    ) -> Result<RateLimitResult> {
263        let result = self.check(bucket_key, config).await?;
264        if !result.allowed {
265            return Err(ForgeError::RateLimitExceeded {
266                retry_after: result.retry_after.unwrap_or(Duration::from_secs(1)),
267                limit: config.requests,
268                remaining: result.remaining,
269            });
270        }
271        Ok(result)
272    }
273
274    /// Clean up expired local buckets (call periodically).
275    pub fn cleanup_local(&self, max_idle: Duration) {
276        let cutoff = std::time::Instant::now() - max_idle;
277        self.local.retain(|_, bucket| bucket.last_refill > cutoff);
278    }
279}
280
281#[cfg(test)]
282mod tests {
283    use super::*;
284
285    #[tokio::test]
286    async fn test_rate_limiter_creation() {
287        let pool = sqlx::postgres::PgPoolOptions::new()
288            .max_connections(1)
289            .connect_lazy("postgres://localhost/test")
290            .expect("Failed to create mock pool");
291
292        let _limiter = RateLimiter::new(pool);
293    }
294
295    #[tokio::test]
296    async fn test_build_key() {
297        let pool = sqlx::postgres::PgPoolOptions::new()
298            .max_connections(1)
299            .connect_lazy("postgres://localhost/test")
300            .expect("Failed to create mock pool");
301
302        let limiter = RateLimiter::new(pool);
303        let auth = AuthContext::unauthenticated();
304        let request = RequestMetadata::default();
305
306        let key = limiter.build_key(RateLimitKey::Global, "test_action", &auth, &request);
307        assert_eq!(key, "global:test_action");
308
309        let key = limiter.build_key(RateLimitKey::User, "test_action", &auth, &request);
310        assert!(key.starts_with("user:"));
311    }
312}