rama_http/layer/required_header/
response.rs

1//! Set required headers on the response, if they are missing.
2//!
3//! For now this only sets `Server` and `Date` heades.
4
5use crate::{
6    HeaderValue, Request, Response,
7    header::{self, DATE, RAMA_ID_HEADER_VALUE, SERVER},
8    headers::{Date, HeaderMapExt},
9};
10use rama_core::{Context, Layer, Service};
11use rama_utils::macros::define_inner_service_accessors;
12use std::{fmt, time::SystemTime};
13
14/// Layer that applies [`AddRequiredResponseHeaders`] which adds a request header.
15///
16/// See [`AddRequiredResponseHeaders`] for more details.
17#[derive(Debug, Clone, Default)]
18pub struct AddRequiredResponseHeadersLayer {
19    overwrite: bool,
20    server_header_value: Option<HeaderValue>,
21}
22
23impl AddRequiredResponseHeadersLayer {
24    /// Create a new [`AddRequiredResponseHeadersLayer`].
25    pub const fn new() -> Self {
26        Self {
27            overwrite: false,
28            server_header_value: None,
29        }
30    }
31
32    /// Set whether to overwrite the existing headers.
33    /// If set to `true`, the headers will be overwritten.
34    ///
35    /// Default is `false`.
36    pub const fn overwrite(mut self, overwrite: bool) -> Self {
37        self.overwrite = overwrite;
38        self
39    }
40
41    /// Set whether to overwrite the existing headers.
42    /// If set to `true`, the headers will be overwritten.
43    ///
44    /// Default is `false`.
45    pub fn set_overwrite(&mut self, overwrite: bool) -> &mut Self {
46        self.overwrite = overwrite;
47        self
48    }
49
50    /// Set a custom [`SERVER`] header value.
51    ///
52    /// By default a versioned `rama` value is used.
53    pub fn server_header_value(mut self, value: HeaderValue) -> Self {
54        self.server_header_value = Some(value);
55        self
56    }
57
58    /// Maybe set a custom [`SERVER`] header value.
59    ///
60    /// By default a versioned `rama` value is used.
61    pub fn maybe_server_header_value(mut self, value: Option<HeaderValue>) -> Self {
62        self.server_header_value = value;
63        self
64    }
65
66    /// Set a custom [`SERVER`] header value.
67    ///
68    /// By default a versioned `rama` value is used.
69    pub fn set_server_header_value(&mut self, value: HeaderValue) -> &mut Self {
70        self.server_header_value = Some(value);
71        self
72    }
73}
74
75impl<S> Layer<S> for AddRequiredResponseHeadersLayer {
76    type Service = AddRequiredResponseHeaders<S>;
77
78    fn layer(&self, inner: S) -> Self::Service {
79        AddRequiredResponseHeaders {
80            inner,
81            overwrite: self.overwrite,
82            server_header_value: self.server_header_value.clone(),
83        }
84    }
85
86    fn into_layer(self, inner: S) -> Self::Service {
87        AddRequiredResponseHeaders {
88            inner,
89            overwrite: self.overwrite,
90            server_header_value: self.server_header_value,
91        }
92    }
93}
94
95/// Middleware that sets a header on the request.
96#[derive(Clone)]
97pub struct AddRequiredResponseHeaders<S> {
98    inner: S,
99    overwrite: bool,
100    server_header_value: Option<HeaderValue>,
101}
102
103impl<S> AddRequiredResponseHeaders<S> {
104    /// Create a new [`AddRequiredResponseHeaders`].
105    pub const fn new(inner: S) -> Self {
106        Self {
107            inner,
108            overwrite: false,
109            server_header_value: None,
110        }
111    }
112
113    /// Set whether to overwrite the existing headers.
114    /// If set to `true`, the headers will be overwritten.
115    ///
116    /// Default is `false`.
117    pub const fn overwrite(mut self, overwrite: bool) -> Self {
118        self.overwrite = overwrite;
119        self
120    }
121
122    /// Set whether to overwrite the existing headers.
123    /// If set to `true`, the headers will be overwritten.
124    ///
125    /// Default is `false`.
126    pub fn set_overwrite(&mut self, overwrite: bool) -> &mut Self {
127        self.overwrite = overwrite;
128        self
129    }
130
131    /// Set a custom [`SERVER`] header value.
132    ///
133    /// By default a versioned `rama` value is used.
134    pub fn server_header_value(mut self, value: HeaderValue) -> Self {
135        self.server_header_value = Some(value);
136        self
137    }
138
139    /// Maybe set a custom [`SERVER`] header value.
140    ///
141    /// By default a versioned `rama` value is used.
142    pub fn maybe_server_header_value(mut self, value: Option<HeaderValue>) -> Self {
143        self.server_header_value = value;
144        self
145    }
146
147    /// Set a custom [`SERVER`] header value.
148    ///
149    /// By default a versioned `rama` value is used.
150    pub fn set_server_header_value(&mut self, value: HeaderValue) -> &mut Self {
151        self.server_header_value = Some(value);
152        self
153    }
154
155    define_inner_service_accessors!();
156}
157
158impl<S> fmt::Debug for AddRequiredResponseHeaders<S>
159where
160    S: fmt::Debug,
161{
162    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
163        f.debug_struct("AddRequiredResponseHeaders")
164            .field("inner", &self.inner)
165            .field("server_header_value", &self.server_header_value)
166            .finish()
167    }
168}
169
170impl<ReqBody, ResBody, State, S> Service<State, Request<ReqBody>> for AddRequiredResponseHeaders<S>
171where
172    ReqBody: Send + 'static,
173    ResBody: Send + 'static,
174    State: Clone + Send + Sync + 'static,
175    S: Service<State, Request<ReqBody>, Response = Response<ResBody>>,
176{
177    type Response = S::Response;
178    type Error = S::Error;
179
180    async fn serve(
181        &self,
182        ctx: Context<State>,
183        req: Request<ReqBody>,
184    ) -> Result<Self::Response, Self::Error> {
185        let mut resp = self.inner.serve(ctx, req).await?;
186
187        if self.overwrite {
188            resp.headers_mut().insert(
189                SERVER,
190                self.server_header_value
191                    .as_ref()
192                    .unwrap_or(&RAMA_ID_HEADER_VALUE)
193                    .clone(),
194            );
195        } else if let header::Entry::Vacant(header) = resp.headers_mut().entry(SERVER) {
196            header.insert(
197                self.server_header_value
198                    .as_ref()
199                    .unwrap_or(&RAMA_ID_HEADER_VALUE)
200                    .clone(),
201            );
202        }
203
204        if self.overwrite || !resp.headers().contains_key(DATE) {
205            resp.headers_mut()
206                .typed_insert(Date::from(SystemTime::now()));
207        }
208
209        Ok(resp)
210    }
211}
212
213#[cfg(test)]
214mod tests {
215    use super::*;
216    use crate::Body;
217    use rama_core::{Layer, service::service_fn};
218    use std::convert::Infallible;
219
220    #[tokio::test]
221    async fn add_required_response_headers() {
222        let svc = AddRequiredResponseHeadersLayer::default().into_layer(service_fn(
223            async |_ctx: Context<()>, req: Request| {
224                assert!(!req.headers().contains_key(SERVER));
225                assert!(!req.headers().contains_key(DATE));
226                Ok::<_, Infallible>(Response::new(Body::empty()))
227            },
228        ));
229
230        let req = Request::new(Body::empty());
231        let resp = svc.serve(Context::default(), req).await.unwrap();
232
233        assert_eq!(
234            resp.headers().get(SERVER).unwrap(),
235            RAMA_ID_HEADER_VALUE.as_ref()
236        );
237        assert!(resp.headers().contains_key(DATE));
238    }
239
240    #[tokio::test]
241    async fn add_required_response_headers_custom_server() {
242        let svc = AddRequiredResponseHeadersLayer::default()
243            .server_header_value(HeaderValue::from_static("foo"))
244            .into_layer(service_fn(async |_ctx: Context<()>, req: Request| {
245                assert!(!req.headers().contains_key(SERVER));
246                assert!(!req.headers().contains_key(DATE));
247                Ok::<_, Infallible>(Response::new(Body::empty()))
248            }));
249
250        let req = Request::new(Body::empty());
251        let resp = svc.serve(Context::default(), req).await.unwrap();
252
253        assert_eq!(
254            resp.headers().get(SERVER).and_then(|v| v.to_str().ok()),
255            Some("foo")
256        );
257        assert!(resp.headers().contains_key(DATE));
258    }
259
260    #[tokio::test]
261    async fn add_required_response_headers_overwrite() {
262        let svc = AddRequiredResponseHeadersLayer::new()
263            .overwrite(true)
264            .into_layer(service_fn(async |_ctx: Context<()>, req: Request| {
265                assert!(!req.headers().contains_key(SERVER));
266                assert!(!req.headers().contains_key(DATE));
267                Ok::<_, Infallible>(
268                    Response::builder()
269                        .header(SERVER, "foo")
270                        .header(DATE, "bar")
271                        .body(Body::empty())
272                        .unwrap(),
273                )
274            }));
275
276        let req = Request::new(Body::empty());
277        let resp = svc.serve(Context::default(), req).await.unwrap();
278
279        assert_eq!(
280            resp.headers().get(SERVER).unwrap(),
281            RAMA_ID_HEADER_VALUE.to_str().unwrap()
282        );
283        assert_ne!(resp.headers().get(DATE).unwrap(), "bar");
284    }
285
286    #[tokio::test]
287    async fn add_required_response_headers_overwrite_custom_ua() {
288        let svc = AddRequiredResponseHeadersLayer::new()
289            .overwrite(true)
290            .server_header_value(HeaderValue::from_static("foo"))
291            .into_layer(service_fn(async |_ctx: Context<()>, req: Request| {
292                assert!(!req.headers().contains_key(SERVER));
293                assert!(!req.headers().contains_key(DATE));
294                Ok::<_, Infallible>(
295                    Response::builder()
296                        .header(SERVER, "foo")
297                        .header(DATE, "bar")
298                        .body(Body::empty())
299                        .unwrap(),
300                )
301            }));
302
303        let req = Request::new(Body::empty());
304        let resp = svc.serve(Context::default(), req).await.unwrap();
305
306        assert_eq!(
307            resp.headers().get(SERVER).and_then(|v| v.to_str().ok()),
308            Some("foo")
309        );
310        assert_ne!(resp.headers().get(DATE).unwrap(), "bar");
311    }
312}