forge_runtime/rate_limit/
limiter.rs1use std::time::Duration;
2
3use chrono::{DateTime, Utc};
4use dashmap::DashMap;
5use sqlx::PgPool;
6
7use forge_core::rate_limit::{RateLimitConfig, RateLimitKey, RateLimitResult};
8use forge_core::{AuthContext, ForgeError, RequestMetadata, Result};
9
10pub struct RateLimiter {
14 pool: PgPool,
15}
16
17impl RateLimiter {
18 pub fn new(pool: PgPool) -> Self {
20 Self { pool }
21 }
22
23 pub async fn check(
25 &self,
26 bucket_key: &str,
27 config: &RateLimitConfig,
28 ) -> Result<RateLimitResult> {
29 let max_tokens = config.requests as f64;
30 let refill_rate = config.refill_rate();
31
32 let result = sqlx::query!(
34 r#"
35 INSERT INTO forge_rate_limits (bucket_key, tokens, last_refill, max_tokens, refill_rate)
36 VALUES ($1, $2 - 1, NOW(), $2, $3)
37 ON CONFLICT (bucket_key) DO UPDATE SET
38 tokens = LEAST(
39 forge_rate_limits.max_tokens::double precision,
40 forge_rate_limits.tokens +
41 (EXTRACT(EPOCH FROM (NOW() - forge_rate_limits.last_refill)) * forge_rate_limits.refill_rate)
42 ) - 1,
43 last_refill = NOW()
44 RETURNING tokens, max_tokens, last_refill, (tokens >= 0) as "allowed!"
45 "#,
46 bucket_key,
47 max_tokens as i32,
48 refill_rate
49 )
50 .fetch_one(&self.pool)
51 .await
52 .map_err(|e| ForgeError::Database(e.to_string()))?;
53
54 let tokens = result.tokens;
55 let last_refill = result.last_refill;
56 let allowed = result.allowed;
57
58 let remaining = tokens.max(0.0) as u32;
59 let reset_at =
60 last_refill + chrono::Duration::seconds(((max_tokens - tokens) / refill_rate) as i64);
61
62 if allowed {
63 Ok(RateLimitResult::allowed(remaining, reset_at))
64 } else {
65 let retry_after = Duration::from_secs_f64((1.0 - tokens) / refill_rate);
66 Ok(RateLimitResult::denied(remaining, reset_at, retry_after))
67 }
68 }
69
70 pub fn build_key(
72 &self,
73 key_type: RateLimitKey,
74 action_name: &str,
75 auth: &AuthContext,
76 request: &RequestMetadata,
77 ) -> String {
78 match key_type {
79 RateLimitKey::User => {
80 let user_id = auth
81 .user_id()
82 .map(|u| u.to_string())
83 .unwrap_or_else(|| "anonymous".to_string());
84 format!("user:{}:{}", user_id, action_name)
85 }
86 RateLimitKey::Ip => {
87 let ip = request.client_ip.as_deref().unwrap_or("unknown");
88 format!("ip:{}:{}", ip, action_name)
89 }
90 RateLimitKey::Tenant => {
91 let tenant_id = auth
92 .claim("tenant_id")
93 .and_then(|v| v.as_str())
94 .unwrap_or("none");
95 format!("tenant:{}:{}", tenant_id, action_name)
96 }
97 RateLimitKey::UserAction => {
98 let user_id = auth
99 .user_id()
100 .map(|u| u.to_string())
101 .unwrap_or_else(|| "anonymous".to_string());
102 format!("user_action:{}:{}", user_id, action_name)
103 }
104 RateLimitKey::Global => {
105 format!("global:{}", action_name)
106 }
107 }
108 }
109
110 pub async fn enforce(
112 &self,
113 bucket_key: &str,
114 config: &RateLimitConfig,
115 ) -> Result<RateLimitResult> {
116 let result = self.check(bucket_key, config).await?;
117 if !result.allowed {
118 return Err(ForgeError::RateLimitExceeded {
119 retry_after: result.retry_after.unwrap_or(Duration::from_secs(1)),
120 limit: config.requests,
121 remaining: result.remaining,
122 });
123 }
124 Ok(result)
125 }
126
127 pub async fn reset(&self, bucket_key: &str) -> Result<()> {
129 sqlx::query("DELETE FROM forge_rate_limits WHERE bucket_key = $1")
130 .bind(bucket_key)
131 .execute(&self.pool)
132 .await
133 .map_err(|e| ForgeError::Database(e.to_string()))?;
134 Ok(())
135 }
136
137 pub async fn cleanup(&self, older_than: DateTime<Utc>) -> Result<u64> {
139 let result = sqlx::query(
140 r#"
141 DELETE FROM forge_rate_limits
142 WHERE created_at < $1
143 "#,
144 )
145 .bind(older_than)
146 .execute(&self.pool)
147 .await
148 .map_err(|e| ForgeError::Database(e.to_string()))?;
149
150 Ok(result.rows_affected())
151 }
152}
153
154struct LocalBucket {
155 tokens: f64,
156 max_tokens: f64,
157 refill_rate: f64,
158 last_refill: std::time::Instant,
159}
160
161impl LocalBucket {
162 fn new(max_tokens: f64, refill_rate: f64) -> Self {
163 Self {
164 tokens: max_tokens,
165 max_tokens,
166 refill_rate,
167 last_refill: std::time::Instant::now(),
168 }
169 }
170
171 fn try_consume(&mut self) -> bool {
172 let now = std::time::Instant::now();
173 let elapsed = now.duration_since(self.last_refill).as_secs_f64();
174 self.tokens = (self.tokens + elapsed * self.refill_rate).min(self.max_tokens);
175 self.last_refill = now;
176
177 if self.tokens >= 1.0 {
178 self.tokens -= 1.0;
179 true
180 } else {
181 false
182 }
183 }
184
185 fn remaining(&self) -> u32 {
186 self.tokens.max(0.0) as u32
187 }
188
189 fn time_until_token(&self) -> Duration {
190 if self.tokens >= 1.0 {
191 Duration::ZERO
192 } else {
193 Duration::from_secs_f64((1.0 - self.tokens) / self.refill_rate)
194 }
195 }
196}
197
198pub struct HybridRateLimiter {
203 local: DashMap<String, LocalBucket>,
204 db_limiter: RateLimiter,
205}
206
207impl HybridRateLimiter {
208 pub fn new(pool: PgPool) -> Self {
209 Self {
210 local: DashMap::new(),
211 db_limiter: RateLimiter::new(pool),
212 }
213 }
214
215 pub async fn check(
218 &self,
219 bucket_key: &str,
220 config: &RateLimitConfig,
221 ) -> Result<RateLimitResult> {
222 if config.key == RateLimitKey::Global {
223 return self.db_limiter.check(bucket_key, config).await;
224 }
225
226 let max_tokens = config.requests as f64;
227 let refill_rate = config.refill_rate();
228
229 let mut bucket = self
230 .local
231 .entry(bucket_key.to_string())
232 .or_insert_with(|| LocalBucket::new(max_tokens, refill_rate));
233
234 let allowed = bucket.try_consume();
235 let remaining = bucket.remaining();
236 let reset_at = Utc::now()
237 + chrono::Duration::seconds(((max_tokens - bucket.tokens) / refill_rate) as i64);
238
239 if allowed {
240 Ok(RateLimitResult::allowed(remaining, reset_at))
241 } else {
242 let retry_after = bucket.time_until_token();
243 Ok(RateLimitResult::denied(remaining, reset_at, retry_after))
244 }
245 }
246
247 pub fn build_key(
248 &self,
249 key_type: RateLimitKey,
250 action_name: &str,
251 auth: &AuthContext,
252 request: &RequestMetadata,
253 ) -> String {
254 self.db_limiter
255 .build_key(key_type, action_name, auth, request)
256 }
257
258 pub async fn enforce(
259 &self,
260 bucket_key: &str,
261 config: &RateLimitConfig,
262 ) -> Result<RateLimitResult> {
263 let result = self.check(bucket_key, config).await?;
264 if !result.allowed {
265 return Err(ForgeError::RateLimitExceeded {
266 retry_after: result.retry_after.unwrap_or(Duration::from_secs(1)),
267 limit: config.requests,
268 remaining: result.remaining,
269 });
270 }
271 Ok(result)
272 }
273
274 pub fn cleanup_local(&self, max_idle: Duration) {
276 let cutoff = std::time::Instant::now() - max_idle;
277 self.local.retain(|_, bucket| bucket.last_refill > cutoff);
278 }
279}
280
281#[cfg(test)]
282mod tests {
283 use super::*;
284
285 #[tokio::test]
286 async fn test_rate_limiter_creation() {
287 let pool = sqlx::postgres::PgPoolOptions::new()
288 .max_connections(1)
289 .connect_lazy("postgres://localhost/test")
290 .expect("Failed to create mock pool");
291
292 let _limiter = RateLimiter::new(pool);
293 }
294
295 #[tokio::test]
296 async fn test_build_key() {
297 let pool = sqlx::postgres::PgPoolOptions::new()
298 .max_connections(1)
299 .connect_lazy("postgres://localhost/test")
300 .expect("Failed to create mock pool");
301
302 let limiter = RateLimiter::new(pool);
303 let auth = AuthContext::unauthenticated();
304 let request = RequestMetadata::default();
305
306 let key = limiter.build_key(RateLimitKey::Global, "test_action", &auth, &request);
307 assert_eq!(key, "global:test_action");
308
309 let key = limiter.build_key(RateLimitKey::User, "test_action", &auth, &request);
310 assert!(key.starts_with("user:"));
311 }
312}