proxy_protocol_codec/v2/codec/
encode.rs

1//! PROXY Protocol v2 header encoder
2
3pub mod stage;
4
5use alloc::vec::Vec;
6use core::marker::PhantomData;
7
8use crate::v2::codec::encode::stage::{Addr, FamProto, Finished, Len, Magic, VerCmd};
9use crate::v2::model::{
10    AddressPair, Command, ExtensionRef, ExtensionType, Family, Protocol, ADDR_INET6_SIZE, ADDR_INET_SIZE,
11    ADDR_UNIX_SIZE, BYTE_VERSION, HEADER_SIZE,
12};
13use crate::v2::Header;
14
15#[derive(Debug)]
16/// Encoder for a PROXY Protocol v2 header.
17pub struct HeaderEncoder<Stage = Magic> {
18    inner: Vec<u8>,
19
20    /// Marker to indicate the encoding / encoding stage.
21    _stage: PhantomData<Stage>,
22}
23
24/// The encoded header type.
25pub type Encoded = HeaderEncoder<Finished>;
26
27impl HeaderEncoder<Magic> {
28    /// Encodes a PROXY Protocol v2 header from the given `Header`.
29    ///
30    /// This returns a [`Encoded`] which can be finalized to produce the encoded
31    /// header bytes. To write some extensions to the header, use the
32    /// methods like [`Encoded::write_ext_alpn`], etc.
33    pub fn encode(header: &Header) -> Encoded {
34        let this = Self {
35            inner: Vec::with_capacity(HEADER_SIZE),
36            _stage: PhantomData,
37        };
38
39        match header.command() {
40            Command::Local => this
41                .write_magic()
42                .write_ver_cmd(Command::Local)
43                .write_fam_proto(Family::Unspecified, Protocol::Unspecified)
44                .write_len(0)
45                .write_addr(&AddressPair::Unspecified),
46            Command::Proxy => this
47                .write_magic()
48                .write_ver_cmd(Command::Proxy)
49                .write_fam_proto(header.address_pair().address_family(), *header.protocol())
50                .write_len(0)
51                .write_addr(header.address_pair()),
52        }
53    }
54
55    #[inline(always)]
56    fn write_magic(mut self) -> HeaderEncoder<VerCmd> {
57        self.inner.extend(Header::MAGIC);
58
59        HeaderEncoder {
60            inner: self.inner,
61            _stage: PhantomData,
62        }
63    }
64}
65
66impl HeaderEncoder<VerCmd> {
67    #[inline(always)]
68    fn write_ver_cmd(mut self, command: Command) -> HeaderEncoder<FamProto> {
69        self.inner.push(BYTE_VERSION | command as u8);
70
71        HeaderEncoder {
72            inner: self.inner,
73            _stage: PhantomData,
74        }
75    }
76}
77
78impl HeaderEncoder<FamProto> {
79    #[inline(always)]
80    fn write_fam_proto(mut self, family: Family, protocol: Protocol) -> HeaderEncoder<Len> {
81        self.inner.push(family as u8 | protocol as u8);
82
83        HeaderEncoder {
84            inner: self.inner,
85            _stage: PhantomData,
86        }
87    }
88}
89
90impl HeaderEncoder<Len> {
91    #[inline(always)]
92    fn write_len(mut self, len: u16) -> HeaderEncoder<Addr> {
93        self.inner.extend(len.to_be_bytes());
94
95        HeaderEncoder {
96            inner: self.inner,
97            _stage: PhantomData,
98        }
99    }
100}
101
102impl HeaderEncoder<Addr> {
103    #[inline(always)]
104    fn write_addr(mut self, address_pair: &AddressPair) -> HeaderEncoder<Finished> {
105        match address_pair {
106            AddressPair::Unspecified => HeaderEncoder {
107                inner: self.inner,
108                _stage: PhantomData,
109            },
110            AddressPair::Inet {
111                src_ip,
112                dst_ip,
113                src_port,
114                dst_port,
115            } => {
116                self.inner.reserve(ADDR_INET_SIZE);
117                self.inner.extend(src_ip.octets());
118                self.inner.extend(dst_ip.octets());
119                self.inner.extend(src_port.to_be_bytes());
120                self.inner.extend(dst_port.to_be_bytes());
121
122                HeaderEncoder {
123                    inner: self.inner,
124                    _stage: PhantomData,
125                }
126            }
127            AddressPair::Inet6 {
128                src_ip,
129                dst_ip,
130                src_port,
131                dst_port,
132            } => {
133                self.inner.reserve(ADDR_INET6_SIZE);
134                self.inner.extend(src_ip.octets());
135                self.inner.extend(dst_ip.octets());
136                self.inner.extend(src_port.to_be_bytes());
137                self.inner.extend(dst_port.to_be_bytes());
138
139                HeaderEncoder {
140                    inner: self.inner,
141                    _stage: PhantomData,
142                }
143            }
144            AddressPair::Unix { src_addr, dst_addr } => {
145                self.inner.reserve(ADDR_UNIX_SIZE);
146                self.inner.extend(src_addr);
147                self.inner.extend(dst_addr);
148
149                HeaderEncoder {
150                    inner: self.inner,
151                    _stage: PhantomData,
152                }
153            }
154        }
155    }
156}
157
158impl HeaderEncoder<Finished> {
159    #[inline]
160    /// Writes the `ALPN` extension bytes to the header.
161    ///
162    /// See [`ExtensionType::ALPN`].
163    pub fn write_ext_alpn(self, alpn: &[u8]) -> Result<Self, EncodeError> {
164        Ok(self.write_ext_custom(ExtensionRef::new(ExtensionType::ALPN, alpn).ok_or(EncodeError::ExtensionTooLong)?))
165    }
166
167    #[inline]
168    /// Writes the `Authority` extension bytes to the header.
169    ///
170    /// See [`ExtensionType::Authority`].
171    pub fn write_ext_authority(self, authority: &[u8]) -> Result<Self, EncodeError> {
172        Ok(self.write_ext_custom(
173            ExtensionRef::new(ExtensionType::Authority, authority).ok_or(EncodeError::ExtensionTooLong)?,
174        ))
175    }
176
177    #[inline]
178    /// Writes padding zeros to the header, the total size is `3 + padding`.
179    ///
180    /// See [`ExtensionType::NoOp`].
181    pub fn write_ext_no_op(mut self, padding: u16) -> Result<Self, EncodeError> {
182        self.inner.push(ExtensionType::NoOp as u8);
183        self.inner.extend(padding.to_be_bytes());
184        self.inner.resize(self.inner.len() + padding as usize, 0);
185        Ok(self)
186    }
187
188    #[inline]
189    #[allow(clippy::missing_panics_doc)]
190    /// Writes the `UniqueId` extension bytes to the header.
191    ///
192    /// See [`ExtensionType::UniqueId`].
193    pub fn write_ext_unique_id(self, payload: &[u8]) -> Result<Self, EncodeError> {
194        if payload.len() > 128 {
195            return Err(EncodeError::ExtensionTooLong);
196        }
197
198        // Safety: payload.len() <= 128 < u16::MAX
199        Ok(self.write_ext_custom(ExtensionRef::new(ExtensionType::NetworkNamespace, payload).unwrap()))
200    }
201
202    #[inline]
203    /// Writes the `NetworkNamespace` extension bytes to the header.
204    ///
205    /// See [`ExtensionType::NetworkNamespace`].
206    pub fn write_ext_network_namespace(self, payload: &[u8]) -> Result<Self, EncodeError> {
207        Ok(self.write_ext_custom(
208            ExtensionRef::new(ExtensionType::NetworkNamespace, payload).ok_or(EncodeError::ExtensionTooLong)?,
209        ))
210    }
211
212    #[inline]
213    /// Writes a custom extension to the header.
214    ///
215    /// Notice: will not check if the `typ` is valid.
216    pub fn write_ext_custom(mut self, extension: ExtensionRef<'_>) -> Self {
217        extension.encode(&mut self.inner);
218        self
219    }
220
221    #[inline]
222    fn update_length(&mut self, additional: u16) -> Result<(), EncodeError> {
223        let Ok(length) = u16::try_from(self.inner.len() - HEADER_SIZE) else {
224            return Err(EncodeError::HeaderTooLong);
225        };
226
227        self.inner[14..16].copy_from_slice(&(length + additional).to_be_bytes());
228
229        Ok(())
230    }
231
232    #[cfg(feature = "feat-codec-v2-crc32c")]
233    #[allow(clippy::missing_panics_doc)]
234    /// Calculates and writes the `CRC32C` extension bytes to the header and
235    /// finalizes the header encoding.
236    pub fn finish_with_crc32c(mut self) -> Result<Vec<u8>, EncodeError> {
237        const FIXED_CRC32C_EXTENSION: [u8; 7] = [
238            ExtensionType::CRC32C as u8,
239            0,
240            4, // Length of the CRC32C value
241            0,
242            0,
243            0,
244            0, // Placeholder for the CRC32C value
245        ];
246
247        self.update_length(FIXED_CRC32C_EXTENSION.len() as u16)?;
248
249        let crc32c_bytes =
250            crc32c::crc32c_append(crc32c::crc32c_append(0, &self.inner), &FIXED_CRC32C_EXTENSION).to_be_bytes();
251
252        // Safety: FIXED_CRC32C_EXTENSION.len() == 7 < u16::MAX
253        self.write_ext_custom(ExtensionRef::new(ExtensionType::CRC32C, &crc32c_bytes).unwrap())
254            .finish()
255    }
256
257    #[inline(always)]
258    /// Finalizes the header encoding.
259    pub fn finish(mut self) -> Result<Vec<u8>, EncodeError> {
260        self.update_length(0)?;
261
262        Ok(self.inner)
263    }
264}
265
266#[cfg(feature = "feat-codec-encode")]
267#[derive(Debug)]
268#[derive(thiserror::Error)]
269/// Errors that can occur while encoding a PROXY Protocol v2 header.
270pub enum EncodeError {
271    #[error("The src / dst address families do not match.")]
272    /// The src / dst address families do not match.
273    FamilyMismatch,
274
275    #[error("The address is not a valid Unix address")]
276    /// The address is not a valid Unix address (e.g., length out-of-bounds).
277    InvalidUnixAddress,
278
279    #[error("Header bytes too long")]
280    /// Header bytes too long
281    HeaderTooLong,
282
283    #[error("The extension payload is too long.")]
284    /// The extension payload is too long.
285    ExtensionTooLong,
286}