use std::collections::BTreeMap;
use http::StatusCode;
use http::header::HeaderValue;
use serde::Deserialize;
use serde::Serialize;
use crate::body::TakoBody;
use crate::responder::Responder;
use crate::types::Response;
pub const PROBLEM_JSON: &str = "application/problem+json";
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Problem {
#[serde(rename = "type", skip_serializing_if = "Option::is_none")]
pub r#type: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub title: Option<String>,
pub status: u16,
#[serde(skip_serializing_if = "Option::is_none")]
pub detail: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub instance: Option<String>,
#[serde(flatten)]
pub extensions: BTreeMap<String, serde_json::Value>,
}
impl Problem {
pub fn from_status(status: StatusCode) -> Self {
Self {
r#type: None,
title: status.canonical_reason().map(str::to_string),
status: status.as_u16(),
detail: None,
instance: None,
extensions: BTreeMap::new(),
}
}
pub fn with_detail(mut self, detail: impl Into<String>) -> Self {
self.detail = Some(detail.into());
self
}
pub fn with_type(mut self, type_uri: impl Into<String>) -> Self {
self.r#type = Some(type_uri.into());
self
}
pub fn with_instance(mut self, instance: impl Into<String>) -> Self {
self.instance = Some(instance.into());
self
}
pub fn with_extension(
mut self,
key: impl Into<String>,
value: impl Into<serde_json::Value>,
) -> Self {
self.extensions.insert(key.into(), value.into());
self
}
}
impl Responder for Problem {
fn into_response(self) -> Response {
let status = StatusCode::from_u16(self.status).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
let body = serde_json::to_vec(&self).unwrap_or_else(|_| b"{}".to_vec());
let mut res = Response::new(TakoBody::from(body));
*res.status_mut() = status;
res.headers_mut().insert(
http::header::CONTENT_TYPE,
HeaderValue::from_static(PROBLEM_JSON),
);
res
}
}
pub fn default_problem_responder(response: Response) -> Response {
let status = response.status();
if !status.is_client_error() && !status.is_server_error() {
return response;
}
if let Some(ct) = response.headers().get(http::header::CONTENT_TYPE)
&& let Ok(s) = ct.to_str()
{
let essence = s
.split(';')
.next()
.unwrap_or("")
.trim()
.to_ascii_lowercase();
if essence == "application/json" || essence == "application/problem+json" {
return response;
}
}
let problem = Problem::from_status(status);
problem.into_response()
}
#[cfg(test)]
mod tests {
use http::Response as HttpResponse;
use http_body_util::BodyExt;
use super::*;
fn body_string(resp: Response) -> String {
tokio::runtime::Builder::new_current_thread()
.build()
.unwrap()
.block_on(async {
let bytes = resp.into_body().collect().await.unwrap().to_bytes();
String::from_utf8(bytes.to_vec()).unwrap()
})
}
#[test]
fn problem_from_status_uses_canonical_reason() {
let p = Problem::from_status(StatusCode::NOT_FOUND);
assert_eq!(p.status, 404);
assert_eq!(p.title.as_deref(), Some("Not Found"));
assert!(p.detail.is_none());
assert!(p.r#type.is_none());
}
#[test]
fn problem_with_detail_setter() {
let p = Problem::from_status(StatusCode::BAD_REQUEST).with_detail("missing field 'name'");
assert_eq!(p.detail.as_deref(), Some("missing field 'name'"));
}
#[test]
fn problem_with_type_and_instance() {
let p = Problem::from_status(StatusCode::CONFLICT)
.with_type("https://example.com/probs/conflict")
.with_instance("/orders/42");
assert_eq!(
p.r#type.as_deref(),
Some("https://example.com/probs/conflict")
);
assert_eq!(p.instance.as_deref(), Some("/orders/42"));
}
#[test]
fn problem_with_extension_round_trips_through_serde() {
let p = Problem::from_status(StatusCode::UNPROCESSABLE_ENTITY)
.with_extension("invalid_params", serde_json::json!(["email", "age"]));
let body = serde_json::to_string(&p).unwrap();
assert!(body.contains(r#""invalid_params":["email","age"]"#));
assert!(body.contains(r#""status":422"#));
}
#[test]
fn problem_into_response_writes_problem_json_content_type() {
let p = Problem::from_status(StatusCode::INTERNAL_SERVER_ERROR);
let resp = p.into_response();
assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR);
assert_eq!(
resp.headers().get(http::header::CONTENT_TYPE).unwrap(),
&HeaderValue::from_static(PROBLEM_JSON),
);
}
#[test]
fn problem_into_response_serializes_canonical_fields() {
let p = Problem::from_status(StatusCode::NOT_FOUND).with_detail("user 7 missing");
let body = body_string(p.into_response());
assert!(body.contains(r#""title":"Not Found""#));
assert!(body.contains(r#""status":404"#));
assert!(body.contains(r#""detail":"user 7 missing""#));
}
#[test]
fn default_problem_responder_replaces_plain_response() {
let mut resp = HttpResponse::new(TakoBody::from("oops"));
*resp.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
resp.headers_mut().insert(
http::header::CONTENT_TYPE,
HeaderValue::from_static("text/plain"),
);
let upgraded = default_problem_responder(resp);
assert_eq!(
upgraded.headers().get(http::header::CONTENT_TYPE).unwrap(),
&HeaderValue::from_static(PROBLEM_JSON),
);
let body = body_string(upgraded);
assert!(body.contains(r#""status":500"#));
}
#[test]
fn default_problem_responder_passes_through_existing_json() {
let mut resp = HttpResponse::new(TakoBody::from(r#"{"err":"x"}"#));
*resp.status_mut() = StatusCode::BAD_REQUEST;
resp.headers_mut().insert(
http::header::CONTENT_TYPE,
HeaderValue::from_static("application/json"),
);
let unchanged = default_problem_responder(resp);
let body = body_string(unchanged);
assert_eq!(body, r#"{"err":"x"}"#);
}
#[test]
fn default_problem_responder_passes_through_problem_json() {
let mut resp = HttpResponse::new(TakoBody::from(r#"{"status":418}"#));
*resp.status_mut() = StatusCode::IM_A_TEAPOT;
resp.headers_mut().insert(
http::header::CONTENT_TYPE,
HeaderValue::from_static(PROBLEM_JSON),
);
let unchanged = default_problem_responder(resp);
let body = body_string(unchanged);
assert_eq!(body, r#"{"status":418}"#);
}
}