use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
use log::debug;
use tokio::io::AsyncReadExt;
use crate::runtime::TcpStream;
const PROXY_V2_SIGNATURE: [u8; 12] = [
0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A,
];
const PROXY_V1_MAX_LENGTH: usize = 108;
const PROXY_V2_HEADER_SIZE: usize = 16;
const PROXY_V2_MAX_ADDR_LEN: usize = 512;
#[derive(Debug)]
pub enum ProxyProtocolResult {
Success(SocketAddr),
NotProxyProtocol,
Error(String),
}
pub async fn parse_proxy_protocol(stream: &mut TcpStream) -> ProxyProtocolResult {
let mut peek_buf = [0u8; PROXY_V2_HEADER_SIZE];
match stream.peek(&mut peek_buf).await {
Ok(n) if n >= 5 => {
if n >= 13 && peek_buf[..12] == PROXY_V2_SIGNATURE {
if (peek_buf[12] & 0xF0) == 0x20 {
return parse_proxy_protocol_v2(stream).await;
}
}
if &peek_buf[..5] == b"PROXY" {
return parse_proxy_protocol_v1(stream).await;
}
ProxyProtocolResult::NotProxyProtocol
}
Ok(_) => ProxyProtocolResult::NotProxyProtocol,
Err(e) => ProxyProtocolResult::Error(format!("Failed to peek stream: {}", e)),
}
}
fn parse_proxy_v1_line(line: &str) -> ProxyProtocolResult {
let line = line.trim_end_matches('\n').trim_end_matches('\r');
let parts: Vec<&str> = line.split(' ').collect();
if parts.is_empty() || parts[0] != "PROXY" {
return ProxyProtocolResult::Error("Invalid PROXY v1 header".into());
}
if parts.len() < 2 {
return ProxyProtocolResult::Error("PROXY v1 header too short".into());
}
match parts[1] {
"UNKNOWN" => {
debug!("PROXY v1 UNKNOWN protocol, using socket address");
ProxyProtocolResult::NotProxyProtocol
}
"TCP4" | "TCP6" => {
if parts.len() != 6 {
return ProxyProtocolResult::Error(format!(
"Invalid PROXY v1 header, expected 6 parts, got {}",
parts.len()
));
}
let src_ip: IpAddr = match parts[2].parse() {
Ok(ip) => ip,
Err(_) => {
return ProxyProtocolResult::Error(format!("Invalid source IP: {}", parts[2]));
}
};
let src_port: u16 = match parts[4].parse() {
Ok(port) => port,
Err(_) => {
return ProxyProtocolResult::Error(format!(
"Invalid source port: {}",
parts[4]
));
}
};
let src_addr = SocketAddr::new(src_ip, src_port);
debug!("PROXY v1 parsed: src={}", src_addr);
ProxyProtocolResult::Success(src_addr)
}
proto => ProxyProtocolResult::Error(format!("Unsupported PROXY v1 protocol: {}", proto)),
}
}
async fn parse_proxy_protocol_v1(stream: &mut TcpStream) -> ProxyProtocolResult {
let mut buf = [0u8; PROXY_V1_MAX_LENGTH];
let mut pos = 0;
loop {
if pos >= PROXY_V1_MAX_LENGTH {
return ProxyProtocolResult::Error("PROXY v1 header too long".into());
}
match stream.read_exact(&mut buf[pos..pos + 1]).await {
Ok(_) => {
pos += 1;
if pos >= 2 && buf[pos - 2] == b'\r' && buf[pos - 1] == b'\n' {
break;
}
}
Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
return ProxyProtocolResult::Error(
"Connection closed while reading PROXY header".into(),
);
}
Err(e) => {
return ProxyProtocolResult::Error(format!("Failed to read PROXY header: {}", e));
}
}
}
match std::str::from_utf8(&buf[..pos]) {
Ok(line) => parse_proxy_v1_line(line),
Err(_) => ProxyProtocolResult::Error("PROXY v1 header contains invalid UTF-8".into()),
}
}
async fn parse_proxy_protocol_v2(stream: &mut TcpStream) -> ProxyProtocolResult {
let mut header = [0u8; PROXY_V2_HEADER_SIZE];
if let Err(e) = stream.read_exact(&mut header).await {
return ProxyProtocolResult::Error(format!("Failed to read PROXY v2 header: {}", e));
}
let addr_len = u16::from_be_bytes([header[14], header[15]]) as usize;
if addr_len > PROXY_V2_MAX_ADDR_LEN {
return ProxyProtocolResult::Error(format!(
"PROXY v2 address length {} exceeds maximum {}",
addr_len, PROXY_V2_MAX_ADDR_LEN
));
}
let addr_data = if addr_len > 0 {
let mut buf = vec![0u8; addr_len];
if let Err(e) = stream.read_exact(&mut buf).await {
return ProxyProtocolResult::Error(format!("Failed to read PROXY v2 address: {}", e));
}
buf
} else {
Vec::new()
};
parse_proxy_v2_bytes(&header, &addr_data)
}
fn parse_proxy_v2_bytes(
header: &[u8; PROXY_V2_HEADER_SIZE],
addr_data: &[u8],
) -> ProxyProtocolResult {
if header[..12] != PROXY_V2_SIGNATURE {
return ProxyProtocolResult::Error("Invalid PROXY v2 signature".into());
}
let ver_cmd = header[12];
let version = (ver_cmd & 0xF0) >> 4;
let command = ver_cmd & 0x0F;
if version != 2 {
return ProxyProtocolResult::Error(format!("Unsupported PROXY version: {}", version));
}
let fam_proto = header[13];
let family = (fam_proto & 0xF0) >> 4;
match command {
0x00 => {
debug!("PROXY v2 LOCAL command, using socket address");
ProxyProtocolResult::NotProxyProtocol
}
0x01 => {
parse_proxy_v2_address(family, addr_data)
}
_ => ProxyProtocolResult::Error(format!("Unsupported PROXY v2 command: {}", command)),
}
}
fn parse_proxy_v2_address(family: u8, addr_data: &[u8]) -> ProxyProtocolResult {
match family {
0x00 => {
debug!("PROXY v2 AF_UNSPEC, using socket address");
ProxyProtocolResult::NotProxyProtocol
}
0x01 => {
if addr_data.len() < 12 {
return ProxyProtocolResult::Error("PROXY v2 IPv4 address data too short".into());
}
let src_ip = Ipv4Addr::new(addr_data[0], addr_data[1], addr_data[2], addr_data[3]);
let src_port = u16::from_be_bytes([addr_data[8], addr_data[9]]);
let src_addr = SocketAddr::new(IpAddr::V4(src_ip), src_port);
debug!("PROXY v2 parsed: src={}", src_addr);
ProxyProtocolResult::Success(src_addr)
}
0x02 => {
if addr_data.len() < 36 {
return ProxyProtocolResult::Error("PROXY v2 IPv6 address data too short".into());
}
let src_ip = Ipv6Addr::from(<[u8; 16]>::try_from(&addr_data[0..16]).unwrap());
let src_port = u16::from_be_bytes([addr_data[32], addr_data[33]]);
let src_addr = SocketAddr::new(IpAddr::V6(src_ip), src_port);
debug!("PROXY v2 parsed: src={}", src_addr);
ProxyProtocolResult::Success(src_addr)
}
0x03 => {
debug!("PROXY v2 AF_UNIX, cannot extract IP address, using socket address");
ProxyProtocolResult::NotProxyProtocol
}
_ => {
debug!("PROXY v2 unknown address family: {:#x}", family);
ProxyProtocolResult::NotProxyProtocol
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn build_proxy_v2_parts(
command: u8,
family: u8,
protocol: u8,
addr_data: &[u8],
) -> ([u8; PROXY_V2_HEADER_SIZE], Vec<u8>) {
let mut header = [0u8; PROXY_V2_HEADER_SIZE];
header[..12].copy_from_slice(&PROXY_V2_SIGNATURE);
header[12] = 0x20 | (command & 0x0F);
header[13] = (family << 4) | (protocol & 0x0F);
let addr_len = addr_data.len() as u16;
header[14..16].copy_from_slice(&addr_len.to_be_bytes());
(header, addr_data.to_vec())
}
#[test]
fn test_proxy_v1_tcp4() {
let line = "PROXY TCP4 192.168.0.1 192.168.0.11 56324 443\r\n";
match parse_proxy_v1_line(line) {
ProxyProtocolResult::Success(addr) => {
assert_eq!(addr.ip(), "192.168.0.1".parse::<IpAddr>().unwrap());
assert_eq!(addr.port(), 56324);
}
other => panic!("Expected Success, got {:?}", other),
}
}
#[test]
fn test_proxy_v1_tcp6() {
let line = "PROXY TCP6 2001:db8::1 2001:db8::2 56324 443\r\n";
match parse_proxy_v1_line(line) {
ProxyProtocolResult::Success(addr) => {
assert_eq!(addr.ip(), "2001:db8::1".parse::<IpAddr>().unwrap());
assert_eq!(addr.port(), 56324);
}
other => panic!("Expected Success, got {:?}", other),
}
}
#[test]
fn test_proxy_v1_unknown() {
let line = "PROXY UNKNOWN\r\n";
match parse_proxy_v1_line(line) {
ProxyProtocolResult::NotProxyProtocol => {}
other => panic!("Expected NotProxyProtocol, got {:?}", other),
}
}
#[test]
fn test_proxy_v1_unknown_with_addresses() {
let line = "PROXY UNKNOWN ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff 65535 65535\r\n";
match parse_proxy_v1_line(line) {
ProxyProtocolResult::NotProxyProtocol => {}
other => panic!("Expected NotProxyProtocol, got {:?}", other),
}
}
#[test]
fn test_proxy_v1_invalid_header() {
let line = "NOT_PROXY TCP4 192.168.0.1 192.168.0.11 56324 443\r\n";
match parse_proxy_v1_line(line) {
ProxyProtocolResult::Error(_) => {}
other => panic!("Expected Error, got {:?}", other),
}
}
#[test]
fn test_proxy_v1_missing_fields() {
let line = "PROXY TCP4 192.168.0.1\r\n";
match parse_proxy_v1_line(line) {
ProxyProtocolResult::Error(msg) => {
assert!(msg.contains("expected 6 parts"));
}
other => panic!("Expected Error, got {:?}", other),
}
}
#[test]
fn test_proxy_v1_invalid_ip() {
let line = "PROXY TCP4 not.an.ip 192.168.0.11 56324 443\r\n";
match parse_proxy_v1_line(line) {
ProxyProtocolResult::Error(msg) => {
assert!(msg.contains("Invalid source IP"));
}
other => panic!("Expected Error, got {:?}", other),
}
}
#[test]
fn test_proxy_v1_invalid_port() {
let line = "PROXY TCP4 192.168.0.1 192.168.0.11 notaport 443\r\n";
match parse_proxy_v1_line(line) {
ProxyProtocolResult::Error(msg) => {
assert!(msg.contains("Invalid source port"));
}
other => panic!("Expected Error, got {:?}", other),
}
}
#[test]
fn test_proxy_v1_unsupported_protocol() {
let line = "PROXY UDP4 192.168.0.1 192.168.0.11 56324 443\r\n";
match parse_proxy_v1_line(line) {
ProxyProtocolResult::Error(msg) => {
assert!(msg.contains("Unsupported PROXY v1 protocol"));
}
other => panic!("Expected Error, got {:?}", other),
}
}
#[test]
fn test_proxy_v2_signature() {
assert_eq!(PROXY_V2_SIGNATURE.len(), 12);
assert_eq!(PROXY_V2_SIGNATURE[4], 0x00); }
#[test]
fn test_proxy_v2_tcp4() {
let mut addr_data = Vec::new();
addr_data.extend_from_slice(&[192, 168, 1, 100]); addr_data.extend_from_slice(&[192, 168, 1, 1]); addr_data.extend_from_slice(&12345u16.to_be_bytes()); addr_data.extend_from_slice(&443u16.to_be_bytes());
let (header, addr) = build_proxy_v2_parts(0x01, 0x01, 0x01, &addr_data);
match parse_proxy_v2_bytes(&header, &addr) {
ProxyProtocolResult::Success(addr) => {
assert_eq!(addr.ip(), "192.168.1.100".parse::<IpAddr>().unwrap());
assert_eq!(addr.port(), 12345);
}
other => panic!("Expected Success, got {:?}", other),
}
}
#[test]
fn test_proxy_v2_tcp6() {
let mut addr_data = Vec::new();
addr_data.extend_from_slice(&[0x20, 0x01, 0x0d, 0xb8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]);
addr_data.extend_from_slice(&[0x20, 0x01, 0x0d, 0xb8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2]);
addr_data.extend_from_slice(&54321u16.to_be_bytes()); addr_data.extend_from_slice(&8080u16.to_be_bytes());
let (header, addr) = build_proxy_v2_parts(0x01, 0x02, 0x01, &addr_data);
match parse_proxy_v2_bytes(&header, &addr) {
ProxyProtocolResult::Success(addr) => {
assert_eq!(addr.ip(), "2001:db8::1".parse::<IpAddr>().unwrap());
assert_eq!(addr.port(), 54321);
}
other => panic!("Expected Success, got {:?}", other),
}
}
#[test]
fn test_proxy_v2_local_command() {
let (header, addr) = build_proxy_v2_parts(0x00, 0x00, 0x00, &[]);
match parse_proxy_v2_bytes(&header, &addr) {
ProxyProtocolResult::NotProxyProtocol => {}
other => panic!("Expected NotProxyProtocol, got {:?}", other),
}
}
#[test]
fn test_proxy_v2_af_unspec() {
let (header, addr) = build_proxy_v2_parts(0x01, 0x00, 0x00, &[]);
match parse_proxy_v2_bytes(&header, &addr) {
ProxyProtocolResult::NotProxyProtocol => {}
other => panic!("Expected NotProxyProtocol, got {:?}", other),
}
}
#[test]
fn test_proxy_v2_invalid_signature() {
let mut header = [0u8; PROXY_V2_HEADER_SIZE];
header[..12].copy_from_slice(b"WRONG_SIGNAT");
match parse_proxy_v2_bytes(&header, &[]) {
ProxyProtocolResult::Error(msg) => {
assert!(msg.contains("Invalid PROXY v2 signature"));
}
other => panic!("Expected Error, got {:?}", other),
}
}
#[test]
fn test_proxy_v2_ipv4_addr_too_short() {
let addr_data = vec![0u8; 8];
let (header, addr) = build_proxy_v2_parts(0x01, 0x01, 0x01, &addr_data);
match parse_proxy_v2_bytes(&header, &addr) {
ProxyProtocolResult::Error(msg) => {
assert!(msg.contains("IPv4 address data too short"));
}
other => panic!("Expected Error, got {:?}", other),
}
}
#[test]
fn test_proxy_v2_ipv6_addr_too_short() {
let addr_data = vec![0u8; 20];
let (header, addr) = build_proxy_v2_parts(0x01, 0x02, 0x01, &addr_data);
match parse_proxy_v2_bytes(&header, &addr) {
ProxyProtocolResult::Error(msg) => {
assert!(msg.contains("IPv6 address data too short"));
}
other => panic!("Expected Error, got {:?}", other),
}
}
#[test]
fn test_proxy_v2_unsupported_command() {
let (header, addr) = build_proxy_v2_parts(0x02, 0x01, 0x01, &[0u8; 12]);
match parse_proxy_v2_bytes(&header, &addr) {
ProxyProtocolResult::Error(msg) => {
assert!(msg.contains("Unsupported PROXY v2 command"));
}
other => panic!("Expected Error, got {:?}", other),
}
}
#[test]
fn test_proxy_v2_af_unix() {
let (header, addr) = build_proxy_v2_parts(0x01, 0x03, 0x00, &[0u8; 216]);
match parse_proxy_v2_bytes(&header, &addr) {
ProxyProtocolResult::NotProxyProtocol => {}
other => panic!("Expected NotProxyProtocol, got {:?}", other),
}
}
#[test]
fn test_proxy_v2_unknown_family() {
let (header, addr) = build_proxy_v2_parts(0x01, 0x04, 0x01, &[0u8; 12]);
match parse_proxy_v2_bytes(&header, &addr) {
ProxyProtocolResult::NotProxyProtocol => {}
other => panic!("Expected NotProxyProtocol, got {:?}", other),
}
}
#[test]
fn test_proxy_v1_haproxy_example() {
let line = "PROXY TCP4 255.255.255.255 255.255.255.255 65535 65535\r\n";
match parse_proxy_v1_line(line) {
ProxyProtocolResult::Success(addr) => {
assert_eq!(addr.ip(), "255.255.255.255".parse::<IpAddr>().unwrap());
assert_eq!(addr.port(), 65535);
}
other => panic!("Expected Success, got {:?}", other),
}
}
#[test]
fn test_proxy_v2_with_tlv_extensions() {
let mut addr_data = Vec::new();
addr_data.extend_from_slice(&[10, 0, 0, 1]); addr_data.extend_from_slice(&[10, 0, 0, 2]); addr_data.extend_from_slice(&8080u16.to_be_bytes()); addr_data.extend_from_slice(&80u16.to_be_bytes()); addr_data.extend_from_slice(&[0x20, 0x00, 0x04, 0x01, 0x02, 0x03, 0x04]);
let (header, addr) = build_proxy_v2_parts(0x01, 0x01, 0x01, &addr_data);
match parse_proxy_v2_bytes(&header, &addr) {
ProxyProtocolResult::Success(addr) => {
assert_eq!(addr.ip(), "10.0.0.1".parse::<IpAddr>().unwrap());
assert_eq!(addr.port(), 8080);
}
other => panic!("Expected Success, got {:?}", other),
}
}
#[test]
fn test_proxy_v2_max_addr_len() {
let mut header = [0u8; PROXY_V2_HEADER_SIZE];
header[..12].copy_from_slice(&PROXY_V2_SIGNATURE);
header[12] = 0x21; header[13] = 0x11; header[14..16].copy_from_slice(&65535u16.to_be_bytes());
assert!(PROXY_V2_MAX_ADDR_LEN < 65535);
assert_eq!(PROXY_V2_MAX_ADDR_LEN, 512);
}
}