rama_http/layer/required_header/
request.rs

1//! Set required headers on the request, if they are missing.
2//!
3//! For now this only sets `Host` header on http/1.1,
4//! as well as always a User-Agent for all versions.
5
6use crate::{
7    HeaderValue, Request, Response,
8    header::{self, HOST, RAMA_ID_HEADER_VALUE, USER_AGENT},
9    headers::HeaderMapExt,
10};
11use rama_core::{
12    Context, Layer, Service,
13    error::{BoxError, ErrorContext},
14};
15use rama_http_headers::Host;
16use rama_net::http::RequestContext;
17use rama_utils::macros::define_inner_service_accessors;
18use std::fmt;
19
20/// Layer that applies [`AddRequiredRequestHeaders`] which adds a request header.
21///
22/// See [`AddRequiredRequestHeaders`] for more details.
23#[derive(Debug, Clone, Default)]
24pub struct AddRequiredRequestHeadersLayer {
25    overwrite: bool,
26    user_agent_header_value: Option<HeaderValue>,
27}
28
29impl AddRequiredRequestHeadersLayer {
30    /// Create a new [`AddRequiredRequestHeadersLayer`].
31    pub const fn new() -> Self {
32        Self {
33            overwrite: false,
34            user_agent_header_value: None,
35        }
36    }
37
38    /// Set whether to overwrite the existing headers.
39    /// If set to `true`, the headers will be overwritten.
40    ///
41    /// Default is `false`.
42    pub const fn overwrite(mut self, overwrite: bool) -> Self {
43        self.overwrite = overwrite;
44        self
45    }
46
47    /// Set whether to overwrite the existing headers.
48    /// If set to `true`, the headers will be overwritten.
49    ///
50    /// Default is `false`.
51    pub fn set_overwrite(&mut self, overwrite: bool) -> &mut Self {
52        self.overwrite = overwrite;
53        self
54    }
55
56    /// Set a custom [`USER_AGENT`] header value.
57    ///
58    /// By default a versioned `rama` value is used.
59    pub fn user_agent_header_value(mut self, value: HeaderValue) -> Self {
60        self.user_agent_header_value = Some(value);
61        self
62    }
63
64    /// Maybe set a custom [`USER_AGENT`] header value.
65    ///
66    /// By default a versioned `rama` value is used.
67    pub fn maybe_user_agent_header_value(mut self, value: Option<HeaderValue>) -> Self {
68        self.user_agent_header_value = value;
69        self
70    }
71
72    /// Set a custom [`USER_AGENT`] header value.
73    ///
74    /// By default a versioned `rama` value is used.
75    pub fn set_user_agent_header_value(&mut self, value: HeaderValue) -> &mut Self {
76        self.user_agent_header_value = Some(value);
77        self
78    }
79}
80
81impl<S> Layer<S> for AddRequiredRequestHeadersLayer {
82    type Service = AddRequiredRequestHeaders<S>;
83
84    fn layer(&self, inner: S) -> Self::Service {
85        AddRequiredRequestHeaders {
86            inner,
87            overwrite: self.overwrite,
88            user_agent_header_value: self.user_agent_header_value.clone(),
89        }
90    }
91
92    fn into_layer(self, inner: S) -> Self::Service {
93        AddRequiredRequestHeaders {
94            inner,
95            overwrite: self.overwrite,
96            user_agent_header_value: self.user_agent_header_value,
97        }
98    }
99}
100
101/// Middleware that sets a header on the request.
102#[derive(Clone)]
103pub struct AddRequiredRequestHeaders<S> {
104    inner: S,
105    overwrite: bool,
106    user_agent_header_value: Option<HeaderValue>,
107}
108
109impl<S> AddRequiredRequestHeaders<S> {
110    /// Create a new [`AddRequiredRequestHeaders`].
111    pub const fn new(inner: S) -> Self {
112        Self {
113            inner,
114            overwrite: false,
115            user_agent_header_value: None,
116        }
117    }
118
119    /// Set whether to overwrite the existing headers.
120    /// If set to `true`, the headers will be overwritten.
121    ///
122    /// Default is `false`.
123    pub const fn overwrite(mut self, overwrite: bool) -> Self {
124        self.overwrite = overwrite;
125        self
126    }
127
128    /// Set whether to overwrite the existing headers.
129    /// If set to `true`, the headers will be overwritten.
130    ///
131    /// Default is `false`.
132    pub fn set_overwrite(&mut self, overwrite: bool) -> &mut Self {
133        self.overwrite = overwrite;
134        self
135    }
136
137    /// Set a custom [`USER_AGENT`] header value.
138    ///
139    /// By default a versioned `rama` value is used.
140    pub fn user_agent_header_value(mut self, value: HeaderValue) -> Self {
141        self.user_agent_header_value = Some(value);
142        self
143    }
144
145    /// Maybe set a custom [`USER_AGENT`] header value.
146    ///
147    /// By default a versioned `rama` value is used.
148    pub fn maybe_user_agent_header_value(mut self, value: Option<HeaderValue>) -> Self {
149        self.user_agent_header_value = value;
150        self
151    }
152
153    /// Set a custom [`USER_AGENT`] header value.
154    ///
155    /// By default a versioned `rama` value is used.
156    pub fn set_user_agent_header_value(&mut self, value: HeaderValue) -> &mut Self {
157        self.user_agent_header_value = Some(value);
158        self
159    }
160
161    define_inner_service_accessors!();
162}
163
164impl<S> fmt::Debug for AddRequiredRequestHeaders<S>
165where
166    S: fmt::Debug,
167{
168    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
169        f.debug_struct("AddRequiredRequestHeaders")
170            .field("inner", &self.inner)
171            .field("user_agent_header_value", &self.user_agent_header_value)
172            .finish()
173    }
174}
175
176impl<ReqBody, ResBody, State, S> Service<State, Request<ReqBody>> for AddRequiredRequestHeaders<S>
177where
178    ReqBody: Send + 'static,
179    ResBody: Send + 'static,
180    State: Clone + Send + Sync + 'static,
181    S: Service<State, Request<ReqBody>, Response = Response<ResBody>, Error: Into<BoxError>>,
182{
183    type Response = S::Response;
184    type Error = BoxError;
185
186    async fn serve(
187        &self,
188        mut ctx: Context<State>,
189        mut req: Request<ReqBody>,
190    ) -> Result<Self::Response, Self::Error> {
191        if self.overwrite || !req.headers().contains_key(HOST) {
192            let request_ctx: &mut RequestContext = ctx
193                .get_or_try_insert_with_ctx(|ctx| (ctx, &req).try_into())
194                .context(
195                    "AddRequiredRequestHeaders: get/compute RequestContext to set authority",
196                )?;
197            if request_ctx.authority_has_default_port() {
198                let host = request_ctx.authority.host().clone();
199                tracing::trace!(
200                    %host,
201                    "add missing host from authority as host header"
202                );
203                req.headers_mut().typed_insert(Host::from(host));
204            } else {
205                let authority = request_ctx.authority.clone();
206                tracing::trace!(
207                    %authority,
208                    "add missing authority as host header"
209                );
210                req.headers_mut().typed_insert(Host::from(authority));
211            }
212        }
213
214        if self.overwrite {
215            req.headers_mut().insert(
216                USER_AGENT,
217                self.user_agent_header_value
218                    .as_ref()
219                    .unwrap_or(&RAMA_ID_HEADER_VALUE)
220                    .clone(),
221            );
222        } else if let header::Entry::Vacant(header) = req.headers_mut().entry(USER_AGENT) {
223            header.insert(
224                self.user_agent_header_value
225                    .as_ref()
226                    .unwrap_or(&RAMA_ID_HEADER_VALUE)
227                    .clone(),
228            );
229        }
230
231        self.inner.serve(ctx, req).await.map_err(Into::into)
232    }
233}
234
235#[cfg(test)]
236mod test {
237    use super::*;
238    use crate::{Body, Request};
239    use rama_core::service::service_fn;
240    use rama_core::{Context, Layer, Service};
241    use std::convert::Infallible;
242
243    #[tokio::test]
244    async fn add_required_request_headers() {
245        let svc = AddRequiredRequestHeadersLayer::default().into_layer(service_fn(
246            async |_ctx: Context<()>, req: Request| {
247                assert!(req.headers().contains_key(HOST));
248                assert!(req.headers().contains_key(USER_AGENT));
249                Ok::<_, Infallible>(rama_http_types::Response::new(Body::empty()))
250            },
251        ));
252
253        let req = Request::builder()
254            .uri("http://www.example.com/")
255            .body(Body::empty())
256            .unwrap();
257        let resp = svc.serve(Context::default(), req).await.unwrap();
258
259        assert!(!resp.headers().contains_key(HOST));
260        assert!(!resp.headers().contains_key(USER_AGENT));
261    }
262
263    #[tokio::test]
264    async fn add_required_request_headers_custom_ua() {
265        let svc = AddRequiredRequestHeadersLayer::default()
266            .user_agent_header_value(HeaderValue::from_static("foo"))
267            .into_layer(service_fn(async |_ctx: Context<()>, req: Request| {
268                assert!(req.headers().contains_key(HOST));
269                assert_eq!(
270                    req.headers().get(USER_AGENT).and_then(|v| v.to_str().ok()),
271                    Some("foo")
272                );
273                Ok::<_, Infallible>(rama_http_types::Response::new(Body::empty()))
274            }));
275
276        let req = Request::builder()
277            .uri("http://www.example.com/")
278            .body(Body::empty())
279            .unwrap();
280        let resp = svc.serve(Context::default(), req).await.unwrap();
281
282        assert!(!resp.headers().contains_key(HOST));
283        assert!(!resp.headers().contains_key(USER_AGENT));
284    }
285
286    #[tokio::test]
287    async fn add_required_request_headers_overwrite() {
288        let svc = AddRequiredRequestHeadersLayer::new()
289            .overwrite(true)
290            .into_layer(service_fn(async |_ctx: Context<()>, req: Request| {
291                assert_eq!(req.headers().get(HOST).unwrap(), "127.0.0.1");
292                assert_eq!(
293                    req.headers().get(USER_AGENT).unwrap(),
294                    RAMA_ID_HEADER_VALUE.to_str().unwrap()
295                );
296                Ok::<_, Infallible>(rama_http_types::Response::new(Body::empty()))
297            }));
298
299        let req = Request::builder()
300            .uri("http://127.0.0.1/")
301            .header(HOST, "example.com")
302            .header(USER_AGENT, "test")
303            .body(Body::empty())
304            .unwrap();
305
306        let resp = svc.serve(Context::default(), req).await.unwrap();
307
308        assert!(!resp.headers().contains_key(HOST));
309        assert!(!resp.headers().contains_key(USER_AGENT));
310    }
311
312    #[tokio::test]
313    async fn add_required_request_headers_overwrite_custom_ua() {
314        let svc = AddRequiredRequestHeadersLayer::new()
315            .overwrite(true)
316            .user_agent_header_value(HeaderValue::from_static("foo"))
317            .into_layer(service_fn(async |_ctx: Context<()>, req: Request| {
318                assert_eq!(req.headers().get(HOST).unwrap(), "127.0.0.1");
319                assert_eq!(
320                    req.headers().get(USER_AGENT).and_then(|v| v.to_str().ok()),
321                    Some("foo")
322                );
323                Ok::<_, Infallible>(rama_http_types::Response::new(Body::empty()))
324            }));
325
326        let req = Request::builder()
327            .uri("http://127.0.0.1/")
328            .header(HOST, "example.com")
329            .header(USER_AGENT, "test")
330            .body(Body::empty())
331            .unwrap();
332
333        let resp = svc.serve(Context::default(), req).await.unwrap();
334
335        assert!(!resp.headers().contains_key(HOST));
336        assert!(!resp.headers().contains_key(USER_AGENT));
337    }
338}