rama_http/layer/set_header/response/
mod.rs

1//! Set a header on the response.
2//!
3//! The header value to be set may be provided as a fixed value when the
4//! middleware is constructed, or determined dynamically based on the response
5//! by a closure. See the [`MakeHeaderValue`] trait for details.
6//!
7//! # Example
8//!
9//! Setting a header from a fixed value provided when the middleware is constructed:
10//!
11//! ```
12//! use rama_http::layer::set_header::SetResponseHeaderLayer;
13//! use rama_http::{Body, Request, Response, header::{self, HeaderValue}};
14//! use rama_core::service::service_fn;
15//! use rama_core::{Context, Service, Layer};
16//! use rama_core::error::BoxError;
17//!
18//! # #[tokio::main]
19//! # async fn main() -> Result<(), BoxError> {
20//! # let render_html = service_fn(async |request: Request| {
21//! #     Ok::<_, std::convert::Infallible>(Response::new(request.into_body()))
22//! # });
23//! #
24//! let mut svc = (
25//!     // Layer that sets `Content-Type: text/html` on responses.
26//!     //
27//!     // `if_not_present` will only insert the header if it does not already
28//!     // have a value.
29//!     SetResponseHeaderLayer::if_not_present(
30//!         header::CONTENT_TYPE,
31//!         HeaderValue::from_static("text/html"),
32//!     ),
33//! ).into_layer(render_html);
34//!
35//! let request = Request::new(Body::empty());
36//!
37//! let response = svc.serve(Context::default(), request).await?;
38//!
39//! assert_eq!(response.headers()["content-type"], "text/html");
40//! #
41//! # Ok(())
42//! # }
43//! ```
44//!
45//! Setting a header based on a value determined dynamically from the response:
46//!
47//! ```
48//! use rama_core::error::BoxError;
49//! use rama_core::service::service_fn;
50//! use rama_core::{Context, Layer, Service};
51//! use rama_http::dep::http_body::Body as _;
52//! use rama_http::layer::set_header::SetResponseHeaderLayer;
53//! use rama_http::{
54//!     header::{self, HeaderValue},
55//!     Body, Request, Response,
56//! };
57//!
58//! #[tokio::main]
59//! async fn main() -> Result<(), BoxError> {
60//!     let render_html = service_fn(async |_request: Request| {
61//!         Ok::<_, std::convert::Infallible>(Response::new(Body::from("1234567890")))
62//!     });
63//!
64//!     let svc = (
65//!         // Layer that sets `Content-Length` if the body has a known size.
66//!         // Bodies with streaming responses wont have a known size.
67//!         //
68//!         // `overriding` will insert the header and override any previous values it
69//!         // may have.
70//!         SetResponseHeaderLayer::overriding_fn(
71//!             header::CONTENT_LENGTH,
72//!             async |response: Response| {
73//!                 let value = if let Some(size) = response.body().size_hint().exact() {
74//!                     // If the response body has a known size, returning `Some` will
75//!                     // set the `Content-Length` header to that value.
76//!                     Some(HeaderValue::from_str(&size.to_string()).unwrap())
77//!                 } else {
78//!                     // If the response body doesn't have a known size, return `None`
79//!                     // to skip setting the header on this response.
80//!                     None
81//!                 };
82//!                 (response, value)
83//!             },
84//!         ),
85//!     )
86//!         .into_layer(render_html);
87//!
88//!     let request = Request::new(Body::empty());
89//!
90//!     let response = svc.serve(Context::default(), request).await?;
91//!
92//!     assert_eq!(response.headers()["content-length"], "10");
93//!
94//!     Ok(())
95//! }
96//! ```
97//!
98//! Setting a header based on the incoming Context and response combined.
99//!
100//! ```
101//! use rama_core::{service::service_fn, Context, Service};
102//! use rama_http::{
103//!     layer::set_header::{response::BoxMakeHeaderValueFn, SetResponseHeader},
104//!     Body, HeaderName, HeaderValue, Request, Response,
105//!     service::web::response::IntoResponse,
106//! };
107//! use std::convert::Infallible;
108//!
109//! #[tokio::main]
110//! async fn main() {
111//!     #[derive(Debug, Clone)]
112//!     struct RequestID(String);
113//!
114//!     #[derive(Debug, Clone)]
115//!     struct Success;
116//!
117//!     let svc = SetResponseHeader::overriding_fn(
118//!         service_fn(async || {
119//!             let mut res = ().into_response();
120//!             res.extensions_mut().insert(Success);
121//!             Ok::<_, Infallible>(res)
122//!         }),
123//!         HeaderName::from_static("x-used-request-id"),
124//!         async |ctx: Context<()>| {
125//!             let factory = ctx.get::<RequestID>().cloned().map(|id| {
126//!                 BoxMakeHeaderValueFn::new(async move |res: Response| {
127//!                     let header_value = res.extensions().get::<Success>().map(|_| {
128//!                         HeaderValue::from_str(id.0.as_str()).unwrap()
129//!                     });
130//!                     (res, header_value)
131//!                 })
132//!             });
133//!             (ctx, factory)
134//!         },
135//!     );
136//!
137//!     const FAKE_USER_ID: &str = "abc123";
138//!
139//!     let mut ctx = Context::default();
140//!     ctx.insert(RequestID(FAKE_USER_ID.to_owned()));
141//!
142//!     let res = svc.serve(ctx, Request::new(Body::empty())).await.unwrap();
143//!
144//!     let mut values = res
145//!         .headers()
146//!         .get_all(HeaderName::from_static("x-used-request-id"))
147//!         .iter();
148//!     assert_eq!(values.next().unwrap(), FAKE_USER_ID);
149//!     assert_eq!(values.next(), None);
150//! }
151//! ```
152
153use crate::{HeaderValue, Request, Response, header::HeaderName, headers::Header};
154use rama_core::{Context, Layer, Service};
155use rama_utils::macros::define_inner_service_accessors;
156use std::fmt;
157
158mod header;
159use header::InsertHeaderMode;
160
161pub use header::{
162    BoxMakeHeaderValueFactoryFn, BoxMakeHeaderValueFn, MakeHeaderValue, MakeHeaderValueFactory,
163};
164
165/// Layer that applies [`SetResponseHeader`] which adds a response header.
166///
167/// See [`SetResponseHeader`] for more details.
168pub struct SetResponseHeaderLayer<M> {
169    header_name: HeaderName,
170    make: M,
171    mode: InsertHeaderMode,
172}
173
174impl<M> fmt::Debug for SetResponseHeaderLayer<M> {
175    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
176        f.debug_struct("SetResponseHeaderLayer")
177            .field("header_name", &self.header_name)
178            .field("mode", &self.mode)
179            .field("make", &std::any::type_name::<M>())
180            .finish()
181    }
182}
183
184impl SetResponseHeaderLayer<HeaderValue> {
185    /// Create a new [`SetResponseHeaderLayer`] from a typed [`Header`].
186    ///
187    /// See [`SetResponseHeaderLayer::overriding`] for more details.
188    pub fn overriding_typed<H: Header>(header: H) -> Self {
189        Self::overriding(H::name().clone(), header.encode_to_value())
190    }
191
192    /// Create a new [`SetResponseHeaderLayer`] from a typed [`Header`].
193    ///
194    /// See [`SetResponseHeaderLayer::appending`] for more details.
195    pub fn appending_typed<H: Header>(header: H) -> Self {
196        Self::appending(H::name().clone(), header.encode_to_value())
197    }
198
199    /// Create a new [`SetResponseHeaderLayer`] from a typed [`Header`].
200    ///
201    /// See [`SetResponseHeaderLayer::if_not_present`] for more details.
202    pub fn if_not_present_typed<H: Header>(header: H) -> Self {
203        Self::if_not_present(H::name().clone(), header.encode_to_value())
204    }
205}
206
207impl<M> SetResponseHeaderLayer<M> {
208    /// Create a new [`SetResponseHeaderLayer`].
209    ///
210    /// If a previous value exists for the same header, it is removed and replaced with the new
211    /// header value.
212    pub fn overriding(header_name: HeaderName, make: M) -> Self {
213        Self::new(header_name, make, InsertHeaderMode::Override)
214    }
215
216    /// Create a new [`SetResponseHeaderLayer`].
217    ///
218    /// The new header is always added, preserving any existing values. If previous values exist,
219    /// the header will have multiple values.
220    pub fn appending(header_name: HeaderName, make: M) -> Self {
221        Self::new(header_name, make, InsertHeaderMode::Append)
222    }
223
224    /// Create a new [`SetResponseHeaderLayer`].
225    ///
226    /// If a previous value exists for the header, the new value is not inserted.
227    pub fn if_not_present(header_name: HeaderName, make: M) -> Self {
228        Self::new(header_name, make, InsertHeaderMode::IfNotPresent)
229    }
230
231    fn new(header_name: HeaderName, make: M, mode: InsertHeaderMode) -> Self {
232        Self {
233            make,
234            header_name,
235            mode,
236        }
237    }
238}
239
240impl<F, A> SetResponseHeaderLayer<BoxMakeHeaderValueFactoryFn<F, A>> {
241    /// Create a new [`SetResponseHeaderLayer`] from a [`super::MakeHeaderValueFn`].
242    ///
243    /// See [`SetResponseHeaderLayer::overriding`] for more details.
244    pub fn overriding_fn(header_name: HeaderName, make_fn: F) -> Self {
245        Self::new(
246            header_name,
247            BoxMakeHeaderValueFactoryFn::new(make_fn),
248            InsertHeaderMode::Override,
249        )
250    }
251
252    /// Create a new [`SetResponseHeaderLayer`] from a [`super::MakeHeaderValueFn`].
253    ///
254    /// See [`SetResponseHeaderLayer::appending`] for more details.
255    pub fn appending_fn(header_name: HeaderName, make_fn: F) -> Self {
256        Self::new(
257            header_name,
258            BoxMakeHeaderValueFactoryFn::new(make_fn),
259            InsertHeaderMode::Append,
260        )
261    }
262
263    /// Create a new [`SetResponseHeaderLayer`] from a [`super::MakeHeaderValueFn`].
264    ///
265    /// See [`SetResponseHeaderLayer::if_not_present`] for more details.
266    pub fn if_not_present_fn(header_name: HeaderName, make_fn: F) -> Self {
267        Self::new(
268            header_name,
269            BoxMakeHeaderValueFactoryFn::new(make_fn),
270            InsertHeaderMode::IfNotPresent,
271        )
272    }
273}
274
275impl<S, M> Layer<S> for SetResponseHeaderLayer<M>
276where
277    M: Clone,
278{
279    type Service = SetResponseHeader<S, M>;
280
281    fn layer(&self, inner: S) -> Self::Service {
282        SetResponseHeader {
283            inner,
284            header_name: self.header_name.clone(),
285            make: self.make.clone(),
286            mode: self.mode,
287        }
288    }
289
290    fn into_layer(self, inner: S) -> Self::Service {
291        SetResponseHeader {
292            inner,
293            header_name: self.header_name,
294            make: self.make,
295            mode: self.mode,
296        }
297    }
298}
299
300impl<M> Clone for SetResponseHeaderLayer<M>
301where
302    M: Clone,
303{
304    fn clone(&self) -> Self {
305        Self {
306            make: self.make.clone(),
307            header_name: self.header_name.clone(),
308            mode: self.mode,
309        }
310    }
311}
312
313/// Middleware that sets a header on the response.
314#[derive(Clone)]
315pub struct SetResponseHeader<S, M> {
316    inner: S,
317    header_name: HeaderName,
318    make: M,
319    mode: InsertHeaderMode,
320}
321
322impl<S, M> SetResponseHeader<S, M> {
323    /// Create a new [`SetResponseHeader`].
324    ///
325    /// If a previous value exists for the same header, it is removed and replaced with the new
326    /// header value.
327    pub fn overriding(inner: S, header_name: HeaderName, make: M) -> Self {
328        Self::new(inner, header_name, make, InsertHeaderMode::Override)
329    }
330
331    /// Create a new [`SetResponseHeader`].
332    ///
333    /// The new header is always added, preserving any existing values. If previous values exist,
334    /// the header will have multiple values.
335    pub fn appending(inner: S, header_name: HeaderName, make: M) -> Self {
336        Self::new(inner, header_name, make, InsertHeaderMode::Append)
337    }
338
339    /// Create a new [`SetResponseHeader`].
340    ///
341    /// If a previous value exists for the header, the new value is not inserted.
342    pub fn if_not_present(inner: S, header_name: HeaderName, make: M) -> Self {
343        Self::new(inner, header_name, make, InsertHeaderMode::IfNotPresent)
344    }
345
346    fn new(inner: S, header_name: HeaderName, make: M, mode: InsertHeaderMode) -> Self {
347        Self {
348            inner,
349            header_name,
350            make,
351            mode,
352        }
353    }
354
355    define_inner_service_accessors!();
356}
357
358impl<S, F, A> SetResponseHeader<S, BoxMakeHeaderValueFactoryFn<F, A>> {
359    /// Create a new [`SetResponseHeader`] from a [`super::MakeHeaderValueFn`].
360    ///
361    /// See [`SetResponseHeader::overriding`] for more details.
362    pub fn overriding_fn(inner: S, header_name: HeaderName, make_fn: F) -> Self {
363        Self::new(
364            inner,
365            header_name,
366            BoxMakeHeaderValueFactoryFn::new(make_fn),
367            InsertHeaderMode::Override,
368        )
369    }
370
371    /// Create a new [`SetResponseHeader`] from a [`super::MakeHeaderValueFn`].
372    ///
373    /// See [`SetResponseHeader::appending`] for more details.
374    pub fn appending_fn(inner: S, header_name: HeaderName, make_fn: F) -> Self {
375        Self::new(
376            inner,
377            header_name,
378            BoxMakeHeaderValueFactoryFn::new(make_fn),
379            InsertHeaderMode::Append,
380        )
381    }
382
383    /// Create a new [`SetResponseHeader`] from a [`super::MakeHeaderValueFn`].
384    ///
385    /// See [`SetResponseHeader::if_not_present`] for more details.
386    pub fn if_not_present_fn(inner: S, header_name: HeaderName, make_fn: F) -> Self {
387        Self::new(
388            inner,
389            header_name,
390            BoxMakeHeaderValueFactoryFn::new(make_fn),
391            InsertHeaderMode::IfNotPresent,
392        )
393    }
394}
395
396impl<S, M> fmt::Debug for SetResponseHeader<S, M>
397where
398    S: fmt::Debug,
399{
400    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
401        f.debug_struct("SetResponseHeader")
402            .field("inner", &self.inner)
403            .field("header_name", &self.header_name)
404            .field("mode", &self.mode)
405            .field("make", &std::any::type_name::<M>())
406            .finish()
407    }
408}
409
410impl<ReqBody, ResBody, State, S, M> Service<State, Request<ReqBody>> for SetResponseHeader<S, M>
411where
412    ReqBody: Send + 'static,
413    ResBody: Send + 'static,
414    State: Clone + Send + Sync + 'static,
415    S: Service<State, Request<ReqBody>, Response = Response<ResBody>>,
416    M: MakeHeaderValueFactory<State, ReqBody, ResBody>,
417{
418    type Response = S::Response;
419    type Error = S::Error;
420
421    async fn serve(
422        &self,
423        ctx: Context<State>,
424        req: Request<ReqBody>,
425    ) -> Result<Self::Response, Self::Error> {
426        let (ctx, req, header_maker) = self.make.make_header_value_maker(ctx, req).await;
427        let res = self.inner.serve(ctx, req).await?;
428        let res = self.mode.apply(&self.header_name, res, header_maker).await;
429        Ok(res)
430    }
431}
432
433#[cfg(test)]
434mod tests {
435    use super::*;
436
437    use crate::{Body, HeaderValue, Request, Response, header};
438    use rama_core::service::service_fn;
439    use std::convert::Infallible;
440
441    #[tokio::test]
442    async fn test_override_mode() {
443        let svc = SetResponseHeader::overriding(
444            service_fn(async || {
445                let res = Response::builder()
446                    .header(header::CONTENT_TYPE, "good-content")
447                    .body(Body::empty())
448                    .unwrap();
449                Ok::<_, Infallible>(res)
450            }),
451            header::CONTENT_TYPE,
452            HeaderValue::from_static("text/html"),
453        );
454
455        let res = svc
456            .serve(Context::default(), Request::new(Body::empty()))
457            .await
458            .unwrap();
459
460        let mut values = res.headers().get_all(header::CONTENT_TYPE).iter();
461        assert_eq!(values.next().unwrap(), "text/html");
462        assert_eq!(values.next(), None);
463    }
464
465    #[tokio::test]
466    async fn test_append_mode() {
467        let svc = SetResponseHeader::appending(
468            service_fn(async || {
469                let res = Response::builder()
470                    .header(header::CONTENT_TYPE, "good-content")
471                    .body(Body::empty())
472                    .unwrap();
473                Ok::<_, Infallible>(res)
474            }),
475            header::CONTENT_TYPE,
476            HeaderValue::from_static("text/html"),
477        );
478
479        let res = svc
480            .serve(Context::default(), Request::new(Body::empty()))
481            .await
482            .unwrap();
483
484        let mut values = res.headers().get_all(header::CONTENT_TYPE).iter();
485        assert_eq!(values.next().unwrap(), "good-content");
486        assert_eq!(values.next().unwrap(), "text/html");
487        assert_eq!(values.next(), None);
488    }
489
490    #[tokio::test]
491    async fn test_skip_if_present_mode() {
492        let svc = SetResponseHeader::if_not_present(
493            service_fn(async || {
494                let res = Response::builder()
495                    .header(header::CONTENT_TYPE, "good-content")
496                    .body(Body::empty())
497                    .unwrap();
498                Ok::<_, Infallible>(res)
499            }),
500            header::CONTENT_TYPE,
501            HeaderValue::from_static("text/html"),
502        );
503
504        let res = svc
505            .serve(Context::default(), Request::new(Body::empty()))
506            .await
507            .unwrap();
508
509        let mut values = res.headers().get_all(header::CONTENT_TYPE).iter();
510        assert_eq!(values.next().unwrap(), "good-content");
511        assert_eq!(values.next(), None);
512    }
513
514    #[tokio::test]
515    async fn test_skip_if_present_mode_when_not_present() {
516        let svc = SetResponseHeader::if_not_present(
517            service_fn(async || {
518                let res = Response::builder().body(Body::empty()).unwrap();
519                Ok::<_, Infallible>(res)
520            }),
521            header::CONTENT_TYPE,
522            HeaderValue::from_static("text/html"),
523        );
524
525        let res = svc
526            .serve(Context::default(), Request::new(Body::empty()))
527            .await
528            .unwrap();
529
530        let mut values = res.headers().get_all(header::CONTENT_TYPE).iter();
531        assert_eq!(values.next().unwrap(), "text/html");
532        assert_eq!(values.next(), None);
533    }
534}