proxy_protocol_codec/
lib.rs

1#![doc = include_str!("../README.md")]
2#![no_std]
3#![allow(clippy::must_use_candidate, reason = "XXX")]
4#![allow(clippy::return_self_not_must_use, reason = "XXX")]
5
6#[cfg(feature = "feat-codec-v1")]
7pub mod v1;
8#[cfg(feature = "feat-codec-v2")]
9pub mod v2;
10
11#[cfg(any(test, feature = "feat-alloc"))]
12extern crate alloc;
13
14#[cfg(any(test, feature = "feat-std"))]
15extern crate std;
16
17#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18/// The supported PROXY Protocol versions.
19pub enum Version {
20    /// PROXY Protocol version 1
21    V1,
22
23    /// PROXY Protocol version 2
24    V2,
25}
26
27impl Version {
28    /// The magic bytes that indicate a PROXY Protocol v1 header.
29    pub const MAGIC_V1: &'static str = "PROXY";
30    /// The magic bytes that indicate a PROXY Protocol v2 header.
31    pub const MAGIC_V2: &'static [u8; 12] = b"\r\n\r\n\x00\r\nQUIT\n";
32
33    #[allow(clippy::result_unit_err, reason = "XXX")]
34    #[allow(clippy::missing_errors_doc, reason = "XXX")]
35    #[inline]
36    /// Peeks into the given buffer to determine if it contains a valid PROXY
37    /// Protocol version magic.
38    ///
39    /// ## Behaviours
40    ///
41    /// If the buffer is too short to determine the version, `Ok(None)` is
42    /// returned. If the buffer contains a valid version magic,
43    /// `Ok(Some(Version))` is returned. Otherwise, `Err(())` is returned.
44    ///
45    /// ```
46    /// # use proxy_protocol_codec::Version;
47    /// let v1_magic = Version::MAGIC_V1.as_bytes();
48    /// let v2_magic = Version::MAGIC_V2;
49    /// assert_eq!(Version::peek(v1_magic), Ok(Some(Version::V1)));
50    /// assert_eq!(Version::peek(&v1_magic[..3]), Ok(None));
51    /// assert_eq!(Version::peek(v2_magic), Ok(Some(Version::V2)));
52    /// assert_eq!(Version::peek(&v2_magic[..6]), Ok(None));
53    /// # assert_eq!(Version::peek(&[0]), Err(()));
54    /// ```
55    pub fn peek(buf: &[u8]) -> Result<Option<Self>, ()> {
56        const V1_MAGIC_LEN: usize = Version::MAGIC_V1.len();
57        const V2_MAGIC_LEN: usize = Version::MAGIC_V2.len();
58
59        match buf.len() {
60            0 => Ok(None),
61            V2_MAGIC_LEN.. if buf.starts_with(Self::MAGIC_V2) => Ok(Some(Self::V2)),
62            1..V2_MAGIC_LEN if Self::MAGIC_V2.starts_with(buf) => Ok(None),
63            V1_MAGIC_LEN.. if buf.starts_with(Self::MAGIC_V1.as_bytes()) => Ok(Some(Self::V1)),
64            1..V1_MAGIC_LEN if Self::MAGIC_V1.as_bytes().starts_with(buf) => Ok(None),
65            _ => Err(()),
66        }
67    }
68}
69
70#[cfg(test)]
71mod smoking {
72    #[test]
73    fn test_v1() {
74        use crate::v1::{AddressPair, Header};
75
76        // PROXY Protocol v1 (text format), TCP4
77        let address_pair = AddressPair::Inet {
78            src_ip: "127.0.0.1".parse().unwrap(),
79            dst_ip: "127.0.0.2".parse().unwrap(),
80            src_port: 8080,
81            dst_port: 80,
82        };
83        let header = Header::new(address_pair);
84
85        assert_eq!(header.encode(), "PROXY TCP4 127.0.0.1 127.0.0.2 8080 80\r\n");
86
87        // PROXY Protocol v1 (text format), TCP6
88        let address_pair = AddressPair::Inet6 {
89            src_ip: "::1".parse().unwrap(),
90            dst_ip: "::2".parse().unwrap(),
91            src_port: 8080,
92            dst_port: 80,
93        };
94        let header = Header::new(address_pair);
95
96        assert_eq!(header.encode(), "PROXY TCP6 ::1 ::2 8080 80\r\n");
97
98        // PROXY Protocol v1 (text format), UNKNOWN
99        let address_pair = AddressPair::Unspecified;
100        let header = Header::new(address_pair);
101
102        assert_eq!(header.encode(), "PROXY UNKNOWN\r\n");
103    }
104
105    #[test]
106    fn test_v2() {
107        use crate::v2::{AddressPair, Decoded, DecodedHeader, Header, Protocol};
108
109        // PROXY Protocol v1 (binary format)
110        let header = Header::new_proxy(
111            Protocol::Stream,
112            AddressPair::Inet {
113                src_ip: "127.0.0.1".parse().unwrap(),
114                dst_ip: "127.0.0.2".parse().unwrap(),
115                src_port: 8080,
116                dst_port: 80,
117            },
118        );
119
120        let encoded = header
121            .encode()
122            .write_ext_alpn(b"h2")
123            .unwrap()
124            .write_ext_authority(b"example.com")
125            .unwrap()
126            .write_ext_no_op(0)
127            .unwrap()
128            .write_ext_unique_id(b"unique_id")
129            .unwrap()
130            .write_ext_network_namespace(b"network_namespace")
131            .unwrap()
132            .finish()
133            .unwrap();
134
135        let Decoded::Some(DecodedHeader {
136            header: decoded_header,
137            extensions: _,
138        }) = Header::decode(&encoded).unwrap()
139        else {
140            panic!("failed to decode v2 header");
141        };
142
143        assert_eq!(header, decoded_header);
144    }
145}