1use crate::Request;
2use crate::headers::{
3 ForwardHeader, HeaderMapExt, Via, XForwardedFor, XForwardedHost, XForwardedProto,
4};
5use rama_core::error::BoxError;
6use rama_core::{Context, Layer, Service};
7use rama_net::address::Domain;
8use rama_net::forwarded::{Forwarded, ForwardedElement, NodeId};
9use rama_net::http::RequestContext;
10use rama_net::stream::SocketInfo;
11use rama_utils::macros::all_the_tuples_no_last_special_case;
12use std::fmt;
13use std::marker::PhantomData;
14
15pub struct SetForwardedHeadersLayer<T = Forwarded> {
87 by_node: NodeId,
88 _headers: PhantomData<fn() -> T>,
89}
90
91impl<T: fmt::Debug> fmt::Debug for SetForwardedHeadersLayer<T> {
92 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
93 f.debug_struct("SetForwardedHeadersLayer")
94 .field("by_node", &self.by_node)
95 .field(
96 "_headers",
97 &format_args!("{}", std::any::type_name::<fn() -> T>()),
98 )
99 .finish()
100 }
101}
102
103impl<T: Clone> Clone for SetForwardedHeadersLayer<T> {
104 fn clone(&self) -> Self {
105 Self {
106 by_node: self.by_node.clone(),
107 _headers: PhantomData,
108 }
109 }
110}
111
112impl<T> SetForwardedHeadersLayer<T> {
113 pub fn forward_by(mut self, node_id: impl Into<NodeId>) -> Self {
117 self.by_node = node_id.into();
118 self
119 }
120
121 pub fn set_forward_by(&mut self, node_id: impl Into<NodeId>) -> &mut Self {
125 self.by_node = node_id.into();
126 self
127 }
128}
129
130impl<T> SetForwardedHeadersLayer<T> {
131 pub fn new() -> Self {
133 Self {
134 by_node: Domain::from_static("rama").into(),
135 _headers: PhantomData,
136 }
137 }
138}
139
140impl Default for SetForwardedHeadersLayer {
141 fn default() -> Self {
142 Self::forwarded()
143 }
144}
145
146impl SetForwardedHeadersLayer {
147 #[inline]
148 pub fn forwarded() -> Self {
150 Self::new()
151 }
152}
153
154impl SetForwardedHeadersLayer<Via> {
155 #[inline]
156 pub fn via() -> Self {
158 Self::new()
159 }
160}
161
162impl SetForwardedHeadersLayer<XForwardedFor> {
163 #[inline]
164 pub fn x_forwarded_for() -> Self {
166 Self::new()
167 }
168}
169
170impl SetForwardedHeadersLayer<XForwardedHost> {
171 #[inline]
172 pub fn x_forwarded_host() -> Self {
174 Self::new()
175 }
176}
177
178impl SetForwardedHeadersLayer<XForwardedProto> {
179 #[inline]
180 pub fn x_forwarded_proto() -> Self {
182 Self::new()
183 }
184}
185
186impl<H, S> Layer<S> for SetForwardedHeadersLayer<H> {
187 type Service = SetForwardedHeadersService<S, H>;
188
189 fn layer(&self, inner: S) -> Self::Service {
190 Self::Service {
191 inner,
192 by_node: self.by_node.clone(),
193 _headers: PhantomData,
194 }
195 }
196
197 fn into_layer(self, inner: S) -> Self::Service {
198 Self::Service {
199 inner,
200 by_node: self.by_node,
201 _headers: PhantomData,
202 }
203 }
204}
205
206pub struct SetForwardedHeadersService<S, T = Forwarded> {
211 inner: S,
212 by_node: NodeId,
213 _headers: PhantomData<fn() -> T>,
214}
215
216impl<S: fmt::Debug, T> fmt::Debug for SetForwardedHeadersService<S, T> {
217 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
218 f.debug_struct("SetForwardedHeadersService")
219 .field("inner", &self.inner)
220 .field("by_node", &self.by_node)
221 .field(
222 "_headers",
223 &format_args!("{}", std::any::type_name::<fn() -> T>()),
224 )
225 .finish()
226 }
227}
228
229impl<S: Clone, T> Clone for SetForwardedHeadersService<S, T> {
230 fn clone(&self) -> Self {
231 SetForwardedHeadersService {
232 inner: self.inner.clone(),
233 by_node: self.by_node.clone(),
234 _headers: PhantomData,
235 }
236 }
237}
238
239impl<S, T> SetForwardedHeadersService<S, T> {
240 pub fn forward_by(mut self, node_id: impl Into<NodeId>) -> Self {
244 self.by_node = node_id.into();
245 self
246 }
247
248 pub fn set_forward_by(&mut self, node_id: impl Into<NodeId>) -> &mut Self {
252 self.by_node = node_id.into();
253 self
254 }
255}
256
257impl<S, T> SetForwardedHeadersService<S, T> {
258 pub fn new(inner: S) -> Self {
260 Self {
261 inner,
262 by_node: Domain::from_static("rama").into(),
263 _headers: PhantomData,
264 }
265 }
266}
267
268impl<S> SetForwardedHeadersService<S> {
269 #[inline]
270 pub fn forwarded(inner: S) -> Self {
272 Self::new(inner)
273 }
274}
275
276impl<S> SetForwardedHeadersService<S, Via> {
277 #[inline]
278 pub fn via(inner: S) -> Self {
280 Self::new(inner)
281 }
282}
283
284impl<S> SetForwardedHeadersService<S, XForwardedFor> {
285 #[inline]
286 pub fn x_forwarded_for(inner: S) -> Self {
288 Self::new(inner)
289 }
290}
291
292impl<S> SetForwardedHeadersService<S, XForwardedHost> {
293 #[inline]
294 pub fn x_forwarded_host(inner: S) -> Self {
296 Self::new(inner)
297 }
298}
299
300impl<S> SetForwardedHeadersService<S, XForwardedProto> {
301 #[inline]
302 pub fn x_forwarded_proto(inner: S) -> Self {
304 Self::new(inner)
305 }
306}
307
308impl<S, H, State, Body> Service<State, Request<Body>> for SetForwardedHeadersService<S, H>
309where
310 S: Service<State, Request<Body>, Error: Into<BoxError>>,
311 H: ForwardHeader + Send + Sync + 'static,
312 Body: Send + 'static,
313 State: Clone + Send + Sync + 'static,
314{
315 type Response = S::Response;
316 type Error = BoxError;
317
318 async fn serve(
319 &self,
320 mut ctx: Context<State>,
321 mut req: Request<Body>,
322 ) -> Result<Self::Response, Self::Error> {
323 let forwarded: Option<Forwarded> = ctx.get().cloned();
324
325 let mut forwarded_element = ForwardedElement::forwarded_by(self.by_node.clone());
326
327 if let Some(peer_addr) = ctx.get::<SocketInfo>().map(|socket| *socket.peer_addr()) {
328 forwarded_element.set_forwarded_for(peer_addr);
329 }
330 let request_ctx: &mut RequestContext =
331 ctx.get_or_try_insert_with_ctx(|ctx| (ctx, &req).try_into())?;
332
333 forwarded_element.set_forwarded_host(request_ctx.authority.clone());
334
335 if let Ok(forwarded_proto) = (&request_ctx.protocol).try_into() {
336 forwarded_element.set_forwarded_proto(forwarded_proto);
337 }
338
339 let forwarded = match forwarded {
340 None => Some(Forwarded::new(forwarded_element)),
341 Some(mut forwarded) => {
342 forwarded.append(forwarded_element);
343 Some(forwarded)
344 }
345 };
346
347 if let Some(forwarded) = forwarded {
348 if let Some(header) = H::try_from_forwarded(forwarded.iter()) {
349 req.headers_mut().typed_insert(header);
350 }
351 }
352
353 self.inner.serve(ctx, req).await.map_err(Into::into)
354 }
355}
356
357macro_rules! set_forwarded_service_for_tuple {
358 ( $($ty:ident),* $(,)? ) => {
359 #[allow(non_snake_case)]
360 impl<S, $($ty),* , State, Body> Service<State, Request<Body>> for SetForwardedHeadersService<S, ($($ty,)*)>
361 where
362 $( $ty: ForwardHeader + Send + Sync + 'static, )*
363 S: Service<State, Request<Body>, Error: Into<BoxError>>,
364 Body: Send + 'static,
365 State: Clone + Send + Sync + 'static,
366 {
367 type Response = S::Response;
368 type Error = BoxError;
369
370 async fn serve(
371 &self,
372 mut ctx: Context<State>,
373 mut req: Request<Body>,
374 ) -> Result<Self::Response, Self::Error> {
375 let forwarded: Option<Forwarded> = ctx.get().cloned();
376
377 let mut forwarded_element = ForwardedElement::forwarded_by(self.by_node.clone());
378
379 if let Some(peer_addr) = ctx.get::<SocketInfo>().map(|socket| *socket.peer_addr()) {
380 forwarded_element.set_forwarded_for(peer_addr);
381 }
382
383 let request_ctx: &mut RequestContext =
384 ctx.get_or_try_insert_with_ctx(|ctx| (ctx, &req).try_into())?;
385
386 forwarded_element.set_forwarded_host(request_ctx.authority.clone());
387
388 if let Ok(forwarded_proto) = (&request_ctx.protocol).try_into() {
389 forwarded_element.set_forwarded_proto(forwarded_proto);
390 }
391
392 let forwarded = match forwarded {
393 None => Some(Forwarded::new(forwarded_element)),
394 Some(mut forwarded) => {
395 forwarded.append(forwarded_element);
396 Some(forwarded)
397 }
398 };
399
400 if let Some(forwarded) = forwarded {
401 $(
402 if let Some(header) = $ty::try_from_forwarded(forwarded.iter()) {
403 req.headers_mut().typed_insert(header);
404 }
405 )*
406 }
407
408 self.inner.serve(ctx, req).await.map_err(Into::into)
409 }
410 }
411 };
412}
413all_the_tuples_no_last_special_case!(set_forwarded_service_for_tuple);
414
415#[cfg(test)]
416mod tests {
417 use super::*;
418 use crate::{
419 IntoResponse, Response, StatusCode,
420 headers::{TrueClientIp, XClientIp, XRealIp},
421 };
422 use rama_core::{Layer, error::OpaqueError, service::service_fn};
423 use std::{convert::Infallible, net::IpAddr};
424
425 fn assert_is_service<T: Service<(), Request<()>>>(_: T) {}
426
427 async fn dummy_service_fn() -> Result<Response, OpaqueError> {
428 Ok(StatusCode::OK.into_response())
429 }
430
431 #[test]
432 fn test_set_forwarded_service_is_service() {
433 assert_is_service(SetForwardedHeadersService::forwarded(service_fn(
434 dummy_service_fn,
435 )));
436 assert_is_service(SetForwardedHeadersService::via(service_fn(
437 dummy_service_fn,
438 )));
439 assert_is_service(SetForwardedHeadersService::x_forwarded_for(service_fn(
440 dummy_service_fn,
441 )));
442 assert_is_service(SetForwardedHeadersService::x_forwarded_proto(service_fn(
443 dummy_service_fn,
444 )));
445 assert_is_service(SetForwardedHeadersService::x_forwarded_host(service_fn(
446 dummy_service_fn,
447 )));
448 assert_is_service(SetForwardedHeadersService::<_, TrueClientIp>::new(
449 service_fn(dummy_service_fn),
450 ));
451 assert_is_service(SetForwardedHeadersService::<_, (TrueClientIp,)>::new(
452 service_fn(dummy_service_fn),
453 ));
454 assert_is_service(
455 SetForwardedHeadersService::<_, (TrueClientIp, XClientIp)>::new(service_fn(
456 dummy_service_fn,
457 )),
458 );
459 assert_is_service(SetForwardedHeadersLayer::via().into_layer(service_fn(dummy_service_fn)));
460 assert_is_service(
461 SetForwardedHeadersLayer::<XRealIp>::new().into_layer(service_fn(dummy_service_fn)),
462 );
463 assert_is_service(
464 SetForwardedHeadersLayer::<(XRealIp, XForwardedProto)>::new()
465 .into_layer(service_fn(dummy_service_fn)),
466 );
467 }
468
469 #[tokio::test]
470 async fn test_set_forwarded_service_forwarded() {
471 async fn svc(request: Request<()>) -> Result<(), Infallible> {
472 assert_eq!(
473 request.headers().get("Forwarded").unwrap(),
474 "by=rama;host=\"example.com:80\";proto=http"
475 );
476 Ok(())
477 }
478
479 let service = SetForwardedHeadersService::forwarded(service_fn(svc));
480 let req = Request::builder().uri("example.com").body(()).unwrap();
481 service.serve(Context::default(), req).await.unwrap();
482 }
483
484 #[tokio::test]
485 async fn test_set_forwarded_service_forwarded_with_chain() {
486 async fn svc(request: Request<()>) -> Result<(), Infallible> {
487 assert_eq!(
488 request.headers().get("Forwarded").unwrap(),
489 "for=12.23.34.45,by=rama;for=\"127.0.0.1:62345\";host=\"www.example.com:443\";proto=https",
490 );
491 Ok(())
492 }
493
494 let service = SetForwardedHeadersService::forwarded(service_fn(svc));
495 let req = Request::builder()
496 .uri("https://www.example.com")
497 .body(())
498 .unwrap();
499 let mut ctx = Context::default();
500 ctx.insert(Forwarded::new(ForwardedElement::forwarded_for(
501 IpAddr::from([12, 23, 34, 45]),
502 )));
503 ctx.insert(SocketInfo::new(None, "127.0.0.1:62345".parse().unwrap()));
504 service.serve(ctx, req).await.unwrap();
505 }
506
507 #[tokio::test]
508 async fn test_set_forwarded_service_x_forwarded_for_with_chain() {
509 async fn svc(request: Request<()>) -> Result<(), Infallible> {
510 assert_eq!(
511 request.headers().get("X-Forwarded-For").unwrap(),
512 "12.23.34.45, 127.0.0.1",
513 );
514 Ok(())
515 }
516
517 let service = SetForwardedHeadersService::x_forwarded_for(service_fn(svc));
518 let req = Request::builder()
519 .uri("https://www.example.com")
520 .body(())
521 .unwrap();
522 let mut ctx = Context::default();
523 ctx.insert(Forwarded::new(ForwardedElement::forwarded_for(
524 IpAddr::from([12, 23, 34, 45]),
525 )));
526 ctx.insert(SocketInfo::new(None, "127.0.0.1:62345".parse().unwrap()));
527 service.serve(ctx, req).await.unwrap();
528 }
529
530 #[tokio::test]
531 async fn test_set_forwarded_service_forwarded_fully_defined() {
532 async fn svc(request: Request<()>) -> Result<(), Infallible> {
533 assert_eq!(
534 request.headers().get("Forwarded").unwrap(),
535 "by=12.23.34.45;for=\"127.0.0.1:62345\";host=\"www.example.com:443\";proto=https",
536 );
537 Ok(())
538 }
539
540 let service = SetForwardedHeadersService::forwarded(service_fn(svc))
541 .forward_by(IpAddr::from([12, 23, 34, 45]));
542 let req = Request::builder()
543 .uri("https://www.example.com")
544 .body(())
545 .unwrap();
546 let mut ctx = Context::default();
547 ctx.insert(SocketInfo::new(None, "127.0.0.1:62345".parse().unwrap()));
548 service.serve(ctx, req).await.unwrap();
549 }
550
551 #[tokio::test]
552 async fn test_set_forwarded_service_forwarded_fully_defined_with_chain() {
553 async fn svc(request: Request<()>) -> Result<(), Infallible> {
554 assert_eq!(
555 request.headers().get("Forwarded").unwrap(),
556 "by=rama;for=\"127.0.0.1:62345\";host=\"www.example.com:443\";proto=https",
557 );
558 Ok(())
559 }
560
561 let service = SetForwardedHeadersService::forwarded(service_fn(svc));
562 let req = Request::builder()
563 .uri("https://www.example.com")
564 .body(())
565 .unwrap();
566 let mut ctx = Context::default();
567 ctx.insert(SocketInfo::new(None, "127.0.0.1:62345".parse().unwrap()));
568 service.serve(ctx, req).await.unwrap();
569 }
570}