haproxy_protocol/
lib.rs

1#![deny(warnings)]
2#![warn(unused_extern_crates)]
3#![deny(clippy::todo)]
4#![deny(clippy::unimplemented)]
5#![deny(clippy::unwrap_used)]
6#![deny(clippy::expect_used)]
7#![deny(clippy::panic)]
8#![deny(clippy::unreachable)]
9#![deny(clippy::await_holding_lock)]
10#![deny(clippy::needless_pass_by_value)]
11#![deny(clippy::trivially_copy_pass_by_ref)]
12
13use crate::parse::{parse_proxy_hdr_v1, parse_proxy_hdr_v2};
14use std::num::NonZeroUsize;
15
16#[cfg(any(test, feature = "tokio"))]
17use crate::parse::{V1_MAX_LEN, V1_MIN_LEN};
18
19const NZ_ONE: NonZeroUsize = NonZeroUsize::new(1).expect("Invalid compile time constant");
20
21mod parse;
22
23#[derive(Debug, PartialEq, Eq, Clone, Copy)]
24#[repr(u8)]
25enum Protocol {
26    Unspec = 0x00,
27    TcpV4 = 0x11,
28    UdpV4 = 0x12,
29    TcpV6 = 0x21,
30    UdpV6 = 0x22,
31    // UnixStream = 0x31,
32    // UnixDgram = 0x32,
33}
34
35#[derive(Debug, PartialEq, Eq, Clone, Copy)]
36#[repr(u8)]
37enum Command {
38    Local = 0x00,
39    Proxy = 0x01,
40}
41
42#[derive(Debug, PartialEq, Eq, Clone)]
43enum Address {
44    None,
45    V4 {
46        src: std::net::SocketAddrV4,
47        dst: std::net::SocketAddrV4,
48    },
49    V6 {
50        src: std::net::SocketAddrV6,
51        dst: std::net::SocketAddrV6,
52    },
53    // Unix {
54    //     src: PathBuf,
55    //     dst: PathBuf,
56    // }
57}
58
59#[derive(Debug, Clone)]
60pub enum RemoteAddress {
61    Local,
62    Invalid,
63    TcpV4 {
64        src: std::net::SocketAddrV4,
65        dst: std::net::SocketAddrV4,
66    },
67    UdpV4 {
68        src: std::net::SocketAddrV4,
69        dst: std::net::SocketAddrV4,
70    },
71    TcpV6 {
72        src: std::net::SocketAddrV6,
73        dst: std::net::SocketAddrV6,
74    },
75    UdpV6 {
76        src: std::net::SocketAddrV6,
77        dst: std::net::SocketAddrV6,
78    },
79}
80
81#[derive(Debug)]
82pub enum Error {
83    Incomplete { need: NonZeroUsize },
84    Invalid,
85    UnableToComplete,
86}
87
88#[derive(Debug, Clone)]
89pub struct ProxyHdrV2 {
90    command: Command,
91    protocol: Protocol,
92    // address_family: AddressFamily,
93    // length: u16,
94    address: Address,
95}
96
97impl ProxyHdrV2 {
98    pub fn parse(input_data: &[u8]) -> Result<(usize, Self), Error> {
99        match parse_proxy_hdr_v2(input_data) {
100            Ok((remainder, hdr)) => {
101                let took = input_data.len() - remainder.len();
102                Ok((took, hdr))
103            }
104            Err(nom::Err::Incomplete(nom::Needed::Size(need))) => Err(Error::Incomplete { need }),
105            // We always know exactly how much is needed for hdr v2
106            Err(nom::Err::Incomplete(nom::Needed::Unknown)) => Err(Error::UnableToComplete),
107
108            Err(nom::Err::Error(err)) => {
109                tracing::error!(?err);
110                Err(Error::Invalid)
111            }
112            Err(nom::Err::Failure(err)) => {
113                tracing::error!(?err, "parser failure handling proxy v2 header");
114                Err(Error::Invalid)
115            }
116        }
117    }
118
119    pub fn to_remote_addr(self) -> RemoteAddress {
120        match (self.command, self.protocol, self.address) {
121            (Command::Local, _, _) => RemoteAddress::Local,
122            (Command::Proxy, Protocol::TcpV4, Address::V4 { src, dst }) => {
123                RemoteAddress::TcpV4 { src, dst }
124            }
125            (Command::Proxy, Protocol::UdpV4, Address::V4 { src, dst }) => {
126                RemoteAddress::UdpV4 { src, dst }
127            }
128            (Command::Proxy, Protocol::TcpV6, Address::V6 { src, dst }) => {
129                RemoteAddress::TcpV6 { src, dst }
130            }
131            (Command::Proxy, Protocol::UdpV6, Address::V6 { src, dst }) => {
132                RemoteAddress::UdpV6 { src, dst }
133            }
134            _ => RemoteAddress::Invalid,
135        }
136    }
137}
138
139#[derive(Debug, Clone)]
140pub struct ProxyHdrV1 {
141    protocol: Protocol,
142    address: Address,
143}
144
145impl ProxyHdrV1 {
146    pub fn parse(input_data: &[u8]) -> Result<(usize, Self), Error> {
147        match parse_proxy_hdr_v1(input_data) {
148            Ok((remainder, hdr)) => {
149                let took = input_data.len() - remainder.len();
150                Ok((took, hdr))
151            }
152            Err(nom::Err::Incomplete(nom::Needed::Size(need))) => Err(Error::Incomplete { need }),
153            // We aren't sure how much we need but we need *something*.
154            Err(nom::Err::Incomplete(nom::Needed::Unknown)) => {
155                Err(Error::Incomplete { need: NZ_ONE })
156            }
157
158            Err(nom::Err::Error(err)) => {
159                tracing::error!(?err);
160                Err(Error::Invalid)
161            }
162            Err(nom::Err::Failure(err)) => {
163                tracing::error!(?err, "parser failure handling proxy v1 header");
164                Err(Error::Invalid)
165            }
166        }
167    }
168
169    pub fn to_remote_addr(self) -> RemoteAddress {
170        match (self.protocol, self.address) {
171            (Protocol::TcpV4, Address::V4 { src, dst }) => RemoteAddress::TcpV4 { src, dst },
172            (Protocol::UdpV4, Address::V4 { src, dst }) => RemoteAddress::UdpV4 { src, dst },
173            (Protocol::TcpV6, Address::V6 { src, dst }) => RemoteAddress::TcpV6 { src, dst },
174            (Protocol::UdpV6, Address::V6 { src, dst }) => RemoteAddress::UdpV6 { src, dst },
175            _ => RemoteAddress::Invalid,
176        }
177    }
178}
179
180#[cfg(any(feature = "tokio", test))]
181#[derive(Debug)]
182pub enum AsyncReadError {
183    Io(std::io::Error),
184    Invalid,
185    UnableToComplete,
186    RequestTooLarge,
187    InconsistentRead,
188}
189
190#[cfg(any(feature = "tokio", test))]
191impl ProxyHdrV2 {
192    pub async fn parse_from_read<S>(mut stream: S) -> Result<(S, Self), AsyncReadError>
193    where
194        S: tokio::io::AsyncReadExt + std::marker::Unpin,
195    {
196        use tracing::{debug, error};
197
198        const HDR_SIZE_LIMIT: usize = 512;
199
200        let mut buf = vec![0; 16];
201
202        // First we need to read the exact amount to get up to the *length* field. This will
203        // let us then proceed to parse the early header and return how much we need to continue
204        // to read.
205        let mut took = stream
206            .read_exact(&mut buf)
207            .await
208            .map_err(AsyncReadError::Io)?;
209
210        match ProxyHdrV2::parse(&buf) {
211            // Okay, we got a valid header - this can occur with proxy for local conditions.
212            Ok((_, hdr)) => return Ok((stream, hdr)),
213            // We need more bytes, this is the precise amount we need.
214            Err(Error::Incomplete { need }) => {
215                let resize_to = buf.len() + usize::from(need);
216                // Limit the amount so that we don't overflow anything or allocate a buffer that
217                // is too large. Nice try hackers.
218                if resize_to > HDR_SIZE_LIMIT {
219                    error!(
220                        "proxy v2 header request was larger than {} bytes, refusing to proceed.",
221                        HDR_SIZE_LIMIT
222                    );
223                    return Err(AsyncReadError::RequestTooLarge);
224                }
225                buf.resize(resize_to, 0);
226            }
227            Err(Error::Invalid) => {
228                debug!(proxy_binary_dump = %hex::encode(&buf));
229                error!("proxy v2 header was invalid");
230                return Err(AsyncReadError::Invalid);
231            }
232            Err(Error::UnableToComplete) => {
233                debug!(proxy_binary_dump = %hex::encode(&buf));
234                error!("proxy v2 header was incomplete");
235                return Err(AsyncReadError::UnableToComplete);
236            }
237        };
238
239        // Now read any remaining bytes into the buffer.
240        took += stream
241            .read_exact(&mut buf[16..])
242            .await
243            .map_err(AsyncReadError::Io)?;
244
245        match ProxyHdrV2::parse(&buf) {
246            Ok((hdr_took, _)) if hdr_took != took => {
247                // We took inconsistent byte amounts, error.
248                error!("proxy v2 header read an inconsistent amount from stream.");
249                Err(AsyncReadError::InconsistentRead)
250            }
251            Ok((_, hdr)) =>
252            // HAPPY!!!!!
253            {
254                Ok((stream, hdr))
255            }
256            Err(Error::Incomplete { need: _ }) => {
257                error!("proxy v2 header could not be read to the end.");
258                Err(AsyncReadError::UnableToComplete)
259            }
260            Err(Error::Invalid) => {
261                debug!(proxy_binary_dump = %hex::encode(&buf));
262                error!("proxy v2 header was invalid");
263                Err(AsyncReadError::Invalid)
264            }
265            Err(Error::UnableToComplete) => {
266                debug!(proxy_binary_dump = %hex::encode(&buf));
267                error!("proxy v2 header was incomplete");
268                Err(AsyncReadError::UnableToComplete)
269            }
270        }
271    }
272}
273
274#[cfg(any(feature = "tokio", test))]
275impl ProxyHdrV1 {
276    pub async fn parse_from_read<S>(mut stream: S) -> Result<(S, Self), AsyncReadError>
277    where
278        S: tokio::io::AsyncReadExt + std::marker::Unpin,
279    {
280        use tracing::{debug, error};
281
282        // This is the maximum size of the buffer we could possibly need.
283        let mut buf = [0; V1_MAX_LEN + 1];
284
285        // First we need to read the exact amount to get up to the *length* field. This will
286        // let us then proceed to parse the early header and return how much we need to continue
287        // to read.
288        let mut took = stream
289            .read_exact(&mut buf[..V1_MIN_LEN])
290            .await
291            .map_err(AsyncReadError::Io)?;
292
293        // Limit the view window to how many bytes we have.
294
295        loop {
296            if took > buf.len() {
297                error!("proxy v1 header read over ran the buffer allocation.");
298                return Err(AsyncReadError::Invalid);
299            }
300            match ProxyHdrV1::parse(&buf[..took]) {
301                Ok((hdr_took, _)) if hdr_took != took => {
302                    // We took inconsistent byte amounts, error.
303                    error!("proxy v1 header read an inconsistent amount from stream.");
304                    return Err(AsyncReadError::InconsistentRead);
305                }
306                Ok((_, hdr)) =>
307                // HAPPY!!!!!
308                {
309                    return Ok((stream, hdr));
310                }
311                Err(Error::Incomplete { need }) => {
312                    // We need more data, read it and then continue the loop.
313                    // Now read any remaining bytes into the buffer.
314                    took += stream
315                        .read_exact(&mut buf[took..took + need.get()])
316                        .await
317                        .map_err(AsyncReadError::Io)?;
318
319                    continue;
320                }
321                Err(Error::Invalid) => {
322                    debug!(proxy_binary_dump = %hex::encode(buf));
323                    error!("proxy v1 header was invalid");
324                    return Err(AsyncReadError::Invalid);
325                }
326                Err(Error::UnableToComplete) => {
327                    debug!(proxy_binary_dump = %hex::encode(buf));
328                    error!("proxy v1 header was incomplete");
329                    return Err(AsyncReadError::UnableToComplete);
330                }
331            }
332        } // end loop
333    }
334}
335
336#[cfg(test)]
337mod tests {
338    use crate::{Address, Command, Protocol, ProxyHdrV1, ProxyHdrV2};
339    use std::net::SocketAddrV4;
340    use std::str::FromStr;
341
342    #[tokio::test]
343    async fn proxyv1_stream_parse() {
344        let _ = tracing_subscriber::fmt::try_init();
345
346        let data = "PROXY TCP4 91.221.138.33 91.221.138.106 47780 636\r\n";
347
348        let (_, hdr) = ProxyHdrV1::parse_from_read(data.as_bytes()).await.unwrap();
349
350        tracing::debug!(?hdr);
351
352        assert_eq!(hdr.protocol, Protocol::TcpV4);
353        assert_eq!(
354            hdr.address,
355            Address::V4 {
356                src: SocketAddrV4::from_str("91.221.138.33:47780").unwrap(),
357                dst: SocketAddrV4::from_str("91.221.138.106:636").unwrap(),
358            }
359        );
360    }
361
362    #[tokio::test]
363    async fn proxyv2_stream_parse() {
364        let _ = tracing_subscriber::fmt::try_init();
365
366        let sample = hex::decode("0d0a0d0a000d0a515549540a2111000cac180c76ac180b8fcdcb027d")
367            .expect("valid hex");
368
369        let (_, hdr) = ProxyHdrV2::parse_from_read(sample.as_slice())
370            .await
371            .expect("should parse v4 addr");
372
373        tracing::debug!(?hdr);
374
375        assert_eq!(hdr.command, Command::Proxy);
376        assert_eq!(hdr.protocol, Protocol::TcpV4);
377        assert_eq!(
378            hdr.address,
379            Address::V4 {
380                src: SocketAddrV4::from_str("172.24.12.118:52683").expect("valid addr"),
381                dst: SocketAddrV4::from_str("172.24.11.143:637").expect("valid addr"),
382            }
383        );
384    }
385
386    #[cfg(all(test, feature = "tokio"))]
387    mod async_stream_tests {
388        use super::*;
389        use std::net::{SocketAddrV4, SocketAddrV6};
390        use std::str::FromStr;
391        use tokio::io::{AsyncReadExt, AsyncWrite, AsyncWriteExt};
392
393        async fn write_in_chunks<W>(mut writer: W, data: &[u8], chunk_sizes: &[usize])
394        where
395            W: AsyncWrite + Unpin,
396        {
397            let mut offset = 0;
398            for &size in chunk_sizes {
399                if offset >= data.len() {
400                    break;
401                }
402                let end = (offset + size).min(data.len());
403                #[allow(clippy::expect_used)] // because test function
404                writer
405                    .write_all(&data[offset..end])
406                    .await
407                    .expect("chunk write should succeed");
408                tokio::task::yield_now().await;
409                offset = end;
410            }
411
412            if offset < data.len() {
413                #[allow(clippy::expect_used)] // because test function
414                writer
415                    .write_all(&data[offset..])
416                    .await
417                    .expect("final write should succeed");
418            }
419        }
420
421        #[tokio::test]
422        async fn tokio_stream_parse_v2_chunks() {
423            let _ = tracing_subscriber::fmt::try_init();
424
425            let sample = hex::decode("0d0a0d0a000d0a515549540a2111000cac180c76ac180b8fcdcb027d")
426                .expect("valid hex");
427            let payload = b"hello";
428            let mut full = sample.clone();
429            full.extend_from_slice(payload);
430
431            let (client, server) = tokio::io::duplex(32);
432
433            let writer = tokio::spawn(async move {
434                write_in_chunks(server, &full, &[5, 3, 1, 7, 2]).await;
435            });
436
437            let (mut stream, hdr) = ProxyHdrV2::parse_from_read(client)
438                .await
439                .expect("should parse v2 from stream");
440
441            let mut extra = vec![0; payload.len()];
442            stream
443                .read_exact(&mut extra)
444                .await
445                .expect("should read extra payload");
446
447            writer.await.expect("writer task should finish");
448
449            assert_eq!(extra.as_slice(), payload);
450            assert_eq!(hdr.command, Command::Proxy);
451            assert_eq!(hdr.protocol, Protocol::TcpV4);
452            assert_eq!(
453                hdr.address,
454                Address::V4 {
455                    src: SocketAddrV4::from_str("172.24.12.118:52683").expect("valid addr"),
456                    dst: SocketAddrV4::from_str("172.24.11.143:637").expect("valid addr"),
457                }
458            );
459        }
460
461        #[tokio::test]
462        async fn tokio_stream_parse_v1_chunks() {
463            let _ = tracing_subscriber::fmt::try_init();
464
465            let header = b"PROXY TCP6 ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff 65535 65535\r\n";
466            let payload = b"more_data";
467            let mut full = header.to_vec();
468            full.extend_from_slice(payload);
469
470            let (client, server) = tokio::io::duplex(64);
471
472            let writer = tokio::spawn(async move {
473                write_in_chunks(server, &full, &[4, 1, 8, 2, 3, 5, 1]).await;
474            });
475
476            let (mut stream, hdr) = ProxyHdrV1::parse_from_read(client)
477                .await
478                .expect("should parse v1 from stream");
479
480            let mut extra = vec![0; payload.len()];
481            stream
482                .read_exact(&mut extra)
483                .await
484                .expect("should read extra payload");
485
486            writer.await.expect("writer task should finish");
487
488            assert_eq!(extra.as_slice(), payload);
489            assert_eq!(hdr.protocol, Protocol::TcpV6);
490            assert_eq!(
491                hdr.address,
492                Address::V6 {
493                    src: SocketAddrV6::from_str("[ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff]:65535")
494                        .expect("valid addr"),
495                    dst: SocketAddrV6::from_str("[ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff]:65535")
496                        .expect("valid addr"),
497                }
498            );
499        }
500    }
501}