rama_http/layer/set_header/request/
header.rs

1use crate::{HeaderName, HeaderValue, Request};
2use rama_core::Context;
3use std::{
4    future::{Future, ready},
5    marker::PhantomData,
6};
7
8/// Trait for producing header values.
9///
10/// Used by [`SetRequestHeader`] and [`SetResponseHeader`].
11///
12/// This trait is implemented for closures with the correct type signature. Typically users will
13/// not have to implement this trait for their own types.
14///
15/// It is also implemented directly for [`HeaderValue`]. When a fixed header value should be added
16/// to all responses, it can be supplied directly to the middleware.
17pub trait MakeHeaderValue<S, B>: Send + Sync + 'static {
18    /// Try to create a header value from the request or response.
19    fn make_header_value(
20        &self,
21        ctx: Context<S>,
22        req: Request<B>,
23    ) -> impl Future<Output = (Context<S>, Request<B>, Option<HeaderValue>)> + Send + '_;
24}
25
26/// Functional version of [`MakeHeaderValue`].
27pub trait MakeHeaderValueFn<S, B, A>: Send + Sync + 'static {
28    /// Try to create a header value from the request or response.
29    fn call(
30        &self,
31        ctx: Context<S>,
32        req: Request<B>,
33    ) -> impl Future<Output = (Context<S>, Request<B>, Option<HeaderValue>)> + Send + '_;
34}
35
36impl<F, Fut, S, B> MakeHeaderValueFn<S, B, ()> for F
37where
38    S: Clone + Send + Sync + 'static,
39    B: Send + 'static,
40    F: Fn() -> Fut + Send + Sync + 'static,
41    Fut: Future<Output = Option<HeaderValue>> + Send + 'static,
42{
43    async fn call(
44        &self,
45        ctx: Context<S>,
46        req: Request<B>,
47    ) -> (Context<S>, Request<B>, Option<HeaderValue>) {
48        let maybe_value = self().await;
49        (ctx, req, maybe_value)
50    }
51}
52
53impl<F, Fut, S, B> MakeHeaderValueFn<S, B, ((), B)> for F
54where
55    S: Clone + Send + Sync + 'static,
56    B: Send + 'static,
57    F: Fn(Request<B>) -> Fut + Send + Sync + 'static,
58    Fut: Future<Output = (Request<B>, Option<HeaderValue>)> + Send + 'static,
59{
60    async fn call(
61        &self,
62        ctx: Context<S>,
63        req: Request<B>,
64    ) -> (Context<S>, Request<B>, Option<HeaderValue>) {
65        let (req, maybe_value) = self(req).await;
66        (ctx, req, maybe_value)
67    }
68}
69
70impl<F, Fut, S, B> MakeHeaderValueFn<S, B, (Context<S>,)> for F
71where
72    S: Clone + Send + Sync + 'static,
73    B: Send + 'static,
74    F: Fn(Context<S>) -> Fut + Send + Sync + 'static,
75    Fut: Future<Output = (Context<S>, Option<HeaderValue>)> + Send + 'static,
76{
77    async fn call(
78        &self,
79        ctx: Context<S>,
80        req: Request<B>,
81    ) -> (Context<S>, Request<B>, Option<HeaderValue>) {
82        let (ctx, maybe_value) = self(ctx).await;
83        (ctx, req, maybe_value)
84    }
85}
86
87impl<F, Fut, S, B> MakeHeaderValueFn<S, B, (Context<S>, B)> for F
88where
89    F: Fn(Context<S>, Request<B>) -> Fut + Send + Sync + 'static,
90    Fut: Future<Output = (Context<S>, Request<B>, Option<HeaderValue>)> + Send + 'static,
91{
92    fn call(
93        &self,
94        ctx: Context<S>,
95        req: Request<B>,
96    ) -> impl Future<Output = (Context<S>, Request<B>, Option<HeaderValue>)> + Send + '_ {
97        self(ctx, req)
98    }
99}
100
101/// The public wrapper type for [`MakeHeaderValueFn`].
102pub struct BoxMakeHeaderValueFn<F, A> {
103    f: F,
104    _marker: PhantomData<fn(A) -> ()>,
105}
106
107impl<F, A> BoxMakeHeaderValueFn<F, A> {
108    /// Create a new [`BoxMakeHeaderValueFn`].
109    pub const fn new(f: F) -> Self {
110        Self {
111            f,
112            _marker: PhantomData,
113        }
114    }
115}
116
117impl<F, A> Clone for BoxMakeHeaderValueFn<F, A>
118where
119    F: Clone,
120{
121    fn clone(&self) -> Self {
122        Self {
123            f: self.f.clone(),
124            _marker: PhantomData,
125        }
126    }
127}
128
129impl<F, A> std::fmt::Debug for BoxMakeHeaderValueFn<F, A>
130where
131    F: std::fmt::Debug,
132{
133    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
134        f.debug_struct("BoxMakeHeaderValueFn")
135            .field("f", &self.f)
136            .finish()
137    }
138}
139
140impl<S, B, A, F> MakeHeaderValue<S, B> for BoxMakeHeaderValueFn<F, A>
141where
142    A: Send + 'static,
143    F: MakeHeaderValueFn<S, B, A>,
144{
145    fn make_header_value(
146        &self,
147        ctx: Context<S>,
148        req: Request<B>,
149    ) -> impl Future<Output = (Context<S>, Request<B>, Option<HeaderValue>)> + Send + '_ {
150        self.f.call(ctx, req)
151    }
152}
153
154impl<S, B> MakeHeaderValue<S, B> for HeaderValue
155where
156    S: Clone + Send + Sync + 'static,
157    B: Send + 'static,
158{
159    fn make_header_value(
160        &self,
161        ctx: Context<S>,
162        req: Request<B>,
163    ) -> impl Future<Output = (Context<S>, Request<B>, Option<HeaderValue>)> + Send + '_ {
164        ready((ctx, req, Some(self.clone())))
165    }
166}
167
168impl<S, B> MakeHeaderValue<S, B> for Option<HeaderValue>
169where
170    S: Clone + Send + Sync + 'static,
171    B: Send + 'static,
172{
173    fn make_header_value(
174        &self,
175        ctx: Context<S>,
176        req: Request<B>,
177    ) -> impl Future<Output = (Context<S>, Request<B>, Option<HeaderValue>)> + Send + '_ {
178        ready((ctx, req, self.clone()))
179    }
180}
181
182#[derive(Debug, Clone, Copy)]
183pub(super) enum InsertHeaderMode {
184    Override,
185    Append,
186    IfNotPresent,
187}
188
189impl InsertHeaderMode {
190    pub(super) async fn apply<S, B, M>(
191        self,
192        header_name: &HeaderName,
193        ctx: Context<S>,
194        req: Request<B>,
195        make: &M,
196    ) -> (Context<S>, Request<B>)
197    where
198        B: Send + 'static,
199        M: MakeHeaderValue<S, B>,
200    {
201        match self {
202            InsertHeaderMode::Override => {
203                let (ctx, mut req, maybe_value) = make.make_header_value(ctx, req).await;
204                if let Some(value) = maybe_value {
205                    req.headers_mut().insert(header_name.clone(), value);
206                }
207                (ctx, req)
208            }
209            InsertHeaderMode::IfNotPresent => {
210                if !req.headers().contains_key(header_name) {
211                    let (ctx, mut req, maybe_value) = make.make_header_value(ctx, req).await;
212                    if let Some(value) = maybe_value {
213                        req.headers_mut().insert(header_name.clone(), value);
214                    }
215                    (ctx, req)
216                } else {
217                    (ctx, req)
218                }
219            }
220            InsertHeaderMode::Append => {
221                let (ctx, mut req, maybe_value) = make.make_header_value(ctx, req).await;
222                if let Some(value) = maybe_value {
223                    req.headers_mut().append(header_name.clone(), value);
224                }
225                (ctx, req)
226            }
227        }
228    }
229}