actix_proxy_protocol/
v2.rs

1use std::{
2    io,
3    net::{IpAddr, SocketAddr},
4};
5
6use smallvec::{SmallVec, ToSmallVec as _};
7use tokio::io::{AsyncWrite, AsyncWriteExt as _};
8
9use crate::{
10    AddressFamily, Command, TransportProtocol, Version,
11    tlv::{Crc32c, Tlv},
12};
13
14pub const SIGNATURE: [u8; 12] = [
15    0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A,
16];
17
18#[derive(Debug, Clone)]
19pub struct Header {
20    command: Command,
21    transport_protocol: TransportProtocol,
22    address_family: AddressFamily,
23    src: SocketAddr,
24    dst: SocketAddr,
25    tlvs: SmallVec<[(u8, SmallVec<[u8; 16]>); 4]>,
26}
27
28impl Header {
29    pub fn new(
30        command: Command,
31        transport_protocol: TransportProtocol,
32        address_family: AddressFamily,
33        src: impl Into<SocketAddr>,
34        dst: impl Into<SocketAddr>,
35    ) -> Self {
36        Self {
37            command,
38            transport_protocol,
39            address_family,
40            src: src.into(),
41            dst: dst.into(),
42            tlvs: SmallVec::new(),
43        }
44    }
45
46    pub fn new_tcp_ipv4_proxy(src: impl Into<SocketAddr>, dst: impl Into<SocketAddr>) -> Self {
47        Self::new(
48            Command::Proxy,
49            TransportProtocol::Stream,
50            AddressFamily::Inet,
51            src,
52            dst,
53        )
54    }
55
56    pub fn add_tlv(&mut self, typ: u8, value: impl AsRef<[u8]>) {
57        self.tlvs.push((typ, SmallVec::from_slice(value.as_ref())));
58    }
59
60    pub fn add_typed_tlv<T: Tlv>(&mut self, tlv: T) {
61        self.add_tlv(T::TYPE, tlv.value_bytes());
62    }
63
64    fn v2_len(&self) -> u16 {
65        let addr_len = if self.src.is_ipv4() {
66            4 + 2 // 4b IPv4 + 2b port number
67        } else {
68            16 + 2 // 16b IPv6 + 2b port number
69        };
70
71        (addr_len * 2)
72            + self
73                .tlvs
74                .iter()
75                .map(|(_, value)| 1 + 2 + value.len() as u16)
76                .sum::<u16>()
77    }
78
79    pub fn write_to(&self, wrt: &mut impl io::Write) -> io::Result<()> {
80        // PROXY v2 signature
81        wrt.write_all(&SIGNATURE)?;
82
83        // version | command
84        wrt.write_all(&[Version::V2.v2_hi() | self.command.v2_lo()])?;
85
86        // address family | transport protocol
87        wrt.write_all(&[self.address_family.v2_hi() | self.transport_protocol.v2_lo()])?;
88
89        // rest-of-header length
90        wrt.write_all(&self.v2_len().to_be_bytes())?;
91
92        tracing::debug!("proxy rest-of-header len: {}", self.v2_len());
93
94        fn write_ip_bytes_to(wrt: &mut impl io::Write, ip: IpAddr) -> io::Result<()> {
95            match ip {
96                IpAddr::V4(ip) => wrt.write_all(&ip.octets()),
97                IpAddr::V6(ip) => wrt.write_all(&ip.octets()),
98            }
99        }
100
101        // L3 (IP) address
102        write_ip_bytes_to(wrt, self.src.ip())?;
103        write_ip_bytes_to(wrt, self.dst.ip())?;
104
105        // L4 ports
106        wrt.write_all(&self.src.port().to_be_bytes())?;
107        wrt.write_all(&self.dst.port().to_be_bytes())?;
108
109        // TLVs
110        for (typ, value) in &self.tlvs {
111            wrt.write_all(&[*typ])?;
112            wrt.write_all(&(value.len() as u16).to_be_bytes())?;
113            wrt.write_all(value)?;
114        }
115
116        Ok(())
117    }
118
119    pub async fn write_to_tokio(&self, wrt: &mut (impl AsyncWrite + Unpin)) -> io::Result<()> {
120        let buf = self.to_vec();
121        wrt.write_all(&buf).await
122    }
123
124    fn to_vec(&self) -> Vec<u8> {
125        // TODO: figure out cap
126        let mut buf = Vec::with_capacity(64);
127        self.write_to(&mut buf).unwrap();
128        buf
129    }
130
131    pub fn has_tlv<T: Tlv>(&self) -> bool {
132        self.tlvs.iter().any(|&(typ, _)| typ == T::TYPE)
133    }
134
135    /// Calculates and adds a crc32c TLV to the PROXY header.
136    ///
137    /// Uses method defined in spec.
138    ///
139    /// If this is not called last thing it will be wrong.
140    pub fn add_crc23c_checksum(&mut self) {
141        // don't add a checksum if it is already set
142        if self.has_tlv::<Crc32c>() {
143            return;
144        }
145
146        // When the checksum is supported by the sender after constructing the header
147        // the sender MUST:
148        // - initialize the checksum field to '0's.
149        // - calculate the CRC32c checksum of the PROXY header as described in RFC4960,
150        //   Appendix B [8].
151        // - put the resultant value into the checksum field, and leave the rest of
152        //   the bits unchanged.
153
154        // add zeroed checksum field to TLVs
155        self.add_typed_tlv(Crc32c::default());
156
157        // write PROXY header to buffer
158        let mut buf = Vec::new();
159        self.write_to(&mut buf).unwrap();
160
161        // calculate CRC on buffer and update CRC TLV
162        let crc_calc = crc32fast::hash(&buf);
163        self.tlvs.last_mut().unwrap().1 = crc_calc.to_be_bytes().to_smallvec();
164
165        tracing::debug!("checksum is {}", crc_calc);
166    }
167
168    pub fn validate_crc32c_tlv(&self) -> Option<bool> {
169        // extract crc32c TLV or exit early if none is present
170        let crc_sent = self
171            .tlvs
172            .iter()
173            .filter_map(|(typ, value)| Crc32c::try_from_parts(*typ, value))
174            .next()?;
175
176        // If the checksum is provided as part of the PROXY header and the checksum
177        // functionality is supported by the receiver, the receiver MUST:
178        //  - store the received CRC32c checksum value aside.
179        //  - replace the 32 bits of the checksum field in the received PROXY header with
180        //    all '0's and calculate a CRC32c checksum value of the whole PROXY header.
181        //  - verify that the calculated CRC32c checksum is the same as the received
182        //    CRC32c checksum. If it is not, the receiver MUST treat the TCP connection
183        //    providing the header as invalid.
184        // The default procedure for handling an invalid TCP connection is to abort it.
185
186        let mut this = self.clone();
187        for (typ, value) in this.tlvs.iter_mut() {
188            if Crc32c::try_from_parts(*typ, value).is_some() {
189                value.fill(0);
190            }
191        }
192
193        let mut buf = Vec::new();
194        this.write_to(&mut buf).unwrap();
195        let crc_calc = crc32fast::hash(&buf);
196
197        Some(crc_sent.checksum == crc_calc)
198    }
199}
200
201#[cfg(test)]
202mod tests {
203    use std::net::Ipv6Addr;
204
205    use const_str::hex;
206    use pretty_assertions::assert_eq;
207
208    use super::*;
209
210    #[test]
211    fn write_v2_no_tlvs() {
212        let mut exp = Vec::new();
213        exp.extend_from_slice(&SIGNATURE); // 0-11
214        exp.extend_from_slice(&[0x21, 0x11]); // 12-13
215        exp.extend_from_slice(&[0x00, 0x0C]); // 14-15
216        exp.extend_from_slice(&[127, 0, 0, 1, 127, 0, 0, 2]); // 16-23
217        exp.extend_from_slice(&[0x04, 0xd2, 0x00, 80]); // 24-27
218
219        let header = Header::new(
220            Command::Proxy,
221            TransportProtocol::Stream,
222            AddressFamily::Inet,
223            SocketAddr::from(([127, 0, 0, 1], 1234)),
224            SocketAddr::from(([127, 0, 0, 2], 80)),
225        );
226
227        assert_eq!(header.v2_len(), 12);
228        assert_eq!(header.to_vec(), exp);
229    }
230
231    #[test]
232    fn write_v2_ipv6_tlv_noop() {
233        let mut exp = Vec::new();
234        exp.extend_from_slice(&SIGNATURE); // 0-11
235        exp.extend_from_slice(&[0x20, 0x11]); // 12-13
236        exp.extend_from_slice(&[0x00, 0x28]); // 14-15
237        exp.extend_from_slice(&hex!("00000000000000000000000000000001")); // 16-31
238        exp.extend_from_slice(&hex!("000102030405060708090A0B0C0D0E0F")); // 32-45
239        exp.extend_from_slice(&[0x00, 80, 0xff, 0xff]); // 45-49
240        exp.extend_from_slice(&[0x04, 0x00, 0x01, 0x00]); // 50-53 NOOP TLV
241
242        let mut header = Header::new(
243            Command::Local,
244            TransportProtocol::Stream,
245            AddressFamily::Inet,
246            SocketAddr::from((Ipv6Addr::LOCALHOST, 80)),
247            SocketAddr::from((
248                Ipv6Addr::from([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]),
249                65535,
250            )),
251        );
252
253        header.add_tlv(0x04, [0]);
254
255        assert_eq!(header.v2_len(), 36 + 4);
256        assert_eq!(header.to_vec(), exp);
257    }
258
259    #[test]
260    fn write_v2_tlv_c2c() {
261        let mut exp = Vec::new();
262        exp.extend_from_slice(&SIGNATURE); // 0-11
263        exp.extend_from_slice(&[0x21, 0x11]); // 12-13
264        exp.extend_from_slice(&[0x00, 0x13]); // 14-15
265        exp.extend_from_slice(&[127, 0, 0, 1, 127, 0, 0, 1]); // 16-23
266        exp.extend_from_slice(&[0x00, 80, 0x00, 80]); // 24-27
267        exp.extend_from_slice(&[0x03, 0x00, 0x04, 0x00, 0x00, 0x00, 0x00]); // 28-35 TLV crc32c
268
269        assert_eq!(
270            crc32fast::hash(&exp),
271            // correct checksum calculated manually
272            u32::from_be_bytes([0x08, 0x70, 0x17, 0x7b]),
273        );
274
275        // re-assign actual checksum to last 4 bytes of expected byte array
276        exp[31..35].copy_from_slice(&[0x08, 0x70, 0x17, 0x7b]);
277
278        let mut header = Header::new(
279            Command::Proxy,
280            TransportProtocol::Stream,
281            AddressFamily::Inet,
282            SocketAddr::from(([127, 0, 0, 1], 80)),
283            SocketAddr::from(([127, 0, 0, 1], 80)),
284        );
285
286        assert!(
287            header.validate_crc32c_tlv().is_none(),
288            "header doesn't have CRC TLV added yet"
289        );
290
291        // add crc32c TLV to header
292        header.add_crc23c_checksum();
293
294        assert_eq!(header.v2_len(), 12 + 7);
295        assert_eq!(header.to_vec(), exp);
296
297        // struct can self-validate checksum
298        assert_eq!(header.validate_crc32c_tlv().unwrap(), true);
299
300        // mangle crc32c TLV and assert that validate now fails
301        *header.tlvs.last_mut().unwrap().1.last_mut().unwrap() = 0x00;
302        assert_eq!(header.validate_crc32c_tlv().unwrap(), false);
303    }
304}