1use axum::{
7 extract::{Path, State},
8 http::{HeaderMap, Method, StatusCode, Uri},
9 response::{IntoResponse, Response},
10 routing::any,
11 Router,
12};
13use uuid::Uuid;
14
15use crate::error::ApiError;
16use crate::middleware::org_rate_limit::increment_usage;
17use crate::models::{HostedMock, Organization, UsageCounter};
18use crate::redis::{current_month_period, RedisPool};
19use crate::AppState;
20
21const DEFAULT_REQUESTS_PER_30D: i64 = 10_000;
25
26const DEFAULT_MOCK_REQUEST_BODY_MB: i64 = 10;
31
32const DEFAULT_MOCK_RPS_LIMIT: i64 = 100;
35
36pub struct MultitenantRouter;
38
39impl MultitenantRouter {
40 pub fn create_router() -> Router<AppState> {
43 Router::new()
44 .route("/mocks/{org_id}/{slug}/{*path}", any(Self::route_request))
45 .route("/mocks/{org_id}/{slug}", any(Self::route_request))
46 }
47
48 async fn route_request(
50 State(state): State<AppState>,
51 method: Method,
52 Path((org_id_str, slug)): Path<(String, String)>,
53 uri: Uri,
54 headers: HeaderMap,
55 body: axum::body::Body,
56 ) -> Result<Response, StatusCode> {
57 let org_id = Uuid::parse_str(&org_id_str).map_err(|e| {
59 tracing::warn!("Invalid org_id '{}': {}", org_id_str, e);
60 StatusCode::BAD_REQUEST
61 })?;
62
63 let deployment = HostedMock::find_by_slug(state.db.pool(), org_id, &slug)
65 .await
66 .map_err(|e| {
67 tracing::error!("Database error looking up deployment {}/{}: {}", org_id, slug, e);
68 StatusCode::INTERNAL_SERVER_ERROR
69 })?
70 .ok_or(StatusCode::NOT_FOUND)?;
71
72 if !matches!(deployment.status(), crate::models::DeploymentStatus::Active) {
74 return Err(StatusCode::SERVICE_UNAVAILABLE);
75 }
76
77 if let Err(response) = enforce_monthly_quota(&state, deployment.org_id).await {
81 return Ok(response);
82 }
83
84 let limits = resolve_proxy_limits(state.db.pool(), deployment.org_id).await;
88 enforce_rps(state.redis.as_ref(), deployment.id, limits.rps).await?;
89
90 let base_url = deployment
92 .internal_url
93 .as_ref()
94 .or(deployment.deployment_url.as_ref())
95 .ok_or(StatusCode::SERVICE_UNAVAILABLE)?;
96
97 let path = uri.path();
99 let path_after_slug =
100 path.strip_prefix(&format!("/mocks/{}/{}", org_id_str, slug)).unwrap_or("/");
101
102 let target_url = build_target_url(base_url, path_after_slug, uri.query());
104
105 let response =
106 proxy_request(method, headers, body, &target_url, limits.max_body_bytes).await?;
107 bump_proxy_usage(&state, deployment.org_id, &response);
108 Ok(response)
109 }
110}
111
112pub async fn custom_domain_fallback(
121 State(state): State<AppState>,
122 method: Method,
123 uri: Uri,
124 headers: HeaderMap,
125 body: axum::body::Body,
126) -> Result<Response, StatusCode> {
127 let mocks_domain = match std::env::var("MOCKFORGE_MOCKS_DOMAIN") {
128 Ok(d) => d,
129 Err(_) => return Err(StatusCode::NOT_FOUND),
130 };
131
132 let host = headers.get("host").and_then(|v| v.to_str().ok()).unwrap_or("");
134 let host = host.split(':').next().unwrap_or(host);
135
136 let slug = match host.strip_suffix(&format!(".{}", mocks_domain)) {
138 Some(s) if !s.is_empty() && !s.contains('.') => s,
139 _ => return Err(StatusCode::NOT_FOUND),
140 };
141
142 tracing::debug!("Custom domain proxy: {} -> slug '{}'", host, slug);
143
144 let deployment = HostedMock::find_active_by_slug(state.db.pool(), slug)
146 .await
147 .map_err(|e| {
148 tracing::error!("Database error looking up deployment by slug '{}': {}", slug, e);
149 StatusCode::INTERNAL_SERVER_ERROR
150 })?
151 .ok_or(StatusCode::NOT_FOUND)?;
152
153 if let Err(response) = enforce_monthly_quota(&state, deployment.org_id).await {
155 return Ok(response);
156 }
157
158 let limits = resolve_proxy_limits(state.db.pool(), deployment.org_id).await;
161 enforce_rps(state.redis.as_ref(), deployment.id, limits.rps).await?;
162
163 let base_url = deployment
165 .internal_url
166 .as_ref()
167 .or(deployment.deployment_url.as_ref())
168 .ok_or(StatusCode::SERVICE_UNAVAILABLE)?;
169
170 let target_url = build_target_url(base_url, uri.path(), uri.query());
171
172 let response = proxy_request(method, headers, body, &target_url, limits.max_body_bytes).await?;
173 bump_proxy_usage(&state, deployment.org_id, &response);
174 Ok(response)
175}
176
177fn build_target_url(base_url: &str, path: &str, query: Option<&str>) -> String {
179 let mut url = format!("{}{}", base_url, path);
180 if let Some(q) = query {
181 url = format!("{}?{}", url, q);
182 }
183 url
184}
185
186async fn proxy_request(
193 method: Method,
194 headers: HeaderMap,
195 body: axum::body::Body,
196 target_url: &str,
197 max_body_bytes: usize,
198) -> Result<Response, StatusCode> {
199 let client = reqwest::Client::new();
200
201 if let Some(declared) = headers
203 .get("content-length")
204 .and_then(|v| v.to_str().ok())
205 .and_then(|s| s.parse::<usize>().ok())
206 {
207 if declared > max_body_bytes {
208 tracing::warn!(
209 "Rejecting oversized proxy body: declared={} max={}",
210 declared,
211 max_body_bytes
212 );
213 return Err(StatusCode::PAYLOAD_TOO_LARGE);
214 }
215 }
216
217 let body_bytes = match axum::body::to_bytes(body, max_body_bytes).await {
220 Ok(b) => b,
221 Err(e) => {
222 tracing::warn!("Proxy body read failed (cap={} bytes): {}", max_body_bytes, e);
223 return Err(StatusCode::PAYLOAD_TOO_LARGE);
224 }
225 };
226
227 let request_builder = match method.as_str() {
229 "GET" => client.get(target_url),
230 "HEAD" => client.head(target_url),
231 "POST" => {
232 let mut req = client.post(target_url);
233 if !body_bytes.is_empty() {
234 req = req.body(body_bytes.to_vec());
235 }
236 req
237 }
238 "PUT" => {
239 let mut req = client.put(target_url);
240 if !body_bytes.is_empty() {
241 req = req.body(body_bytes.to_vec());
242 }
243 req
244 }
245 "PATCH" => {
246 let mut req = client.patch(target_url);
247 if !body_bytes.is_empty() {
248 req = req.body(body_bytes.to_vec());
249 }
250 req
251 }
252 "DELETE" => client.delete(target_url),
253 _ => return Err(StatusCode::METHOD_NOT_ALLOWED),
254 };
255
256 let mut request = request_builder.timeout(std::time::Duration::from_secs(30));
257
258 for header_name in ["accept", "content-type", "authorization", "x-request-id"] {
260 if let Some(value) = headers.get(header_name) {
261 if let Ok(value_str) = value.to_str() {
262 request = request.header(header_name, value_str);
263 }
264 }
265 }
266
267 let response = request.send().await.map_err(|e| {
268 tracing::error!("Failed to proxy request to {}: {}", target_url, e);
269 StatusCode::BAD_GATEWAY
270 })?;
271
272 let status = StatusCode::from_u16(response.status().as_u16())
274 .unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
275
276 let mut response_headers = Vec::new();
277 for (key, value) in response.headers() {
278 if let (Ok(header_name), Ok(value_str)) =
279 (key.as_str().parse::<axum::http::HeaderName>(), value.to_str())
280 {
281 if let Ok(header_value) = axum::http::HeaderValue::from_str(value_str) {
282 response_headers.push((header_name, header_value));
283 }
284 }
285 }
286
287 let resp_body = response.bytes().await.map_err(|e| {
288 tracing::error!("Failed to read proxy response body: {}", e);
289 StatusCode::BAD_GATEWAY
290 })?;
291
292 let mut response_builder = Response::builder().status(status);
293 for (header_name, header_value) in response_headers {
294 response_builder = response_builder.header(header_name, header_value);
295 }
296
297 response_builder.body(axum::body::Body::from(resp_body.to_vec())).map_err(|e| {
298 tracing::error!("Failed to build proxy response: {}", e);
299 StatusCode::INTERNAL_SERVER_ERROR
300 })
301}
302
303fn monthly_request_limit(limits_json: &serde_json::Value) -> Option<i64> {
308 match limits_json.get("requests_per_30d").and_then(|v| v.as_i64()) {
309 Some(-1) => None, Some(n) if n > 0 => Some(n),
311 _ => Some(DEFAULT_REQUESTS_PER_30D),
314 }
315}
316
317async fn enforce_monthly_quota(state: &AppState, org_id: Uuid) -> Result<(), Response> {
326 let org = match Organization::find_by_id(state.db.pool(), org_id).await {
327 Ok(Some(org)) => org,
328 Ok(None) => {
329 tracing::warn!("Org {} not found while enforcing monthly quota", org_id);
330 return Ok(());
331 }
332 Err(e) => {
333 tracing::error!("DB error loading org {} for monthly quota check: {}", org_id, e);
334 return Ok(());
335 }
336 };
337
338 let Some(limit) = monthly_request_limit(&org.limits_json) else {
339 return Ok(()); };
341
342 let used = match UsageCounter::get_or_create_current(state.db.pool(), org_id).await {
343 Ok(counter) => counter.requests,
344 Err(e) => {
345 tracing::error!("Failed to read usage counter for org {}: {}", org_id, e);
346 return Ok(()); }
348 };
349
350 if used >= limit {
351 tracing::info!("Monthly request quota exhausted for org {}: {}/{}", org_id, used, limit);
352 Err(ApiError::UsageLimitExceeded {
353 limit_type: "requests".to_string(),
354 current: used,
355 max: limit,
356 period: current_month_period(),
357 }
358 .into_response())
359 } else {
360 Ok(())
361 }
362}
363
364fn bump_proxy_usage(state: &AppState, org_id: Uuid, response: &Response) {
375 if !response.status().is_success() {
376 return;
377 }
378
379 let response_size = response
380 .headers()
381 .get("content-length")
382 .and_then(|h| h.to_str().ok())
383 .and_then(|s| s.parse::<i64>().ok())
384 .unwrap_or(256);
385
386 let pool = state.db.pool().clone();
387 let redis = state.redis.clone();
388 tokio::spawn(async move {
389 if let Err(e) = increment_usage(&pool, redis.as_ref(), org_id, response_size).await {
390 tracing::error!("Failed to increment proxy usage for org {}: {:?}", org_id, e);
391 }
392 });
393}
394
395#[derive(Debug, Clone, Copy)]
399struct ProxyLimits {
400 max_body_bytes: usize,
402 rps: i64,
404}
405
406fn proxy_limits_from_json(limits_json: &serde_json::Value) -> ProxyLimits {
416 let body_mb = limits_json
417 .get("mock_request_body_mb")
418 .and_then(|v| v.as_i64())
419 .filter(|v| *v > 0)
420 .unwrap_or(DEFAULT_MOCK_REQUEST_BODY_MB);
421 let rps = limits_json
422 .get("mock_rps_limit")
423 .and_then(|v| v.as_i64())
424 .filter(|v| *v > 0)
425 .unwrap_or(DEFAULT_MOCK_RPS_LIMIT);
426
427 ProxyLimits {
428 max_body_bytes: (body_mb as usize).saturating_mul(1024 * 1024),
429 rps,
430 }
431}
432
433async fn resolve_proxy_limits(pool: &sqlx::PgPool, org_id: Uuid) -> ProxyLimits {
439 let limits_json = match Organization::find_by_id(pool, org_id).await {
440 Ok(Some(org)) => org.limits_json,
441 Ok(None) => {
442 tracing::warn!("Org {} not found while resolving proxy limits", org_id);
443 serde_json::Value::Null
444 }
445 Err(e) => {
446 tracing::error!("DB error resolving proxy limits for org {}: {}", org_id, e);
447 serde_json::Value::Null
448 }
449 };
450
451 proxy_limits_from_json(&limits_json)
452}
453
454async fn enforce_rps(
460 redis: Option<&RedisPool>,
461 deployment_id: Uuid,
462 rps: i64,
463) -> Result<(), StatusCode> {
464 let Some(pool) = redis else {
465 tracing::debug!(
466 "Redis not configured — skipping RPS enforcement for deployment {}",
467 deployment_id
468 );
469 return Ok(());
470 };
471
472 let bucket = chrono::Utc::now().timestamp();
473 let key = format!("mock_rps:{}:{}", deployment_id, bucket);
474
475 match pool.increment_with_expiry(&key, 2).await {
478 Ok(count) if count > rps => {
479 tracing::info!("RPS cap hit for deployment {}: {}/{}", deployment_id, count, rps);
480 Err(StatusCode::TOO_MANY_REQUESTS)
481 }
482 Ok(_) => Ok(()),
483 Err(e) => {
484 tracing::error!("Redis RPS check failed for deployment {}: {}", deployment_id, e);
485 Ok(())
487 }
488 }
489}
490
491#[cfg(test)]
492mod tests {
493 use super::*;
494 use serde_json::json;
495
496 #[test]
497 fn monthly_limit_pro_plan_default() {
498 assert_eq!(monthly_request_limit(&json!({ "requests_per_30d": 250_000 })), Some(250_000));
499 }
500
501 #[test]
502 fn monthly_limit_team_plan_default() {
503 assert_eq!(
504 monthly_request_limit(&json!({ "requests_per_30d": 1_000_000 })),
505 Some(1_000_000)
506 );
507 }
508
509 #[test]
510 fn monthly_limit_unlimited_sentinel() {
511 assert_eq!(monthly_request_limit(&json!({ "requests_per_30d": -1 })), None);
513 }
514
515 #[test]
516 fn monthly_limit_zero_falls_back_to_default() {
517 assert_eq!(
520 monthly_request_limit(&json!({ "requests_per_30d": 0 })),
521 Some(DEFAULT_REQUESTS_PER_30D)
522 );
523 }
524
525 #[test]
526 fn monthly_limit_missing_field_falls_back() {
527 assert_eq!(monthly_request_limit(&json!({})), Some(DEFAULT_REQUESTS_PER_30D));
528 }
529
530 #[test]
531 fn monthly_limit_null_json_falls_back() {
532 assert_eq!(monthly_request_limit(&serde_json::Value::Null), Some(DEFAULT_REQUESTS_PER_30D));
533 }
534
535 #[test]
536 fn monthly_limit_wrong_json_type_falls_back() {
537 assert_eq!(
540 monthly_request_limit(&json!({ "requests_per_30d": "250000" })),
541 Some(DEFAULT_REQUESTS_PER_30D)
542 );
543 }
544
545 #[test]
548 fn proxy_limits_pro_plan_defaults() {
549 let limits = proxy_limits_from_json(&json!({
550 "mock_request_body_mb": 10,
551 "mock_rps_limit": 100,
552 }));
553 assert_eq!(limits.max_body_bytes, 10 * 1024 * 1024);
554 assert_eq!(limits.rps, 100);
555 }
556
557 #[test]
558 fn proxy_limits_team_plan_defaults() {
559 let limits = proxy_limits_from_json(&json!({
560 "mock_request_body_mb": 50,
561 "mock_rps_limit": 1000,
562 }));
563 assert_eq!(limits.max_body_bytes, 50 * 1024 * 1024);
564 assert_eq!(limits.rps, 1000);
565 }
566
567 #[test]
568 fn proxy_limits_missing_fields_fall_back_to_built_in_defaults() {
569 let limits = proxy_limits_from_json(&json!({}));
571 assert_eq!(limits.max_body_bytes, DEFAULT_MOCK_REQUEST_BODY_MB as usize * 1024 * 1024);
572 assert_eq!(limits.rps, DEFAULT_MOCK_RPS_LIMIT);
573 }
574
575 #[test]
576 fn proxy_limits_null_json_falls_back() {
577 let limits = proxy_limits_from_json(&serde_json::Value::Null);
579 assert_eq!(limits.max_body_bytes, DEFAULT_MOCK_REQUEST_BODY_MB as usize * 1024 * 1024);
580 assert_eq!(limits.rps, DEFAULT_MOCK_RPS_LIMIT);
581 }
582
583 #[test]
584 fn proxy_limits_non_positive_values_treated_as_missing() {
585 let limits = proxy_limits_from_json(&json!({
588 "mock_request_body_mb": -1,
589 "mock_rps_limit": 0,
590 }));
591 assert_eq!(limits.max_body_bytes, DEFAULT_MOCK_REQUEST_BODY_MB as usize * 1024 * 1024);
592 assert_eq!(limits.rps, DEFAULT_MOCK_RPS_LIMIT);
593 }
594
595 #[test]
596 fn proxy_limits_string_values_treated_as_missing() {
597 let limits = proxy_limits_from_json(&json!({
600 "mock_request_body_mb": "10",
601 "mock_rps_limit": "100",
602 }));
603 assert_eq!(limits.max_body_bytes, DEFAULT_MOCK_REQUEST_BODY_MB as usize * 1024 * 1024);
604 assert_eq!(limits.rps, DEFAULT_MOCK_RPS_LIMIT);
605 }
606
607 #[test]
608 fn proxy_limits_extreme_body_mb_does_not_overflow() {
609 let limits = proxy_limits_from_json(&json!({
611 "mock_request_body_mb": i64::MAX,
612 "mock_rps_limit": 100,
613 }));
614 assert_eq!(limits.max_body_bytes, usize::MAX);
615 }
616
617 #[tokio::test]
618 async fn enforce_rps_without_redis_is_allow_through() {
619 let result = enforce_rps(None, Uuid::new_v4(), 100).await;
622 assert!(result.is_ok());
623 }
624}