1use axum::response::IntoResponse;
2use schemars::JsonSchema;
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5
6#[cfg(feature = "openapi")]
8use aide::openapi::{MediaType, Operation, ReferenceOr, Response, SchemaObject, StatusCode};
9
10#[cfg(feature = "openapi")]
12pub trait ProblemDetailsVariantInfo {
13 fn get_variant_info(variant_name: &str) -> Option<(u16, String, Option<schemars::Schema>)>;
14}
15
16#[cfg(feature = "openapi")]
18pub fn problem_details_schema() -> schemars::Schema {
19 use schemars::JsonSchema;
20 crate::problem_details::ProblemDetails::json_schema(&mut schemars::SchemaGenerator::default())
21}
22
23#[cfg(feature = "openapi")]
25pub fn register_error_response_by_variant<T>(
26 _ctx: &mut aide::generate::GenContext,
27 operation: &mut Operation,
28 variant_path: &str,
29) where
30 T: ProblemDetailsVariantInfo,
31{
32 let variant_name = variant_path.split("::").last().unwrap_or(variant_path);
33
34 let Some((status_code, description, _schema_opt)) = T::get_variant_info(variant_name) else {
35 tracing::warn!(
36 "Variant '{}' not found in error type '{}' when registering OpenAPI responses",
37 variant_name,
38 std::any::type_name::<T>()
39 );
40 return;
41 };
42
43 let problem_type = format!(
45 "about:blank/{}",
46 variant_name.to_lowercase().replace("::", "-")
47 );
48 let example = serde_json::json!({
49 "type": problem_type,
50 "title": format!("{} Error", variant_name),
51 "status": status_code,
52 "detail": format!("{} occurred", variant_name)
53 });
54
55 let response = Response {
56 description,
57 content: {
58 let mut content = indexmap::IndexMap::new();
59 let media_type = MediaType {
60 schema: Some(SchemaObject {
61 json_schema: problem_details_schema(),
62 example: Some(example),
63 external_docs: None,
64 }),
65 ..Default::default()
66 };
67
68 content.insert("application/problem+json".to_string(), media_type.clone());
69 content.insert("application/json".to_string(), media_type); content
71 },
72 ..Default::default()
73 };
74
75 if operation.responses.is_none() {
77 operation.responses = Some(Default::default());
78 }
79
80 let responses = operation.responses.as_mut().unwrap();
81 let status_code_key = StatusCode::Code(status_code);
82
83 if let Some(existing) = responses.responses.get_mut(&status_code_key) {
84 if let ReferenceOr::Item(existing_response) = existing {
86 if existing_response.description != response.description {
87 existing_response.description = format!(
88 "{}\n- {}",
89 existing_response.description, response.description
90 );
91 }
92 }
93 } else {
94 responses
95 .responses
96 .insert(status_code_key, ReferenceOr::Item(response));
97 }
98}
99
100#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
105pub struct ProblemDetails {
106 #[serde(rename = "type")]
108 pub problem_type: String,
109
110 pub title: String,
112
113 pub status: u16,
115
116 #[serde(skip_serializing_if = "Option::is_none")]
118 pub detail: Option<String>,
119
120 #[serde(skip_serializing_if = "Option::is_none")]
122 pub instance: Option<String>,
123
124 #[serde(flatten)]
126 pub extensions: HashMap<String, serde_json::Value>,
127}
128
129impl ProblemDetails {
130 pub fn new(problem_type: impl Into<String>, title: impl Into<String>, status: u16) -> Self {
132 Self {
133 problem_type: problem_type.into(),
134 title: title.into(),
135 status,
136 detail: None,
137 instance: None,
138 extensions: HashMap::new(),
139 }
140 }
141
142 pub fn with_detail(mut self, detail: impl Into<String>) -> Self {
144 self.detail = Some(detail.into());
145 self
146 }
147
148 pub fn with_instance(mut self, instance: impl Into<String>) -> Self {
150 self.instance = Some(instance.into());
151 self
152 }
153
154 pub fn with_extension(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
156 self.extensions.insert(key.into(), value);
157 self
158 }
159
160 pub fn validation_error(detail: impl Into<String>) -> Self {
162 Self::new("about:blank", "Validation Error", 400).with_detail(detail)
163 }
164
165 pub fn authentication_error() -> Self {
167 Self::new("about:blank", "Authentication Required", 401)
168 .with_detail("Authentication credentials are required to access this resource")
169 }
170
171 pub fn authorization_error() -> Self {
173 Self::new("about:blank", "Insufficient Permissions", 403)
174 .with_detail("You don't have permission to access this resource")
175 }
176
177 pub fn not_found(resource: impl Into<String>) -> Self {
179 Self::new("about:blank", "Resource Not Found", 404)
180 .with_detail(format!("The requested {} was not found", resource.into()))
181 }
182
183 pub fn internal_server_error() -> Self {
185 Self::new("about:blank", "Internal Server Error", 500)
186 .with_detail("An unexpected error occurred while processing your request")
187 }
188
189 pub fn service_unavailable() -> Self {
191 Self::new("about:blank", "Service Unavailable", 503)
192 .with_detail("The service is temporarily unavailable")
193 }
194
195 pub fn custom_problem(
197 problem_type: impl Into<String>,
198 title: impl Into<String>,
199 status: u16,
200 ) -> Self {
201 Self::new(problem_type, title, status)
202 }
203}
204
205impl IntoResponse for ProblemDetails {
206 fn into_response(mut self) -> axum::response::Response {
207 let status = axum::http::StatusCode::from_u16(self.status)
208 .unwrap_or(axum::http::StatusCode::INTERNAL_SERVER_ERROR);
209
210 if self.instance.is_none() {
212 if let Some(uri) = get_current_request_uri() {
213 self.instance = Some(uri);
214 }
215 }
216
217 (
219 status,
220 [("content-type", "application/problem+json")],
221 axum::Json(self),
222 )
223 .into_response()
224 }
225}
226
227tokio::task_local! {
229 static CURRENT_REQUEST_URI: String;
230}
231
232fn get_current_request_uri() -> Option<String> {
234 CURRENT_REQUEST_URI.try_with(|uri| uri.clone()).ok()
235}
236
237pub fn set_current_request_uri(uri: String) {
239 CURRENT_REQUEST_URI.scope(uri, async {
240 });
242}
243
244pub async fn capture_request_uri_middleware(
246 req: axum::http::Request<axum::body::Body>,
247 next: axum::middleware::Next,
248) -> axum::response::Response {
249 let uri = req.uri().to_string();
250
251 CURRENT_REQUEST_URI
253 .scope(uri, async move { next.run(req).await })
254 .await
255}
256
257impl ProblemDetails {
259 pub fn status_code(&self) -> axum::http::StatusCode {
260 axum::http::StatusCode::from_u16(self.status)
261 .unwrap_or(axum::http::StatusCode::INTERNAL_SERVER_ERROR)
262 }
263}
264
265#[cfg(test)]
266mod tests {
267 use super::*;
268
269 #[test]
270 fn test_problem_details_creation() {
271 let problem = ProblemDetails::new("https://example.com/problems/test", "Test Problem", 400)
272 .with_detail("This is a test problem")
273 .with_instance("/test/123")
274 .with_extension("code", serde_json::Value::String("TEST_001".to_string()));
275
276 assert_eq!(problem.problem_type, "https://example.com/problems/test");
277 assert_eq!(problem.title, "Test Problem");
278 assert_eq!(problem.status, 400);
279 assert_eq!(problem.detail, Some("This is a test problem".to_string()));
280 assert_eq!(problem.instance, Some("/test/123".to_string()));
281 assert_eq!(
282 problem.extensions.get("code"),
283 Some(&serde_json::Value::String("TEST_001".to_string()))
284 );
285 }
286
287 #[test]
288 fn test_validation_error() {
289 let problem = ProblemDetails::validation_error("Name is required");
291 assert_eq!(problem.status, 400);
292 assert_eq!(problem.title, "Validation Error");
293 assert_eq!(problem.problem_type, "about:blank");
294 }
295
296 #[test]
297 fn test_into_response() {
298 let problem = ProblemDetails::not_found("user");
299 let response = problem.into_response();
300
301 assert_eq!(response.status(), axum::http::StatusCode::NOT_FOUND);
302 }
303
304 #[test]
305 fn test_status_code() {
306 let problem = ProblemDetails::validation_error("Test error");
307 assert_eq!(problem.status_code(), axum::http::StatusCode::BAD_REQUEST);
308 }
309
310 #[tokio::test]
311 async fn test_automatic_uri_capture() {
312 let test_uri = "/test/path".to_string();
314
315 CURRENT_REQUEST_URI
316 .scope(test_uri.clone(), async {
317 let uri = get_current_request_uri();
318 assert_eq!(uri, Some(test_uri));
319 })
320 .await;
321 }
322}