actix_proxy_protocol/
v1.rs

1use std::{fmt, io, net::SocketAddr};
2
3use arrayvec::ArrayVec;
4use nom::{IResult, Parser as _};
5use tokio::io::{AsyncWrite, AsyncWriteExt as _};
6
7use crate::AddressFamily;
8
9pub const SIGNATURE: &str = "PROXY";
10pub const MAX_HEADER_SIZE: usize = 107;
11
12#[derive(Debug, Clone)]
13pub struct Header {
14    /// Address family.
15    af: AddressFamily,
16
17    /// Source address.
18    src: SocketAddr,
19
20    /// Destination address.
21    dst: SocketAddr,
22}
23
24impl Header {
25    pub const fn new(af: AddressFamily, src: SocketAddr, dst: SocketAddr) -> Self {
26        Self { af, src, dst }
27    }
28
29    pub const fn new_inet(src: SocketAddr, dst: SocketAddr) -> Self {
30        Self::new(AddressFamily::Inet, src, dst)
31    }
32
33    pub const fn new_inet6(src: SocketAddr, dst: SocketAddr) -> Self {
34        Self::new(AddressFamily::Inet6, src, dst)
35    }
36
37    pub fn write_to(&self, wrt: &mut impl io::Write) -> io::Result<()> {
38        write!(wrt, "{self}")
39    }
40
41    pub async fn write_to_tokio(&self, wrt: &mut (impl AsyncWrite + Unpin)) -> io::Result<()> {
42        // max length of a V1 header is 107 bytes
43        let mut buf = ArrayVec::<_, MAX_HEADER_SIZE>::new();
44        self.write_to(&mut buf)?;
45        wrt.write_all(&buf).await
46    }
47
48    pub fn try_from_bytes(slice: &[u8]) -> IResult<&[u8], Self> {
49        parsing::parse(slice)
50    }
51}
52
53impl fmt::Display for Header {
54    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
55        write!(
56            f,
57            "{proto_sig} {af} {src_ip} {dst_ip} {src_port} {dst_port}\r\n",
58            proto_sig = SIGNATURE,
59            af = self.af.v1_str(),
60            src_ip = self.src.ip(),
61            dst_ip = self.dst.ip(),
62            src_port = itoa::Buffer::new().format(self.src.port()),
63            dst_port = itoa::Buffer::new().format(self.dst.port()),
64        )
65    }
66}
67
68mod parsing {
69    use std::{
70        net::{Ipv4Addr, SocketAddrV4},
71        str::{self, FromStr},
72    };
73
74    use nom::{
75        IResult,
76        branch::alt,
77        bytes::complete::{tag, take_while},
78        character::complete::char,
79        combinator::{map, map_res},
80    };
81
82    use super::*;
83
84    /// Parses a number from serialized representation (as bytes).
85    fn parse_number<T: FromStr>(input: &[u8]) -> IResult<&[u8], T> {
86        map_res(take_while(|c: u8| c.is_ascii_digit()), |s: &[u8]| {
87            let s = str::from_utf8(s).map_err(|_| "utf8 error")?;
88            let val = s.parse::<T>().map_err(|_| "u8 parse error")?;
89            Ok::<_, Box<dyn std::error::Error>>(val)
90        })
91        .parse(input)
92    }
93
94    /// Parses an address family.
95    fn parse_address_family(input: &[u8]) -> IResult<&[u8], AddressFamily> {
96        map_res(alt((tag("TCP4"), tag("TCP6"))), |af: &[u8]| match af {
97            b"TCP4" => Ok(AddressFamily::Inet),
98            b"TCP6" => Ok(AddressFamily::Inet6),
99            _ => Err(io::Error::new(
100                io::ErrorKind::InvalidData,
101                "invalid address family",
102            )),
103        })
104        .parse(input)
105    }
106
107    /// Parses an IPv4 address from serialized representation (as bytes).
108    fn parse_ipv4(input: &[u8]) -> IResult<&[u8], Ipv4Addr> {
109        map(
110            (
111                parse_number::<u8>,
112                char('.'),
113                parse_number::<u8>,
114                char('.'),
115                parse_number::<u8>,
116                char('.'),
117                parse_number::<u8>,
118            ),
119            |(a, _, b, _, c, _, d)| Ipv4Addr::new(a, b, c, d),
120        )
121        .parse(input)
122    }
123
124    /// Parses an IPv4 address from ASCII bytes.
125    pub(super) fn parse(input: &[u8]) -> IResult<&[u8], Header> {
126        map(
127            (
128                tag(SIGNATURE),
129                char(' '),
130                parse_address_family,
131                char(' '),
132                parse_ipv4,
133                char(' '),
134                parse_ipv4,
135                char(' '),
136                parse_number::<u16>,
137                char(' '),
138                parse_number::<u16>,
139            ),
140            |(_, _, af, _, src_ip, _, dst_ip, _, src_port, _, dst_port)| Header {
141                af,
142                src: SocketAddr::V4(SocketAddrV4::new(src_ip, src_port)),
143                dst: SocketAddr::V4(SocketAddrV4::new(dst_ip, dst_port)),
144            },
145        )
146        .parse(input)
147    }
148}