1use std::{
4 future::Future,
5 pin::Pin,
6 task::{Context, Poll},
7};
8
9use axum::{Json, body::Body, body::to_bytes, extract::Request, response::IntoResponse};
10use http::StatusCode;
11use serde::{Deserialize, Serialize};
12use tower::{Layer, Service};
13
14use crate::context::RequestContext;
15
16#[derive(Clone, Debug, Eq, PartialEq, thiserror::Error)]
29#[error("{message}")]
30pub struct HttpError {
31 status: StatusCode,
32 code: &'static str,
33 message: String,
34}
35
36impl HttpError {
37 pub fn new(status: StatusCode, code: &'static str, message: impl Into<String>) -> Self {
43 Self {
44 status,
45 code,
46 message: message.into(),
47 }
48 }
49
50 pub fn bad_request(message: impl Into<String>) -> Self {
52 Self::new(StatusCode::BAD_REQUEST, "bad_request", message)
53 }
54
55 pub fn unauthorized(message: impl Into<String>) -> Self {
57 Self::new(StatusCode::UNAUTHORIZED, "unauthorized", message)
58 }
59
60 pub fn forbidden(message: impl Into<String>) -> Self {
62 Self::new(StatusCode::FORBIDDEN, "forbidden", message)
63 }
64
65 pub fn not_found(message: impl Into<String>) -> Self {
67 Self::new(StatusCode::NOT_FOUND, "not_found", message)
68 }
69
70 pub fn conflict(message: impl Into<String>) -> Self {
72 Self::new(StatusCode::CONFLICT, "conflict", message)
73 }
74
75 pub fn too_many_requests(message: impl Into<String>) -> Self {
77 Self::new(StatusCode::TOO_MANY_REQUESTS, "too_many_requests", message)
78 }
79
80 pub fn unprocessable_entity(message: impl Into<String>) -> Self {
82 Self::new(
83 StatusCode::UNPROCESSABLE_ENTITY,
84 "unprocessable_entity",
85 message,
86 )
87 }
88
89 pub fn internal_server_error() -> Self {
94 Self::new(
95 StatusCode::INTERNAL_SERVER_ERROR,
96 "internal_server_error",
97 "internal server error",
98 )
99 }
100
101 pub fn status(&self) -> StatusCode {
103 self.status
104 }
105
106 pub fn code(&self) -> &'static str {
108 self.code
109 }
110
111 pub fn message(&self) -> &str {
113 &self.message
114 }
115}
116
117impl IntoResponse for HttpError {
118 fn into_response(self) -> axum::response::Response {
119 let status = self.status;
120 let code = self.code;
121 let message = self.message;
122
123 if status.is_server_error() {
124 tracing::error!(
125 http.status = status.as_u16(),
126 error.code = code,
127 error.message = %message,
128 "http error response"
129 );
130 } else {
131 tracing::warn!(
132 http.status = status.as_u16(),
133 error.code = code,
134 error.message = %message,
135 "http error response"
136 );
137 }
138
139 let body = Json(ErrorBody {
140 error: ErrorDetails { code, message },
141 });
142 (status, body).into_response()
143 }
144}
145
146#[derive(Debug, Serialize)]
147struct ErrorBody {
148 error: ErrorDetails,
149}
150
151#[derive(Debug, Serialize)]
152struct ErrorDetails {
153 code: &'static str,
154 message: String,
155}
156
157#[derive(Clone, Copy, Debug, Default)]
191pub struct ErrorEnvelopeLayer;
192
193impl ErrorEnvelopeLayer {
194 pub fn new() -> Self {
196 Self
197 }
198}
199
200impl<S> Layer<S> for ErrorEnvelopeLayer {
201 type Service = ErrorEnvelopeService<S>;
202
203 fn layer(&self, inner: S) -> Self::Service {
204 ErrorEnvelopeService { inner }
205 }
206}
207
208#[derive(Clone, Debug)]
210pub struct ErrorEnvelopeService<S> {
211 inner: S,
212}
213
214impl<S> Service<Request> for ErrorEnvelopeService<S>
215where
216 S: Service<Request, Response = axum::response::Response> + Send + 'static,
217 S::Future: Send + 'static,
218 S::Error: Send + 'static,
219{
220 type Response = axum::response::Response;
221 type Error = S::Error;
222 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
223
224 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
225 self.inner.poll_ready(cx)
226 }
227
228 fn call(&mut self, request: Request) -> Self::Future {
229 let path = request.uri().path().to_owned();
230 let context = request.extensions().get::<RequestContext>().cloned();
231 let future = self.inner.call(request);
232
233 Box::pin(async move {
234 let response = future.await?;
235 if !response.status().is_client_error() && !response.status().is_server_error() {
236 return Ok(response);
237 }
238 Ok(envelope_response(response, context, path).await)
239 })
240 }
241}
242
243async fn envelope_response(
244 response: axum::response::Response,
245 context: Option<RequestContext>,
246 path: String,
247) -> axum::response::Response {
248 let (mut parts, body) = response.into_parts();
249 let status = parts.status;
250 let extracted = read_legacy_error_body(body).await;
251 let mut code = extracted
252 .as_ref()
253 .map(|body| body.error.code.clone())
254 .unwrap_or_else(|| default_code(status).to_owned());
255 let mut message = extracted
256 .as_ref()
257 .map(|body| body.error.message.clone())
258 .unwrap_or_else(|| status.canonical_reason().unwrap_or("error").to_owned());
259 let mut details = extracted
260 .map(|body| {
261 if body.error.details.is_empty() {
262 serde_json::Value::Null
263 } else {
264 serde_json::Value::Object(body.error.details)
265 }
266 })
267 .unwrap_or(serde_json::Value::Null);
268 if status.is_server_error() {
269 tracing::error!(
270 http.status = status.as_u16(),
271 error.code = %code,
272 request.id = context.as_ref().map(RequestContext::request_id).unwrap_or(""),
273 http.path = %path,
274 "http error envelope"
275 );
276 message = "internal server error".to_owned();
279 details = serde_json::Value::Null;
280 code = default_code(status).to_owned();
281 }
282
283 let envelope = ProductionErrorBody {
284 error: ProductionErrorDetails {
285 status_code: status.as_u16(),
286 code,
287 message,
288 details,
289 timestamp: timestamp_now(),
290 path,
291 request_id: context
292 .as_ref()
293 .map(RequestContext::request_id)
294 .unwrap_or("")
295 .to_owned(),
296 },
297 };
298 let body = serde_json::to_vec(&envelope).expect("error envelope should serialize");
299 parts.headers.insert(
300 http::header::CONTENT_TYPE,
301 http::HeaderValue::from_static("application/json"),
302 );
303 axum::response::Response::from_parts(parts, Body::from(body))
304}
305
306const MAX_ERROR_ENVELOPE_BODY_BYTES: usize = 64 * 1024;
307
308async fn read_legacy_error_body(body: Body) -> Option<LegacyErrorBody> {
309 let bytes = to_bytes(body, MAX_ERROR_ENVELOPE_BODY_BYTES).await.ok()?;
310 serde_json::from_slice::<LegacyErrorBody>(&bytes).ok()
311}
312
313pub(crate) fn timestamp_now() -> String {
315 time::OffsetDateTime::now_utc()
316 .format(&time::format_description::well_known::Rfc3339)
317 .expect("UTC timestamp should format as RFC3339")
318}
319
320fn default_code(status: StatusCode) -> &'static str {
321 match status {
322 StatusCode::BAD_REQUEST => "bad_request",
323 StatusCode::UNAUTHORIZED => "unauthorized",
324 StatusCode::FORBIDDEN => "forbidden",
325 StatusCode::NOT_FOUND => "not_found",
326 StatusCode::CONFLICT => "conflict",
327 StatusCode::UNPROCESSABLE_ENTITY => "unprocessable_entity",
328 StatusCode::TOO_MANY_REQUESTS => "too_many_requests",
329 status if status.is_server_error() => "internal_server_error",
330 _ => "http_error",
331 }
332}
333
334#[derive(Debug, Deserialize)]
335struct LegacyErrorBody {
336 error: LegacyErrorDetails,
337}
338
339#[derive(Debug, Deserialize)]
340struct LegacyErrorDetails {
341 code: String,
342 message: String,
343 #[serde(flatten)]
344 details: serde_json::Map<String, serde_json::Value>,
345}
346
347#[derive(Debug, Serialize)]
348struct ProductionErrorBody {
349 error: ProductionErrorDetails,
350}
351
352#[derive(Debug, Serialize)]
353#[serde(rename_all = "camelCase")]
354struct ProductionErrorDetails {
355 status_code: u16,
356 code: String,
357 message: String,
358 details: serde_json::Value,
359 timestamp: String,
360 path: String,
361 request_id: String,
362}
363
364#[derive(Clone, Debug, Eq, PartialEq, thiserror::Error)]
366#[error("route path `{path}` contains a parameter segment without a name after ':'")]
367pub struct RoutePathError {
368 path: String,
369}
370
371impl RoutePathError {
372 pub fn empty_parameter(path: impl Into<String>) -> Self {
374 Self { path: path.into() }
375 }
376
377 pub fn path(&self) -> &str {
379 &self.path
380 }
381}