Skip to main content

mockforge_registry_server/middleware/
org_rate_limit.rs

1//! Organization-aware rate limiting middleware
2//!
3//! This middleware enforces rate limits based on organization plan limits.
4//! It tracks usage in Redis and checks against plan limits before allowing requests.
5
6use axum::{
7    extract::{Request, State},
8    http::{HeaderMap, StatusCode},
9    middleware::Next,
10    response::{IntoResponse, Response},
11    Json,
12};
13use chrono::Datelike;
14use serde_json::json;
15use uuid::Uuid;
16
17use crate::{
18    middleware::resolve_org_context,
19    models::{Organization, UsageCounter},
20    redis::{current_month_period, RedisPool},
21    AppState,
22};
23
24/// Check if organization has exceeded plan limits.
25///
26/// `_redis` and `_user_id` are accepted as part of the public signature so
27/// the auth-route middleware can pass them in; current implementation reads
28/// the authoritative counter from Postgres, but a Redis fast-path is intended
29/// for a follow-up.
30pub async fn check_org_limits(
31    pool: &sqlx::PgPool,
32    _redis: Option<&RedisPool>,
33    org: &Organization,
34    _user_id: Uuid,
35) -> Result<(), RateLimitError> {
36    let limits = &org.limits_json;
37
38    // Get current month period
39    let period = current_month_period();
40
41    // Get or create usage counter
42    let usage = UsageCounter::get_or_create_current(pool, org.id)
43        .await
44        .map_err(|_| RateLimitError::Database)?;
45
46    // Check request limit
47    let requests_limit = limits.get("requests_per_30d").and_then(|v| v.as_i64()).unwrap_or(10000);
48
49    if usage.requests >= requests_limit {
50        return Err(RateLimitError::LimitExceeded {
51            limit_type: "requests".to_string(),
52            limit: requests_limit,
53            used: usage.requests,
54            reset_period: period.clone(),
55        });
56    }
57
58    // Check storage limit
59    let storage_limit_gb = limits.get("storage_gb").and_then(|v| v.as_i64()).unwrap_or(1);
60    let storage_limit_bytes = storage_limit_gb * 1_000_000_000;
61
62    if usage.storage_bytes >= storage_limit_bytes {
63        return Err(RateLimitError::LimitExceeded {
64            limit_type: "storage".to_string(),
65            limit: storage_limit_bytes,
66            used: usage.storage_bytes,
67            reset_period: period.clone(),
68        });
69    }
70
71    Ok(())
72}
73
74/// Increment usage counter for a request
75pub async fn increment_usage(
76    pool: &sqlx::PgPool,
77    redis: Option<&RedisPool>,
78    org_id: Uuid,
79    request_size_bytes: i64,
80) -> Result<(), RateLimitError> {
81    // Increment in Redis first (fast path)
82    if let Some(redis_pool) = redis {
83        let period = current_month_period();
84        let requests_key = format!("usage:{}:{}:requests", org_id, period);
85        let _ = redis_pool.increment_with_expiry(&requests_key, 2592000).await; // 30 days
86        if request_size_bytes > 0 {
87            let egress_key = format!("usage:{}:{}:egress", org_id, period);
88            let _ = redis_pool.increment_with_expiry(&egress_key, 2592000).await;
89        }
90    }
91
92    // Increment in database (slower, but persistent)
93    UsageCounter::increment_requests(pool, org_id, 1)
94        .await
95        .map_err(|_| RateLimitError::Database)?;
96
97    if request_size_bytes > 0 {
98        UsageCounter::increment_egress(pool, org_id, request_size_bytes)
99            .await
100            .map_err(|_| RateLimitError::Database)?;
101    }
102
103    Ok(())
104}
105
106/// Organization-aware rate limiting middleware
107///
108/// This middleware:
109/// 1. Resolves organization context from request
110/// 2. Checks plan limits (requests, storage, etc.)
111/// 3. Increments usage counters
112/// 4. Returns 429 if limits exceeded
113///
114/// Note: This should be applied AFTER auth_middleware
115pub async fn org_rate_limit_middleware(
116    State(state): State<AppState>,
117    headers: HeaderMap,
118    request: Request,
119    next: Next,
120) -> Result<Response, impl IntoResponse> {
121    // Try to get user_id from auth middleware (set in extensions)
122    let user_id_str = request.extensions().get::<String>().cloned();
123
124    // If no user_id, this might be a public endpoint - skip org rate limiting
125    // (but still apply global rate limiting if configured)
126    let user_id = if let Some(id_str) = user_id_str {
127        match Uuid::parse_str(&id_str) {
128            Ok(id) => id,
129            Err(_) => {
130                // Invalid user_id, skip org rate limiting
131                return Ok(next.run(request).await);
132            }
133        }
134    } else {
135        return Ok(next.run(request).await);
136    };
137
138    // Resolve org context (pass request extensions for API token org_id lookup)
139    let org_ctx =
140        match resolve_org_context(&state, user_id, &headers, Some(request.extensions())).await {
141            Ok(ctx) => ctx,
142            Err(_) => {
143                // No org context, skip org rate limiting
144                return Ok(next.run(request).await);
145            }
146        };
147
148    let pool = state.db.pool();
149
150    // Check org limits
151    if let Err(e) = check_org_limits(pool, state.redis.as_ref(), &org_ctx.org, user_id).await {
152        return Err(rate_limit_error_response(e));
153    }
154
155    // Get usage info for rate limit headers
156    let usage = UsageCounter::get_or_create_current(pool, org_ctx.org_id).await.ok();
157
158    let limits = &org_ctx.org.limits_json;
159    let requests_limit = limits.get("requests_per_30d").and_then(|v| v.as_i64()).unwrap_or(10000);
160
161    let requests_remaining = usage
162        .as_ref()
163        .map(|u| (requests_limit - u.requests).max(0))
164        .unwrap_or(requests_limit);
165
166    // Calculate reset time (end of current month)
167    let now = chrono::Utc::now();
168    let next_month = if now.month() == 12 {
169        chrono::NaiveDate::from_ymd_opt(now.year() + 1, 1, 1)
170    } else {
171        chrono::NaiveDate::from_ymd_opt(now.year(), now.month() + 1, 1)
172    }
173    .and_then(|d| d.and_hms_opt(0, 0, 0))
174    .map(|dt| chrono::DateTime::<chrono::Utc>::from_naive_utc_and_offset(dt, chrono::Utc));
175
176    let reset_timestamp =
177        next_month.map(|dt| dt.timestamp()).unwrap_or_else(|| now.timestamp() + 2592000); // Fallback: 30 days from now
178
179    // Capture request body size before it's consumed
180    let request_body_bytes: i64 = request
181        .headers()
182        .get("content-length")
183        .and_then(|h| h.to_str().ok())
184        .and_then(|s| s.parse::<i64>().ok())
185        .unwrap_or(0);
186
187    // Process request
188    let mut response = next.run(request).await;
189
190    // Add rate limit headers
191    let headers = response.headers_mut();
192    headers.insert(
193        "X-RateLimit-Limit",
194        axum::http::HeaderValue::from_str(&requests_limit.to_string())
195            .unwrap_or_else(|_| axum::http::HeaderValue::from_static("10000")),
196    );
197    headers.insert(
198        "X-RateLimit-Remaining",
199        axum::http::HeaderValue::from_str(&requests_remaining.to_string())
200            .unwrap_or_else(|_| axum::http::HeaderValue::from_static("0")),
201    );
202    headers.insert(
203        "X-RateLimit-Reset",
204        axum::http::HeaderValue::from_str(&reset_timestamp.to_string())
205            .unwrap_or_else(|_| axum::http::HeaderValue::from_static("0")),
206    );
207
208    // Increment usage (only for successful requests, 2xx status)
209    let status = response.status();
210    if status.is_success() {
211        // Total egress = response body size + request body size (both count as data transfer)
212        let response_size = estimate_response_size(&response);
213        let total_egress = response_size + request_body_bytes;
214
215        // Increment usage asynchronously (don't block response)
216        let pool_clone = pool.clone();
217        let redis_clone = state.redis.clone();
218        let org_id = org_ctx.org_id;
219        tokio::spawn(async move {
220            let _ = increment_usage(&pool_clone, redis_clone.as_ref(), org_id, total_egress).await;
221        });
222    }
223
224    Ok(response)
225}
226
227/// Estimate response size from headers.
228///
229/// Uses Content-Length when available, falls back to a conservative 256-byte
230/// estimate (typical small JSON response) rather than the old 1KB default.
231fn estimate_response_size(response: &Response) -> i64 {
232    response
233        .headers()
234        .get("content-length")
235        .and_then(|h| h.to_str().ok())
236        .and_then(|s| s.parse::<i64>().ok())
237        .unwrap_or(256)
238}
239
240/// Rate limit error
241#[derive(Debug)]
242pub enum RateLimitError {
243    Database,
244    LimitExceeded {
245        limit_type: String,
246        limit: i64,
247        used: i64,
248        reset_period: String,
249    },
250}
251
252/// Convert rate limit error to HTTP response
253fn rate_limit_error_response(error: RateLimitError) -> impl IntoResponse {
254    match error {
255        RateLimitError::Database => (
256            StatusCode::INTERNAL_SERVER_ERROR,
257            Json(json!({
258                "error": "Internal server error",
259                "message": "Failed to check rate limits"
260            })),
261        ),
262        RateLimitError::LimitExceeded {
263            limit_type,
264            limit,
265            used,
266            reset_period,
267        } => {
268            let limit_type_display = match limit_type.as_str() {
269                "requests" => "Monthly request limit",
270                "storage" => "Storage limit",
271                _ => "Usage limit",
272            };
273
274            (
275                StatusCode::TOO_MANY_REQUESTS,
276                Json(json!({
277                    "error": "Rate limit exceeded",
278                    "message": format!("{} exceeded. Used {}/{}", limit_type_display, used, limit),
279                    "limit_type": limit_type,
280                    "limit": limit,
281                    "used": used,
282                    "reset_period": reset_period,
283                    "upgrade_url": "/billing/upgrade"
284                })),
285            )
286        }
287    }
288}