use std::convert::Infallible;
use std::future::Future;
use std::net::IpAddr;
use std::net::Ipv4Addr;
use std::net::Ipv6Addr;
use std::net::SocketAddr;
use std::sync::Arc;
use hyper::server::conn::http1;
use hyper::service::service_fn;
use tako_rs_core::body::TakoBody;
use tako_rs_core::conn_info::ConnInfo;
use tako_rs_core::router::Router;
use tako_rs_core::types::BoxError;
use tokio::io::AsyncReadExt;
use tokio::task::JoinSet;
use crate::ServerConfig;
const PROXY_V2_SIG: [u8; 12] = *b"\r\n\r\n\0\r\nQUIT\n";
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ProxyVersion {
V1,
V2,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ProxyTransport {
Tcp,
Udp,
Unknown,
}
#[derive(Debug, Clone)]
pub struct ProxyTlv {
pub kind: u8,
pub value: Vec<u8>,
}
#[derive(Debug, Clone, Default)]
pub struct ProxyTlsInfo {
pub client_flags: u8,
pub verify: u32,
pub version: Option<String>,
pub common_name: Option<String>,
pub cipher: Option<String>,
pub sig_alg: Option<String>,
pub key_alg: Option<String>,
}
const MAX_PROXY_ADDR_LEN: usize = 536;
const PP2_TYPE_ALPN: u8 = 0x01;
const PP2_TYPE_AUTHORITY: u8 = 0x02;
const PP2_TYPE_CRC32C: u8 = 0x03;
const PP2_TYPE_NOOP: u8 = 0x04;
const PP2_TYPE_UNIQUE_ID: u8 = 0x05;
const PP2_TYPE_SSL: u8 = 0x20;
const PP2_SUBTYPE_SSL_VERSION: u8 = 0x21;
const PP2_SUBTYPE_SSL_CN: u8 = 0x22;
const PP2_SUBTYPE_SSL_CIPHER: u8 = 0x23;
const PP2_SUBTYPE_SSL_SIG_ALG: u8 = 0x24;
const PP2_SUBTYPE_SSL_KEY_ALG: u8 = 0x25;
const PP2_TYPE_NETNS: u8 = 0x30;
const PP2_TYPE_AWS_VPC_ENDPOINT_ID: u8 = 0xEA;
#[derive(Debug, Clone)]
pub struct ProxyHeader {
pub version: ProxyVersion,
pub transport: ProxyTransport,
pub source: Option<SocketAddr>,
pub destination: Option<SocketAddr>,
pub source_unix: Option<std::path::PathBuf>,
pub destination_unix: Option<std::path::PathBuf>,
pub authority: Option<String>,
pub alpn: Option<Vec<u8>>,
pub aws_vpc_endpoint_id: Option<String>,
pub tls: Option<ProxyTlsInfo>,
pub unique_id: Option<Vec<u8>>,
pub tlvs: Vec<ProxyTlv>,
pub crc32c_verified: Option<bool>,
}
impl ProxyHeader {
fn empty(version: ProxyVersion, transport: ProxyTransport) -> Self {
Self {
version,
transport,
source: None,
destination: None,
source_unix: None,
destination_unix: None,
authority: None,
alpn: None,
aws_vpc_endpoint_id: None,
tls: None,
unique_id: None,
tlvs: Vec::new(),
crc32c_verified: None,
}
}
}
fn apply_tlvs(header: &mut ProxyHeader, mut buf: &[u8]) {
while buf.len() >= 3 {
let kind = buf[0];
let len = u16::from_be_bytes([buf[1], buf[2]]) as usize;
if buf.len() < 3 + len {
break;
}
let value = &buf[3..3 + len];
match kind {
PP2_TYPE_ALPN => header.alpn = Some(value.to_vec()),
PP2_TYPE_AUTHORITY => {
if let Ok(s) = std::str::from_utf8(value) {
header.authority = Some(s.to_string());
}
}
PP2_TYPE_AWS_VPC_ENDPOINT_ID => {
if let Ok(s) = std::str::from_utf8(value) {
header.aws_vpc_endpoint_id = Some(s.to_string());
}
}
PP2_TYPE_UNIQUE_ID => header.unique_id = Some(value.to_vec()),
PP2_TYPE_SSL => {
#[allow(clippy::collapsible_match)]
if value.len() >= 5 {
let mut tls = ProxyTlsInfo {
client_flags: value[0],
verify: u32::from_be_bytes([value[1], value[2], value[3], value[4]]),
..Default::default()
};
let mut sub = &value[5..];
while sub.len() >= 3 {
let sk = sub[0];
let slen = u16::from_be_bytes([sub[1], sub[2]]) as usize;
if sub.len() < 3 + slen {
break;
}
let sval = &sub[3..3 + slen];
match sk {
PP2_SUBTYPE_SSL_VERSION => {
tls.version = std::str::from_utf8(sval).ok().map(str::to_string);
}
PP2_SUBTYPE_SSL_CN => {
tls.common_name = std::str::from_utf8(sval).ok().map(str::to_string);
}
PP2_SUBTYPE_SSL_CIPHER => {
tls.cipher = std::str::from_utf8(sval).ok().map(str::to_string);
}
PP2_SUBTYPE_SSL_SIG_ALG => {
tls.sig_alg = std::str::from_utf8(sval).ok().map(str::to_string);
}
PP2_SUBTYPE_SSL_KEY_ALG => {
tls.key_alg = std::str::from_utf8(sval).ok().map(str::to_string);
}
_ => {}
}
sub = &sub[3 + slen..];
}
header.tls = Some(tls);
}
}
PP2_TYPE_CRC32C | PP2_TYPE_NOOP | PP2_TYPE_NETNS => {}
_ => {}
}
header.tlvs.push(ProxyTlv {
kind,
value: value.to_vec(),
});
buf = &buf[3 + len..];
}
}
pub async fn read_proxy_protocol<R: AsyncReadExt + Unpin>(
reader: &mut R,
) -> std::io::Result<ProxyHeader> {
let mut sig = [0u8; 12];
reader.read_exact(&mut sig).await?;
if sig == PROXY_V2_SIG {
parse_v2(reader, &sig).await
} else if sig.starts_with(b"PROXY ") {
parse_v1(reader, &sig).await
} else {
Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"invalid PROXY protocol header: unrecognized signature",
))
}
}
async fn parse_v1<R: AsyncReadExt + Unpin>(
reader: &mut R,
initial: &[u8; 12],
) -> std::io::Result<ProxyHeader> {
let mut line = Vec::from(&initial[..]);
loop {
let mut byte = [0u8; 1];
reader.read_exact(&mut byte).await?;
line.push(byte[0]);
if line.ends_with(b"\r\n") {
break;
}
if line.len() > 107 {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"PROXY v1 header exceeds maximum length",
));
}
}
let text = std::str::from_utf8(&line).map_err(|_| {
std::io::Error::new(
std::io::ErrorKind::InvalidData,
"invalid UTF-8 in PROXY v1 header",
)
})?;
let text = text.trim_end_matches("\r\n");
let parts: Vec<&str> = text.split(' ').collect();
if parts.len() < 2 {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"malformed PROXY v1 header",
));
}
match parts[1] {
"UNKNOWN" => Ok(ProxyHeader::empty(
ProxyVersion::V1,
ProxyTransport::Unknown,
)),
"TCP4" | "TCP6" => {
if parts.len() < 6 {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"incomplete PROXY v1 TCP header",
));
}
let src_ip: IpAddr = parts[2].parse().map_err(|e| {
std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!("bad source IP: {e}"),
)
})?;
let dst_ip: IpAddr = parts[3].parse().map_err(|e| {
std::io::Error::new(std::io::ErrorKind::InvalidData, format!("bad dest IP: {e}"))
})?;
let src_port: u16 = parts[4].parse().map_err(|e| {
std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!("bad source port: {e}"),
)
})?;
let dst_port: u16 = parts[5].parse().map_err(|e| {
std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!("bad dest port: {e}"),
)
})?;
let mut header = ProxyHeader::empty(ProxyVersion::V1, ProxyTransport::Tcp);
header.source = Some(SocketAddr::new(src_ip, src_port));
header.destination = Some(SocketAddr::new(dst_ip, dst_port));
Ok(header)
}
other => Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!("unknown PROXY v1 protocol: {other}"),
)),
}
}
fn locate_crc32c_tlv(mut buf: &[u8]) -> Option<(usize, u32)> {
let mut offset = 0usize;
while buf.len() >= 3 {
let kind = buf[0];
let len = u16::from_be_bytes([buf[1], buf[2]]) as usize;
if buf.len() < 3 + len {
return None;
}
if kind == PP2_TYPE_CRC32C && len == 4 {
let value_offset = offset + 3;
let v = &buf[3..7];
let expected = u32::from_be_bytes([v[0], v[1], v[2], v[3]]);
return Some((value_offset, expected));
}
buf = &buf[3 + len..];
offset += 3 + len;
}
None
}
fn verify_v2_crc32c(
sig: &[u8; 12],
hdr: &[u8; 4],
addr_buf: &[u8],
tlv_start: usize,
) -> Option<bool> {
if tlv_start >= addr_buf.len() {
return None;
}
let (value_offset_in_tlvs, expected) = locate_crc32c_tlv(&addr_buf[tlv_start..])?;
let zero_at_in_addr = tlv_start + value_offset_in_tlvs;
let mut full = Vec::with_capacity(12 + 4 + addr_buf.len());
full.extend_from_slice(sig);
full.extend_from_slice(hdr);
full.extend_from_slice(addr_buf);
let zero_at = 16 + zero_at_in_addr;
full[zero_at..zero_at + 4].copy_from_slice(&[0, 0, 0, 0]);
let computed = crc32c::crc32c(&full);
Some(computed == expected)
}
async fn parse_v2<R: AsyncReadExt + Unpin>(
reader: &mut R,
sig: &[u8; 12],
) -> std::io::Result<ProxyHeader> {
let mut hdr = [0u8; 4];
reader.read_exact(&mut hdr).await?;
let ver_cmd = hdr[0];
let version = (ver_cmd >> 4) & 0x0F;
let command = ver_cmd & 0x0F;
if version != 2 {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!("unsupported PROXY v2 version: {version}"),
));
}
let fam_proto = hdr[1];
let family = (fam_proto >> 4) & 0x0F;
let protocol = fam_proto & 0x0F;
let addr_len = u16::from_be_bytes([hdr[2], hdr[3]]) as usize;
if addr_len > MAX_PROXY_ADDR_LEN {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!("PROXY v2 addr_len {addr_len} exceeds {MAX_PROXY_ADDR_LEN}"),
));
}
let mut addr_buf = vec![0u8; addr_len];
if addr_len > 0 {
reader.read_exact(&mut addr_buf).await?;
}
if command == 0 {
return Ok(ProxyHeader::empty(
ProxyVersion::V2,
ProxyTransport::Unknown,
));
}
let transport = match protocol {
1 => ProxyTransport::Tcp,
2 => ProxyTransport::Udp,
_ => ProxyTransport::Unknown,
};
let mut header = ProxyHeader::empty(ProxyVersion::V2, transport);
let consumed: usize = match family {
1 if addr_buf.len() >= 12 => {
let src_ip = Ipv4Addr::new(addr_buf[0], addr_buf[1], addr_buf[2], addr_buf[3]);
let dst_ip = Ipv4Addr::new(addr_buf[4], addr_buf[5], addr_buf[6], addr_buf[7]);
let src_port = u16::from_be_bytes([addr_buf[8], addr_buf[9]]);
let dst_port = u16::from_be_bytes([addr_buf[10], addr_buf[11]]);
header.source = Some(SocketAddr::new(IpAddr::V4(src_ip), src_port));
header.destination = Some(SocketAddr::new(IpAddr::V4(dst_ip), dst_port));
12
}
2 if addr_buf.len() >= 36 => {
let src_ip = Ipv6Addr::from(<[u8; 16]>::try_from(&addr_buf[0..16]).unwrap());
let dst_ip = Ipv6Addr::from(<[u8; 16]>::try_from(&addr_buf[16..32]).unwrap());
let src_port = u16::from_be_bytes([addr_buf[32], addr_buf[33]]);
let dst_port = u16::from_be_bytes([addr_buf[34], addr_buf[35]]);
header.source = Some(SocketAddr::new(IpAddr::V6(src_ip), src_port));
header.destination = Some(SocketAddr::new(IpAddr::V6(dst_ip), dst_port));
36
}
3 if addr_buf.len() >= 216 => {
let src = parse_unix_path(&addr_buf[0..108]);
let dst = parse_unix_path(&addr_buf[108..216]);
header.source_unix = src;
header.destination_unix = dst;
216
}
_ => 0,
};
header.crc32c_verified = verify_v2_crc32c(sig, &hdr, &addr_buf, consumed);
if header.crc32c_verified == Some(false) {
tracing::warn!("PROXY v2 CRC32C mismatch — header may be corrupt or spoofed");
}
if consumed < addr_buf.len() {
apply_tlvs(&mut header, &addr_buf[consumed..]);
}
Ok(header)
}
fn parse_unix_path(bytes: &[u8]) -> Option<std::path::PathBuf> {
let nul = bytes.iter().position(|b| *b == 0).unwrap_or(bytes.len());
if nul == 0 {
return None;
}
std::str::from_utf8(&bytes[..nul])
.ok()
.map(|s| std::path::PathBuf::from(s.to_string()))
}
fn format_forwarded(addr: SocketAddr) -> String {
match addr {
SocketAddr::V4(v4) => format!("for=\"{}:{}\"", v4.ip(), v4.port()),
SocketAddr::V6(v6) => format!("for=\"[{}]:{}\"", v6.ip(), v6.port()),
}
}
pub async fn serve_http_with_proxy_protocol(listener: tokio::net::TcpListener, router: Router) {
if let Err(e) = run_proxy_http(
listener,
router,
None::<std::future::Pending<()>>,
ServerConfig::default(),
)
.await
{
tracing::error!("PROXY protocol HTTP server error: {e}");
}
}
pub async fn serve_http_with_proxy_protocol_and_shutdown(
listener: tokio::net::TcpListener,
router: Router,
signal: impl Future<Output = ()> + Send + 'static,
) {
if let Err(e) = run_proxy_http(listener, router, Some(signal), ServerConfig::default()).await {
tracing::error!("PROXY protocol HTTP server error: {e}");
}
}
pub async fn serve_http_with_proxy_protocol_and_config(
listener: tokio::net::TcpListener,
router: Router,
config: ServerConfig,
) {
if let Err(e) = run_proxy_http(listener, router, None::<std::future::Pending<()>>, config).await {
tracing::error!("PROXY protocol HTTP server error: {e}");
}
}
pub async fn serve_http_with_proxy_protocol_shutdown_and_config(
listener: tokio::net::TcpListener,
router: Router,
signal: impl Future<Output = ()> + Send + 'static,
config: ServerConfig,
) {
if let Err(e) = run_proxy_http(listener, router, Some(signal), config).await {
tracing::error!("PROXY protocol HTTP server error: {e}");
}
}
async fn run_proxy_http(
listener: tokio::net::TcpListener,
router: Router,
signal: Option<impl Future<Output = ()> + Send + 'static>,
config: ServerConfig,
) -> Result<(), BoxError> {
let router = Arc::new(router);
#[cfg(feature = "plugins")]
router.setup_plugins_once();
tracing::debug!(
"Tako PROXY protocol HTTP listening on {}",
listener.local_addr()?
);
let mut join_set = JoinSet::new();
let mut accept_backoff = config.accept_backoff;
let max_conn_semaphore = config
.max_connections
.map(|n| Arc::new(tokio::sync::Semaphore::new(n)));
let drain_timeout = config.drain_timeout;
let header_read_timeout = config.header_read_timeout;
let keep_alive = config.keep_alive;
let proxy_read_timeout = config.proxy_read_timeout;
let cancel = tokio_util::sync::CancellationToken::new();
if let Some(s) = signal {
let cancel_for_signal = cancel.clone();
tokio::spawn(async move {
s.await;
cancel_for_signal.cancel();
});
}
loop {
tokio::select! {
result = listener.accept() => {
let (mut stream, _tcp_addr) = match result {
Ok(v) => { accept_backoff.reset(); v }
Err(err) => {
tracing::warn!("PROXY accept failed: {err}; backing off");
accept_backoff.sleep_and_grow().await;
continue;
}
};
let permit = if let Some(sem) = &max_conn_semaphore {
tokio::select! {
biased;
() = cancel.cancelled() => break,
permit = sem.clone().acquire_owned() => match permit {
Ok(p) => Some(p),
Err(_) => continue,
},
}
} else {
None
};
let _ = stream.set_nodelay(true);
let router = router.clone();
join_set.spawn(async move {
let proxy_header =
match tokio::time::timeout(proxy_read_timeout, read_proxy_protocol(&mut stream)).await {
Ok(Ok(h)) => h,
Ok(Err(e)) => {
tracing::warn!("Failed to parse PROXY protocol: {e}");
return;
}
Err(_) => {
tracing::warn!(
"PROXY protocol read deadline ({:?}) elapsed; dropping connection",
proxy_read_timeout,
);
return;
}
};
let real_addr = proxy_header.source;
let io = hyper_util::rt::TokioIo::new(stream);
let svc = service_fn(move |mut req| {
let router = router.clone();
let proxy_header = proxy_header.clone();
let real_addr = real_addr;
async move {
req.headers_mut().remove(http::header::FORWARDED);
req.headers_mut().remove("x-forwarded-for");
req.headers_mut().remove("x-forwarded-host");
req.headers_mut().remove("x-forwarded-proto");
if let Some(addr) = real_addr {
let forwarded_value = format_forwarded(addr);
if let Ok(v) = http::HeaderValue::from_str(&forwarded_value) {
req.headers_mut().insert(http::header::FORWARDED, v);
}
req.extensions_mut().insert(addr);
req.extensions_mut().insert(ConnInfo::tcp(addr));
}
req.extensions_mut().insert(proxy_header);
let response = router.dispatch(req.map(TakoBody::incoming)).await;
Ok::<_, Infallible>(response)
}
});
let mut http = http1::Builder::new();
http.keep_alive(keep_alive);
http.timer(hyper_util::rt::TokioTimer::new());
if let Some(t) = header_read_timeout {
http.header_read_timeout(t);
}
let conn = http.serve_connection(io, svc).with_upgrades();
if let Err(err) = conn.await {
if err.is_incomplete_message() {
tracing::debug!("client disconnected mid-message on PROXY protocol connection: {err}");
} else {
tracing::error!("Error serving PROXY protocol connection: {err}");
}
}
drop(permit);
});
}
() = cancel.cancelled() => {
tracing::info!("PROXY protocol HTTP server shutting down...");
break;
}
}
}
let drain = tokio::time::timeout(drain_timeout, async {
while join_set.join_next().await.is_some() {}
});
if drain.await.is_err() {
tracing::warn!(
"Drain timeout exceeded, aborting {} remaining connections",
join_set.len()
);
join_set.abort_all();
}
tracing::info!("PROXY protocol HTTP server shut down gracefully");
Ok(())
}