mockforge_registry_server/middleware/
org_rate_limit.rs1use axum::{
7 extract::{Request, State},
8 http::{HeaderMap, StatusCode},
9 middleware::Next,
10 response::{IntoResponse, Response},
11 Json,
12};
13use chrono::Datelike;
14use serde_json::json;
15use uuid::Uuid;
16
17use crate::{
18 middleware::resolve_org_context,
19 models::{Organization, UsageCounter},
20 redis::{current_month_period, RedisPool},
21 AppState,
22};
23
24pub async fn check_org_limits(
31 pool: &sqlx::PgPool,
32 _redis: Option<&RedisPool>,
33 org: &Organization,
34 _user_id: Uuid,
35) -> Result<(), RateLimitError> {
36 let limits = &org.limits_json;
37
38 let period = current_month_period();
40
41 let usage = UsageCounter::get_or_create_current(pool, org.id)
43 .await
44 .map_err(|_| RateLimitError::Database)?;
45
46 let requests_limit = limits.get("requests_per_30d").and_then(|v| v.as_i64()).unwrap_or(10000);
48
49 if usage.requests >= requests_limit {
50 return Err(RateLimitError::LimitExceeded {
51 limit_type: "requests".to_string(),
52 limit: requests_limit,
53 used: usage.requests,
54 reset_period: period.clone(),
55 });
56 }
57
58 let storage_limit_gb = limits.get("storage_gb").and_then(|v| v.as_i64()).unwrap_or(1);
60 let storage_limit_bytes = storage_limit_gb * 1_000_000_000;
61
62 if usage.storage_bytes >= storage_limit_bytes {
63 return Err(RateLimitError::LimitExceeded {
64 limit_type: "storage".to_string(),
65 limit: storage_limit_bytes,
66 used: usage.storage_bytes,
67 reset_period: period.clone(),
68 });
69 }
70
71 Ok(())
72}
73
74pub async fn increment_usage(
76 pool: &sqlx::PgPool,
77 redis: Option<&RedisPool>,
78 org_id: Uuid,
79 request_size_bytes: i64,
80) -> Result<(), RateLimitError> {
81 if let Some(redis_pool) = redis {
83 let period = current_month_period();
84 let requests_key = format!("usage:{}:{}:requests", org_id, period);
85 let _ = redis_pool.increment_with_expiry(&requests_key, 2592000).await; if request_size_bytes > 0 {
87 let egress_key = format!("usage:{}:{}:egress", org_id, period);
88 let _ = redis_pool.increment_with_expiry(&egress_key, 2592000).await;
89 }
90 }
91
92 UsageCounter::increment_requests(pool, org_id, 1)
94 .await
95 .map_err(|_| RateLimitError::Database)?;
96
97 if request_size_bytes > 0 {
98 UsageCounter::increment_egress(pool, org_id, request_size_bytes)
99 .await
100 .map_err(|_| RateLimitError::Database)?;
101 }
102
103 Ok(())
104}
105
106pub async fn org_rate_limit_middleware(
116 State(state): State<AppState>,
117 headers: HeaderMap,
118 request: Request,
119 next: Next,
120) -> Result<Response, impl IntoResponse> {
121 let user_id_str = request.extensions().get::<String>().cloned();
123
124 let user_id = if let Some(id_str) = user_id_str {
127 match Uuid::parse_str(&id_str) {
128 Ok(id) => id,
129 Err(_) => {
130 return Ok(next.run(request).await);
132 }
133 }
134 } else {
135 return Ok(next.run(request).await);
136 };
137
138 let org_ctx =
140 match resolve_org_context(&state, user_id, &headers, Some(request.extensions())).await {
141 Ok(ctx) => ctx,
142 Err(_) => {
143 return Ok(next.run(request).await);
145 }
146 };
147
148 let pool = state.db.pool();
149
150 if let Err(e) = check_org_limits(pool, state.redis.as_ref(), &org_ctx.org, user_id).await {
152 return Err(rate_limit_error_response(e));
153 }
154
155 let usage = UsageCounter::get_or_create_current(pool, org_ctx.org_id).await.ok();
157
158 let limits = &org_ctx.org.limits_json;
159 let requests_limit = limits.get("requests_per_30d").and_then(|v| v.as_i64()).unwrap_or(10000);
160
161 let requests_remaining = usage
162 .as_ref()
163 .map(|u| (requests_limit - u.requests).max(0))
164 .unwrap_or(requests_limit);
165
166 let now = chrono::Utc::now();
168 let next_month = if now.month() == 12 {
169 chrono::NaiveDate::from_ymd_opt(now.year() + 1, 1, 1)
170 } else {
171 chrono::NaiveDate::from_ymd_opt(now.year(), now.month() + 1, 1)
172 }
173 .and_then(|d| d.and_hms_opt(0, 0, 0))
174 .map(|dt| chrono::DateTime::<chrono::Utc>::from_naive_utc_and_offset(dt, chrono::Utc));
175
176 let reset_timestamp =
177 next_month.map(|dt| dt.timestamp()).unwrap_or_else(|| now.timestamp() + 2592000); let request_body_bytes: i64 = request
181 .headers()
182 .get("content-length")
183 .and_then(|h| h.to_str().ok())
184 .and_then(|s| s.parse::<i64>().ok())
185 .unwrap_or(0);
186
187 let mut response = next.run(request).await;
189
190 let headers = response.headers_mut();
192 headers.insert(
193 "X-RateLimit-Limit",
194 axum::http::HeaderValue::from_str(&requests_limit.to_string())
195 .unwrap_or_else(|_| axum::http::HeaderValue::from_static("10000")),
196 );
197 headers.insert(
198 "X-RateLimit-Remaining",
199 axum::http::HeaderValue::from_str(&requests_remaining.to_string())
200 .unwrap_or_else(|_| axum::http::HeaderValue::from_static("0")),
201 );
202 headers.insert(
203 "X-RateLimit-Reset",
204 axum::http::HeaderValue::from_str(&reset_timestamp.to_string())
205 .unwrap_or_else(|_| axum::http::HeaderValue::from_static("0")),
206 );
207
208 let status = response.status();
210 if status.is_success() {
211 let response_size = estimate_response_size(&response);
213 let total_egress = response_size + request_body_bytes;
214
215 let pool_clone = pool.clone();
217 let redis_clone = state.redis.clone();
218 let org_id = org_ctx.org_id;
219 tokio::spawn(async move {
220 let _ = increment_usage(&pool_clone, redis_clone.as_ref(), org_id, total_egress).await;
221 });
222 }
223
224 Ok(response)
225}
226
227fn estimate_response_size(response: &Response) -> i64 {
232 response
233 .headers()
234 .get("content-length")
235 .and_then(|h| h.to_str().ok())
236 .and_then(|s| s.parse::<i64>().ok())
237 .unwrap_or(256)
238}
239
240#[derive(Debug)]
242pub enum RateLimitError {
243 Database,
244 LimitExceeded {
245 limit_type: String,
246 limit: i64,
247 used: i64,
248 reset_period: String,
249 },
250}
251
252fn rate_limit_error_response(error: RateLimitError) -> impl IntoResponse {
254 match error {
255 RateLimitError::Database => (
256 StatusCode::INTERNAL_SERVER_ERROR,
257 Json(json!({
258 "error": "Internal server error",
259 "message": "Failed to check rate limits"
260 })),
261 ),
262 RateLimitError::LimitExceeded {
263 limit_type,
264 limit,
265 used,
266 reset_period,
267 } => {
268 let limit_type_display = match limit_type.as_str() {
269 "requests" => "Monthly request limit",
270 "storage" => "Storage limit",
271 _ => "Usage limit",
272 };
273
274 (
275 StatusCode::TOO_MANY_REQUESTS,
276 Json(json!({
277 "error": "Rate limit exceeded",
278 "message": format!("{} exceeded. Used {}/{}", limit_type_display, used, limit),
279 "limit_type": limit_type,
280 "limit": limit,
281 "used": used,
282 "reset_period": reset_period,
283 "upgrade_url": "/billing/upgrade"
284 })),
285 )
286 }
287 }
288}