use std::{
fmt,
net::{SocketAddr, SocketAddrV4, SocketAddrV6},
};
#[derive(PartialEq, Debug)]
pub enum ProxyProtocolHeader {
V1(HeaderV1),
V2(HeaderV2),
}
impl ProxyProtocolHeader {
pub fn into_bytes(&self) -> Vec<u8> {
match *self {
ProxyProtocolHeader::V1(ref header) => header.into_bytes(),
ProxyProtocolHeader::V2(ref header) => header.into_bytes(),
}
}
}
#[derive(Debug, PartialEq, Eq)]
pub enum ProtocolSupportedV1 {
TCP4, TCP6, UNKNOWN, }
impl fmt::Display for ProtocolSupportedV1 {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
ProtocolSupportedV1::TCP4 => write!(f, "TCP4"),
ProtocolSupportedV1::TCP6 => write!(f, "TCP6"),
ProtocolSupportedV1::UNKNOWN => write!(f, "UNKNOWN"),
}
}
}
#[derive(Debug, PartialEq, Eq)]
pub struct HeaderV1 {
pub protocol: ProtocolSupportedV1,
pub addr_src: SocketAddr,
pub addr_dst: SocketAddr,
}
const PROXY_PROTO_IDENTIFIER: &str = "PROXY";
impl HeaderV1 {
pub fn new(addr_src: SocketAddr, addr_dst: SocketAddr) -> Self {
let protocol = if addr_dst.is_ipv6() {
ProtocolSupportedV1::TCP6
} else if addr_dst.is_ipv4() {
ProtocolSupportedV1::TCP4
} else {
ProtocolSupportedV1::UNKNOWN
};
HeaderV1 {
protocol,
addr_src,
addr_dst,
}
}
pub fn into_bytes(&self) -> Vec<u8> {
let bytes = if self.protocol.eq(&ProtocolSupportedV1::UNKNOWN) {
format!("{} {}\r\n", PROXY_PROTO_IDENTIFIER, self.protocol,).into_bytes()
} else {
format!(
"{} {} {} {} {} {}\r\n",
PROXY_PROTO_IDENTIFIER,
self.protocol,
self.addr_src.ip(),
self.addr_dst.ip(),
self.addr_src.port(),
self.addr_dst.port(),
)
.into_bytes()
};
debug_assert!(
bytes.starts_with(PROXY_PROTO_IDENTIFIER.as_bytes()),
"v1 header must start with the PROXY identifier"
);
debug_assert!(
bytes.ends_with(b"\r\n"),
"v1 header must be CRLF-terminated"
);
bytes
}
}
#[derive(Debug, PartialEq, Eq)]
pub enum Command {
Local,
Proxy,
}
#[derive(Debug, PartialEq)]
pub struct HeaderV2 {
pub command: Command,
pub family: u8, pub addr: ProxyAddr,
}
impl HeaderV2 {
pub fn new(command: Command, addr_src: SocketAddr, addr_dst: SocketAddr) -> Self {
let addr = ProxyAddr::from(addr_src, addr_dst);
let family = get_family(&addr);
debug_assert_eq!(
family,
get_family(&addr),
"cached family must match the address it describes"
);
debug_assert!(
matches!(addr, ProxyAddr::AfUnspec) == (family == 0x00),
"AfUnspec iff zero family byte"
);
HeaderV2 {
command,
family,
addr,
}
}
pub fn into_bytes(&self) -> Vec<u8> {
let expected_len = self.len();
let addr_len = self.addr.len() as usize;
let mut header = Vec::with_capacity(expected_len);
let signature = [
0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A,
];
header.extend_from_slice(&signature);
debug_assert_eq!(
header.len(),
signature.len(),
"v2 header must open with exactly the 12-byte signature"
);
let command = match self.command {
Command::Local => 0,
Command::Proxy => 1,
};
let ver_and_cmd = 0x20 | command;
header.push(ver_and_cmd);
header.push(self.family);
header.extend_from_slice(&u16_to_array_of_u8(self.addr.len()));
debug_assert_eq!(
header.len(),
16,
"v2 fixed prefix (signature + ver/cmd + family + length) must be 16 bytes"
);
self.addr.write_bytes_to(&mut header);
debug_assert_eq!(
header.len(),
expected_len,
"serialized v2 header length must match HeaderV2::len()"
);
debug_assert_eq!(
header.len(),
16 + addr_len,
"serialized v2 header must be the 16-byte prefix plus the address block"
);
header
}
pub fn len(&self) -> usize {
let total = 12 + 1 + 1 + 2 + self.addr.len() as usize;
debug_assert!(
total >= 16,
"v2 header is at least its 16-byte fixed prefix"
);
debug_assert!(
total <= 16 + 216,
"v2 header never exceeds the 16-byte prefix plus the largest (unix) address block"
);
total
}
pub fn is_empty(&self) -> bool {
0 == self.len()
}
}
pub enum ProxyAddr {
Ipv4Addr {
src_addr: SocketAddrV4,
dst_addr: SocketAddrV4,
},
Ipv6Addr {
src_addr: SocketAddrV6,
dst_addr: SocketAddrV6,
},
UnixAddr {
src_addr: [u8; 108],
dst_addr: [u8; 108],
},
AfUnspec,
}
impl ProxyAddr {
pub fn from(addr_src: SocketAddr, addr_dst: SocketAddr) -> Self {
let addr = match (addr_src, addr_dst) {
(SocketAddr::V4(addr_ipv4_src), SocketAddr::V4(addr_ipv4_dst)) => ProxyAddr::Ipv4Addr {
src_addr: addr_ipv4_src,
dst_addr: addr_ipv4_dst,
},
(SocketAddr::V6(addr_ipv6_src), SocketAddr::V6(addr_ipv6_dst)) => ProxyAddr::Ipv6Addr {
src_addr: addr_ipv6_src,
dst_addr: addr_ipv6_dst,
},
_ => ProxyAddr::AfUnspec,
};
debug_assert_eq!(
matches!(addr, ProxyAddr::Ipv4Addr { .. }),
addr_src.is_ipv4() && addr_dst.is_ipv4(),
"Ipv4Addr variant iff both endpoints are IPv4"
);
debug_assert_eq!(
matches!(addr, ProxyAddr::Ipv6Addr { .. }),
addr_src.is_ipv6() && addr_dst.is_ipv6(),
"Ipv6Addr variant iff both endpoints are IPv6"
);
addr
}
fn len(&self) -> u16 {
match *self {
ProxyAddr::Ipv4Addr { .. } => 12,
ProxyAddr::Ipv6Addr { .. } => 36,
ProxyAddr::UnixAddr { .. } => 216,
ProxyAddr::AfUnspec => 0,
}
}
pub fn source(&self) -> Option<SocketAddr> {
match self {
ProxyAddr::Ipv4Addr { src_addr: src, .. } => Some(SocketAddr::V4(*src)),
ProxyAddr::Ipv6Addr { src_addr: src, .. } => Some(SocketAddr::V6(*src)),
_ => None,
}
}
pub fn destination(&self) -> Option<SocketAddr> {
match self {
ProxyAddr::Ipv4Addr { dst_addr: dst, .. } => Some(SocketAddr::V4(*dst)),
ProxyAddr::Ipv6Addr { dst_addr: dst, .. } => Some(SocketAddr::V6(*dst)),
_ => None,
}
}
fn write_bytes_to(&self, buf: &mut Vec<u8>) {
let before = buf.len();
let declared = self.len() as usize;
match *self {
ProxyAddr::Ipv4Addr { src_addr, dst_addr } => {
buf.extend_from_slice(&src_addr.ip().octets());
buf.extend_from_slice(&dst_addr.ip().octets());
buf.extend_from_slice(&u16_to_array_of_u8(src_addr.port()));
buf.extend_from_slice(&u16_to_array_of_u8(dst_addr.port()));
}
ProxyAddr::Ipv6Addr { src_addr, dst_addr } => {
buf.extend_from_slice(&src_addr.ip().octets());
buf.extend_from_slice(&dst_addr.ip().octets());
buf.extend_from_slice(&u16_to_array_of_u8(src_addr.port()));
buf.extend_from_slice(&u16_to_array_of_u8(dst_addr.port()));
}
ProxyAddr::UnixAddr { src_addr, dst_addr } => {
buf.extend_from_slice(&src_addr);
buf.extend_from_slice(&dst_addr);
}
ProxyAddr::AfUnspec => {}
};
debug_assert!(
buf.len() >= before,
"write_bytes_to must never shrink the buffer"
);
debug_assert_eq!(
buf.len() - before,
declared,
"appended address bytes must equal the declared ProxyAddr::len()"
);
}
}
impl fmt::Debug for ProxyAddr {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
ProxyAddr::Ipv4Addr { src_addr, dst_addr } => {
write!(f, "{dst_addr:?} {src_addr:?}")
}
ProxyAddr::Ipv6Addr { src_addr, dst_addr } => {
write!(f, "{dst_addr:?} {src_addr:?}")
}
ProxyAddr::UnixAddr { src_addr, dst_addr } => {
write!(f, "{:?} {:?}", &dst_addr[..], &src_addr[..])
}
ProxyAddr::AfUnspec => write!(f, "AFUNSPEC"),
}
}
}
impl PartialEq for ProxyAddr {
fn eq(&self, other: &ProxyAddr) -> bool {
match *self {
ProxyAddr::Ipv4Addr { src_addr, dst_addr } => match other {
ProxyAddr::Ipv4Addr {
src_addr: src_other,
dst_addr: dst_other,
} => *src_other == src_addr && *dst_other == dst_addr,
_ => false,
},
ProxyAddr::Ipv6Addr { src_addr, dst_addr } => match other {
ProxyAddr::Ipv6Addr {
src_addr: src_other,
dst_addr: dst_other,
} => *src_other == src_addr && *dst_other == dst_addr,
_ => false,
},
ProxyAddr::UnixAddr { src_addr, dst_addr } => match other {
ProxyAddr::UnixAddr {
src_addr: src_other,
dst_addr: dst_other,
} => src_other[..] == src_addr[..] && dst_other[..] == dst_addr[..],
_ => false,
},
ProxyAddr::AfUnspec => {
if let ProxyAddr::AfUnspec = other {
return true;
}
false
}
}
}
}
fn get_family(addr: &ProxyAddr) -> u8 {
let family = match *addr {
ProxyAddr::Ipv4Addr { .. } => 0x10 | 0x01, ProxyAddr::Ipv6Addr { .. } => 0x20 | 0x01, ProxyAddr::UnixAddr { .. } => 0x30 | 0x01, ProxyAddr::AfUnspec => 0x00, };
debug_assert!(
(family >> 4) <= 0x03,
"address family nibble must be one of AF_UNSPEC/INET/INET6/UNIX"
);
debug_assert!(
matches!(addr, ProxyAddr::AfUnspec) == (family == 0x00),
"only AfUnspec maps to the all-zero family byte"
);
debug_assert!(
matches!(addr, ProxyAddr::AfUnspec) || (family & 0x0f) == 0x01,
"concrete address families must advertise the STREAM transport"
);
family
}
fn u16_to_array_of_u8(x: u16) -> [u8; 2] {
let b1: u8 = ((x >> 8) & 0xff) as u8;
let b2: u8 = (x & 0xff) as u8;
let out = [b1, b2];
debug_assert_eq!(
u16::from_be_bytes(out),
x,
"big-endian split must round-trip the input u16"
);
out
}
#[cfg(test)]
mod test_v2 {
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
use super::*;
#[test]
fn test_u16_to_array_of_u8() {
let val_u16: u16 = 65534;
let expected = [0xff, 0xfe];
assert_eq!(expected, u16_to_array_of_u8(val_u16));
}
#[test]
fn test_deserialize_tcp_ipv4_proxy_protocol_header() {
let src_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(125, 25, 10, 1)), 8080);
let dst_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(10, 4, 5, 8)), 4200);
let header = HeaderV2::new(Command::Local, src_addr, dst_addr);
let expected = &[
0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54,
0x0A, 0x20, 0x11, 0x00, 0x0C, 0x7D, 0x19, 0x0A, 0x01, 0x0A, 0x04, 0x05, 0x08, 0x1F, 0x90, 0x10, 0x68, ];
assert_eq!(expected, &header.into_bytes()[..]);
}
#[test]
fn test_deserialize_tcp_ipv6_proxy_protocol_header() {
let src_addr = SocketAddr::new(IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)), 8080);
let dst_addr = SocketAddr::new(IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)), 4200);
let header = HeaderV2::new(Command::Proxy, src_addr, dst_addr);
let expected = [
0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54,
0x0A, 0x21, 0x21, 0x00, 0x24, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x01, 0x1F, 0x90, 0x10, 0x68,
];
assert_eq!(&expected[..], &header.into_bytes()[..]);
}
}