rama_http/layer/set_header/response/mod.rs
1//! Set a header on the response.
2//!
3//! The header value to be set may be provided as a fixed value when the
4//! middleware is constructed, or determined dynamically based on the response
5//! by a closure. See the [`MakeHeaderValue`] trait for details.
6//!
7//! # Example
8//!
9//! Setting a header from a fixed value provided when the middleware is constructed:
10//!
11//! ```
12//! use rama_http::layer::set_header::SetResponseHeaderLayer;
13//! use rama_http::{Body, Request, Response, header::{self, HeaderValue}};
14//! use rama_core::service::service_fn;
15//! use rama_core::{Context, Service, Layer};
16//! use rama_core::error::BoxError;
17//!
18//! # #[tokio::main]
19//! # async fn main() -> Result<(), BoxError> {
20//! # let render_html = service_fn(async |request: Request| {
21//! # Ok::<_, std::convert::Infallible>(Response::new(request.into_body()))
22//! # });
23//! #
24//! let mut svc = (
25//! // Layer that sets `Content-Type: text/html` on responses.
26//! //
27//! // `if_not_present` will only insert the header if it does not already
28//! // have a value.
29//! SetResponseHeaderLayer::if_not_present(
30//! header::CONTENT_TYPE,
31//! HeaderValue::from_static("text/html"),
32//! ),
33//! ).into_layer(render_html);
34//!
35//! let request = Request::new(Body::empty());
36//!
37//! let response = svc.serve(Context::default(), request).await?;
38//!
39//! assert_eq!(response.headers()["content-type"], "text/html");
40//! #
41//! # Ok(())
42//! # }
43//! ```
44//!
45//! Setting a header based on a value determined dynamically from the response:
46//!
47//! ```
48//! use rama_core::error::BoxError;
49//! use rama_core::service::service_fn;
50//! use rama_core::{Context, Layer, Service};
51//! use rama_http::dep::http_body::Body as _;
52//! use rama_http::layer::set_header::SetResponseHeaderLayer;
53//! use rama_http::{
54//! header::{self, HeaderValue},
55//! Body, Request, Response,
56//! };
57//!
58//! #[tokio::main]
59//! async fn main() -> Result<(), BoxError> {
60//! let render_html = service_fn(async |_request: Request| {
61//! Ok::<_, std::convert::Infallible>(Response::new(Body::from("1234567890")))
62//! });
63//!
64//! let svc = (
65//! // Layer that sets `Content-Length` if the body has a known size.
66//! // Bodies with streaming responses wont have a known size.
67//! //
68//! // `overriding` will insert the header and override any previous values it
69//! // may have.
70//! SetResponseHeaderLayer::overriding_fn(
71//! header::CONTENT_LENGTH,
72//! async |response: Response| {
73//! let value = if let Some(size) = response.body().size_hint().exact() {
74//! // If the response body has a known size, returning `Some` will
75//! // set the `Content-Length` header to that value.
76//! Some(HeaderValue::from_str(&size.to_string()).unwrap())
77//! } else {
78//! // If the response body doesn't have a known size, return `None`
79//! // to skip setting the header on this response.
80//! None
81//! };
82//! (response, value)
83//! },
84//! ),
85//! )
86//! .into_layer(render_html);
87//!
88//! let request = Request::new(Body::empty());
89//!
90//! let response = svc.serve(Context::default(), request).await?;
91//!
92//! assert_eq!(response.headers()["content-length"], "10");
93//!
94//! Ok(())
95//! }
96//! ```
97//!
98//! Setting a header based on the incoming Context and response combined.
99//!
100//! ```
101//! use rama_core::{service::service_fn, Context, Service};
102//! use rama_http::{
103//! layer::set_header::{response::BoxMakeHeaderValueFn, SetResponseHeader},
104//! Body, HeaderName, HeaderValue, IntoResponse, Request, Response,
105//! };
106//! use std::convert::Infallible;
107//!
108//! #[tokio::main]
109//! async fn main() {
110//! #[derive(Debug, Clone)]
111//! struct RequestID(String);
112//!
113//! #[derive(Debug, Clone)]
114//! struct Success;
115//!
116//! let svc = SetResponseHeader::overriding_fn(
117//! service_fn(async || {
118//! let mut res = ().into_response();
119//! res.extensions_mut().insert(Success);
120//! Ok::<_, Infallible>(res)
121//! }),
122//! HeaderName::from_static("x-used-request-id"),
123//! async |ctx: Context<()>| {
124//! let factory = ctx.get::<RequestID>().cloned().map(|id| {
125//! BoxMakeHeaderValueFn::new(async move |res: Response| {
126//! let header_value = res.extensions().get::<Success>().map(|_| {
127//! HeaderValue::from_str(id.0.as_str()).unwrap()
128//! });
129//! (res, header_value)
130//! })
131//! });
132//! (ctx, factory)
133//! },
134//! );
135//!
136//! const FAKE_USER_ID: &str = "abc123";
137//!
138//! let mut ctx = Context::default();
139//! ctx.insert(RequestID(FAKE_USER_ID.to_owned()));
140//!
141//! let res = svc.serve(ctx, Request::new(Body::empty())).await.unwrap();
142//!
143//! let mut values = res
144//! .headers()
145//! .get_all(HeaderName::from_static("x-used-request-id"))
146//! .iter();
147//! assert_eq!(values.next().unwrap(), FAKE_USER_ID);
148//! assert_eq!(values.next(), None);
149//! }
150//! ```
151
152use crate::{
153 HeaderValue, Request, Response,
154 header::HeaderName,
155 headers::{Header, HeaderExt},
156};
157use rama_core::{Context, Layer, Service};
158use rama_utils::macros::define_inner_service_accessors;
159use std::fmt;
160
161mod header;
162use header::InsertHeaderMode;
163
164pub use header::{
165 BoxMakeHeaderValueFactoryFn, BoxMakeHeaderValueFn, MakeHeaderValue, MakeHeaderValueFactory,
166};
167
168/// Layer that applies [`SetResponseHeader`] which adds a response header.
169///
170/// See [`SetResponseHeader`] for more details.
171pub struct SetResponseHeaderLayer<M> {
172 header_name: HeaderName,
173 make: M,
174 mode: InsertHeaderMode,
175}
176
177impl<M> fmt::Debug for SetResponseHeaderLayer<M> {
178 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
179 f.debug_struct("SetResponseHeaderLayer")
180 .field("header_name", &self.header_name)
181 .field("mode", &self.mode)
182 .field("make", &std::any::type_name::<M>())
183 .finish()
184 }
185}
186
187impl SetResponseHeaderLayer<HeaderValue> {
188 /// Create a new [`SetResponseHeaderLayer`] from a typed [`Header`].
189 ///
190 /// See [`SetResponseHeaderLayer::overriding`] for more details.
191 pub fn overriding_typed<H: Header>(header: H) -> Self {
192 Self::overriding(H::name().clone(), header.encode_to_value())
193 }
194
195 /// Create a new [`SetResponseHeaderLayer`] from a typed [`Header`].
196 ///
197 /// See [`SetResponseHeaderLayer::appending`] for more details.
198 pub fn appending_typed<H: Header>(header: H) -> Self {
199 Self::appending(H::name().clone(), header.encode_to_value())
200 }
201
202 /// Create a new [`SetResponseHeaderLayer`] from a typed [`Header`].
203 ///
204 /// See [`SetResponseHeaderLayer::if_not_present`] for more details.
205 pub fn if_not_present_typed<H: Header>(header: H) -> Self {
206 Self::if_not_present(H::name().clone(), header.encode_to_value())
207 }
208}
209
210impl<M> SetResponseHeaderLayer<M> {
211 /// Create a new [`SetResponseHeaderLayer`].
212 ///
213 /// If a previous value exists for the same header, it is removed and replaced with the new
214 /// header value.
215 pub fn overriding(header_name: HeaderName, make: M) -> Self {
216 Self::new(header_name, make, InsertHeaderMode::Override)
217 }
218
219 /// Create a new [`SetResponseHeaderLayer`].
220 ///
221 /// The new header is always added, preserving any existing values. If previous values exist,
222 /// the header will have multiple values.
223 pub fn appending(header_name: HeaderName, make: M) -> Self {
224 Self::new(header_name, make, InsertHeaderMode::Append)
225 }
226
227 /// Create a new [`SetResponseHeaderLayer`].
228 ///
229 /// If a previous value exists for the header, the new value is not inserted.
230 pub fn if_not_present(header_name: HeaderName, make: M) -> Self {
231 Self::new(header_name, make, InsertHeaderMode::IfNotPresent)
232 }
233
234 fn new(header_name: HeaderName, make: M, mode: InsertHeaderMode) -> Self {
235 Self {
236 make,
237 header_name,
238 mode,
239 }
240 }
241}
242
243impl<F, A> SetResponseHeaderLayer<BoxMakeHeaderValueFactoryFn<F, A>> {
244 /// Create a new [`SetResponseHeaderLayer`] from a [`super::MakeHeaderValueFn`].
245 ///
246 /// See [`SetResponseHeaderLayer::overriding`] for more details.
247 pub fn overriding_fn(header_name: HeaderName, make_fn: F) -> Self {
248 Self::new(
249 header_name,
250 BoxMakeHeaderValueFactoryFn::new(make_fn),
251 InsertHeaderMode::Override,
252 )
253 }
254
255 /// Create a new [`SetResponseHeaderLayer`] from a [`super::MakeHeaderValueFn`].
256 ///
257 /// See [`SetResponseHeaderLayer::appending`] for more details.
258 pub fn appending_fn(header_name: HeaderName, make_fn: F) -> Self {
259 Self::new(
260 header_name,
261 BoxMakeHeaderValueFactoryFn::new(make_fn),
262 InsertHeaderMode::Append,
263 )
264 }
265
266 /// Create a new [`SetResponseHeaderLayer`] from a [`super::MakeHeaderValueFn`].
267 ///
268 /// See [`SetResponseHeaderLayer::if_not_present`] for more details.
269 pub fn if_not_present_fn(header_name: HeaderName, make_fn: F) -> Self {
270 Self::new(
271 header_name,
272 BoxMakeHeaderValueFactoryFn::new(make_fn),
273 InsertHeaderMode::IfNotPresent,
274 )
275 }
276}
277
278impl<S, M> Layer<S> for SetResponseHeaderLayer<M>
279where
280 M: Clone,
281{
282 type Service = SetResponseHeader<S, M>;
283
284 fn layer(&self, inner: S) -> Self::Service {
285 SetResponseHeader {
286 inner,
287 header_name: self.header_name.clone(),
288 make: self.make.clone(),
289 mode: self.mode,
290 }
291 }
292
293 fn into_layer(self, inner: S) -> Self::Service {
294 SetResponseHeader {
295 inner,
296 header_name: self.header_name,
297 make: self.make,
298 mode: self.mode,
299 }
300 }
301}
302
303impl<M> Clone for SetResponseHeaderLayer<M>
304where
305 M: Clone,
306{
307 fn clone(&self) -> Self {
308 Self {
309 make: self.make.clone(),
310 header_name: self.header_name.clone(),
311 mode: self.mode,
312 }
313 }
314}
315
316/// Middleware that sets a header on the response.
317#[derive(Clone)]
318pub struct SetResponseHeader<S, M> {
319 inner: S,
320 header_name: HeaderName,
321 make: M,
322 mode: InsertHeaderMode,
323}
324
325impl<S, M> SetResponseHeader<S, M> {
326 /// Create a new [`SetResponseHeader`].
327 ///
328 /// If a previous value exists for the same header, it is removed and replaced with the new
329 /// header value.
330 pub fn overriding(inner: S, header_name: HeaderName, make: M) -> Self {
331 Self::new(inner, header_name, make, InsertHeaderMode::Override)
332 }
333
334 /// Create a new [`SetResponseHeader`].
335 ///
336 /// The new header is always added, preserving any existing values. If previous values exist,
337 /// the header will have multiple values.
338 pub fn appending(inner: S, header_name: HeaderName, make: M) -> Self {
339 Self::new(inner, header_name, make, InsertHeaderMode::Append)
340 }
341
342 /// Create a new [`SetResponseHeader`].
343 ///
344 /// If a previous value exists for the header, the new value is not inserted.
345 pub fn if_not_present(inner: S, header_name: HeaderName, make: M) -> Self {
346 Self::new(inner, header_name, make, InsertHeaderMode::IfNotPresent)
347 }
348
349 fn new(inner: S, header_name: HeaderName, make: M, mode: InsertHeaderMode) -> Self {
350 Self {
351 inner,
352 header_name,
353 make,
354 mode,
355 }
356 }
357
358 define_inner_service_accessors!();
359}
360
361impl<S, F, A> SetResponseHeader<S, BoxMakeHeaderValueFactoryFn<F, A>> {
362 /// Create a new [`SetResponseHeader`] from a [`super::MakeHeaderValueFn`].
363 ///
364 /// See [`SetResponseHeader::overriding`] for more details.
365 pub fn overriding_fn(inner: S, header_name: HeaderName, make_fn: F) -> Self {
366 Self::new(
367 inner,
368 header_name,
369 BoxMakeHeaderValueFactoryFn::new(make_fn),
370 InsertHeaderMode::Override,
371 )
372 }
373
374 /// Create a new [`SetResponseHeader`] from a [`super::MakeHeaderValueFn`].
375 ///
376 /// See [`SetResponseHeader::appending`] for more details.
377 pub fn appending_fn(inner: S, header_name: HeaderName, make_fn: F) -> Self {
378 Self::new(
379 inner,
380 header_name,
381 BoxMakeHeaderValueFactoryFn::new(make_fn),
382 InsertHeaderMode::Append,
383 )
384 }
385
386 /// Create a new [`SetResponseHeader`] from a [`super::MakeHeaderValueFn`].
387 ///
388 /// See [`SetResponseHeader::if_not_present`] for more details.
389 pub fn if_not_present_fn(inner: S, header_name: HeaderName, make_fn: F) -> Self {
390 Self::new(
391 inner,
392 header_name,
393 BoxMakeHeaderValueFactoryFn::new(make_fn),
394 InsertHeaderMode::IfNotPresent,
395 )
396 }
397}
398
399impl<S, M> fmt::Debug for SetResponseHeader<S, M>
400where
401 S: fmt::Debug,
402{
403 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
404 f.debug_struct("SetResponseHeader")
405 .field("inner", &self.inner)
406 .field("header_name", &self.header_name)
407 .field("mode", &self.mode)
408 .field("make", &std::any::type_name::<M>())
409 .finish()
410 }
411}
412
413impl<ReqBody, ResBody, State, S, M> Service<State, Request<ReqBody>> for SetResponseHeader<S, M>
414where
415 ReqBody: Send + 'static,
416 ResBody: Send + 'static,
417 State: Clone + Send + Sync + 'static,
418 S: Service<State, Request<ReqBody>, Response = Response<ResBody>>,
419 M: MakeHeaderValueFactory<State, ReqBody, ResBody>,
420{
421 type Response = S::Response;
422 type Error = S::Error;
423
424 async fn serve(
425 &self,
426 ctx: Context<State>,
427 req: Request<ReqBody>,
428 ) -> Result<Self::Response, Self::Error> {
429 let (ctx, req, header_maker) = self.make.make_header_value_maker(ctx, req).await;
430 let res = self.inner.serve(ctx, req).await?;
431 let res = self.mode.apply(&self.header_name, res, header_maker).await;
432 Ok(res)
433 }
434}
435
436#[cfg(test)]
437mod tests {
438 use super::*;
439
440 use crate::{Body, HeaderValue, Request, Response, header};
441 use rama_core::service::service_fn;
442 use std::convert::Infallible;
443
444 #[tokio::test]
445 async fn test_override_mode() {
446 let svc = SetResponseHeader::overriding(
447 service_fn(async || {
448 let res = Response::builder()
449 .header(header::CONTENT_TYPE, "good-content")
450 .body(Body::empty())
451 .unwrap();
452 Ok::<_, Infallible>(res)
453 }),
454 header::CONTENT_TYPE,
455 HeaderValue::from_static("text/html"),
456 );
457
458 let res = svc
459 .serve(Context::default(), Request::new(Body::empty()))
460 .await
461 .unwrap();
462
463 let mut values = res.headers().get_all(header::CONTENT_TYPE).iter();
464 assert_eq!(values.next().unwrap(), "text/html");
465 assert_eq!(values.next(), None);
466 }
467
468 #[tokio::test]
469 async fn test_append_mode() {
470 let svc = SetResponseHeader::appending(
471 service_fn(async || {
472 let res = Response::builder()
473 .header(header::CONTENT_TYPE, "good-content")
474 .body(Body::empty())
475 .unwrap();
476 Ok::<_, Infallible>(res)
477 }),
478 header::CONTENT_TYPE,
479 HeaderValue::from_static("text/html"),
480 );
481
482 let res = svc
483 .serve(Context::default(), Request::new(Body::empty()))
484 .await
485 .unwrap();
486
487 let mut values = res.headers().get_all(header::CONTENT_TYPE).iter();
488 assert_eq!(values.next().unwrap(), "good-content");
489 assert_eq!(values.next().unwrap(), "text/html");
490 assert_eq!(values.next(), None);
491 }
492
493 #[tokio::test]
494 async fn test_skip_if_present_mode() {
495 let svc = SetResponseHeader::if_not_present(
496 service_fn(async || {
497 let res = Response::builder()
498 .header(header::CONTENT_TYPE, "good-content")
499 .body(Body::empty())
500 .unwrap();
501 Ok::<_, Infallible>(res)
502 }),
503 header::CONTENT_TYPE,
504 HeaderValue::from_static("text/html"),
505 );
506
507 let res = svc
508 .serve(Context::default(), Request::new(Body::empty()))
509 .await
510 .unwrap();
511
512 let mut values = res.headers().get_all(header::CONTENT_TYPE).iter();
513 assert_eq!(values.next().unwrap(), "good-content");
514 assert_eq!(values.next(), None);
515 }
516
517 #[tokio::test]
518 async fn test_skip_if_present_mode_when_not_present() {
519 let svc = SetResponseHeader::if_not_present(
520 service_fn(async || {
521 let res = Response::builder().body(Body::empty()).unwrap();
522 Ok::<_, Infallible>(res)
523 }),
524 header::CONTENT_TYPE,
525 HeaderValue::from_static("text/html"),
526 );
527
528 let res = svc
529 .serve(Context::default(), Request::new(Body::empty()))
530 .await
531 .unwrap();
532
533 let mut values = res.headers().get_all(header::CONTENT_TYPE).iter();
534 assert_eq!(values.next().unwrap(), "text/html");
535 assert_eq!(values.next(), None);
536 }
537}