1use std::{
4 future::Future,
5 pin::Pin,
6 task::{Context, Poll},
7};
8
9use axum::{Json, body::Body, body::to_bytes, extract::Request, response::IntoResponse};
10use http::StatusCode;
11use serde::{Deserialize, Serialize};
12use tower::{Layer, Service};
13
14use crate::context::RequestContext;
15
16#[derive(Clone, Debug, Eq, PartialEq, thiserror::Error)]
29#[error("{message}")]
30pub struct HttpError {
31 status: StatusCode,
32 code: &'static str,
33 message: String,
34}
35
36impl HttpError {
37 pub fn new(status: StatusCode, code: &'static str, message: impl Into<String>) -> Self {
43 Self {
44 status,
45 code,
46 message: message.into(),
47 }
48 }
49
50 pub fn bad_request(message: impl Into<String>) -> Self {
52 Self::new(StatusCode::BAD_REQUEST, "bad_request", message)
53 }
54
55 pub fn unauthorized(message: impl Into<String>) -> Self {
57 Self::new(StatusCode::UNAUTHORIZED, "unauthorized", message)
58 }
59
60 pub fn forbidden(message: impl Into<String>) -> Self {
62 Self::new(StatusCode::FORBIDDEN, "forbidden", message)
63 }
64
65 pub fn not_found(message: impl Into<String>) -> Self {
67 Self::new(StatusCode::NOT_FOUND, "not_found", message)
68 }
69
70 pub fn conflict(message: impl Into<String>) -> Self {
72 Self::new(StatusCode::CONFLICT, "conflict", message)
73 }
74
75 pub fn too_many_requests(message: impl Into<String>) -> Self {
77 Self::new(StatusCode::TOO_MANY_REQUESTS, "too_many_requests", message)
78 }
79
80 pub fn unprocessable_entity(message: impl Into<String>) -> Self {
82 Self::new(
83 StatusCode::UNPROCESSABLE_ENTITY,
84 "unprocessable_entity",
85 message,
86 )
87 }
88
89 pub fn internal_server_error() -> Self {
94 Self::new(
95 StatusCode::INTERNAL_SERVER_ERROR,
96 "internal_server_error",
97 "internal server error",
98 )
99 }
100
101 pub fn status(&self) -> StatusCode {
103 self.status
104 }
105
106 pub fn code(&self) -> &'static str {
108 self.code
109 }
110
111 pub fn message(&self) -> &str {
113 &self.message
114 }
115}
116
117impl IntoResponse for HttpError {
118 fn into_response(self) -> axum::response::Response {
119 let status = self.status;
120 let code = self.code;
121 let message = self.message;
122
123 if status.is_server_error() {
124 tracing::error!(
125 http.status = status.as_u16(),
126 error.code = code,
127 error.message = %message,
128 "http error response"
129 );
130 } else {
131 tracing::warn!(
132 http.status = status.as_u16(),
133 error.code = code,
134 error.message = %message,
135 "http error response"
136 );
137 }
138
139 let body = Json(ErrorBody {
140 error: ErrorDetails { code, message },
141 });
142 (status, body).into_response()
143 }
144}
145
146pub async fn not_found_fallback() -> HttpError {
152 HttpError::not_found("route not found")
153}
154
155#[derive(Debug, Serialize)]
156struct ErrorBody {
157 error: ErrorDetails,
158}
159
160#[derive(Debug, Serialize)]
161struct ErrorDetails {
162 code: &'static str,
163 message: String,
164}
165
166#[derive(Clone, Copy, Debug, Default)]
200pub struct ErrorEnvelopeLayer;
201
202impl ErrorEnvelopeLayer {
203 pub fn new() -> Self {
205 Self
206 }
207}
208
209impl<S> Layer<S> for ErrorEnvelopeLayer {
210 type Service = ErrorEnvelopeService<S>;
211
212 fn layer(&self, inner: S) -> Self::Service {
213 ErrorEnvelopeService { inner }
214 }
215}
216
217#[derive(Clone, Debug)]
219pub struct ErrorEnvelopeService<S> {
220 inner: S,
221}
222
223impl<S> Service<Request> for ErrorEnvelopeService<S>
224where
225 S: Service<Request, Response = axum::response::Response> + Send + 'static,
226 S::Future: Send + 'static,
227 S::Error: Send + 'static,
228{
229 type Response = axum::response::Response;
230 type Error = S::Error;
231 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
232
233 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
234 self.inner.poll_ready(cx)
235 }
236
237 fn call(&mut self, request: Request) -> Self::Future {
238 let path = request.uri().path().to_owned();
239 let context = request.extensions().get::<RequestContext>().cloned();
240 let future = self.inner.call(request);
241
242 Box::pin(async move {
243 let response = future.await?;
244 if !response.status().is_client_error() && !response.status().is_server_error() {
245 return Ok(response);
246 }
247 Ok(envelope_response(response, context, path).await)
248 })
249 }
250}
251
252async fn envelope_response(
253 response: axum::response::Response,
254 context: Option<RequestContext>,
255 path: String,
256) -> axum::response::Response {
257 let (mut parts, body) = response.into_parts();
258 let status = parts.status;
259 let extracted = read_legacy_error_body(body).await;
260 let mut code = extracted
261 .as_ref()
262 .map(|body| body.error.code.clone())
263 .unwrap_or_else(|| default_code(status).to_owned());
264 let mut message = extracted
265 .as_ref()
266 .map(|body| body.error.message.clone())
267 .unwrap_or_else(|| status.canonical_reason().unwrap_or("error").to_owned());
268 let mut details = extracted
269 .map(|body| {
270 if body.error.details.is_empty() {
271 serde_json::Value::Null
272 } else {
273 serde_json::Value::Object(body.error.details)
274 }
275 })
276 .unwrap_or(serde_json::Value::Null);
277 if status.is_server_error() {
278 tracing::error!(
279 http.status = status.as_u16(),
280 error.code = %code,
281 request.id = context.as_ref().map(RequestContext::request_id).unwrap_or(""),
282 http.path = %path,
283 "http error envelope"
284 );
285 message = "internal server error".to_owned();
288 details = serde_json::Value::Null;
289 code = default_code(status).to_owned();
290 }
291
292 let envelope = ProductionErrorBody {
293 error: ProductionErrorDetails {
294 status_code: status.as_u16(),
295 code,
296 message,
297 details,
298 timestamp: timestamp_now(),
299 path,
300 request_id: context
301 .as_ref()
302 .map(RequestContext::request_id)
303 .unwrap_or("")
304 .to_owned(),
305 },
306 };
307 let body = serde_json::to_vec(&envelope).expect("error envelope should serialize");
308 parts.headers.insert(
309 http::header::CONTENT_TYPE,
310 http::HeaderValue::from_static("application/json"),
311 );
312 axum::response::Response::from_parts(parts, Body::from(body))
313}
314
315const MAX_ERROR_ENVELOPE_BODY_BYTES: usize = 64 * 1024;
316
317async fn read_legacy_error_body(body: Body) -> Option<LegacyErrorBody> {
318 let bytes = to_bytes(body, MAX_ERROR_ENVELOPE_BODY_BYTES).await.ok()?;
319 serde_json::from_slice::<LegacyErrorBody>(&bytes).ok()
320}
321
322pub(crate) fn timestamp_now() -> String {
324 time::OffsetDateTime::now_utc()
325 .format(&time::format_description::well_known::Rfc3339)
326 .expect("UTC timestamp should format as RFC3339")
327}
328
329fn default_code(status: StatusCode) -> &'static str {
330 match status {
331 StatusCode::BAD_REQUEST => "bad_request",
332 StatusCode::UNAUTHORIZED => "unauthorized",
333 StatusCode::FORBIDDEN => "forbidden",
334 StatusCode::NOT_FOUND => "not_found",
335 StatusCode::CONFLICT => "conflict",
336 StatusCode::UNPROCESSABLE_ENTITY => "unprocessable_entity",
337 StatusCode::TOO_MANY_REQUESTS => "too_many_requests",
338 status if status.is_server_error() => "internal_server_error",
339 _ => "http_error",
340 }
341}
342
343#[derive(Debug, Deserialize)]
344struct LegacyErrorBody {
345 error: LegacyErrorDetails,
346}
347
348#[derive(Debug, Deserialize)]
349struct LegacyErrorDetails {
350 code: String,
351 message: String,
352 #[serde(flatten)]
353 details: serde_json::Map<String, serde_json::Value>,
354}
355
356#[derive(Debug, Serialize)]
357struct ProductionErrorBody {
358 error: ProductionErrorDetails,
359}
360
361#[derive(Debug, Serialize)]
362#[serde(rename_all = "camelCase")]
363struct ProductionErrorDetails {
364 status_code: u16,
365 code: String,
366 message: String,
367 details: serde_json::Value,
368 timestamp: String,
369 path: String,
370 request_id: String,
371}
372
373#[derive(Clone, Debug, Eq, PartialEq, thiserror::Error)]
375#[error("route path `{path}` contains a parameter segment without a name after ':'")]
376pub struct RoutePathError {
377 path: String,
378}
379
380impl RoutePathError {
381 pub fn empty_parameter(path: impl Into<String>) -> Self {
383 Self { path: path.into() }
384 }
385
386 pub fn path(&self) -> &str {
388 &self.path
389 }
390}