forge_runtime/rate_limit/
limiter.rs1use std::time::Duration;
2
3use chrono::{DateTime, Utc};
4use sqlx::PgPool;
5
6use forge_core::rate_limit::{RateLimitConfig, RateLimitKey, RateLimitResult};
7use forge_core::{AuthContext, ForgeError, RequestMetadata, Result};
8
9pub struct RateLimiter {
13 pool: PgPool,
14}
15
16impl RateLimiter {
17 pub fn new(pool: PgPool) -> Self {
19 Self { pool }
20 }
21
22 pub async fn check(
24 &self,
25 bucket_key: &str,
26 config: &RateLimitConfig,
27 ) -> Result<RateLimitResult> {
28 let max_tokens = config.requests as f64;
29 let refill_rate = config.refill_rate();
30
31 let result: (f64, i32, DateTime<Utc>, bool) = sqlx::query_as(
33 r#"
34 INSERT INTO forge_rate_limits (bucket_key, tokens, last_refill, max_tokens, refill_rate)
35 VALUES ($1, $2 - 1, NOW(), $2, $3)
36 ON CONFLICT (bucket_key) DO UPDATE SET
37 tokens = LEAST(
38 forge_rate_limits.max_tokens::double precision,
39 forge_rate_limits.tokens +
40 (EXTRACT(EPOCH FROM (NOW() - forge_rate_limits.last_refill)) * forge_rate_limits.refill_rate)
41 ) - 1,
42 last_refill = NOW()
43 RETURNING tokens, max_tokens, last_refill, (tokens >= 0) as allowed
44 "#,
45 )
46 .bind(bucket_key)
47 .bind(max_tokens as i32)
48 .bind(refill_rate)
49 .fetch_one(&self.pool)
50 .await
51 .map_err(|e| ForgeError::Database(e.to_string()))?;
52
53 let (tokens, _max, last_refill, allowed) = result;
54
55 let remaining = tokens.max(0.0) as u32;
56 let reset_at =
57 last_refill + chrono::Duration::seconds(((max_tokens - tokens) / refill_rate) as i64);
58
59 if allowed {
60 Ok(RateLimitResult::allowed(remaining, reset_at))
61 } else {
62 let retry_after = Duration::from_secs_f64((1.0 - tokens) / refill_rate);
63 Ok(RateLimitResult::denied(remaining, reset_at, retry_after))
64 }
65 }
66
67 pub fn build_key(
69 &self,
70 key_type: RateLimitKey,
71 action_name: &str,
72 auth: &AuthContext,
73 request: &RequestMetadata,
74 ) -> String {
75 match key_type {
76 RateLimitKey::User => {
77 let user_id = auth
78 .user_id()
79 .map(|u| u.to_string())
80 .unwrap_or_else(|| "anonymous".to_string());
81 format!("user:{}:{}", user_id, action_name)
82 }
83 RateLimitKey::Ip => {
84 let ip = request.client_ip.as_deref().unwrap_or("unknown");
85 format!("ip:{}:{}", ip, action_name)
86 }
87 RateLimitKey::Tenant => {
88 let tenant_id = auth
89 .claim("tenant_id")
90 .and_then(|v| v.as_str())
91 .unwrap_or("none");
92 format!("tenant:{}:{}", tenant_id, action_name)
93 }
94 RateLimitKey::UserAction => {
95 let user_id = auth
96 .user_id()
97 .map(|u| u.to_string())
98 .unwrap_or_else(|| "anonymous".to_string());
99 format!("user_action:{}:{}", user_id, action_name)
100 }
101 RateLimitKey::Global => {
102 format!("global:{}", action_name)
103 }
104 }
105 }
106
107 pub async fn enforce(
109 &self,
110 bucket_key: &str,
111 config: &RateLimitConfig,
112 ) -> Result<RateLimitResult> {
113 let result = self.check(bucket_key, config).await?;
114 if !result.allowed {
115 return Err(ForgeError::RateLimitExceeded {
116 retry_after: result.retry_after.unwrap_or(Duration::from_secs(1)),
117 limit: config.requests,
118 remaining: result.remaining,
119 });
120 }
121 Ok(result)
122 }
123
124 pub async fn reset(&self, bucket_key: &str) -> Result<()> {
126 sqlx::query("DELETE FROM forge_rate_limits WHERE bucket_key = $1")
127 .bind(bucket_key)
128 .execute(&self.pool)
129 .await
130 .map_err(|e| ForgeError::Database(e.to_string()))?;
131 Ok(())
132 }
133
134 pub async fn cleanup(&self, older_than: DateTime<Utc>) -> Result<u64> {
136 let result = sqlx::query(
137 r#"
138 DELETE FROM forge_rate_limits
139 WHERE created_at < $1
140 "#,
141 )
142 .bind(older_than)
143 .execute(&self.pool)
144 .await
145 .map_err(|e| ForgeError::Database(e.to_string()))?;
146
147 Ok(result.rows_affected())
148 }
149}
150
151#[cfg(test)]
152mod tests {
153 use super::*;
154
155 #[tokio::test]
156 async fn test_rate_limiter_creation() {
157 let pool = sqlx::postgres::PgPoolOptions::new()
158 .max_connections(1)
159 .connect_lazy("postgres://localhost/test")
160 .expect("Failed to create mock pool");
161
162 let _limiter = RateLimiter::new(pool);
163 }
164
165 #[tokio::test]
166 async fn test_build_key() {
167 let pool = sqlx::postgres::PgPoolOptions::new()
168 .max_connections(1)
169 .connect_lazy("postgres://localhost/test")
170 .expect("Failed to create mock pool");
171
172 let limiter = RateLimiter::new(pool);
173 let auth = AuthContext::unauthenticated();
174 let request = RequestMetadata::default();
175
176 let key = limiter.build_key(RateLimitKey::Global, "test_action", &auth, &request);
177 assert_eq!(key, "global:test_action");
178
179 let key = limiter.build_key(RateLimitKey::User, "test_action", &auth, &request);
180 assert!(key.starts_with("user:"));
181 }
182}