zlayer-proxy 0.13.0

High-performance reverse proxy with TLS termination and L4/L7 routing
Documentation
//! TCP stream proxy service
//!
//! Implements raw TCP proxying with a standalone `serve()` method.
//! Provides bidirectional tunneling between clients and backends.
//!
//! Optionally terminates TLS at the proxy (driven by the endpoint's
//! `stream.tls` config) and/or prepends a PROXY protocol v2 header to the
//! upstream connection (driven by `stream.proxy_protocol`).

use std::net::SocketAddr;
use std::sync::Arc;
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
use tokio::net::TcpListener;
use tokio_rustls::TlsAcceptor;

use super::registry::StreamRegistry;

/// Build a [`TlsAcceptor`] for L4 TLS termination from a dynamic SNI cert
/// resolver.
///
/// This reuses the same rustls server-config shape as the L7 HTTPS listener
/// (`with_no_client_auth().with_cert_resolver(..)`), so an L4 TCP endpoint with
/// `stream.tls = true` terminates TLS using the same shared certificate set
/// (ACME / hot-loaded certs) as HTTPS endpoints. Pass the `ProxyManager`'s
/// shared `Arc<SniCertResolver>` here.
#[must_use]
pub fn tls_acceptor_from_resolver(
    resolver: Arc<dyn rustls::server::ResolvesServerCert>,
) -> TlsAcceptor {
    let config = rustls::ServerConfig::builder()
        .with_no_client_auth()
        .with_cert_resolver(resolver);
    TlsAcceptor::from(Arc::new(config))
}

/// TCP stream proxy service
///
/// Listens on a port and proxies TCP connections to registered backends
/// using round-robin load balancing.
pub struct TcpStreamService {
    registry: Arc<StreamRegistry>,
    listen_port: u16,
    /// When `Some`, the proxy terminates TLS on accepted client connections
    /// using this acceptor (built from the shared SNI cert resolver) and
    /// relays the decrypted plaintext to the backend.
    tls_acceptor: Option<TlsAcceptor>,
    /// When `true`, prepend a PROXY protocol v2 header to the upstream
    /// connection so the backend can recover the real client address.
    proxy_protocol: bool,
    /// Listener local address, captured at `serve()` time. Used as the
    /// "destination" address in the PROXY protocol header.
    local_addr: std::sync::OnceLock<SocketAddr>,
}

impl TcpStreamService {
    /// Create a new TCP stream service
    #[must_use]
    pub fn new(registry: Arc<StreamRegistry>, listen_port: u16) -> Self {
        Self {
            registry,
            listen_port,
            tls_acceptor: None,
            proxy_protocol: false,
            local_addr: std::sync::OnceLock::new(),
        }
    }

    /// Enable TLS termination using the given acceptor (builder-style).
    #[must_use]
    pub fn with_tls_acceptor(mut self, acceptor: TlsAcceptor) -> Self {
        self.tls_acceptor = Some(acceptor);
        self
    }

    /// Enable PROXY protocol v2 toward the upstream backend (builder-style).
    #[must_use]
    pub fn with_proxy_protocol(mut self, enabled: bool) -> Self {
        self.proxy_protocol = enabled;
        self
    }

    /// Get the listen port
    #[must_use]
    pub fn port(&self) -> u16 {
        self.listen_port
    }

    /// Get a reference to the registry
    #[must_use]
    pub fn registry(&self) -> &Arc<StreamRegistry> {
        &self.registry
    }

    /// Run a standalone TCP accept loop on the given listener.
    ///
    /// For each accepted connection, resolves a backend from the registry and
    /// spawns a task to perform bidirectional tunneling. This method runs
    /// indefinitely until the listener encounters a fatal error.
    pub async fn serve(self: Arc<Self>, listener: TcpListener) {
        // Capture the listener local address once so the PROXY protocol header
        // can report the correct destination address/port.
        if let Ok(addr) = listener.local_addr() {
            let _ = self.local_addr.set(addr);
        }

        tracing::info!(
            port = self.listen_port,
            tls = self.tls_acceptor.is_some(),
            proxy_protocol = self.proxy_protocol,
            "TCP stream proxy listening"
        );

        loop {
            let (client_stream, client_addr) = match listener.accept().await {
                Ok(conn) => conn,
                Err(e) => {
                    // Transient errors (too many open files, etc.) -- log and retry
                    tracing::warn!(
                        port = self.listen_port,
                        error = %e,
                        "TCP accept error, retrying"
                    );
                    tokio::time::sleep(std::time::Duration::from_millis(50)).await;
                    continue;
                }
            };

            let svc = Arc::clone(&self);
            tokio::spawn(async move {
                svc.handle_raw_connection(client_stream, client_addr).await;
            });
        }
    }

    /// Handle a single raw TCP connection (resolve backend, tunnel).
    async fn handle_raw_connection(
        &self,
        client_stream: tokio::net::TcpStream,
        client_addr: SocketAddr,
    ) {
        // Resolve service for this port
        let Some(service) = self.registry.resolve_tcp(self.listen_port) else {
            tracing::warn!(
                port = self.listen_port,
                client = %client_addr,
                "No service registered for TCP port"
            );
            return;
        };

        // Select backend using round-robin
        let Some(backend) = service.select_backend() else {
            tracing::warn!(
                port = self.listen_port,
                service = %service.name,
                client = %client_addr,
                "No backends available for TCP service"
            );
            return;
        };

        tracing::debug!(
            port = self.listen_port,
            service = %service.name,
            client = %client_addr,
            backend = %backend,
            "Proxying TCP connection"
        );

        // Connect to the upstream backend
        let mut upstream = match tokio::net::TcpStream::connect(backend).await {
            Ok(stream) => stream,
            Err(e) => {
                tracing::warn!(
                    error = %e,
                    backend = %backend,
                    service = %service.name,
                    client = %client_addr,
                    "Failed to connect to TCP backend"
                );
                return;
            }
        };

        // When PROXY protocol is enabled, emit a v2 header to the upstream so
        // the backend can recover the real client address. The destination is
        // the listener's local address (captured at serve()).
        if self.proxy_protocol {
            let dst = self
                .local_addr
                .get()
                .copied()
                .unwrap_or_else(|| SocketAddr::new(backend.ip(), self.listen_port));
            let header = build_proxy_protocol_v2_header(client_addr, dst);
            if let Err(e) = upstream.write_all(&header).await {
                tracing::warn!(
                    error = %e,
                    backend = %backend,
                    service = %service.name,
                    client = %client_addr,
                    "Failed to write PROXY protocol header to backend"
                );
                return;
            }
        }

        // Terminate TLS if configured, then relay the resulting plaintext
        // stream against the upstream; otherwise relay raw.
        if let Some(acceptor) = &self.tls_acceptor {
            match acceptor.accept(client_stream).await {
                Ok(tls_stream) => {
                    Self::duplex(tls_stream, upstream).await;
                }
                Err(e) => {
                    tracing::warn!(
                        error = %e,
                        service = %service.name,
                        client = %client_addr,
                        "TLS handshake with client failed"
                    );
                }
            }
        } else {
            Self::duplex(client_stream, upstream).await;
        }
    }

    /// Bidirectional data copy between a downstream (client-facing) and an
    /// upstream (backend-facing) stream.
    ///
    /// Generic over any `AsyncRead + AsyncWrite` so it can relay either a raw
    /// `TcpStream` or a TLS-terminated `TlsStream` on the downstream side.
    /// Uses `tokio::io::copy_bidirectional` for efficient proxying.
    async fn duplex<D, U>(mut downstream: D, mut upstream: U)
    where
        D: AsyncRead + AsyncWrite + Unpin,
        U: AsyncRead + AsyncWrite + Unpin,
    {
        match tokio::io::copy_bidirectional(&mut downstream, &mut upstream).await {
            Ok((down_to_up, up_to_down)) => {
                tracing::debug!(
                    down_to_up = down_to_up,
                    up_to_down = up_to_down,
                    "TCP tunnel closed"
                );
            }
            Err(e) => {
                tracing::debug!(error = %e, "TCP tunnel error");
            }
        }
    }

    /// `pub(crate)` bidirectional splice between a downstream and upstream
    /// stream, reusing the same [`copy_bidirectional`](tokio::io::copy_bidirectional)
    /// machinery as [`Self::duplex`].
    ///
    /// Exposed so the HTTPS ingress (`server.rs`) can splice an unmanaged SNI
    /// connection straight through to its real upstream without terminating TLS.
    pub(crate) async fn splice<D, U>(downstream: D, upstream: U)
    where
        D: AsyncRead + AsyncWrite + Unpin,
        U: AsyncRead + AsyncWrite + Unpin,
    {
        Self::duplex(downstream, upstream).await;
    }
}

/// Build a PROXY protocol v2 header describing a proxied TCP connection from
/// `src` (the real client) to `dst` (the proxy's listener address).
///
/// The header layout (see the `HAProxy` PROXY protocol v2 spec):
/// - 12-byte signature `0D 0A 0D 0A 00 0D 0A 51 55 49 54 0A`
/// - byte 13: version/command `0x21` (v2, PROXY command)
/// - byte 14: address family + transport — `0x11` (`AF_INET` + STREAM) or
///   `0x21` (`AF_INET6` + STREAM)
/// - bytes 15-16: big-endian length of the following address block
///   (12 for IPv4, 36 for IPv6)
/// - address block: src IP, dst IP, src port, dst port (all big-endian)
///
/// The address family is chosen from the client (`src`) address. When `src`
/// and `dst` families differ, both addresses are coerced to the client's
/// family so the header stays internally consistent.
#[must_use]
pub fn build_proxy_protocol_v2_header(src: SocketAddr, dst: SocketAddr) -> Vec<u8> {
    const SIG: [u8; 12] = [
        0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A,
    ];

    let mut out = Vec::with_capacity(28);
    out.extend_from_slice(&SIG);
    out.push(0x21); // version 2 + PROXY command

    match src {
        SocketAddr::V4(src_v4) => {
            out.push(0x11); // AF_INET + STREAM
            out.extend_from_slice(&12u16.to_be_bytes()); // addr block len

            let dst_ip = match dst {
                SocketAddr::V4(d) => *d.ip(),
                SocketAddr::V6(_) => std::net::Ipv4Addr::UNSPECIFIED,
            };
            out.extend_from_slice(&src_v4.ip().octets());
            out.extend_from_slice(&dst_ip.octets());
            out.extend_from_slice(&src_v4.port().to_be_bytes());
            out.extend_from_slice(&dst.port().to_be_bytes());
        }
        SocketAddr::V6(src_v6) => {
            out.push(0x21); // AF_INET6 + STREAM
            out.extend_from_slice(&36u16.to_be_bytes()); // addr block len

            let dst_ip = match dst {
                SocketAddr::V6(d) => *d.ip(),
                SocketAddr::V4(d) => d.ip().to_ipv6_mapped(),
            };
            out.extend_from_slice(&src_v6.ip().octets());
            out.extend_from_slice(&dst_ip.octets());
            out.extend_from_slice(&src_v6.port().to_be_bytes());
            out.extend_from_slice(&dst.port().to_be_bytes());
        }
    }

    out
}

#[cfg(test)]
mod tests {
    use super::*;
    use std::net::{Ipv4Addr, Ipv6Addr, SocketAddrV4, SocketAddrV6};

    #[test]
    fn proxy_protocol_v2_ipv4_exact_bytes() {
        let src = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(192, 168, 1, 50), 0xABCD));
        let dst = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(10, 0, 0, 1), 5432));
        let hdr = build_proxy_protocol_v2_header(src, dst);

        let expected: Vec<u8> = vec![
            // 12-byte signature
            0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A,
            0x21, // v2 + PROXY
            0x11, // AF_INET + STREAM
            0x00, 0x0C, // addr block length = 12
            192, 168, 1, 50, // src IP
            10, 0, 0, 1, // dst IP
            0xAB, 0xCD, // src port 0xABCD
            0x15, 0x38, // dst port 5432
        ];
        assert_eq!(hdr, expected);
        assert_eq!(hdr.len(), 16 + 12);
    }

    #[test]
    fn proxy_protocol_v2_ipv6_shape() {
        let src = SocketAddr::V6(SocketAddrV6::new(
            Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 1),
            7777,
            0,
            0,
        ));
        let dst = SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::LOCALHOST, 8888, 0, 0));
        let hdr = build_proxy_protocol_v2_header(src, dst);

        // 12 sig + 1 ver/cmd + 1 fam + 2 len + 36 addr block = 52 bytes
        assert_eq!(hdr.len(), 16 + 36);
        assert_eq!(
            &hdr[..12],
            &[0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A]
        );
        assert_eq!(hdr[12], 0x21); // v2 + PROXY
        assert_eq!(hdr[13], 0x21); // AF_INET6 + STREAM
        assert_eq!(&hdr[14..16], &36u16.to_be_bytes());
        // src IP starts at byte 16
        assert_eq!(
            &hdr[16..32],
            &src.ip().to_string().parse::<Ipv6Addr>().unwrap().octets()
        );
        // ports at the tail
        assert_eq!(&hdr[48..50], &7777u16.to_be_bytes());
        assert_eq!(&hdr[50..52], &8888u16.to_be_bytes());
    }
}