proxy_protocol_codec/v2/codec/
encode.rs

1//! PROXY Protocol v2 header encoder
2
3pub mod stage;
4
5use alloc::vec::Vec;
6use core::marker::PhantomData;
7
8use crate::v2::codec::encode::stage::{Addr, FamProto, Finished, Len, Magic, VerCmd};
9use crate::v2::model::{
10    AddressPair, Command, ExtensionRef, ExtensionType, Family, Protocol, ADDR_INET6_SIZE, ADDR_INET_SIZE,
11    ADDR_UNIX_SIZE, BYTE_VERSION, HEADER_SIZE,
12};
13use crate::v2::Header;
14
15#[derive(Debug)]
16/// Encoder for a PROXY Protocol v2 header.
17pub struct HeaderEncoder<Stage = Magic> {
18    inner: Vec<u8>,
19
20    /// Marker to indicate the encoding / encoding stage.
21    _stage: PhantomData<Stage>,
22}
23
24/// The encoded header type.
25pub type Encoded = HeaderEncoder<Finished>;
26
27impl HeaderEncoder<Magic> {
28    /// Encodes a PROXY Protocol v2 header from the given `Header`.
29    ///
30    /// This returns a [`Encoded`] which can be finalized to produce the encoded
31    /// header bytes. To write some extensions to the header, use the
32    /// methods like [`Encoded::write_ext_alpn`], etc.
33    pub fn encode(header: &Header) -> Encoded {
34        let this = Self {
35            inner: Vec::with_capacity(HEADER_SIZE),
36            _stage: PhantomData,
37        };
38
39        match header.command() {
40            Command::Local => this
41                .write_magic()
42                .write_ver_cmd(Command::Local)
43                .write_fam_proto(Family::Unspecified, Protocol::Unspecified)
44                .write_len(0)
45                .write_addr(&AddressPair::Unspecified),
46            Command::Proxy => this
47                .write_magic()
48                .write_ver_cmd(Command::Proxy)
49                .write_fam_proto(header.address_pair().address_family(), *header.protocol())
50                .write_len(0)
51                .write_addr(header.address_pair()),
52        }
53    }
54
55    #[inline]
56    fn write_magic(mut self) -> HeaderEncoder<VerCmd> {
57        self.inner.extend(Header::MAGIC);
58
59        HeaderEncoder {
60            inner: self.inner,
61            _stage: PhantomData,
62        }
63    }
64}
65
66impl HeaderEncoder<VerCmd> {
67    #[inline]
68    fn write_ver_cmd(mut self, command: Command) -> HeaderEncoder<FamProto> {
69        self.inner.push(BYTE_VERSION | command as u8);
70
71        HeaderEncoder {
72            inner: self.inner,
73            _stage: PhantomData,
74        }
75    }
76}
77
78impl HeaderEncoder<FamProto> {
79    #[inline]
80    fn write_fam_proto(mut self, family: Family, protocol: Protocol) -> HeaderEncoder<Len> {
81        self.inner.push(family as u8 | protocol as u8);
82
83        HeaderEncoder {
84            inner: self.inner,
85            _stage: PhantomData,
86        }
87    }
88}
89
90impl HeaderEncoder<Len> {
91    #[inline]
92    fn write_len(mut self, len: u16) -> HeaderEncoder<Addr> {
93        self.inner.extend(len.to_be_bytes());
94
95        HeaderEncoder {
96            inner: self.inner,
97            _stage: PhantomData,
98        }
99    }
100}
101
102impl HeaderEncoder<Addr> {
103    #[inline]
104    fn write_addr(mut self, address_pair: &AddressPair) -> HeaderEncoder<Finished> {
105        match address_pair {
106            AddressPair::Unspecified => HeaderEncoder {
107                inner: self.inner,
108                _stage: PhantomData,
109            },
110            AddressPair::Inet {
111                src_ip,
112                dst_ip,
113                src_port,
114                dst_port,
115            } => {
116                self.inner.reserve(ADDR_INET_SIZE);
117                self.inner.extend(src_ip.octets());
118                self.inner.extend(dst_ip.octets());
119                self.inner.extend(src_port.to_be_bytes());
120                self.inner.extend(dst_port.to_be_bytes());
121
122                HeaderEncoder {
123                    inner: self.inner,
124                    _stage: PhantomData,
125                }
126            }
127            AddressPair::Inet6 {
128                src_ip,
129                dst_ip,
130                src_port,
131                dst_port,
132            } => {
133                self.inner.reserve(ADDR_INET6_SIZE);
134                self.inner.extend(src_ip.octets());
135                self.inner.extend(dst_ip.octets());
136                self.inner.extend(src_port.to_be_bytes());
137                self.inner.extend(dst_port.to_be_bytes());
138
139                HeaderEncoder {
140                    inner: self.inner,
141                    _stage: PhantomData,
142                }
143            }
144            AddressPair::Unix { src_addr, dst_addr } => {
145                self.inner.reserve(ADDR_UNIX_SIZE);
146                self.inner.extend(src_addr);
147                self.inner.extend(dst_addr);
148
149                HeaderEncoder {
150                    inner: self.inner,
151                    _stage: PhantomData,
152                }
153            }
154        }
155    }
156}
157
158impl HeaderEncoder<Finished> {
159    #[inline]
160    /// Writes the `ALPN` extension bytes to the header.
161    ///
162    /// See [`ExtensionType::ALPN`].
163    ///
164    /// # Errors
165    ///
166    /// See [`EncodeError`].
167    pub fn write_ext_alpn(self, alpn: &[u8]) -> Result<Self, EncodeError> {
168        Ok(self.write_ext_custom(ExtensionRef::new(ExtensionType::ALPN, alpn).ok_or(EncodeError::ExtensionTooLong)?))
169    }
170
171    #[inline]
172    /// Writes the `Authority` extension bytes to the header.
173    ///
174    /// See [`ExtensionType::Authority`].
175    ///
176    /// # Errors
177    ///
178    /// See [`EncodeError`].
179    pub fn write_ext_authority(self, authority: &[u8]) -> Result<Self, EncodeError> {
180        Ok(self.write_ext_custom(
181            ExtensionRef::new(ExtensionType::Authority, authority).ok_or(EncodeError::ExtensionTooLong)?,
182        ))
183    }
184
185    #[inline]
186    /// Writes padding zeros to the header, the total size is `3 + padding`.
187    ///
188    /// See [`ExtensionType::NoOp`].
189    ///
190    /// # Errors
191    ///
192    /// See [`EncodeError`].
193    pub fn write_ext_no_op(mut self, padding: u16) -> Result<Self, EncodeError> {
194        self.inner.push(ExtensionType::NoOp as u8);
195        self.inner.extend(padding.to_be_bytes());
196        self.inner.resize(self.inner.len() + padding as usize, 0);
197        Ok(self)
198    }
199
200    #[inline]
201    #[allow(clippy::missing_panics_doc, reason = "XXX")]
202    /// Writes the `UniqueId` extension bytes to the header.
203    ///
204    /// See [`ExtensionType::UniqueId`].
205    ///
206    /// # Errors
207    ///
208    /// See [`EncodeError`].
209    pub fn write_ext_unique_id(self, payload: &[u8]) -> Result<Self, EncodeError> {
210        if payload.len() > 128 {
211            return Err(EncodeError::ExtensionTooLong);
212        }
213
214        // Safety: payload.len() <= 128 < u16::MAX
215        Ok(self.write_ext_custom(ExtensionRef::new(ExtensionType::NetworkNamespace, payload).unwrap()))
216    }
217
218    #[inline]
219    /// Writes the `NetworkNamespace` extension bytes to the header.
220    ///
221    /// See [`ExtensionType::NetworkNamespace`].
222    ///
223    /// # Errors
224    ///
225    /// See [`EncodeError`].
226    pub fn write_ext_network_namespace(self, payload: &[u8]) -> Result<Self, EncodeError> {
227        Ok(self.write_ext_custom(
228            ExtensionRef::new(ExtensionType::NetworkNamespace, payload).ok_or(EncodeError::ExtensionTooLong)?,
229        ))
230    }
231
232    #[inline]
233    /// Writes a custom extension to the header.
234    ///
235    /// Notice: will not check if the `typ` is valid.
236    pub fn write_ext_custom(mut self, extension: ExtensionRef<'_>) -> Self {
237        extension.encode(&mut self.inner);
238        self
239    }
240
241    #[inline]
242    fn update_length(&mut self, additional: u16) -> Result<(), EncodeError> {
243        let Ok(length) = u16::try_from(self.inner.len() - HEADER_SIZE) else {
244            return Err(EncodeError::HeaderTooLong);
245        };
246
247        self.inner[14..16].copy_from_slice(&(length + additional).to_be_bytes());
248
249        Ok(())
250    }
251
252    #[cfg(feature = "feat-codec-v2-crc32c")]
253    #[allow(clippy::missing_panics_doc, reason = "XXX")]
254    /// Calculates and writes the `CRC32C` extension bytes to the header and
255    /// finalizes the header encoding.
256    ///
257    /// # Errors
258    ///
259    /// See [`EncodeError`].
260    pub fn finish_with_crc32c(mut self) -> Result<Vec<u8>, EncodeError> {
261        const FIXED_CRC32C_EXTENSION: [u8; 7] = [
262            ExtensionType::CRC32C as u8,
263            0,
264            4, // Length of the CRC32C value
265            0,
266            0,
267            0,
268            0, // Placeholder for the CRC32C value
269        ];
270
271        #[allow(clippy::cast_possible_truncation, reason = "XXX")]
272        self.update_length(FIXED_CRC32C_EXTENSION.len() as u16)?;
273
274        let crc32c_bytes =
275            crc32c::crc32c_append(crc32c::crc32c_append(0, &self.inner), &FIXED_CRC32C_EXTENSION).to_be_bytes();
276
277        // Safety: FIXED_CRC32C_EXTENSION.len() == 7 < u16::MAX
278        self.write_ext_custom(ExtensionRef::new(ExtensionType::CRC32C, &crc32c_bytes).unwrap())
279            .finish()
280    }
281
282    #[inline]
283    /// Finalizes the header encoding.
284    ///
285    /// # Errors
286    ///
287    /// See [`EncodeError`].
288    pub fn finish(mut self) -> Result<Vec<u8>, EncodeError> {
289        self.update_length(0)?;
290
291        Ok(self.inner)
292    }
293}
294
295#[allow(clippy::module_name_repetitions, reason = "XXX")]
296#[cfg(feature = "feat-codec-encode")]
297#[derive(Debug)]
298#[derive(thiserror::Error)]
299/// Errors that can occur while encoding a PROXY Protocol v2 header.
300pub enum EncodeError {
301    #[error("The src / dst address families do not match.")]
302    /// The src / dst address families do not match.
303    FamilyMismatch,
304
305    #[error("The address is not a valid Unix address")]
306    /// The address is not a valid Unix address (e.g., length out-of-bounds).
307    InvalidUnixAddress,
308
309    #[error("Header bytes too long")]
310    /// Header bytes too long
311    HeaderTooLong,
312
313    #[error("The extension payload is too long.")]
314    /// The extension payload is too long.
315    ExtensionTooLong,
316}