Skip to main content

proxy_protocol_rs/
builder.rs

1// Copyright (C) 2025-2026 Michael S. Klishin and Contributors
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::net::SocketAddr;
16
17use tokio::io::{self, AsyncWrite, AsyncWriteExt};
18
19use crate::parse::V2_SIGNATURE;
20use crate::types::{
21    AddressFamily, Command, ProxyAddress, SslInfo, Transport, TransportProtocol, Version,
22};
23
24/// Proxy Protocol header builder
25#[must_use]
26pub struct HeaderBuilder {
27    version: Version,
28    command: Command,
29    transport: Option<Transport>,
30    source: Option<ProxyAddress>,
31    destination: Option<ProxyAddress>,
32    tlv_entries: Vec<(u8, Vec<u8>)>,
33    add_crc32c: bool,
34}
35
36impl HeaderBuilder {
37    /// Create a v2 PROXY header with source and destination addresses
38    ///
39    /// # Panics
40    ///
41    /// Panics if source and destination have different address families
42    /// (e.g. one IPv4 and one IPv6)
43    pub fn v2_proxy(source: SocketAddr, destination: SocketAddr) -> Self {
44        assert_eq!(
45            source.is_ipv4(),
46            destination.is_ipv4(),
47            "source and destination must use the same address family"
48        );
49        let family = if source.is_ipv4() {
50            AddressFamily::Inet
51        } else {
52            AddressFamily::Inet6
53        };
54        Self {
55            version: Version::V2,
56            command: Command::Proxy,
57            transport: Some(Transport {
58                family,
59                protocol: TransportProtocol::Stream,
60            }),
61            source: Some(ProxyAddress::Inet(source)),
62            destination: Some(ProxyAddress::Inet(destination)),
63            tlv_entries: Vec::new(),
64            add_crc32c: false,
65        }
66    }
67
68    /// Create a v2 LOCAL header (health-check / proxy-to-self)
69    pub fn v2_local() -> Self {
70        Self {
71            version: Version::V2,
72            command: Command::Local,
73            transport: None,
74            source: None,
75            destination: None,
76            tlv_entries: Vec::new(),
77            add_crc32c: false,
78        }
79    }
80
81    /// Create a v1 PROXY header with source and destination addresses
82    ///
83    /// # Panics
84    ///
85    /// Panics if source and destination have different address families
86    pub fn v1_proxy(source: SocketAddr, destination: SocketAddr) -> Self {
87        assert_eq!(
88            source.is_ipv4(),
89            destination.is_ipv4(),
90            "source and destination must use the same address family"
91        );
92        let family = if source.is_ipv4() {
93            AddressFamily::Inet
94        } else {
95            AddressFamily::Inet6
96        };
97        Self {
98            version: Version::V1,
99            command: Command::Proxy,
100            transport: Some(Transport {
101                family,
102                protocol: TransportProtocol::Stream,
103            }),
104            source: Some(ProxyAddress::Inet(source)),
105            destination: Some(ProxyAddress::Inet(destination)),
106            tlv_entries: Vec::new(),
107            add_crc32c: false,
108        }
109    }
110
111    /// Create a v1 UNKNOWN header (no addresses)
112    pub fn v1_unknown() -> Self {
113        Self {
114            version: Version::V1,
115            command: Command::Proxy,
116            transport: None,
117            source: None,
118            destination: None,
119            tlv_entries: Vec::new(),
120            add_crc32c: false,
121        }
122    }
123
124    /// Create a v2 PROXY header for Unix domain sockets
125    pub fn v2_unix(
126        source: impl Into<Vec<u8>>,
127        destination: impl Into<Vec<u8>>,
128        protocol: TransportProtocol,
129    ) -> Self {
130        Self {
131            version: Version::V2,
132            command: Command::Proxy,
133            transport: Some(Transport {
134                family: AddressFamily::Unix,
135                protocol,
136            }),
137            source: Some(ProxyAddress::Unix(source.into())),
138            destination: Some(ProxyAddress::Unix(destination.into())),
139            tlv_entries: Vec::new(),
140            add_crc32c: false,
141        }
142    }
143
144    /// Override the transport protocol (default is `Stream` for inet headers)
145    pub fn with_transport_protocol(mut self, protocol: TransportProtocol) -> Self {
146        if let Some(ref mut t) = self.transport {
147            t.protocol = protocol;
148        }
149        self
150    }
151
152    /// Add an authority TLV (0x02)
153    pub fn with_authority(mut self, authority: impl Into<String>) -> Self {
154        let v = authority.into().into_bytes();
155        self.tlv_entries.push((0x02, v));
156        self
157    }
158
159    /// Add a unique ID TLV (0x05)
160    ///
161    /// # Panics
162    ///
163    /// Panics if `id` exceeds 128 bytes (the spec maximum for PP2_TYPE_UNIQUE_ID)
164    pub fn with_unique_id(mut self, id: impl Into<Vec<u8>>) -> Self {
165        let id = id.into();
166        assert!(
167            id.len() <= 128,
168            "unique ID length {} exceeds the 128-byte spec maximum",
169            id.len()
170        );
171        self.tlv_entries.push((0x05, id));
172        self
173    }
174
175    /// Add an ALPN TLV (0x01)
176    pub fn with_alpn(mut self, alpn: impl Into<Vec<u8>>) -> Self {
177        self.tlv_entries.push((0x01, alpn.into()));
178        self
179    }
180
181    /// Add an SSL info TLV (0x20)
182    pub fn with_ssl(mut self, ssl: SslInfo) -> Self {
183        self.tlv_entries.push((0x20, encode_ssl_tlv_value(&ssl)));
184        self
185    }
186
187    /// Add a NETNS TLV (0x30)
188    pub fn with_netns(mut self, netns: impl Into<String>) -> Self {
189        self.tlv_entries.push((0x30, netns.into().into_bytes()));
190        self
191    }
192
193    /// Add an arbitrary raw TLV
194    pub fn with_raw_tlv(mut self, type_byte: u8, value: impl Into<Vec<u8>>) -> Self {
195        self.tlv_entries.push((type_byte, value.into()));
196        self
197    }
198
199    /// Add a NOOP padding TLV (0x04) with `len` zero bytes
200    pub fn with_padding(mut self, len: u16) -> Self {
201        self.tlv_entries.push((0x04, vec![0u8; len as usize]));
202        self
203    }
204
205    /// Enable CRC32c checksum TLV; the checksum is computed at build time
206    pub fn with_crc32c(mut self) -> Self {
207        self.add_crc32c = true;
208        self
209    }
210
211    /// Encode the header to bytes
212    ///
213    /// # Panics
214    ///
215    /// Panics if any single TLV value exceeds 65 535 bytes or if the total v2
216    /// payload (addresses + all TLVs) exceeds 65 535 bytes. These are hard
217    /// limits of the v2 wire format (u16 length fields).
218    #[must_use]
219    pub fn build(&self) -> Vec<u8> {
220        match self.version {
221            Version::V1 => self.build_v1(),
222            Version::V2 => self.build_v2(),
223        }
224    }
225
226    /// Write the header directly to an `AsyncWrite` sink
227    ///
228    /// # Panics
229    ///
230    /// Same as [`build()`](Self::build): panics if any TLV value or the total
231    /// v2 payload exceeds the 65 535-byte protocol limit.
232    pub async fn write_to<W: AsyncWrite + Unpin>(&self, writer: &mut W) -> io::Result<usize> {
233        let bytes = self.build();
234        writer.write_all(&bytes).await?;
235        Ok(bytes.len())
236    }
237
238    fn build_v1(&self) -> Vec<u8> {
239        match (&self.source, &self.destination, &self.transport) {
240            (Some(ProxyAddress::Inet(src)), Some(ProxyAddress::Inet(dst)), Some(transport)) => {
241                let proto = match transport.family {
242                    AddressFamily::Inet => "TCP4",
243                    AddressFamily::Inet6 => "TCP6",
244                    _ => unreachable!(),
245                };
246                format!(
247                    "PROXY {} {} {} {} {}\r\n",
248                    proto,
249                    src.ip(),
250                    dst.ip(),
251                    src.port(),
252                    dst.port()
253                )
254                .into_bytes()
255            }
256            _ => b"PROXY UNKNOWN\r\n".to_vec(),
257        }
258    }
259
260    fn build_v2(&self) -> Vec<u8> {
261        let mut buf = Vec::with_capacity(256);
262
263        // 12-byte signature
264        buf.extend_from_slice(V2_SIGNATURE);
265
266        // ver_cmd byte
267        let cmd_nibble = match self.command {
268            Command::Local => 0x00,
269            Command::Proxy => 0x01,
270        };
271        buf.push(0x20 | cmd_nibble);
272
273        // fam_proto byte
274        let (fam, proto) = match &self.transport {
275            Some(t) => {
276                let f = match t.family {
277                    AddressFamily::Inet => 1,
278                    AddressFamily::Inet6 => 2,
279                    AddressFamily::Unix => 3,
280                };
281                let p = match t.protocol {
282                    TransportProtocol::Stream => 1,
283                    TransportProtocol::Datagram => 2,
284                };
285                (f, p)
286            }
287            None => (0, 0),
288        };
289        buf.push((fam << 4) | proto);
290
291        // Placeholder for payload length (2 bytes) — fill in later
292        let len_pos = buf.len();
293        buf.extend_from_slice(&[0, 0]);
294
295        // Addresses
296        match self.command {
297            Command::Local => {}
298            Command::Proxy => {
299                self.encode_addresses(&mut buf);
300            }
301        }
302
303        // TLV entries
304        for (tlv_type, value) in &self.tlv_entries {
305            assert!(
306                value.len() <= u16::MAX as usize,
307                "TLV value length {} exceeds maximum of 65535",
308                value.len()
309            );
310            buf.push(*tlv_type);
311            buf.extend_from_slice(&(value.len() as u16).to_be_bytes());
312            buf.extend_from_slice(value);
313        }
314
315        // CRC32c TLV (must be last)
316        if self.add_crc32c {
317            // Append a placeholder CRC TLV: type(0x03) + len(0x0004) + value(0x00000000)
318            buf.push(0x03);
319            buf.extend_from_slice(&4u16.to_be_bytes());
320            buf.extend_from_slice(&[0, 0, 0, 0]);
321        }
322
323        // Fill in payload length BEFORE computing CRC so the CRC covers the real length
324        let payload_len = buf.len() - 16;
325        assert!(
326            payload_len <= u16::MAX as usize,
327            "v2 payload exceeds maximum size of 65535 bytes ({payload_len} bytes)"
328        );
329        let payload_len = payload_len as u16;
330        buf[len_pos..len_pos + 2].copy_from_slice(&payload_len.to_be_bytes());
331
332        // Now compute and fill in CRC value
333        if self.add_crc32c {
334            let crc = crc32c::crc32c(&buf);
335            let crc_pos = buf.len() - 4;
336            buf[crc_pos..crc_pos + 4].copy_from_slice(&crc.to_be_bytes());
337        }
338
339        buf
340    }
341
342    fn encode_addresses(&self, buf: &mut Vec<u8>) {
343        match (&self.source, &self.destination) {
344            (Some(ProxyAddress::Inet(src)), Some(ProxyAddress::Inet(dst))) => {
345                match (src.ip(), dst.ip()) {
346                    (std::net::IpAddr::V4(s), std::net::IpAddr::V4(d)) => {
347                        buf.extend_from_slice(&s.octets());
348                        buf.extend_from_slice(&d.octets());
349                        buf.extend_from_slice(&src.port().to_be_bytes());
350                        buf.extend_from_slice(&dst.port().to_be_bytes());
351                    }
352                    (std::net::IpAddr::V6(s), std::net::IpAddr::V6(d)) => {
353                        buf.extend_from_slice(&s.octets());
354                        buf.extend_from_slice(&d.octets());
355                        buf.extend_from_slice(&src.port().to_be_bytes());
356                        buf.extend_from_slice(&dst.port().to_be_bytes());
357                    }
358                    _ => {}
359                }
360            }
361            (Some(ProxyAddress::Unix(src)), Some(ProxyAddress::Unix(dst))) => {
362                let mut src_field = [0u8; 108];
363                let src_len = src.len().min(108);
364                src_field[..src_len].copy_from_slice(&src[..src_len]);
365                buf.extend_from_slice(&src_field);
366
367                let mut dst_field = [0u8; 108];
368                let dst_len = dst.len().min(108);
369                dst_field[..dst_len].copy_from_slice(&dst[..dst_len]);
370                buf.extend_from_slice(&dst_field);
371            }
372            _ => {}
373        }
374    }
375}
376
377fn encode_ssl_tlv_value(ssl: &SslInfo) -> Vec<u8> {
378    let mut buf = Vec::new();
379
380    // client flags byte
381    buf.push(ssl.client_flags.bits());
382
383    // verify: 0 = verified, non-zero = not verified
384    let verify: u32 = if ssl.verified { 0 } else { 1 };
385    buf.extend_from_slice(&verify.to_be_bytes());
386
387    // Sub-TLVs
388    if let Some(ref v) = ssl.version {
389        encode_sub_tlv(&mut buf, 0x21, v.as_bytes());
390    }
391    if let Some(ref v) = ssl.cn {
392        encode_sub_tlv(&mut buf, 0x22, v.as_bytes());
393    }
394    if let Some(ref v) = ssl.cipher {
395        encode_sub_tlv(&mut buf, 0x23, v.as_bytes());
396    }
397    if let Some(ref v) = ssl.sig_alg {
398        encode_sub_tlv(&mut buf, 0x24, v.as_bytes());
399    }
400    if let Some(ref v) = ssl.key_alg {
401        encode_sub_tlv(&mut buf, 0x25, v.as_bytes());
402    }
403    if let Some(ref v) = ssl.group {
404        encode_sub_tlv(&mut buf, 0x26, v.as_bytes());
405    }
406    if let Some(ref v) = ssl.sig_scheme {
407        encode_sub_tlv(&mut buf, 0x27, v.as_bytes());
408    }
409    if let Some(ref v) = ssl.client_cert {
410        encode_sub_tlv(&mut buf, 0x28, v);
411    }
412
413    buf
414}
415
416fn encode_sub_tlv(buf: &mut Vec<u8>, type_byte: u8, value: &[u8]) {
417    assert!(
418        value.len() <= u16::MAX as usize,
419        "sub-TLV value length {} exceeds maximum of 65535",
420        value.len()
421    );
422    buf.push(type_byte);
423    buf.extend_from_slice(&(value.len() as u16).to_be_bytes());
424    buf.extend_from_slice(value);
425}