use crate::{
db::errors::DbError,
db::handlers::{Credits, api_keys::ApiKeys},
errors::Error,
types::UserId,
};
use axum::{
body::Body,
extract::State,
http::{Request, Response, StatusCode},
middleware::Next,
response::IntoResponse,
};
use rust_decimal::Decimal;
use serde::Deserialize;
use sqlx::PgPool;
use tracing::{debug, instrument};
#[derive(Debug, Deserialize)]
struct ChatRequest {
model: String,
}
#[instrument(name = "dwctl.error_enrichment", skip_all, fields(http.request.method = %request.method(), url.path = %request.uri().path(), url.query = request.uri().query().unwrap_or("")))]
pub async fn error_enrichment_middleware(State(pool): State<PgPool>, request: Request<Body>, next: Next) -> Response<Body> {
let api_key = request
.headers()
.get("authorization")
.and_then(|h| h.to_str().ok())
.and_then(|auth| auth.strip_prefix("Bearer ").or_else(|| auth.strip_prefix("bearer ")))
.map(|token| token.trim().to_string());
let (parts, body) = request.into_parts();
let bytes = match axum::body::to_bytes(body, usize::MAX).await {
Ok(b) => b,
Err(_) => {
let reconstructed = Request::from_parts(parts, Body::empty());
return next.run(reconstructed).await;
}
};
let model_name = serde_json::from_slice::<ChatRequest>(&bytes).ok().map(|req| req.model);
let reconstructed = Request::from_parts(parts, Body::from(bytes));
let response = next.run(reconstructed).await;
if response.status() == StatusCode::FORBIDDEN
&& let Some(key) = api_key
{
debug!("Intercepted 403 response on AI proxy path, attempting enrichment");
if let Some(model) = model_name
&& let Ok(user_id) = get_user_id_of_api_key(pool.clone(), &key).await
&& let Ok(has_access) = check_user_has_model_access(pool.clone(), user_id, &model).await
&& !has_access
{
return Error::ModelAccessDenied {
model_name: model.clone(),
message: format!(
"You do not have access to '{}'. Please contact your administrator to request access.",
model
),
}
.into_response();
}
if let Ok(balance) = get_balance_of_api_key(pool.clone(), &key).await
&& balance <= Decimal::ZERO
{
return Error::InsufficientCredits {
current_balance: balance,
message: "Account balance too low. Please add credits to continue.".to_string(),
}
.into_response();
}
}
response
}
#[instrument(skip_all, name = "dwctl.get_user_id_of_api_key")]
pub async fn get_user_id_of_api_key(pool: PgPool, api_key: &str) -> Result<UserId, DbError> {
let mut conn = pool.acquire().await?;
let mut api_keys_repo = ApiKeys::new(&mut conn);
api_keys_repo
.get_user_id_by_secret(api_key)
.await?
.ok_or_else(|| anyhow::anyhow!("API key not found or associated user doesn't exist").into())
}
#[instrument(skip_all, name = "dwctl.get_balance_of_api_key")]
pub async fn get_balance_of_api_key(pool: PgPool, api_key: &str) -> Result<Decimal, DbError> {
let user_id = get_user_id_of_api_key(pool.clone(), api_key).await?;
debug!("Found user_id for API key: {}", user_id);
let mut conn = pool.acquire().await?;
let mut credits_repo = Credits::new(&mut conn);
credits_repo.get_user_balance(user_id).await
}
#[instrument(skip_all, name = "dwctl.check_user_has_model_access")]
pub async fn check_user_has_model_access(pool: PgPool, user_id: UserId, model_alias: &str) -> Result<bool, DbError> {
let mut conn = pool.acquire().await?;
let result = sqlx::query_scalar!(
r#"
SELECT EXISTS(
SELECT 1
FROM deployed_models d
JOIN deployment_groups dg ON dg.deployment_id = d.id
WHERE d.alias = $1
AND d.deleted = false
AND dg.group_id IN (
SELECT ug.group_id FROM user_groups ug WHERE ug.user_id = $2
UNION
SELECT '00000000-0000-0000-0000-000000000000'::uuid
WHERE $2 != '00000000-0000-0000-0000-000000000000'
)
) as "has_access!"
"#,
model_alias,
user_id
)
.fetch_one(&mut *conn)
.await?;
Ok(result)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::db::{
handlers::{Credits, Repository as _, api_keys::ApiKeys},
models::{
api_keys::{ApiKeyCreateDBRequest, ApiKeyPurpose},
credits::{CreditTransactionCreateDBRequest, CreditTransactionType},
},
};
use crate::{api::models::users::Role, test::utils::create_test_user};
use rust_decimal::Decimal;
#[sqlx::test]
#[test_log::test]
async fn test_error_enrichment_middleware_enriches_403_with_balance(pool: PgPool) {
use crate::test::utils::{add_deployment_to_group, add_user_to_group, create_test_group};
let user = create_test_user(&pool, Role::StandardUser).await;
let mut api_key_conn = pool.acquire().await.unwrap();
let mut api_keys_repo = ApiKeys::new(&mut api_key_conn);
let api_key = api_keys_repo
.create(&ApiKeyCreateDBRequest {
user_id: user.id,
name: "Test Key".to_string(),
description: None,
purpose: ApiKeyPurpose::Realtime,
requests_per_second: None,
burst_size: None,
created_by: user.id,
})
.await
.unwrap();
let mut credits_conn = pool.acquire().await.unwrap();
let mut credits_repo = Credits::new(&mut credits_conn);
credits_repo
.create_transaction(&CreditTransactionCreateDBRequest {
user_id: user.id,
transaction_type: CreditTransactionType::AdminGrant,
amount: Decimal::new(5000, 2),
source_id: uuid::Uuid::new_v4().to_string(),
description: Some("Initial credits".to_string()),
fusillade_batch_id: None,
api_key_id: None,
})
.await
.unwrap();
let router = axum::Router::new()
.route(
"/ai/v1/chat/completions",
axum::routing::post(|| async {
axum::response::Response::builder()
.status(StatusCode::FORBIDDEN)
.body(axum::body::Body::from("Forbidden"))
.unwrap()
}),
)
.layer(axum::middleware::from_fn_with_state(
pool.clone(),
crate::error_enrichment::error_enrichment_middleware,
));
let server = axum_test::TestServer::new(router).expect("Failed to create test server");
let response = server
.post("/ai/v1/chat/completions")
.add_header("authorization", &format!("Bearer {}", api_key.secret))
.json(&serde_json::json!({
"model": "test-model",
"messages": [{"role": "user", "content": "Hello"}]
}))
.await;
assert_eq!(response.status_code().as_u16(), 403);
let body = response.text();
assert!(body.contains("do not have access to 'test-model'"));
let mut credits_conn = pool.acquire().await.unwrap();
let mut credits_repo = Credits::new(&mut credits_conn);
credits_repo
.create_transaction(&CreditTransactionCreateDBRequest {
user_id: user.id,
transaction_type: CreditTransactionType::Usage,
amount: Decimal::new(10000, 2),
source_id: uuid::Uuid::new_v4().to_string(),
description: Some("Usage".to_string()),
fusillade_batch_id: None,
api_key_id: None,
})
.await
.unwrap();
let endpoint_id = crate::test::utils::create_test_endpoint(&pool, "test-endpoint", user.id).await;
let deployment_id =
crate::test::utils::create_test_model(&pool, "authorized-model-name", "authorized-model", endpoint_id, user.id).await;
let group = create_test_group(&pool).await;
add_user_to_group(&pool, user.id, group.id).await;
add_deployment_to_group(&pool, deployment_id, group.id, user.id).await;
let response = server
.post("/ai/v1/chat/completions")
.add_header("authorization", &format!("Bearer {}", api_key.secret))
.json(&serde_json::json!({
"model": "authorized-model",
"messages": [{"role": "user", "content": "Hello"}]
}))
.await;
assert_eq!(response.status_code().as_u16(), 402);
let body = response.text();
println!("Enriched response body: {}", body);
assert!(body.contains("balance too low"));
}
#[sqlx::test]
#[test_log::test]
async fn test_error_enrichment_middleware_passes_through_legitimate_403(pool: PgPool) {
use crate::test::utils::{add_deployment_to_group, add_user_to_group, create_test_group};
let user = create_test_user(&pool, Role::StandardUser).await;
let group = create_test_group(&pool).await;
add_user_to_group(&pool, user.id, group.id).await;
let mut api_key_conn = pool.acquire().await.unwrap();
let mut api_keys_repo = ApiKeys::new(&mut api_key_conn);
let api_key = api_keys_repo
.create(&ApiKeyCreateDBRequest {
user_id: user.id,
name: "Test Key".to_string(),
description: None,
purpose: ApiKeyPurpose::Realtime,
requests_per_second: None,
burst_size: None,
created_by: user.id,
})
.await
.unwrap();
let mut credits_conn = pool.acquire().await.unwrap();
let mut credits_repo = Credits::new(&mut credits_conn);
credits_repo
.create_transaction(&CreditTransactionCreateDBRequest {
user_id: user.id,
transaction_type: CreditTransactionType::AdminGrant,
amount: Decimal::new(5000, 2),
source_id: uuid::Uuid::new_v4().to_string(),
description: Some("Initial credits".to_string()),
fusillade_batch_id: None,
api_key_id: None,
})
.await
.unwrap();
let endpoint_id = crate::test::utils::create_test_endpoint(&pool, "test-endpoint", user.id).await;
let deployment_id =
crate::test::utils::create_test_model(&pool, "authorized-model-name", "authorized-model", endpoint_id, user.id).await;
add_deployment_to_group(&pool, deployment_id, group.id, user.id).await;
let router = axum::Router::new()
.route(
"/ai/v1/chat/completions",
axum::routing::post(|| async {
axum::response::Response::builder()
.status(StatusCode::FORBIDDEN)
.body(axum::body::Body::from("Rate limit exceeded"))
.unwrap()
}),
)
.layer(axum::middleware::from_fn_with_state(
pool.clone(),
crate::error_enrichment::error_enrichment_middleware,
));
let server = axum_test::TestServer::new(router).expect("Failed to create test server");
let response = server
.post("/ai/v1/chat/completions")
.add_header("authorization", &format!("Bearer {}", api_key.secret))
.json(&serde_json::json!({
"model": "authorized-model",
"messages": [{"role": "user", "content": "Hello"}]
}))
.await;
assert_eq!(response.status_code().as_u16(), 403);
let body = response.text();
assert!(body.contains("Rate limit exceeded"));
assert!(!body.contains("balance"));
assert!(!body.contains("access"));
}
#[sqlx::test]
#[test_log::test]
async fn test_error_enrichment_middleware_ignores_non_ai_paths(pool: PgPool) {
let router = axum::Router::new()
.route(
"/admin/api/v1/users",
axum::routing::get(|| async {
axum::response::Response::builder()
.status(StatusCode::FORBIDDEN)
.body(axum::body::Body::from("Admin Forbidden"))
.unwrap()
}),
)
.layer(axum::middleware::from_fn_with_state(
pool.clone(),
crate::error_enrichment::error_enrichment_middleware,
));
let server = axum_test::TestServer::new(router).expect("Failed to create test server");
let response = server.get("/admin/api/v1/users").await;
assert_eq!(response.status_code().as_u16(), 403);
assert_eq!(response.text(), "Admin Forbidden");
}
#[sqlx::test]
#[test_log::test]
async fn test_error_enrichment_middleware_ignores_non_403_errors(pool: PgPool) {
let router = axum::Router::new()
.route(
"/ai/v1/chat/completions",
axum::routing::post(|| async {
axum::response::Response::builder()
.status(StatusCode::NOT_FOUND)
.header("authorization", "Bearer dummy-key")
.body(axum::body::Body::from("Not Found"))
.unwrap()
}),
)
.layer(axum::middleware::from_fn_with_state(
pool.clone(),
crate::error_enrichment::error_enrichment_middleware,
));
let server = axum_test::TestServer::new(router).expect("Failed to create test server");
let response = server
.post("/ai/v1/chat/completions")
.add_header("authorization", "Bearer test-key")
.json(&serde_json::json!({"model": "test"}))
.await;
assert_eq!(response.status_code().as_u16(), 404);
assert_eq!(response.text(), "Not Found");
}
#[sqlx::test]
#[test_log::test]
async fn test_error_enrichment_middleware_without_auth_header(pool: PgPool) {
let router = axum::Router::new()
.route(
"/ai/v1/chat/completions",
axum::routing::post(|| async {
axum::response::Response::builder()
.status(StatusCode::FORBIDDEN)
.body(axum::body::Body::from("No Auth"))
.unwrap()
}),
)
.layer(axum::middleware::from_fn_with_state(
pool.clone(),
crate::error_enrichment::error_enrichment_middleware,
));
let server = axum_test::TestServer::new(router).expect("Failed to create test server");
let response = server
.post("/ai/v1/chat/completions")
.json(&serde_json::json!({"model": "test"}))
.await;
assert_eq!(response.status_code().as_u16(), 403);
assert_eq!(response.text(), "No Auth");
}
}