use axum::{
extract::{Request, State},
http::{HeaderMap, Method},
middleware::Next,
response::{IntoResponse, Response},
};
use uuid::Uuid;
use crate::{
error::ApiError,
middleware::resolve_org_context,
models::{Subscription, SubscriptionStatus},
AppState,
};
const PAST_DUE_GRACE_SECONDS: i64 = 24 * 60 * 60;
const WRITE_ALLOWLIST_PREFIXES: &[&str] = &[
"/api/v1/billing/", "/api/v1/auth/", "/api/v1/support/", "/api/v1/legal/", "/api/v1/waitlist/", ];
fn is_read_method(method: &Method) -> bool {
matches!(method, &Method::GET | &Method::HEAD | &Method::OPTIONS)
}
fn is_write_allowlisted(path: &str) -> bool {
WRITE_ALLOWLIST_PREFIXES.iter().any(|prefix| path.starts_with(prefix))
}
pub async fn past_due_writes_blocked_middleware(
State(state): State<AppState>,
headers: HeaderMap,
request: Request,
next: Next,
) -> Result<Response, Response> {
if is_read_method(request.method()) {
return Ok(next.run(request).await);
}
if is_write_allowlisted(request.uri().path()) {
return Ok(next.run(request).await);
}
let Some(user_id) = request.extensions().get::<String>().and_then(|s| Uuid::parse_str(s).ok())
else {
return Ok(next.run(request).await);
};
let Ok(org_ctx) =
resolve_org_context(&state, user_id, &headers, Some(request.extensions())).await
else {
return Ok(next.run(request).await);
};
let pool = state.db.pool();
let subscription = match Subscription::find_by_org(pool, org_ctx.org_id).await {
Ok(Some(sub)) => sub,
Ok(None) => return Ok(next.run(request).await), Err(e) => {
tracing::error!(
org_id = %org_ctx.org_id,
"past_due middleware: subscription lookup failed: {}",
e,
);
return Ok(next.run(request).await); }
};
if subscription.status() != SubscriptionStatus::PastDue {
return Ok(next.run(request).await);
}
let elapsed = (chrono::Utc::now() - subscription.updated_at).num_seconds();
if elapsed <= PAST_DUE_GRACE_SECONDS {
return Ok(next.run(request).await);
}
tracing::warn!(
org_id = %org_ctx.org_id,
method = %request.method(),
path = request.uri().path(),
past_due_seconds = elapsed,
"blocking write: subscription past_due beyond 24h grace",
);
Err(ApiError::PaymentRequired(
"Subscription is past due. Update your payment method in the billing portal to resume \
deploys and other write operations. Reads remain available."
.to_string(),
)
.into_response())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn read_methods_pass() {
assert!(is_read_method(&Method::GET));
assert!(is_read_method(&Method::HEAD));
assert!(is_read_method(&Method::OPTIONS));
}
#[test]
fn write_methods_blocked_unless_allowlisted() {
assert!(!is_read_method(&Method::POST));
assert!(!is_read_method(&Method::PUT));
assert!(!is_read_method(&Method::PATCH));
assert!(!is_read_method(&Method::DELETE));
}
#[test]
fn billing_path_is_allowlisted() {
assert!(is_write_allowlisted("/api/v1/billing/checkout"));
assert!(is_write_allowlisted("/api/v1/billing/portal"));
}
#[test]
fn auth_path_is_allowlisted() {
assert!(is_write_allowlisted("/api/v1/auth/2fa/setup"));
assert!(is_write_allowlisted("/api/v1/auth/change-password"));
}
#[test]
fn deploy_path_not_allowlisted() {
assert!(!is_write_allowlisted("/api/v1/hosted-mocks"));
assert!(!is_write_allowlisted("/api/v1/workspaces"));
assert!(!is_write_allowlisted("/api/v1/organizations/abc/members"));
}
#[test]
fn billing_lookalike_not_allowlisted() {
assert!(!is_write_allowlisted("/api/v1/orgs/123/billing-summary"));
}
}