forge_runtime/rate_limit/
limiter.rs1use std::time::Duration;
2
3use chrono::{DateTime, Utc};
4use dashmap::DashMap;
5use sqlx::PgPool;
6
7use forge_core::rate_limit::{RateLimitConfig, RateLimitKey, RateLimitResult};
8use forge_core::{AuthContext, ForgeError, RequestMetadata, Result};
9
10pub struct RateLimiter {
14 pool: PgPool,
15}
16
17impl RateLimiter {
18 pub fn new(pool: PgPool) -> Self {
20 Self { pool }
21 }
22
23 pub async fn check(
25 &self,
26 bucket_key: &str,
27 config: &RateLimitConfig,
28 ) -> Result<RateLimitResult> {
29 let max_tokens = config.requests as f64;
30 let refill_rate = config.refill_rate();
31
32 let result: (f64, i32, DateTime<Utc>, bool) = sqlx::query_as(
34 r#"
35 INSERT INTO forge_rate_limits (bucket_key, tokens, last_refill, max_tokens, refill_rate)
36 VALUES ($1, $2 - 1, NOW(), $2, $3)
37 ON CONFLICT (bucket_key) DO UPDATE SET
38 tokens = LEAST(
39 forge_rate_limits.max_tokens::double precision,
40 forge_rate_limits.tokens +
41 (EXTRACT(EPOCH FROM (NOW() - forge_rate_limits.last_refill)) * forge_rate_limits.refill_rate)
42 ) - 1,
43 last_refill = NOW()
44 RETURNING tokens, max_tokens, last_refill, (tokens >= 0) as allowed
45 "#,
46 )
47 .bind(bucket_key)
48 .bind(max_tokens as i32)
49 .bind(refill_rate)
50 .fetch_one(&self.pool)
51 .await
52 .map_err(|e| ForgeError::Database(e.to_string()))?;
53
54 let (tokens, _max, last_refill, allowed) = result;
55
56 let remaining = tokens.max(0.0) as u32;
57 let reset_at =
58 last_refill + chrono::Duration::seconds(((max_tokens - tokens) / refill_rate) as i64);
59
60 if allowed {
61 Ok(RateLimitResult::allowed(remaining, reset_at))
62 } else {
63 let retry_after = Duration::from_secs_f64((1.0 - tokens) / refill_rate);
64 Ok(RateLimitResult::denied(remaining, reset_at, retry_after))
65 }
66 }
67
68 pub fn build_key(
70 &self,
71 key_type: RateLimitKey,
72 action_name: &str,
73 auth: &AuthContext,
74 request: &RequestMetadata,
75 ) -> String {
76 match key_type {
77 RateLimitKey::User => {
78 let user_id = auth
79 .user_id()
80 .map(|u| u.to_string())
81 .unwrap_or_else(|| "anonymous".to_string());
82 format!("user:{}:{}", user_id, action_name)
83 }
84 RateLimitKey::Ip => {
85 let ip = request.client_ip.as_deref().unwrap_or("unknown");
86 format!("ip:{}:{}", ip, action_name)
87 }
88 RateLimitKey::Tenant => {
89 let tenant_id = auth
90 .claim("tenant_id")
91 .and_then(|v| v.as_str())
92 .unwrap_or("none");
93 format!("tenant:{}:{}", tenant_id, action_name)
94 }
95 RateLimitKey::UserAction => {
96 let user_id = auth
97 .user_id()
98 .map(|u| u.to_string())
99 .unwrap_or_else(|| "anonymous".to_string());
100 format!("user_action:{}:{}", user_id, action_name)
101 }
102 RateLimitKey::Global => {
103 format!("global:{}", action_name)
104 }
105 }
106 }
107
108 pub async fn enforce(
110 &self,
111 bucket_key: &str,
112 config: &RateLimitConfig,
113 ) -> Result<RateLimitResult> {
114 let result = self.check(bucket_key, config).await?;
115 if !result.allowed {
116 return Err(ForgeError::RateLimitExceeded {
117 retry_after: result.retry_after.unwrap_or(Duration::from_secs(1)),
118 limit: config.requests,
119 remaining: result.remaining,
120 });
121 }
122 Ok(result)
123 }
124
125 pub async fn reset(&self, bucket_key: &str) -> Result<()> {
127 sqlx::query("DELETE FROM forge_rate_limits WHERE bucket_key = $1")
128 .bind(bucket_key)
129 .execute(&self.pool)
130 .await
131 .map_err(|e| ForgeError::Database(e.to_string()))?;
132 Ok(())
133 }
134
135 pub async fn cleanup(&self, older_than: DateTime<Utc>) -> Result<u64> {
137 let result = sqlx::query(
138 r#"
139 DELETE FROM forge_rate_limits
140 WHERE created_at < $1
141 "#,
142 )
143 .bind(older_than)
144 .execute(&self.pool)
145 .await
146 .map_err(|e| ForgeError::Database(e.to_string()))?;
147
148 Ok(result.rows_affected())
149 }
150}
151
152struct LocalBucket {
153 tokens: f64,
154 max_tokens: f64,
155 refill_rate: f64,
156 last_refill: std::time::Instant,
157}
158
159impl LocalBucket {
160 fn new(max_tokens: f64, refill_rate: f64) -> Self {
161 Self {
162 tokens: max_tokens,
163 max_tokens,
164 refill_rate,
165 last_refill: std::time::Instant::now(),
166 }
167 }
168
169 fn try_consume(&mut self) -> bool {
170 let now = std::time::Instant::now();
171 let elapsed = now.duration_since(self.last_refill).as_secs_f64();
172 self.tokens = (self.tokens + elapsed * self.refill_rate).min(self.max_tokens);
173 self.last_refill = now;
174
175 if self.tokens >= 1.0 {
176 self.tokens -= 1.0;
177 true
178 } else {
179 false
180 }
181 }
182
183 fn remaining(&self) -> u32 {
184 self.tokens.max(0.0) as u32
185 }
186
187 fn time_until_token(&self) -> Duration {
188 if self.tokens >= 1.0 {
189 Duration::ZERO
190 } else {
191 Duration::from_secs_f64((1.0 - self.tokens) / self.refill_rate)
192 }
193 }
194}
195
196pub struct HybridRateLimiter {
201 local: DashMap<String, LocalBucket>,
202 db_limiter: RateLimiter,
203}
204
205impl HybridRateLimiter {
206 pub fn new(pool: PgPool) -> Self {
207 Self {
208 local: DashMap::new(),
209 db_limiter: RateLimiter::new(pool),
210 }
211 }
212
213 pub async fn check(
216 &self,
217 bucket_key: &str,
218 config: &RateLimitConfig,
219 ) -> Result<RateLimitResult> {
220 if config.key == RateLimitKey::Global {
221 return self.db_limiter.check(bucket_key, config).await;
222 }
223
224 let max_tokens = config.requests as f64;
225 let refill_rate = config.refill_rate();
226
227 let mut bucket = self
228 .local
229 .entry(bucket_key.to_string())
230 .or_insert_with(|| LocalBucket::new(max_tokens, refill_rate));
231
232 let allowed = bucket.try_consume();
233 let remaining = bucket.remaining();
234 let reset_at = Utc::now()
235 + chrono::Duration::seconds(((max_tokens - bucket.tokens) / refill_rate) as i64);
236
237 if allowed {
238 Ok(RateLimitResult::allowed(remaining, reset_at))
239 } else {
240 let retry_after = bucket.time_until_token();
241 Ok(RateLimitResult::denied(remaining, reset_at, retry_after))
242 }
243 }
244
245 pub fn build_key(
246 &self,
247 key_type: RateLimitKey,
248 action_name: &str,
249 auth: &AuthContext,
250 request: &RequestMetadata,
251 ) -> String {
252 self.db_limiter
253 .build_key(key_type, action_name, auth, request)
254 }
255
256 pub async fn enforce(
257 &self,
258 bucket_key: &str,
259 config: &RateLimitConfig,
260 ) -> Result<RateLimitResult> {
261 let result = self.check(bucket_key, config).await?;
262 if !result.allowed {
263 return Err(ForgeError::RateLimitExceeded {
264 retry_after: result.retry_after.unwrap_or(Duration::from_secs(1)),
265 limit: config.requests,
266 remaining: result.remaining,
267 });
268 }
269 Ok(result)
270 }
271
272 pub fn cleanup_local(&self, max_idle: Duration) {
274 let cutoff = std::time::Instant::now() - max_idle;
275 self.local.retain(|_, bucket| bucket.last_refill > cutoff);
276 }
277}
278
279#[cfg(test)]
280mod tests {
281 use super::*;
282
283 #[tokio::test]
284 async fn test_rate_limiter_creation() {
285 let pool = sqlx::postgres::PgPoolOptions::new()
286 .max_connections(1)
287 .connect_lazy("postgres://localhost/test")
288 .expect("Failed to create mock pool");
289
290 let _limiter = RateLimiter::new(pool);
291 }
292
293 #[tokio::test]
294 async fn test_build_key() {
295 let pool = sqlx::postgres::PgPoolOptions::new()
296 .max_connections(1)
297 .connect_lazy("postgres://localhost/test")
298 .expect("Failed to create mock pool");
299
300 let limiter = RateLimiter::new(pool);
301 let auth = AuthContext::unauthenticated();
302 let request = RequestMetadata::default();
303
304 let key = limiter.build_key(RateLimitKey::Global, "test_action", &auth, &request);
305 assert_eq!(key, "global:test_action");
306
307 let key = limiter.build_key(RateLimitKey::User, "test_action", &auth, &request);
308 assert!(key.starts_with("user:"));
309 }
310}