proxy_protocol/
version1.rs

1use bytes::{Buf, BufMut as _, BytesMut};
2use snafu::{ensure, OptionExt as _, ResultExt as _, Snafu};
3use std::{
4    io::Write as _,
5    net::{AddrParseError, IpAddr, Ipv4Addr, Ipv6Addr, SocketAddrV4, SocketAddrV6},
6    str::{FromStr as _, Utf8Error},
7};
8
9const CR: u8 = 0x0D;
10const LF: u8 = 0x0A;
11
12#[derive(Debug, Snafu)]
13#[cfg_attr(test, derive(PartialEq, Eq))]
14pub enum ParseError {
15    #[snafu(display("an unexpected eof was hit"))]
16    UnexpectedEof,
17
18    #[snafu(display("an illegal address family was presented"))]
19    IllegalAddressFamily,
20
21    #[snafu(display("the given input is not valid ascii text"))]
22    NonAscii { source: Utf8Error },
23
24    #[snafu(display("the given input misses an address"))]
25    MissingAddress,
26
27    #[snafu(display("invalid ip address"))]
28    InvalidAddress { source: AddrParseError },
29
30    #[snafu(display("invalid port"))]
31    InvalidPort,
32
33    #[snafu(display("illegal header ending"))]
34    IllegalHeaderEnding,
35}
36
37#[derive(Debug, Snafu)]
38pub enum EncodeError {
39    #[snafu(display("could not write to the buffer"))]
40    StdIo { source: std::io::Error },
41}
42
43#[derive(Debug, PartialEq, Eq, Hash, Clone, Copy)]
44pub enum ProxyAddresses {
45    Unknown,
46    Ipv4 {
47        source: SocketAddrV4,
48        destination: SocketAddrV4,
49    },
50    Ipv6 {
51        source: SocketAddrV6,
52        destination: SocketAddrV6,
53    },
54}
55
56fn count_till_first(haystack: &[u8], needle: u8) -> Option<usize> {
57    for (idx, &b) in haystack.iter().enumerate() {
58        if b == needle {
59            return Some(idx);
60        }
61    }
62
63    None
64}
65
66pub(crate) fn parse(buf: &mut impl Buf) -> Result<super::ProxyHeader, ParseError> {
67    ensure!(buf.remaining() >= 4, UnexpectedEof);
68
69    let step = buf.get_u8();
70
71    #[derive(PartialEq, Eq)]
72    enum ProxyAddressFamily {
73        Tcp4,
74        Tcp6,
75        Unknown,
76    }
77
78    let address_family = match step {
79        b'T' => {
80            // Tcp4 / Tcp6
81            buf.advance(2);
82            let version = buf.get_u8();
83            match version {
84                b'4' => ProxyAddressFamily::Tcp4,
85                b'6' => ProxyAddressFamily::Tcp6,
86                _ => return IllegalAddressFamily.fail(),
87            }
88        }
89        b'U' => {
90            // Unknown
91            ensure!(buf.remaining() >= 6, UnexpectedEof); // Not 7, we consumed 1.
92            buf.advance(6);
93            ProxyAddressFamily::Unknown
94        }
95        _ => return IllegalAddressFamily.fail(),
96    };
97
98    if address_family == ProxyAddressFamily::Unknown {
99        // Just consume up to the end.
100        let mut cr = false;
101        loop {
102            ensure!(buf.has_remaining(), UnexpectedEof);
103            let b = buf.get_u8();
104            if cr && b == LF {
105                break;
106            }
107            cr = b == CR;
108        }
109        return Ok(super::ProxyHeader::Version1 {
110            addresses: ProxyAddresses::Unknown,
111        });
112    }
113
114    // 1 space, 4 digits, 3 dots, absolute minimum for the source.
115    ensure!(buf.remaining() >= 8, UnexpectedEof);
116    buf.advance(1); // Space
117
118    let count = count_till_first(buf.chunk(), b' ').context(MissingAddress)?;
119    let source = &buf.chunk()[..count];
120    let source = std::str::from_utf8(source).context(NonAscii)?;
121    let source = match address_family {
122        ProxyAddressFamily::Tcp4 => IpAddr::V4(Ipv4Addr::from_str(source).context(InvalidAddress)?),
123        ProxyAddressFamily::Tcp6 => IpAddr::V6(Ipv6Addr::from_str(source).context(InvalidAddress)?),
124        ProxyAddressFamily::Unknown => unreachable!("unknown should have its own branch"),
125    };
126    buf.advance(count);
127
128    // Same as above, another address incoming.
129    ensure!(buf.remaining() >= 8, UnexpectedEof);
130    buf.advance(1); // Space
131
132    let count = count_till_first(buf.chunk(), b' ').context(MissingAddress)?;
133    let destination = &buf.chunk()[..count];
134    let destination = std::str::from_utf8(destination).context(NonAscii)?;
135    let destination = match address_family {
136        ProxyAddressFamily::Tcp4 => {
137            IpAddr::V4(Ipv4Addr::from_str(destination).context(InvalidAddress)?)
138        }
139        ProxyAddressFamily::Tcp6 => {
140            IpAddr::V6(Ipv6Addr::from_str(destination).context(InvalidAddress)?)
141        }
142        ProxyAddressFamily::Unknown => unreachable!("unknown should have its own branch"),
143    };
144    buf.advance(count);
145
146    // Space, then a port. 0 is minimum valid port, so 1 byte.
147    ensure!(buf.remaining() >= 2, UnexpectedEof);
148    buf.advance(1);
149
150    let count = count_till_first(buf.chunk(), b' ').context(InvalidPort)?;
151    let source_port = &buf.chunk()[..count];
152    let source_port = std::str::from_utf8(source_port).context(NonAscii)?;
153    ensure!(
154        // The port 0 is itself valid, but 01 is not.
155        !source_port.starts_with('0') || source_port == "0",
156        InvalidPort,
157    );
158    let source_port: u16 = source_port.parse().ok().context(InvalidPort)?;
159    buf.advance(count);
160
161    // Space, then a port, then CRLF. 0 is minimum valid port, so 1 byte.
162    ensure!(buf.remaining() >= 4, UnexpectedEof);
163    buf.advance(1);
164
165    // This is the last member of the string. Read until CR; that's next up.
166    let count = count_till_first(buf.chunk(), CR).context(InvalidPort)?;
167    let destination_port = &buf.chunk()[..count];
168    let destination_port = std::str::from_utf8(destination_port).context(NonAscii)?;
169    ensure!(
170        // The port 0 is itself valid, but 01 is not.
171        !destination_port.starts_with('0') || destination_port == "0",
172        InvalidPort,
173    );
174    let destination_port: u16 = destination_port.parse().ok().context(InvalidPort)?;
175    buf.advance(count);
176
177    ensure!(buf.get_u8() == CR, IllegalHeaderEnding);
178    ensure!(buf.get_u8() == LF, IllegalHeaderEnding);
179
180    let addresses = match (source, destination) {
181        (IpAddr::V4(source), IpAddr::V4(destination)) => ProxyAddresses::Ipv4 {
182            source: SocketAddrV4::new(source, source_port),
183            destination: SocketAddrV4::new(destination, destination_port),
184        },
185        (IpAddr::V6(source), IpAddr::V6(destination)) => ProxyAddresses::Ipv6 {
186            source: SocketAddrV6::new(source, source_port, 0, 0),
187            destination: SocketAddrV6::new(destination, destination_port, 0, 0),
188        },
189        // Mismatches are checked before reading ports.
190        _ => unreachable!(),
191    };
192
193    Ok(super::ProxyHeader::Version1 {
194        addresses,
195    })
196}
197
198pub(crate) fn encode(addresses: ProxyAddresses) -> Result<BytesMut, EncodeError> {
199    if let ProxyAddresses::Unknown = addresses {
200        return Ok(BytesMut::from(&b"PROXY UNKNOWN\r\n"[..]));
201    }
202
203    // Reserve as much data as we're gonna need -- at most.
204    let mut buf = BytesMut::with_capacity(107).writer();
205    buf.write_all(&b"PROXY TCP"[..]).context(StdIo)?;
206
207    match addresses {
208        ProxyAddresses::Ipv4 {
209            source,
210            destination,
211        } => {
212            buf.write(&b"4 "[..]).context(StdIo)?;
213            write!(
214                buf,
215                "{} {} {} {}\r\n",
216                source.ip(),
217                destination.ip(),
218                source.port(),
219                destination.port(),
220            )
221            .context(StdIo)?;
222        }
223        ProxyAddresses::Ipv6 {
224            source,
225            destination,
226        } => {
227            buf.write(&b"6 "[..]).context(StdIo)?;
228            write!(
229                buf,
230                "{} {} {} {}\r\n",
231                source.ip(),
232                destination.ip(),
233                source.port(),
234                destination.port(),
235            )
236            .context(StdIo)?;
237        }
238        ProxyAddresses::Unknown => unreachable!(),
239    }
240
241    Ok(buf.into_inner())
242}
243
244#[cfg(test)]
245mod parse_tests {
246    use super::*;
247    use crate::ProxyHeader;
248    use bytes::Bytes;
249    use pretty_assertions::assert_eq;
250    use rand::prelude::*;
251    use std::net::{Ipv4Addr, Ipv6Addr, SocketAddrV4, SocketAddrV6};
252
253    #[test]
254    fn test_valid_unknown_cases() {
255        let unknown = Ok(ProxyHeader::Version1 {
256            addresses: ProxyAddresses::Unknown,
257        });
258        assert_eq!(parse(&mut &b"UNKNOWN\r\n"[..]), unknown);
259        assert_eq!(
260            parse(&mut &b"UNKNOWN this is bogus data!\r\r\r\n"[..]),
261            unknown,
262        );
263        assert_eq!(
264            parse(&mut &b"UNKNOWN 192.168.0.1 192.168.1.1 123 321\r\n"[..]),
265            unknown,
266        );
267
268        let mut random = [0u8; 128];
269        rand::thread_rng().fill_bytes(&mut random);
270        let mut header = b"UNKNOWN ".to_vec();
271        header.extend(&random[..]);
272        header.extend(b"\r\n");
273        let mut buf = Bytes::from(header);
274        assert_eq!(parse(&mut buf), unknown);
275        assert!(!buf.has_remaining()); // Consume the ENTIRE header!
276    }
277
278    #[test]
279    fn test_valid_ipv4_cases() {
280        fn valid(
281            (a, b, c, d): (u8, u8, u8, u8),
282            e: u16,
283            (f, g, h, i): (u8, u8, u8, u8),
284            j: u16,
285        ) -> ProxyHeader {
286            ProxyHeader::Version1 {
287                addresses: ProxyAddresses::Ipv4 {
288                    source: SocketAddrV4::new(Ipv4Addr::new(a, b, c, d), e),
289                    destination: SocketAddrV4::new(Ipv4Addr::new(f, g, h, i), j),
290                },
291            }
292        }
293        assert_eq!(
294            parse(&mut &b"TCP4 192.168.201.102 1.2.3.4 0 65535\r\n"[..]),
295            Ok(valid((192, 168, 201, 102), 0, (1, 2, 3, 4), 65535)),
296        );
297        assert_eq!(
298            parse(&mut &b"TCP4 0.0.0.0 0.0.0.0 0 0\r\n"[..]),
299            Ok(valid((0, 0, 0, 0), 0, (0, 0, 0, 0), 0)),
300        );
301        assert_eq!(
302            parse(&mut &b"TCP4 255.255.255.255 255.255.255.255 65535 65535\r\n"[..]),
303            Ok(valid(
304                (255, 255, 255, 255),
305                65535,
306                (255, 255, 255, 255),
307                65535,
308            )),
309        );
310    }
311
312    #[test]
313    fn test_valid_ipv6_cases() {
314        fn valid(
315            (a, b, c, d, e, f, g, h): (u16, u16, u16, u16, u16, u16, u16, u16),
316            i: u16,
317            (j, k, l, m, n, o, p, q): (u16, u16, u16, u16, u16, u16, u16, u16),
318            r: u16,
319        ) -> ProxyHeader {
320            ProxyHeader::Version1 {
321                addresses: ProxyAddresses::Ipv6 {
322                    source: SocketAddrV6::new(Ipv6Addr::new(a, b, c, d, e, f, g, h), i, 0, 0),
323                    destination: SocketAddrV6::new(Ipv6Addr::new(j, k, l, m, n, o, p, q), r, 0, 0),
324                },
325            }
326        }
327        assert_eq!(
328            parse(&mut &b"TCP6 ab:ce:ef:01:23:45:67:89 ::1 0 65535\r\n"[..]),
329            Ok(valid(
330                (0xAB, 0xCE, 0xEF, 0x01, 0x23, 0x45, 0x67, 0x89),
331                0,
332                (0, 0, 0, 0, 0, 0, 0, 1),
333                65535,
334            )),
335        );
336        assert_eq!(
337            parse(&mut &b"TCP6 :: :: 0 0\r\n"[..]),
338            Ok(valid(
339                (0, 0, 0, 0, 0, 0, 0, 0),
340                0,
341                (0, 0, 0, 0, 0, 0, 0, 0),
342                0,
343            )),
344        );
345        assert_eq!(
346            parse(
347                &mut &b"TCP6 ff:ff:ff:ff:ff:ff:ff:ff ff:ff:ff:ff:ff:ff:ff:ff 65535 65535\r\n"[..],
348            ),
349            Ok(valid(
350                (0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF),
351                65535,
352                (0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF),
353                65535,
354            )),
355        );
356    }
357
358    #[test]
359    fn test_invalid_cases() {
360        assert_eq!(
361            parse(&mut &b"UNKNOWN \r"[..]),
362            Err(ParseError::UnexpectedEof)
363        );
364        assert_eq!(
365            parse(&mut &b"UNKNOWN \r\t\t\r"[..]),
366            Err(ParseError::UnexpectedEof),
367        );
368        assert_eq!(
369            parse(&mut &b"UNKNOWN\r\r\r\r\rHello, world!"[..]),
370            Err(ParseError::UnexpectedEof),
371        );
372        assert_eq!(
373            parse(&mut &b"UNKNOWN\nGET /index.html HTTP/1.0"[..]),
374            Err(ParseError::UnexpectedEof),
375        );
376        assert_eq!(
377            parse(&mut &b"UNKNOWN\n"[..]),
378            Err(ParseError::UnexpectedEof)
379        );
380    }
381
382    #[test]
383    fn test_crlf() {
384        assert_eq!(CR, b'\r');
385        assert_eq!(LF, b'\n');
386    }
387}
388
389#[cfg(test)]
390mod encode_tests {
391    use super::*;
392    use bytes::Bytes;
393    use pretty_assertions::assert_eq;
394    use std::net::{Ipv4Addr, Ipv6Addr, SocketAddrV4, SocketAddrV6};
395
396    #[test]
397    fn test_unknown() {
398        let encoded = encode(ProxyAddresses::Unknown);
399        assert!(matches!(encoded, Ok(_)));
400        assert_eq!(encoded.unwrap(), &b"PROXY UNKNOWN\r\n"[..]);
401    }
402
403    #[test]
404    fn test_tcp4() {
405        let encoded = encode(ProxyAddresses::Ipv4 {
406            source: SocketAddrV4::new(Ipv4Addr::new(1, 2, 3, 4), 987),
407            destination: SocketAddrV4::new(Ipv4Addr::new(255, 254, 253, 252), 12345),
408        });
409        assert!(matches!(encoded, Ok(_)));
410        assert_eq!(
411            encoded.unwrap(),
412            Bytes::from_static(&b"PROXY TCP4 1.2.3.4 255.254.253.252 987 12345\r\n"[..]),
413        );
414    }
415
416    #[test]
417    fn test_tcp6() {
418        let encoded = encode(ProxyAddresses::Ipv6 {
419            source: SocketAddrV6::new(Ipv6Addr::new(1, 2, 3, 4, 5, 6, 7, 8), 987, 0, 0),
420            destination: SocketAddrV6::new(
421                Ipv6Addr::new(65535, 65534, 65533, 65532, 0, 1, 2, 3),
422                12345,
423                0,
424                0,
425            ),
426        });
427        assert!(matches!(encoded, Ok(_)));
428        assert_eq!(
429            encoded.unwrap(),
430            Bytes::from_static(
431                &b"PROXY TCP6 1:2:3:4:5:6:7:8 ffff:fffe:fffd:fffc:0:1:2:3 987 12345\r\n"[..],
432            ),
433        );
434
435        let encoded = encode(ProxyAddresses::Ipv6 {
436            source: SocketAddrV6::new(Ipv6Addr::new(1, 2, 3, 4, 5, 6, 7, 8), 987, 0, 0),
437            destination: SocketAddrV6::new(
438                Ipv6Addr::new(65535, 65534, 0, 0, 0, 1, 2, 3),
439                12345,
440                0,
441                0,
442            ),
443        });
444        assert!(matches!(encoded, Ok(_)));
445        assert_eq!(
446            encoded.unwrap(),
447            Bytes::from_static(&b"PROXY TCP6 1:2:3:4:5:6:7:8 ffff:fffe::1:2:3 987 12345\r\n"[..]),
448        );
449
450        let encoded = encode(ProxyAddresses::Ipv6 {
451            source: SocketAddrV6::new(Ipv6Addr::new(1, 2, 3, 4, 5, 6, 7, 8), 987, 0, 0),
452            destination: SocketAddrV6::new(Ipv6Addr::new(0, 0, 0, 0, 0, 1, 2, 3), 12345, 0, 0),
453        });
454        assert!(matches!(encoded, Ok(_)));
455        assert_eq!(
456            encoded.unwrap(),
457            Bytes::from_static(&b"PROXY TCP6 1:2:3:4:5:6:7:8 ::1:2:3 987 12345\r\n"[..]),
458        );
459    }
460}