rama_http/layer/set_header/response/
header.rs

1use crate::{HeaderName, HeaderValue, Request, Response};
2use rama_core::Context;
3use std::{
4    future::{Future, ready},
5    marker::PhantomData,
6};
7
8/// Trait for preparing a maker ([`MakeHeaderValue`]) that will be used
9/// to actually create the [`HeaderValue`] when desired.
10///
11/// The reason why this is split in two parts for responses is because
12/// the context is consumed by the inner service producting the response
13/// to which the header (maybe) will be attached to. In order to not
14/// clone the entire `Context` and its `State` it is therefore better
15/// to let the implementer decide what state is to be cloned and which not.
16///
17/// E.g. for a static Header value one might not need any state or context at all,
18/// which would make it pretty wastefull if we would for such cases clone
19/// these stateful datastructures anyhow.
20///
21/// Most users will however not have to worry about this Trait or why it is there,
22/// as the trait is implemented already for functions, closures and HeaderValues.
23pub trait MakeHeaderValueFactory<S, ReqBody, ResBody>: Send + Sync + 'static {
24    /// Maker that _can_ be produced by this Factory.
25    type Maker: MakeHeaderValue<ResBody>;
26
27    /// Try to create a header value from the request or response.
28    fn make_header_value_maker(
29        &self,
30        ctx: Context<S>,
31        request: Request<ReqBody>,
32    ) -> impl Future<Output = (Context<S>, Request<ReqBody>, Self::Maker)> + Send + '_;
33}
34
35/// Trait for producing header values, created by a `MakeHeaderValueFactory`.
36///
37/// Used by [`SetRequestHeader`] and [`SetResponseHeader`].
38///
39/// This trait is implemented for closures with the correct type signature. Typically users will
40/// not have to implement this trait for their own types.
41///
42/// It is also implemented directly for [`HeaderValue`]. When a fixed header value should be added
43/// to all responses, it can be supplied directly to the middleware.
44pub trait MakeHeaderValue<B>: Send + Sync + 'static {
45    /// Try to create a header value from the request or response.
46    fn make_header_value(
47        self,
48        response: Response<B>,
49    ) -> impl Future<Output = (Response<B>, Option<HeaderValue>)> + Send;
50}
51
52impl<B, M> MakeHeaderValue<B> for Option<M>
53where
54    M: MakeHeaderValue<B> + Clone,
55    B: Send + 'static,
56{
57    async fn make_header_value(self, response: Response<B>) -> (Response<B>, Option<HeaderValue>) {
58        match self {
59            Some(m) => m.make_header_value(response).await,
60            None => (response, None),
61        }
62    }
63}
64
65impl<B> MakeHeaderValue<B> for HeaderValue
66where
67    B: Send + 'static,
68{
69    fn make_header_value(
70        self,
71        response: Response<B>,
72    ) -> impl Future<Output = (Response<B>, Option<HeaderValue>)> + Send {
73        ready((response, Some(self)))
74    }
75}
76
77impl<S, ReqBody, ResBody> MakeHeaderValueFactory<S, ReqBody, ResBody> for HeaderValue
78where
79    S: Clone + Send + Sync + 'static,
80    ReqBody: Send + 'static,
81    ResBody: Send + 'static,
82{
83    type Maker = Self;
84
85    fn make_header_value_maker(
86        &self,
87        ctx: Context<S>,
88        req: Request<ReqBody>,
89    ) -> impl Future<Output = (Context<S>, Request<ReqBody>, Self::Maker)> + Send + '_ {
90        ready((ctx, req, self.clone()))
91    }
92}
93
94impl<S, ReqBody, ResBody> MakeHeaderValueFactory<S, ReqBody, ResBody> for Option<HeaderValue>
95where
96    S: Clone + Send + Sync + 'static,
97    ReqBody: Send + 'static,
98    ResBody: Send + 'static,
99{
100    type Maker = Self;
101
102    fn make_header_value_maker(
103        &self,
104        ctx: Context<S>,
105        req: Request<ReqBody>,
106    ) -> impl Future<Output = (Context<S>, Request<ReqBody>, Self::Maker)> + Send + '_ {
107        ready((ctx, req, self.clone()))
108    }
109}
110
111/// Functional version of [`MakeHeaderValue`].
112pub trait MakeHeaderValueFactoryFn<S, ReqBody, ResBody, A>: Send + Sync + 'static {
113    type Maker: MakeHeaderValue<ResBody>;
114
115    /// Try to create a header value from the request or response.
116    fn call(
117        &self,
118        ctx: Context<S>,
119        request: Request<ReqBody>,
120    ) -> impl Future<Output = (Context<S>, Request<ReqBody>, Self::Maker)> + Send + '_;
121}
122
123impl<F, Fut, S, ReqBody, ResBody, M> MakeHeaderValueFactoryFn<S, ReqBody, ResBody, ()> for F
124where
125    S: Clone + Send + Sync + 'static,
126    ReqBody: Send + 'static,
127    ResBody: Send + 'static,
128    M: MakeHeaderValue<ResBody>,
129    F: Fn() -> Fut + Send + Sync + 'static,
130    Fut: Future<Output = M> + Send + 'static,
131    M: MakeHeaderValue<ResBody>,
132{
133    type Maker = M;
134
135    async fn call(
136        &self,
137        ctx: Context<S>,
138        request: Request<ReqBody>,
139    ) -> (Context<S>, Request<ReqBody>, M) {
140        let maker = self().await;
141        (ctx, request, maker)
142    }
143}
144
145impl<F, Fut, S, ReqBody, ResBody, M>
146    MakeHeaderValueFactoryFn<S, ReqBody, ResBody, ((), Request<ReqBody>)> for F
147where
148    S: Clone + Send + Sync + 'static,
149    ReqBody: Send + 'static,
150    ResBody: Send + 'static,
151    M: MakeHeaderValue<ResBody>,
152    F: Fn(Request<ReqBody>) -> Fut + Send + Sync + 'static,
153    Fut: Future<Output = (Request<ReqBody>, M)> + Send + 'static,
154    M: MakeHeaderValue<ResBody>,
155{
156    type Maker = M;
157
158    async fn call(
159        &self,
160        ctx: Context<S>,
161        request: Request<ReqBody>,
162    ) -> (Context<S>, Request<ReqBody>, M) {
163        let (request, maker) = self(request).await;
164        (ctx, request, maker)
165    }
166}
167
168impl<F, Fut, S, ReqBody, ResBody, M> MakeHeaderValueFactoryFn<S, ReqBody, ResBody, (Context<S>,)>
169    for F
170where
171    S: Clone + Send + Sync + 'static,
172    ReqBody: Send + 'static,
173    ResBody: Send + 'static,
174    M: MakeHeaderValue<ResBody>,
175    F: Fn(Context<S>) -> Fut + Send + Sync + 'static,
176    Fut: Future<Output = (Context<S>, M)> + Send + 'static,
177    M: MakeHeaderValue<ResBody>,
178{
179    type Maker = M;
180
181    async fn call(
182        &self,
183        ctx: Context<S>,
184        request: Request<ReqBody>,
185    ) -> (Context<S>, Request<ReqBody>, M) {
186        let (ctx, maker) = self(ctx).await;
187        (ctx, request, maker)
188    }
189}
190
191impl<F, Fut, S, ReqBody, ResBody, M> MakeHeaderValueFactoryFn<S, ReqBody, ResBody, (Context<S>, M)>
192    for F
193where
194    S: Clone + Send + Sync + 'static,
195    ReqBody: Send + 'static,
196    ResBody: Send + 'static,
197    M: MakeHeaderValue<ResBody>,
198    F: Fn(Context<S>, Request<ReqBody>) -> Fut + Send + Sync + 'static,
199    Fut: Future<Output = (Context<S>, Request<ReqBody>, M)> + Send + 'static,
200    M: MakeHeaderValue<ResBody>,
201{
202    type Maker = M;
203
204    fn call(
205        &self,
206        ctx: Context<S>,
207        request: Request<ReqBody>,
208    ) -> impl Future<Output = (Context<S>, Request<ReqBody>, M)> + Send + '_ {
209        self(ctx, request)
210    }
211}
212
213/// The public wrapper type for [`MakeHeaderValueFactoryFn`].
214pub struct BoxMakeHeaderValueFactoryFn<F, A> {
215    f: F,
216    _marker: PhantomData<fn(A) -> ()>,
217}
218
219impl<F, A> BoxMakeHeaderValueFactoryFn<F, A> {
220    /// Create a new [`BoxMakeHeaderValueFactoryFn`].
221    pub const fn new(f: F) -> Self {
222        Self {
223            f,
224            _marker: PhantomData,
225        }
226    }
227}
228
229impl<F, A> Clone for BoxMakeHeaderValueFactoryFn<F, A>
230where
231    F: Clone,
232{
233    fn clone(&self) -> Self {
234        Self {
235            f: self.f.clone(),
236            _marker: PhantomData,
237        }
238    }
239}
240
241impl<F, A> std::fmt::Debug for BoxMakeHeaderValueFactoryFn<F, A>
242where
243    F: std::fmt::Debug,
244{
245    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
246        f.debug_struct("BoxMakeHeaderValueFn")
247            .field("f", &self.f)
248            .finish()
249    }
250}
251
252impl<S, ReqBody, ResBody, A, F> MakeHeaderValueFactory<S, ReqBody, ResBody>
253    for BoxMakeHeaderValueFactoryFn<F, A>
254where
255    A: Send + 'static,
256    F: MakeHeaderValueFactoryFn<S, ReqBody, ResBody, A>,
257{
258    type Maker = F::Maker;
259
260    fn make_header_value_maker(
261        &self,
262        ctx: Context<S>,
263        request: Request<ReqBody>,
264    ) -> impl Future<Output = (Context<S>, Request<ReqBody>, Self::Maker)> + Send + '_ {
265        self.f.call(ctx, request)
266    }
267}
268
269/// Functional version of [`MakeHeaderValue`],
270/// to make it easier to create a (response) header maker
271/// directly from a response.
272pub trait MakeHeaderValueFn<B, A>: Send + Sync + 'static {
273    /// Try to create a header value from the request or response.
274    fn call(
275        self,
276        response: Response<B>,
277    ) -> impl Future<Output = (Response<B>, Option<HeaderValue>)> + Send;
278}
279
280impl<F, Fut, B> MakeHeaderValueFn<B, ()> for F
281where
282    B: Send + 'static,
283    F: FnOnce() -> Fut + Send + Sync + 'static,
284    Fut: Future<Output = Option<HeaderValue>> + Send + 'static,
285{
286    async fn call(self, response: Response<B>) -> (Response<B>, Option<HeaderValue>) {
287        let maybe_value = self().await;
288        (response, maybe_value)
289    }
290}
291
292impl<F, Fut, B> MakeHeaderValueFn<B, Response<B>> for F
293where
294    B: Send + 'static,
295    F: FnOnce(Response<B>) -> Fut + Send + Sync + 'static,
296    Fut: Future<Output = (Response<B>, Option<HeaderValue>)> + Send + 'static,
297{
298    async fn call(self, response: Response<B>) -> (Response<B>, Option<HeaderValue>) {
299        let (response, maybe_value) = self(response).await;
300        (response, maybe_value)
301    }
302}
303
304/// The public wrapper type for [`MakeHeaderValueFn`].
305pub struct BoxMakeHeaderValueFn<F, A> {
306    f: F,
307    _marker: PhantomData<fn(A) -> ()>,
308}
309
310impl<F, A> BoxMakeHeaderValueFn<F, A> {
311    /// Create a new [`BoxMakeHeaderValueFn`].
312    pub const fn new(f: F) -> Self {
313        Self {
314            f,
315            _marker: PhantomData,
316        }
317    }
318}
319
320impl<F, A> Clone for BoxMakeHeaderValueFn<F, A>
321where
322    F: Clone,
323{
324    fn clone(&self) -> Self {
325        Self {
326            f: self.f.clone(),
327            _marker: PhantomData,
328        }
329    }
330}
331
332impl<F, A> std::fmt::Debug for BoxMakeHeaderValueFn<F, A>
333where
334    F: std::fmt::Debug,
335{
336    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
337        f.debug_struct("BoxMakeHeaderValueFn")
338            .field("f", &self.f)
339            .finish()
340    }
341}
342
343impl<B, A, F> MakeHeaderValue<B> for BoxMakeHeaderValueFn<F, A>
344where
345    A: Send + 'static,
346    F: MakeHeaderValueFn<B, A>,
347{
348    fn make_header_value(
349        self,
350        response: Response<B>,
351    ) -> impl Future<Output = (Response<B>, Option<HeaderValue>)> + Send {
352        self.f.call(response)
353    }
354}
355
356impl<F, Fut, S, ReqBody, ResBody> MakeHeaderValueFactoryFn<S, ReqBody, ResBody, ((), (), ())> for F
357where
358    S: Clone + Send + Sync + 'static,
359    ReqBody: Send + 'static,
360    ResBody: Send + 'static,
361    F: FnOnce() -> Fut + Clone + Send + Sync + 'static,
362    Fut: Future<Output = Option<HeaderValue>> + Send + 'static,
363{
364    type Maker = BoxMakeHeaderValueFn<F, ()>;
365
366    async fn call(
367        &self,
368        ctx: Context<S>,
369        request: Request<ReqBody>,
370    ) -> (Context<S>, Request<ReqBody>, Self::Maker) {
371        let maker = self.clone();
372        (ctx, request, BoxMakeHeaderValueFn::new(maker))
373    }
374}
375
376impl<F, Fut, S, ReqBody, ResBody>
377    MakeHeaderValueFactoryFn<S, ReqBody, ResBody, ((), (), Response<ResBody>)> for F
378where
379    S: Clone + Send + Sync + 'static,
380    ReqBody: Send + 'static,
381    ResBody: Send + 'static,
382    F: FnOnce(Response<ResBody>) -> Fut + Clone + Send + Sync + 'static,
383    Fut: Future<Output = (Response<ResBody>, Option<HeaderValue>)> + Send + 'static,
384{
385    type Maker = BoxMakeHeaderValueFn<F, Response<ResBody>>;
386
387    async fn call(
388        &self,
389        ctx: Context<S>,
390        request: Request<ReqBody>,
391    ) -> (Context<S>, Request<ReqBody>, Self::Maker) {
392        let maker = self.clone();
393        (ctx, request, BoxMakeHeaderValueFn::new(maker))
394    }
395}
396
397#[derive(Debug, Clone, Copy)]
398pub(super) enum InsertHeaderMode {
399    Override,
400    Append,
401    IfNotPresent,
402}
403
404impl InsertHeaderMode {
405    pub(super) async fn apply<B, M>(
406        self,
407        header_name: &HeaderName,
408        response: Response<B>,
409        make: M,
410    ) -> Response<B>
411    where
412        B: Send + 'static,
413        M: MakeHeaderValue<B>,
414    {
415        match self {
416            InsertHeaderMode::Override => {
417                let (mut response, maybe_value) = make.make_header_value(response).await;
418                if let Some(value) = maybe_value {
419                    response.headers_mut().insert(header_name.clone(), value);
420                }
421                response
422            }
423            InsertHeaderMode::IfNotPresent => {
424                if !response.headers().contains_key(header_name) {
425                    let (mut response, maybe_value) = make.make_header_value(response).await;
426                    if let Some(value) = maybe_value {
427                        response.headers_mut().insert(header_name.clone(), value);
428                    }
429                    response
430                } else {
431                    response
432                }
433            }
434            InsertHeaderMode::Append => {
435                let (mut response, maybe_value) = make.make_header_value(response).await;
436                if let Some(value) = maybe_value {
437                    response.headers_mut().append(header_name.clone(), value);
438                }
439                response
440            }
441        }
442    }
443}