use axum::{
extract::{Path, State},
http::{HeaderMap, Method, StatusCode, Uri},
response::Response,
routing::any,
Router,
};
use uuid::Uuid;
use crate::middleware::org_rate_limit::increment_usage;
use crate::models::{HostedMock, Organization, UsageCounter};
use crate::redis::RedisPool;
use crate::AppState;
const DEFAULT_REQUESTS_PER_30D: i64 = 10_000;
const DEFAULT_MOCK_REQUEST_BODY_MB: i64 = 10;
const DEFAULT_MOCK_RPS_LIMIT: i64 = 100;
pub struct MultitenantRouter;
impl MultitenantRouter {
pub fn create_router() -> Router<AppState> {
Router::new()
.route("/mocks/{org_id}/{slug}/{*path}", any(Self::route_request))
.route("/mocks/{org_id}/{slug}", any(Self::route_request))
}
async fn route_request(
State(state): State<AppState>,
method: Method,
Path((org_id_str, slug)): Path<(String, String)>,
uri: Uri,
headers: HeaderMap,
body: axum::body::Body,
) -> Result<Response, StatusCode> {
let org_id = Uuid::parse_str(&org_id_str).map_err(|e| {
tracing::warn!("Invalid org_id '{}': {}", org_id_str, e);
StatusCode::BAD_REQUEST
})?;
let deployment = HostedMock::find_by_slug(state.db.pool(), org_id, &slug)
.await
.map_err(|e| {
tracing::error!("Database error looking up deployment {}/{}: {}", org_id, slug, e);
StatusCode::INTERNAL_SERVER_ERROR
})?
.ok_or(StatusCode::NOT_FOUND)?;
if !matches!(deployment.status(), crate::models::DeploymentStatus::Active) {
return Err(StatusCode::SERVICE_UNAVAILABLE);
}
enforce_monthly_quota(&state, deployment.org_id).await?;
let limits = resolve_proxy_limits(state.db.pool(), deployment.org_id).await;
enforce_rps(state.redis.as_ref(), deployment.id, limits.rps).await?;
let base_url = deployment
.internal_url
.as_ref()
.or(deployment.deployment_url.as_ref())
.ok_or(StatusCode::SERVICE_UNAVAILABLE)?;
let path = uri.path();
let path_after_slug =
path.strip_prefix(&format!("/mocks/{}/{}", org_id_str, slug)).unwrap_or("/");
let target_url = build_target_url(base_url, path_after_slug, uri.query());
let response =
proxy_request(method, headers, body, &target_url, limits.max_body_bytes).await?;
bump_proxy_usage(&state, deployment.org_id, &response);
Ok(response)
}
}
pub async fn custom_domain_fallback(
State(state): State<AppState>,
method: Method,
uri: Uri,
headers: HeaderMap,
body: axum::body::Body,
) -> Result<Response, StatusCode> {
let mocks_domain = match std::env::var("MOCKFORGE_MOCKS_DOMAIN") {
Ok(d) => d,
Err(_) => return Err(StatusCode::NOT_FOUND),
};
let host = headers.get("host").and_then(|v| v.to_str().ok()).unwrap_or("");
let host = host.split(':').next().unwrap_or(host);
let slug = match host.strip_suffix(&format!(".{}", mocks_domain)) {
Some(s) if !s.is_empty() && !s.contains('.') => s,
_ => return Err(StatusCode::NOT_FOUND),
};
tracing::debug!("Custom domain proxy: {} -> slug '{}'", host, slug);
let deployment = HostedMock::find_active_by_slug(state.db.pool(), slug)
.await
.map_err(|e| {
tracing::error!("Database error looking up deployment by slug '{}': {}", slug, e);
StatusCode::INTERNAL_SERVER_ERROR
})?
.ok_or(StatusCode::NOT_FOUND)?;
enforce_monthly_quota(&state, deployment.org_id).await?;
let limits = resolve_proxy_limits(state.db.pool(), deployment.org_id).await;
enforce_rps(state.redis.as_ref(), deployment.id, limits.rps).await?;
let base_url = deployment
.internal_url
.as_ref()
.or(deployment.deployment_url.as_ref())
.ok_or(StatusCode::SERVICE_UNAVAILABLE)?;
let target_url = build_target_url(base_url, uri.path(), uri.query());
let response = proxy_request(method, headers, body, &target_url, limits.max_body_bytes).await?;
bump_proxy_usage(&state, deployment.org_id, &response);
Ok(response)
}
fn build_target_url(base_url: &str, path: &str, query: Option<&str>) -> String {
let mut url = format!("{}{}", base_url, path);
if let Some(q) = query {
url = format!("{}?{}", url, q);
}
url
}
async fn proxy_request(
method: Method,
headers: HeaderMap,
body: axum::body::Body,
target_url: &str,
max_body_bytes: usize,
) -> Result<Response, StatusCode> {
let client = reqwest::Client::new();
if let Some(declared) = headers
.get("content-length")
.and_then(|v| v.to_str().ok())
.and_then(|s| s.parse::<usize>().ok())
{
if declared > max_body_bytes {
tracing::warn!(
"Rejecting oversized proxy body: declared={} max={}",
declared,
max_body_bytes
);
return Err(StatusCode::PAYLOAD_TOO_LARGE);
}
}
let body_bytes = match axum::body::to_bytes(body, max_body_bytes).await {
Ok(b) => b,
Err(e) => {
tracing::warn!("Proxy body read failed (cap={} bytes): {}", max_body_bytes, e);
return Err(StatusCode::PAYLOAD_TOO_LARGE);
}
};
let request_builder = match method.as_str() {
"GET" => client.get(target_url),
"HEAD" => client.head(target_url),
"POST" => {
let mut req = client.post(target_url);
if !body_bytes.is_empty() {
req = req.body(body_bytes.to_vec());
}
req
}
"PUT" => {
let mut req = client.put(target_url);
if !body_bytes.is_empty() {
req = req.body(body_bytes.to_vec());
}
req
}
"PATCH" => {
let mut req = client.patch(target_url);
if !body_bytes.is_empty() {
req = req.body(body_bytes.to_vec());
}
req
}
"DELETE" => client.delete(target_url),
_ => return Err(StatusCode::METHOD_NOT_ALLOWED),
};
let mut request = request_builder.timeout(std::time::Duration::from_secs(30));
for header_name in ["accept", "content-type", "authorization", "x-request-id"] {
if let Some(value) = headers.get(header_name) {
if let Ok(value_str) = value.to_str() {
request = request.header(header_name, value_str);
}
}
}
let response = request.send().await.map_err(|e| {
tracing::error!("Failed to proxy request to {}: {}", target_url, e);
StatusCode::BAD_GATEWAY
})?;
let status = StatusCode::from_u16(response.status().as_u16())
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
let mut response_headers = Vec::new();
for (key, value) in response.headers() {
if let (Ok(header_name), Ok(value_str)) =
(key.as_str().parse::<axum::http::HeaderName>(), value.to_str())
{
if let Ok(header_value) = axum::http::HeaderValue::from_str(value_str) {
response_headers.push((header_name, header_value));
}
}
}
let resp_body = response.bytes().await.map_err(|e| {
tracing::error!("Failed to read proxy response body: {}", e);
StatusCode::BAD_GATEWAY
})?;
let mut response_builder = Response::builder().status(status);
for (header_name, header_value) in response_headers {
response_builder = response_builder.header(header_name, header_value);
}
response_builder.body(axum::body::Body::from(resp_body.to_vec())).map_err(|e| {
tracing::error!("Failed to build proxy response: {}", e);
StatusCode::INTERNAL_SERVER_ERROR
})
}
fn monthly_request_limit(limits_json: &serde_json::Value) -> Option<i64> {
match limits_json.get("requests_per_30d").and_then(|v| v.as_i64()) {
Some(-1) => None, Some(n) if n > 0 => Some(n),
_ => Some(DEFAULT_REQUESTS_PER_30D),
}
}
async fn enforce_monthly_quota(state: &AppState, org_id: Uuid) -> Result<(), StatusCode> {
let org = match Organization::find_by_id(state.db.pool(), org_id).await {
Ok(Some(org)) => org,
Ok(None) => {
tracing::warn!("Org {} not found while enforcing monthly quota", org_id);
return Ok(());
}
Err(e) => {
tracing::error!("DB error loading org {} for monthly quota check: {}", org_id, e);
return Ok(());
}
};
let Some(limit) = monthly_request_limit(&org.limits_json) else {
return Ok(()); };
let used = match UsageCounter::get_or_create_current(state.db.pool(), org_id).await {
Ok(counter) => counter.requests,
Err(e) => {
tracing::error!("Failed to read usage counter for org {}: {}", org_id, e);
return Ok(()); }
};
if used >= limit {
tracing::info!("Monthly request quota exhausted for org {}: {}/{}", org_id, used, limit);
Err(StatusCode::TOO_MANY_REQUESTS)
} else {
Ok(())
}
}
fn bump_proxy_usage(state: &AppState, org_id: Uuid, response: &Response) {
if !response.status().is_success() {
return;
}
let response_size = response
.headers()
.get("content-length")
.and_then(|h| h.to_str().ok())
.and_then(|s| s.parse::<i64>().ok())
.unwrap_or(256);
let pool = state.db.pool().clone();
let redis = state.redis.clone();
tokio::spawn(async move {
if let Err(e) = increment_usage(&pool, redis.as_ref(), org_id, response_size).await {
tracing::error!("Failed to increment proxy usage for org {}: {:?}", org_id, e);
}
});
}
#[derive(Debug, Clone, Copy)]
struct ProxyLimits {
max_body_bytes: usize,
rps: i64,
}
fn proxy_limits_from_json(limits_json: &serde_json::Value) -> ProxyLimits {
let body_mb = limits_json
.get("mock_request_body_mb")
.and_then(|v| v.as_i64())
.filter(|v| *v > 0)
.unwrap_or(DEFAULT_MOCK_REQUEST_BODY_MB);
let rps = limits_json
.get("mock_rps_limit")
.and_then(|v| v.as_i64())
.filter(|v| *v > 0)
.unwrap_or(DEFAULT_MOCK_RPS_LIMIT);
ProxyLimits {
max_body_bytes: (body_mb as usize).saturating_mul(1024 * 1024),
rps,
}
}
async fn resolve_proxy_limits(pool: &sqlx::PgPool, org_id: Uuid) -> ProxyLimits {
let limits_json = match Organization::find_by_id(pool, org_id).await {
Ok(Some(org)) => org.limits_json,
Ok(None) => {
tracing::warn!("Org {} not found while resolving proxy limits", org_id);
serde_json::Value::Null
}
Err(e) => {
tracing::error!("DB error resolving proxy limits for org {}: {}", org_id, e);
serde_json::Value::Null
}
};
proxy_limits_from_json(&limits_json)
}
async fn enforce_rps(
redis: Option<&RedisPool>,
deployment_id: Uuid,
rps: i64,
) -> Result<(), StatusCode> {
let Some(pool) = redis else {
tracing::debug!(
"Redis not configured — skipping RPS enforcement for deployment {}",
deployment_id
);
return Ok(());
};
let bucket = chrono::Utc::now().timestamp();
let key = format!("mock_rps:{}:{}", deployment_id, bucket);
match pool.increment_with_expiry(&key, 2).await {
Ok(count) if count > rps => {
tracing::info!("RPS cap hit for deployment {}: {}/{}", deployment_id, count, rps);
Err(StatusCode::TOO_MANY_REQUESTS)
}
Ok(_) => Ok(()),
Err(e) => {
tracing::error!("Redis RPS check failed for deployment {}: {}", deployment_id, e);
Ok(())
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn monthly_limit_pro_plan_default() {
assert_eq!(monthly_request_limit(&json!({ "requests_per_30d": 250_000 })), Some(250_000));
}
#[test]
fn monthly_limit_team_plan_default() {
assert_eq!(
monthly_request_limit(&json!({ "requests_per_30d": 1_000_000 })),
Some(1_000_000)
);
}
#[test]
fn monthly_limit_unlimited_sentinel() {
assert_eq!(monthly_request_limit(&json!({ "requests_per_30d": -1 })), None);
}
#[test]
fn monthly_limit_zero_falls_back_to_default() {
assert_eq!(
monthly_request_limit(&json!({ "requests_per_30d": 0 })),
Some(DEFAULT_REQUESTS_PER_30D)
);
}
#[test]
fn monthly_limit_missing_field_falls_back() {
assert_eq!(monthly_request_limit(&json!({})), Some(DEFAULT_REQUESTS_PER_30D));
}
#[test]
fn monthly_limit_null_json_falls_back() {
assert_eq!(monthly_request_limit(&serde_json::Value::Null), Some(DEFAULT_REQUESTS_PER_30D));
}
#[test]
fn monthly_limit_wrong_json_type_falls_back() {
assert_eq!(
monthly_request_limit(&json!({ "requests_per_30d": "250000" })),
Some(DEFAULT_REQUESTS_PER_30D)
);
}
#[test]
fn proxy_limits_pro_plan_defaults() {
let limits = proxy_limits_from_json(&json!({
"mock_request_body_mb": 10,
"mock_rps_limit": 100,
}));
assert_eq!(limits.max_body_bytes, 10 * 1024 * 1024);
assert_eq!(limits.rps, 100);
}
#[test]
fn proxy_limits_team_plan_defaults() {
let limits = proxy_limits_from_json(&json!({
"mock_request_body_mb": 50,
"mock_rps_limit": 1000,
}));
assert_eq!(limits.max_body_bytes, 50 * 1024 * 1024);
assert_eq!(limits.rps, 1000);
}
#[test]
fn proxy_limits_missing_fields_fall_back_to_built_in_defaults() {
let limits = proxy_limits_from_json(&json!({}));
assert_eq!(limits.max_body_bytes, DEFAULT_MOCK_REQUEST_BODY_MB as usize * 1024 * 1024);
assert_eq!(limits.rps, DEFAULT_MOCK_RPS_LIMIT);
}
#[test]
fn proxy_limits_null_json_falls_back() {
let limits = proxy_limits_from_json(&serde_json::Value::Null);
assert_eq!(limits.max_body_bytes, DEFAULT_MOCK_REQUEST_BODY_MB as usize * 1024 * 1024);
assert_eq!(limits.rps, DEFAULT_MOCK_RPS_LIMIT);
}
#[test]
fn proxy_limits_non_positive_values_treated_as_missing() {
let limits = proxy_limits_from_json(&json!({
"mock_request_body_mb": -1,
"mock_rps_limit": 0,
}));
assert_eq!(limits.max_body_bytes, DEFAULT_MOCK_REQUEST_BODY_MB as usize * 1024 * 1024);
assert_eq!(limits.rps, DEFAULT_MOCK_RPS_LIMIT);
}
#[test]
fn proxy_limits_string_values_treated_as_missing() {
let limits = proxy_limits_from_json(&json!({
"mock_request_body_mb": "10",
"mock_rps_limit": "100",
}));
assert_eq!(limits.max_body_bytes, DEFAULT_MOCK_REQUEST_BODY_MB as usize * 1024 * 1024);
assert_eq!(limits.rps, DEFAULT_MOCK_RPS_LIMIT);
}
#[test]
fn proxy_limits_extreme_body_mb_does_not_overflow() {
let limits = proxy_limits_from_json(&json!({
"mock_request_body_mb": i64::MAX,
"mock_rps_limit": 100,
}));
assert_eq!(limits.max_body_bytes, usize::MAX);
}
#[tokio::test]
async fn enforce_rps_without_redis_is_allow_through() {
let result = enforce_rps(None, Uuid::new_v4(), 100).await;
assert!(result.is_ok());
}
}