rama_haproxy/client/
layer.rs

1use std::{fmt, marker::PhantomData, net::IpAddr};
2
3use crate::protocol::{v1, v2};
4use rama_core::{
5    Context, Layer, Service,
6    error::{BoxError, ErrorContext, OpaqueError},
7};
8use rama_net::{
9    client::{ConnectorService, EstablishedClientConnection},
10    forwarded::Forwarded,
11    stream::{Socket, SocketInfo, Stream},
12};
13use tokio::io::AsyncWriteExt;
14
15/// Layer to encode and write the HaProxy Protocol,
16/// as a client on the connected stream.
17///
18/// This connector should in most cases
19/// happen as the first thing after establishing the connection.
20#[derive(Debug, Clone)]
21pub struct HaProxyLayer<P = protocol::Tcp, V = version::Two> {
22    version: V,
23    _phantom: PhantomData<fn(P)>,
24}
25
26impl HaProxyLayer {
27    /// Create a new [`HaProxyLayer`] for the TCP protocol (default).
28    ///
29    /// This is in the PROXY spec referred to as:
30    ///
31    /// - TCP4 (for IPv4, v1)
32    /// - TCP6 (for IPv6, v1)
33    /// - Stream (v2)
34    pub fn tcp() -> Self {
35        HaProxyLayer {
36            version: Default::default(),
37            _phantom: PhantomData,
38        }
39    }
40
41    /// Use version one of PROXY protocol, instead of the
42    /// default version two.
43    ///
44    /// Version one makes use of a less advanced text protocol,
45    /// instead the more advanced binary v2 protocol.
46    ///
47    /// Use this only if you have no control over a v1-only server.
48    pub fn v1(self) -> HaProxyLayer<protocol::Tcp, version::One> {
49        HaProxyLayer {
50            version: Default::default(),
51            _phantom: PhantomData,
52        }
53    }
54}
55
56impl HaProxyLayer<protocol::Udp> {
57    /// Create a new [`HaProxyLayer`] for the UDP protocol,
58    /// instead of the default TCP protocol.
59    ///
60    /// This is in the PROXY spec referred to as:
61    ///
62    /// - Datagram (v2)
63    pub fn udp() -> Self {
64        HaProxyLayer {
65            version: Default::default(),
66            _phantom: PhantomData,
67        }
68    }
69}
70
71impl<P> HaProxyLayer<P> {
72    /// Attach a custom bytes payload to the PROXY header.
73    ///
74    /// NOTE this is only possible in Version two of the PROXY Protocol.
75    /// In case you downgrade this [`HaProxyLayer`] to version one later
76    /// using [`Self::v1`] this payload will be dropped.
77    pub fn payload(mut self, payload: Vec<u8>) -> Self {
78        self.version.payload = Some(payload);
79        self
80    }
81
82    /// Attach a custom bytes payload to the PROXY header.
83    ///
84    /// NOTE this is only possible in Version two of the PROXY Protocol.
85    /// In case you downgrade this [`HaProxyLayer`] to version one later
86    /// using [`Self::v1`] this payload will be dropped.
87    pub fn set_payload(&mut self, payload: Vec<u8>) -> &mut Self {
88        self.version.payload = Some(payload);
89        self
90    }
91}
92
93impl<S, P, V: Clone> Layer<S> for HaProxyLayer<P, V> {
94    type Service = HaProxyService<S, P, V>;
95
96    fn layer(&self, inner: S) -> Self::Service {
97        HaProxyService {
98            inner,
99            version: self.version.clone(),
100            _phantom: PhantomData,
101        }
102    }
103
104    fn into_layer(self, inner: S) -> Self::Service {
105        HaProxyService {
106            inner,
107            version: self.version,
108            _phantom: PhantomData,
109        }
110    }
111}
112
113/// Service to encode and write the HaProxy Protocol
114/// as a client on the connected stream.
115///
116/// This connector should in most cases
117/// happen as the first thing after establishing the connection.
118pub struct HaProxyService<S, P = protocol::Tcp, V = version::Two> {
119    inner: S,
120    version: V,
121    _phantom: PhantomData<fn(P)>,
122}
123
124impl<S> HaProxyService<S> {
125    /// Create a new [`HaProxyService`] for the TCP protocol (default).
126    ///
127    /// This is in the PROXY spec referred to as:
128    ///
129    /// - TCP4 (for IPv4, v1)
130    /// - TCP6 (for IPv6, v1)
131    /// - Stream (v2)
132    pub fn tcp(inner: S) -> Self {
133        HaProxyService {
134            inner,
135            version: Default::default(),
136            _phantom: PhantomData,
137        }
138    }
139
140    /// Use version one of PROXY protocol, instead of the
141    /// default version two.
142    ///
143    /// Version one makes use of a less advanced text protocol,
144    /// instead the more advanced binary v2 protocol.
145    ///
146    /// Use this only if you have no control over a v1-only server.
147    pub fn v1(self) -> HaProxyService<S, protocol::Tcp, version::One> {
148        HaProxyService {
149            inner: self.inner,
150            version: Default::default(),
151            _phantom: PhantomData,
152        }
153    }
154}
155
156impl<S> HaProxyService<S, protocol::Udp> {
157    /// Create a new [`HaProxyService`] for the UDP protocol,
158    /// instead of the default TCP protocol.
159    ///
160    /// This is in the PROXY spec referred to as:
161    ///
162    /// - Datagram (v2)
163    pub fn udp(inner: S) -> Self {
164        HaProxyService {
165            inner,
166            version: Default::default(),
167            _phantom: PhantomData,
168        }
169    }
170}
171
172impl<S, P> HaProxyService<S, P> {
173    /// Attach a custom bytes payload to the PROXY header.
174    ///
175    /// NOTE this is only possible in Version two of the PROXY Protocol.
176    /// In case you downgrade this [`HaProxyLayer`] to version one later
177    /// using [`Self::v1`] this payload will be dropped.
178    pub fn payload(mut self, payload: Vec<u8>) -> Self {
179        self.version.payload = Some(payload);
180        self
181    }
182
183    /// Attach a custom bytes payload to the PROXY header.
184    ///
185    /// NOTE this is only possible in Version two of the PROXY Protocol.
186    /// In case you downgrade this [`HaProxyLayer`] to version one later
187    /// using [`Self::v1`] this payload will be dropped.
188    pub fn set_payload(&mut self, payload: Vec<u8>) -> &mut Self {
189        self.version.payload = Some(payload);
190        self
191    }
192}
193
194impl<S: fmt::Debug, P, V: fmt::Debug> fmt::Debug for HaProxyService<S, P, V> {
195    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
196        f.debug_struct("HaProxyService")
197            .field("inner", &self.inner)
198            .field("version", &self.version)
199            .field(
200                "_phantom",
201                &format_args!("{}", std::any::type_name::<fn(P)>()),
202            )
203            .finish()
204    }
205}
206
207impl<S: Clone, P, V: Clone> Clone for HaProxyService<S, P, V> {
208    fn clone(&self) -> Self {
209        HaProxyService {
210            inner: self.inner.clone(),
211            version: self.version.clone(),
212            _phantom: PhantomData,
213        }
214    }
215}
216
217impl<S, P, State, Request> Service<State, Request> for HaProxyService<S, P, version::One>
218where
219    S: ConnectorService<State, Request, Connection: Stream + Socket + Unpin, Error: Into<BoxError>>,
220    P: Send + 'static,
221    State: Clone + Send + Sync + 'static,
222    Request: Send + 'static,
223{
224    type Response = EstablishedClientConnection<S::Connection, State, Request>;
225    type Error = BoxError;
226
227    async fn serve(
228        &self,
229        ctx: Context<State>,
230        req: Request,
231    ) -> Result<Self::Response, Self::Error> {
232        let EstablishedClientConnection { ctx, req, mut conn } =
233            self.inner.connect(ctx, req).await.map_err(Into::into)?;
234
235        let src = ctx
236            .get::<Forwarded>()
237            .and_then(|f| f.client_socket_addr())
238            .or_else(|| ctx.get::<SocketInfo>().map(|info| *info.peer_addr()))
239            .ok_or_else(|| {
240                OpaqueError::from_display("PROXY client (v1): missing src socket address")
241            })?;
242
243        let peer_addr = conn.peer_addr()?;
244        let addresses = match (src.ip(), peer_addr.ip()) {
245            (IpAddr::V4(src_ip), IpAddr::V4(dst_ip)) => {
246                v1::Addresses::new_tcp4(src_ip, dst_ip, src.port(), peer_addr.port())
247            }
248            (IpAddr::V6(src_ip), IpAddr::V6(dst_ip)) => {
249                v1::Addresses::new_tcp6(src_ip, dst_ip, src.port(), peer_addr.port())
250            }
251            (_, _) => {
252                return Err(OpaqueError::from_display(
253                    "PROXY client (v1): IP version mismatch between src and dest",
254                )
255                .into());
256            }
257        };
258
259        conn.write_all(addresses.to_string().as_bytes())
260            .await
261            .context("PROXY client (v1): write addresses")?;
262
263        Ok(EstablishedClientConnection { ctx, req, conn })
264    }
265}
266
267impl<S, P, State, Request, T> Service<State, Request> for HaProxyService<S, P, version::Two>
268where
269    S: Service<
270            State,
271            Request,
272            Response = EstablishedClientConnection<T, State, Request>,
273            Error: Into<BoxError>,
274        >,
275    P: protocol::Protocol + Send + 'static,
276    State: Clone + Send + Sync + 'static,
277    Request: Send + 'static,
278    T: Stream + Socket + Unpin,
279{
280    type Response = EstablishedClientConnection<T, State, Request>;
281    type Error = BoxError;
282
283    async fn serve(
284        &self,
285        ctx: Context<State>,
286        req: Request,
287    ) -> Result<Self::Response, Self::Error> {
288        let EstablishedClientConnection { ctx, req, mut conn } =
289            self.inner.serve(ctx, req).await.map_err(Into::into)?;
290
291        let src = ctx
292            .get::<Forwarded>()
293            .and_then(|f| f.client_socket_addr())
294            .or_else(|| ctx.get::<SocketInfo>().map(|info| *info.peer_addr()))
295            .ok_or_else(|| {
296                OpaqueError::from_display("PROXY client (v2): missing src socket address")
297            })?;
298
299        let peer_addr = conn.peer_addr()?;
300        let builder = match (src.ip(), peer_addr.ip()) {
301            (IpAddr::V4(src_ip), IpAddr::V4(dst_ip)) => v2::Builder::with_addresses(
302                v2::Version::Two | v2::Command::Proxy,
303                P::v2_protocol(),
304                v2::IPv4::new(src_ip, dst_ip, src.port(), peer_addr.port()),
305            ),
306            (IpAddr::V6(src_ip), IpAddr::V6(dst_ip)) => v2::Builder::with_addresses(
307                v2::Version::Two | v2::Command::Proxy,
308                P::v2_protocol(),
309                v2::IPv6::new(src_ip, dst_ip, src.port(), peer_addr.port()),
310            ),
311            (_, _) => {
312                return Err(OpaqueError::from_display(
313                    "PROXY client (v2): IP version mismatch between src and dest",
314                )
315                .into());
316            }
317        };
318
319        let builder = if let Some(payload) = self.version.payload.as_deref() {
320            builder
321                .write_payload(payload)
322                .context("PROXY client (v2): write custom binary payload to to header")?
323        } else {
324            builder
325        };
326
327        let header = builder
328            .build()
329            .context("PROXY client (v2): encode header")?;
330        conn.write_all(&header[..])
331            .await
332            .context("PROXY client (v2): write header")?;
333
334        Ok(EstablishedClientConnection { ctx, req, conn })
335    }
336}
337
338pub mod version {
339    //! Marker traits for the HaProxy (PROXY) version to be used by client layer (service).
340
341    #[derive(Debug, Clone, Default)]
342    /// Use version 1 of the PROXY protocol.
343    ///
344    /// See [`crate::protocol`] for more information.
345    #[non_exhaustive]
346    pub struct One;
347
348    #[derive(Debug, Clone, Default)]
349    /// Use version 2 of the PROXY protocol.
350    ///
351    /// See [`crate::protocol`] for more information.
352    pub struct Two {
353        pub(crate) payload: Option<Vec<u8>>,
354    }
355}
356
357pub mod protocol {
358    //! Marker traits for the HaProxy (PROXY) protocol to be used by client layer (service).
359
360    use crate::protocol::v2;
361
362    #[derive(Debug, Clone)]
363    /// Encode the data for the TCP protocol (possible in [`super::version::One`] and [`super::version::Two`]).
364    ///
365    /// See [`crate::protocol`] for more information.
366    pub struct Tcp;
367
368    #[derive(Debug, Clone)]
369    /// Encode the data for the UDP protocol (possible only in [`super::version::Two`]).
370    ///
371    /// See [`crate::protocol`] for more information.
372    pub struct Udp;
373
374    pub(super) trait Protocol {
375        /// Return the v2 PROXY protocol linked to the protocol implementation.
376        fn v2_protocol() -> v2::Protocol;
377    }
378
379    impl Protocol for Tcp {
380        fn v2_protocol() -> v2::Protocol {
381            v2::Protocol::Stream
382        }
383    }
384
385    impl Protocol for Udp {
386        fn v2_protocol() -> v2::Protocol {
387            v2::Protocol::Datagram
388        }
389    }
390}
391
392#[cfg(test)]
393mod tests {
394    use super::*;
395    use rama_core::{Layer, service::service_fn};
396    use rama_net::forwarded::{ForwardedElement, NodeId};
397    use std::{convert::Infallible, net::SocketAddr, pin::Pin};
398    use tokio::io::{AsyncRead, AsyncWrite};
399    use tokio_test::io::{Builder, Mock};
400
401    struct SocketConnection {
402        conn: Mock,
403        socket: SocketAddr,
404    }
405
406    impl Socket for SocketConnection {
407        fn local_addr(&self) -> std::io::Result<SocketAddr> {
408            Ok(self.socket)
409        }
410
411        fn peer_addr(&self) -> std::io::Result<SocketAddr> {
412            Ok(self.socket)
413        }
414    }
415
416    impl AsyncWrite for SocketConnection {
417        fn poll_write(
418            mut self: std::pin::Pin<&mut Self>,
419            cx: &mut std::task::Context<'_>,
420            buf: &[u8],
421        ) -> std::task::Poll<Result<usize, std::io::Error>> {
422            Pin::new(&mut self.conn).poll_write(cx, buf)
423        }
424
425        fn poll_flush(
426            mut self: std::pin::Pin<&mut Self>,
427            cx: &mut std::task::Context<'_>,
428        ) -> std::task::Poll<Result<(), std::io::Error>> {
429            Pin::new(&mut self.conn).poll_flush(cx)
430        }
431
432        fn poll_shutdown(
433            mut self: std::pin::Pin<&mut Self>,
434            cx: &mut std::task::Context<'_>,
435        ) -> std::task::Poll<Result<(), std::io::Error>> {
436            Pin::new(&mut self.conn).poll_shutdown(cx)
437        }
438    }
439
440    impl AsyncRead for SocketConnection {
441        fn poll_read(
442            mut self: Pin<&mut Self>,
443            cx: &mut std::task::Context<'_>,
444            buf: &mut tokio::io::ReadBuf<'_>,
445        ) -> std::task::Poll<std::io::Result<()>> {
446            Pin::new(&mut self.conn).poll_read(cx, buf)
447        }
448    }
449
450    #[tokio::test]
451    async fn test_v1_tcp() {
452        for (expected_line, input_ctx, target_addr) in [
453            (
454                "PROXY TCP4 127.0.1.2 192.168.1.101 80 443\r\n",
455                {
456                    let mut ctx = Context::default();
457                    ctx.insert(SocketInfo::new(None, "127.0.1.2:80".parse().unwrap()));
458                    ctx
459                },
460                "192.168.1.101:443",
461            ),
462            (
463                "PROXY TCP4 127.0.1.2 192.168.1.101 80 443\r\n",
464                {
465                    let mut ctx = Context::default();
466                    ctx.insert(SocketInfo::new(
467                        None,
468                        "[1234:5678:90ab:cdef:fedc:ba09:8765:4321]:443"
469                            .parse()
470                            .unwrap(),
471                    ));
472                    ctx.insert(Forwarded::new(ForwardedElement::forwarded_for(
473                        NodeId::try_from("127.0.1.2:80").unwrap(),
474                    )));
475                    ctx
476                },
477                "192.168.1.101:443",
478            ),
479            (
480                "PROXY TCP6 1234:5678:90ab:cdef:fedc:ba09:8765:4321 4321:8765:ba09:fedc:cdef:90ab:5678:1234 443 65535\r\n",
481                {
482                    let mut ctx = Context::default();
483                    ctx.insert(SocketInfo::new(
484                        None,
485                        "[1234:5678:90ab:cdef:fedc:ba09:8765:4321]:443"
486                            .parse()
487                            .unwrap(),
488                    ));
489                    ctx
490                },
491                "[4321:8765:ba09:fedc:cdef:90ab:5678:1234]:65535",
492            ),
493            (
494                "PROXY TCP6 1234:5678:90ab:cdef:fedc:ba09:8765:4321 4321:8765:ba09:fedc:cdef:90ab:5678:1234 443 65535\r\n",
495                {
496                    let mut ctx = Context::default();
497                    ctx.insert(SocketInfo::new(None, "127.0.1.2:80".parse().unwrap()));
498                    ctx.insert(Forwarded::new(ForwardedElement::forwarded_for(
499                        NodeId::try_from("[1234:5678:90ab:cdef:fedc:ba09:8765:4321]:443").unwrap(),
500                    )));
501                    ctx
502                },
503                "[4321:8765:ba09:fedc:cdef:90ab:5678:1234]:65535",
504            ),
505        ] {
506            let svc = HaProxyLayer::tcp()
507                .v1()
508                .layer(service_fn(async move |ctx, req| {
509                    Ok::<_, Infallible>(EstablishedClientConnection {
510                        ctx,
511                        req,
512                        conn: SocketConnection {
513                            socket: target_addr.parse().unwrap(),
514                            conn: Builder::new().write(expected_line.as_bytes()).build(),
515                        },
516                    })
517                }));
518            svc.serve(input_ctx, ()).await.unwrap();
519        }
520    }
521
522    #[tokio::test]
523    async fn test_v1_tcp_ip_version_mismatch() {
524        for (input_ctx, target_addr) in [
525            (
526                {
527                    let mut ctx = Context::default();
528                    ctx.insert(SocketInfo::new(
529                        None,
530                        "[1234:5678:90ab:cdef:fedc:ba09:8765:4321]:80"
531                            .parse()
532                            .unwrap(),
533                    ));
534                    ctx
535                },
536                "192.168.1.101:443",
537            ),
538            (
539                {
540                    let mut ctx = Context::default();
541                    ctx.insert(SocketInfo::new(None, "127.0.1.2:80".parse().unwrap()));
542                    ctx.insert(Forwarded::new(ForwardedElement::forwarded_for(
543                        NodeId::try_from("[1234:5678:90ab:cdef:fedc:ba09:8765:4321]:80").unwrap(),
544                    )));
545                    ctx
546                },
547                "192.168.1.101:443",
548            ),
549            (
550                {
551                    let mut ctx = Context::default();
552                    ctx.insert(SocketInfo::new(None, "127.0.1.2:80".parse().unwrap()));
553                    ctx
554                },
555                "[4321:8765:ba09:fedc:cdef:90ab:5678:1234]:65535",
556            ),
557            (
558                {
559                    let mut ctx = Context::default();
560                    ctx.insert(SocketInfo::new(
561                        None,
562                        "[1234:5678:90ab:cdef:fedc:ba09:8765:4321]:80"
563                            .parse()
564                            .unwrap(),
565                    ));
566                    ctx.insert(Forwarded::new(ForwardedElement::forwarded_for(
567                        NodeId::try_from("127.0.1.2:80").unwrap(),
568                    )));
569                    ctx
570                },
571                "[4321:8765:ba09:fedc:cdef:90ab:5678:1234]:65535",
572            ),
573        ] {
574            let svc = HaProxyLayer::tcp()
575                .v1()
576                .layer(service_fn(async move |ctx, req| {
577                    Ok::<_, Infallible>(EstablishedClientConnection {
578                        ctx,
579                        req,
580                        conn: SocketConnection {
581                            socket: target_addr.parse().unwrap(),
582                            conn: Builder::new().build(),
583                        },
584                    })
585                }));
586            assert!(svc.serve(input_ctx, ()).await.is_err());
587        }
588    }
589
590    #[tokio::test]
591    async fn test_v1_tcp_missing_src() {
592        for (input_ctx, target_addr) in [
593            (Context::default(), "192.168.1.101:443"),
594            (
595                Context::default(),
596                "[1234:5678:90ab:cdef:fedc:ba09:8765:4321]:443",
597            ),
598        ] {
599            let svc = HaProxyLayer::tcp()
600                .v1()
601                .layer(service_fn(async move |ctx, req| {
602                    Ok::<_, Infallible>(EstablishedClientConnection {
603                        ctx,
604                        req,
605                        conn: SocketConnection {
606                            socket: target_addr.parse().unwrap(),
607                            conn: Builder::new().build(),
608                        },
609                    })
610                }));
611            assert!(svc.serve(input_ctx, ()).await.is_err());
612        }
613    }
614
615    #[tokio::test]
616    async fn test_v2_tcp4() {
617        for input_ctx in [
618            {
619                let mut ctx = Context::default();
620                ctx.insert(SocketInfo::new(None, "127.0.0.1:80".parse().unwrap()));
621                ctx
622            },
623            {
624                let mut ctx = Context::default();
625                ctx.insert(SocketInfo::new(
626                    None,
627                    "[1234:5678:90ab:cdef:fedc:ba09:8765:4321]:443"
628                        .parse()
629                        .unwrap(),
630                ));
631                ctx.insert(Forwarded::new(ForwardedElement::forwarded_for(
632                    NodeId::try_from("127.0.0.1:80").unwrap(),
633                )));
634                ctx
635            },
636        ] {
637            let svc =
638                HaProxyLayer::tcp()
639                    .payload(vec![42])
640                    .layer(service_fn(async move |ctx, req| {
641                        Ok::<_, Infallible>(EstablishedClientConnection {
642                            ctx,
643                            req,
644                            conn: SocketConnection {
645                                socket: "192.168.1.1:443".parse().unwrap(),
646                                conn: Builder::new()
647                                    .write(&[
648                                        b'\r', b'\n', b'\r', b'\n', b'\0', b'\r', b'\n', b'Q',
649                                        b'U', b'I', b'T', b'\n', 0x21, 0x11, 0, 13, 127, 0, 0, 1,
650                                        192, 168, 1, 1, 0, 80, 1, 187, 42,
651                                    ])
652                                    .build(),
653                            },
654                        })
655                    }));
656            svc.serve(input_ctx, ()).await.unwrap();
657        }
658    }
659
660    #[tokio::test]
661    async fn test_v2_udp4() {
662        for input_ctx in [
663            {
664                let mut ctx = Context::default();
665                ctx.insert(SocketInfo::new(None, "127.0.0.1:80".parse().unwrap()));
666                ctx
667            },
668            {
669                let mut ctx = Context::default();
670                ctx.insert(SocketInfo::new(
671                    None,
672                    "[1234:5678:90ab:cdef:fedc:ba09:8765:4321]:443"
673                        .parse()
674                        .unwrap(),
675                ));
676                ctx.insert(Forwarded::new(ForwardedElement::forwarded_for(
677                    NodeId::try_from("127.0.0.1:80").unwrap(),
678                )));
679                ctx
680            },
681        ] {
682            let svc =
683                HaProxyLayer::udp()
684                    .payload(vec![42])
685                    .layer(service_fn(async move |ctx, req| {
686                        Ok::<_, Infallible>(EstablishedClientConnection {
687                            ctx,
688                            req,
689                            conn: SocketConnection {
690                                socket: "192.168.1.1:443".parse().unwrap(),
691                                conn: Builder::new()
692                                    .write(&[
693                                        b'\r', b'\n', b'\r', b'\n', b'\0', b'\r', b'\n', b'Q',
694                                        b'U', b'I', b'T', b'\n', 0x21, 0x12, 0, 13, 127, 0, 0, 1,
695                                        192, 168, 1, 1, 0, 80, 1, 187, 42,
696                                    ])
697                                    .build(),
698                            },
699                        })
700                    }));
701            svc.serve(input_ctx, ()).await.unwrap();
702        }
703    }
704
705    #[tokio::test]
706    async fn test_v2_tcp6() {
707        for input_ctx in [
708            {
709                let mut ctx = Context::default();
710                ctx.insert(SocketInfo::new(
711                    None,
712                    "[1234:5678:90ab:cdef:fedc:ba09:8765:4321]:80"
713                        .parse()
714                        .unwrap(),
715                ));
716                ctx
717            },
718            {
719                let mut ctx = Context::default();
720                ctx.insert(SocketInfo::new(None, "127.0.0.1:80".parse().unwrap()));
721                ctx.insert(Forwarded::new(ForwardedElement::forwarded_for(
722                    NodeId::try_from("[1234:5678:90ab:cdef:fedc:ba09:8765:4321]:80").unwrap(),
723                )));
724                ctx
725            },
726        ] {
727            let svc =
728                HaProxyLayer::tcp()
729                    .payload(vec![42])
730                    .layer(service_fn(async move |ctx, req| {
731                        Ok::<_, Infallible>(EstablishedClientConnection {
732                            ctx,
733                            req,
734                            conn: SocketConnection {
735                                socket: "[4321:8765:ba09:fedc:cdef:90ab:5678:1234]:443"
736                                    .parse()
737                                    .unwrap(),
738                                conn: Builder::new()
739                                    .write(&[
740                                        b'\r', b'\n', b'\r', b'\n', b'\0', b'\r', b'\n', b'Q',
741                                        b'U', b'I', b'T', b'\n', 0x21, 0x21, 0, 37, 0x12, 0x34,
742                                        0x56, 0x78, 0x90, 0xab, 0xcd, 0xef, 0xfe, 0xdc, 0xba, 0x09,
743                                        0x87, 0x65, 0x43, 0x21, 0x43, 0x21, 0x87, 0x65, 0xba, 0x09,
744                                        0xfe, 0xdc, 0xcd, 0xef, 0x90, 0xab, 0x56, 0x78, 0x12, 0x34,
745                                        0, 80, 1, 187, 42,
746                                    ])
747                                    .build(),
748                            },
749                        })
750                    }));
751            svc.serve(input_ctx, ()).await.unwrap();
752        }
753    }
754
755    #[tokio::test]
756    async fn test_v2_udp6() {
757        for input_ctx in [
758            {
759                let mut ctx = Context::default();
760                ctx.insert(SocketInfo::new(
761                    None,
762                    "[1234:5678:90ab:cdef:fedc:ba09:8765:4321]:80"
763                        .parse()
764                        .unwrap(),
765                ));
766                ctx
767            },
768            {
769                let mut ctx = Context::default();
770                ctx.insert(SocketInfo::new(None, "127.0.0.1:80".parse().unwrap()));
771                ctx.insert(Forwarded::new(ForwardedElement::forwarded_for(
772                    NodeId::try_from("[1234:5678:90ab:cdef:fedc:ba09:8765:4321]:80").unwrap(),
773                )));
774                ctx
775            },
776        ] {
777            let svc =
778                HaProxyLayer::udp()
779                    .payload(vec![42])
780                    .layer(service_fn(async move |ctx, req| {
781                        Ok::<_, Infallible>(EstablishedClientConnection {
782                            ctx,
783                            req,
784                            conn: SocketConnection {
785                                socket: "[4321:8765:ba09:fedc:cdef:90ab:5678:1234]:443"
786                                    .parse()
787                                    .unwrap(),
788                                conn: Builder::new()
789                                    .write(&[
790                                        b'\r', b'\n', b'\r', b'\n', b'\0', b'\r', b'\n', b'Q',
791                                        b'U', b'I', b'T', b'\n', 0x21, 0x22, 0, 37, 0x12, 0x34,
792                                        0x56, 0x78, 0x90, 0xab, 0xcd, 0xef, 0xfe, 0xdc, 0xba, 0x09,
793                                        0x87, 0x65, 0x43, 0x21, 0x43, 0x21, 0x87, 0x65, 0xba, 0x09,
794                                        0xfe, 0xdc, 0xcd, 0xef, 0x90, 0xab, 0x56, 0x78, 0x12, 0x34,
795                                        0, 80, 1, 187, 42,
796                                    ])
797                                    .build(),
798                            },
799                        })
800                    }));
801            svc.serve(input_ctx, ()).await.unwrap();
802        }
803    }
804
805    #[tokio::test]
806    async fn test_v2_ip_version_mismatch() {
807        for (input_ctx, target_addr) in [
808            (
809                {
810                    let mut ctx = Context::default();
811                    ctx.insert(SocketInfo::new(
812                        None,
813                        "[1234:5678:90ab:cdef:fedc:ba09:8765:4321]:80"
814                            .parse()
815                            .unwrap(),
816                    ));
817                    ctx
818                },
819                "192.168.1.101:443",
820            ),
821            (
822                {
823                    let mut ctx = Context::default();
824                    ctx.insert(SocketInfo::new(None, "127.0.1.2:80".parse().unwrap()));
825                    ctx.insert(Forwarded::new(ForwardedElement::forwarded_for(
826                        NodeId::try_from("[1234:5678:90ab:cdef:fedc:ba09:8765:4321]:80").unwrap(),
827                    )));
828                    ctx
829                },
830                "192.168.1.101:443",
831            ),
832            (
833                {
834                    let mut ctx = Context::default();
835                    ctx.insert(SocketInfo::new(None, "127.0.1.2:80".parse().unwrap()));
836                    ctx
837                },
838                "[4321:8765:ba09:fedc:cdef:90ab:5678:1234]:65535",
839            ),
840            (
841                {
842                    let mut ctx = Context::default();
843                    ctx.insert(SocketInfo::new(
844                        None,
845                        "[1234:5678:90ab:cdef:fedc:ba09:8765:4321]:80"
846                            .parse()
847                            .unwrap(),
848                    ));
849                    ctx.insert(Forwarded::new(ForwardedElement::forwarded_for(
850                        NodeId::try_from("127.0.1.2:80").unwrap(),
851                    )));
852                    ctx
853                },
854                "[4321:8765:ba09:fedc:cdef:90ab:5678:1234]:65535",
855            ),
856        ] {
857            // TCP
858
859            let svc = HaProxyLayer::tcp().layer(service_fn(async move |ctx, req| {
860                Ok::<_, Infallible>(EstablishedClientConnection {
861                    ctx,
862                    req,
863                    conn: SocketConnection {
864                        socket: target_addr.parse().unwrap(),
865                        conn: Builder::new().build(),
866                    },
867                })
868            }));
869            assert!(svc.serve(input_ctx.clone(), ()).await.is_err());
870
871            // UDP
872
873            let svc = HaProxyLayer::udp().layer(service_fn(async move |ctx, req| {
874                Ok::<_, Infallible>(EstablishedClientConnection {
875                    ctx,
876                    req,
877                    conn: SocketConnection {
878                        socket: target_addr.parse().unwrap(),
879                        conn: Builder::new().build(),
880                    },
881                })
882            }));
883            assert!(svc.serve(input_ctx, ()).await.is_err());
884        }
885    }
886
887    #[tokio::test]
888    async fn test_v2_missing_src() {
889        for (input_ctx, target_addr) in [
890            (Context::default(), "192.168.1.101:443"),
891            (
892                Context::default(),
893                "[1234:5678:90ab:cdef:fedc:ba09:8765:4321]:443",
894            ),
895        ] {
896            // TCP
897
898            let svc = HaProxyLayer::tcp().layer(service_fn(async move |ctx, req| {
899                Ok::<_, Infallible>(EstablishedClientConnection {
900                    ctx,
901                    req,
902                    conn: SocketConnection {
903                        socket: target_addr.parse().unwrap(),
904                        conn: Builder::new().build(),
905                    },
906                })
907            }));
908            assert!(svc.serve(input_ctx.clone(), ()).await.is_err());
909
910            // UDP
911
912            let svc = HaProxyLayer::udp().layer(service_fn(async move |ctx, req| {
913                Ok::<_, Infallible>(EstablishedClientConnection {
914                    ctx,
915                    req,
916                    conn: SocketConnection {
917                        socket: target_addr.parse().unwrap(),
918                        conn: Builder::new().build(),
919                    },
920                })
921            }));
922            assert!(svc.serve(input_ctx.clone(), ()).await.is_err());
923        }
924    }
925}