use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use axum::response::IntoResponse;
#[cfg(feature = "openapi")]
use aide::openapi::{MediaType, Operation, ReferenceOr, Response, SchemaObject, StatusCode};
#[cfg(feature = "openapi")]
pub trait ProblemDetailsVariantInfo {
fn get_variant_info(variant_name: &str) -> Option<(u16, String, Option<schemars::Schema>)>;
}
#[cfg(feature = "openapi")]
pub fn problem_details_schema() -> schemars::Schema {
use schemars::JsonSchema;
crate::problem_details::ProblemDetails::json_schema(&mut schemars::SchemaGenerator::default())
}
#[cfg(feature = "openapi")]
pub fn register_error_response_by_variant<T>(
_ctx: &mut aide::generate::GenContext,
operation: &mut Operation,
variant_path: &str,
) where
T: ProblemDetailsVariantInfo,
{
let variant_name = variant_path.split("::").last().unwrap_or(variant_path);
let Some((status_code, description, _schema_opt)) = T::get_variant_info(variant_name) else {
tracing::warn!("Variant '{}' not found in error type '{}' when registering OpenAPI responses",
variant_name, std::any::type_name::<T>());
return;
};
let problem_type = format!("about:blank/{}", variant_name.to_lowercase().replace("::", "-"));
let example = serde_json::json!({
"type": problem_type,
"title": format!("{} Error", variant_name),
"status": status_code,
"detail": format!("{} occurred", variant_name)
});
let response = Response {
description,
content: {
let mut content = indexmap::IndexMap::new();
let media_type = MediaType {
schema: Some(SchemaObject {
json_schema: problem_details_schema(),
example: Some(example),
external_docs: None,
}),
..Default::default()
};
content.insert("application/problem+json".to_string(), media_type.clone());
content.insert("application/json".to_string(), media_type); content
},
..Default::default()
};
if operation.responses.is_none() {
operation.responses = Some(Default::default());
}
let responses = operation.responses.as_mut().unwrap();
let status_code_key = StatusCode::Code(status_code);
if let Some(existing) = responses.responses.get_mut(&status_code_key) {
if let ReferenceOr::Item(existing_response) = existing {
if existing_response.description != response.description {
existing_response.description = format!("{}\n- {}", existing_response.description, response.description);
}
}
} else {
responses.responses.insert(status_code_key, ReferenceOr::Item(response));
}
}
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
pub struct ProblemDetails {
#[serde(rename = "type")]
pub problem_type: String,
pub title: String,
pub status: u16,
#[serde(skip_serializing_if = "Option::is_none")]
pub detail: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub instance: Option<String>,
#[serde(flatten)]
pub extensions: HashMap<String, serde_json::Value>,
}
impl ProblemDetails {
pub fn new(problem_type: impl Into<String>, title: impl Into<String>, status: u16) -> Self {
Self {
problem_type: problem_type.into(),
title: title.into(),
status,
detail: None,
instance: None,
extensions: HashMap::new(),
}
}
pub fn with_detail(mut self, detail: impl Into<String>) -> Self {
self.detail = Some(detail.into());
self
}
pub fn with_instance(mut self, instance: impl Into<String>) -> Self {
self.instance = Some(instance.into());
self
}
pub fn with_extension(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
self.extensions.insert(key.into(), value);
self
}
pub fn validation_error(detail: impl Into<String>) -> Self {
Self::new(
"about:blank",
"Validation Error",
400,
)
.with_detail(detail)
}
pub fn authentication_error() -> Self {
Self::new(
"about:blank",
"Authentication Required",
401,
)
.with_detail("Authentication credentials are required to access this resource")
}
pub fn authorization_error() -> Self {
Self::new(
"about:blank",
"Insufficient Permissions",
403,
)
.with_detail("You don't have permission to access this resource")
}
pub fn not_found(resource: impl Into<String>) -> Self {
Self::new(
"about:blank",
"Resource Not Found",
404,
)
.with_detail(format!("The requested {} was not found", resource.into()))
}
pub fn internal_server_error() -> Self {
Self::new(
"about:blank",
"Internal Server Error",
500,
)
.with_detail("An unexpected error occurred while processing your request")
}
pub fn service_unavailable() -> Self {
Self::new(
"about:blank",
"Service Unavailable",
503,
)
.with_detail("The service is temporarily unavailable")
}
pub fn custom_problem(problem_type: impl Into<String>, title: impl Into<String>, status: u16) -> Self {
Self::new(
problem_type,
title,
status,
)
}
}
impl IntoResponse for ProblemDetails {
fn into_response(mut self) -> axum::response::Response {
let status = axum::http::StatusCode::from_u16(self.status)
.unwrap_or(axum::http::StatusCode::INTERNAL_SERVER_ERROR);
if self.instance.is_none() {
if let Some(uri) = get_current_request_uri() {
self.instance = Some(uri);
}
}
(
status,
[("content-type", "application/problem+json")],
axum::Json(self),
).into_response()
}
}
tokio::task_local! {
static CURRENT_REQUEST_URI: String;
}
fn get_current_request_uri() -> Option<String> {
CURRENT_REQUEST_URI.try_with(|uri| uri.clone()).ok()
}
pub fn set_current_request_uri(uri: String) {
CURRENT_REQUEST_URI.scope(uri, async {
});
}
pub async fn capture_request_uri_middleware(
req: axum::http::Request<axum::body::Body>,
next: axum::middleware::Next,
) -> axum::response::Response {
let uri = req.uri().to_string();
CURRENT_REQUEST_URI.scope(uri, async move {
next.run(req).await
}).await
}
impl ProblemDetails {
pub fn status_code(&self) -> axum::http::StatusCode {
axum::http::StatusCode::from_u16(self.status)
.unwrap_or(axum::http::StatusCode::INTERNAL_SERVER_ERROR)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_problem_details_creation() {
let problem = ProblemDetails::new("https://example.com/problems/test", "Test Problem", 400)
.with_detail("This is a test problem")
.with_instance("/test/123")
.with_extension("code", serde_json::Value::String("TEST_001".to_string()));
assert_eq!(problem.problem_type, "https://example.com/problems/test");
assert_eq!(problem.title, "Test Problem");
assert_eq!(problem.status, 400);
assert_eq!(problem.detail, Some("This is a test problem".to_string()));
assert_eq!(problem.instance, Some("/test/123".to_string()));
assert_eq!(problem.extensions.get("code"), Some(&serde_json::Value::String("TEST_001".to_string())));
}
#[test]
fn test_validation_error() {
let problem = ProblemDetails::validation_error("Name is required");
assert_eq!(problem.status, 400);
assert_eq!(problem.title, "Validation Error");
assert_eq!(problem.problem_type, "about:blank");
}
#[test]
fn test_into_response() {
let problem = ProblemDetails::not_found("user");
let response = problem.into_response();
assert_eq!(response.status(), axum::http::StatusCode::NOT_FOUND);
}
#[test]
fn test_status_code() {
let problem = ProblemDetails::validation_error("Test error");
assert_eq!(problem.status_code(), axum::http::StatusCode::BAD_REQUEST);
}
#[tokio::test]
async fn test_automatic_uri_capture() {
let test_uri = "/test/path".to_string();
CURRENT_REQUEST_URI.scope(test_uri.clone(), async {
let uri = get_current_request_uri();
assert_eq!(uri, Some(test_uri));
}).await;
}
}