1use std::net::SocketAddr;
16
17use tokio::io::{self, AsyncWrite, AsyncWriteExt};
18
19use crate::parse::V2_SIGNATURE;
20use crate::types::{
21 AddressFamily, Command, ProxyAddress, SslInfo, Transport, TransportProtocol, Version,
22};
23
24#[must_use]
26pub struct HeaderBuilder {
27 version: Version,
28 command: Command,
29 transport: Option<Transport>,
30 source: Option<ProxyAddress>,
31 destination: Option<ProxyAddress>,
32 tlv_entries: Vec<(u8, Vec<u8>)>,
33 add_crc32c: bool,
34}
35
36impl HeaderBuilder {
37 pub fn v2_proxy(source: SocketAddr, destination: SocketAddr) -> Self {
44 assert_eq!(
45 source.is_ipv4(),
46 destination.is_ipv4(),
47 "source and destination must use the same address family"
48 );
49 let family = if source.is_ipv4() {
50 AddressFamily::Inet
51 } else {
52 AddressFamily::Inet6
53 };
54 Self {
55 version: Version::V2,
56 command: Command::Proxy,
57 transport: Some(Transport {
58 family,
59 protocol: TransportProtocol::Stream,
60 }),
61 source: Some(ProxyAddress::Inet(source)),
62 destination: Some(ProxyAddress::Inet(destination)),
63 tlv_entries: Vec::new(),
64 add_crc32c: false,
65 }
66 }
67
68 pub fn v2_local() -> Self {
70 Self {
71 version: Version::V2,
72 command: Command::Local,
73 transport: None,
74 source: None,
75 destination: None,
76 tlv_entries: Vec::new(),
77 add_crc32c: false,
78 }
79 }
80
81 pub fn v1_proxy(source: SocketAddr, destination: SocketAddr) -> Self {
87 assert_eq!(
88 source.is_ipv4(),
89 destination.is_ipv4(),
90 "source and destination must use the same address family"
91 );
92 let family = if source.is_ipv4() {
93 AddressFamily::Inet
94 } else {
95 AddressFamily::Inet6
96 };
97 Self {
98 version: Version::V1,
99 command: Command::Proxy,
100 transport: Some(Transport {
101 family,
102 protocol: TransportProtocol::Stream,
103 }),
104 source: Some(ProxyAddress::Inet(source)),
105 destination: Some(ProxyAddress::Inet(destination)),
106 tlv_entries: Vec::new(),
107 add_crc32c: false,
108 }
109 }
110
111 pub fn v1_unknown() -> Self {
113 Self {
114 version: Version::V1,
115 command: Command::Proxy,
116 transport: None,
117 source: None,
118 destination: None,
119 tlv_entries: Vec::new(),
120 add_crc32c: false,
121 }
122 }
123
124 pub fn v2_unix(
126 source: impl Into<Vec<u8>>,
127 destination: impl Into<Vec<u8>>,
128 protocol: TransportProtocol,
129 ) -> Self {
130 Self {
131 version: Version::V2,
132 command: Command::Proxy,
133 transport: Some(Transport {
134 family: AddressFamily::Unix,
135 protocol,
136 }),
137 source: Some(ProxyAddress::Unix(source.into())),
138 destination: Some(ProxyAddress::Unix(destination.into())),
139 tlv_entries: Vec::new(),
140 add_crc32c: false,
141 }
142 }
143
144 pub fn with_transport_protocol(mut self, protocol: TransportProtocol) -> Self {
146 if let Some(ref mut t) = self.transport {
147 t.protocol = protocol;
148 }
149 self
150 }
151
152 pub fn with_authority(mut self, authority: impl Into<String>) -> Self {
154 let v = authority.into().into_bytes();
155 self.tlv_entries.push((0x02, v));
156 self
157 }
158
159 pub fn with_unique_id(mut self, id: impl Into<Vec<u8>>) -> Self {
165 let id = id.into();
166 assert!(
167 id.len() <= 128,
168 "unique ID length {} exceeds the 128-byte spec maximum",
169 id.len()
170 );
171 self.tlv_entries.push((0x05, id));
172 self
173 }
174
175 pub fn with_alpn(mut self, alpn: impl Into<Vec<u8>>) -> Self {
177 self.tlv_entries.push((0x01, alpn.into()));
178 self
179 }
180
181 pub fn with_ssl(mut self, ssl: SslInfo) -> Self {
183 self.tlv_entries.push((0x20, encode_ssl_tlv_value(&ssl)));
184 self
185 }
186
187 pub fn with_netns(mut self, netns: impl Into<String>) -> Self {
189 self.tlv_entries.push((0x30, netns.into().into_bytes()));
190 self
191 }
192
193 pub fn with_raw_tlv(mut self, type_byte: u8, value: impl Into<Vec<u8>>) -> Self {
195 self.tlv_entries.push((type_byte, value.into()));
196 self
197 }
198
199 pub fn with_padding(mut self, len: u16) -> Self {
201 self.tlv_entries.push((0x04, vec![0u8; len as usize]));
202 self
203 }
204
205 pub fn with_crc32c(mut self) -> Self {
207 self.add_crc32c = true;
208 self
209 }
210
211 #[must_use]
219 pub fn build(&self) -> Vec<u8> {
220 match self.version {
221 Version::V1 => self.build_v1(),
222 Version::V2 => self.build_v2(),
223 }
224 }
225
226 pub async fn write_to<W: AsyncWrite + Unpin>(&self, writer: &mut W) -> io::Result<usize> {
233 let bytes = self.build();
234 writer.write_all(&bytes).await?;
235 Ok(bytes.len())
236 }
237
238 fn build_v1(&self) -> Vec<u8> {
239 match (&self.source, &self.destination, &self.transport) {
240 (Some(ProxyAddress::Inet(src)), Some(ProxyAddress::Inet(dst)), Some(transport)) => {
241 let proto = match transport.family {
242 AddressFamily::Inet => "TCP4",
243 AddressFamily::Inet6 => "TCP6",
244 _ => unreachable!(),
245 };
246 format!(
247 "PROXY {} {} {} {} {}\r\n",
248 proto,
249 src.ip(),
250 dst.ip(),
251 src.port(),
252 dst.port()
253 )
254 .into_bytes()
255 }
256 _ => b"PROXY UNKNOWN\r\n".to_vec(),
257 }
258 }
259
260 fn build_v2(&self) -> Vec<u8> {
261 let mut buf = Vec::with_capacity(256);
262
263 buf.extend_from_slice(V2_SIGNATURE);
265
266 let cmd_nibble = match self.command {
268 Command::Local => 0x00,
269 Command::Proxy => 0x01,
270 };
271 buf.push(0x20 | cmd_nibble);
272
273 let (fam, proto) = match &self.transport {
275 Some(t) => {
276 let f = match t.family {
277 AddressFamily::Inet => 1,
278 AddressFamily::Inet6 => 2,
279 AddressFamily::Unix => 3,
280 };
281 let p = match t.protocol {
282 TransportProtocol::Stream => 1,
283 TransportProtocol::Datagram => 2,
284 };
285 (f, p)
286 }
287 None => (0, 0),
288 };
289 buf.push((fam << 4) | proto);
290
291 let len_pos = buf.len();
293 buf.extend_from_slice(&[0, 0]);
294
295 match self.command {
297 Command::Local => {}
298 Command::Proxy => {
299 self.encode_addresses(&mut buf);
300 }
301 }
302
303 for (tlv_type, value) in &self.tlv_entries {
305 assert!(
306 value.len() <= u16::MAX as usize,
307 "TLV value length {} exceeds maximum of 65535",
308 value.len()
309 );
310 buf.push(*tlv_type);
311 buf.extend_from_slice(&(value.len() as u16).to_be_bytes());
312 buf.extend_from_slice(value);
313 }
314
315 if self.add_crc32c {
317 buf.push(0x03);
319 buf.extend_from_slice(&4u16.to_be_bytes());
320 buf.extend_from_slice(&[0, 0, 0, 0]);
321 }
322
323 let payload_len = buf.len() - 16;
325 assert!(
326 payload_len <= u16::MAX as usize,
327 "v2 payload exceeds maximum size of 65535 bytes ({payload_len} bytes)"
328 );
329 let payload_len = payload_len as u16;
330 buf[len_pos..len_pos + 2].copy_from_slice(&payload_len.to_be_bytes());
331
332 if self.add_crc32c {
334 let crc = crc32c::crc32c(&buf);
335 let crc_pos = buf.len() - 4;
336 buf[crc_pos..crc_pos + 4].copy_from_slice(&crc.to_be_bytes());
337 }
338
339 buf
340 }
341
342 fn encode_addresses(&self, buf: &mut Vec<u8>) {
343 match (&self.source, &self.destination) {
344 (Some(ProxyAddress::Inet(src)), Some(ProxyAddress::Inet(dst))) => {
345 match (src.ip(), dst.ip()) {
346 (std::net::IpAddr::V4(s), std::net::IpAddr::V4(d)) => {
347 buf.extend_from_slice(&s.octets());
348 buf.extend_from_slice(&d.octets());
349 buf.extend_from_slice(&src.port().to_be_bytes());
350 buf.extend_from_slice(&dst.port().to_be_bytes());
351 }
352 (std::net::IpAddr::V6(s), std::net::IpAddr::V6(d)) => {
353 buf.extend_from_slice(&s.octets());
354 buf.extend_from_slice(&d.octets());
355 buf.extend_from_slice(&src.port().to_be_bytes());
356 buf.extend_from_slice(&dst.port().to_be_bytes());
357 }
358 _ => {}
359 }
360 }
361 (Some(ProxyAddress::Unix(src)), Some(ProxyAddress::Unix(dst))) => {
362 let mut src_field = [0u8; 108];
363 let src_len = src.len().min(108);
364 src_field[..src_len].copy_from_slice(&src[..src_len]);
365 buf.extend_from_slice(&src_field);
366
367 let mut dst_field = [0u8; 108];
368 let dst_len = dst.len().min(108);
369 dst_field[..dst_len].copy_from_slice(&dst[..dst_len]);
370 buf.extend_from_slice(&dst_field);
371 }
372 _ => {}
373 }
374 }
375}
376
377fn encode_ssl_tlv_value(ssl: &SslInfo) -> Vec<u8> {
378 let mut buf = Vec::new();
379
380 buf.push(ssl.client_flags.bits());
382
383 let verify: u32 = if ssl.verified { 0 } else { 1 };
385 buf.extend_from_slice(&verify.to_be_bytes());
386
387 if let Some(ref v) = ssl.version {
389 encode_sub_tlv(&mut buf, 0x21, v.as_bytes());
390 }
391 if let Some(ref v) = ssl.cn {
392 encode_sub_tlv(&mut buf, 0x22, v.as_bytes());
393 }
394 if let Some(ref v) = ssl.cipher {
395 encode_sub_tlv(&mut buf, 0x23, v.as_bytes());
396 }
397 if let Some(ref v) = ssl.sig_alg {
398 encode_sub_tlv(&mut buf, 0x24, v.as_bytes());
399 }
400 if let Some(ref v) = ssl.key_alg {
401 encode_sub_tlv(&mut buf, 0x25, v.as_bytes());
402 }
403 if let Some(ref v) = ssl.group {
404 encode_sub_tlv(&mut buf, 0x26, v.as_bytes());
405 }
406 if let Some(ref v) = ssl.sig_scheme {
407 encode_sub_tlv(&mut buf, 0x27, v.as_bytes());
408 }
409 if let Some(ref v) = ssl.client_cert {
410 encode_sub_tlv(&mut buf, 0x28, v);
411 }
412
413 buf
414}
415
416fn encode_sub_tlv(buf: &mut Vec<u8>, type_byte: u8, value: &[u8]) {
417 assert!(
418 value.len() <= u16::MAX as usize,
419 "sub-TLV value length {} exceeds maximum of 65535",
420 value.len()
421 );
422 buf.push(type_byte);
423 buf.extend_from_slice(&(value.len() as u16).to_be_bytes());
424 buf.extend_from_slice(value);
425}