use std::{net::SocketAddr, time::Duration};
use tokio::{io::AsyncReadExt, net::TcpStream, time::timeout};
use tracing::{debug, error, warn};
use super::{ProtocolResult, errors::ProxyProtocolError};
static PROXY_PROTOCOL_V1_SIGNATURE: &[u8] = b"PROXY ";
static PROXY_PROTOCOL_V2_SIGNATURE: &[u8] = b"\x0D\x0A\x0D\x0A\x00\x0D\x0A\x51\x55\x49\x54\x0A";
#[derive(Debug, Clone)]
pub struct ProxyProtocolReader {
pub enabled: bool,
pub timeout: Duration,
pub allowed_versions: Option<Vec<u8>>,
}
impl Default for ProxyProtocolReader {
fn default() -> Self {
Self {
enabled: false,
timeout: Duration::from_secs(5),
allowed_versions: None, }
}
}
impl ProxyProtocolReader {
pub fn new(enabled: bool, timeout_secs: u64, allowed_versions: Option<Vec<u8>>) -> Self {
Self {
enabled,
timeout: Duration::from_secs(timeout_secs),
allowed_versions,
}
}
pub async fn read_header(&self, stream: &mut TcpStream) -> ProtocolResult<Option<SocketAddr>> {
if !self.enabled {
debug!("Proxy protocol reading disabled, skipping header read");
return Ok(None);
}
debug!("Reading proxy protocol header");
let mut peek_buf = [0u8; 12];
let read_result = timeout(self.timeout, stream.peek(&mut peek_buf)).await;
match read_result {
Ok(Ok(n)) if n >= 5 => {
if peek_buf.starts_with(PROXY_PROTOCOL_V1_SIGNATURE) {
debug!("Detected Proxy Protocol v1");
self.read_v1_header(stream).await
} else if n >= 12 && peek_buf.starts_with(PROXY_PROTOCOL_V2_SIGNATURE) {
debug!("Detected Proxy Protocol v2");
self.read_v2_header(stream).await
} else {
debug!("No proxy protocol header detected");
Ok(None)
}
}
Ok(Ok(_)) => {
debug!("Not enough bytes to determine proxy protocol version");
Ok(None)
}
Ok(Err(e)) => {
error!("Error peeking stream: {}", e);
Err(ProxyProtocolError::Io(e.to_string()))
}
Err(_) => {
warn!("Timeout waiting for proxy protocol header");
Err(ProxyProtocolError::Other(
"Timeout waiting for proxy protocol header".to_string(),
))
}
}
}
async fn read_v1_header(&self, stream: &mut TcpStream) -> ProtocolResult<Option<SocketAddr>> {
let mut header_bytes = Vec::with_capacity(108); let mut buf = [0u8; 1];
let mut found_cr = false;
while header_bytes.len() < 108 {
if let Err(e) = stream.read_exact(&mut buf).await {
return Err(ProxyProtocolError::Io(e.to_string()));
}
header_bytes.push(buf[0]);
if found_cr && buf[0] == b'\n' {
break;
}
found_cr = buf[0] == b'\r';
}
let header_str = match std::str::from_utf8(&header_bytes) {
Ok(s) => s,
Err(e) => {
error!("Invalid UTF-8 in proxy protocol header: {}", e);
return Err(ProxyProtocolError::InvalidHeader(format!(
"Invalid UTF-8: {}",
e
)));
}
};
let parts: Vec<&str> = header_str.split_whitespace().collect();
if parts.len() < 6 {
error!("Invalid proxy protocol v1 header format");
return Err(ProxyProtocolError::InvalidHeader(
"Invalid v1 format".to_string(),
));
}
if parts[0] != "PROXY" {
error!("Invalid proxy protocol header, doesn't start with PROXY");
return Err(ProxyProtocolError::InvalidHeader(
"Missing PROXY prefix".to_string(),
));
}
let proto = parts[1];
let src_addr = parts[2];
let src_port = match parts[4].parse::<u16>() {
Ok(p) => p,
Err(e) => {
error!("Invalid source port in proxy protocol: {}", e);
return Err(ProxyProtocolError::InvalidHeader(format!(
"Invalid source port: {}",
e
)));
}
};
let addr = match proto {
"TCP4" => match src_addr.parse() {
Ok(ipv4) => Some(SocketAddr::new(std::net::IpAddr::V4(ipv4), src_port)),
Err(e) => {
error!("Invalid IPv4 address in proxy protocol: {}", e);
return Err(ProxyProtocolError::InvalidHeader(format!(
"Invalid IPv4: {}",
e
)));
}
},
"TCP6" => match src_addr.parse() {
Ok(ipv6) => Some(SocketAddr::new(std::net::IpAddr::V6(ipv6), src_port)),
Err(e) => {
error!("Invalid IPv6 address in proxy protocol: {}", e);
return Err(ProxyProtocolError::InvalidHeader(format!(
"Invalid IPv6: {}",
e
)));
}
},
"UNKNOWN" => None,
_ => {
error!("Unknown protocol family in proxy protocol: {}", proto);
return Err(ProxyProtocolError::InvalidHeader(format!(
"Unknown protocol: {}",
proto
)));
}
};
debug!("Parsed proxy protocol v1, client addr: {:?}", addr);
Ok(addr)
}
async fn read_v2_header(&self, stream: &mut TcpStream) -> ProtocolResult<Option<SocketAddr>> {
let mut header = [0u8; 16];
if let Err(e) = stream.read_exact(&mut header).await {
return Err(ProxyProtocolError::Io(e.to_string()));
}
if !header[0..12].eq(PROXY_PROTOCOL_V2_SIGNATURE) {
return Err(ProxyProtocolError::InvalidHeader(
"Invalid v2 signature".to_string(),
));
}
let addr_len = ((header[14] as u16) << 8) | (header[15] as u16);
let mut addr_data = vec![0u8; addr_len as usize];
if let Err(e) = stream.read_exact(&mut addr_data).await {
return Err(ProxyProtocolError::Io(e.to_string()));
}
let family = header[13] & 0xF0; match family {
0x10 => {
if addr_data.len() >= 12 {
let mut src_ip = [0u8; 4];
src_ip.copy_from_slice(&addr_data[0..4]);
let src_port = ((addr_data[8] as u16) << 8) | addr_data[9] as u16;
let addr = SocketAddr::new(
std::net::IpAddr::V4(std::net::Ipv4Addr::from(src_ip)),
src_port,
);
debug!("Parsed proxy protocol v2 IPv4, client addr: {}", addr);
Ok(Some(addr))
} else {
error!("IPv4 address data too short: {}", addr_data.len());
Err(ProxyProtocolError::InvalidLength(addr_data.len()))
}
}
0x20 => {
if addr_data.len() >= 36 {
let mut src_ip = [0u8; 16];
src_ip.copy_from_slice(&addr_data[0..16]);
let src_port = ((addr_data[32] as u16) << 8) | addr_data[33] as u16;
let addr = SocketAddr::new(
std::net::IpAddr::V6(std::net::Ipv6Addr::from(src_ip)),
src_port,
);
debug!("Parsed proxy protocol v2 IPv6, client addr: {}", addr);
Ok(Some(addr))
} else {
error!("IPv6 address data too short: {}", addr_data.len());
Err(ProxyProtocolError::InvalidLength(addr_data.len()))
}
}
0x30 => {
debug!("Proxy protocol v2 with Unix domain socket, not extracting address");
Ok(None)
}
0x00 => {
debug!("Proxy protocol v2 with unspecified address family");
Ok(None)
}
_ => {
error!("Unknown address family in proxy protocol v2: {:#x}", family);
Err(ProxyProtocolError::InvalidHeader(format!(
"Unknown family: {:#x}",
family
)))
}
}
}
}