Skip to main content

forge_runtime/rate_limit/
limiter.rs

1use std::time::Duration;
2
3use chrono::{DateTime, Utc};
4use dashmap::DashMap;
5use sqlx::PgPool;
6
7use forge_core::rate_limit::{RateLimitConfig, RateLimitKey, RateLimitResult};
8use forge_core::{AuthContext, ForgeError, RequestMetadata, Result};
9
10/// Rate limiter using PostgreSQL for state storage.
11///
12/// Implements token bucket algorithm with atomic updates.
13pub struct RateLimiter {
14    pool: PgPool,
15}
16
17impl RateLimiter {
18    /// Create a new rate limiter.
19    pub fn new(pool: PgPool) -> Self {
20        Self { pool }
21    }
22
23    /// Check rate limit for a bucket key.
24    pub async fn check(
25        &self,
26        bucket_key: &str,
27        config: &RateLimitConfig,
28    ) -> Result<RateLimitResult> {
29        let max_tokens = config.requests as f64;
30        let refill_rate = config.refill_rate();
31
32        // Atomic upsert with token bucket logic
33        let result = sqlx::query!(
34            r#"
35            INSERT INTO forge_rate_limits (bucket_key, tokens, last_refill, max_tokens, refill_rate)
36            VALUES ($1, $2 - 1, NOW(), $2, $3)
37            ON CONFLICT (bucket_key) DO UPDATE SET
38                tokens = LEAST(
39                    forge_rate_limits.max_tokens::double precision,
40                    forge_rate_limits.tokens +
41                        (EXTRACT(EPOCH FROM (NOW() - forge_rate_limits.last_refill)) * forge_rate_limits.refill_rate)
42                ) - 1,
43                last_refill = NOW()
44            RETURNING tokens, max_tokens, last_refill, (tokens >= 0) as "allowed!"
45            "#,
46            bucket_key,
47            max_tokens as i32,
48            refill_rate
49        )
50        .fetch_one(&self.pool)
51        .await
52        .map_err(|e| ForgeError::Database(e.to_string()))?;
53
54        let tokens = result.tokens;
55        let last_refill = result.last_refill;
56        let allowed = result.allowed;
57
58        let remaining = tokens.max(0.0) as u32;
59        let reset_at =
60            last_refill + chrono::Duration::seconds(((max_tokens - tokens) / refill_rate) as i64);
61
62        if allowed {
63            Ok(RateLimitResult::allowed(remaining, reset_at))
64        } else {
65            let retry_after = Duration::from_secs_f64((1.0 - tokens) / refill_rate);
66            Ok(RateLimitResult::denied(remaining, reset_at, retry_after))
67        }
68    }
69
70    /// Build a bucket key for the given parameters.
71    pub fn build_key(
72        &self,
73        key_type: RateLimitKey,
74        action_name: &str,
75        auth: &AuthContext,
76        request: &RequestMetadata,
77    ) -> String {
78        match key_type {
79            RateLimitKey::User => {
80                let user_id = auth
81                    .user_id()
82                    .map(|u| u.to_string())
83                    .unwrap_or_else(|| "anonymous".to_string());
84                format!("user:{}:{}", user_id, action_name)
85            }
86            RateLimitKey::Ip => {
87                let ip = request.client_ip.as_deref().unwrap_or("unknown");
88                format!("ip:{}:{}", ip, action_name)
89            }
90            RateLimitKey::Tenant => {
91                let tenant_id = auth
92                    .claim("tenant_id")
93                    .and_then(|v| v.as_str())
94                    .unwrap_or("none");
95                format!("tenant:{}:{}", tenant_id, action_name)
96            }
97            RateLimitKey::UserAction => {
98                let user_id = auth
99                    .user_id()
100                    .map(|u| u.to_string())
101                    .unwrap_or_else(|| "anonymous".to_string());
102                format!("user_action:{}:{}", user_id, action_name)
103            }
104            RateLimitKey::Global => {
105                format!("global:{}", action_name)
106            }
107        }
108    }
109
110    /// Check rate limit and return an error if exceeded.
111    pub async fn enforce(
112        &self,
113        bucket_key: &str,
114        config: &RateLimitConfig,
115    ) -> Result<RateLimitResult> {
116        let result = self.check(bucket_key, config).await?;
117        if !result.allowed {
118            return Err(ForgeError::RateLimitExceeded {
119                retry_after: result.retry_after.unwrap_or(Duration::from_secs(1)),
120                limit: config.requests,
121                remaining: result.remaining,
122            });
123        }
124        Ok(result)
125    }
126
127    /// Reset a rate limit bucket.
128    pub async fn reset(&self, bucket_key: &str) -> Result<()> {
129        sqlx::query!(
130            "DELETE FROM forge_rate_limits WHERE bucket_key = $1",
131            bucket_key
132        )
133        .execute(&self.pool)
134        .await
135        .map_err(|e| ForgeError::Database(e.to_string()))?;
136        Ok(())
137    }
138
139    /// Clean up old rate limit entries.
140    pub async fn cleanup(&self, older_than: DateTime<Utc>) -> Result<u64> {
141        let result = sqlx::query!(
142            r#"
143            DELETE FROM forge_rate_limits
144            WHERE created_at < $1
145            "#,
146            older_than,
147        )
148        .execute(&self.pool)
149        .await
150        .map_err(|e| ForgeError::Database(e.to_string()))?;
151
152        Ok(result.rows_affected())
153    }
154}
155
156struct LocalBucket {
157    tokens: f64,
158    max_tokens: f64,
159    refill_rate: f64,
160    last_refill: std::time::Instant,
161}
162
163impl LocalBucket {
164    fn new(max_tokens: f64, refill_rate: f64) -> Self {
165        Self {
166            tokens: max_tokens,
167            max_tokens,
168            refill_rate,
169            last_refill: std::time::Instant::now(),
170        }
171    }
172
173    fn try_consume(&mut self) -> bool {
174        let now = std::time::Instant::now();
175        let elapsed = now.duration_since(self.last_refill).as_secs_f64();
176        self.tokens = (self.tokens + elapsed * self.refill_rate).min(self.max_tokens);
177        self.last_refill = now;
178
179        if self.tokens >= 1.0 {
180            self.tokens -= 1.0;
181            true
182        } else {
183            false
184        }
185    }
186
187    fn remaining(&self) -> u32 {
188        self.tokens.max(0.0) as u32
189    }
190
191    fn time_until_token(&self) -> Duration {
192        if self.tokens >= 1.0 {
193            Duration::ZERO
194        } else {
195            Duration::from_secs_f64((1.0 - self.tokens) / self.refill_rate)
196        }
197    }
198}
199
200/// Maximum number of local rate limit buckets to prevent unbounded memory growth.
201/// When exceeded, a cleanup is triggered to evict idle entries.
202const MAX_LOCAL_BUCKETS: usize = 100_000;
203
204/// Hybrid rate limiter with in-memory fast path and periodic DB sync.
205///
206/// Per-user/per-IP checks use a local DashMap for sub-microsecond decisions.
207/// Global keys always hit the database for strict cross-node consistency.
208pub struct HybridRateLimiter {
209    local: DashMap<String, LocalBucket>,
210    db_limiter: RateLimiter,
211}
212
213impl HybridRateLimiter {
214    pub fn new(pool: PgPool) -> Self {
215        Self {
216            local: DashMap::new(),
217            db_limiter: RateLimiter::new(pool),
218        }
219    }
220
221    /// Check rate limit. Uses local fast path for per-user/per-IP keys,
222    /// database for global keys.
223    pub async fn check(
224        &self,
225        bucket_key: &str,
226        config: &RateLimitConfig,
227    ) -> Result<RateLimitResult> {
228        if config.key == RateLimitKey::Global {
229            return self.db_limiter.check(bucket_key, config).await;
230        }
231
232        let max_tokens = config.requests as f64;
233        let refill_rate = config.refill_rate();
234
235        // Evict idle buckets when the map gets too large to prevent memory exhaustion
236        if self.local.len() > MAX_LOCAL_BUCKETS {
237            self.cleanup_local(Duration::from_secs(300)); // evict entries idle > 5 min
238        }
239
240        let mut bucket = self
241            .local
242            .entry(bucket_key.to_string())
243            .or_insert_with(|| LocalBucket::new(max_tokens, refill_rate));
244
245        let allowed = bucket.try_consume();
246        let remaining = bucket.remaining();
247        let reset_at = Utc::now()
248            + chrono::Duration::seconds(((max_tokens - bucket.tokens) / refill_rate) as i64);
249
250        if allowed {
251            Ok(RateLimitResult::allowed(remaining, reset_at))
252        } else {
253            let retry_after = bucket.time_until_token();
254            Ok(RateLimitResult::denied(remaining, reset_at, retry_after))
255        }
256    }
257
258    pub fn build_key(
259        &self,
260        key_type: RateLimitKey,
261        action_name: &str,
262        auth: &AuthContext,
263        request: &RequestMetadata,
264    ) -> String {
265        self.db_limiter
266            .build_key(key_type, action_name, auth, request)
267    }
268
269    pub async fn enforce(
270        &self,
271        bucket_key: &str,
272        config: &RateLimitConfig,
273    ) -> Result<RateLimitResult> {
274        let result = self.check(bucket_key, config).await?;
275        if !result.allowed {
276            return Err(ForgeError::RateLimitExceeded {
277                retry_after: result.retry_after.unwrap_or(Duration::from_secs(1)),
278                limit: config.requests,
279                remaining: result.remaining,
280            });
281        }
282        Ok(result)
283    }
284
285    /// Clean up expired local buckets (call periodically).
286    pub fn cleanup_local(&self, max_idle: Duration) {
287        let cutoff = std::time::Instant::now() - max_idle;
288        self.local.retain(|_, bucket| bucket.last_refill > cutoff);
289    }
290}
291
292#[cfg(test)]
293mod tests {
294    use super::*;
295
296    #[tokio::test]
297    async fn test_rate_limiter_creation() {
298        let pool = sqlx::postgres::PgPoolOptions::new()
299            .max_connections(1)
300            .connect_lazy("postgres://localhost/test")
301            .expect("Failed to create mock pool");
302
303        let _limiter = RateLimiter::new(pool);
304    }
305
306    #[tokio::test]
307    async fn test_build_key() {
308        let pool = sqlx::postgres::PgPoolOptions::new()
309            .max_connections(1)
310            .connect_lazy("postgres://localhost/test")
311            .expect("Failed to create mock pool");
312
313        let limiter = RateLimiter::new(pool);
314        let auth = AuthContext::unauthenticated();
315        let request = RequestMetadata::default();
316
317        let key = limiter.build_key(RateLimitKey::Global, "test_action", &auth, &request);
318        assert_eq!(key, "global:test_action");
319
320        let key = limiter.build_key(RateLimitKey::User, "test_action", &auth, &request);
321        assert!(key.starts_with("user:"));
322    }
323}