use std::net::SocketAddr;
use std::sync::Arc;
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
use tokio::net::TcpListener;
use tokio_rustls::TlsAcceptor;
use super::registry::StreamRegistry;
#[must_use]
pub fn tls_acceptor_from_resolver(
resolver: Arc<dyn rustls::server::ResolvesServerCert>,
) -> TlsAcceptor {
let config = rustls::ServerConfig::builder()
.with_no_client_auth()
.with_cert_resolver(resolver);
TlsAcceptor::from(Arc::new(config))
}
pub struct TcpStreamService {
registry: Arc<StreamRegistry>,
listen_port: u16,
tls_acceptor: Option<TlsAcceptor>,
proxy_protocol: bool,
local_addr: std::sync::OnceLock<SocketAddr>,
}
impl TcpStreamService {
#[must_use]
pub fn new(registry: Arc<StreamRegistry>, listen_port: u16) -> Self {
Self {
registry,
listen_port,
tls_acceptor: None,
proxy_protocol: false,
local_addr: std::sync::OnceLock::new(),
}
}
#[must_use]
pub fn with_tls_acceptor(mut self, acceptor: TlsAcceptor) -> Self {
self.tls_acceptor = Some(acceptor);
self
}
#[must_use]
pub fn with_proxy_protocol(mut self, enabled: bool) -> Self {
self.proxy_protocol = enabled;
self
}
#[must_use]
pub fn port(&self) -> u16 {
self.listen_port
}
#[must_use]
pub fn registry(&self) -> &Arc<StreamRegistry> {
&self.registry
}
pub async fn serve(self: Arc<Self>, listener: TcpListener) {
if let Ok(addr) = listener.local_addr() {
let _ = self.local_addr.set(addr);
}
tracing::info!(
port = self.listen_port,
tls = self.tls_acceptor.is_some(),
proxy_protocol = self.proxy_protocol,
"TCP stream proxy listening"
);
loop {
let (client_stream, client_addr) = match listener.accept().await {
Ok(conn) => conn,
Err(e) => {
tracing::warn!(
port = self.listen_port,
error = %e,
"TCP accept error, retrying"
);
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
continue;
}
};
let svc = Arc::clone(&self);
tokio::spawn(async move {
svc.handle_raw_connection(client_stream, client_addr).await;
});
}
}
async fn handle_raw_connection(
&self,
client_stream: tokio::net::TcpStream,
client_addr: SocketAddr,
) {
let Some(service) = self.registry.resolve_tcp(self.listen_port) else {
tracing::warn!(
port = self.listen_port,
client = %client_addr,
"No service registered for TCP port"
);
return;
};
let Some(backend) = service.select_backend() else {
tracing::warn!(
port = self.listen_port,
service = %service.name,
client = %client_addr,
"No backends available for TCP service"
);
return;
};
tracing::debug!(
port = self.listen_port,
service = %service.name,
client = %client_addr,
backend = %backend,
"Proxying TCP connection"
);
let mut upstream = match tokio::net::TcpStream::connect(backend).await {
Ok(stream) => stream,
Err(e) => {
tracing::warn!(
error = %e,
backend = %backend,
service = %service.name,
client = %client_addr,
"Failed to connect to TCP backend"
);
return;
}
};
if self.proxy_protocol {
let dst = self
.local_addr
.get()
.copied()
.unwrap_or_else(|| SocketAddr::new(backend.ip(), self.listen_port));
let header = build_proxy_protocol_v2_header(client_addr, dst);
if let Err(e) = upstream.write_all(&header).await {
tracing::warn!(
error = %e,
backend = %backend,
service = %service.name,
client = %client_addr,
"Failed to write PROXY protocol header to backend"
);
return;
}
}
if let Some(acceptor) = &self.tls_acceptor {
match acceptor.accept(client_stream).await {
Ok(tls_stream) => {
Self::duplex(tls_stream, upstream).await;
}
Err(e) => {
tracing::warn!(
error = %e,
service = %service.name,
client = %client_addr,
"TLS handshake with client failed"
);
}
}
} else {
Self::duplex(client_stream, upstream).await;
}
}
async fn duplex<D, U>(mut downstream: D, mut upstream: U)
where
D: AsyncRead + AsyncWrite + Unpin,
U: AsyncRead + AsyncWrite + Unpin,
{
match tokio::io::copy_bidirectional(&mut downstream, &mut upstream).await {
Ok((down_to_up, up_to_down)) => {
tracing::debug!(
down_to_up = down_to_up,
up_to_down = up_to_down,
"TCP tunnel closed"
);
}
Err(e) => {
tracing::debug!(error = %e, "TCP tunnel error");
}
}
}
pub(crate) async fn splice<D, U>(downstream: D, upstream: U)
where
D: AsyncRead + AsyncWrite + Unpin,
U: AsyncRead + AsyncWrite + Unpin,
{
Self::duplex(downstream, upstream).await;
}
}
#[must_use]
pub fn build_proxy_protocol_v2_header(src: SocketAddr, dst: SocketAddr) -> Vec<u8> {
const SIG: [u8; 12] = [
0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A,
];
let mut out = Vec::with_capacity(28);
out.extend_from_slice(&SIG);
out.push(0x21);
match src {
SocketAddr::V4(src_v4) => {
out.push(0x11); out.extend_from_slice(&12u16.to_be_bytes());
let dst_ip = match dst {
SocketAddr::V4(d) => *d.ip(),
SocketAddr::V6(_) => std::net::Ipv4Addr::UNSPECIFIED,
};
out.extend_from_slice(&src_v4.ip().octets());
out.extend_from_slice(&dst_ip.octets());
out.extend_from_slice(&src_v4.port().to_be_bytes());
out.extend_from_slice(&dst.port().to_be_bytes());
}
SocketAddr::V6(src_v6) => {
out.push(0x21); out.extend_from_slice(&36u16.to_be_bytes());
let dst_ip = match dst {
SocketAddr::V6(d) => *d.ip(),
SocketAddr::V4(d) => d.ip().to_ipv6_mapped(),
};
out.extend_from_slice(&src_v6.ip().octets());
out.extend_from_slice(&dst_ip.octets());
out.extend_from_slice(&src_v6.port().to_be_bytes());
out.extend_from_slice(&dst.port().to_be_bytes());
}
}
out
}
#[cfg(test)]
mod tests {
use super::*;
use std::net::{Ipv4Addr, Ipv6Addr, SocketAddrV4, SocketAddrV6};
#[test]
fn proxy_protocol_v2_ipv4_exact_bytes() {
let src = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(192, 168, 1, 50), 0xABCD));
let dst = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(10, 0, 0, 1), 5432));
let hdr = build_proxy_protocol_v2_header(src, dst);
let expected: Vec<u8> = vec![
0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A,
0x21, 0x11, 0x00, 0x0C, 192, 168, 1, 50, 10, 0, 0, 1, 0xAB, 0xCD, 0x15, 0x38, ];
assert_eq!(hdr, expected);
assert_eq!(hdr.len(), 16 + 12);
}
#[test]
fn proxy_protocol_v2_ipv6_shape() {
let src = SocketAddr::V6(SocketAddrV6::new(
Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 1),
7777,
0,
0,
));
let dst = SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::LOCALHOST, 8888, 0, 0));
let hdr = build_proxy_protocol_v2_header(src, dst);
assert_eq!(hdr.len(), 16 + 36);
assert_eq!(
&hdr[..12],
&[0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A]
);
assert_eq!(hdr[12], 0x21); assert_eq!(hdr[13], 0x21); assert_eq!(&hdr[14..16], &36u16.to_be_bytes());
assert_eq!(
&hdr[16..32],
&src.ip().to_string().parse::<Ipv6Addr>().unwrap().octets()
);
assert_eq!(&hdr[48..50], &7777u16.to_be_bytes());
assert_eq!(&hdr[50..52], &8888u16.to_be_bytes());
}
}