rama_http/layer/forwarded/
set_forwarded.rs1use crate::headers::{
2 ForwardHeader, HeaderMapExt, Via, XForwardedFor, XForwardedHost, XForwardedProto,
3};
4use crate::Request;
5use rama_core::error::BoxError;
6use rama_core::{Context, Layer, Service};
7use rama_net::address::Domain;
8use rama_net::forwarded::{Forwarded, ForwardedElement, NodeId};
9use rama_net::http::RequestContext;
10use rama_net::stream::SocketInfo;
11use rama_utils::macros::all_the_tuples_no_last_special_case;
12use std::fmt;
13use std::marker::PhantomData;
14
15pub struct SetForwardedHeadersLayer<T = Forwarded> {
87 by_node: NodeId,
88 _headers: PhantomData<fn() -> T>,
89}
90
91impl<T: fmt::Debug> fmt::Debug for SetForwardedHeadersLayer<T> {
92 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
93 f.debug_struct("SetForwardedHeadersLayer")
94 .field("by_node", &self.by_node)
95 .field(
96 "_headers",
97 &format_args!("{}", std::any::type_name::<fn() -> T>()),
98 )
99 .finish()
100 }
101}
102
103impl<T: Clone> Clone for SetForwardedHeadersLayer<T> {
104 fn clone(&self) -> Self {
105 Self {
106 by_node: self.by_node.clone(),
107 _headers: PhantomData,
108 }
109 }
110}
111
112impl<T> SetForwardedHeadersLayer<T> {
113 pub fn forward_by(mut self, node_id: impl Into<NodeId>) -> Self {
117 self.by_node = node_id.into();
118 self
119 }
120
121 pub fn set_forward_by(&mut self, node_id: impl Into<NodeId>) -> &mut Self {
125 self.by_node = node_id.into();
126 self
127 }
128}
129
130impl<T> SetForwardedHeadersLayer<T> {
131 pub fn new() -> Self {
133 Self {
134 by_node: Domain::from_static("rama").into(),
135 _headers: PhantomData,
136 }
137 }
138}
139
140impl Default for SetForwardedHeadersLayer {
141 fn default() -> Self {
142 Self::forwarded()
143 }
144}
145
146impl SetForwardedHeadersLayer {
147 #[inline]
148 pub fn forwarded() -> Self {
150 Self::new()
151 }
152}
153
154impl SetForwardedHeadersLayer<Via> {
155 #[inline]
156 pub fn via() -> Self {
158 Self::new()
159 }
160}
161
162impl SetForwardedHeadersLayer<XForwardedFor> {
163 #[inline]
164 pub fn x_forwarded_for() -> Self {
166 Self::new()
167 }
168}
169
170impl SetForwardedHeadersLayer<XForwardedHost> {
171 #[inline]
172 pub fn x_forwarded_host() -> Self {
174 Self::new()
175 }
176}
177
178impl SetForwardedHeadersLayer<XForwardedProto> {
179 #[inline]
180 pub fn x_forwarded_proto() -> Self {
182 Self::new()
183 }
184}
185
186impl<H, S> Layer<S> for SetForwardedHeadersLayer<H> {
187 type Service = SetForwardedHeadersService<S, H>;
188
189 fn layer(&self, inner: S) -> Self::Service {
190 Self::Service {
191 inner,
192 by_node: self.by_node.clone(),
193 _headers: PhantomData,
194 }
195 }
196}
197
198pub struct SetForwardedHeadersService<S, T = Forwarded> {
203 inner: S,
204 by_node: NodeId,
205 _headers: PhantomData<fn() -> T>,
206}
207
208impl<S: fmt::Debug, T> fmt::Debug for SetForwardedHeadersService<S, T> {
209 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
210 f.debug_struct("SetForwardedHeadersService")
211 .field("inner", &self.inner)
212 .field("by_node", &self.by_node)
213 .field(
214 "_headers",
215 &format_args!("{}", std::any::type_name::<fn() -> T>()),
216 )
217 .finish()
218 }
219}
220
221impl<S: Clone, T> Clone for SetForwardedHeadersService<S, T> {
222 fn clone(&self) -> Self {
223 SetForwardedHeadersService {
224 inner: self.inner.clone(),
225 by_node: self.by_node.clone(),
226 _headers: PhantomData,
227 }
228 }
229}
230
231impl<S, T> SetForwardedHeadersService<S, T> {
232 pub fn forward_by(mut self, node_id: impl Into<NodeId>) -> Self {
236 self.by_node = node_id.into();
237 self
238 }
239
240 pub fn set_forward_by(&mut self, node_id: impl Into<NodeId>) -> &mut Self {
244 self.by_node = node_id.into();
245 self
246 }
247}
248
249impl<S, T> SetForwardedHeadersService<S, T> {
250 pub fn new(inner: S) -> Self {
252 Self {
253 inner,
254 by_node: Domain::from_static("rama").into(),
255 _headers: PhantomData,
256 }
257 }
258}
259
260impl<S> SetForwardedHeadersService<S> {
261 #[inline]
262 pub fn forwarded(inner: S) -> Self {
264 Self::new(inner)
265 }
266}
267
268impl<S> SetForwardedHeadersService<S, Via> {
269 #[inline]
270 pub fn via(inner: S) -> Self {
272 Self::new(inner)
273 }
274}
275
276impl<S> SetForwardedHeadersService<S, XForwardedFor> {
277 #[inline]
278 pub fn x_forwarded_for(inner: S) -> Self {
280 Self::new(inner)
281 }
282}
283
284impl<S> SetForwardedHeadersService<S, XForwardedHost> {
285 #[inline]
286 pub fn x_forwarded_host(inner: S) -> Self {
288 Self::new(inner)
289 }
290}
291
292impl<S> SetForwardedHeadersService<S, XForwardedProto> {
293 #[inline]
294 pub fn x_forwarded_proto(inner: S) -> Self {
296 Self::new(inner)
297 }
298}
299
300impl<S, H, State, Body> Service<State, Request<Body>> for SetForwardedHeadersService<S, H>
301where
302 S: Service<State, Request<Body>, Error: Into<BoxError>>,
303 H: ForwardHeader + Send + Sync + 'static,
304 Body: Send + 'static,
305 State: Clone + Send + Sync + 'static,
306{
307 type Response = S::Response;
308 type Error = BoxError;
309
310 async fn serve(
311 &self,
312 mut ctx: Context<State>,
313 mut req: Request<Body>,
314 ) -> Result<Self::Response, Self::Error> {
315 let forwarded: Option<Forwarded> = ctx.get().cloned();
316
317 let mut forwarded_element = ForwardedElement::forwarded_by(self.by_node.clone());
318
319 if let Some(peer_addr) = ctx.get::<SocketInfo>().map(|socket| *socket.peer_addr()) {
320 forwarded_element.set_forwarded_for(peer_addr);
321 }
322 let request_ctx: &mut RequestContext =
323 ctx.get_or_try_insert_with_ctx(|ctx| (ctx, &req).try_into())?;
324
325 forwarded_element.set_forwarded_host(request_ctx.authority.clone());
326
327 if let Ok(forwarded_proto) = (&request_ctx.protocol).try_into() {
328 forwarded_element.set_forwarded_proto(forwarded_proto);
329 }
330
331 let forwarded = match forwarded {
332 None => Some(Forwarded::new(forwarded_element)),
333 Some(mut forwarded) => {
334 forwarded.append(forwarded_element);
335 Some(forwarded)
336 }
337 };
338
339 if let Some(forwarded) = forwarded {
340 if let Some(header) = H::try_from_forwarded(forwarded.iter()) {
341 req.headers_mut().typed_insert(header);
342 }
343 }
344
345 self.inner.serve(ctx, req).await.map_err(Into::into)
346 }
347}
348
349macro_rules! set_forwarded_service_for_tuple {
350 ( $($ty:ident),* $(,)? ) => {
351 #[allow(non_snake_case)]
352 impl<S, $($ty),* , State, Body> Service<State, Request<Body>> for SetForwardedHeadersService<S, ($($ty,)*)>
353 where
354 $( $ty: ForwardHeader + Send + Sync + 'static, )*
355 S: Service<State, Request<Body>, Error: Into<BoxError>>,
356 Body: Send + 'static,
357 State: Clone + Send + Sync + 'static,
358 {
359 type Response = S::Response;
360 type Error = BoxError;
361
362 async fn serve(
363 &self,
364 mut ctx: Context<State>,
365 mut req: Request<Body>,
366 ) -> Result<Self::Response, Self::Error> {
367 let forwarded: Option<Forwarded> = ctx.get().cloned();
368
369 let mut forwarded_element = ForwardedElement::forwarded_by(self.by_node.clone());
370
371 if let Some(peer_addr) = ctx.get::<SocketInfo>().map(|socket| *socket.peer_addr()) {
372 forwarded_element.set_forwarded_for(peer_addr);
373 }
374
375 let request_ctx: &mut RequestContext =
376 ctx.get_or_try_insert_with_ctx(|ctx| (ctx, &req).try_into())?;
377
378 forwarded_element.set_forwarded_host(request_ctx.authority.clone());
379
380 if let Ok(forwarded_proto) = (&request_ctx.protocol).try_into() {
381 forwarded_element.set_forwarded_proto(forwarded_proto);
382 }
383
384 let forwarded = match forwarded {
385 None => Some(Forwarded::new(forwarded_element)),
386 Some(mut forwarded) => {
387 forwarded.append(forwarded_element);
388 Some(forwarded)
389 }
390 };
391
392 if let Some(forwarded) = forwarded {
393 $(
394 if let Some(header) = $ty::try_from_forwarded(forwarded.iter()) {
395 req.headers_mut().typed_insert(header);
396 }
397 )*
398 }
399
400 self.inner.serve(ctx, req).await.map_err(Into::into)
401 }
402 }
403 };
404}
405all_the_tuples_no_last_special_case!(set_forwarded_service_for_tuple);
406
407#[cfg(test)]
408mod tests {
409 use super::*;
410 use crate::{
411 headers::{TrueClientIp, XClientIp, XRealIp},
412 IntoResponse, Response, StatusCode,
413 };
414 use rama_core::{error::OpaqueError, service::service_fn, Layer};
415 use std::{convert::Infallible, net::IpAddr};
416
417 fn assert_is_service<T: Service<(), Request<()>>>(_: T) {}
418
419 async fn dummy_service_fn() -> Result<Response, OpaqueError> {
420 Ok(StatusCode::OK.into_response())
421 }
422
423 #[test]
424 fn test_set_forwarded_service_is_service() {
425 assert_is_service(SetForwardedHeadersService::forwarded(service_fn(
426 dummy_service_fn,
427 )));
428 assert_is_service(SetForwardedHeadersService::via(service_fn(
429 dummy_service_fn,
430 )));
431 assert_is_service(SetForwardedHeadersService::x_forwarded_for(service_fn(
432 dummy_service_fn,
433 )));
434 assert_is_service(SetForwardedHeadersService::x_forwarded_proto(service_fn(
435 dummy_service_fn,
436 )));
437 assert_is_service(SetForwardedHeadersService::x_forwarded_host(service_fn(
438 dummy_service_fn,
439 )));
440 assert_is_service(SetForwardedHeadersService::<_, TrueClientIp>::new(
441 service_fn(dummy_service_fn),
442 ));
443 assert_is_service(SetForwardedHeadersService::<_, (TrueClientIp,)>::new(
444 service_fn(dummy_service_fn),
445 ));
446 assert_is_service(
447 SetForwardedHeadersService::<_, (TrueClientIp, XClientIp)>::new(service_fn(
448 dummy_service_fn,
449 )),
450 );
451 assert_is_service(SetForwardedHeadersLayer::via().layer(service_fn(dummy_service_fn)));
452 assert_is_service(
453 SetForwardedHeadersLayer::<XRealIp>::new().layer(service_fn(dummy_service_fn)),
454 );
455 assert_is_service(
456 SetForwardedHeadersLayer::<(XRealIp, XForwardedProto)>::new()
457 .layer(service_fn(dummy_service_fn)),
458 );
459 }
460
461 #[tokio::test]
462 async fn test_set_forwarded_service_forwarded() {
463 async fn svc(request: Request<()>) -> Result<(), Infallible> {
464 assert_eq!(
465 request.headers().get("Forwarded").unwrap(),
466 "by=rama;host=\"example.com:80\";proto=http"
467 );
468 Ok(())
469 }
470
471 let service = SetForwardedHeadersService::forwarded(service_fn(svc));
472 let req = Request::builder().uri("example.com").body(()).unwrap();
473 service.serve(Context::default(), req).await.unwrap();
474 }
475
476 #[tokio::test]
477 async fn test_set_forwarded_service_forwarded_with_chain() {
478 async fn svc(request: Request<()>) -> Result<(), Infallible> {
479 assert_eq!(
480 request.headers().get("Forwarded").unwrap(),
481 "for=12.23.34.45,by=rama;for=\"127.0.0.1:62345\";host=\"www.example.com:443\";proto=https",
482 );
483 Ok(())
484 }
485
486 let service = SetForwardedHeadersService::forwarded(service_fn(svc));
487 let req = Request::builder()
488 .uri("https://www.example.com")
489 .body(())
490 .unwrap();
491 let mut ctx = Context::default();
492 ctx.insert(Forwarded::new(ForwardedElement::forwarded_for(
493 IpAddr::from([12, 23, 34, 45]),
494 )));
495 ctx.insert(SocketInfo::new(None, "127.0.0.1:62345".parse().unwrap()));
496 service.serve(ctx, req).await.unwrap();
497 }
498
499 #[tokio::test]
500 async fn test_set_forwarded_service_x_forwarded_for_with_chain() {
501 async fn svc(request: Request<()>) -> Result<(), Infallible> {
502 assert_eq!(
503 request.headers().get("X-Forwarded-For").unwrap(),
504 "12.23.34.45, 127.0.0.1",
505 );
506 Ok(())
507 }
508
509 let service = SetForwardedHeadersService::x_forwarded_for(service_fn(svc));
510 let req = Request::builder()
511 .uri("https://www.example.com")
512 .body(())
513 .unwrap();
514 let mut ctx = Context::default();
515 ctx.insert(Forwarded::new(ForwardedElement::forwarded_for(
516 IpAddr::from([12, 23, 34, 45]),
517 )));
518 ctx.insert(SocketInfo::new(None, "127.0.0.1:62345".parse().unwrap()));
519 service.serve(ctx, req).await.unwrap();
520 }
521
522 #[tokio::test]
523 async fn test_set_forwarded_service_forwarded_fully_defined() {
524 async fn svc(request: Request<()>) -> Result<(), Infallible> {
525 assert_eq!(
526 request.headers().get("Forwarded").unwrap(),
527 "by=12.23.34.45;for=\"127.0.0.1:62345\";host=\"www.example.com:443\";proto=https",
528 );
529 Ok(())
530 }
531
532 let service = SetForwardedHeadersService::forwarded(service_fn(svc))
533 .forward_by(IpAddr::from([12, 23, 34, 45]));
534 let req = Request::builder()
535 .uri("https://www.example.com")
536 .body(())
537 .unwrap();
538 let mut ctx = Context::default();
539 ctx.insert(SocketInfo::new(None, "127.0.0.1:62345".parse().unwrap()));
540 service.serve(ctx, req).await.unwrap();
541 }
542
543 #[tokio::test]
544 async fn test_set_forwarded_service_forwarded_fully_defined_with_chain() {
545 async fn svc(request: Request<()>) -> Result<(), Infallible> {
546 assert_eq!(
547 request.headers().get("Forwarded").unwrap(),
548 "by=rama;for=\"127.0.0.1:62345\";host=\"www.example.com:443\";proto=https",
549 );
550 Ok(())
551 }
552
553 let service = SetForwardedHeadersService::forwarded(service_fn(svc));
554 let req = Request::builder()
555 .uri("https://www.example.com")
556 .body(())
557 .unwrap();
558 let mut ctx = Context::default();
559 ctx.insert(SocketInfo::new(None, "127.0.0.1:62345".parse().unwrap()));
560 service.serve(ctx, req).await.unwrap();
561 }
562}