rama_http/layer/set_header/response/
mod.rs

1//! Set a header on the response.
2//!
3//! The header value to be set may be provided as a fixed value when the
4//! middleware is constructed, or determined dynamically based on the response
5//! by a closure. See the [`MakeHeaderValue`] trait for details.
6//!
7//! # Example
8//!
9//! Setting a header from a fixed value provided when the middleware is constructed:
10//!
11//! ```
12//! use rama_http::layer::set_header::SetResponseHeaderLayer;
13//! use rama_http::{Body, Request, Response, header::{self, HeaderValue}};
14//! use rama_core::service::service_fn;
15//! use rama_core::{Context, Service, Layer};
16//! use rama_core::error::BoxError;
17//!
18//! # #[tokio::main]
19//! # async fn main() -> Result<(), BoxError> {
20//! # let render_html = service_fn(async |request: Request| {
21//! #     Ok::<_, std::convert::Infallible>(Response::new(request.into_body()))
22//! # });
23//! #
24//! let mut svc = (
25//!     // Layer that sets `Content-Type: text/html` on responses.
26//!     //
27//!     // `if_not_present` will only insert the header if it does not already
28//!     // have a value.
29//!     SetResponseHeaderLayer::if_not_present(
30//!         header::CONTENT_TYPE,
31//!         HeaderValue::from_static("text/html"),
32//!     ),
33//! ).into_layer(render_html);
34//!
35//! let request = Request::new(Body::empty());
36//!
37//! let response = svc.serve(Context::default(), request).await?;
38//!
39//! assert_eq!(response.headers()["content-type"], "text/html");
40//! #
41//! # Ok(())
42//! # }
43//! ```
44//!
45//! Setting a header based on a value determined dynamically from the response:
46//!
47//! ```
48//! use rama_core::error::BoxError;
49//! use rama_core::service::service_fn;
50//! use rama_core::{Context, Layer, Service};
51//! use rama_http::dep::http_body::Body as _;
52//! use rama_http::layer::set_header::SetResponseHeaderLayer;
53//! use rama_http::{
54//!     header::{self, HeaderValue},
55//!     Body, Request, Response,
56//! };
57//!
58//! #[tokio::main]
59//! async fn main() -> Result<(), BoxError> {
60//!     let render_html = service_fn(async |_request: Request| {
61//!         Ok::<_, std::convert::Infallible>(Response::new(Body::from("1234567890")))
62//!     });
63//!
64//!     let svc = (
65//!         // Layer that sets `Content-Length` if the body has a known size.
66//!         // Bodies with streaming responses wont have a known size.
67//!         //
68//!         // `overriding` will insert the header and override any previous values it
69//!         // may have.
70//!         SetResponseHeaderLayer::overriding_fn(
71//!             header::CONTENT_LENGTH,
72//!             async |response: Response| {
73//!                 let value = if let Some(size) = response.body().size_hint().exact() {
74//!                     // If the response body has a known size, returning `Some` will
75//!                     // set the `Content-Length` header to that value.
76//!                     Some(HeaderValue::from_str(&size.to_string()).unwrap())
77//!                 } else {
78//!                     // If the response body doesn't have a known size, return `None`
79//!                     // to skip setting the header on this response.
80//!                     None
81//!                 };
82//!                 (response, value)
83//!             },
84//!         ),
85//!     )
86//!         .into_layer(render_html);
87//!
88//!     let request = Request::new(Body::empty());
89//!
90//!     let response = svc.serve(Context::default(), request).await?;
91//!
92//!     assert_eq!(response.headers()["content-length"], "10");
93//!
94//!     Ok(())
95//! }
96//! ```
97//!
98//! Setting a header based on the incoming Context and response combined.
99//!
100//! ```
101//! use rama_core::{service::service_fn, Context, Service};
102//! use rama_http::{
103//!     layer::set_header::{response::BoxMakeHeaderValueFn, SetResponseHeader},
104//!     Body, HeaderName, HeaderValue, IntoResponse, Request, Response,
105//! };
106//! use std::convert::Infallible;
107//!
108//! #[tokio::main]
109//! async fn main() {
110//!     #[derive(Debug, Clone)]
111//!     struct RequestID(String);
112//!
113//!     #[derive(Debug, Clone)]
114//!     struct Success;
115//!
116//!     let svc = SetResponseHeader::overriding_fn(
117//!         service_fn(async || {
118//!             let mut res = ().into_response();
119//!             res.extensions_mut().insert(Success);
120//!             Ok::<_, Infallible>(res)
121//!         }),
122//!         HeaderName::from_static("x-used-request-id"),
123//!         async |ctx: Context<()>| {
124//!             let factory = ctx.get::<RequestID>().cloned().map(|id| {
125//!                 BoxMakeHeaderValueFn::new(async move |res: Response| {
126//!                     let header_value = res.extensions().get::<Success>().map(|_| {
127//!                         HeaderValue::from_str(id.0.as_str()).unwrap()
128//!                     });
129//!                     (res, header_value)
130//!                 })
131//!             });
132//!             (ctx, factory)
133//!         },
134//!     );
135//!
136//!     const FAKE_USER_ID: &str = "abc123";
137//!
138//!     let mut ctx = Context::default();
139//!     ctx.insert(RequestID(FAKE_USER_ID.to_owned()));
140//!
141//!     let res = svc.serve(ctx, Request::new(Body::empty())).await.unwrap();
142//!
143//!     let mut values = res
144//!         .headers()
145//!         .get_all(HeaderName::from_static("x-used-request-id"))
146//!         .iter();
147//!     assert_eq!(values.next().unwrap(), FAKE_USER_ID);
148//!     assert_eq!(values.next(), None);
149//! }
150//! ```
151
152use crate::{
153    HeaderValue, Request, Response,
154    header::HeaderName,
155    headers::{Header, HeaderExt},
156};
157use rama_core::{Context, Layer, Service};
158use rama_utils::macros::define_inner_service_accessors;
159use std::fmt;
160
161mod header;
162use header::InsertHeaderMode;
163
164pub use header::{
165    BoxMakeHeaderValueFactoryFn, BoxMakeHeaderValueFn, MakeHeaderValue, MakeHeaderValueFactory,
166};
167
168/// Layer that applies [`SetResponseHeader`] which adds a response header.
169///
170/// See [`SetResponseHeader`] for more details.
171pub struct SetResponseHeaderLayer<M> {
172    header_name: HeaderName,
173    make: M,
174    mode: InsertHeaderMode,
175}
176
177impl<M> fmt::Debug for SetResponseHeaderLayer<M> {
178    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
179        f.debug_struct("SetResponseHeaderLayer")
180            .field("header_name", &self.header_name)
181            .field("mode", &self.mode)
182            .field("make", &std::any::type_name::<M>())
183            .finish()
184    }
185}
186
187impl SetResponseHeaderLayer<HeaderValue> {
188    /// Create a new [`SetResponseHeaderLayer`] from a typed [`Header`].
189    ///
190    /// See [`SetResponseHeaderLayer::overriding`] for more details.
191    pub fn overriding_typed<H: Header>(header: H) -> Self {
192        Self::overriding(H::name().clone(), header.encode_to_value())
193    }
194
195    /// Create a new [`SetResponseHeaderLayer`] from a typed [`Header`].
196    ///
197    /// See [`SetResponseHeaderLayer::appending`] for more details.
198    pub fn appending_typed<H: Header>(header: H) -> Self {
199        Self::appending(H::name().clone(), header.encode_to_value())
200    }
201
202    /// Create a new [`SetResponseHeaderLayer`] from a typed [`Header`].
203    ///
204    /// See [`SetResponseHeaderLayer::if_not_present`] for more details.
205    pub fn if_not_present_typed<H: Header>(header: H) -> Self {
206        Self::if_not_present(H::name().clone(), header.encode_to_value())
207    }
208}
209
210impl<M> SetResponseHeaderLayer<M> {
211    /// Create a new [`SetResponseHeaderLayer`].
212    ///
213    /// If a previous value exists for the same header, it is removed and replaced with the new
214    /// header value.
215    pub fn overriding(header_name: HeaderName, make: M) -> Self {
216        Self::new(header_name, make, InsertHeaderMode::Override)
217    }
218
219    /// Create a new [`SetResponseHeaderLayer`].
220    ///
221    /// The new header is always added, preserving any existing values. If previous values exist,
222    /// the header will have multiple values.
223    pub fn appending(header_name: HeaderName, make: M) -> Self {
224        Self::new(header_name, make, InsertHeaderMode::Append)
225    }
226
227    /// Create a new [`SetResponseHeaderLayer`].
228    ///
229    /// If a previous value exists for the header, the new value is not inserted.
230    pub fn if_not_present(header_name: HeaderName, make: M) -> Self {
231        Self::new(header_name, make, InsertHeaderMode::IfNotPresent)
232    }
233
234    fn new(header_name: HeaderName, make: M, mode: InsertHeaderMode) -> Self {
235        Self {
236            make,
237            header_name,
238            mode,
239        }
240    }
241}
242
243impl<F, A> SetResponseHeaderLayer<BoxMakeHeaderValueFactoryFn<F, A>> {
244    /// Create a new [`SetResponseHeaderLayer`] from a [`super::MakeHeaderValueFn`].
245    ///
246    /// See [`SetResponseHeaderLayer::overriding`] for more details.
247    pub fn overriding_fn(header_name: HeaderName, make_fn: F) -> Self {
248        Self::new(
249            header_name,
250            BoxMakeHeaderValueFactoryFn::new(make_fn),
251            InsertHeaderMode::Override,
252        )
253    }
254
255    /// Create a new [`SetResponseHeaderLayer`] from a [`super::MakeHeaderValueFn`].
256    ///
257    /// See [`SetResponseHeaderLayer::appending`] for more details.
258    pub fn appending_fn(header_name: HeaderName, make_fn: F) -> Self {
259        Self::new(
260            header_name,
261            BoxMakeHeaderValueFactoryFn::new(make_fn),
262            InsertHeaderMode::Append,
263        )
264    }
265
266    /// Create a new [`SetResponseHeaderLayer`] from a [`super::MakeHeaderValueFn`].
267    ///
268    /// See [`SetResponseHeaderLayer::if_not_present`] for more details.
269    pub fn if_not_present_fn(header_name: HeaderName, make_fn: F) -> Self {
270        Self::new(
271            header_name,
272            BoxMakeHeaderValueFactoryFn::new(make_fn),
273            InsertHeaderMode::IfNotPresent,
274        )
275    }
276}
277
278impl<S, M> Layer<S> for SetResponseHeaderLayer<M>
279where
280    M: Clone,
281{
282    type Service = SetResponseHeader<S, M>;
283
284    fn layer(&self, inner: S) -> Self::Service {
285        SetResponseHeader {
286            inner,
287            header_name: self.header_name.clone(),
288            make: self.make.clone(),
289            mode: self.mode,
290        }
291    }
292
293    fn into_layer(self, inner: S) -> Self::Service {
294        SetResponseHeader {
295            inner,
296            header_name: self.header_name,
297            make: self.make,
298            mode: self.mode,
299        }
300    }
301}
302
303impl<M> Clone for SetResponseHeaderLayer<M>
304where
305    M: Clone,
306{
307    fn clone(&self) -> Self {
308        Self {
309            make: self.make.clone(),
310            header_name: self.header_name.clone(),
311            mode: self.mode,
312        }
313    }
314}
315
316/// Middleware that sets a header on the response.
317#[derive(Clone)]
318pub struct SetResponseHeader<S, M> {
319    inner: S,
320    header_name: HeaderName,
321    make: M,
322    mode: InsertHeaderMode,
323}
324
325impl<S, M> SetResponseHeader<S, M> {
326    /// Create a new [`SetResponseHeader`].
327    ///
328    /// If a previous value exists for the same header, it is removed and replaced with the new
329    /// header value.
330    pub fn overriding(inner: S, header_name: HeaderName, make: M) -> Self {
331        Self::new(inner, header_name, make, InsertHeaderMode::Override)
332    }
333
334    /// Create a new [`SetResponseHeader`].
335    ///
336    /// The new header is always added, preserving any existing values. If previous values exist,
337    /// the header will have multiple values.
338    pub fn appending(inner: S, header_name: HeaderName, make: M) -> Self {
339        Self::new(inner, header_name, make, InsertHeaderMode::Append)
340    }
341
342    /// Create a new [`SetResponseHeader`].
343    ///
344    /// If a previous value exists for the header, the new value is not inserted.
345    pub fn if_not_present(inner: S, header_name: HeaderName, make: M) -> Self {
346        Self::new(inner, header_name, make, InsertHeaderMode::IfNotPresent)
347    }
348
349    fn new(inner: S, header_name: HeaderName, make: M, mode: InsertHeaderMode) -> Self {
350        Self {
351            inner,
352            header_name,
353            make,
354            mode,
355        }
356    }
357
358    define_inner_service_accessors!();
359}
360
361impl<S, F, A> SetResponseHeader<S, BoxMakeHeaderValueFactoryFn<F, A>> {
362    /// Create a new [`SetResponseHeader`] from a [`super::MakeHeaderValueFn`].
363    ///
364    /// See [`SetResponseHeader::overriding`] for more details.
365    pub fn overriding_fn(inner: S, header_name: HeaderName, make_fn: F) -> Self {
366        Self::new(
367            inner,
368            header_name,
369            BoxMakeHeaderValueFactoryFn::new(make_fn),
370            InsertHeaderMode::Override,
371        )
372    }
373
374    /// Create a new [`SetResponseHeader`] from a [`super::MakeHeaderValueFn`].
375    ///
376    /// See [`SetResponseHeader::appending`] for more details.
377    pub fn appending_fn(inner: S, header_name: HeaderName, make_fn: F) -> Self {
378        Self::new(
379            inner,
380            header_name,
381            BoxMakeHeaderValueFactoryFn::new(make_fn),
382            InsertHeaderMode::Append,
383        )
384    }
385
386    /// Create a new [`SetResponseHeader`] from a [`super::MakeHeaderValueFn`].
387    ///
388    /// See [`SetResponseHeader::if_not_present`] for more details.
389    pub fn if_not_present_fn(inner: S, header_name: HeaderName, make_fn: F) -> Self {
390        Self::new(
391            inner,
392            header_name,
393            BoxMakeHeaderValueFactoryFn::new(make_fn),
394            InsertHeaderMode::IfNotPresent,
395        )
396    }
397}
398
399impl<S, M> fmt::Debug for SetResponseHeader<S, M>
400where
401    S: fmt::Debug,
402{
403    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
404        f.debug_struct("SetResponseHeader")
405            .field("inner", &self.inner)
406            .field("header_name", &self.header_name)
407            .field("mode", &self.mode)
408            .field("make", &std::any::type_name::<M>())
409            .finish()
410    }
411}
412
413impl<ReqBody, ResBody, State, S, M> Service<State, Request<ReqBody>> for SetResponseHeader<S, M>
414where
415    ReqBody: Send + 'static,
416    ResBody: Send + 'static,
417    State: Clone + Send + Sync + 'static,
418    S: Service<State, Request<ReqBody>, Response = Response<ResBody>>,
419    M: MakeHeaderValueFactory<State, ReqBody, ResBody>,
420{
421    type Response = S::Response;
422    type Error = S::Error;
423
424    async fn serve(
425        &self,
426        ctx: Context<State>,
427        req: Request<ReqBody>,
428    ) -> Result<Self::Response, Self::Error> {
429        let (ctx, req, header_maker) = self.make.make_header_value_maker(ctx, req).await;
430        let res = self.inner.serve(ctx, req).await?;
431        let res = self.mode.apply(&self.header_name, res, header_maker).await;
432        Ok(res)
433    }
434}
435
436#[cfg(test)]
437mod tests {
438    use super::*;
439
440    use crate::{Body, HeaderValue, Request, Response, header};
441    use rama_core::service::service_fn;
442    use std::convert::Infallible;
443
444    #[tokio::test]
445    async fn test_override_mode() {
446        let svc = SetResponseHeader::overriding(
447            service_fn(async || {
448                let res = Response::builder()
449                    .header(header::CONTENT_TYPE, "good-content")
450                    .body(Body::empty())
451                    .unwrap();
452                Ok::<_, Infallible>(res)
453            }),
454            header::CONTENT_TYPE,
455            HeaderValue::from_static("text/html"),
456        );
457
458        let res = svc
459            .serve(Context::default(), Request::new(Body::empty()))
460            .await
461            .unwrap();
462
463        let mut values = res.headers().get_all(header::CONTENT_TYPE).iter();
464        assert_eq!(values.next().unwrap(), "text/html");
465        assert_eq!(values.next(), None);
466    }
467
468    #[tokio::test]
469    async fn test_append_mode() {
470        let svc = SetResponseHeader::appending(
471            service_fn(async || {
472                let res = Response::builder()
473                    .header(header::CONTENT_TYPE, "good-content")
474                    .body(Body::empty())
475                    .unwrap();
476                Ok::<_, Infallible>(res)
477            }),
478            header::CONTENT_TYPE,
479            HeaderValue::from_static("text/html"),
480        );
481
482        let res = svc
483            .serve(Context::default(), Request::new(Body::empty()))
484            .await
485            .unwrap();
486
487        let mut values = res.headers().get_all(header::CONTENT_TYPE).iter();
488        assert_eq!(values.next().unwrap(), "good-content");
489        assert_eq!(values.next().unwrap(), "text/html");
490        assert_eq!(values.next(), None);
491    }
492
493    #[tokio::test]
494    async fn test_skip_if_present_mode() {
495        let svc = SetResponseHeader::if_not_present(
496            service_fn(async || {
497                let res = Response::builder()
498                    .header(header::CONTENT_TYPE, "good-content")
499                    .body(Body::empty())
500                    .unwrap();
501                Ok::<_, Infallible>(res)
502            }),
503            header::CONTENT_TYPE,
504            HeaderValue::from_static("text/html"),
505        );
506
507        let res = svc
508            .serve(Context::default(), Request::new(Body::empty()))
509            .await
510            .unwrap();
511
512        let mut values = res.headers().get_all(header::CONTENT_TYPE).iter();
513        assert_eq!(values.next().unwrap(), "good-content");
514        assert_eq!(values.next(), None);
515    }
516
517    #[tokio::test]
518    async fn test_skip_if_present_mode_when_not_present() {
519        let svc = SetResponseHeader::if_not_present(
520            service_fn(async || {
521                let res = Response::builder().body(Body::empty()).unwrap();
522                Ok::<_, Infallible>(res)
523            }),
524            header::CONTENT_TYPE,
525            HeaderValue::from_static("text/html"),
526        );
527
528        let res = svc
529            .serve(Context::default(), Request::new(Body::empty()))
530            .await
531            .unwrap();
532
533        let mut values = res.headers().get_all(header::CONTENT_TYPE).iter();
534        assert_eq!(values.next().unwrap(), "text/html");
535        assert_eq!(values.next(), None);
536    }
537}