rama_http/layer/required_header/
request.rs1use crate::{
7 HeaderValue, Request, Response,
8 header::{self, HOST, RAMA_ID_HEADER_VALUE, USER_AGENT},
9 headers::HeaderMapExt,
10};
11use rama_core::{
12 Context, Layer, Service,
13 error::{BoxError, ErrorContext},
14};
15use rama_http_headers::Host;
16use rama_net::http::RequestContext;
17use rama_utils::macros::define_inner_service_accessors;
18use std::fmt;
19
20#[derive(Debug, Clone, Default)]
24pub struct AddRequiredRequestHeadersLayer {
25 overwrite: bool,
26 user_agent_header_value: Option<HeaderValue>,
27}
28
29impl AddRequiredRequestHeadersLayer {
30 pub const fn new() -> Self {
32 Self {
33 overwrite: false,
34 user_agent_header_value: None,
35 }
36 }
37
38 pub const fn overwrite(mut self, overwrite: bool) -> Self {
43 self.overwrite = overwrite;
44 self
45 }
46
47 pub fn set_overwrite(&mut self, overwrite: bool) -> &mut Self {
52 self.overwrite = overwrite;
53 self
54 }
55
56 pub fn user_agent_header_value(mut self, value: HeaderValue) -> Self {
60 self.user_agent_header_value = Some(value);
61 self
62 }
63
64 pub fn maybe_user_agent_header_value(mut self, value: Option<HeaderValue>) -> Self {
68 self.user_agent_header_value = value;
69 self
70 }
71
72 pub fn set_user_agent_header_value(&mut self, value: HeaderValue) -> &mut Self {
76 self.user_agent_header_value = Some(value);
77 self
78 }
79}
80
81impl<S> Layer<S> for AddRequiredRequestHeadersLayer {
82 type Service = AddRequiredRequestHeaders<S>;
83
84 fn layer(&self, inner: S) -> Self::Service {
85 AddRequiredRequestHeaders {
86 inner,
87 overwrite: self.overwrite,
88 user_agent_header_value: self.user_agent_header_value.clone(),
89 }
90 }
91
92 fn into_layer(self, inner: S) -> Self::Service {
93 AddRequiredRequestHeaders {
94 inner,
95 overwrite: self.overwrite,
96 user_agent_header_value: self.user_agent_header_value,
97 }
98 }
99}
100
101#[derive(Clone)]
103pub struct AddRequiredRequestHeaders<S> {
104 inner: S,
105 overwrite: bool,
106 user_agent_header_value: Option<HeaderValue>,
107}
108
109impl<S> AddRequiredRequestHeaders<S> {
110 pub const fn new(inner: S) -> Self {
112 Self {
113 inner,
114 overwrite: false,
115 user_agent_header_value: None,
116 }
117 }
118
119 pub const fn overwrite(mut self, overwrite: bool) -> Self {
124 self.overwrite = overwrite;
125 self
126 }
127
128 pub fn set_overwrite(&mut self, overwrite: bool) -> &mut Self {
133 self.overwrite = overwrite;
134 self
135 }
136
137 pub fn user_agent_header_value(mut self, value: HeaderValue) -> Self {
141 self.user_agent_header_value = Some(value);
142 self
143 }
144
145 pub fn maybe_user_agent_header_value(mut self, value: Option<HeaderValue>) -> Self {
149 self.user_agent_header_value = value;
150 self
151 }
152
153 pub fn set_user_agent_header_value(&mut self, value: HeaderValue) -> &mut Self {
157 self.user_agent_header_value = Some(value);
158 self
159 }
160
161 define_inner_service_accessors!();
162}
163
164impl<S> fmt::Debug for AddRequiredRequestHeaders<S>
165where
166 S: fmt::Debug,
167{
168 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
169 f.debug_struct("AddRequiredRequestHeaders")
170 .field("inner", &self.inner)
171 .field("user_agent_header_value", &self.user_agent_header_value)
172 .finish()
173 }
174}
175
176impl<ReqBody, ResBody, State, S> Service<State, Request<ReqBody>> for AddRequiredRequestHeaders<S>
177where
178 ReqBody: Send + 'static,
179 ResBody: Send + 'static,
180 State: Clone + Send + Sync + 'static,
181 S: Service<State, Request<ReqBody>, Response = Response<ResBody>, Error: Into<BoxError>>,
182{
183 type Response = S::Response;
184 type Error = BoxError;
185
186 async fn serve(
187 &self,
188 mut ctx: Context<State>,
189 mut req: Request<ReqBody>,
190 ) -> Result<Self::Response, Self::Error> {
191 if self.overwrite || !req.headers().contains_key(HOST) {
192 let request_ctx: &mut RequestContext = ctx
193 .get_or_try_insert_with_ctx(|ctx| (ctx, &req).try_into())
194 .context(
195 "AddRequiredRequestHeaders: get/compute RequestContext to set authority",
196 )?;
197 if request_ctx.authority_has_default_port() {
198 let host = request_ctx.authority.host().clone();
199 tracing::trace!(
200 %host,
201 "add missing host from authority as host header"
202 );
203 req.headers_mut().typed_insert(Host::from(host));
204 } else {
205 let authority = request_ctx.authority.clone();
206 tracing::trace!(
207 %authority,
208 "add missing authority as host header"
209 );
210 req.headers_mut().typed_insert(Host::from(authority));
211 }
212 }
213
214 if self.overwrite {
215 req.headers_mut().insert(
216 USER_AGENT,
217 self.user_agent_header_value
218 .as_ref()
219 .unwrap_or(&RAMA_ID_HEADER_VALUE)
220 .clone(),
221 );
222 } else if let header::Entry::Vacant(header) = req.headers_mut().entry(USER_AGENT) {
223 header.insert(
224 self.user_agent_header_value
225 .as_ref()
226 .unwrap_or(&RAMA_ID_HEADER_VALUE)
227 .clone(),
228 );
229 }
230
231 self.inner.serve(ctx, req).await.map_err(Into::into)
232 }
233}
234
235#[cfg(test)]
236mod test {
237 use super::*;
238 use crate::{Body, Request};
239 use rama_core::service::service_fn;
240 use rama_core::{Context, Layer, Service};
241 use std::convert::Infallible;
242
243 #[tokio::test]
244 async fn add_required_request_headers() {
245 let svc = AddRequiredRequestHeadersLayer::default().into_layer(service_fn(
246 async |_ctx: Context<()>, req: Request| {
247 assert!(req.headers().contains_key(HOST));
248 assert!(req.headers().contains_key(USER_AGENT));
249 Ok::<_, Infallible>(rama_http_types::Response::new(Body::empty()))
250 },
251 ));
252
253 let req = Request::builder()
254 .uri("http://www.example.com/")
255 .body(Body::empty())
256 .unwrap();
257 let resp = svc.serve(Context::default(), req).await.unwrap();
258
259 assert!(!resp.headers().contains_key(HOST));
260 assert!(!resp.headers().contains_key(USER_AGENT));
261 }
262
263 #[tokio::test]
264 async fn add_required_request_headers_custom_ua() {
265 let svc = AddRequiredRequestHeadersLayer::default()
266 .user_agent_header_value(HeaderValue::from_static("foo"))
267 .into_layer(service_fn(async |_ctx: Context<()>, req: Request| {
268 assert!(req.headers().contains_key(HOST));
269 assert_eq!(
270 req.headers().get(USER_AGENT).and_then(|v| v.to_str().ok()),
271 Some("foo")
272 );
273 Ok::<_, Infallible>(rama_http_types::Response::new(Body::empty()))
274 }));
275
276 let req = Request::builder()
277 .uri("http://www.example.com/")
278 .body(Body::empty())
279 .unwrap();
280 let resp = svc.serve(Context::default(), req).await.unwrap();
281
282 assert!(!resp.headers().contains_key(HOST));
283 assert!(!resp.headers().contains_key(USER_AGENT));
284 }
285
286 #[tokio::test]
287 async fn add_required_request_headers_overwrite() {
288 let svc = AddRequiredRequestHeadersLayer::new()
289 .overwrite(true)
290 .into_layer(service_fn(async |_ctx: Context<()>, req: Request| {
291 assert_eq!(req.headers().get(HOST).unwrap(), "127.0.0.1");
292 assert_eq!(
293 req.headers().get(USER_AGENT).unwrap(),
294 RAMA_ID_HEADER_VALUE.to_str().unwrap()
295 );
296 Ok::<_, Infallible>(rama_http_types::Response::new(Body::empty()))
297 }));
298
299 let req = Request::builder()
300 .uri("http://127.0.0.1/")
301 .header(HOST, "example.com")
302 .header(USER_AGENT, "test")
303 .body(Body::empty())
304 .unwrap();
305
306 let resp = svc.serve(Context::default(), req).await.unwrap();
307
308 assert!(!resp.headers().contains_key(HOST));
309 assert!(!resp.headers().contains_key(USER_AGENT));
310 }
311
312 #[tokio::test]
313 async fn add_required_request_headers_overwrite_custom_ua() {
314 let svc = AddRequiredRequestHeadersLayer::new()
315 .overwrite(true)
316 .user_agent_header_value(HeaderValue::from_static("foo"))
317 .into_layer(service_fn(async |_ctx: Context<()>, req: Request| {
318 assert_eq!(req.headers().get(HOST).unwrap(), "127.0.0.1");
319 assert_eq!(
320 req.headers().get(USER_AGENT).and_then(|v| v.to_str().ok()),
321 Some("foo")
322 );
323 Ok::<_, Infallible>(rama_http_types::Response::new(Body::empty()))
324 }));
325
326 let req = Request::builder()
327 .uri("http://127.0.0.1/")
328 .header(HOST, "example.com")
329 .header(USER_AGENT, "test")
330 .body(Body::empty())
331 .unwrap();
332
333 let resp = svc.serve(Context::default(), req).await.unwrap();
334
335 assert!(!resp.headers().contains_key(HOST));
336 assert!(!resp.headers().contains_key(USER_AGENT));
337 }
338}