rama_http/layer/set_status.rs
1//! Middleware to override status codes.
2//!
3//! # Example
4//!
5//! ```
6//! use std::{iter::once, convert::Infallible};
7//! use bytes::Bytes;
8//! use rama_http::layer::set_status::SetStatusLayer;
9//! use rama_http::{Body, Request, Response, StatusCode};
10//! use rama_core::service::service_fn;
11//! use rama_core::{Context, Layer, Service};
12//! use rama_core::error::BoxError;
13//!
14//! async fn handle(req: Request) -> Result<Response, Infallible> {
15//! // ...
16//! # Ok(Response::new(Body::empty()))
17//! }
18//!
19//! # #[tokio::main]
20//! # async fn main() -> Result<(), BoxError> {
21//! let mut service = (
22//! // change the status to `404 Not Found` regardless what the inner service returns
23//! SetStatusLayer::new(StatusCode::NOT_FOUND),
24//! ).into_layer(service_fn(handle));
25//!
26//! // Call the service.
27//! let request = Request::builder().body(Body::empty())?;
28//!
29//! let response = service.serve(Context::default(), request).await?;
30//!
31//! assert_eq!(response.status(), StatusCode::NOT_FOUND);
32//! #
33//! # Ok(())
34//! # }
35//! ```
36
37use std::fmt;
38
39use crate::{Request, Response, StatusCode};
40use rama_core::{Context, Layer, Service};
41use rama_utils::macros::define_inner_service_accessors;
42
43/// Layer that applies [`SetStatus`] which overrides the status codes.
44#[derive(Debug, Clone)]
45pub struct SetStatusLayer {
46 status: StatusCode,
47}
48
49impl SetStatusLayer {
50 /// Create a new [`SetStatusLayer`].
51 ///
52 /// The response status code will be `status` regardless of what the inner service returns.
53 pub const fn new(status: StatusCode) -> Self {
54 SetStatusLayer { status }
55 }
56
57 /// Create a new [`SetStatusLayer`] layer which will create
58 /// a service that will always set the status code at [`StatusCode::OK`].
59 #[inline]
60 pub const fn ok() -> Self {
61 Self::new(StatusCode::OK)
62 }
63}
64
65impl<S> Layer<S> for SetStatusLayer {
66 type Service = SetStatus<S>;
67
68 fn layer(&self, inner: S) -> Self::Service {
69 SetStatus::new(inner, self.status)
70 }
71}
72
73/// Middleware to override status codes.
74///
75/// See the [module docs](self) for more details.
76pub struct SetStatus<S> {
77 inner: S,
78 status: StatusCode,
79}
80
81impl<S> SetStatus<S> {
82 /// Create a new [`SetStatus`].
83 ///
84 /// The response status code will be `status` regardless of what the inner service returns.
85 pub const fn new(inner: S, status: StatusCode) -> Self {
86 Self { status, inner }
87 }
88
89 /// Create a new [`SetStatus`] service which will always set the
90 /// status code at [`StatusCode::OK`].
91 #[inline]
92 pub const fn ok(inner: S) -> Self {
93 Self::new(inner, StatusCode::OK)
94 }
95
96 define_inner_service_accessors!();
97}
98
99impl<S: fmt::Debug> fmt::Debug for SetStatus<S> {
100 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
101 f.debug_struct("SetStatus")
102 .field("inner", &self.inner)
103 .field("status", &self.status)
104 .finish()
105 }
106}
107
108impl<S: Clone> Clone for SetStatus<S> {
109 fn clone(&self) -> Self {
110 SetStatus {
111 inner: self.inner.clone(),
112 status: self.status,
113 }
114 }
115}
116
117impl<S: Copy> Copy for SetStatus<S> {}
118
119impl<State, S, ReqBody, ResBody> Service<State, Request<ReqBody>> for SetStatus<S>
120where
121 State: Clone + Send + Sync + 'static,
122 S: Service<State, Request<ReqBody>, Response = Response<ResBody>>,
123 ReqBody: Send + 'static,
124 ResBody: Send + 'static,
125{
126 type Response = S::Response;
127 type Error = S::Error;
128
129 async fn serve(
130 &self,
131 ctx: Context<State>,
132 req: Request<ReqBody>,
133 ) -> Result<Self::Response, Self::Error> {
134 let mut response = self.inner.serve(ctx, req).await?;
135 *response.status_mut() = self.status;
136 Ok(response)
137 }
138}