rama_haproxy/protocol/v2/
builder.rs

1//! Builder pattern to generate both valid and invalid PROXY protocol v2 headers.
2
3use crate::protocol::v2::{
4    Addresses, Protocol, Type, TypeLengthValue, TypeLengthValues, LENGTH, MINIMUM_LENGTH,
5    MINIMUM_TLV_LENGTH, PROTOCOL_PREFIX,
6};
7use std::io::{self, Write};
8
9/// `Write` interface for the builder's internal buffer.
10/// Can be used to turn header parts into bytes.
11///
12/// ## Examples
13/// ```rust
14/// use rama_haproxy::protocol::v2::{Addresses, Writer, WriteToHeader};
15/// use std::net::SocketAddr;
16///
17/// let addresses: Addresses = ("127.0.0.1:80".parse::<SocketAddr>().unwrap(), "192.168.1.1:443".parse::<SocketAddr>().unwrap()).into();
18/// let mut writer = Writer::default();
19///
20/// addresses.write_to(&mut writer).unwrap();
21///
22/// assert_eq!(addresses.to_bytes().unwrap(), writer.finish());
23/// ```
24#[derive(Debug, Default)]
25pub struct Writer {
26    bytes: Vec<u8>,
27}
28
29/// Implementation of the builder pattern for PROXY protocol v2 headers.
30/// Supports both valid and invalid headers via the `write_payload` and `write_payloads` functions.
31///
32/// ## Examples
33/// ```rust
34/// use rama_haproxy::protocol::v2::{Addresses, AddressFamily, Builder, Command, IPv4, Protocol, PROTOCOL_PREFIX, Type, Version};
35/// let mut expected = Vec::from(PROTOCOL_PREFIX);
36/// expected.extend([
37///    0x21, 0x12, 0, 16, 127, 0, 0, 1, 192, 168, 1, 1, 0, 80, 1, 187, 4, 0, 1, 42
38/// ]);
39///
40/// let addresses: Addresses = IPv4::new([127, 0, 0, 1], [192, 168, 1, 1], 80, 443).into();
41/// let header = Builder::with_addresses(
42///     Version::Two | Command::Proxy,
43///     Protocol::Datagram,
44///     addresses
45/// )
46/// .write_tlv(Type::NoOp, [42].as_slice())
47/// .unwrap()
48/// .build()
49/// .unwrap();
50///
51/// assert_eq!(header, expected);
52/// ```
53#[derive(Debug)]
54pub struct Builder {
55    header: Option<Vec<u8>>,
56    version_command: u8,
57    address_family_protocol: u8,
58    addresses: Addresses,
59    length: Option<u16>,
60    additional_capacity: usize,
61}
62
63impl Writer {
64    /// Consumes this `Writer` and returns the buffer holding the proxy protocol header payloads.
65    /// The returned bytes are not guaranteed to be a valid proxy protocol header.
66    pub fn finish(self) -> Vec<u8> {
67        self.bytes
68    }
69}
70
71impl From<Vec<u8>> for Writer {
72    fn from(bytes: Vec<u8>) -> Self {
73        Writer { bytes }
74    }
75}
76
77impl Write for Writer {
78    fn write(&mut self, buffer: &[u8]) -> io::Result<usize> {
79        if self.bytes.len() > (u16::MAX as usize) + MINIMUM_LENGTH {
80            Err(io::ErrorKind::WriteZero.into())
81        } else {
82            self.bytes.extend_from_slice(buffer);
83            Ok(buffer.len())
84        }
85    }
86
87    fn flush(&mut self) -> io::Result<()> {
88        Ok(())
89    }
90}
91
92/// Defines how to write a type as part of a binary PROXY protocol header.
93pub trait WriteToHeader {
94    /// Write this instance to the given `Writer`.
95    /// The `Writer` returns an IO error when an individual byte slice is longer than `u16::MAX`.
96    /// However, the total length of the buffer may exceed `u16::MAX`.
97    fn write_to(&self, writer: &mut Writer) -> io::Result<usize>;
98
99    /// Writes this instance to a temporary buffer and returns the buffer.
100    fn to_bytes(&self) -> io::Result<Vec<u8>> {
101        let mut writer = Writer::default();
102
103        self.write_to(&mut writer)?;
104
105        Ok(writer.finish())
106    }
107}
108
109impl WriteToHeader for Addresses {
110    fn write_to(&self, writer: &mut Writer) -> io::Result<usize> {
111        match self {
112            Addresses::Unspecified => (),
113            Addresses::IPv4(a) => {
114                writer.write_all(a.source_address.octets().as_slice())?;
115                writer.write_all(a.destination_address.octets().as_slice())?;
116                writer.write_all(a.source_port.to_be_bytes().as_slice())?;
117                writer.write_all(a.destination_port.to_be_bytes().as_slice())?;
118            }
119            Addresses::IPv6(a) => {
120                writer.write_all(a.source_address.octets().as_slice())?;
121                writer.write_all(a.destination_address.octets().as_slice())?;
122                writer.write_all(a.source_port.to_be_bytes().as_slice())?;
123                writer.write_all(a.destination_port.to_be_bytes().as_slice())?;
124            }
125            Addresses::Unix(a) => {
126                writer.write_all(a.source.as_slice())?;
127                writer.write_all(a.destination.as_slice())?;
128            }
129        };
130
131        Ok(self.len())
132    }
133}
134
135impl WriteToHeader for TypeLengthValue<'_> {
136    fn write_to(&self, writer: &mut Writer) -> io::Result<usize> {
137        if self.value.len() > u16::MAX as usize {
138            return Err(io::ErrorKind::WriteZero.into());
139        }
140
141        writer.write_all([self.kind].as_slice())?;
142        writer.write_all((self.value.len() as u16).to_be_bytes().as_slice())?;
143        writer.write_all(self.value.as_ref())?;
144
145        Ok(MINIMUM_TLV_LENGTH + self.value.len())
146    }
147}
148
149impl<T: Copy + Into<u8>> WriteToHeader for (T, &[u8]) {
150    fn write_to(&self, writer: &mut Writer) -> io::Result<usize> {
151        let kind = self.0.into();
152        let value = self.1;
153
154        if value.len() > u16::MAX as usize {
155            return Err(io::ErrorKind::WriteZero.into());
156        }
157
158        writer.write_all([kind].as_slice())?;
159        writer.write_all((value.len() as u16).to_be_bytes().as_slice())?;
160        writer.write_all(value)?;
161
162        Ok(MINIMUM_TLV_LENGTH + value.len())
163    }
164}
165
166impl WriteToHeader for TypeLengthValues<'_> {
167    fn write_to(&self, writer: &mut Writer) -> io::Result<usize> {
168        let bytes = self.as_bytes();
169
170        writer.write_all(bytes)?;
171
172        Ok(bytes.len())
173    }
174}
175
176impl WriteToHeader for [u8] {
177    fn write_to(&self, writer: &mut Writer) -> io::Result<usize> {
178        let slice = self;
179
180        if slice.len() > u16::MAX as usize {
181            return Err(io::ErrorKind::WriteZero.into());
182        }
183
184        writer.write_all(slice)?;
185
186        Ok(slice.len())
187    }
188}
189
190impl<T: ?Sized + WriteToHeader> WriteToHeader for &T {
191    fn write_to(&self, writer: &mut Writer) -> io::Result<usize> {
192        (*self).write_to(writer)
193    }
194}
195
196impl WriteToHeader for Type {
197    fn write_to(&self, writer: &mut Writer) -> io::Result<usize> {
198        writer.write([(*self).into()].as_slice())
199    }
200}
201
202macro_rules! impl_write_to_header {
203    ($t:ident) => {
204        impl WriteToHeader for $t {
205            fn write_to(&self, writer: &mut Writer) -> io::Result<usize> {
206                let bytes = self.to_be_bytes();
207
208                writer.write_all(bytes.as_slice())?;
209
210                Ok(bytes.len())
211            }
212        }
213    };
214}
215
216impl_write_to_header!(u8);
217impl_write_to_header!(u16);
218impl_write_to_header!(u32);
219impl_write_to_header!(u64);
220impl_write_to_header!(u128);
221impl_write_to_header!(usize);
222
223impl_write_to_header!(i8);
224impl_write_to_header!(i16);
225impl_write_to_header!(i32);
226impl_write_to_header!(i64);
227impl_write_to_header!(i128);
228impl_write_to_header!(isize);
229
230impl Builder {
231    /// Creates an instance of a `Builder` with the given header bytes.
232    /// No guarantee is made that any address bytes written as a payload will match the header's address family.
233    /// The length is determined on `build` unless `set_length` is called to set an explicit value.
234    pub const fn new(version_command: u8, address_family_protocol: u8) -> Self {
235        Builder {
236            header: None,
237            version_command,
238            address_family_protocol,
239            addresses: Addresses::Unspecified,
240            length: None,
241            additional_capacity: 0,
242        }
243    }
244
245    /// Creates an instance of a `Builder` with the given header bytes and `Addresses`.
246    /// The address family is determined from the variant of the `Addresses` given.
247    /// The length is determined on `build` unless `set_length` is called to set an explicit value.
248    pub fn with_addresses<T: Into<Addresses>>(
249        version_command: u8,
250        protocol: Protocol,
251        addresses: T,
252    ) -> Self {
253        let addresses = addresses.into();
254
255        Builder {
256            header: None,
257            version_command,
258            address_family_protocol: addresses.address_family() | protocol,
259            addresses,
260            length: None,
261            additional_capacity: 0,
262        }
263    }
264
265    /// Reserves the requested additional capacity in the underlying buffer.
266    /// Helps to prevent resizing the underlying buffer when called before `write_payload`, `write_payloads`.
267    /// When called after `write_payload`, `write_payloads`, useful as a hint on how to resize the buffer.
268    pub fn reserve_capacity(mut self, capacity: usize) -> Self {
269        match self.header {
270            None => self.additional_capacity += capacity,
271            Some(ref mut header) => header.reserve(capacity),
272        }
273
274        self
275    }
276
277    /// Reserves the requested additional capacity in the underlying buffer.
278    /// Helps to prevent resizing the underlying buffer when called before `write_payload`, `write_payloads`.
279    /// When called after `write_payload`, `write_payloads`, useful as a hint on how to resize the buffer.
280    pub fn set_reserve_capacity(&mut self, capacity: usize) -> &mut Self {
281        match self.header {
282            None => self.additional_capacity += capacity,
283            Some(ref mut header) => header.reserve(capacity),
284        }
285
286        self
287    }
288
289    /// Overrides the length in the header.
290    /// When set to `Some` value, the length may be smaller or larger than the actual payload in the buffer.
291    pub fn set_length<T: Into<Option<u16>>>(mut self, length: T) -> Self {
292        self.length = length.into();
293        self
294    }
295
296    /// Writes a iterable set of payloads in order to the buffer.
297    /// No bytes are added by this `Builder` as a delimiter.
298    pub fn write_payloads<T, I, II>(mut self, payloads: II) -> io::Result<Self>
299    where
300        T: WriteToHeader,
301        I: Iterator<Item = T>,
302        II: IntoIterator<IntoIter = I, Item = T>,
303    {
304        self.write_header()?;
305
306        let mut writer = Writer::from(self.header.take().unwrap_or_default());
307
308        for item in payloads {
309            item.write_to(&mut writer)?;
310        }
311
312        self.header = Some(writer.finish());
313
314        Ok(self)
315    }
316
317    /// Writes a single payload to the buffer.
318    /// No surrounding bytes (terminal or otherwise) are added by this `Builder`.
319    pub fn write_payload<T: WriteToHeader>(mut self, payload: T) -> io::Result<Self> {
320        self.write_header()?;
321        self.write_internal(payload)?;
322
323        Ok(self)
324    }
325
326    /// Writes a Type-Length-Value as a payload.
327    /// No surrounding bytes (terminal or otherwise) are added by this `Builder`.
328    /// The length is determined by the length of the slice.
329    /// An error is returned when the length of the slice exceeds `u16::MAX`.
330    pub fn write_tlv(self, kind: impl Into<u8>, value: &[u8]) -> io::Result<Self> {
331        self.write_payload(TypeLengthValue::new(kind, value))
332    }
333
334    /// Writes to the underlying buffer without first writing the header bytes.
335    fn write_internal<T: WriteToHeader>(&mut self, payload: T) -> io::Result<()> {
336        let mut writer = Writer::from(self.header.take().unwrap_or_default());
337
338        payload.write_to(&mut writer)?;
339
340        self.header = Some(writer.finish());
341
342        Ok(())
343    }
344
345    /// Writes the protocol prefix, version, command, address family, protocol, and optional addresses to the buffer.
346    /// Does nothing if the buffer is not empty.
347    fn write_header(&mut self) -> io::Result<()> {
348        if self.header.is_some() {
349            return Ok(());
350        }
351
352        let mut header =
353            Vec::with_capacity(MINIMUM_LENGTH + self.addresses.len() + self.additional_capacity);
354
355        let length = self.length.unwrap_or_default();
356
357        header.extend_from_slice(PROTOCOL_PREFIX);
358        header.push(self.version_command);
359        header.push(self.address_family_protocol);
360        header.extend_from_slice(length.to_be_bytes().as_slice());
361
362        let mut writer = Writer::from(header);
363
364        self.addresses.write_to(&mut writer)?;
365        self.header = Some(writer.finish());
366
367        Ok(())
368    }
369
370    /// Builds the header and returns the underlying buffer.
371    /// If no length was explicitly set, returns an error when the length of the payload portion exceeds `u16::MAX`.
372    pub fn build(mut self) -> io::Result<Vec<u8>> {
373        self.write_header()?;
374
375        let mut header = self.header.take().unwrap_or_default();
376
377        if self.length.is_some() {
378            return Ok(header);
379        }
380
381        if let Ok(payload_length) = u16::try_from(header[MINIMUM_LENGTH..].len()) {
382            let length = payload_length.to_be_bytes();
383            header[LENGTH..LENGTH + length.len()].copy_from_slice(length.as_slice());
384            Ok(header)
385        } else {
386            Err(io::ErrorKind::WriteZero.into())
387        }
388    }
389}
390
391#[cfg(test)]
392mod tests {
393    use super::*;
394    use crate::protocol::v2::{AddressFamily, Command, IPv4, IPv6, Protocol, Type, Unix, Version};
395
396    #[test]
397    fn build_length_too_small() {
398        let mut expected = Vec::from(PROTOCOL_PREFIX);
399        expected.extend([0x21, 0x12, 0, 1, 0, 0, 0, 1]);
400
401        let actual = Builder::new(
402            Version::Two | Command::Proxy,
403            AddressFamily::IPv4 | Protocol::Datagram,
404        )
405        .set_length(1)
406        .write_payload(1u32)
407        .unwrap()
408        .build()
409        .unwrap();
410
411        assert_eq!(actual, expected);
412    }
413
414    #[test]
415    fn build_payload_too_long() {
416        let error = Builder::new(
417            Version::Two | Command::Proxy,
418            AddressFamily::IPv4 | Protocol::Datagram,
419        )
420        .write_payload(vec![0u8; (u16::MAX as usize) + 1].as_slice())
421        .unwrap_err();
422
423        assert_eq!(error.kind(), io::ErrorKind::WriteZero);
424    }
425
426    #[test]
427    fn build_no_payload() {
428        let mut expected = Vec::from(PROTOCOL_PREFIX);
429        expected.extend([0x21, 0x01, 0, 0]);
430
431        let header = Builder::new(
432            Version::Two | Command::Proxy,
433            AddressFamily::Unspecified | Protocol::Stream,
434        )
435        .build()
436        .unwrap();
437
438        assert_eq!(header, expected);
439    }
440
441    #[test]
442    fn build_arbitrary_payload() {
443        let mut expected = Vec::from(PROTOCOL_PREFIX);
444        expected.extend([0x21, 0x01, 0, 1, 42]);
445
446        let header = Builder::new(
447            Version::Two | Command::Proxy,
448            AddressFamily::Unspecified | Protocol::Stream,
449        )
450        .write_payload(42u8)
451        .unwrap()
452        .build()
453        .unwrap();
454
455        assert_eq!(header, expected);
456    }
457
458    #[test]
459    fn build_ipv4() {
460        let mut expected = Vec::from(PROTOCOL_PREFIX);
461        expected.extend([
462            0x21, 0x12, 0, 12, 127, 0, 0, 1, 192, 168, 1, 1, 0, 80, 1, 187,
463        ]);
464
465        let addresses: Addresses = IPv4::new([127, 0, 0, 1], [192, 168, 1, 1], 80, 443).into();
466        let header = Builder::new(
467            Version::Two | Command::Proxy,
468            AddressFamily::IPv4 | Protocol::Datagram,
469        )
470        .set_length(addresses.len() as u16)
471        .write_payload(addresses)
472        .unwrap()
473        .build()
474        .unwrap();
475
476        assert_eq!(header, expected);
477    }
478
479    #[test]
480    fn build_ipv6() {
481        let source_address = [
482            0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
483            0xFF, 0xF2,
484        ];
485        let destination_address = [
486            0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
487            0xFF, 0xF1,
488        ];
489        let mut expected = Vec::from(PROTOCOL_PREFIX);
490        expected.extend([0x20, 0x20, 0, 36]);
491        expected.extend(source_address);
492        expected.extend(destination_address);
493        expected.extend([0, 80, 1, 187]);
494
495        let header = Builder::with_addresses(
496            Version::Two | Command::Local,
497            Protocol::Unspecified,
498            IPv6::new(source_address, destination_address, 80, 443),
499        )
500        .build()
501        .unwrap();
502
503        assert_eq!(header, expected);
504    }
505
506    #[test]
507    fn build_unix() {
508        let source_address = [0xFFu8; 108];
509        let destination_address = [0xAAu8; 108];
510
511        let addresses: Addresses = Unix::new(source_address, destination_address).into();
512        let mut expected = Vec::from(PROTOCOL_PREFIX);
513        expected.extend([0x20, 0x31, 0, 216]);
514        expected.extend(source_address);
515        expected.extend(destination_address);
516
517        let header = Builder::new(
518            Version::Two | Command::Local,
519            AddressFamily::Unix | Protocol::Stream,
520        )
521        .reserve_capacity(addresses.len())
522        .write_payload(addresses)
523        .unwrap()
524        .build()
525        .unwrap();
526
527        assert_eq!(header, expected);
528    }
529
530    #[test]
531    fn build_ipv4_with_tlv() {
532        let mut expected = Vec::from(PROTOCOL_PREFIX);
533        expected.extend([
534            0x21, 0x12, 0, 17, 127, 0, 0, 1, 192, 168, 1, 1, 0, 80, 1, 187, 4, 0, 2, 0, 42,
535        ]);
536
537        let addresses: Addresses = IPv4::new([127, 0, 0, 1], [192, 168, 1, 1], 80, 443).into();
538        let header =
539            Builder::with_addresses(Version::Two | Command::Proxy, Protocol::Datagram, addresses)
540                .reserve_capacity(5)
541                .write_tlv(Type::NoOp, [0, 42].as_slice())
542                .unwrap()
543                .build()
544                .unwrap();
545
546        assert_eq!(header, expected);
547    }
548
549    #[test]
550    fn build_ipv4_with_nested_tlv() {
551        let mut expected = Vec::from(PROTOCOL_PREFIX);
552        expected.extend([
553            0x21, 0x12, 0, 20, 127, 0, 0, 1, 192, 168, 1, 1, 0, 80, 1, 187, 0x20, 0, 5, 0, 0, 0, 0,
554            0,
555        ]);
556
557        let addresses: Addresses = IPv4::new([127, 0, 0, 1], [192, 168, 1, 1], 80, 443).into();
558        let header = Builder::new(
559            Version::Two | Command::Proxy,
560            AddressFamily::IPv4 | Protocol::Datagram,
561        )
562        .write_payload(addresses)
563        .unwrap()
564        .write_payload(Type::SSL)
565        .unwrap()
566        .write_payload(5u16)
567        .unwrap()
568        .write_payload([0u8; 5].as_slice())
569        .unwrap()
570        .build()
571        .unwrap();
572
573        assert_eq!(header, expected);
574    }
575
576    #[test]
577    fn build_ipv6_with_tlvs() {
578        let source_address = [
579            0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
580            0xFF, 0xF2,
581        ];
582        let destination_address = [
583            0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
584            0xFF, 0xF1,
585        ];
586        let addresses: Addresses = IPv6::new(source_address, destination_address, 80, 443).into();
587        let mut expected = Vec::from(PROTOCOL_PREFIX);
588        expected.extend([0x20, 0x20, 0, 48]);
589        expected.extend(source_address);
590        expected.extend(destination_address);
591        expected.extend([0, 80, 1, 187]);
592        expected.extend([4, 0, 1, 0, 4, 0, 1, 0, 4, 0, 1, 42]);
593
594        let header = Builder::new(
595            Version::Two | Command::Local,
596            AddressFamily::IPv6 | Protocol::Unspecified,
597        )
598        .write_payload(addresses)
599        .unwrap()
600        .write_payloads([
601            (Type::NoOp, [0].as_slice()),
602            (Type::NoOp, [0].as_slice()),
603            (Type::NoOp, [42].as_slice()),
604        ])
605        .unwrap()
606        .build()
607        .unwrap();
608
609        assert_eq!(header, expected);
610    }
611
612    #[test]
613    fn build_unix_with_tlv() {
614        let source_address = [0xFFu8; 108];
615        let destination_address = [0xAAu8; 108];
616
617        let addresses: Addresses = Unix::new(source_address, destination_address).into();
618        let mut expected = Vec::from(PROTOCOL_PREFIX);
619        expected.extend([0x20, 0x31, 0, 216]);
620        expected.extend(source_address);
621        expected.extend(destination_address);
622        expected.extend([0x20, 0, 0]);
623
624        let header = Builder::new(
625            Version::Two | Command::Local,
626            AddressFamily::Unix | Protocol::Stream,
627        )
628        .set_length(216)
629        .write_payload(addresses)
630        .unwrap()
631        .write_tlv(Type::SSL, &[])
632        .unwrap()
633        .build()
634        .unwrap();
635
636        assert_eq!(header, expected);
637    }
638}