use std::fmt;
use serde::{Serialize, Serializer};
use serde_json::{Map, Value};
#[derive(Debug, Clone)]
pub struct Problem {
pub type_uri: String,
pub title: String,
pub status: u16,
pub detail: Option<String>,
pub instance: Option<String>,
pub extensions: Map<String, Value>,
}
impl Serialize for Problem {
fn serialize<S: Serializer>(&self, s: S) -> Result<S::Ok, S::Error> {
use serde::ser::SerializeMap;
let extras = self.extensions.len()
+ self.detail.is_some() as usize
+ self.instance.is_some() as usize;
let mut map = s.serialize_map(Some(3 + extras))?;
map.serialize_entry("type", &self.type_uri)?;
map.serialize_entry("title", &self.title)?;
map.serialize_entry("status", &self.status)?;
if let Some(d) = &self.detail {
map.serialize_entry("detail", d)?;
}
if let Some(i) = &self.instance {
map.serialize_entry("instance", i)?;
}
for (k, v) in &self.extensions {
map.serialize_entry(k, v)?;
}
map.end()
}
}
impl fmt::Display for Problem {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{} {} — {}", self.status, self.title, self.type_uri)?;
if let Some(d) = &self.detail {
write!(f, ": {d}")?;
}
Ok(())
}
}
impl Problem {
const BASE: &'static str = "https://docs.rok.rs/errors";
pub fn new(type_uri: impl Into<String>, title: impl Into<String>, status: u16) -> Self {
Self {
type_uri: type_uri.into(),
title: title.into(),
status,
detail: None,
instance: None,
extensions: Map::new(),
}
}
pub fn detail(mut self, detail: impl Into<String>) -> Self {
self.detail = Some(detail.into());
self
}
pub fn instance(mut self, instance: impl Into<String>) -> Self {
self.instance = Some(instance.into());
self
}
pub fn extend(mut self, key: impl Into<String>, value: impl Into<Value>) -> Self {
self.extensions.insert(key.into(), value.into());
self
}
pub fn title(mut self, title: impl Into<String>) -> Self {
self.title = title.into();
self
}
pub fn not_found(detail: impl Into<String>) -> Self {
Self::new(format!("{}/not-found", Self::BASE), "Not Found", 404).detail(detail)
}
pub fn bad_request(detail: impl Into<String>) -> Self {
Self::new(format!("{}/bad-request", Self::BASE), "Bad Request", 400).detail(detail)
}
pub fn unauthorized(detail: impl Into<String>) -> Self {
Self::new(format!("{}/unauthorized", Self::BASE), "Unauthorized", 401).detail(detail)
}
pub fn forbidden(detail: impl Into<String>) -> Self {
Self::new(format!("{}/forbidden", Self::BASE), "Forbidden", 403).detail(detail)
}
pub fn conflict(detail: impl Into<String>) -> Self {
Self::new(format!("{}/conflict", Self::BASE), "Conflict", 409).detail(detail)
}
pub fn unprocessable(detail: impl Into<String>) -> Self {
Self::new(
format!("{}/unprocessable-entity", Self::BASE),
"Unprocessable Entity",
422,
)
.detail(detail)
}
pub fn too_many_requests(detail: impl Into<String>) -> Self {
Self::new(
format!("{}/too-many-requests", Self::BASE),
"Too Many Requests",
429,
)
.detail(detail)
}
pub fn internal(detail: impl Into<String>) -> Self {
Self::new(
format!("{}/internal-server-error", Self::BASE),
"Internal Server Error",
500,
)
.detail(detail)
}
pub fn service_unavailable(detail: impl Into<String>) -> Self {
Self::new(
format!("{}/service-unavailable", Self::BASE),
"Service Unavailable",
503,
)
.detail(detail)
}
pub fn from_validation(errors: std::collections::HashMap<String, Vec<String>>) -> Self {
let field_errors: Value = errors
.into_iter()
.map(|(k, v)| (k, Value::Array(v.into_iter().map(Value::String).collect())))
.collect::<Map<_, _>>()
.into();
Self::unprocessable("One or more fields failed validation.").extend("errors", field_errors)
}
pub fn to_json_bytes(&self) -> Vec<u8> {
match serde_json::to_vec(self) {
Ok(bytes) => bytes,
Err(_e) => {
#[cfg(feature = "app")]
tracing::error!(error = %_e, "Problem serialization failed");
b"{}"[..].to_vec()
}
}
}
}
#[cfg(feature = "axum")]
mod axum_impl {
use super::Problem;
use axum::{
body::Body,
response::{IntoResponse, Response},
};
use http::{header, StatusCode};
impl IntoResponse for Problem {
fn into_response(self) -> Response {
let status =
StatusCode::from_u16(self.status).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
let body = self.to_json_bytes();
Response::builder()
.status(status)
.header(header::CONTENT_TYPE, "application/problem+json")
.body(Body::from(body))
.unwrap_or_else(|_| {
Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.body(Body::empty())
.unwrap()
})
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn base() -> &'static str {
"https://docs.rok.rs/errors"
}
#[test]
fn not_found_shape() {
let p = Problem::not_found("User 42 does not exist.");
assert_eq!(p.status, 404);
assert_eq!(p.title, "Not Found");
assert_eq!(p.type_uri, format!("{}/not-found", base()));
assert_eq!(p.detail.as_deref(), Some("User 42 does not exist."));
}
#[test]
fn bad_request_shape() {
let p = Problem::bad_request("Missing required field `email`.");
assert_eq!(p.status, 400);
assert_eq!(p.title, "Bad Request");
}
#[test]
fn unauthorized_shape() {
let p = Problem::unauthorized("No valid credentials were supplied.");
assert_eq!(p.status, 401);
assert_eq!(p.title, "Unauthorized");
}
#[test]
fn forbidden_shape() {
let p = Problem::forbidden("You may not delete another user's posts.");
assert_eq!(p.status, 403);
assert_eq!(p.title, "Forbidden");
}
#[test]
fn conflict_shape() {
let p = Problem::conflict("Email address already registered.");
assert_eq!(p.status, 409);
assert_eq!(p.title, "Conflict");
}
#[test]
fn unprocessable_shape() {
let p = Problem::unprocessable("Validation failed.");
assert_eq!(p.status, 422);
assert_eq!(p.title, "Unprocessable Entity");
}
#[test]
fn too_many_requests_shape() {
let p = Problem::too_many_requests("Slow down!");
assert_eq!(p.status, 429);
assert_eq!(p.title, "Too Many Requests");
}
#[test]
fn internal_shape() {
let p = Problem::internal("An unexpected error occurred.");
assert_eq!(p.status, 500);
assert_eq!(p.title, "Internal Server Error");
}
#[test]
fn service_unavailable_shape() {
let p = Problem::service_unavailable("The service is temporarily offline.");
assert_eq!(p.status, 503);
assert_eq!(p.title, "Service Unavailable");
}
#[test]
fn custom_problem_with_extensions() {
let p = Problem::new(
"https://example.com/errors/quota-exceeded",
"Quota Exceeded",
429,
)
.detail("You have exceeded your daily upload quota.")
.extend("limit", 100u64)
.extend("remaining", 0u64);
assert_eq!(p.status, 429);
assert_eq!(p.extensions["limit"], serde_json::json!(100u64));
assert_eq!(p.extensions["remaining"], serde_json::json!(0u64));
}
#[test]
fn instance_field_roundtrip() {
let p = Problem::not_found("Order not found.").instance("/orders/99");
assert_eq!(p.instance.as_deref(), Some("/orders/99"));
}
#[test]
fn title_override() {
let p = Problem::not_found("Custom detail.").title("Resource Not Found");
assert_eq!(p.title, "Resource Not Found");
}
#[test]
fn serialize_mandatory_fields() {
let p = Problem::not_found("test");
let v: serde_json::Value = serde_json::from_slice(&p.to_json_bytes()).unwrap();
assert!(v.get("type").is_some());
assert!(v.get("title").is_some());
assert!(v.get("status").is_some());
}
#[test]
fn serialize_omits_none_fields() {
let p = Problem::not_found("test");
let v: serde_json::Value = serde_json::from_slice(&p.to_json_bytes()).unwrap();
assert!(v.get("instance").is_none());
}
#[test]
fn serialize_includes_optional_fields() {
let p = Problem::not_found("test").instance("/foo/1");
let v: serde_json::Value = serde_json::from_slice(&p.to_json_bytes()).unwrap();
assert_eq!(v["instance"], "/foo/1");
}
#[test]
fn display_format() {
let p = Problem::not_found("User not found.");
let s = p.to_string();
assert!(s.contains("404"));
assert!(s.contains("Not Found"));
assert!(s.contains("User not found."));
}
#[test]
fn from_validation_errors() {
let mut errors = std::collections::HashMap::new();
errors.insert("email".to_string(), vec!["is required".to_string()]);
errors.insert("name".to_string(), vec!["is too short".to_string()]);
let p = Problem::from_validation(errors);
assert_eq!(p.status, 422);
let errors_val = p.extensions.get("errors").expect("errors extension");
assert!(errors_val.get("email").is_some());
assert!(errors_val.get("name").is_some());
}
}