rama_http/layer/set_header/response/mod.rs
1//! Set a header on the response.
2//!
3//! The header value to be set may be provided as a fixed value when the
4//! middleware is constructed, or determined dynamically based on the response
5//! by a closure. See the [`MakeHeaderValue`] trait for details.
6//!
7//! # Example
8//!
9//! Setting a header from a fixed value provided when the middleware is constructed:
10//!
11//! ```
12//! use rama_http::layer::set_header::SetResponseHeaderLayer;
13//! use rama_http::{Body, Request, Response, header::{self, HeaderValue}};
14//! use rama_core::service::service_fn;
15//! use rama_core::{Context, Service, Layer};
16//! use rama_core::error::BoxError;
17//!
18//! # #[tokio::main]
19//! # async fn main() -> Result<(), BoxError> {
20//! # let render_html = service_fn(async |request: Request| {
21//! # Ok::<_, std::convert::Infallible>(Response::new(request.into_body()))
22//! # });
23//! #
24//! let mut svc = (
25//! // Layer that sets `Content-Type: text/html` on responses.
26//! //
27//! // `if_not_present` will only insert the header if it does not already
28//! // have a value.
29//! SetResponseHeaderLayer::if_not_present(
30//! header::CONTENT_TYPE,
31//! HeaderValue::from_static("text/html"),
32//! ),
33//! ).into_layer(render_html);
34//!
35//! let request = Request::new(Body::empty());
36//!
37//! let response = svc.serve(Context::default(), request).await?;
38//!
39//! assert_eq!(response.headers()["content-type"], "text/html");
40//! #
41//! # Ok(())
42//! # }
43//! ```
44//!
45//! Setting a header based on a value determined dynamically from the response:
46//!
47//! ```
48//! use rama_core::error::BoxError;
49//! use rama_core::service::service_fn;
50//! use rama_core::{Context, Layer, Service};
51//! use rama_http::dep::http_body::Body as _;
52//! use rama_http::layer::set_header::SetResponseHeaderLayer;
53//! use rama_http::{
54//! header::{self, HeaderValue},
55//! Body, Request, Response,
56//! };
57//!
58//! #[tokio::main]
59//! async fn main() -> Result<(), BoxError> {
60//! let render_html = service_fn(async |_request: Request| {
61//! Ok::<_, std::convert::Infallible>(Response::new(Body::from("1234567890")))
62//! });
63//!
64//! let svc = (
65//! // Layer that sets `Content-Length` if the body has a known size.
66//! // Bodies with streaming responses wont have a known size.
67//! //
68//! // `overriding` will insert the header and override any previous values it
69//! // may have.
70//! SetResponseHeaderLayer::overriding_fn(
71//! header::CONTENT_LENGTH,
72//! async |response: Response| {
73//! let value = if let Some(size) = response.body().size_hint().exact() {
74//! // If the response body has a known size, returning `Some` will
75//! // set the `Content-Length` header to that value.
76//! Some(HeaderValue::from_str(&size.to_string()).unwrap())
77//! } else {
78//! // If the response body doesn't have a known size, return `None`
79//! // to skip setting the header on this response.
80//! None
81//! };
82//! (response, value)
83//! },
84//! ),
85//! )
86//! .into_layer(render_html);
87//!
88//! let request = Request::new(Body::empty());
89//!
90//! let response = svc.serve(Context::default(), request).await?;
91//!
92//! assert_eq!(response.headers()["content-length"], "10");
93//!
94//! Ok(())
95//! }
96//! ```
97//!
98//! Setting a header based on the incoming Context and response combined.
99//!
100//! ```
101//! use rama_core::{service::service_fn, Context, Service};
102//! use rama_http::{
103//! layer::set_header::{response::BoxMakeHeaderValueFn, SetResponseHeader},
104//! Body, HeaderName, HeaderValue, Request, Response,
105//! service::web::response::IntoResponse,
106//! };
107//! use std::convert::Infallible;
108//!
109//! #[tokio::main]
110//! async fn main() {
111//! #[derive(Debug, Clone)]
112//! struct RequestID(String);
113//!
114//! #[derive(Debug, Clone)]
115//! struct Success;
116//!
117//! let svc = SetResponseHeader::overriding_fn(
118//! service_fn(async || {
119//! let mut res = ().into_response();
120//! res.extensions_mut().insert(Success);
121//! Ok::<_, Infallible>(res)
122//! }),
123//! HeaderName::from_static("x-used-request-id"),
124//! async |ctx: Context<()>| {
125//! let factory = ctx.get::<RequestID>().cloned().map(|id| {
126//! BoxMakeHeaderValueFn::new(async move |res: Response| {
127//! let header_value = res.extensions().get::<Success>().map(|_| {
128//! HeaderValue::from_str(id.0.as_str()).unwrap()
129//! });
130//! (res, header_value)
131//! })
132//! });
133//! (ctx, factory)
134//! },
135//! );
136//!
137//! const FAKE_USER_ID: &str = "abc123";
138//!
139//! let mut ctx = Context::default();
140//! ctx.insert(RequestID(FAKE_USER_ID.to_owned()));
141//!
142//! let res = svc.serve(ctx, Request::new(Body::empty())).await.unwrap();
143//!
144//! let mut values = res
145//! .headers()
146//! .get_all(HeaderName::from_static("x-used-request-id"))
147//! .iter();
148//! assert_eq!(values.next().unwrap(), FAKE_USER_ID);
149//! assert_eq!(values.next(), None);
150//! }
151//! ```
152
153use crate::{HeaderValue, Request, Response, header::HeaderName, headers::Header};
154use rama_core::{Context, Layer, Service};
155use rama_utils::macros::define_inner_service_accessors;
156use std::fmt;
157
158mod header;
159use header::InsertHeaderMode;
160
161pub use header::{
162 BoxMakeHeaderValueFactoryFn, BoxMakeHeaderValueFn, MakeHeaderValue, MakeHeaderValueFactory,
163};
164
165/// Layer that applies [`SetResponseHeader`] which adds a response header.
166///
167/// See [`SetResponseHeader`] for more details.
168pub struct SetResponseHeaderLayer<M> {
169 header_name: HeaderName,
170 make: M,
171 mode: InsertHeaderMode,
172}
173
174impl<M> fmt::Debug for SetResponseHeaderLayer<M> {
175 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
176 f.debug_struct("SetResponseHeaderLayer")
177 .field("header_name", &self.header_name)
178 .field("mode", &self.mode)
179 .field("make", &std::any::type_name::<M>())
180 .finish()
181 }
182}
183
184impl SetResponseHeaderLayer<HeaderValue> {
185 /// Create a new [`SetResponseHeaderLayer`] from a typed [`Header`].
186 ///
187 /// See [`SetResponseHeaderLayer::overriding`] for more details.
188 pub fn overriding_typed<H: Header>(header: H) -> Self {
189 Self::overriding(H::name().clone(), header.encode_to_value())
190 }
191
192 /// Create a new [`SetResponseHeaderLayer`] from a typed [`Header`].
193 ///
194 /// See [`SetResponseHeaderLayer::appending`] for more details.
195 pub fn appending_typed<H: Header>(header: H) -> Self {
196 Self::appending(H::name().clone(), header.encode_to_value())
197 }
198
199 /// Create a new [`SetResponseHeaderLayer`] from a typed [`Header`].
200 ///
201 /// See [`SetResponseHeaderLayer::if_not_present`] for more details.
202 pub fn if_not_present_typed<H: Header>(header: H) -> Self {
203 Self::if_not_present(H::name().clone(), header.encode_to_value())
204 }
205}
206
207impl<M> SetResponseHeaderLayer<M> {
208 /// Create a new [`SetResponseHeaderLayer`].
209 ///
210 /// If a previous value exists for the same header, it is removed and replaced with the new
211 /// header value.
212 pub fn overriding(header_name: HeaderName, make: M) -> Self {
213 Self::new(header_name, make, InsertHeaderMode::Override)
214 }
215
216 /// Create a new [`SetResponseHeaderLayer`].
217 ///
218 /// The new header is always added, preserving any existing values. If previous values exist,
219 /// the header will have multiple values.
220 pub fn appending(header_name: HeaderName, make: M) -> Self {
221 Self::new(header_name, make, InsertHeaderMode::Append)
222 }
223
224 /// Create a new [`SetResponseHeaderLayer`].
225 ///
226 /// If a previous value exists for the header, the new value is not inserted.
227 pub fn if_not_present(header_name: HeaderName, make: M) -> Self {
228 Self::new(header_name, make, InsertHeaderMode::IfNotPresent)
229 }
230
231 fn new(header_name: HeaderName, make: M, mode: InsertHeaderMode) -> Self {
232 Self {
233 make,
234 header_name,
235 mode,
236 }
237 }
238}
239
240impl<F, A> SetResponseHeaderLayer<BoxMakeHeaderValueFactoryFn<F, A>> {
241 /// Create a new [`SetResponseHeaderLayer`] from a [`super::MakeHeaderValueFn`].
242 ///
243 /// See [`SetResponseHeaderLayer::overriding`] for more details.
244 pub fn overriding_fn(header_name: HeaderName, make_fn: F) -> Self {
245 Self::new(
246 header_name,
247 BoxMakeHeaderValueFactoryFn::new(make_fn),
248 InsertHeaderMode::Override,
249 )
250 }
251
252 /// Create a new [`SetResponseHeaderLayer`] from a [`super::MakeHeaderValueFn`].
253 ///
254 /// See [`SetResponseHeaderLayer::appending`] for more details.
255 pub fn appending_fn(header_name: HeaderName, make_fn: F) -> Self {
256 Self::new(
257 header_name,
258 BoxMakeHeaderValueFactoryFn::new(make_fn),
259 InsertHeaderMode::Append,
260 )
261 }
262
263 /// Create a new [`SetResponseHeaderLayer`] from a [`super::MakeHeaderValueFn`].
264 ///
265 /// See [`SetResponseHeaderLayer::if_not_present`] for more details.
266 pub fn if_not_present_fn(header_name: HeaderName, make_fn: F) -> Self {
267 Self::new(
268 header_name,
269 BoxMakeHeaderValueFactoryFn::new(make_fn),
270 InsertHeaderMode::IfNotPresent,
271 )
272 }
273}
274
275impl<S, M> Layer<S> for SetResponseHeaderLayer<M>
276where
277 M: Clone,
278{
279 type Service = SetResponseHeader<S, M>;
280
281 fn layer(&self, inner: S) -> Self::Service {
282 SetResponseHeader {
283 inner,
284 header_name: self.header_name.clone(),
285 make: self.make.clone(),
286 mode: self.mode,
287 }
288 }
289
290 fn into_layer(self, inner: S) -> Self::Service {
291 SetResponseHeader {
292 inner,
293 header_name: self.header_name,
294 make: self.make,
295 mode: self.mode,
296 }
297 }
298}
299
300impl<M> Clone for SetResponseHeaderLayer<M>
301where
302 M: Clone,
303{
304 fn clone(&self) -> Self {
305 Self {
306 make: self.make.clone(),
307 header_name: self.header_name.clone(),
308 mode: self.mode,
309 }
310 }
311}
312
313/// Middleware that sets a header on the response.
314#[derive(Clone)]
315pub struct SetResponseHeader<S, M> {
316 inner: S,
317 header_name: HeaderName,
318 make: M,
319 mode: InsertHeaderMode,
320}
321
322impl<S, M> SetResponseHeader<S, M> {
323 /// Create a new [`SetResponseHeader`].
324 ///
325 /// If a previous value exists for the same header, it is removed and replaced with the new
326 /// header value.
327 pub fn overriding(inner: S, header_name: HeaderName, make: M) -> Self {
328 Self::new(inner, header_name, make, InsertHeaderMode::Override)
329 }
330
331 /// Create a new [`SetResponseHeader`].
332 ///
333 /// The new header is always added, preserving any existing values. If previous values exist,
334 /// the header will have multiple values.
335 pub fn appending(inner: S, header_name: HeaderName, make: M) -> Self {
336 Self::new(inner, header_name, make, InsertHeaderMode::Append)
337 }
338
339 /// Create a new [`SetResponseHeader`].
340 ///
341 /// If a previous value exists for the header, the new value is not inserted.
342 pub fn if_not_present(inner: S, header_name: HeaderName, make: M) -> Self {
343 Self::new(inner, header_name, make, InsertHeaderMode::IfNotPresent)
344 }
345
346 fn new(inner: S, header_name: HeaderName, make: M, mode: InsertHeaderMode) -> Self {
347 Self {
348 inner,
349 header_name,
350 make,
351 mode,
352 }
353 }
354
355 define_inner_service_accessors!();
356}
357
358impl<S, F, A> SetResponseHeader<S, BoxMakeHeaderValueFactoryFn<F, A>> {
359 /// Create a new [`SetResponseHeader`] from a [`super::MakeHeaderValueFn`].
360 ///
361 /// See [`SetResponseHeader::overriding`] for more details.
362 pub fn overriding_fn(inner: S, header_name: HeaderName, make_fn: F) -> Self {
363 Self::new(
364 inner,
365 header_name,
366 BoxMakeHeaderValueFactoryFn::new(make_fn),
367 InsertHeaderMode::Override,
368 )
369 }
370
371 /// Create a new [`SetResponseHeader`] from a [`super::MakeHeaderValueFn`].
372 ///
373 /// See [`SetResponseHeader::appending`] for more details.
374 pub fn appending_fn(inner: S, header_name: HeaderName, make_fn: F) -> Self {
375 Self::new(
376 inner,
377 header_name,
378 BoxMakeHeaderValueFactoryFn::new(make_fn),
379 InsertHeaderMode::Append,
380 )
381 }
382
383 /// Create a new [`SetResponseHeader`] from a [`super::MakeHeaderValueFn`].
384 ///
385 /// See [`SetResponseHeader::if_not_present`] for more details.
386 pub fn if_not_present_fn(inner: S, header_name: HeaderName, make_fn: F) -> Self {
387 Self::new(
388 inner,
389 header_name,
390 BoxMakeHeaderValueFactoryFn::new(make_fn),
391 InsertHeaderMode::IfNotPresent,
392 )
393 }
394}
395
396impl<S, M> fmt::Debug for SetResponseHeader<S, M>
397where
398 S: fmt::Debug,
399{
400 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
401 f.debug_struct("SetResponseHeader")
402 .field("inner", &self.inner)
403 .field("header_name", &self.header_name)
404 .field("mode", &self.mode)
405 .field("make", &std::any::type_name::<M>())
406 .finish()
407 }
408}
409
410impl<ReqBody, ResBody, State, S, M> Service<State, Request<ReqBody>> for SetResponseHeader<S, M>
411where
412 ReqBody: Send + 'static,
413 ResBody: Send + 'static,
414 State: Clone + Send + Sync + 'static,
415 S: Service<State, Request<ReqBody>, Response = Response<ResBody>>,
416 M: MakeHeaderValueFactory<State, ReqBody, ResBody>,
417{
418 type Response = S::Response;
419 type Error = S::Error;
420
421 async fn serve(
422 &self,
423 ctx: Context<State>,
424 req: Request<ReqBody>,
425 ) -> Result<Self::Response, Self::Error> {
426 let (ctx, req, header_maker) = self.make.make_header_value_maker(ctx, req).await;
427 let res = self.inner.serve(ctx, req).await?;
428 let res = self.mode.apply(&self.header_name, res, header_maker).await;
429 Ok(res)
430 }
431}
432
433#[cfg(test)]
434mod tests {
435 use super::*;
436
437 use crate::{Body, HeaderValue, Request, Response, header};
438 use rama_core::service::service_fn;
439 use std::convert::Infallible;
440
441 #[tokio::test]
442 async fn test_override_mode() {
443 let svc = SetResponseHeader::overriding(
444 service_fn(async || {
445 let res = Response::builder()
446 .header(header::CONTENT_TYPE, "good-content")
447 .body(Body::empty())
448 .unwrap();
449 Ok::<_, Infallible>(res)
450 }),
451 header::CONTENT_TYPE,
452 HeaderValue::from_static("text/html"),
453 );
454
455 let res = svc
456 .serve(Context::default(), Request::new(Body::empty()))
457 .await
458 .unwrap();
459
460 let mut values = res.headers().get_all(header::CONTENT_TYPE).iter();
461 assert_eq!(values.next().unwrap(), "text/html");
462 assert_eq!(values.next(), None);
463 }
464
465 #[tokio::test]
466 async fn test_append_mode() {
467 let svc = SetResponseHeader::appending(
468 service_fn(async || {
469 let res = Response::builder()
470 .header(header::CONTENT_TYPE, "good-content")
471 .body(Body::empty())
472 .unwrap();
473 Ok::<_, Infallible>(res)
474 }),
475 header::CONTENT_TYPE,
476 HeaderValue::from_static("text/html"),
477 );
478
479 let res = svc
480 .serve(Context::default(), Request::new(Body::empty()))
481 .await
482 .unwrap();
483
484 let mut values = res.headers().get_all(header::CONTENT_TYPE).iter();
485 assert_eq!(values.next().unwrap(), "good-content");
486 assert_eq!(values.next().unwrap(), "text/html");
487 assert_eq!(values.next(), None);
488 }
489
490 #[tokio::test]
491 async fn test_skip_if_present_mode() {
492 let svc = SetResponseHeader::if_not_present(
493 service_fn(async || {
494 let res = Response::builder()
495 .header(header::CONTENT_TYPE, "good-content")
496 .body(Body::empty())
497 .unwrap();
498 Ok::<_, Infallible>(res)
499 }),
500 header::CONTENT_TYPE,
501 HeaderValue::from_static("text/html"),
502 );
503
504 let res = svc
505 .serve(Context::default(), Request::new(Body::empty()))
506 .await
507 .unwrap();
508
509 let mut values = res.headers().get_all(header::CONTENT_TYPE).iter();
510 assert_eq!(values.next().unwrap(), "good-content");
511 assert_eq!(values.next(), None);
512 }
513
514 #[tokio::test]
515 async fn test_skip_if_present_mode_when_not_present() {
516 let svc = SetResponseHeader::if_not_present(
517 service_fn(async || {
518 let res = Response::builder().body(Body::empty()).unwrap();
519 Ok::<_, Infallible>(res)
520 }),
521 header::CONTENT_TYPE,
522 HeaderValue::from_static("text/html"),
523 );
524
525 let res = svc
526 .serve(Context::default(), Request::new(Body::empty()))
527 .await
528 .unwrap();
529
530 let mut values = res.headers().get_all(header::CONTENT_TYPE).iter();
531 assert_eq!(values.next().unwrap(), "text/html");
532 assert_eq!(values.next(), None);
533 }
534}