1use crate::protocol::v2::{
4 Addresses, Protocol, Type, TypeLengthValue, TypeLengthValues, LENGTH, MINIMUM_LENGTH,
5 MINIMUM_TLV_LENGTH, PROTOCOL_PREFIX,
6};
7use std::io::{self, Write};
8
9#[derive(Debug, Default)]
25pub struct Writer {
26 bytes: Vec<u8>,
27}
28
29#[derive(Debug)]
54pub struct Builder {
55 header: Option<Vec<u8>>,
56 version_command: u8,
57 address_family_protocol: u8,
58 addresses: Addresses,
59 length: Option<u16>,
60 additional_capacity: usize,
61}
62
63impl Writer {
64 pub fn finish(self) -> Vec<u8> {
67 self.bytes
68 }
69}
70
71impl From<Vec<u8>> for Writer {
72 fn from(bytes: Vec<u8>) -> Self {
73 Writer { bytes }
74 }
75}
76
77impl Write for Writer {
78 fn write(&mut self, buffer: &[u8]) -> io::Result<usize> {
79 if self.bytes.len() > (u16::MAX as usize) + MINIMUM_LENGTH {
80 Err(io::ErrorKind::WriteZero.into())
81 } else {
82 self.bytes.extend_from_slice(buffer);
83 Ok(buffer.len())
84 }
85 }
86
87 fn flush(&mut self) -> io::Result<()> {
88 Ok(())
89 }
90}
91
92pub trait WriteToHeader {
94 fn write_to(&self, writer: &mut Writer) -> io::Result<usize>;
98
99 fn to_bytes(&self) -> io::Result<Vec<u8>> {
101 let mut writer = Writer::default();
102
103 self.write_to(&mut writer)?;
104
105 Ok(writer.finish())
106 }
107}
108
109impl WriteToHeader for Addresses {
110 fn write_to(&self, writer: &mut Writer) -> io::Result<usize> {
111 match self {
112 Addresses::Unspecified => (),
113 Addresses::IPv4(a) => {
114 writer.write_all(a.source_address.octets().as_slice())?;
115 writer.write_all(a.destination_address.octets().as_slice())?;
116 writer.write_all(a.source_port.to_be_bytes().as_slice())?;
117 writer.write_all(a.destination_port.to_be_bytes().as_slice())?;
118 }
119 Addresses::IPv6(a) => {
120 writer.write_all(a.source_address.octets().as_slice())?;
121 writer.write_all(a.destination_address.octets().as_slice())?;
122 writer.write_all(a.source_port.to_be_bytes().as_slice())?;
123 writer.write_all(a.destination_port.to_be_bytes().as_slice())?;
124 }
125 Addresses::Unix(a) => {
126 writer.write_all(a.source.as_slice())?;
127 writer.write_all(a.destination.as_slice())?;
128 }
129 };
130
131 Ok(self.len())
132 }
133}
134
135impl WriteToHeader for TypeLengthValue<'_> {
136 fn write_to(&self, writer: &mut Writer) -> io::Result<usize> {
137 if self.value.len() > u16::MAX as usize {
138 return Err(io::ErrorKind::WriteZero.into());
139 }
140
141 writer.write_all([self.kind].as_slice())?;
142 writer.write_all((self.value.len() as u16).to_be_bytes().as_slice())?;
143 writer.write_all(self.value.as_ref())?;
144
145 Ok(MINIMUM_TLV_LENGTH + self.value.len())
146 }
147}
148
149impl<T: Copy + Into<u8>> WriteToHeader for (T, &[u8]) {
150 fn write_to(&self, writer: &mut Writer) -> io::Result<usize> {
151 let kind = self.0.into();
152 let value = self.1;
153
154 if value.len() > u16::MAX as usize {
155 return Err(io::ErrorKind::WriteZero.into());
156 }
157
158 writer.write_all([kind].as_slice())?;
159 writer.write_all((value.len() as u16).to_be_bytes().as_slice())?;
160 writer.write_all(value)?;
161
162 Ok(MINIMUM_TLV_LENGTH + value.len())
163 }
164}
165
166impl WriteToHeader for TypeLengthValues<'_> {
167 fn write_to(&self, writer: &mut Writer) -> io::Result<usize> {
168 let bytes = self.as_bytes();
169
170 writer.write_all(bytes)?;
171
172 Ok(bytes.len())
173 }
174}
175
176impl WriteToHeader for [u8] {
177 fn write_to(&self, writer: &mut Writer) -> io::Result<usize> {
178 let slice = self;
179
180 if slice.len() > u16::MAX as usize {
181 return Err(io::ErrorKind::WriteZero.into());
182 }
183
184 writer.write_all(slice)?;
185
186 Ok(slice.len())
187 }
188}
189
190impl<T: ?Sized + WriteToHeader> WriteToHeader for &T {
191 fn write_to(&self, writer: &mut Writer) -> io::Result<usize> {
192 (*self).write_to(writer)
193 }
194}
195
196impl WriteToHeader for Type {
197 fn write_to(&self, writer: &mut Writer) -> io::Result<usize> {
198 writer.write([(*self).into()].as_slice())
199 }
200}
201
202macro_rules! impl_write_to_header {
203 ($t:ident) => {
204 impl WriteToHeader for $t {
205 fn write_to(&self, writer: &mut Writer) -> io::Result<usize> {
206 let bytes = self.to_be_bytes();
207
208 writer.write_all(bytes.as_slice())?;
209
210 Ok(bytes.len())
211 }
212 }
213 };
214}
215
216impl_write_to_header!(u8);
217impl_write_to_header!(u16);
218impl_write_to_header!(u32);
219impl_write_to_header!(u64);
220impl_write_to_header!(u128);
221impl_write_to_header!(usize);
222
223impl_write_to_header!(i8);
224impl_write_to_header!(i16);
225impl_write_to_header!(i32);
226impl_write_to_header!(i64);
227impl_write_to_header!(i128);
228impl_write_to_header!(isize);
229
230impl Builder {
231 pub const fn new(version_command: u8, address_family_protocol: u8) -> Self {
235 Builder {
236 header: None,
237 version_command,
238 address_family_protocol,
239 addresses: Addresses::Unspecified,
240 length: None,
241 additional_capacity: 0,
242 }
243 }
244
245 pub fn with_addresses<T: Into<Addresses>>(
249 version_command: u8,
250 protocol: Protocol,
251 addresses: T,
252 ) -> Self {
253 let addresses = addresses.into();
254
255 Builder {
256 header: None,
257 version_command,
258 address_family_protocol: addresses.address_family() | protocol,
259 addresses,
260 length: None,
261 additional_capacity: 0,
262 }
263 }
264
265 pub fn reserve_capacity(mut self, capacity: usize) -> Self {
269 match self.header {
270 None => self.additional_capacity += capacity,
271 Some(ref mut header) => header.reserve(capacity),
272 }
273
274 self
275 }
276
277 pub fn set_reserve_capacity(&mut self, capacity: usize) -> &mut Self {
281 match self.header {
282 None => self.additional_capacity += capacity,
283 Some(ref mut header) => header.reserve(capacity),
284 }
285
286 self
287 }
288
289 pub fn set_length<T: Into<Option<u16>>>(mut self, length: T) -> Self {
292 self.length = length.into();
293 self
294 }
295
296 pub fn write_payloads<T, I, II>(mut self, payloads: II) -> io::Result<Self>
299 where
300 T: WriteToHeader,
301 I: Iterator<Item = T>,
302 II: IntoIterator<IntoIter = I, Item = T>,
303 {
304 self.write_header()?;
305
306 let mut writer = Writer::from(self.header.take().unwrap_or_default());
307
308 for item in payloads {
309 item.write_to(&mut writer)?;
310 }
311
312 self.header = Some(writer.finish());
313
314 Ok(self)
315 }
316
317 pub fn write_payload<T: WriteToHeader>(mut self, payload: T) -> io::Result<Self> {
320 self.write_header()?;
321 self.write_internal(payload)?;
322
323 Ok(self)
324 }
325
326 pub fn write_tlv(self, kind: impl Into<u8>, value: &[u8]) -> io::Result<Self> {
331 self.write_payload(TypeLengthValue::new(kind, value))
332 }
333
334 fn write_internal<T: WriteToHeader>(&mut self, payload: T) -> io::Result<()> {
336 let mut writer = Writer::from(self.header.take().unwrap_or_default());
337
338 payload.write_to(&mut writer)?;
339
340 self.header = Some(writer.finish());
341
342 Ok(())
343 }
344
345 fn write_header(&mut self) -> io::Result<()> {
348 if self.header.is_some() {
349 return Ok(());
350 }
351
352 let mut header =
353 Vec::with_capacity(MINIMUM_LENGTH + self.addresses.len() + self.additional_capacity);
354
355 let length = self.length.unwrap_or_default();
356
357 header.extend_from_slice(PROTOCOL_PREFIX);
358 header.push(self.version_command);
359 header.push(self.address_family_protocol);
360 header.extend_from_slice(length.to_be_bytes().as_slice());
361
362 let mut writer = Writer::from(header);
363
364 self.addresses.write_to(&mut writer)?;
365 self.header = Some(writer.finish());
366
367 Ok(())
368 }
369
370 pub fn build(mut self) -> io::Result<Vec<u8>> {
373 self.write_header()?;
374
375 let mut header = self.header.take().unwrap_or_default();
376
377 if self.length.is_some() {
378 return Ok(header);
379 }
380
381 if let Ok(payload_length) = u16::try_from(header[MINIMUM_LENGTH..].len()) {
382 let length = payload_length.to_be_bytes();
383 header[LENGTH..LENGTH + length.len()].copy_from_slice(length.as_slice());
384 Ok(header)
385 } else {
386 Err(io::ErrorKind::WriteZero.into())
387 }
388 }
389}
390
391#[cfg(test)]
392mod tests {
393 use super::*;
394 use crate::protocol::v2::{AddressFamily, Command, IPv4, IPv6, Protocol, Type, Unix, Version};
395
396 #[test]
397 fn build_length_too_small() {
398 let mut expected = Vec::from(PROTOCOL_PREFIX);
399 expected.extend([0x21, 0x12, 0, 1, 0, 0, 0, 1]);
400
401 let actual = Builder::new(
402 Version::Two | Command::Proxy,
403 AddressFamily::IPv4 | Protocol::Datagram,
404 )
405 .set_length(1)
406 .write_payload(1u32)
407 .unwrap()
408 .build()
409 .unwrap();
410
411 assert_eq!(actual, expected);
412 }
413
414 #[test]
415 fn build_payload_too_long() {
416 let error = Builder::new(
417 Version::Two | Command::Proxy,
418 AddressFamily::IPv4 | Protocol::Datagram,
419 )
420 .write_payload(vec![0u8; (u16::MAX as usize) + 1].as_slice())
421 .unwrap_err();
422
423 assert_eq!(error.kind(), io::ErrorKind::WriteZero);
424 }
425
426 #[test]
427 fn build_no_payload() {
428 let mut expected = Vec::from(PROTOCOL_PREFIX);
429 expected.extend([0x21, 0x01, 0, 0]);
430
431 let header = Builder::new(
432 Version::Two | Command::Proxy,
433 AddressFamily::Unspecified | Protocol::Stream,
434 )
435 .build()
436 .unwrap();
437
438 assert_eq!(header, expected);
439 }
440
441 #[test]
442 fn build_arbitrary_payload() {
443 let mut expected = Vec::from(PROTOCOL_PREFIX);
444 expected.extend([0x21, 0x01, 0, 1, 42]);
445
446 let header = Builder::new(
447 Version::Two | Command::Proxy,
448 AddressFamily::Unspecified | Protocol::Stream,
449 )
450 .write_payload(42u8)
451 .unwrap()
452 .build()
453 .unwrap();
454
455 assert_eq!(header, expected);
456 }
457
458 #[test]
459 fn build_ipv4() {
460 let mut expected = Vec::from(PROTOCOL_PREFIX);
461 expected.extend([
462 0x21, 0x12, 0, 12, 127, 0, 0, 1, 192, 168, 1, 1, 0, 80, 1, 187,
463 ]);
464
465 let addresses: Addresses = IPv4::new([127, 0, 0, 1], [192, 168, 1, 1], 80, 443).into();
466 let header = Builder::new(
467 Version::Two | Command::Proxy,
468 AddressFamily::IPv4 | Protocol::Datagram,
469 )
470 .set_length(addresses.len() as u16)
471 .write_payload(addresses)
472 .unwrap()
473 .build()
474 .unwrap();
475
476 assert_eq!(header, expected);
477 }
478
479 #[test]
480 fn build_ipv6() {
481 let source_address = [
482 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
483 0xFF, 0xF2,
484 ];
485 let destination_address = [
486 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
487 0xFF, 0xF1,
488 ];
489 let mut expected = Vec::from(PROTOCOL_PREFIX);
490 expected.extend([0x20, 0x20, 0, 36]);
491 expected.extend(source_address);
492 expected.extend(destination_address);
493 expected.extend([0, 80, 1, 187]);
494
495 let header = Builder::with_addresses(
496 Version::Two | Command::Local,
497 Protocol::Unspecified,
498 IPv6::new(source_address, destination_address, 80, 443),
499 )
500 .build()
501 .unwrap();
502
503 assert_eq!(header, expected);
504 }
505
506 #[test]
507 fn build_unix() {
508 let source_address = [0xFFu8; 108];
509 let destination_address = [0xAAu8; 108];
510
511 let addresses: Addresses = Unix::new(source_address, destination_address).into();
512 let mut expected = Vec::from(PROTOCOL_PREFIX);
513 expected.extend([0x20, 0x31, 0, 216]);
514 expected.extend(source_address);
515 expected.extend(destination_address);
516
517 let header = Builder::new(
518 Version::Two | Command::Local,
519 AddressFamily::Unix | Protocol::Stream,
520 )
521 .reserve_capacity(addresses.len())
522 .write_payload(addresses)
523 .unwrap()
524 .build()
525 .unwrap();
526
527 assert_eq!(header, expected);
528 }
529
530 #[test]
531 fn build_ipv4_with_tlv() {
532 let mut expected = Vec::from(PROTOCOL_PREFIX);
533 expected.extend([
534 0x21, 0x12, 0, 17, 127, 0, 0, 1, 192, 168, 1, 1, 0, 80, 1, 187, 4, 0, 2, 0, 42,
535 ]);
536
537 let addresses: Addresses = IPv4::new([127, 0, 0, 1], [192, 168, 1, 1], 80, 443).into();
538 let header =
539 Builder::with_addresses(Version::Two | Command::Proxy, Protocol::Datagram, addresses)
540 .reserve_capacity(5)
541 .write_tlv(Type::NoOp, [0, 42].as_slice())
542 .unwrap()
543 .build()
544 .unwrap();
545
546 assert_eq!(header, expected);
547 }
548
549 #[test]
550 fn build_ipv4_with_nested_tlv() {
551 let mut expected = Vec::from(PROTOCOL_PREFIX);
552 expected.extend([
553 0x21, 0x12, 0, 20, 127, 0, 0, 1, 192, 168, 1, 1, 0, 80, 1, 187, 0x20, 0, 5, 0, 0, 0, 0,
554 0,
555 ]);
556
557 let addresses: Addresses = IPv4::new([127, 0, 0, 1], [192, 168, 1, 1], 80, 443).into();
558 let header = Builder::new(
559 Version::Two | Command::Proxy,
560 AddressFamily::IPv4 | Protocol::Datagram,
561 )
562 .write_payload(addresses)
563 .unwrap()
564 .write_payload(Type::SSL)
565 .unwrap()
566 .write_payload(5u16)
567 .unwrap()
568 .write_payload([0u8; 5].as_slice())
569 .unwrap()
570 .build()
571 .unwrap();
572
573 assert_eq!(header, expected);
574 }
575
576 #[test]
577 fn build_ipv6_with_tlvs() {
578 let source_address = [
579 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
580 0xFF, 0xF2,
581 ];
582 let destination_address = [
583 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
584 0xFF, 0xF1,
585 ];
586 let addresses: Addresses = IPv6::new(source_address, destination_address, 80, 443).into();
587 let mut expected = Vec::from(PROTOCOL_PREFIX);
588 expected.extend([0x20, 0x20, 0, 48]);
589 expected.extend(source_address);
590 expected.extend(destination_address);
591 expected.extend([0, 80, 1, 187]);
592 expected.extend([4, 0, 1, 0, 4, 0, 1, 0, 4, 0, 1, 42]);
593
594 let header = Builder::new(
595 Version::Two | Command::Local,
596 AddressFamily::IPv6 | Protocol::Unspecified,
597 )
598 .write_payload(addresses)
599 .unwrap()
600 .write_payloads([
601 (Type::NoOp, [0].as_slice()),
602 (Type::NoOp, [0].as_slice()),
603 (Type::NoOp, [42].as_slice()),
604 ])
605 .unwrap()
606 .build()
607 .unwrap();
608
609 assert_eq!(header, expected);
610 }
611
612 #[test]
613 fn build_unix_with_tlv() {
614 let source_address = [0xFFu8; 108];
615 let destination_address = [0xAAu8; 108];
616
617 let addresses: Addresses = Unix::new(source_address, destination_address).into();
618 let mut expected = Vec::from(PROTOCOL_PREFIX);
619 expected.extend([0x20, 0x31, 0, 216]);
620 expected.extend(source_address);
621 expected.extend(destination_address);
622 expected.extend([0x20, 0, 0]);
623
624 let header = Builder::new(
625 Version::Two | Command::Local,
626 AddressFamily::Unix | Protocol::Stream,
627 )
628 .set_length(216)
629 .write_payload(addresses)
630 .unwrap()
631 .write_tlv(Type::SSL, &[])
632 .unwrap()
633 .build()
634 .unwrap();
635
636 assert_eq!(header, expected);
637 }
638}