rama_haproxy/protocol/v2/
model.rs

1use crate::protocol::ip::{IPv4, IPv6};
2use crate::protocol::v2::error::ParseError;
3use std::borrow::Cow;
4use std::fmt;
5use std::net::SocketAddr;
6use std::ops::BitOr;
7
8/// The prefix of the PROXY protocol header.
9pub const PROTOCOL_PREFIX: &[u8] = b"\r\n\r\n\0\r\nQUIT\n";
10/// The minimum length in bytes of a PROXY protocol header.
11pub(crate) const MINIMUM_LENGTH: usize = 16;
12/// The minimum length in bytes of a Type-Length-Value payload.
13pub(crate) const MINIMUM_TLV_LENGTH: usize = 3;
14
15/// The number of bytes for an IPv4 addresses payload.
16const IPV4_ADDRESSES_BYTES: usize = 12;
17/// The number of bytes for an IPv6 addresses payload.
18const IPV6_ADDRESSES_BYTES: usize = 36;
19/// The number of bytes for a unix addresses payload.
20const UNIX_ADDRESSES_BYTES: usize = 216;
21
22/// A proxy protocol version 2 header.
23///
24/// ## Examples
25/// ```rust
26/// use rama_haproxy::protocol::v2::{Addresses, AddressFamily, Command, Header, IPv4, ParseError, Protocol, PROTOCOL_PREFIX, Type, TypeLengthValue, Version};
27/// let mut header = Vec::from(PROTOCOL_PREFIX);
28/// header.extend([
29///    0x21, 0x12, 0, 16, 127, 0, 0, 1, 192, 168, 1, 1, 0, 80, 1, 187, 4, 0, 1, 42
30/// ]);
31///
32/// let addresses: Addresses = IPv4::new([127, 0, 0, 1], [192, 168, 1, 1], 80, 443).into();
33/// let expected = Header {
34///    header: header.as_slice().into(),
35///    version: Version::Two,
36///    command: Command::Proxy,
37///    protocol: Protocol::Datagram,
38///    addresses
39/// };
40/// let actual = Header::try_from(header.as_slice()).unwrap();
41///
42/// assert_eq!(actual, expected);
43/// assert_eq!(actual.tlvs().collect::<Vec<Result<TypeLengthValue<'_>, ParseError>>>(), vec![Ok(TypeLengthValue::new(Type::NoOp, &[42]))]);
44/// ```
45#[derive(Clone, Debug, PartialEq, Eq, Hash)]
46pub struct Header<'a> {
47    /// The underlying byte slice this `Header` is built on.
48    pub header: Cow<'a, [u8]>,
49    /// The version of the PROXY protocol.
50    pub version: Version,
51    /// The command of the PROXY protocol.
52    pub command: Command,
53    /// The protocol of the PROXY protocol.
54    pub protocol: Protocol,
55    /// The source and destination addresses of the PROXY protocol.
56    pub addresses: Addresses,
57}
58
59/// The supported `Version`s for binary headers.
60#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
61pub enum Version {
62    /// Version two of the PROXY protocol.
63    Two = 0x20,
64}
65
66/// The supported `Command`s for a PROXY protocol header.
67#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
68pub enum Command {
69    /// The connection is a local connection.
70    Local = 0,
71    /// The connection is a proxy connection.
72    Proxy,
73}
74
75/// The supported `AddressFamily` for a PROXY protocol header.
76#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
77pub enum AddressFamily {
78    /// The address family is unspecified.
79    Unspecified = 0x00,
80    /// The address family is IPv4.
81    IPv4 = 0x10,
82    /// The address family is IPv6.
83    IPv6 = 0x20,
84    /// The address family is Unix.
85    Unix = 0x30,
86}
87
88/// The supported `Protocol`s for a PROXY protocol header.
89#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
90pub enum Protocol {
91    /// The protocol is unspecified.
92    Unspecified = 0,
93    /// The protocol is a stream.
94    Stream,
95    /// The protocol is a datagram.
96    Datagram,
97}
98
99/// The source and destination address information for a given `AddressFamily`.
100///
101/// ## Examples
102/// ```rust
103/// use rama_haproxy::protocol::v2::{Addresses, AddressFamily};
104/// use std::net::SocketAddr;
105///
106/// let addresses: Addresses = ("127.0.0.1:80".parse::<SocketAddr>().unwrap(), "192.168.1.1:443".parse::<SocketAddr>().unwrap()).into();
107///
108/// assert_eq!(addresses.address_family(), AddressFamily::IPv4);
109/// ```
110#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
111pub enum Addresses {
112    /// The source and destination addresses are unspecified.
113    Unspecified,
114    /// The source and destination addresses are IPv4.
115    IPv4(IPv4),
116    /// The source and destination addresses are IPv6.
117    IPv6(IPv6),
118    /// The source and destination addresses are Unix.
119    Unix(Unix),
120}
121
122/// The source and destination addresses of UNIX sockets.
123#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
124pub struct Unix {
125    /// The source address of the UNIX socket.
126    pub source: [u8; 108],
127    /// The destination address of the UNIX socket.
128    pub destination: [u8; 108],
129}
130
131/// An `Iterator` of `TypeLengthValue`s stored in a byte slice.
132#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
133pub struct TypeLengthValues<'a> {
134    bytes: &'a [u8],
135    offset: usize,
136}
137
138/// A Type-Length-Value payload.
139#[derive(Clone, Debug, PartialEq, Eq, Hash)]
140pub struct TypeLengthValue<'a> {
141    /// The type of the `TypeLengthValue`.
142    pub kind: u8,
143    /// The value of the `TypeLengthValue`.
144    pub value: Cow<'a, [u8]>,
145}
146
147/// Supported types for `TypeLengthValue` payloads.
148#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
149pub enum Type {
150    /// The ALPN of the connection.
151    ALPN = 0x01,
152    /// The authority of the connection.
153    Authority,
154    /// The CRC32C checksum of the connection.
155    CRC32C,
156    /// NoOp
157    NoOp,
158    /// The Unique ID of the connection.
159    UniqueId,
160    /// The SSL information.
161    SSL = 0x20,
162    /// The SSL Version.
163    SSLVersion,
164    /// The SSL common name.
165    SSLCommonName,
166    /// The SSL cipher.
167    SSLCipher,
168    /// The SSL Signature Algorithm.
169    SSLSignatureAlgorithm,
170    /// The SSL Key Algorithm
171    SSLKeyAlgorithm,
172    /// The SSL Network Namespace.
173    NetworkNamespace = 0x30,
174}
175
176impl fmt::Display for Header<'_> {
177    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
178        write!(
179            f,
180            "{:?} {:#X} {:#X} ({} bytes)",
181            PROTOCOL_PREFIX,
182            self.version | self.command,
183            self.protocol | self.address_family(),
184            self.length()
185        )
186    }
187}
188
189impl Header<'_> {
190    /// Creates an owned clone of this [`Header`].
191    pub fn to_owned(&self) -> Header<'static> {
192        Header {
193            header: Cow::Owned(self.header.to_vec()),
194            version: self.version,
195            command: self.command,
196            protocol: self.protocol,
197            addresses: self.addresses,
198        }
199    }
200
201    /// The length of this `Header`'s payload in bytes.
202    pub fn length(&self) -> usize {
203        self.header[MINIMUM_LENGTH..].len()
204    }
205
206    /// The total length of this `Header` in bytes.
207    pub fn len(&self) -> usize {
208        self.header.len()
209    }
210
211    /// Tests whether this `Header`'s underlying byte slice is empty.
212    pub fn is_empty(&self) -> bool {
213        self.header.is_empty()
214    }
215
216    /// The `AddressFamily` of this `Header`.
217    pub fn address_family(&self) -> AddressFamily {
218        self.addresses.address_family()
219    }
220
221    /// The length in bytes of the address portion of the payload.
222    fn address_bytes_end(&self) -> usize {
223        let length = self.length();
224        let address_bytes = self.address_family().byte_length().unwrap_or(length);
225
226        MINIMUM_LENGTH + std::cmp::min(address_bytes, length)
227    }
228
229    /// The bytes of the address portion of the payload.
230    pub fn address_bytes(&self) -> &[u8] {
231        &self.header[MINIMUM_LENGTH..self.address_bytes_end()]
232    }
233
234    /// The bytes of the `TypeLengthValue` portion of the payload.
235    pub fn tlv_bytes(&self) -> &[u8] {
236        &self.header[self.address_bytes_end()..]
237    }
238
239    /// An `Iterator` of `TypeLengthValue`s.
240    pub fn tlvs(&self) -> TypeLengthValues<'_> {
241        TypeLengthValues {
242            bytes: self.tlv_bytes(),
243            offset: 0,
244        }
245    }
246
247    /// The underlying byte slice this `Header` is built on.
248    pub fn as_bytes(&self) -> &[u8] {
249        self.header.as_ref()
250    }
251}
252
253impl TypeLengthValues<'_> {
254    /// The underlying byte slice of the `TypeLengthValue`s portion of the `Header` payload.
255    pub fn as_bytes(&self) -> &[u8] {
256        self.bytes
257    }
258}
259
260impl<'a> From<&'a [u8]> for TypeLengthValues<'a> {
261    fn from(bytes: &'a [u8]) -> Self {
262        TypeLengthValues { bytes, offset: 0 }
263    }
264}
265
266impl<'a> Iterator for TypeLengthValues<'a> {
267    type Item = Result<TypeLengthValue<'a>, ParseError>;
268
269    fn next(&mut self) -> Option<Self::Item> {
270        if self.offset >= self.bytes.len() {
271            return None;
272        }
273
274        let remaining = &self.bytes[self.offset..];
275
276        if remaining.len() < MINIMUM_TLV_LENGTH {
277            self.offset = self.bytes.len();
278            return Some(Err(ParseError::Leftovers(self.bytes.len())));
279        }
280
281        let tlv_type = remaining[0];
282        let length = u16::from_be_bytes([remaining[1], remaining[2]]);
283        let tlv_length = MINIMUM_TLV_LENGTH + length as usize;
284
285        if remaining.len() < tlv_length {
286            self.offset = self.bytes.len();
287            return Some(Err(ParseError::InvalidTLV(tlv_type, length)));
288        }
289
290        self.offset += tlv_length;
291
292        Some(Ok(TypeLengthValue {
293            kind: tlv_type,
294            value: Cow::Borrowed(&remaining[MINIMUM_TLV_LENGTH..tlv_length]),
295        }))
296    }
297}
298
299impl TypeLengthValues<'_> {
300    /// The number of bytes in the `TypeLengthValue` portion of the `Header`.
301    pub fn len(&self) -> u16 {
302        self.bytes.len() as u16
303    }
304
305    /// Whether there are any bytes to be interpreted as `TypeLengthValue`s.
306    pub fn is_empty(&self) -> bool {
307        self.bytes.is_empty()
308    }
309}
310
311impl BitOr<Command> for Version {
312    type Output = u8;
313
314    fn bitor(self, command: Command) -> Self::Output {
315        (self as u8) | (command as u8)
316    }
317}
318
319impl BitOr<Version> for Command {
320    type Output = u8;
321
322    fn bitor(self, version: Version) -> Self::Output {
323        (self as u8) | (version as u8)
324    }
325}
326
327impl BitOr<Protocol> for AddressFamily {
328    type Output = u8;
329
330    fn bitor(self, protocol: Protocol) -> Self::Output {
331        (self as u8) | (protocol as u8)
332    }
333}
334
335impl AddressFamily {
336    /// The length in bytes for this `AddressFamily`.
337    /// `AddressFamily::Unspecified` does not require any bytes, and is represented as `None`.
338    pub fn byte_length(&self) -> Option<usize> {
339        match self {
340            AddressFamily::IPv4 => Some(IPV4_ADDRESSES_BYTES),
341            AddressFamily::IPv6 => Some(IPV6_ADDRESSES_BYTES),
342            AddressFamily::Unix => Some(UNIX_ADDRESSES_BYTES),
343            AddressFamily::Unspecified => None,
344        }
345    }
346}
347
348impl From<AddressFamily> for u16 {
349    fn from(address_family: AddressFamily) -> Self {
350        address_family.byte_length().unwrap_or_default() as u16
351    }
352}
353
354impl From<(SocketAddr, SocketAddr)> for Addresses {
355    fn from(addresses: (SocketAddr, SocketAddr)) -> Self {
356        match addresses {
357            (SocketAddr::V4(source), SocketAddr::V4(destination)) => Addresses::IPv4(IPv4::new(
358                *source.ip(),
359                *destination.ip(),
360                source.port(),
361                destination.port(),
362            )),
363            (SocketAddr::V6(source), SocketAddr::V6(destination)) => Addresses::IPv6(IPv6::new(
364                *source.ip(),
365                *destination.ip(),
366                source.port(),
367                destination.port(),
368            )),
369            _ => Addresses::Unspecified,
370        }
371    }
372}
373
374impl From<IPv4> for Addresses {
375    fn from(addresses: IPv4) -> Self {
376        Addresses::IPv4(addresses)
377    }
378}
379
380impl From<IPv6> for Addresses {
381    fn from(addresses: IPv6) -> Self {
382        Addresses::IPv6(addresses)
383    }
384}
385
386impl From<Unix> for Addresses {
387    fn from(addresses: Unix) -> Self {
388        Addresses::Unix(addresses)
389    }
390}
391
392impl Addresses {
393    /// The `AddressFamily` for this `Addresses`.
394    pub fn address_family(&self) -> AddressFamily {
395        match self {
396            Addresses::Unspecified => AddressFamily::Unspecified,
397            Addresses::IPv4(..) => AddressFamily::IPv4,
398            Addresses::IPv6(..) => AddressFamily::IPv6,
399            Addresses::Unix(..) => AddressFamily::Unix,
400        }
401    }
402
403    /// The length in bytes of the `Addresses` in the `Header`'s payload.
404    pub fn len(&self) -> usize {
405        self.address_family().byte_length().unwrap_or_default()
406    }
407
408    /// Tests whether the `Addresses` consume any space in the `Header`'s payload.
409    /// `AddressFamily::Unspecified` does not require any bytes, and always returns true.
410    pub fn is_empty(&self) -> bool {
411        self.address_family().byte_length().is_none()
412    }
413}
414
415impl Unix {
416    /// Creates a new instance of a source and destination address pair for Unix sockets.
417    pub const fn new(source: [u8; 108], destination: [u8; 108]) -> Self {
418        Unix {
419            source,
420            destination,
421        }
422    }
423}
424
425impl BitOr<AddressFamily> for Protocol {
426    type Output = u8;
427
428    fn bitor(self, address_family: AddressFamily) -> Self::Output {
429        (self as u8) | (address_family as u8)
430    }
431}
432
433impl<'a, T: Into<u8>> From<(T, &'a [u8])> for TypeLengthValue<'a> {
434    fn from((kind, value): (T, &'a [u8])) -> Self {
435        TypeLengthValue {
436            kind: kind.into(),
437            value: value.into(),
438        }
439    }
440}
441
442impl<'a> TypeLengthValue<'a> {
443    /// Creates an owned clone of this [`TypeLengthValue`].
444    pub fn to_owned(&self) -> TypeLengthValue<'static> {
445        TypeLengthValue {
446            kind: self.kind,
447            value: Cow::Owned(self.value.to_vec()),
448        }
449    }
450
451    /// Creates a new instance of a `TypeLengthValue`, where the length is determine by the length of the byte slice.
452    /// No check is done to ensure the byte slice's length fits in a `u16`.
453    pub fn new<T: Into<u8>>(kind: T, value: &'a [u8]) -> Self {
454        TypeLengthValue {
455            kind: kind.into(),
456            value: value.into(),
457        }
458    }
459
460    /// The length in bytes of this `TypeLengthValue`'s value.
461    pub fn len(&self) -> usize {
462        self.value.len()
463    }
464
465    /// Tests whether the value of this `TypeLengthValue` is empty.
466    pub fn is_empty(&self) -> bool {
467        self.value.is_empty()
468    }
469}
470
471impl From<Type> for u8 {
472    fn from(kind: Type) -> Self {
473        kind as u8
474    }
475}