proxy_protocol_codec/v2/codec/
encode.rs

1//! PROXY Protocol v2 header encoder
2
3pub mod stage;
4
5use alloc::vec::Vec;
6use core::marker::PhantomData;
7
8use crate::v2::codec::encode::stage::{Addr, FamProto, Finished, Len, Magic, VerCmd};
9use crate::v2::model::{
10    AddressPair, Command, ExtensionRef, ExtensionType, Family, Protocol, ADDR_INET6_SIZE, ADDR_INET_SIZE,
11    ADDR_UNIX_SIZE, BYTE_VERSION, HEADER_SIZE,
12};
13use crate::v2::Header;
14
15#[derive(Debug)]
16/// Encoder for a PROXY Protocol v2 header.
17pub struct HeaderEncoder<Stage = Magic> {
18    inner: Vec<u8>,
19
20    /// Marker to indicate the encoding / encoding stage.
21    _stage: PhantomData<Stage>,
22}
23
24impl HeaderEncoder<Magic> {
25    /// Encodes a PROXY Protocol v2 header from the given `Header`.
26    pub fn encode(header: &Header) -> HeaderEncoder<Finished> {
27        let this = Self {
28            inner: Vec::with_capacity(HEADER_SIZE),
29            _stage: PhantomData,
30        };
31
32        match header.command() {
33            Command::Local => this
34                .write_magic()
35                .write_ver_cmd(Command::Local)
36                .write_fam_proto(Family::Unspecified, Protocol::Unspecified)
37                .write_len(0)
38                .write_addr(&AddressPair::Unspecified),
39            Command::Proxy => this
40                .write_magic()
41                .write_ver_cmd(Command::Proxy)
42                .write_fam_proto(header.address_pair().address_family(), *header.protocol())
43                .write_len(0)
44                .write_addr(header.address_pair()),
45        }
46    }
47
48    #[inline(always)]
49    fn write_magic(mut self) -> HeaderEncoder<VerCmd> {
50        self.inner.extend(Header::MAGIC);
51
52        HeaderEncoder {
53            inner: self.inner,
54            _stage: PhantomData,
55        }
56    }
57}
58
59impl HeaderEncoder<VerCmd> {
60    #[inline(always)]
61    fn write_ver_cmd(mut self, command: Command) -> HeaderEncoder<FamProto> {
62        self.inner.push(BYTE_VERSION | command as u8);
63
64        HeaderEncoder {
65            inner: self.inner,
66            _stage: PhantomData,
67        }
68    }
69}
70
71impl HeaderEncoder<FamProto> {
72    #[inline(always)]
73    fn write_fam_proto(mut self, family: Family, protocol: Protocol) -> HeaderEncoder<Len> {
74        self.inner.push(family as u8 | protocol as u8);
75
76        HeaderEncoder {
77            inner: self.inner,
78            _stage: PhantomData,
79        }
80    }
81}
82
83impl HeaderEncoder<Len> {
84    #[inline(always)]
85    fn write_len(mut self, len: u16) -> HeaderEncoder<Addr> {
86        self.inner.extend(len.to_be_bytes());
87
88        HeaderEncoder {
89            inner: self.inner,
90            _stage: PhantomData,
91        }
92    }
93}
94
95impl HeaderEncoder<Addr> {
96    #[inline(always)]
97    fn write_addr(mut self, address_pair: &AddressPair) -> HeaderEncoder<Finished> {
98        match address_pair {
99            AddressPair::Unspecified => HeaderEncoder {
100                inner: self.inner,
101                _stage: PhantomData,
102            },
103            AddressPair::Inet {
104                src_ip,
105                dst_ip,
106                src_port,
107                dst_port,
108            } => {
109                self.inner.reserve(ADDR_INET_SIZE);
110                self.inner.extend(src_ip.octets());
111                self.inner.extend(dst_ip.octets());
112                self.inner.extend(src_port.to_be_bytes());
113                self.inner.extend(dst_port.to_be_bytes());
114
115                HeaderEncoder {
116                    inner: self.inner,
117                    _stage: PhantomData,
118                }
119            }
120            AddressPair::Inet6 {
121                src_ip,
122                dst_ip,
123                src_port,
124                dst_port,
125            } => {
126                self.inner.reserve(ADDR_INET6_SIZE);
127                self.inner.extend(src_ip.octets());
128                self.inner.extend(dst_ip.octets());
129                self.inner.extend(src_port.to_be_bytes());
130                self.inner.extend(dst_port.to_be_bytes());
131
132                HeaderEncoder {
133                    inner: self.inner,
134                    _stage: PhantomData,
135                }
136            }
137            AddressPair::Unix { src_addr, dst_addr } => {
138                self.inner.reserve(ADDR_UNIX_SIZE);
139                self.inner.extend(src_addr);
140                self.inner.extend(dst_addr);
141
142                HeaderEncoder {
143                    inner: self.inner,
144                    _stage: PhantomData,
145                }
146            }
147        }
148    }
149}
150
151impl HeaderEncoder<Finished> {
152    #[inline]
153    /// Writes the `ALPN` extension bytes to the header.
154    ///
155    /// See [`ExtensionType::ALPN`].
156    pub fn write_ext_alpn(self, alpn: &[u8]) -> Result<Self, EncodeError> {
157        Ok(self.write_ext_custom(ExtensionRef::new(ExtensionType::ALPN, alpn).ok_or(EncodeError::ExtensionTooLong)?))
158    }
159
160    #[inline]
161    /// Writes the `Authority` extension bytes to the header.
162    ///
163    /// See [`ExtensionType::Authority`].
164    pub fn write_ext_authority(self, authority: &[u8]) -> Result<Self, EncodeError> {
165        Ok(self.write_ext_custom(
166            ExtensionRef::new(ExtensionType::Authority, authority).ok_or(EncodeError::ExtensionTooLong)?,
167        ))
168    }
169
170    #[inline]
171    /// Writes padding zeros to the header, the total size is `3 + padding`.
172    ///
173    /// See [`ExtensionType::NoOp`].
174    pub fn write_ext_no_op(mut self, padding: u16) -> Result<Self, EncodeError> {
175        self.inner.push(ExtensionType::NoOp as u8);
176        self.inner.extend(padding.to_be_bytes());
177        self.inner.resize(self.inner.len() + padding as usize, 0);
178        Ok(self)
179    }
180
181    #[inline]
182    #[allow(clippy::missing_panics_doc)]
183    /// Writes the `UniqueId` extension bytes to the header.
184    ///
185    /// See [`ExtensionType::UniqueId`].
186    pub fn write_ext_unique_id(self, payload: &[u8]) -> Result<Self, EncodeError> {
187        if payload.len() > 128 {
188            return Err(EncodeError::ExtensionTooLong);
189        }
190
191        // Safety: payload.len() <= 128 < u16::MAX
192        Ok(self.write_ext_custom(ExtensionRef::new(ExtensionType::NetworkNamespace, payload).unwrap()))
193    }
194
195    #[inline]
196    /// Writes the `NetworkNamespace` extension bytes to the header.
197    ///
198    /// See [`ExtensionType::NetworkNamespace`].
199    pub fn write_ext_network_namespace(self, payload: &[u8]) -> Result<Self, EncodeError> {
200        Ok(self.write_ext_custom(
201            ExtensionRef::new(ExtensionType::NetworkNamespace, payload).ok_or(EncodeError::ExtensionTooLong)?,
202        ))
203    }
204
205    #[inline]
206    /// Writes a custom extension to the header.
207    ///
208    /// Notice: will not check if the `typ` is valid.
209    pub fn write_ext_custom(mut self, extension: ExtensionRef<'_>) -> Self {
210        extension.encode(&mut self.inner);
211        self
212    }
213
214    #[inline]
215    fn update_length(&mut self, additional: u16) -> Result<(), EncodeError> {
216        let Ok(length) = u16::try_from(self.inner.len() - HEADER_SIZE) else {
217            return Err(EncodeError::HeaderTooLong);
218        };
219
220        self.inner[14..16].copy_from_slice(&(length + additional).to_be_bytes());
221
222        Ok(())
223    }
224
225    #[cfg(feature = "feat-codec-v2-crc32c")]
226    #[allow(clippy::missing_panics_doc)]
227    /// Calculates and writes the `CRC32C` extension bytes to the header and
228    /// finalizes the header encoding.
229    pub fn finish_with_crc32c(mut self) -> Result<Vec<u8>, EncodeError> {
230        const FIXED_CRC32C_EXTENSION: [u8; 6] = [
231            ExtensionType::CRC32C as u8,
232            (u32::BITS / 8) as u8, // Length of the CRC32C value
233            0,
234            0,
235            0,
236            0, // Placeholder for the CRC32C value
237        ];
238
239        self.update_length(FIXED_CRC32C_EXTENSION.len() as u16)?;
240
241        let crc32c_bytes =
242            crc32c::crc32c_append(crc32c::crc32c_append(0, &self.inner), &FIXED_CRC32C_EXTENSION).to_be_bytes();
243
244        // Safety: FIXED_CRC32C_EXTENSION.len() == 6 < u16::MAX
245        self.write_ext_custom(ExtensionRef::new(ExtensionType::CRC32C, &crc32c_bytes).unwrap())
246            .finish()
247    }
248
249    #[inline(always)]
250    /// Finalizes the header encoding.
251    pub fn finish(mut self) -> Result<Vec<u8>, EncodeError> {
252        self.update_length(0)?;
253
254        Ok(self.inner)
255    }
256}
257
258#[cfg(feature = "feat-codec-encode")]
259#[derive(Debug)]
260#[derive(thiserror::Error)]
261/// Errors that can occur while encoding a PROXY Protocol v2 header.
262pub enum EncodeError {
263    #[error("The src / dst address families do not match.")]
264    /// The src / dst address families do not match.
265    FamilyMismatch,
266
267    #[error("The address is not a valid Unix address")]
268    /// The address is not a valid Unix address (e.g., length out-of-bounds).
269    InvalidUnixAddress,
270
271    #[error("Header bytes too long")]
272    /// Header bytes too long
273    HeaderTooLong,
274
275    #[error("The extension payload is too long.")]
276    /// The extension payload is too long.
277    ExtensionTooLong,
278}