proxy_protocol_codec/v2/codec/
decode.rs

1//! PROXY Protocol v2 header decoder
2
3use core::cmp::min;
4use core::net::{Ipv4Addr, Ipv6Addr};
5use core::num::NonZeroUsize;
6
7use slicur::Reader;
8
9use crate::v2::model::{
10    AddressPair, Command, ExtensionRef, Family, Protocol, ADDR_INET6_SIZE, ADDR_INET_SIZE, ADDR_UNIX_SIZE,
11    BYTE_VERSION, HEADER_SIZE,
12};
13use crate::v2::Header;
14
15#[derive(Debug)]
16/// PROXY Protocol v2 header decoder.
17///
18/// See [`decode`](Self::decode) for more details.
19pub struct HeaderDecoder;
20
21// Masks the right 4-bits so only the left 4-bits are
22// present.
23const MASK_HI: u8 = 0xF0;
24
25// Masks the left 4-bits so only the right 4-bits are
26// present.
27const MASK_LO: u8 = 0x0F;
28
29/// See [`Command`].
30const COMMAND_LOCAL: u8 = Command::Local as u8;
31
32/// See [`Command`].
33const COMMAND_PROXY: u8 = Command::Proxy as u8;
34
35/// See [`Family`].
36const FAMILY_UNSPECIFIED: u8 = Family::Unspecified as u8;
37
38/// See [`Family`].
39const FAMILY_INET: u8 = Family::Inet as u8;
40
41/// See [`Family`].
42const FAMILY_INET6: u8 = Family::Inet6 as u8;
43
44/// See [`Family`].
45const FAMILY_UNIX: u8 = Family::Unix as u8;
46
47/// See [`Protocol`].
48const PROTOCOL_UNSPECIFIED: u8 = Protocol::Unspecified as u8;
49
50/// See [`Protocol`].
51const PROTOCOL_STREAM: u8 = Protocol::Stream as u8;
52
53/// See [`Protocol`].
54const PROTOCOL_DGRAM: u8 = Protocol::Dgram as u8;
55
56impl HeaderDecoder {
57    #[allow(clippy::missing_panics_doc)]
58    /// Try to decode the PROXY Protocol v2 header from the given buffer.
59    ///
60    /// The caller SHOULD first **peek** exactly **12** bytes from the network
61    /// input into a buffer and [`decode`](Self::decode) it, to detect the
62    /// presence of a PROXY Protocol v2 header.
63    ///
64    /// When the buffer is not prefixed with PROXY Protocol v2 header
65    /// [`MAGIC`](Header::MAGIC), this method returns [`Decoded::None`]. The
66    /// caller MAY reject the connection, or treat the connection as a
67    /// normal one w/o PROXY Protocol v2 header.
68    ///
69    /// When a PROXY protocol v2 header is detected, [`Decoded::Partial`] is
70    /// returned (this is what we expect, since we only have the MAGIC bytes
71    /// peeked). The caller SHOULD then **read** exactly `16 +
72    /// remaining_bytes` bytes into a buffer (may reuse the buffer peeking
73    /// the MAGIC bytes) then [`decode`](Self::decode) it.
74    ///
75    /// When any error is returned, the caller SHOULD reject the connection.
76    ///
77    /// When there're extensions in the PROXY Protocol v2 header, the caller
78    /// SHOULD read the extensions to check if they are malformed or not.
79    /// See [`DecodedExtensions`] for more details.
80    pub fn decode<'a>(buf: &'a [u8]) -> Result<Decoded<'a>, DecodeError> {
81        // 1. Magic bytes
82        {
83            let magic_length = min(Header::MAGIC.len(), buf.len());
84
85            if buf[..magic_length] != Header::MAGIC[..magic_length] {
86                return Ok(Decoded::None);
87            }
88        }
89
90        // 2. Read header
91        match HEADER_SIZE.checked_sub(buf.len()).and_then(NonZeroUsize::new) {
92            None => {}
93            Some(remaining_bytes) => {
94                // The caller should read 16 bytes first, in fact.
95                #[cfg(feature = "feat-nightly")]
96                core::hint::cold_path();
97
98                return Ok(Decoded::Partial(remaining_bytes));
99            }
100        }
101
102        // 2.1. version
103        match buf[12] & MASK_HI {
104            BYTE_VERSION => {}
105            v => {
106                #[cfg(feature = "feat-nightly")]
107                core::hint::cold_path();
108
109                return Err(DecodeError::InvalidVersion(v));
110            }
111        };
112
113        // 2.2. command
114        let command = match buf[12] & MASK_LO {
115            COMMAND_LOCAL => Command::Local,
116            COMMAND_PROXY => Command::Proxy,
117            c => {
118                #[cfg(feature = "feat-nightly")]
119                core::hint::cold_path();
120
121                return Err(DecodeError::InvalidCommand(c));
122            }
123        };
124
125        // 3.1. addr_family
126        let addr_family = match buf[13] & MASK_HI {
127            FAMILY_UNSPECIFIED => Family::Unspecified,
128            FAMILY_INET => Family::Inet,
129            FAMILY_INET6 => Family::Inet6,
130            FAMILY_UNIX => Family::Unix,
131            f => {
132                #[cfg(feature = "feat-nightly")]
133                core::hint::cold_path();
134
135                return Err(DecodeError::InvalidFamily(f));
136            }
137        };
138
139        // 3.2. protocol
140        let protocol = match buf[13] & MASK_LO {
141            PROTOCOL_UNSPECIFIED => Protocol::Unspecified,
142            PROTOCOL_STREAM => Protocol::Stream,
143            PROTOCOL_DGRAM => Protocol::Dgram,
144            p => {
145                #[cfg(feature = "feat-nightly")]
146                core::hint::cold_path();
147
148                return Err(DecodeError::InvalidProtocol(p));
149            }
150        };
151
152        // 4. remaining_len
153        let remaining_len = u16::from_be_bytes([buf[14], buf[15]]);
154
155        // Check if the buffer has enough data for the the payload
156        let payload = match HEADER_SIZE
157            .checked_add(remaining_len as usize)
158            .ok_or(DecodeError::MalformedData)?
159            .checked_sub(buf.len())
160            .map(NonZeroUsize::new)
161        {
162            Some(None) => &buf[HEADER_SIZE..],
163            Some(Some(remaining_bytes)) => return Ok(Decoded::Partial(remaining_bytes)),
164            None => {
165                #[cfg(feature = "feat-nightly")]
166                core::hint::cold_path();
167
168                // HEADER_SIZE + remaining_len < buf.len(), reject trailing data
169                return Err(DecodeError::TrailingData);
170            }
171        };
172
173        let (address_pair, extensions) = match addr_family {
174            Family::Unspecified => (AddressPair::Unspecified, payload),
175            Family::Inet => {
176                if payload.len() < ADDR_INET_SIZE {
177                    #[cfg(feature = "feat-nightly")]
178                    core::hint::cold_path();
179
180                    return Err(DecodeError::MalformedData);
181                }
182
183                (
184                    AddressPair::Inet {
185                        src_ip: Ipv4Addr::from(TryInto::<[u8; 4]>::try_into(&payload[0..4]).unwrap()),
186                        dst_ip: Ipv4Addr::from(TryInto::<[u8; 4]>::try_into(&payload[4..8]).unwrap()),
187                        src_port: u16::from_be_bytes([payload[8], payload[9]]),
188                        dst_port: u16::from_be_bytes([payload[10], payload[11]]),
189                    },
190                    &payload[ADDR_INET_SIZE..],
191                )
192            }
193            Family::Inet6 => {
194                if payload.len() < ADDR_INET6_SIZE {
195                    #[cfg(feature = "feat-nightly")]
196                    core::hint::cold_path();
197
198                    return Err(DecodeError::MalformedData);
199                }
200
201                (
202                    AddressPair::Inet6 {
203                        src_ip: Ipv6Addr::from(TryInto::<[u8; 16]>::try_into(&payload[0..16]).unwrap()),
204                        dst_ip: Ipv6Addr::from(TryInto::<[u8; 16]>::try_into(&payload[16..32]).unwrap()),
205                        src_port: u16::from_be_bytes([payload[32], payload[33]]),
206                        dst_port: u16::from_be_bytes([payload[34], payload[35]]),
207                    },
208                    &payload[ADDR_INET6_SIZE..],
209                )
210            }
211            Family::Unix => {
212                if payload.len() < ADDR_UNIX_SIZE {
213                    #[cfg(feature = "feat-nightly")]
214                    core::hint::cold_path();
215
216                    return Err(DecodeError::MalformedData);
217                }
218
219                (
220                    AddressPair::Unix {
221                        src_addr: payload[0..108].try_into().unwrap(),
222                        dst_addr: payload[108..216].try_into().unwrap(),
223                    },
224                    &payload[ADDR_UNIX_SIZE..],
225                )
226            }
227        };
228
229        match command {
230            Command::Local => Ok(Decoded::Some(DecodedHeader {
231                header: Header::new_local(),
232                extensions: DecodedExtensions::const_from(extensions),
233            })),
234            Command::Proxy => Ok(Decoded::Some(DecodedHeader {
235                header: Header::new_proxy(protocol, address_pair),
236                extensions: DecodedExtensions::const_from(extensions),
237            })),
238        }
239    }
240}
241
242#[allow(clippy::large_enum_variant)]
243#[derive(Debug)]
244/// The result of decoding a PROXY Protocol v2 header.
245pub enum Decoded<'a> {
246    /// The PROXY Protocol v2 header and its extensions decoded.
247    Some(DecodedHeader<'a>),
248
249    /// Partial data, the caller should read more data.
250    Partial(NonZeroUsize),
251
252    /// Not a PROXY Protocol v2 header.
253    None,
254}
255
256#[derive(Debug)]
257/// A wrapper around the PROXY Protocol v2 header and its extensions.
258pub struct DecodedHeader<'a> {
259    /// The PROXY Protocol v2 header.
260    pub header: Header,
261
262    /// Extensions of the PROXY Protocol v2 header.
263    pub extensions: DecodedExtensions<'a>,
264}
265
266wrapper_lite::wrapper! {
267    #[wrapper_impl(Deref)]
268    #[derive(Debug)]
269    /// A wrapper around a slice of bytes representing the encoded extensions
270    /// of the PROXY Protocol v2 header.
271    ///
272    /// This implements `IntoIterator` to iterate over the extensions. See
273    /// [`DecodedExtensionsIter`] for more details.
274    pub DecodedExtensions<'a>(&'a [u8])
275}
276
277impl<'a> IntoIterator for DecodedExtensions<'a> {
278    type IntoIter = DecodedExtensionsIter<'a>;
279    type Item = Result<ExtensionRef<'a>, DecodeError>;
280
281    fn into_iter(self) -> Self::IntoIter {
282        DecodedExtensionsIter {
283            inner: Reader::init(self.inner),
284        }
285    }
286}
287
288#[derive(Debug)]
289/// Iterator over the extensions of the PROXY Protocol v2 header.
290pub struct DecodedExtensionsIter<'a> {
291    inner: Reader<'a>,
292}
293
294impl<'a> Iterator for DecodedExtensionsIter<'a> {
295    type Item = Result<ExtensionRef<'a>, DecodeError>;
296
297    fn next(&mut self) -> Option<Self::Item> {
298        ExtensionRef::decode(&mut self.inner).transpose()
299    }
300}
301
302#[derive(Debug)]
303#[derive(thiserror::Error)]
304/// Errors that can occur while decoding a PROXY Protocol v2 header.
305pub enum DecodeError {
306    #[error("Invalid PROXY Protocol version: {0}")]
307    /// Invalid PROXY Protocol version
308    InvalidVersion(u8),
309
310    #[error("Invalid PROXY Protocol command: {0}")]
311    /// Invalid PROXY Protocol command
312    InvalidCommand(u8),
313
314    #[error("Invalid proxy address family: {0}")]
315    /// Invalid proxy address family
316    InvalidFamily(u8),
317
318    #[error("Invalid proxy transport protocol: {0}")]
319    /// Invalid proxy transport protocol
320    InvalidProtocol(u8),
321
322    #[error("Trailing data after the header")]
323    /// The buffer contains trailing data after the PROXY Protocol v2 header.
324    TrailingData,
325
326    #[error("Malformed data")]
327    /// The data is malformed, e.g. the length of an extension does not match
328    /// the actual data length.
329    MalformedData,
330}