Skip to main content

forge_runtime/rate_limit/
limiter.rs

1use std::time::Duration;
2
3use chrono::{DateTime, Utc};
4use dashmap::DashMap;
5use sqlx::PgPool;
6
7use forge_core::rate_limit::{RateLimitConfig, RateLimitKey, RateLimitResult};
8use forge_core::{AuthContext, ForgeError, RequestMetadata, Result};
9
10/// Rate limiter using PostgreSQL for state storage.
11///
12/// Implements token bucket algorithm with atomic updates.
13pub struct RateLimiter {
14    pool: PgPool,
15}
16
17impl RateLimiter {
18    /// Create a new rate limiter.
19    pub fn new(pool: PgPool) -> Self {
20        Self { pool }
21    }
22
23    /// Check rate limit for a bucket key.
24    pub async fn check(
25        &self,
26        bucket_key: &str,
27        config: &RateLimitConfig,
28    ) -> Result<RateLimitResult> {
29        let max_tokens = config.requests as f64;
30        let refill_rate = config.refill_rate();
31
32        // Atomic upsert with token bucket logic
33        let result: (f64, i32, DateTime<Utc>, bool) = sqlx::query_as(
34            r#"
35            INSERT INTO forge_rate_limits (bucket_key, tokens, last_refill, max_tokens, refill_rate)
36            VALUES ($1, $2 - 1, NOW(), $2, $3)
37            ON CONFLICT (bucket_key) DO UPDATE SET
38                tokens = LEAST(
39                    forge_rate_limits.max_tokens::double precision,
40                    forge_rate_limits.tokens +
41                        (EXTRACT(EPOCH FROM (NOW() - forge_rate_limits.last_refill)) * forge_rate_limits.refill_rate)
42                ) - 1,
43                last_refill = NOW()
44            RETURNING tokens, max_tokens, last_refill, (tokens >= 0) as allowed
45            "#,
46        )
47        .bind(bucket_key)
48        .bind(max_tokens as i32)
49        .bind(refill_rate)
50        .fetch_one(&self.pool)
51        .await
52        .map_err(|e| ForgeError::Database(e.to_string()))?;
53
54        let (tokens, _max, last_refill, allowed) = result;
55
56        let remaining = tokens.max(0.0) as u32;
57        let reset_at =
58            last_refill + chrono::Duration::seconds(((max_tokens - tokens) / refill_rate) as i64);
59
60        if allowed {
61            Ok(RateLimitResult::allowed(remaining, reset_at))
62        } else {
63            let retry_after = Duration::from_secs_f64((1.0 - tokens) / refill_rate);
64            Ok(RateLimitResult::denied(remaining, reset_at, retry_after))
65        }
66    }
67
68    /// Build a bucket key for the given parameters.
69    pub fn build_key(
70        &self,
71        key_type: RateLimitKey,
72        action_name: &str,
73        auth: &AuthContext,
74        request: &RequestMetadata,
75    ) -> String {
76        match key_type {
77            RateLimitKey::User => {
78                let user_id = auth
79                    .user_id()
80                    .map(|u| u.to_string())
81                    .unwrap_or_else(|| "anonymous".to_string());
82                format!("user:{}:{}", user_id, action_name)
83            }
84            RateLimitKey::Ip => {
85                let ip = request.client_ip.as_deref().unwrap_or("unknown");
86                format!("ip:{}:{}", ip, action_name)
87            }
88            RateLimitKey::Tenant => {
89                let tenant_id = auth
90                    .claim("tenant_id")
91                    .and_then(|v| v.as_str())
92                    .unwrap_or("none");
93                format!("tenant:{}:{}", tenant_id, action_name)
94            }
95            RateLimitKey::UserAction => {
96                let user_id = auth
97                    .user_id()
98                    .map(|u| u.to_string())
99                    .unwrap_or_else(|| "anonymous".to_string());
100                format!("user_action:{}:{}", user_id, action_name)
101            }
102            RateLimitKey::Global => {
103                format!("global:{}", action_name)
104            }
105        }
106    }
107
108    /// Check rate limit and return an error if exceeded.
109    pub async fn enforce(
110        &self,
111        bucket_key: &str,
112        config: &RateLimitConfig,
113    ) -> Result<RateLimitResult> {
114        let result = self.check(bucket_key, config).await?;
115        if !result.allowed {
116            return Err(ForgeError::RateLimitExceeded {
117                retry_after: result.retry_after.unwrap_or(Duration::from_secs(1)),
118                limit: config.requests,
119                remaining: result.remaining,
120            });
121        }
122        Ok(result)
123    }
124
125    /// Reset a rate limit bucket.
126    pub async fn reset(&self, bucket_key: &str) -> Result<()> {
127        sqlx::query("DELETE FROM forge_rate_limits WHERE bucket_key = $1")
128            .bind(bucket_key)
129            .execute(&self.pool)
130            .await
131            .map_err(|e| ForgeError::Database(e.to_string()))?;
132        Ok(())
133    }
134
135    /// Clean up old rate limit entries.
136    pub async fn cleanup(&self, older_than: DateTime<Utc>) -> Result<u64> {
137        let result = sqlx::query(
138            r#"
139            DELETE FROM forge_rate_limits
140            WHERE created_at < $1
141            "#,
142        )
143        .bind(older_than)
144        .execute(&self.pool)
145        .await
146        .map_err(|e| ForgeError::Database(e.to_string()))?;
147
148        Ok(result.rows_affected())
149    }
150}
151
152struct LocalBucket {
153    tokens: f64,
154    max_tokens: f64,
155    refill_rate: f64,
156    last_refill: std::time::Instant,
157}
158
159impl LocalBucket {
160    fn new(max_tokens: f64, refill_rate: f64) -> Self {
161        Self {
162            tokens: max_tokens,
163            max_tokens,
164            refill_rate,
165            last_refill: std::time::Instant::now(),
166        }
167    }
168
169    fn try_consume(&mut self) -> bool {
170        let now = std::time::Instant::now();
171        let elapsed = now.duration_since(self.last_refill).as_secs_f64();
172        self.tokens = (self.tokens + elapsed * self.refill_rate).min(self.max_tokens);
173        self.last_refill = now;
174
175        if self.tokens >= 1.0 {
176            self.tokens -= 1.0;
177            true
178        } else {
179            false
180        }
181    }
182
183    fn remaining(&self) -> u32 {
184        self.tokens.max(0.0) as u32
185    }
186
187    fn time_until_token(&self) -> Duration {
188        if self.tokens >= 1.0 {
189            Duration::ZERO
190        } else {
191            Duration::from_secs_f64((1.0 - self.tokens) / self.refill_rate)
192        }
193    }
194}
195
196/// Hybrid rate limiter with in-memory fast path and periodic DB sync.
197///
198/// Per-user/per-IP checks use a local DashMap for sub-microsecond decisions.
199/// Global keys always hit the database for strict cross-node consistency.
200pub struct HybridRateLimiter {
201    local: DashMap<String, LocalBucket>,
202    db_limiter: RateLimiter,
203}
204
205impl HybridRateLimiter {
206    pub fn new(pool: PgPool) -> Self {
207        Self {
208            local: DashMap::new(),
209            db_limiter: RateLimiter::new(pool),
210        }
211    }
212
213    /// Check rate limit. Uses local fast path for per-user/per-IP keys,
214    /// database for global keys.
215    pub async fn check(
216        &self,
217        bucket_key: &str,
218        config: &RateLimitConfig,
219    ) -> Result<RateLimitResult> {
220        if config.key == RateLimitKey::Global {
221            return self.db_limiter.check(bucket_key, config).await;
222        }
223
224        let max_tokens = config.requests as f64;
225        let refill_rate = config.refill_rate();
226
227        let mut bucket = self
228            .local
229            .entry(bucket_key.to_string())
230            .or_insert_with(|| LocalBucket::new(max_tokens, refill_rate));
231
232        let allowed = bucket.try_consume();
233        let remaining = bucket.remaining();
234        let reset_at = Utc::now()
235            + chrono::Duration::seconds(((max_tokens - bucket.tokens) / refill_rate) as i64);
236
237        if allowed {
238            Ok(RateLimitResult::allowed(remaining, reset_at))
239        } else {
240            let retry_after = bucket.time_until_token();
241            Ok(RateLimitResult::denied(remaining, reset_at, retry_after))
242        }
243    }
244
245    pub fn build_key(
246        &self,
247        key_type: RateLimitKey,
248        action_name: &str,
249        auth: &AuthContext,
250        request: &RequestMetadata,
251    ) -> String {
252        self.db_limiter
253            .build_key(key_type, action_name, auth, request)
254    }
255
256    pub async fn enforce(
257        &self,
258        bucket_key: &str,
259        config: &RateLimitConfig,
260    ) -> Result<RateLimitResult> {
261        let result = self.check(bucket_key, config).await?;
262        if !result.allowed {
263            return Err(ForgeError::RateLimitExceeded {
264                retry_after: result.retry_after.unwrap_or(Duration::from_secs(1)),
265                limit: config.requests,
266                remaining: result.remaining,
267            });
268        }
269        Ok(result)
270    }
271
272    /// Clean up expired local buckets (call periodically).
273    pub fn cleanup_local(&self, max_idle: Duration) {
274        let cutoff = std::time::Instant::now() - max_idle;
275        self.local.retain(|_, bucket| bucket.last_refill > cutoff);
276    }
277}
278
279#[cfg(test)]
280mod tests {
281    use super::*;
282
283    #[tokio::test]
284    async fn test_rate_limiter_creation() {
285        let pool = sqlx::postgres::PgPoolOptions::new()
286            .max_connections(1)
287            .connect_lazy("postgres://localhost/test")
288            .expect("Failed to create mock pool");
289
290        let _limiter = RateLimiter::new(pool);
291    }
292
293    #[tokio::test]
294    async fn test_build_key() {
295        let pool = sqlx::postgres::PgPoolOptions::new()
296            .max_connections(1)
297            .connect_lazy("postgres://localhost/test")
298            .expect("Failed to create mock pool");
299
300        let limiter = RateLimiter::new(pool);
301        let auth = AuthContext::unauthenticated();
302        let request = RequestMetadata::default();
303
304        let key = limiter.build_key(RateLimitKey::Global, "test_action", &auth, &request);
305        assert_eq!(key, "global:test_action");
306
307        let key = limiter.build_key(RateLimitKey::User, "test_action", &auth, &request);
308        assert!(key.starts_with("user:"));
309    }
310}