1use std::{fmt, io, net::SocketAddr};
2
3use arrayvec::ArrayVec;
4use nom::{IResult, Parser as _};
5use tokio::io::{AsyncWrite, AsyncWriteExt as _};
6
7use crate::AddressFamily;
8
9pub const SIGNATURE: &str = "PROXY";
10pub const MAX_HEADER_SIZE: usize = 107;
11
12#[derive(Debug, Clone)]
13pub struct Header {
14 af: AddressFamily,
16
17 src: SocketAddr,
19
20 dst: SocketAddr,
22}
23
24impl Header {
25 pub const fn new(af: AddressFamily, src: SocketAddr, dst: SocketAddr) -> Self {
26 Self { af, src, dst }
27 }
28
29 pub const fn new_inet(src: SocketAddr, dst: SocketAddr) -> Self {
30 Self::new(AddressFamily::Inet, src, dst)
31 }
32
33 pub const fn new_inet6(src: SocketAddr, dst: SocketAddr) -> Self {
34 Self::new(AddressFamily::Inet6, src, dst)
35 }
36
37 pub fn write_to(&self, wrt: &mut impl io::Write) -> io::Result<()> {
38 write!(wrt, "{self}")
39 }
40
41 pub async fn write_to_tokio(&self, wrt: &mut (impl AsyncWrite + Unpin)) -> io::Result<()> {
42 let mut buf = ArrayVec::<_, MAX_HEADER_SIZE>::new();
44 self.write_to(&mut buf)?;
45 wrt.write_all(&buf).await
46 }
47
48 pub fn try_from_bytes(slice: &[u8]) -> IResult<&[u8], Self> {
49 parsing::parse(slice)
50 }
51}
52
53impl fmt::Display for Header {
54 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
55 write!(
56 f,
57 "{proto_sig} {af} {src_ip} {dst_ip} {src_port} {dst_port}\r\n",
58 proto_sig = SIGNATURE,
59 af = self.af.v1_str(),
60 src_ip = self.src.ip(),
61 dst_ip = self.dst.ip(),
62 src_port = itoa::Buffer::new().format(self.src.port()),
63 dst_port = itoa::Buffer::new().format(self.dst.port()),
64 )
65 }
66}
67
68mod parsing {
69 use std::{
70 net::{Ipv4Addr, SocketAddrV4},
71 str::{self, FromStr},
72 };
73
74 use nom::{
75 IResult,
76 branch::alt,
77 bytes::complete::{tag, take_while},
78 character::complete::char,
79 combinator::{map, map_res},
80 };
81
82 use super::*;
83
84 fn parse_number<T: FromStr>(input: &[u8]) -> IResult<&[u8], T> {
86 map_res(take_while(|c: u8| c.is_ascii_digit()), |s: &[u8]| {
87 let s = str::from_utf8(s).map_err(|_| "utf8 error")?;
88 let val = s.parse::<T>().map_err(|_| "u8 parse error")?;
89 Ok::<_, Box<dyn std::error::Error>>(val)
90 })
91 .parse(input)
92 }
93
94 fn parse_address_family(input: &[u8]) -> IResult<&[u8], AddressFamily> {
96 map_res(alt((tag("TCP4"), tag("TCP6"))), |af: &[u8]| match af {
97 b"TCP4" => Ok(AddressFamily::Inet),
98 b"TCP6" => Ok(AddressFamily::Inet6),
99 _ => Err(io::Error::new(
100 io::ErrorKind::InvalidData,
101 "invalid address family",
102 )),
103 })
104 .parse(input)
105 }
106
107 fn parse_ipv4(input: &[u8]) -> IResult<&[u8], Ipv4Addr> {
109 map(
110 (
111 parse_number::<u8>,
112 char('.'),
113 parse_number::<u8>,
114 char('.'),
115 parse_number::<u8>,
116 char('.'),
117 parse_number::<u8>,
118 ),
119 |(a, _, b, _, c, _, d)| Ipv4Addr::new(a, b, c, d),
120 )
121 .parse(input)
122 }
123
124 pub(super) fn parse(input: &[u8]) -> IResult<&[u8], Header> {
126 map(
127 (
128 tag(SIGNATURE),
129 char(' '),
130 parse_address_family,
131 char(' '),
132 parse_ipv4,
133 char(' '),
134 parse_ipv4,
135 char(' '),
136 parse_number::<u16>,
137 char(' '),
138 parse_number::<u16>,
139 ),
140 |(_, _, af, _, src_ip, _, dst_ip, _, src_port, _, dst_port)| Header {
141 af,
142 src: SocketAddr::V4(SocketAddrV4::new(src_ip, src_port)),
143 dst: SocketAddr::V4(SocketAddrV4::new(dst_ip, dst_port)),
144 },
145 )
146 .parse(input)
147 }
148}