rama_http/layer/
set_status.rs

1//! Middleware to override status codes.
2//!
3//! # Example
4//!
5//! ```
6//! use std::{iter::once, convert::Infallible};
7//! use bytes::Bytes;
8//! use rama_http::layer::set_status::SetStatusLayer;
9//! use rama_http::{Body, Request, Response, StatusCode};
10//! use rama_core::service::service_fn;
11//! use rama_core::{Context, Layer, Service};
12//! use rama_core::error::BoxError;
13//!
14//! async fn handle(req: Request) -> Result<Response, Infallible> {
15//!     // ...
16//!     # Ok(Response::new(Body::empty()))
17//! }
18//!
19//! # #[tokio::main]
20//! # async fn main() -> Result<(), BoxError> {
21//! let mut service = (
22//!     // change the status to `404 Not Found` regardless what the inner service returns
23//!     SetStatusLayer::new(StatusCode::NOT_FOUND),
24//! ).into_layer(service_fn(handle));
25//!
26//! // Call the service.
27//! let request = Request::builder().body(Body::empty())?;
28//!
29//! let response = service.serve(Context::default(), request).await?;
30//!
31//! assert_eq!(response.status(), StatusCode::NOT_FOUND);
32//! #
33//! # Ok(())
34//! # }
35//! ```
36
37use std::fmt;
38
39use crate::{Request, Response, StatusCode};
40use rama_core::{Context, Layer, Service};
41use rama_utils::macros::define_inner_service_accessors;
42
43/// Layer that applies [`SetStatus`] which overrides the status codes.
44#[derive(Debug, Clone)]
45pub struct SetStatusLayer {
46    status: StatusCode,
47}
48
49impl SetStatusLayer {
50    /// Create a new [`SetStatusLayer`].
51    ///
52    /// The response status code will be `status` regardless of what the inner service returns.
53    pub const fn new(status: StatusCode) -> Self {
54        SetStatusLayer { status }
55    }
56
57    /// Create a new [`SetStatusLayer`] layer which will create
58    /// a service that will always set the status code at [`StatusCode::OK`].
59    #[inline]
60    pub const fn ok() -> Self {
61        Self::new(StatusCode::OK)
62    }
63}
64
65impl<S> Layer<S> for SetStatusLayer {
66    type Service = SetStatus<S>;
67
68    fn layer(&self, inner: S) -> Self::Service {
69        SetStatus::new(inner, self.status)
70    }
71}
72
73/// Middleware to override status codes.
74///
75/// See the [module docs](self) for more details.
76pub struct SetStatus<S> {
77    inner: S,
78    status: StatusCode,
79}
80
81impl<S> SetStatus<S> {
82    /// Create a new [`SetStatus`].
83    ///
84    /// The response status code will be `status` regardless of what the inner service returns.
85    pub const fn new(inner: S, status: StatusCode) -> Self {
86        Self { status, inner }
87    }
88
89    /// Create a new [`SetStatus`] service which will always set the
90    /// status code at [`StatusCode::OK`].
91    #[inline]
92    pub const fn ok(inner: S) -> Self {
93        Self::new(inner, StatusCode::OK)
94    }
95
96    define_inner_service_accessors!();
97}
98
99impl<S: fmt::Debug> fmt::Debug for SetStatus<S> {
100    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
101        f.debug_struct("SetStatus")
102            .field("inner", &self.inner)
103            .field("status", &self.status)
104            .finish()
105    }
106}
107
108impl<S: Clone> Clone for SetStatus<S> {
109    fn clone(&self) -> Self {
110        SetStatus {
111            inner: self.inner.clone(),
112            status: self.status,
113        }
114    }
115}
116
117impl<S: Copy> Copy for SetStatus<S> {}
118
119impl<State, S, ReqBody, ResBody> Service<State, Request<ReqBody>> for SetStatus<S>
120where
121    State: Clone + Send + Sync + 'static,
122    S: Service<State, Request<ReqBody>, Response = Response<ResBody>>,
123    ReqBody: Send + 'static,
124    ResBody: Send + 'static,
125{
126    type Response = S::Response;
127    type Error = S::Error;
128
129    async fn serve(
130        &self,
131        ctx: Context<State>,
132        req: Request<ReqBody>,
133    ) -> Result<Self::Response, Self::Error> {
134        let mut response = self.inner.serve(ctx, req).await?;
135        *response.status_mut() = self.status;
136        Ok(response)
137    }
138}