use std::{
future::Future,
pin::Pin,
task::{Context, Poll},
};
use axum::{Json, body::Body, body::to_bytes, extract::Request, response::IntoResponse};
use http::StatusCode;
use serde::{Deserialize, Serialize};
use tower::{Layer, Service};
use crate::context::RequestContext;
#[derive(Clone, Debug, Eq, PartialEq, thiserror::Error)]
#[error("{message}")]
pub struct HttpError {
status: StatusCode,
code: &'static str,
message: String,
}
impl HttpError {
pub fn new(status: StatusCode, code: &'static str, message: impl Into<String>) -> Self {
Self {
status,
code,
message: message.into(),
}
}
pub fn bad_request(message: impl Into<String>) -> Self {
Self::new(StatusCode::BAD_REQUEST, "bad_request", message)
}
pub fn unauthorized(message: impl Into<String>) -> Self {
Self::new(StatusCode::UNAUTHORIZED, "unauthorized", message)
}
pub fn forbidden(message: impl Into<String>) -> Self {
Self::new(StatusCode::FORBIDDEN, "forbidden", message)
}
pub fn not_found(message: impl Into<String>) -> Self {
Self::new(StatusCode::NOT_FOUND, "not_found", message)
}
pub fn conflict(message: impl Into<String>) -> Self {
Self::new(StatusCode::CONFLICT, "conflict", message)
}
pub fn too_many_requests(message: impl Into<String>) -> Self {
Self::new(StatusCode::TOO_MANY_REQUESTS, "too_many_requests", message)
}
pub fn unprocessable_entity(message: impl Into<String>) -> Self {
Self::new(
StatusCode::UNPROCESSABLE_ENTITY,
"unprocessable_entity",
message,
)
}
pub fn internal_server_error() -> Self {
Self::new(
StatusCode::INTERNAL_SERVER_ERROR,
"internal_server_error",
"internal server error",
)
}
pub fn status(&self) -> StatusCode {
self.status
}
pub fn code(&self) -> &'static str {
self.code
}
pub fn message(&self) -> &str {
&self.message
}
}
impl IntoResponse for HttpError {
fn into_response(self) -> axum::response::Response {
let status = self.status;
let code = self.code;
let message = self.message;
if status.is_server_error() {
tracing::error!(
http.status = status.as_u16(),
error.code = code,
error.message = %message,
"http error response"
);
} else {
tracing::warn!(
http.status = status.as_u16(),
error.code = code,
error.message = %message,
"http error response"
);
}
let body = Json(ErrorBody {
error: ErrorDetails { code, message },
});
(status, body).into_response()
}
}
pub async fn not_found_fallback() -> HttpError {
HttpError::not_found("route not found")
}
#[derive(Debug, Serialize)]
struct ErrorBody {
error: ErrorDetails,
}
#[derive(Debug, Serialize)]
struct ErrorDetails {
code: &'static str,
message: String,
}
#[derive(Clone, Copy, Debug, Default)]
pub struct ErrorEnvelopeLayer;
impl ErrorEnvelopeLayer {
pub fn new() -> Self {
Self
}
}
impl<S> Layer<S> for ErrorEnvelopeLayer {
type Service = ErrorEnvelopeService<S>;
fn layer(&self, inner: S) -> Self::Service {
ErrorEnvelopeService { inner }
}
}
#[derive(Clone, Debug)]
pub struct ErrorEnvelopeService<S> {
inner: S,
}
impl<S> Service<Request> for ErrorEnvelopeService<S>
where
S: Service<Request, Response = axum::response::Response> + Send + 'static,
S::Future: Send + 'static,
S::Error: Send + 'static,
{
type Response = axum::response::Response;
type Error = S::Error;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, request: Request) -> Self::Future {
let path = request.uri().path().to_owned();
let context = request.extensions().get::<RequestContext>().cloned();
let future = self.inner.call(request);
Box::pin(async move {
let response = future.await?;
if !response.status().is_client_error() && !response.status().is_server_error() {
return Ok(response);
}
Ok(envelope_response(response, context, path).await)
})
}
}
async fn envelope_response(
response: axum::response::Response,
context: Option<RequestContext>,
path: String,
) -> axum::response::Response {
let (mut parts, body) = response.into_parts();
let status = parts.status;
let extracted = read_legacy_error_body(body).await;
let mut code = extracted
.as_ref()
.map(|body| body.error.code.clone())
.unwrap_or_else(|| default_code(status).to_owned());
let mut message = extracted
.as_ref()
.map(|body| body.error.message.clone())
.unwrap_or_else(|| status.canonical_reason().unwrap_or("error").to_owned());
let mut details = extracted
.map(|body| {
if body.error.details.is_empty() {
serde_json::Value::Null
} else {
serde_json::Value::Object(body.error.details)
}
})
.unwrap_or(serde_json::Value::Null);
if status.is_server_error() {
tracing::error!(
http.status = status.as_u16(),
error.code = %code,
request.id = context.as_ref().map(RequestContext::request_id).unwrap_or(""),
http.path = %path,
"http error envelope"
);
message = "internal server error".to_owned();
details = serde_json::Value::Null;
code = default_code(status).to_owned();
}
let envelope = ProductionErrorBody {
error: ProductionErrorDetails {
status_code: status.as_u16(),
code,
message,
details,
timestamp: timestamp_now(),
path,
request_id: context
.as_ref()
.map(RequestContext::request_id)
.unwrap_or("")
.to_owned(),
},
};
let body = serde_json::to_vec(&envelope).expect("error envelope should serialize");
parts.headers.insert(
http::header::CONTENT_TYPE,
http::HeaderValue::from_static("application/json"),
);
axum::response::Response::from_parts(parts, Body::from(body))
}
const MAX_ERROR_ENVELOPE_BODY_BYTES: usize = 64 * 1024;
async fn read_legacy_error_body(body: Body) -> Option<LegacyErrorBody> {
let bytes = to_bytes(body, MAX_ERROR_ENVELOPE_BODY_BYTES).await.ok()?;
serde_json::from_slice::<LegacyErrorBody>(&bytes).ok()
}
pub(crate) fn timestamp_now() -> String {
time::OffsetDateTime::now_utc()
.format(&time::format_description::well_known::Rfc3339)
.expect("UTC timestamp should format as RFC3339")
}
fn default_code(status: StatusCode) -> &'static str {
match status {
StatusCode::BAD_REQUEST => "bad_request",
StatusCode::UNAUTHORIZED => "unauthorized",
StatusCode::FORBIDDEN => "forbidden",
StatusCode::NOT_FOUND => "not_found",
StatusCode::CONFLICT => "conflict",
StatusCode::UNPROCESSABLE_ENTITY => "unprocessable_entity",
StatusCode::TOO_MANY_REQUESTS => "too_many_requests",
status if status.is_server_error() => "internal_server_error",
_ => "http_error",
}
}
#[derive(Debug, Deserialize)]
struct LegacyErrorBody {
error: LegacyErrorDetails,
}
#[derive(Debug, Deserialize)]
struct LegacyErrorDetails {
code: String,
message: String,
#[serde(flatten)]
details: serde_json::Map<String, serde_json::Value>,
}
#[derive(Debug, Serialize)]
struct ProductionErrorBody {
error: ProductionErrorDetails,
}
#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
struct ProductionErrorDetails {
status_code: u16,
code: String,
message: String,
details: serde_json::Value,
timestamp: String,
path: String,
request_id: String,
}
#[derive(Clone, Debug, Eq, PartialEq, thiserror::Error)]
#[error("route path `{path}` contains a parameter segment without a name after ':'")]
pub struct RoutePathError {
path: String,
}
impl RoutePathError {
pub fn empty_parameter(path: impl Into<String>) -> Self {
Self { path: path.into() }
}
pub fn path(&self) -> &str {
&self.path
}
}