1use std::fmt;
19
20use serde::{Serialize, Serializer};
21use serde_json::{Map, Value};
22
23#[derive(Debug, Clone)]
27pub struct Problem {
28 pub type_uri: String,
29 pub title: String,
30 pub status: u16,
31 pub detail: Option<String>,
32 pub instance: Option<String>,
33 pub extensions: Map<String, Value>,
34}
35
36impl Serialize for Problem {
37 fn serialize<S: Serializer>(&self, s: S) -> Result<S::Ok, S::Error> {
38 use serde::ser::SerializeMap;
39
40 let extras = self.extensions.len()
41 + self.detail.is_some() as usize
42 + self.instance.is_some() as usize;
43 let mut map = s.serialize_map(Some(3 + extras))?;
44
45 map.serialize_entry("type", &self.type_uri)?;
46 map.serialize_entry("title", &self.title)?;
47 map.serialize_entry("status", &self.status)?;
48
49 if let Some(d) = &self.detail {
50 map.serialize_entry("detail", d)?;
51 }
52 if let Some(i) = &self.instance {
53 map.serialize_entry("instance", i)?;
54 }
55 for (k, v) in &self.extensions {
56 map.serialize_entry(k, v)?;
57 }
58
59 map.end()
60 }
61}
62
63impl fmt::Display for Problem {
64 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
65 write!(f, "{} {} — {}", self.status, self.title, self.type_uri)?;
66 if let Some(d) = &self.detail {
67 write!(f, ": {d}")?;
68 }
69 Ok(())
70 }
71}
72
73impl Problem {
76 const BASE: &'static str = "https://docs.rok.rs/errors";
77
78 pub fn new(type_uri: impl Into<String>, title: impl Into<String>, status: u16) -> Self {
80 Self {
81 type_uri: type_uri.into(),
82 title: title.into(),
83 status,
84 detail: None,
85 instance: None,
86 extensions: Map::new(),
87 }
88 }
89
90 pub fn detail(mut self, detail: impl Into<String>) -> Self {
92 self.detail = Some(detail.into());
93 self
94 }
95
96 pub fn instance(mut self, instance: impl Into<String>) -> Self {
98 self.instance = Some(instance.into());
99 self
100 }
101
102 pub fn extend(mut self, key: impl Into<String>, value: impl Into<Value>) -> Self {
104 self.extensions.insert(key.into(), value.into());
105 self
106 }
107
108 pub fn title(mut self, title: impl Into<String>) -> Self {
110 self.title = title.into();
111 self
112 }
113
114 pub fn not_found(detail: impl Into<String>) -> Self {
118 Self::new(format!("{}/not-found", Self::BASE), "Not Found", 404).detail(detail)
119 }
120
121 pub fn bad_request(detail: impl Into<String>) -> Self {
123 Self::new(format!("{}/bad-request", Self::BASE), "Bad Request", 400).detail(detail)
124 }
125
126 pub fn unauthorized(detail: impl Into<String>) -> Self {
128 Self::new(format!("{}/unauthorized", Self::BASE), "Unauthorized", 401).detail(detail)
129 }
130
131 pub fn forbidden(detail: impl Into<String>) -> Self {
133 Self::new(format!("{}/forbidden", Self::BASE), "Forbidden", 403).detail(detail)
134 }
135
136 pub fn conflict(detail: impl Into<String>) -> Self {
138 Self::new(format!("{}/conflict", Self::BASE), "Conflict", 409).detail(detail)
139 }
140
141 pub fn unprocessable(detail: impl Into<String>) -> Self {
143 Self::new(
144 format!("{}/unprocessable-entity", Self::BASE),
145 "Unprocessable Entity",
146 422,
147 )
148 .detail(detail)
149 }
150
151 pub fn too_many_requests(detail: impl Into<String>) -> Self {
153 Self::new(
154 format!("{}/too-many-requests", Self::BASE),
155 "Too Many Requests",
156 429,
157 )
158 .detail(detail)
159 }
160
161 pub fn internal(detail: impl Into<String>) -> Self {
163 Self::new(
164 format!("{}/internal-server-error", Self::BASE),
165 "Internal Server Error",
166 500,
167 )
168 .detail(detail)
169 }
170
171 pub fn service_unavailable(detail: impl Into<String>) -> Self {
173 Self::new(
174 format!("{}/service-unavailable", Self::BASE),
175 "Service Unavailable",
176 503,
177 )
178 .detail(detail)
179 }
180
181 pub fn from_validation(errors: std::collections::HashMap<String, Vec<String>>) -> Self {
185 let field_errors: Value = errors
186 .into_iter()
187 .map(|(k, v)| (k, Value::Array(v.into_iter().map(Value::String).collect())))
188 .collect::<Map<_, _>>()
189 .into();
190
191 Self::unprocessable("One or more fields failed validation.").extend("errors", field_errors)
192 }
193
194 pub fn to_json_bytes(&self) -> Vec<u8> {
201 match serde_json::to_vec(self) {
202 Ok(bytes) => bytes,
203 Err(_e) => {
204 #[cfg(feature = "app")]
205 tracing::error!(error = %_e, "Problem serialization failed");
206 b"{}"[..].to_vec()
207 }
208 }
209 }
210}
211
212#[cfg(feature = "axum")]
215mod axum_impl {
216 use super::Problem;
217 use axum::{
218 body::Body,
219 response::{IntoResponse, Response},
220 };
221 use http::{header, StatusCode};
222
223 impl IntoResponse for Problem {
224 fn into_response(self) -> Response {
225 let status =
226 StatusCode::from_u16(self.status).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
227 let body = self.to_json_bytes();
228
229 Response::builder()
230 .status(status)
231 .header(header::CONTENT_TYPE, "application/problem+json")
232 .body(Body::from(body))
233 .unwrap_or_else(|_| {
234 Response::builder()
235 .status(StatusCode::INTERNAL_SERVER_ERROR)
236 .body(Body::empty())
237 .unwrap()
238 })
239 }
240 }
241}
242
243#[cfg(test)]
246mod tests {
247 use super::*;
248
249 fn base() -> &'static str {
250 "https://docs.rok.rs/errors"
251 }
252
253 #[test]
254 fn not_found_shape() {
255 let p = Problem::not_found("User 42 does not exist.");
256 assert_eq!(p.status, 404);
257 assert_eq!(p.title, "Not Found");
258 assert_eq!(p.type_uri, format!("{}/not-found", base()));
259 assert_eq!(p.detail.as_deref(), Some("User 42 does not exist."));
260 }
261
262 #[test]
263 fn bad_request_shape() {
264 let p = Problem::bad_request("Missing required field `email`.");
265 assert_eq!(p.status, 400);
266 assert_eq!(p.title, "Bad Request");
267 }
268
269 #[test]
270 fn unauthorized_shape() {
271 let p = Problem::unauthorized("No valid credentials were supplied.");
272 assert_eq!(p.status, 401);
273 assert_eq!(p.title, "Unauthorized");
274 }
275
276 #[test]
277 fn forbidden_shape() {
278 let p = Problem::forbidden("You may not delete another user's posts.");
279 assert_eq!(p.status, 403);
280 assert_eq!(p.title, "Forbidden");
281 }
282
283 #[test]
284 fn conflict_shape() {
285 let p = Problem::conflict("Email address already registered.");
286 assert_eq!(p.status, 409);
287 assert_eq!(p.title, "Conflict");
288 }
289
290 #[test]
291 fn unprocessable_shape() {
292 let p = Problem::unprocessable("Validation failed.");
293 assert_eq!(p.status, 422);
294 assert_eq!(p.title, "Unprocessable Entity");
295 }
296
297 #[test]
298 fn too_many_requests_shape() {
299 let p = Problem::too_many_requests("Slow down!");
300 assert_eq!(p.status, 429);
301 assert_eq!(p.title, "Too Many Requests");
302 }
303
304 #[test]
305 fn internal_shape() {
306 let p = Problem::internal("An unexpected error occurred.");
307 assert_eq!(p.status, 500);
308 assert_eq!(p.title, "Internal Server Error");
309 }
310
311 #[test]
312 fn service_unavailable_shape() {
313 let p = Problem::service_unavailable("The service is temporarily offline.");
314 assert_eq!(p.status, 503);
315 assert_eq!(p.title, "Service Unavailable");
316 }
317
318 #[test]
319 fn custom_problem_with_extensions() {
320 let p = Problem::new(
321 "https://example.com/errors/quota-exceeded",
322 "Quota Exceeded",
323 429,
324 )
325 .detail("You have exceeded your daily upload quota.")
326 .extend("limit", 100u64)
327 .extend("remaining", 0u64);
328
329 assert_eq!(p.status, 429);
330 assert_eq!(p.extensions["limit"], serde_json::json!(100u64));
331 assert_eq!(p.extensions["remaining"], serde_json::json!(0u64));
332 }
333
334 #[test]
335 fn instance_field_roundtrip() {
336 let p = Problem::not_found("Order not found.").instance("/orders/99");
337 assert_eq!(p.instance.as_deref(), Some("/orders/99"));
338 }
339
340 #[test]
341 fn title_override() {
342 let p = Problem::not_found("Custom detail.").title("Resource Not Found");
343 assert_eq!(p.title, "Resource Not Found");
344 }
345
346 #[test]
347 fn serialize_mandatory_fields() {
348 let p = Problem::not_found("test");
349 let v: serde_json::Value = serde_json::from_slice(&p.to_json_bytes()).unwrap();
350 assert!(v.get("type").is_some());
351 assert!(v.get("title").is_some());
352 assert!(v.get("status").is_some());
353 }
354
355 #[test]
356 fn serialize_omits_none_fields() {
357 let p = Problem::not_found("test");
358 let v: serde_json::Value = serde_json::from_slice(&p.to_json_bytes()).unwrap();
359 assert!(v.get("instance").is_none());
360 }
361
362 #[test]
363 fn serialize_includes_optional_fields() {
364 let p = Problem::not_found("test").instance("/foo/1");
365 let v: serde_json::Value = serde_json::from_slice(&p.to_json_bytes()).unwrap();
366 assert_eq!(v["instance"], "/foo/1");
367 }
368
369 #[test]
370 fn display_format() {
371 let p = Problem::not_found("User not found.");
372 let s = p.to_string();
373 assert!(s.contains("404"));
374 assert!(s.contains("Not Found"));
375 assert!(s.contains("User not found."));
376 }
377
378 #[test]
379 fn from_validation_errors() {
380 let mut errors = std::collections::HashMap::new();
381 errors.insert("email".to_string(), vec!["is required".to_string()]);
382 errors.insert("name".to_string(), vec!["is too short".to_string()]);
383
384 let p = Problem::from_validation(errors);
385 assert_eq!(p.status, 422);
386 let errors_val = p.extensions.get("errors").expect("errors extension");
387 assert!(errors_val.get("email").is_some());
388 assert!(errors_val.get("name").is_some());
389 }
390}