1use std::{fmt, marker::PhantomData, net::IpAddr};
2
3use crate::protocol::{v1, v2};
4use rama_core::{
5 Context, Layer, Service,
6 error::{BoxError, ErrorContext, OpaqueError},
7};
8use rama_net::{
9 client::{ConnectorService, EstablishedClientConnection},
10 forwarded::Forwarded,
11 stream::{Socket, SocketInfo, Stream},
12};
13use tokio::io::AsyncWriteExt;
14
15#[derive(Debug, Clone)]
21pub struct HaProxyLayer<P = protocol::Tcp, V = version::Two> {
22 version: V,
23 _phantom: PhantomData<fn(P)>,
24}
25
26impl HaProxyLayer {
27 pub fn tcp() -> Self {
35 HaProxyLayer {
36 version: Default::default(),
37 _phantom: PhantomData,
38 }
39 }
40
41 pub fn v1(self) -> HaProxyLayer<protocol::Tcp, version::One> {
49 HaProxyLayer {
50 version: Default::default(),
51 _phantom: PhantomData,
52 }
53 }
54}
55
56impl HaProxyLayer<protocol::Udp> {
57 pub fn udp() -> Self {
64 HaProxyLayer {
65 version: Default::default(),
66 _phantom: PhantomData,
67 }
68 }
69}
70
71impl<P> HaProxyLayer<P> {
72 pub fn payload(mut self, payload: Vec<u8>) -> Self {
78 self.version.payload = Some(payload);
79 self
80 }
81
82 pub fn set_payload(&mut self, payload: Vec<u8>) -> &mut Self {
88 self.version.payload = Some(payload);
89 self
90 }
91}
92
93impl<S, P, V: Clone> Layer<S> for HaProxyLayer<P, V> {
94 type Service = HaProxyService<S, P, V>;
95
96 fn layer(&self, inner: S) -> Self::Service {
97 HaProxyService {
98 inner,
99 version: self.version.clone(),
100 _phantom: PhantomData,
101 }
102 }
103
104 fn into_layer(self, inner: S) -> Self::Service {
105 HaProxyService {
106 inner,
107 version: self.version,
108 _phantom: PhantomData,
109 }
110 }
111}
112
113pub struct HaProxyService<S, P = protocol::Tcp, V = version::Two> {
119 inner: S,
120 version: V,
121 _phantom: PhantomData<fn(P)>,
122}
123
124impl<S> HaProxyService<S> {
125 pub fn tcp(inner: S) -> Self {
133 HaProxyService {
134 inner,
135 version: Default::default(),
136 _phantom: PhantomData,
137 }
138 }
139
140 pub fn v1(self) -> HaProxyService<S, protocol::Tcp, version::One> {
148 HaProxyService {
149 inner: self.inner,
150 version: Default::default(),
151 _phantom: PhantomData,
152 }
153 }
154}
155
156impl<S> HaProxyService<S, protocol::Udp> {
157 pub fn udp(inner: S) -> Self {
164 HaProxyService {
165 inner,
166 version: Default::default(),
167 _phantom: PhantomData,
168 }
169 }
170}
171
172impl<S, P> HaProxyService<S, P> {
173 pub fn payload(mut self, payload: Vec<u8>) -> Self {
179 self.version.payload = Some(payload);
180 self
181 }
182
183 pub fn set_payload(&mut self, payload: Vec<u8>) -> &mut Self {
189 self.version.payload = Some(payload);
190 self
191 }
192}
193
194impl<S: fmt::Debug, P, V: fmt::Debug> fmt::Debug for HaProxyService<S, P, V> {
195 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
196 f.debug_struct("HaProxyService")
197 .field("inner", &self.inner)
198 .field("version", &self.version)
199 .field(
200 "_phantom",
201 &format_args!("{}", std::any::type_name::<fn(P)>()),
202 )
203 .finish()
204 }
205}
206
207impl<S: Clone, P, V: Clone> Clone for HaProxyService<S, P, V> {
208 fn clone(&self) -> Self {
209 HaProxyService {
210 inner: self.inner.clone(),
211 version: self.version.clone(),
212 _phantom: PhantomData,
213 }
214 }
215}
216
217impl<S, P, State, Request> Service<State, Request> for HaProxyService<S, P, version::One>
218where
219 S: ConnectorService<State, Request, Connection: Stream + Socket + Unpin, Error: Into<BoxError>>,
220 P: Send + 'static,
221 State: Clone + Send + Sync + 'static,
222 Request: Send + 'static,
223{
224 type Response = EstablishedClientConnection<S::Connection, State, Request>;
225 type Error = BoxError;
226
227 async fn serve(
228 &self,
229 ctx: Context<State>,
230 req: Request,
231 ) -> Result<Self::Response, Self::Error> {
232 let EstablishedClientConnection { ctx, req, mut conn } =
233 self.inner.connect(ctx, req).await.map_err(Into::into)?;
234
235 let src = ctx
236 .get::<Forwarded>()
237 .and_then(|f| f.client_socket_addr())
238 .or_else(|| ctx.get::<SocketInfo>().map(|info| *info.peer_addr()))
239 .ok_or_else(|| {
240 OpaqueError::from_display("PROXY client (v1): missing src socket address")
241 })?;
242
243 let peer_addr = conn.peer_addr()?;
244 let addresses = match (src.ip(), peer_addr.ip()) {
245 (IpAddr::V4(src_ip), IpAddr::V4(dst_ip)) => {
246 v1::Addresses::new_tcp4(src_ip, dst_ip, src.port(), peer_addr.port())
247 }
248 (IpAddr::V6(src_ip), IpAddr::V6(dst_ip)) => {
249 v1::Addresses::new_tcp6(src_ip, dst_ip, src.port(), peer_addr.port())
250 }
251 (_, _) => {
252 return Err(OpaqueError::from_display(
253 "PROXY client (v1): IP version mismatch between src and dest",
254 )
255 .into());
256 }
257 };
258
259 conn.write_all(addresses.to_string().as_bytes())
260 .await
261 .context("PROXY client (v1): write addresses")?;
262
263 Ok(EstablishedClientConnection { ctx, req, conn })
264 }
265}
266
267impl<S, P, State, Request, T> Service<State, Request> for HaProxyService<S, P, version::Two>
268where
269 S: Service<
270 State,
271 Request,
272 Response = EstablishedClientConnection<T, State, Request>,
273 Error: Into<BoxError>,
274 >,
275 P: protocol::Protocol + Send + 'static,
276 State: Clone + Send + Sync + 'static,
277 Request: Send + 'static,
278 T: Stream + Socket + Unpin,
279{
280 type Response = EstablishedClientConnection<T, State, Request>;
281 type Error = BoxError;
282
283 async fn serve(
284 &self,
285 ctx: Context<State>,
286 req: Request,
287 ) -> Result<Self::Response, Self::Error> {
288 let EstablishedClientConnection { ctx, req, mut conn } =
289 self.inner.serve(ctx, req).await.map_err(Into::into)?;
290
291 let src = ctx
292 .get::<Forwarded>()
293 .and_then(|f| f.client_socket_addr())
294 .or_else(|| ctx.get::<SocketInfo>().map(|info| *info.peer_addr()))
295 .ok_or_else(|| {
296 OpaqueError::from_display("PROXY client (v2): missing src socket address")
297 })?;
298
299 let peer_addr = conn.peer_addr()?;
300 let builder = match (src.ip(), peer_addr.ip()) {
301 (IpAddr::V4(src_ip), IpAddr::V4(dst_ip)) => v2::Builder::with_addresses(
302 v2::Version::Two | v2::Command::Proxy,
303 P::v2_protocol(),
304 v2::IPv4::new(src_ip, dst_ip, src.port(), peer_addr.port()),
305 ),
306 (IpAddr::V6(src_ip), IpAddr::V6(dst_ip)) => v2::Builder::with_addresses(
307 v2::Version::Two | v2::Command::Proxy,
308 P::v2_protocol(),
309 v2::IPv6::new(src_ip, dst_ip, src.port(), peer_addr.port()),
310 ),
311 (_, _) => {
312 return Err(OpaqueError::from_display(
313 "PROXY client (v2): IP version mismatch between src and dest",
314 )
315 .into());
316 }
317 };
318
319 let builder = if let Some(payload) = self.version.payload.as_deref() {
320 builder
321 .write_payload(payload)
322 .context("PROXY client (v2): write custom binary payload to to header")?
323 } else {
324 builder
325 };
326
327 let header = builder
328 .build()
329 .context("PROXY client (v2): encode header")?;
330 conn.write_all(&header[..])
331 .await
332 .context("PROXY client (v2): write header")?;
333
334 Ok(EstablishedClientConnection { ctx, req, conn })
335 }
336}
337
338pub mod version {
339 #[derive(Debug, Clone, Default)]
342 #[non_exhaustive]
346 pub struct One;
347
348 #[derive(Debug, Clone, Default)]
349 pub struct Two {
353 pub(crate) payload: Option<Vec<u8>>,
354 }
355}
356
357pub mod protocol {
358 use crate::protocol::v2;
361
362 #[derive(Debug, Clone)]
363 pub struct Tcp;
367
368 #[derive(Debug, Clone)]
369 pub struct Udp;
373
374 pub(super) trait Protocol {
375 fn v2_protocol() -> v2::Protocol;
377 }
378
379 impl Protocol for Tcp {
380 fn v2_protocol() -> v2::Protocol {
381 v2::Protocol::Stream
382 }
383 }
384
385 impl Protocol for Udp {
386 fn v2_protocol() -> v2::Protocol {
387 v2::Protocol::Datagram
388 }
389 }
390}
391
392#[cfg(test)]
393mod tests {
394 use super::*;
395 use rama_core::{Layer, service::service_fn};
396 use rama_net::forwarded::{ForwardedElement, NodeId};
397 use std::{convert::Infallible, net::SocketAddr, pin::Pin};
398 use tokio::io::{AsyncRead, AsyncWrite};
399 use tokio_test::io::{Builder, Mock};
400
401 struct SocketConnection {
402 conn: Mock,
403 socket: SocketAddr,
404 }
405
406 impl Socket for SocketConnection {
407 fn local_addr(&self) -> std::io::Result<SocketAddr> {
408 Ok(self.socket)
409 }
410
411 fn peer_addr(&self) -> std::io::Result<SocketAddr> {
412 Ok(self.socket)
413 }
414 }
415
416 impl AsyncWrite for SocketConnection {
417 fn poll_write(
418 mut self: std::pin::Pin<&mut Self>,
419 cx: &mut std::task::Context<'_>,
420 buf: &[u8],
421 ) -> std::task::Poll<Result<usize, std::io::Error>> {
422 Pin::new(&mut self.conn).poll_write(cx, buf)
423 }
424
425 fn poll_flush(
426 mut self: std::pin::Pin<&mut Self>,
427 cx: &mut std::task::Context<'_>,
428 ) -> std::task::Poll<Result<(), std::io::Error>> {
429 Pin::new(&mut self.conn).poll_flush(cx)
430 }
431
432 fn poll_shutdown(
433 mut self: std::pin::Pin<&mut Self>,
434 cx: &mut std::task::Context<'_>,
435 ) -> std::task::Poll<Result<(), std::io::Error>> {
436 Pin::new(&mut self.conn).poll_shutdown(cx)
437 }
438 }
439
440 impl AsyncRead for SocketConnection {
441 fn poll_read(
442 mut self: Pin<&mut Self>,
443 cx: &mut std::task::Context<'_>,
444 buf: &mut tokio::io::ReadBuf<'_>,
445 ) -> std::task::Poll<std::io::Result<()>> {
446 Pin::new(&mut self.conn).poll_read(cx, buf)
447 }
448 }
449
450 #[tokio::test]
451 async fn test_v1_tcp() {
452 for (expected_line, input_ctx, target_addr) in [
453 (
454 "PROXY TCP4 127.0.1.2 192.168.1.101 80 443\r\n",
455 {
456 let mut ctx = Context::default();
457 ctx.insert(SocketInfo::new(None, "127.0.1.2:80".parse().unwrap()));
458 ctx
459 },
460 "192.168.1.101:443",
461 ),
462 (
463 "PROXY TCP4 127.0.1.2 192.168.1.101 80 443\r\n",
464 {
465 let mut ctx = Context::default();
466 ctx.insert(SocketInfo::new(
467 None,
468 "[1234:5678:90ab:cdef:fedc:ba09:8765:4321]:443"
469 .parse()
470 .unwrap(),
471 ));
472 ctx.insert(Forwarded::new(ForwardedElement::forwarded_for(
473 NodeId::try_from("127.0.1.2:80").unwrap(),
474 )));
475 ctx
476 },
477 "192.168.1.101:443",
478 ),
479 (
480 "PROXY TCP6 1234:5678:90ab:cdef:fedc:ba09:8765:4321 4321:8765:ba09:fedc:cdef:90ab:5678:1234 443 65535\r\n",
481 {
482 let mut ctx = Context::default();
483 ctx.insert(SocketInfo::new(
484 None,
485 "[1234:5678:90ab:cdef:fedc:ba09:8765:4321]:443"
486 .parse()
487 .unwrap(),
488 ));
489 ctx
490 },
491 "[4321:8765:ba09:fedc:cdef:90ab:5678:1234]:65535",
492 ),
493 (
494 "PROXY TCP6 1234:5678:90ab:cdef:fedc:ba09:8765:4321 4321:8765:ba09:fedc:cdef:90ab:5678:1234 443 65535\r\n",
495 {
496 let mut ctx = Context::default();
497 ctx.insert(SocketInfo::new(None, "127.0.1.2:80".parse().unwrap()));
498 ctx.insert(Forwarded::new(ForwardedElement::forwarded_for(
499 NodeId::try_from("[1234:5678:90ab:cdef:fedc:ba09:8765:4321]:443").unwrap(),
500 )));
501 ctx
502 },
503 "[4321:8765:ba09:fedc:cdef:90ab:5678:1234]:65535",
504 ),
505 ] {
506 let svc = HaProxyLayer::tcp()
507 .v1()
508 .layer(service_fn(async move |ctx, req| {
509 Ok::<_, Infallible>(EstablishedClientConnection {
510 ctx,
511 req,
512 conn: SocketConnection {
513 socket: target_addr.parse().unwrap(),
514 conn: Builder::new().write(expected_line.as_bytes()).build(),
515 },
516 })
517 }));
518 svc.serve(input_ctx, ()).await.unwrap();
519 }
520 }
521
522 #[tokio::test]
523 async fn test_v1_tcp_ip_version_mismatch() {
524 for (input_ctx, target_addr) in [
525 (
526 {
527 let mut ctx = Context::default();
528 ctx.insert(SocketInfo::new(
529 None,
530 "[1234:5678:90ab:cdef:fedc:ba09:8765:4321]:80"
531 .parse()
532 .unwrap(),
533 ));
534 ctx
535 },
536 "192.168.1.101:443",
537 ),
538 (
539 {
540 let mut ctx = Context::default();
541 ctx.insert(SocketInfo::new(None, "127.0.1.2:80".parse().unwrap()));
542 ctx.insert(Forwarded::new(ForwardedElement::forwarded_for(
543 NodeId::try_from("[1234:5678:90ab:cdef:fedc:ba09:8765:4321]:80").unwrap(),
544 )));
545 ctx
546 },
547 "192.168.1.101:443",
548 ),
549 (
550 {
551 let mut ctx = Context::default();
552 ctx.insert(SocketInfo::new(None, "127.0.1.2:80".parse().unwrap()));
553 ctx
554 },
555 "[4321:8765:ba09:fedc:cdef:90ab:5678:1234]:65535",
556 ),
557 (
558 {
559 let mut ctx = Context::default();
560 ctx.insert(SocketInfo::new(
561 None,
562 "[1234:5678:90ab:cdef:fedc:ba09:8765:4321]:80"
563 .parse()
564 .unwrap(),
565 ));
566 ctx.insert(Forwarded::new(ForwardedElement::forwarded_for(
567 NodeId::try_from("127.0.1.2:80").unwrap(),
568 )));
569 ctx
570 },
571 "[4321:8765:ba09:fedc:cdef:90ab:5678:1234]:65535",
572 ),
573 ] {
574 let svc = HaProxyLayer::tcp()
575 .v1()
576 .layer(service_fn(async move |ctx, req| {
577 Ok::<_, Infallible>(EstablishedClientConnection {
578 ctx,
579 req,
580 conn: SocketConnection {
581 socket: target_addr.parse().unwrap(),
582 conn: Builder::new().build(),
583 },
584 })
585 }));
586 assert!(svc.serve(input_ctx, ()).await.is_err());
587 }
588 }
589
590 #[tokio::test]
591 async fn test_v1_tcp_missing_src() {
592 for (input_ctx, target_addr) in [
593 (Context::default(), "192.168.1.101:443"),
594 (
595 Context::default(),
596 "[1234:5678:90ab:cdef:fedc:ba09:8765:4321]:443",
597 ),
598 ] {
599 let svc = HaProxyLayer::tcp()
600 .v1()
601 .layer(service_fn(async move |ctx, req| {
602 Ok::<_, Infallible>(EstablishedClientConnection {
603 ctx,
604 req,
605 conn: SocketConnection {
606 socket: target_addr.parse().unwrap(),
607 conn: Builder::new().build(),
608 },
609 })
610 }));
611 assert!(svc.serve(input_ctx, ()).await.is_err());
612 }
613 }
614
615 #[tokio::test]
616 async fn test_v2_tcp4() {
617 for input_ctx in [
618 {
619 let mut ctx = Context::default();
620 ctx.insert(SocketInfo::new(None, "127.0.0.1:80".parse().unwrap()));
621 ctx
622 },
623 {
624 let mut ctx = Context::default();
625 ctx.insert(SocketInfo::new(
626 None,
627 "[1234:5678:90ab:cdef:fedc:ba09:8765:4321]:443"
628 .parse()
629 .unwrap(),
630 ));
631 ctx.insert(Forwarded::new(ForwardedElement::forwarded_for(
632 NodeId::try_from("127.0.0.1:80").unwrap(),
633 )));
634 ctx
635 },
636 ] {
637 let svc =
638 HaProxyLayer::tcp()
639 .payload(vec![42])
640 .layer(service_fn(async move |ctx, req| {
641 Ok::<_, Infallible>(EstablishedClientConnection {
642 ctx,
643 req,
644 conn: SocketConnection {
645 socket: "192.168.1.1:443".parse().unwrap(),
646 conn: Builder::new()
647 .write(&[
648 b'\r', b'\n', b'\r', b'\n', b'\0', b'\r', b'\n', b'Q',
649 b'U', b'I', b'T', b'\n', 0x21, 0x11, 0, 13, 127, 0, 0, 1,
650 192, 168, 1, 1, 0, 80, 1, 187, 42,
651 ])
652 .build(),
653 },
654 })
655 }));
656 svc.serve(input_ctx, ()).await.unwrap();
657 }
658 }
659
660 #[tokio::test]
661 async fn test_v2_udp4() {
662 for input_ctx in [
663 {
664 let mut ctx = Context::default();
665 ctx.insert(SocketInfo::new(None, "127.0.0.1:80".parse().unwrap()));
666 ctx
667 },
668 {
669 let mut ctx = Context::default();
670 ctx.insert(SocketInfo::new(
671 None,
672 "[1234:5678:90ab:cdef:fedc:ba09:8765:4321]:443"
673 .parse()
674 .unwrap(),
675 ));
676 ctx.insert(Forwarded::new(ForwardedElement::forwarded_for(
677 NodeId::try_from("127.0.0.1:80").unwrap(),
678 )));
679 ctx
680 },
681 ] {
682 let svc =
683 HaProxyLayer::udp()
684 .payload(vec![42])
685 .layer(service_fn(async move |ctx, req| {
686 Ok::<_, Infallible>(EstablishedClientConnection {
687 ctx,
688 req,
689 conn: SocketConnection {
690 socket: "192.168.1.1:443".parse().unwrap(),
691 conn: Builder::new()
692 .write(&[
693 b'\r', b'\n', b'\r', b'\n', b'\0', b'\r', b'\n', b'Q',
694 b'U', b'I', b'T', b'\n', 0x21, 0x12, 0, 13, 127, 0, 0, 1,
695 192, 168, 1, 1, 0, 80, 1, 187, 42,
696 ])
697 .build(),
698 },
699 })
700 }));
701 svc.serve(input_ctx, ()).await.unwrap();
702 }
703 }
704
705 #[tokio::test]
706 async fn test_v2_tcp6() {
707 for input_ctx in [
708 {
709 let mut ctx = Context::default();
710 ctx.insert(SocketInfo::new(
711 None,
712 "[1234:5678:90ab:cdef:fedc:ba09:8765:4321]:80"
713 .parse()
714 .unwrap(),
715 ));
716 ctx
717 },
718 {
719 let mut ctx = Context::default();
720 ctx.insert(SocketInfo::new(None, "127.0.0.1:80".parse().unwrap()));
721 ctx.insert(Forwarded::new(ForwardedElement::forwarded_for(
722 NodeId::try_from("[1234:5678:90ab:cdef:fedc:ba09:8765:4321]:80").unwrap(),
723 )));
724 ctx
725 },
726 ] {
727 let svc =
728 HaProxyLayer::tcp()
729 .payload(vec![42])
730 .layer(service_fn(async move |ctx, req| {
731 Ok::<_, Infallible>(EstablishedClientConnection {
732 ctx,
733 req,
734 conn: SocketConnection {
735 socket: "[4321:8765:ba09:fedc:cdef:90ab:5678:1234]:443"
736 .parse()
737 .unwrap(),
738 conn: Builder::new()
739 .write(&[
740 b'\r', b'\n', b'\r', b'\n', b'\0', b'\r', b'\n', b'Q',
741 b'U', b'I', b'T', b'\n', 0x21, 0x21, 0, 37, 0x12, 0x34,
742 0x56, 0x78, 0x90, 0xab, 0xcd, 0xef, 0xfe, 0xdc, 0xba, 0x09,
743 0x87, 0x65, 0x43, 0x21, 0x43, 0x21, 0x87, 0x65, 0xba, 0x09,
744 0xfe, 0xdc, 0xcd, 0xef, 0x90, 0xab, 0x56, 0x78, 0x12, 0x34,
745 0, 80, 1, 187, 42,
746 ])
747 .build(),
748 },
749 })
750 }));
751 svc.serve(input_ctx, ()).await.unwrap();
752 }
753 }
754
755 #[tokio::test]
756 async fn test_v2_udp6() {
757 for input_ctx in [
758 {
759 let mut ctx = Context::default();
760 ctx.insert(SocketInfo::new(
761 None,
762 "[1234:5678:90ab:cdef:fedc:ba09:8765:4321]:80"
763 .parse()
764 .unwrap(),
765 ));
766 ctx
767 },
768 {
769 let mut ctx = Context::default();
770 ctx.insert(SocketInfo::new(None, "127.0.0.1:80".parse().unwrap()));
771 ctx.insert(Forwarded::new(ForwardedElement::forwarded_for(
772 NodeId::try_from("[1234:5678:90ab:cdef:fedc:ba09:8765:4321]:80").unwrap(),
773 )));
774 ctx
775 },
776 ] {
777 let svc =
778 HaProxyLayer::udp()
779 .payload(vec![42])
780 .layer(service_fn(async move |ctx, req| {
781 Ok::<_, Infallible>(EstablishedClientConnection {
782 ctx,
783 req,
784 conn: SocketConnection {
785 socket: "[4321:8765:ba09:fedc:cdef:90ab:5678:1234]:443"
786 .parse()
787 .unwrap(),
788 conn: Builder::new()
789 .write(&[
790 b'\r', b'\n', b'\r', b'\n', b'\0', b'\r', b'\n', b'Q',
791 b'U', b'I', b'T', b'\n', 0x21, 0x22, 0, 37, 0x12, 0x34,
792 0x56, 0x78, 0x90, 0xab, 0xcd, 0xef, 0xfe, 0xdc, 0xba, 0x09,
793 0x87, 0x65, 0x43, 0x21, 0x43, 0x21, 0x87, 0x65, 0xba, 0x09,
794 0xfe, 0xdc, 0xcd, 0xef, 0x90, 0xab, 0x56, 0x78, 0x12, 0x34,
795 0, 80, 1, 187, 42,
796 ])
797 .build(),
798 },
799 })
800 }));
801 svc.serve(input_ctx, ()).await.unwrap();
802 }
803 }
804
805 #[tokio::test]
806 async fn test_v2_ip_version_mismatch() {
807 for (input_ctx, target_addr) in [
808 (
809 {
810 let mut ctx = Context::default();
811 ctx.insert(SocketInfo::new(
812 None,
813 "[1234:5678:90ab:cdef:fedc:ba09:8765:4321]:80"
814 .parse()
815 .unwrap(),
816 ));
817 ctx
818 },
819 "192.168.1.101:443",
820 ),
821 (
822 {
823 let mut ctx = Context::default();
824 ctx.insert(SocketInfo::new(None, "127.0.1.2:80".parse().unwrap()));
825 ctx.insert(Forwarded::new(ForwardedElement::forwarded_for(
826 NodeId::try_from("[1234:5678:90ab:cdef:fedc:ba09:8765:4321]:80").unwrap(),
827 )));
828 ctx
829 },
830 "192.168.1.101:443",
831 ),
832 (
833 {
834 let mut ctx = Context::default();
835 ctx.insert(SocketInfo::new(None, "127.0.1.2:80".parse().unwrap()));
836 ctx
837 },
838 "[4321:8765:ba09:fedc:cdef:90ab:5678:1234]:65535",
839 ),
840 (
841 {
842 let mut ctx = Context::default();
843 ctx.insert(SocketInfo::new(
844 None,
845 "[1234:5678:90ab:cdef:fedc:ba09:8765:4321]:80"
846 .parse()
847 .unwrap(),
848 ));
849 ctx.insert(Forwarded::new(ForwardedElement::forwarded_for(
850 NodeId::try_from("127.0.1.2:80").unwrap(),
851 )));
852 ctx
853 },
854 "[4321:8765:ba09:fedc:cdef:90ab:5678:1234]:65535",
855 ),
856 ] {
857 let svc = HaProxyLayer::tcp().layer(service_fn(async move |ctx, req| {
860 Ok::<_, Infallible>(EstablishedClientConnection {
861 ctx,
862 req,
863 conn: SocketConnection {
864 socket: target_addr.parse().unwrap(),
865 conn: Builder::new().build(),
866 },
867 })
868 }));
869 assert!(svc.serve(input_ctx.clone(), ()).await.is_err());
870
871 let svc = HaProxyLayer::udp().layer(service_fn(async move |ctx, req| {
874 Ok::<_, Infallible>(EstablishedClientConnection {
875 ctx,
876 req,
877 conn: SocketConnection {
878 socket: target_addr.parse().unwrap(),
879 conn: Builder::new().build(),
880 },
881 })
882 }));
883 assert!(svc.serve(input_ctx, ()).await.is_err());
884 }
885 }
886
887 #[tokio::test]
888 async fn test_v2_missing_src() {
889 for (input_ctx, target_addr) in [
890 (Context::default(), "192.168.1.101:443"),
891 (
892 Context::default(),
893 "[1234:5678:90ab:cdef:fedc:ba09:8765:4321]:443",
894 ),
895 ] {
896 let svc = HaProxyLayer::tcp().layer(service_fn(async move |ctx, req| {
899 Ok::<_, Infallible>(EstablishedClientConnection {
900 ctx,
901 req,
902 conn: SocketConnection {
903 socket: target_addr.parse().unwrap(),
904 conn: Builder::new().build(),
905 },
906 })
907 }));
908 assert!(svc.serve(input_ctx.clone(), ()).await.is_err());
909
910 let svc = HaProxyLayer::udp().layer(service_fn(async move |ctx, req| {
913 Ok::<_, Infallible>(EstablishedClientConnection {
914 ctx,
915 req,
916 conn: SocketConnection {
917 socket: target_addr.parse().unwrap(),
918 conn: Builder::new().build(),
919 },
920 })
921 }));
922 assert!(svc.serve(input_ctx.clone(), ()).await.is_err());
923 }
924 }
925}