proxy_protocol_codec/v2/codec/
decode.rs

1//! PROXY Protocol v2 header decoder
2
3#[cfg(feature = "feat-alloc")]
4use alloc::vec::Vec;
5use core::cmp::min;
6use core::iter::FusedIterator;
7use core::net::{Ipv4Addr, Ipv6Addr};
8use core::num::NonZeroUsize;
9
10use slicur::Reader;
11
12use crate::v2::model::{
13    AddressPair, Command, ExtensionRef, Family, Protocol, ADDR_INET6_SIZE, ADDR_INET_SIZE, ADDR_UNIX_SIZE,
14    BYTE_VERSION, HEADER_SIZE,
15};
16use crate::v2::Header;
17
18#[derive(Debug)]
19/// PROXY Protocol v2 header decoder.
20///
21/// See [`decode`](Self::decode) for more details.
22pub struct HeaderDecoder;
23
24// Masks the right 4-bits so only the left 4-bits are
25// present.
26const MASK_HI: u8 = 0xF0;
27
28// Masks the left 4-bits so only the right 4-bits are
29// present.
30const MASK_LO: u8 = 0x0F;
31
32/// See [`Command`].
33const COMMAND_LOCAL: u8 = Command::Local as u8;
34
35/// See [`Command`].
36const COMMAND_PROXY: u8 = Command::Proxy as u8;
37
38/// See [`Family`].
39const FAMILY_UNSPECIFIED: u8 = Family::Unspecified as u8;
40
41/// See [`Family`].
42const FAMILY_INET: u8 = Family::Inet as u8;
43
44/// See [`Family`].
45const FAMILY_INET6: u8 = Family::Inet6 as u8;
46
47/// See [`Family`].
48const FAMILY_UNIX: u8 = Family::Unix as u8;
49
50/// See [`Protocol`].
51const PROTOCOL_UNSPECIFIED: u8 = Protocol::Unspecified as u8;
52
53/// See [`Protocol`].
54const PROTOCOL_STREAM: u8 = Protocol::Stream as u8;
55
56/// See [`Protocol`].
57const PROTOCOL_DGRAM: u8 = Protocol::Dgram as u8;
58
59impl HeaderDecoder {
60    #[allow(clippy::missing_panics_doc)]
61    /// Attempts to decode the PROXY Protocol v2 header from its bytes
62    /// representation.
63    ///
64    /// The caller MAY first **peek** exactly **[`HEADER_SIZE`]** bytes from the
65    /// network input into a buffer and then [`decode`](Self::decode) it, to
66    /// detect the presence of a PROXY Protocol v2 header. If less than 16
67    /// bytes are peeked, the caller MAY reject the connection, or treat the
68    /// connection as a normal one w/o PROXY Protocol v2 header.
69    ///
70    /// When the buffer is not prefixed with PROXY Protocol v2 header
71    /// [`MAGIC`](Header::MAGIC), this method returns [`Decoded::None`]. The
72    /// caller MAY reject the connection, or treat the connection as a
73    /// normal one w/o PROXY Protocol v2 header.
74    ///
75    /// When a PROXY protocol v2 header is detected, [`Decoded::Partial`] is
76    /// returned (this is what we expect, since we only have the MAGIC bytes
77    /// peeked). The caller SHOULD then **read** exactly [`HEADER_SIZE`] +
78    /// `remaining_bytes` bytes into a buffer (may reuse the buffer peeking
79    /// the MAGIC bytes) and [`decode`](Self::decode) it again.
80    ///
81    /// When any error is returned, the caller SHOULD reject the connection.
82    ///
83    /// When there're extensions in the PROXY Protocol v2 header, the caller
84    /// SHOULD read the extensions to check if they are malformed or not.
85    /// See [`DecodedExtensions`] for more details.
86    pub fn decode<'a>(buf: &'a [u8]) -> Result<Decoded<'a>, DecodeError> {
87        // 1. Magic bytes
88        {
89            let magic_length = min(Header::MAGIC.len(), buf.len());
90
91            if buf[..magic_length] != Header::MAGIC[..magic_length] {
92                return Ok(Decoded::None);
93            }
94        }
95
96        // 2. Read header
97        match HEADER_SIZE.checked_sub(buf.len()).and_then(NonZeroUsize::new) {
98            None => {}
99            Some(remaining_bytes) => {
100                // The caller should read 16 bytes first, in fact.
101                #[cfg(feature = "feat-nightly")]
102                core::hint::cold_path();
103
104                return Ok(Decoded::Partial(remaining_bytes));
105            }
106        }
107
108        // 2.1. version
109        match buf[12] & MASK_HI {
110            BYTE_VERSION => {}
111            v => {
112                #[cfg(feature = "feat-nightly")]
113                core::hint::cold_path();
114
115                return Err(DecodeError::InvalidVersion(v));
116            }
117        };
118
119        // 2.2. command
120        let command = match buf[12] & MASK_LO {
121            COMMAND_LOCAL => Command::Local,
122            COMMAND_PROXY => Command::Proxy,
123            c => {
124                #[cfg(feature = "feat-nightly")]
125                core::hint::cold_path();
126
127                return Err(DecodeError::InvalidCommand(c));
128            }
129        };
130
131        // 3.1. addr_family
132        let addr_family = match buf[13] & MASK_HI {
133            FAMILY_UNSPECIFIED => Family::Unspecified,
134            FAMILY_INET => Family::Inet,
135            FAMILY_INET6 => Family::Inet6,
136            FAMILY_UNIX => Family::Unix,
137            f => {
138                #[cfg(feature = "feat-nightly")]
139                core::hint::cold_path();
140
141                return Err(DecodeError::InvalidFamily(f));
142            }
143        };
144
145        // 3.2. protocol
146        let protocol = match buf[13] & MASK_LO {
147            PROTOCOL_UNSPECIFIED => Protocol::Unspecified,
148            PROTOCOL_STREAM => Protocol::Stream,
149            PROTOCOL_DGRAM => Protocol::Dgram,
150            p => {
151                #[cfg(feature = "feat-nightly")]
152                core::hint::cold_path();
153
154                return Err(DecodeError::InvalidProtocol(p));
155            }
156        };
157
158        // 4. remaining_len
159        let remaining_len = u16::from_be_bytes([buf[14], buf[15]]);
160
161        // Check if the buffer has enough data for the the payload
162        let payload = match HEADER_SIZE
163            .checked_add(remaining_len as usize)
164            .ok_or(DecodeError::MalformedData)?
165            .checked_sub(buf.len())
166            .map(NonZeroUsize::new)
167        {
168            Some(None) => &buf[HEADER_SIZE..],
169            Some(Some(remaining_bytes)) => return Ok(Decoded::Partial(remaining_bytes)),
170            None => {
171                #[cfg(feature = "feat-nightly")]
172                core::hint::cold_path();
173
174                // HEADER_SIZE + remaining_len < buf.len(), reject trailing data
175                return Err(DecodeError::TrailingData);
176            }
177        };
178
179        let (address_pair, extensions) = match addr_family {
180            Family::Unspecified => (AddressPair::Unspecified, payload),
181            Family::Inet => {
182                if payload.len() < ADDR_INET_SIZE {
183                    #[cfg(feature = "feat-nightly")]
184                    core::hint::cold_path();
185
186                    return Err(DecodeError::MalformedData);
187                }
188
189                (
190                    AddressPair::Inet {
191                        src_ip: Ipv4Addr::from(TryInto::<[u8; 4]>::try_into(&payload[0..4]).unwrap()),
192                        dst_ip: Ipv4Addr::from(TryInto::<[u8; 4]>::try_into(&payload[4..8]).unwrap()),
193                        src_port: u16::from_be_bytes([payload[8], payload[9]]),
194                        dst_port: u16::from_be_bytes([payload[10], payload[11]]),
195                    },
196                    &payload[ADDR_INET_SIZE..],
197                )
198            }
199            Family::Inet6 => {
200                if payload.len() < ADDR_INET6_SIZE {
201                    #[cfg(feature = "feat-nightly")]
202                    core::hint::cold_path();
203
204                    return Err(DecodeError::MalformedData);
205                }
206
207                (
208                    AddressPair::Inet6 {
209                        src_ip: Ipv6Addr::from(TryInto::<[u8; 16]>::try_into(&payload[0..16]).unwrap()),
210                        dst_ip: Ipv6Addr::from(TryInto::<[u8; 16]>::try_into(&payload[16..32]).unwrap()),
211                        src_port: u16::from_be_bytes([payload[32], payload[33]]),
212                        dst_port: u16::from_be_bytes([payload[34], payload[35]]),
213                    },
214                    &payload[ADDR_INET6_SIZE..],
215                )
216            }
217            Family::Unix => {
218                if payload.len() < ADDR_UNIX_SIZE {
219                    #[cfg(feature = "feat-nightly")]
220                    core::hint::cold_path();
221
222                    return Err(DecodeError::MalformedData);
223                }
224
225                (
226                    AddressPair::Unix {
227                        src_addr: payload[0..108].try_into().unwrap(),
228                        dst_addr: payload[108..216].try_into().unwrap(),
229                    },
230                    &payload[ADDR_UNIX_SIZE..],
231                )
232            }
233        };
234
235        match command {
236            Command::Local => Ok(Decoded::Some(DecodedHeader {
237                header: Header::new_local(),
238                extensions: DecodedExtensions::const_from(extensions),
239            })),
240            Command::Proxy => Ok(Decoded::Some(DecodedHeader {
241                header: Header::new_proxy(protocol, address_pair),
242                extensions: DecodedExtensions::const_from(extensions),
243            })),
244        }
245    }
246}
247
248#[allow(clippy::large_enum_variant)]
249#[derive(Debug)]
250/// The result of decoding a PROXY Protocol v2 header.
251pub enum Decoded<'a> {
252    /// The PROXY Protocol v2 header and its extensions decoded.
253    Some(DecodedHeader<'a>),
254
255    /// Partial data, the caller should read more data.
256    Partial(NonZeroUsize),
257
258    /// Not a PROXY Protocol v2 header.
259    None,
260}
261
262#[derive(Debug)]
263/// A wrapper around the PROXY Protocol v2 header and its extensions.
264pub struct DecodedHeader<'a> {
265    /// The PROXY Protocol v2 header.
266    pub header: Header,
267
268    /// Extensions of the PROXY Protocol v2 header.
269    pub extensions: DecodedExtensions<'a>,
270}
271
272wrapper_lite::wrapper! {
273    #[wrapper_impl(Deref)]
274    #[derive(Debug)]
275    /// A wrapper around a slice of bytes representing the encoded extensions
276    /// of the PROXY Protocol v2 header.
277    ///
278    /// This implements `IntoIterator` to iterate over the extensions. See
279    /// [`DecodedExtensionsIter`] for more details.
280    pub DecodedExtensions<'a>(&'a [u8])
281}
282
283impl<'a> DecodedExtensions<'a> {
284    #[cfg(feature = "feat-alloc")]
285    /// Iterates over the extensions of the PROXY Protocol v2 header and
286    /// collects them into a `Vec<ExtensionRef>`.
287    pub fn collect(self) -> Result<Vec<ExtensionRef<'a>>, DecodeError> {
288        self.into_iter().collect()
289    }
290}
291
292impl<'a> IntoIterator for DecodedExtensions<'a> {
293    type IntoIter = DecodedExtensionsIter<'a>;
294    type Item = Result<ExtensionRef<'a>, DecodeError>;
295
296    fn into_iter(self) -> Self::IntoIter {
297        DecodedExtensionsIter {
298            inner: Some(Reader::init(self.inner)),
299        }
300    }
301}
302
303#[derive(Debug)]
304/// Iterator over the extensions of the PROXY Protocol v2 header.
305///
306/// This iterator yields [`ExtensionRef`]s, which are references to the
307/// decoded extensions. If an error occurs while decoding an extension, the
308/// iterator will yield an `Err(DecodeError)` instead.
309///
310/// The iterator is fused, meaning that once it returns `None`, it will
311/// continue to return `None` on subsequent calls.
312pub struct DecodedExtensionsIter<'a> {
313    inner: Option<Reader<'a>>,
314}
315
316impl<'a> Iterator for DecodedExtensionsIter<'a> {
317    type Item = Result<ExtensionRef<'a>, DecodeError>;
318
319    fn next(&mut self) -> Option<Self::Item> {
320        match self.inner.as_mut() {
321            Some(reader) => match ExtensionRef::decode(reader) {
322                Ok(Some(extension)) => Some(Ok(extension)),
323                Ok(None) => {
324                    // No more extensions, stop iterating.
325                    self.inner = None;
326
327                    None
328                }
329                Err(err) => {
330                    // An error occurred while decoding an extension, return the error.
331                    self.inner = None;
332
333                    Some(Err(err))
334                }
335            },
336            None => None,
337        }
338    }
339}
340
341impl<'a> FusedIterator for DecodedExtensionsIter<'a> {}
342
343#[derive(Debug)]
344#[derive(thiserror::Error)]
345/// Errors that can occur while decoding a PROXY Protocol v2 header.
346pub enum DecodeError {
347    #[error("Invalid PROXY Protocol version: {0}")]
348    /// Invalid PROXY Protocol version
349    InvalidVersion(u8),
350
351    #[error("Invalid PROXY Protocol command: {0}")]
352    /// Invalid PROXY Protocol command
353    InvalidCommand(u8),
354
355    #[error("Invalid proxy address family: {0}")]
356    /// Invalid proxy address family
357    InvalidFamily(u8),
358
359    #[error("Invalid proxy transport protocol: {0}")]
360    /// Invalid proxy transport protocol
361    InvalidProtocol(u8),
362
363    #[error("Trailing data after the header")]
364    /// The buffer contains trailing data after the PROXY Protocol v2 header.
365    TrailingData,
366
367    #[error("Malformed data")]
368    /// The data is malformed, e.g. the length of an extension does not match
369    /// the actual data length.
370    MalformedData,
371}