Skip to main content

rok_core/
problem.rs

1//! RFC 9457 — Problem Details for HTTP APIs.
2//!
3//! # Example
4//!
5//! ```rust
6//! use rok_core::Problem;
7//!
8//! let p = Problem::not_found("User 42 does not exist.");
9//! assert_eq!(p.status, 404);
10//!
11//! let custom = Problem::new("https://example.com/errors/quota-exceeded", "Quota Exceeded", 429)
12//!     .detail("You have exceeded your daily upload quota.")
13//!     .extend("limit", 100)
14//!     .extend("remaining", 0);
15//! assert_eq!(custom.status, 429);
16//! ```
17
18use std::fmt;
19
20use serde::{Serialize, Serializer};
21use serde_json::{Map, Value};
22
23// ── Core struct ───────────────────────────────────────────────────────────────
24
25/// RFC 9457 problem details response.
26#[derive(Debug, Clone)]
27pub struct Problem {
28    pub type_uri: String,
29    pub title: String,
30    pub status: u16,
31    pub detail: Option<String>,
32    pub instance: Option<String>,
33    pub extensions: Map<String, Value>,
34}
35
36impl Serialize for Problem {
37    fn serialize<S: Serializer>(&self, s: S) -> Result<S::Ok, S::Error> {
38        use serde::ser::SerializeMap;
39
40        let extras = self.extensions.len()
41            + self.detail.is_some() as usize
42            + self.instance.is_some() as usize;
43        let mut map = s.serialize_map(Some(3 + extras))?;
44
45        map.serialize_entry("type", &self.type_uri)?;
46        map.serialize_entry("title", &self.title)?;
47        map.serialize_entry("status", &self.status)?;
48
49        if let Some(d) = &self.detail {
50            map.serialize_entry("detail", d)?;
51        }
52        if let Some(i) = &self.instance {
53            map.serialize_entry("instance", i)?;
54        }
55        for (k, v) in &self.extensions {
56            map.serialize_entry(k, v)?;
57        }
58
59        map.end()
60    }
61}
62
63impl fmt::Display for Problem {
64    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
65        write!(f, "{} {} — {}", self.status, self.title, self.type_uri)?;
66        if let Some(d) = &self.detail {
67            write!(f, ": {d}")?;
68        }
69        Ok(())
70    }
71}
72
73// ── Builder ───────────────────────────────────────────────────────────────────
74
75impl Problem {
76    const BASE: &'static str = "https://docs.rok.rs/errors";
77
78    /// Create a problem with a fully-qualified `type_uri`.
79    pub fn new(type_uri: impl Into<String>, title: impl Into<String>, status: u16) -> Self {
80        Self {
81            type_uri: type_uri.into(),
82            title: title.into(),
83            status,
84            detail: None,
85            instance: None,
86            extensions: Map::new(),
87        }
88    }
89
90    /// Human-readable explanation of the specific occurrence.
91    pub fn detail(mut self, detail: impl Into<String>) -> Self {
92        self.detail = Some(detail.into());
93        self
94    }
95
96    /// URI reference that identifies the specific occurrence.
97    pub fn instance(mut self, instance: impl Into<String>) -> Self {
98        self.instance = Some(instance.into());
99        self
100    }
101
102    /// Attach a custom extension member (serialisable value).
103    pub fn extend(mut self, key: impl Into<String>, value: impl Into<Value>) -> Self {
104        self.extensions.insert(key.into(), value.into());
105        self
106    }
107
108    /// Override the title (fluent).
109    pub fn title(mut self, title: impl Into<String>) -> Self {
110        self.title = title.into();
111        self
112    }
113
114    // ── Predefined constructors ───────────────────────────────────────────────
115
116    /// 404 Not Found.
117    pub fn not_found(detail: impl Into<String>) -> Self {
118        Self::new(format!("{}/not-found", Self::BASE), "Not Found", 404).detail(detail)
119    }
120
121    /// 400 Bad Request.
122    pub fn bad_request(detail: impl Into<String>) -> Self {
123        Self::new(format!("{}/bad-request", Self::BASE), "Bad Request", 400).detail(detail)
124    }
125
126    /// 401 Unauthorized.
127    pub fn unauthorized(detail: impl Into<String>) -> Self {
128        Self::new(format!("{}/unauthorized", Self::BASE), "Unauthorized", 401).detail(detail)
129    }
130
131    /// 403 Forbidden.
132    pub fn forbidden(detail: impl Into<String>) -> Self {
133        Self::new(format!("{}/forbidden", Self::BASE), "Forbidden", 403).detail(detail)
134    }
135
136    /// 409 Conflict.
137    pub fn conflict(detail: impl Into<String>) -> Self {
138        Self::new(format!("{}/conflict", Self::BASE), "Conflict", 409).detail(detail)
139    }
140
141    /// 422 Unprocessable Entity.
142    pub fn unprocessable(detail: impl Into<String>) -> Self {
143        Self::new(
144            format!("{}/unprocessable-entity", Self::BASE),
145            "Unprocessable Entity",
146            422,
147        )
148        .detail(detail)
149    }
150
151    /// 429 Too Many Requests.
152    pub fn too_many_requests(detail: impl Into<String>) -> Self {
153        Self::new(
154            format!("{}/too-many-requests", Self::BASE),
155            "Too Many Requests",
156            429,
157        )
158        .detail(detail)
159    }
160
161    /// 500 Internal Server Error.
162    pub fn internal(detail: impl Into<String>) -> Self {
163        Self::new(
164            format!("{}/internal-server-error", Self::BASE),
165            "Internal Server Error",
166            500,
167        )
168        .detail(detail)
169    }
170
171    /// 503 Service Unavailable.
172    pub fn service_unavailable(detail: impl Into<String>) -> Self {
173        Self::new(
174            format!("{}/service-unavailable", Self::BASE),
175            "Service Unavailable",
176            503,
177        )
178        .detail(detail)
179    }
180
181    // ── Validation integration ────────────────────────────────────────────────
182
183    /// Build an unprocessable-entity problem from a map of field → errors.
184    pub fn from_validation(errors: std::collections::HashMap<String, Vec<String>>) -> Self {
185        let field_errors: Value = errors
186            .into_iter()
187            .map(|(k, v)| (k, Value::Array(v.into_iter().map(Value::String).collect())))
188            .collect::<Map<_, _>>()
189            .into();
190
191        Self::unprocessable("One or more fields failed validation.").extend("errors", field_errors)
192    }
193
194    // ── Serialisation helpers ─────────────────────────────────────────────────
195
196    /// Serialize to a JSON byte vec.
197    ///
198    /// On serialization failure, logs the error and returns `b"{}"` as a
199    /// minimal RFC 9457-compliant fallback.
200    pub fn to_json_bytes(&self) -> Vec<u8> {
201        match serde_json::to_vec(self) {
202            Ok(bytes) => bytes,
203            Err(_e) => {
204                #[cfg(feature = "app")]
205                tracing::error!(error = %_e, "Problem serialization failed");
206                b"{}"[..].to_vec()
207            }
208        }
209    }
210}
211
212// ── axum IntoResponse ─────────────────────────────────────────────────────────
213
214#[cfg(feature = "axum")]
215mod axum_impl {
216    use super::Problem;
217    use axum::{
218        body::Body,
219        response::{IntoResponse, Response},
220    };
221    use http::{header, StatusCode};
222
223    impl IntoResponse for Problem {
224        fn into_response(self) -> Response {
225            let status =
226                StatusCode::from_u16(self.status).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
227            let body = self.to_json_bytes();
228
229            Response::builder()
230                .status(status)
231                .header(header::CONTENT_TYPE, "application/problem+json")
232                .body(Body::from(body))
233                .unwrap_or_else(|_| {
234                    Response::builder()
235                        .status(StatusCode::INTERNAL_SERVER_ERROR)
236                        .body(Body::empty())
237                        .unwrap()
238                })
239        }
240    }
241}
242
243// ── Tests ─────────────────────────────────────────────────────────────────────
244
245#[cfg(test)]
246mod tests {
247    use super::*;
248
249    fn base() -> &'static str {
250        "https://docs.rok.rs/errors"
251    }
252
253    #[test]
254    fn not_found_shape() {
255        let p = Problem::not_found("User 42 does not exist.");
256        assert_eq!(p.status, 404);
257        assert_eq!(p.title, "Not Found");
258        assert_eq!(p.type_uri, format!("{}/not-found", base()));
259        assert_eq!(p.detail.as_deref(), Some("User 42 does not exist."));
260    }
261
262    #[test]
263    fn bad_request_shape() {
264        let p = Problem::bad_request("Missing required field `email`.");
265        assert_eq!(p.status, 400);
266        assert_eq!(p.title, "Bad Request");
267    }
268
269    #[test]
270    fn unauthorized_shape() {
271        let p = Problem::unauthorized("No valid credentials were supplied.");
272        assert_eq!(p.status, 401);
273        assert_eq!(p.title, "Unauthorized");
274    }
275
276    #[test]
277    fn forbidden_shape() {
278        let p = Problem::forbidden("You may not delete another user's posts.");
279        assert_eq!(p.status, 403);
280        assert_eq!(p.title, "Forbidden");
281    }
282
283    #[test]
284    fn conflict_shape() {
285        let p = Problem::conflict("Email address already registered.");
286        assert_eq!(p.status, 409);
287        assert_eq!(p.title, "Conflict");
288    }
289
290    #[test]
291    fn unprocessable_shape() {
292        let p = Problem::unprocessable("Validation failed.");
293        assert_eq!(p.status, 422);
294        assert_eq!(p.title, "Unprocessable Entity");
295    }
296
297    #[test]
298    fn too_many_requests_shape() {
299        let p = Problem::too_many_requests("Slow down!");
300        assert_eq!(p.status, 429);
301        assert_eq!(p.title, "Too Many Requests");
302    }
303
304    #[test]
305    fn internal_shape() {
306        let p = Problem::internal("An unexpected error occurred.");
307        assert_eq!(p.status, 500);
308        assert_eq!(p.title, "Internal Server Error");
309    }
310
311    #[test]
312    fn service_unavailable_shape() {
313        let p = Problem::service_unavailable("The service is temporarily offline.");
314        assert_eq!(p.status, 503);
315        assert_eq!(p.title, "Service Unavailable");
316    }
317
318    #[test]
319    fn custom_problem_with_extensions() {
320        let p = Problem::new(
321            "https://example.com/errors/quota-exceeded",
322            "Quota Exceeded",
323            429,
324        )
325        .detail("You have exceeded your daily upload quota.")
326        .extend("limit", 100u64)
327        .extend("remaining", 0u64);
328
329        assert_eq!(p.status, 429);
330        assert_eq!(p.extensions["limit"], serde_json::json!(100u64));
331        assert_eq!(p.extensions["remaining"], serde_json::json!(0u64));
332    }
333
334    #[test]
335    fn instance_field_roundtrip() {
336        let p = Problem::not_found("Order not found.").instance("/orders/99");
337        assert_eq!(p.instance.as_deref(), Some("/orders/99"));
338    }
339
340    #[test]
341    fn title_override() {
342        let p = Problem::not_found("Custom detail.").title("Resource Not Found");
343        assert_eq!(p.title, "Resource Not Found");
344    }
345
346    #[test]
347    fn serialize_mandatory_fields() {
348        let p = Problem::not_found("test");
349        let v: serde_json::Value = serde_json::from_slice(&p.to_json_bytes()).unwrap();
350        assert!(v.get("type").is_some());
351        assert!(v.get("title").is_some());
352        assert!(v.get("status").is_some());
353    }
354
355    #[test]
356    fn serialize_omits_none_fields() {
357        let p = Problem::not_found("test");
358        let v: serde_json::Value = serde_json::from_slice(&p.to_json_bytes()).unwrap();
359        assert!(v.get("instance").is_none());
360    }
361
362    #[test]
363    fn serialize_includes_optional_fields() {
364        let p = Problem::not_found("test").instance("/foo/1");
365        let v: serde_json::Value = serde_json::from_slice(&p.to_json_bytes()).unwrap();
366        assert_eq!(v["instance"], "/foo/1");
367    }
368
369    #[test]
370    fn display_format() {
371        let p = Problem::not_found("User not found.");
372        let s = p.to_string();
373        assert!(s.contains("404"));
374        assert!(s.contains("Not Found"));
375        assert!(s.contains("User not found."));
376    }
377
378    #[test]
379    fn from_validation_errors() {
380        let mut errors = std::collections::HashMap::new();
381        errors.insert("email".to_string(), vec!["is required".to_string()]);
382        errors.insert("name".to_string(), vec!["is too short".to_string()]);
383
384        let p = Problem::from_validation(errors);
385        assert_eq!(p.status, 422);
386        let errors_val = p.extensions.get("errors").expect("errors extension");
387        assert!(errors_val.get("email").is_some());
388        assert!(errors_val.get("name").is_some());
389    }
390}