pub use api_bones::error::{ApiError, ErrorCode, ProblemJson, ValidationError};
use axum::response::{IntoResponse, Response};
#[derive(Debug)]
pub struct HandlerError(pub ApiError);
impl HandlerError {
pub fn new(code: ErrorCode, detail: impl Into<String>) -> Self {
Self(ApiError::new(code, detail))
}
pub fn with_request_id(mut self, id: uuid::Uuid) -> Self {
self.0 = self.0.with_request_id(id);
self
}
pub fn with_errors(mut self, errors: Vec<ValidationError>) -> Self {
self.0 = self.0.with_errors(errors);
self
}
#[cfg(feature = "database")]
pub fn from_sqlx(err: &sqlx::Error) -> Self {
match err {
sqlx::Error::RowNotFound => {
Self::new(ErrorCode::ResourceNotFound, "resource not found")
}
sqlx::Error::Database(db_err) => {
if db_err.code().as_deref() == Some("23505") {
Self::new(ErrorCode::ResourceAlreadyExists, "resource already exists")
} else {
tracing::error!(error = %err, "database error");
Self::new(ErrorCode::InternalServerError, "internal server error")
}
}
_ => {
tracing::error!(error = %err, "database error");
Self::new(ErrorCode::InternalServerError, "internal server error")
}
}
}
}
impl From<ApiError> for HandlerError {
fn from(e: ApiError) -> Self {
Self(e)
}
}
impl IntoResponse for HandlerError {
fn into_response(self) -> Response {
ProblemJson::from(self.0).into_response()
}
}
pub struct UnconstrainedResponse(Response);
impl UnconstrainedResponse {
pub fn new(r: impl IntoResponse) -> Self {
Self(r.into_response())
}
}
impl IntoResponse for UnconstrainedResponse {
fn into_response(self) -> Response {
self.0
}
}
#[cfg(feature = "rfc-types")]
pub use rfc_ok::RfcOk;
#[cfg(feature = "rfc-types")]
mod rfc_ok {
use axum::{
body::Body,
http::{HeaderMap, HeaderValue, StatusCode, header},
response::{IntoResponse, Response},
};
use std::marker::PhantomData;
pub struct RfcOk<T> {
pub(super) status: StatusCode,
pub(super) headers: HeaderMap,
pub(super) body: Vec<u8>,
pub(super) _data: PhantomData<fn() -> T>,
}
impl<T> RfcOk<T> {
pub(super) fn new(status: StatusCode, headers: HeaderMap, body: Vec<u8>) -> Self {
Self {
status,
headers,
body,
_data: PhantomData,
}
}
pub fn status(&self) -> StatusCode {
self.status
}
pub fn headers(&self) -> &HeaderMap {
&self.headers
}
pub fn body_json(&self) -> serde_json::Value {
serde_json::from_slice(&self.body).expect("body is always valid JSON")
}
}
unsafe impl<T> Send for RfcOk<T> {}
unsafe impl<T> Sync for RfcOk<T> {}
impl<T> IntoResponse for RfcOk<T> {
fn into_response(self) -> Response {
let mut resp = Response::new(Body::from(self.body));
*resp.status_mut() = self.status;
*resp.headers_mut() = self.headers;
resp.headers_mut().insert(
header::CONTENT_TYPE,
HeaderValue::from_static("application/json"),
);
resp
}
}
}
#[cfg(feature = "rfc-types")]
mod type_aliases {
use super::{HandlerError, RfcOk};
use api_bones::PaginatedResponse;
pub type HandlerResponse<T> = Result<RfcOk<T>, HandlerError>;
pub type HandlerListResponse<T> = Result<RfcOk<PaginatedResponse<T>>, HandlerError>;
pub type CreatedResponse<T> = Result<RfcOk<T>, HandlerError>;
pub type CreatedAtResponse<T> = Result<RfcOk<T>, HandlerError>;
pub type EtaggedHandlerResponse<T> = Result<RfcOk<T>, HandlerError>;
}
#[cfg(not(feature = "rfc-types"))]
mod type_aliases {
use super::HandlerError;
pub type HandlerResponse<T> = Result<
(
axum::http::StatusCode,
axum::Json<api_bones::ApiResponse<T>>,
),
HandlerError,
>;
pub type HandlerListResponse<T> =
Result<axum::Json<api_bones::ApiResponse<api_bones::PaginatedResponse<T>>>, HandlerError>;
pub type CreatedResponse<T> = Result<
(
axum::http::StatusCode,
axum::Json<api_bones::ApiResponse<T>>,
),
HandlerError,
>;
pub type CreatedAtResponse<T> = Result<
(
axum::http::StatusCode,
axum::http::HeaderMap,
axum::Json<api_bones::ApiResponse<T>>,
),
HandlerError,
>;
pub type EtaggedHandlerResponse<T> = Result<
(
axum::http::StatusCode,
api_bones::etag::ETag,
axum::Json<api_bones::ApiResponse<T>>,
),
HandlerError,
>;
}
pub use type_aliases::{
CreatedAtResponse, CreatedResponse, EtaggedHandlerResponse, HandlerListResponse,
HandlerResponse,
};
#[cfg(feature = "rfc-types")]
pub fn created<T: serde::Serialize>(value: T) -> CreatedResponse<T> {
let body = serde_json::to_vec(&api_bones::ApiResponse::builder(value).build())
.expect("ApiResponse<T> is always serializable");
Ok(RfcOk::new(
axum::http::StatusCode::CREATED,
axum::http::HeaderMap::new(),
body,
))
}
#[cfg(not(feature = "rfc-types"))]
pub fn created<T>(value: T) -> CreatedResponse<T> {
Ok((
axum::http::StatusCode::CREATED,
axum::Json(api_bones::ApiResponse::builder(value).build()),
))
}
#[cfg(feature = "rfc-types")]
pub fn created_at<T: serde::Serialize>(location: &str, value: T) -> CreatedAtResponse<T> {
let mut headers = axum::http::HeaderMap::new();
headers.insert(
axum::http::header::LOCATION,
location.parse().expect("valid Location URI"),
);
let body = serde_json::to_vec(&api_bones::ApiResponse::builder(value).build())
.expect("ApiResponse<T> is always serializable");
Ok(RfcOk::new(axum::http::StatusCode::CREATED, headers, body))
}
#[cfg(not(feature = "rfc-types"))]
pub fn created_at<T>(location: &str, value: T) -> CreatedAtResponse<T> {
let mut headers = axum::http::HeaderMap::new();
headers.insert(
axum::http::header::LOCATION,
location.parse().expect("valid Location URI"),
);
Ok((
axum::http::StatusCode::CREATED,
headers,
axum::Json(api_bones::ApiResponse::builder(value).build()),
))
}
#[cfg(feature = "rfc-types")]
pub fn created_under<T: api_bones::HasId + serde::Serialize>(
prefix: &str,
value: T,
) -> CreatedAtResponse<T> {
let location = format!("{}/{}", prefix.trim_end_matches('/'), value.id());
created_at(&location, value)
}
#[cfg(not(feature = "rfc-types"))]
pub fn created_under<T: api_bones::HasId>(prefix: &str, value: T) -> CreatedAtResponse<T> {
let location = format!("{}/{}", prefix.trim_end_matches('/'), value.id());
created_at(&location, value)
}
#[cfg(feature = "rfc-types")]
pub fn ok<T: serde::Serialize>(value: T) -> HandlerResponse<T> {
let body = serde_json::to_vec(&api_bones::ApiResponse::builder(value).build())
.expect("ApiResponse<T> is always serializable");
Ok(RfcOk::new(
axum::http::StatusCode::OK,
axum::http::HeaderMap::new(),
body,
))
}
#[cfg(not(feature = "rfc-types"))]
pub fn ok<T>(value: T) -> HandlerResponse<T> {
Ok((
axum::http::StatusCode::OK,
axum::Json(api_bones::ApiResponse::builder(value).build()),
))
}
#[cfg(feature = "rfc-types")]
pub fn listed<T: serde::Serialize>(
page: api_bones::PaginatedResponse<T>,
) -> HandlerListResponse<T> {
let body = serde_json::to_vec(&api_bones::ApiResponse::builder(page).build())
.expect("ApiResponse<PaginatedResponse<T>> is always serializable");
Ok(RfcOk::new(
axum::http::StatusCode::OK,
axum::http::HeaderMap::new(),
body,
))
}
#[cfg(not(feature = "rfc-types"))]
pub fn listed<T>(page: api_bones::PaginatedResponse<T>) -> HandlerListResponse<T> {
Ok(axum::Json(api_bones::ApiResponse::builder(page).build()))
}
pub fn listed_page<T, U>(
items: Vec<T>,
params: &api_bones::pagination::PaginationParams,
) -> HandlerListResponse<U>
where
T: Into<U>,
U: serde::Serialize,
{
let total = items.len() as u64;
let page: Vec<U> = items
.into_iter()
.skip(params.offset.unwrap_or(0) as usize)
.take(params.limit.unwrap_or(20) as usize)
.map(Into::into)
.collect();
listed(api_bones::PaginatedResponse::new(page, total, params))
}
#[cfg(feature = "rfc-types")]
pub fn etagged<T: serde::Serialize>(
etag: api_bones::etag::ETag,
value: T,
) -> EtaggedHandlerResponse<T> {
let mut headers = axum::http::HeaderMap::new();
headers.insert(
axum::http::header::ETAG,
axum::http::HeaderValue::from_str(&etag.to_string())
.expect("ETag is always a valid header value"),
);
let body = serde_json::to_vec(&api_bones::ApiResponse::builder(value).build())
.expect("ApiResponse<T> is always serializable");
Ok(RfcOk::new(axum::http::StatusCode::OK, headers, body))
}
#[cfg(not(feature = "rfc-types"))]
pub fn etagged<T>(etag: api_bones::etag::ETag, value: T) -> EtaggedHandlerResponse<T> {
Ok((
axum::http::StatusCode::OK,
etag,
axum::Json(api_bones::ApiResponse::builder(value).build()),
))
}
pub(crate) fn panic_handler(err: Box<dyn std::any::Any + Send + 'static>) -> Response {
let detail = if let Some(s) = err.downcast_ref::<String>() {
s.as_str()
} else if let Some(s) = err.downcast_ref::<&str>() {
s
} else {
"panic"
};
tracing::error!(panic = detail, "handler panicked");
HandlerError::new(ErrorCode::InternalServerError, "internal server error").into_response()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn handler_error_into_response_returns_problem_json() {
let err = HandlerError::new(ErrorCode::ResourceNotFound, "not found");
let resp = err.into_response();
assert_eq!(resp.status(), 404);
}
#[test]
fn handler_error_from_api_error() {
let api_err = ApiError::new(ErrorCode::InternalServerError, "oops");
let handler_err = HandlerError::from(api_err);
let resp = handler_err.into_response();
assert_eq!(resp.status(), 500);
}
#[test]
fn with_request_id_and_errors() {
let id = uuid::Uuid::now_v7();
let err = HandlerError::new(ErrorCode::ValidationFailed, "bad input")
.with_request_id(id)
.with_errors(vec![ValidationError {
field: "name".into(),
message: "required".into(),
rule: None,
}]);
let resp = err.into_response();
assert_eq!(resp.status(), 400);
}
#[test]
fn panic_handler_downcasts_string_payload() {
let payload: Box<dyn std::any::Any + Send + 'static> = Box::new("boom".to_string());
let resp = panic_handler(payload);
assert_eq!(resp.status(), 500);
}
#[test]
fn panic_handler_downcasts_static_str_payload() {
let payload: Box<dyn std::any::Any + Send + 'static> = Box::new("static boom");
let resp = panic_handler(payload);
assert_eq!(resp.status(), 500);
}
#[test]
fn panic_handler_handles_unknown_payload() {
let payload: Box<dyn std::any::Any + Send + 'static> = Box::new(42u32);
let resp = panic_handler(payload);
assert_eq!(resp.status(), 500);
}
#[cfg(feature = "rfc-types")]
mod rfc {
use super::*;
#[test]
fn created_builds_201_with_envelope() {
let resp = created("x").unwrap();
assert_eq!(resp.status(), axum::http::StatusCode::CREATED);
assert_eq!(resp.body_json()["data"], "x");
}
#[test]
fn ok_builds_200_with_envelope() {
let resp = ok(42u32).unwrap();
assert_eq!(resp.status(), axum::http::StatusCode::OK);
assert_eq!(resp.body_json()["data"], 42);
}
#[test]
fn etagged_builds_200_with_etag_and_envelope() {
use api_bones::etag::ETag;
let etag = ETag::strong("abc123");
let resp = etagged(etag.clone(), 99u32).unwrap();
assert_eq!(resp.status(), axum::http::StatusCode::OK);
assert_eq!(
resp.headers()
.get(axum::http::header::ETAG)
.unwrap()
.to_str()
.unwrap(),
etag.to_string(),
);
assert_eq!(resp.body_json()["data"], 99);
}
#[test]
fn listed_wraps_paginated_response() {
use api_bones::{PaginatedResponse, pagination::PaginationParams};
let page: PaginatedResponse<u32> =
PaginatedResponse::new(vec![1, 2], 2, &PaginationParams::default());
let json = listed(page).unwrap().body_json();
assert_eq!(json["data"]["items"], serde_json::json!([1, 2]));
}
#[test]
fn listed_page_maps_and_paginates() {
use api_bones::pagination::PaginationParams;
let items: Vec<u32> = (1..=5).collect();
let params = PaginationParams {
offset: Some(1),
limit: Some(2),
};
let json = listed_page::<u32, u64>(items, ¶ms).unwrap().body_json();
assert_eq!(json["data"]["items"], serde_json::json!([2, 3]));
}
#[test]
fn listed_page_uses_defaults_when_params_are_none() {
use api_bones::pagination::PaginationParams;
let items: Vec<u32> = (1..=25).collect();
let json = listed_page::<u32, u64>(items, &PaginationParams::default())
.unwrap()
.body_json();
assert_eq!(json["data"]["items"].as_array().unwrap().len(), 20);
}
#[test]
fn created_under_composes_location() {
struct R {
id: u64,
}
impl api_bones::HasId for R {
type Id = u64;
fn id(&self) -> &u64 {
&self.id
}
}
impl serde::Serialize for R {
fn serialize<S: serde::Serializer>(&self, s: S) -> Result<S::Ok, S::Error> {
s.serialize_u64(self.id)
}
}
let resp = created_under("/v1/widgets", R { id: 7 }).unwrap();
assert_eq!(resp.status(), axum::http::StatusCode::CREATED);
assert_eq!(
resp.headers().get(axum::http::header::LOCATION).unwrap(),
"/v1/widgets/7",
);
}
#[test]
fn created_under_trims_trailing_slash() {
struct R {
id: String,
}
impl api_bones::HasId for R {
type Id = String;
fn id(&self) -> &String {
&self.id
}
}
impl serde::Serialize for R {
fn serialize<S: serde::Serializer>(&self, s: S) -> Result<S::Ok, S::Error> {
s.serialize_str(&self.id)
}
}
let resp = created_under("/v1/things/", R { id: "abc".into() }).unwrap();
assert_eq!(
resp.headers().get(axum::http::header::LOCATION).unwrap(),
"/v1/things/abc",
);
}
}
#[test]
fn unconstrained_response_passes_through() {
use axum::http::StatusCode;
let resp = UnconstrainedResponse::new(StatusCode::IM_A_TEAPOT).into_response();
assert_eq!(resp.status(), StatusCode::IM_A_TEAPOT);
}
}