1use std::{
2 io,
3 net::{IpAddr, SocketAddr},
4};
5
6use smallvec::{SmallVec, ToSmallVec as _};
7use tokio::io::{AsyncWrite, AsyncWriteExt as _};
8
9use crate::{
10 AddressFamily, Command, TransportProtocol, Version,
11 tlv::{Crc32c, Tlv},
12};
13
14pub const SIGNATURE: [u8; 12] = [
15 0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A,
16];
17
18#[derive(Debug, Clone)]
19pub struct Header {
20 command: Command,
21 transport_protocol: TransportProtocol,
22 address_family: AddressFamily,
23 src: SocketAddr,
24 dst: SocketAddr,
25 tlvs: SmallVec<[(u8, SmallVec<[u8; 16]>); 4]>,
26}
27
28impl Header {
29 pub fn new(
30 command: Command,
31 transport_protocol: TransportProtocol,
32 address_family: AddressFamily,
33 src: impl Into<SocketAddr>,
34 dst: impl Into<SocketAddr>,
35 ) -> Self {
36 Self {
37 command,
38 transport_protocol,
39 address_family,
40 src: src.into(),
41 dst: dst.into(),
42 tlvs: SmallVec::new(),
43 }
44 }
45
46 pub fn new_tcp_ipv4_proxy(src: impl Into<SocketAddr>, dst: impl Into<SocketAddr>) -> Self {
47 Self::new(
48 Command::Proxy,
49 TransportProtocol::Stream,
50 AddressFamily::Inet,
51 src,
52 dst,
53 )
54 }
55
56 pub fn add_tlv(&mut self, typ: u8, value: impl AsRef<[u8]>) {
57 self.tlvs.push((typ, SmallVec::from_slice(value.as_ref())));
58 }
59
60 pub fn add_typed_tlv<T: Tlv>(&mut self, tlv: T) {
61 self.add_tlv(T::TYPE, tlv.value_bytes());
62 }
63
64 fn v2_len(&self) -> u16 {
65 let addr_len = if self.src.is_ipv4() {
66 4 + 2 } else {
68 16 + 2 };
70
71 (addr_len * 2)
72 + self
73 .tlvs
74 .iter()
75 .map(|(_, value)| 1 + 2 + value.len() as u16)
76 .sum::<u16>()
77 }
78
79 pub fn write_to(&self, wrt: &mut impl io::Write) -> io::Result<()> {
80 wrt.write_all(&SIGNATURE)?;
82
83 wrt.write_all(&[Version::V2.v2_hi() | self.command.v2_lo()])?;
85
86 wrt.write_all(&[self.address_family.v2_hi() | self.transport_protocol.v2_lo()])?;
88
89 wrt.write_all(&self.v2_len().to_be_bytes())?;
91
92 tracing::debug!("proxy rest-of-header len: {}", self.v2_len());
93
94 fn write_ip_bytes_to(wrt: &mut impl io::Write, ip: IpAddr) -> io::Result<()> {
95 match ip {
96 IpAddr::V4(ip) => wrt.write_all(&ip.octets()),
97 IpAddr::V6(ip) => wrt.write_all(&ip.octets()),
98 }
99 }
100
101 write_ip_bytes_to(wrt, self.src.ip())?;
103 write_ip_bytes_to(wrt, self.dst.ip())?;
104
105 wrt.write_all(&self.src.port().to_be_bytes())?;
107 wrt.write_all(&self.dst.port().to_be_bytes())?;
108
109 for (typ, value) in &self.tlvs {
111 wrt.write_all(&[*typ])?;
112 wrt.write_all(&(value.len() as u16).to_be_bytes())?;
113 wrt.write_all(value)?;
114 }
115
116 Ok(())
117 }
118
119 pub async fn write_to_tokio(&self, wrt: &mut (impl AsyncWrite + Unpin)) -> io::Result<()> {
120 let buf = self.to_vec();
121 wrt.write_all(&buf).await
122 }
123
124 fn to_vec(&self) -> Vec<u8> {
125 let mut buf = Vec::with_capacity(64);
127 self.write_to(&mut buf).unwrap();
128 buf
129 }
130
131 pub fn has_tlv<T: Tlv>(&self) -> bool {
132 self.tlvs.iter().any(|&(typ, _)| typ == T::TYPE)
133 }
134
135 pub fn add_crc23c_checksum(&mut self) {
141 if self.has_tlv::<Crc32c>() {
143 return;
144 }
145
146 self.add_typed_tlv(Crc32c::default());
156
157 let mut buf = Vec::new();
159 self.write_to(&mut buf).unwrap();
160
161 let crc_calc = crc32fast::hash(&buf);
163 self.tlvs.last_mut().unwrap().1 = crc_calc.to_be_bytes().to_smallvec();
164
165 tracing::debug!("checksum is {}", crc_calc);
166 }
167
168 pub fn validate_crc32c_tlv(&self) -> Option<bool> {
169 let crc_sent = self
171 .tlvs
172 .iter()
173 .filter_map(|(typ, value)| Crc32c::try_from_parts(*typ, value))
174 .next()?;
175
176 let mut this = self.clone();
187 for (typ, value) in this.tlvs.iter_mut() {
188 if Crc32c::try_from_parts(*typ, value).is_some() {
189 value.fill(0);
190 }
191 }
192
193 let mut buf = Vec::new();
194 this.write_to(&mut buf).unwrap();
195 let crc_calc = crc32fast::hash(&buf);
196
197 Some(crc_sent.checksum == crc_calc)
198 }
199}
200
201#[cfg(test)]
202mod tests {
203 use std::net::Ipv6Addr;
204
205 use const_str::hex;
206 use pretty_assertions::assert_eq;
207
208 use super::*;
209
210 #[test]
211 fn write_v2_no_tlvs() {
212 let mut exp = Vec::new();
213 exp.extend_from_slice(&SIGNATURE); exp.extend_from_slice(&[0x21, 0x11]); exp.extend_from_slice(&[0x00, 0x0C]); exp.extend_from_slice(&[127, 0, 0, 1, 127, 0, 0, 2]); exp.extend_from_slice(&[0x04, 0xd2, 0x00, 80]); let header = Header::new(
220 Command::Proxy,
221 TransportProtocol::Stream,
222 AddressFamily::Inet,
223 SocketAddr::from(([127, 0, 0, 1], 1234)),
224 SocketAddr::from(([127, 0, 0, 2], 80)),
225 );
226
227 assert_eq!(header.v2_len(), 12);
228 assert_eq!(header.to_vec(), exp);
229 }
230
231 #[test]
232 fn write_v2_ipv6_tlv_noop() {
233 let mut exp = Vec::new();
234 exp.extend_from_slice(&SIGNATURE); exp.extend_from_slice(&[0x20, 0x11]); exp.extend_from_slice(&[0x00, 0x28]); exp.extend_from_slice(&hex!("00000000000000000000000000000001")); exp.extend_from_slice(&hex!("000102030405060708090A0B0C0D0E0F")); exp.extend_from_slice(&[0x00, 80, 0xff, 0xff]); exp.extend_from_slice(&[0x04, 0x00, 0x01, 0x00]); let mut header = Header::new(
243 Command::Local,
244 TransportProtocol::Stream,
245 AddressFamily::Inet,
246 SocketAddr::from((Ipv6Addr::LOCALHOST, 80)),
247 SocketAddr::from((
248 Ipv6Addr::from([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]),
249 65535,
250 )),
251 );
252
253 header.add_tlv(0x04, [0]);
254
255 assert_eq!(header.v2_len(), 36 + 4);
256 assert_eq!(header.to_vec(), exp);
257 }
258
259 #[test]
260 fn write_v2_tlv_c2c() {
261 let mut exp = Vec::new();
262 exp.extend_from_slice(&SIGNATURE); exp.extend_from_slice(&[0x21, 0x11]); exp.extend_from_slice(&[0x00, 0x13]); exp.extend_from_slice(&[127, 0, 0, 1, 127, 0, 0, 1]); exp.extend_from_slice(&[0x00, 80, 0x00, 80]); exp.extend_from_slice(&[0x03, 0x00, 0x04, 0x00, 0x00, 0x00, 0x00]); assert_eq!(
270 crc32fast::hash(&exp),
271 u32::from_be_bytes([0x08, 0x70, 0x17, 0x7b]),
273 );
274
275 exp[31..35].copy_from_slice(&[0x08, 0x70, 0x17, 0x7b]);
277
278 let mut header = Header::new(
279 Command::Proxy,
280 TransportProtocol::Stream,
281 AddressFamily::Inet,
282 SocketAddr::from(([127, 0, 0, 1], 80)),
283 SocketAddr::from(([127, 0, 0, 1], 80)),
284 );
285
286 assert!(
287 header.validate_crc32c_tlv().is_none(),
288 "header doesn't have CRC TLV added yet"
289 );
290
291 header.add_crc23c_checksum();
293
294 assert_eq!(header.v2_len(), 12 + 7);
295 assert_eq!(header.to_vec(), exp);
296
297 assert_eq!(header.validate_crc32c_tlv().unwrap(), true);
299
300 *header.tlvs.last_mut().unwrap().1.last_mut().unwrap() = 0x00;
302 assert_eq!(header.validate_crc32c_tlv().unwrap(), false);
303 }
304}