fraiseql-wire 2.2.1

Streaming JSON query engine for Postgres 17
Documentation
//! Transport abstraction (TCP with optional TLS vs Unix socket)

#[allow(unused_imports)] // Reason: used only in doc links for `# Errors` sections
use crate::error::WireError;
use crate::Result;
use bytes::BytesMut;
use socket2::{SockRef, TcpKeepalive};
use std::path::Path;
use std::time::Duration;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpStream, UnixStream};

/// TCP stream variant: plain or TLS-encrypted
#[allow(clippy::large_enum_variant)] // Reason: variant size difference is acceptable; boxing would add indirection in hot path
#[non_exhaustive]
pub enum TcpVariant {
    /// Plain TCP connection
    Plain(TcpStream),
    /// TLS-encrypted TCP connection
    Tls(tokio_rustls::client::TlsStream<TcpStream>),
}

impl std::fmt::Debug for TcpVariant {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        match self {
            TcpVariant::Plain(_) => f.write_str("TcpVariant::Plain(TcpStream)"),
            TcpVariant::Tls(_) => f.write_str("TcpVariant::Tls(TlsStream)"),
        }
    }
}

impl TcpVariant {
    /// Write all bytes to the stream
    ///
    /// # Errors
    ///
    /// Returns [`WireError`] if the underlying I/O write fails.
    pub async fn write_all(&mut self, buf: &[u8]) -> Result<()> {
        match self {
            TcpVariant::Plain(stream) => stream.write_all(buf).await?,
            TcpVariant::Tls(stream) => stream.write_all(buf).await?,
        }
        Ok(())
    }

    /// Flush the stream
    ///
    /// # Errors
    ///
    /// Returns [`WireError`] if the underlying I/O flush fails.
    pub async fn flush(&mut self) -> Result<()> {
        match self {
            TcpVariant::Plain(stream) => stream.flush().await?,
            TcpVariant::Tls(stream) => stream.flush().await?,
        }
        Ok(())
    }

    /// Read into buffer
    ///
    /// # Errors
    ///
    /// Returns [`WireError`] if the underlying I/O read fails.
    pub async fn read_buf(&mut self, buf: &mut BytesMut) -> Result<usize> {
        let n = match self {
            TcpVariant::Plain(stream) => stream.read_buf(buf).await?,
            TcpVariant::Tls(stream) => stream.read_buf(buf).await?,
        };
        Ok(n)
    }

    /// Shutdown the stream
    ///
    /// # Errors
    ///
    /// Returns [`WireError`] if the underlying I/O shutdown fails.
    pub async fn shutdown(&mut self) -> Result<()> {
        match self {
            TcpVariant::Plain(stream) => stream.shutdown().await?,
            TcpVariant::Tls(stream) => stream.shutdown().await?,
        }
        Ok(())
    }

    /// Apply TCP keepalive settings to the underlying socket.
    ///
    /// Extracts the raw socket reference via `socket2::SockRef` and configures
    /// `SO_KEEPALIVE` with the given idle interval. This is a no-op for TLS
    /// streams that wrap a `TcpStream`; the keepalive is applied to the inner
    /// TCP socket before the TLS handshake anyway.
    ///
    /// # Errors
    ///
    /// Returns [`WireError`] if setting the socket keepalive options fails.
    pub fn apply_keepalive(&self, idle: Duration) -> Result<()> {
        let keepalive = TcpKeepalive::new().with_time(idle);
        match self {
            TcpVariant::Plain(stream) => {
                let sock = SockRef::from(stream);
                sock.set_keepalive(true)?;
                sock.set_tcp_keepalive(&keepalive)?;
            }
            TcpVariant::Tls(stream) => {
                // The inner TcpStream is accessible via the get_ref() chain.
                let tcp = stream.get_ref().0;
                let sock = SockRef::from(tcp);
                sock.set_keepalive(true)?;
                sock.set_tcp_keepalive(&keepalive)?;
            }
        }
        Ok(())
    }
}

/// Transport layer abstraction
#[derive(Debug)]
#[allow(clippy::large_enum_variant)] // Reason: variant size difference is acceptable; boxing would add indirection in hot path
#[non_exhaustive]
pub enum Transport {
    /// TCP socket (plain or TLS)
    Tcp(TcpVariant),
    /// Unix domain socket
    Unix(UnixStream),
}

impl Transport {
    /// Connect via plain TCP
    ///
    /// # Errors
    ///
    /// Returns [`WireError`] if the TCP connection to `host:port` fails.
    pub async fn connect_tcp(host: &str, port: u16) -> Result<Self> {
        let stream = TcpStream::connect((host, port)).await?;
        Ok(Transport::Tcp(TcpVariant::Plain(stream)))
    }

    /// Connect via TLS-encrypted TCP using PostgreSQL SSL negotiation protocol.
    ///
    /// PostgreSQL requires a specific SSL upgrade sequence:
    /// 1. Send `SSLRequest` message (8 bytes)
    /// 2. Server responds with 'S' (accept) or 'N' (reject)
    /// 3. If accepted, perform TLS handshake
    ///
    /// # Errors
    ///
    /// Returns `WireError::Io` if the TCP connection or TLS handshake fails.
    /// Returns `WireError::Config` if the server rejects the SSL request.
    pub async fn connect_tcp_tls(
        host: &str,
        port: u16,
        tls_config: &crate::connection::TlsConfig,
    ) -> Result<Self> {
        use tokio::io::{AsyncReadExt, AsyncWriteExt};

        let mut tcp_stream = TcpStream::connect((host, port)).await?;

        // PostgreSQL SSLRequest message:
        // - Length: 8 (4 bytes, big-endian)
        // - Request code: 80877103 (4 bytes, big-endian) = (1234 << 16) | 5679
        let ssl_request: [u8; 8] = [
            0x00, 0x00, 0x00, 0x08, // Length = 8
            0x04, 0xd2, 0x16, 0x2f, // Request code = 80877103
        ];

        tcp_stream.write_all(&ssl_request).await?;
        tcp_stream.flush().await?;

        // Read server response (single byte: 'S' = accept, 'N' = reject)
        let mut response = [0u8; 1];
        tcp_stream.read_exact(&mut response).await?;

        match response[0] {
            b'S' => {
                // Server accepted SSL - proceed with TLS handshake
            }
            b'N' => {
                return Err(crate::WireError::Config(
                    "Server does not support SSL connections".to_string(),
                ));
            }
            other => {
                return Err(crate::WireError::Config(format!(
                    "Unexpected SSL response from server: {:02x}",
                    other
                )));
            }
        }

        // Parse server name for TLS handshake (SNI)
        let server_name = crate::connection::parse_server_name(host)?;
        let server_name = rustls_pki_types::ServerName::try_from(server_name)
            .map_err(|_| crate::WireError::Config(format!("Invalid hostname for TLS: {}", host)))?;

        // Perform TLS handshake
        let client_config = tls_config.client_config();
        let tls_connector = tokio_rustls::TlsConnector::from(client_config);
        let tls_stream = tls_connector
            .connect(server_name, tcp_stream)
            .await
            .map_err(|e| crate::WireError::Config(format!("TLS handshake failed: {}", e)))?;

        Ok(Transport::Tcp(TcpVariant::Tls(tls_stream)))
    }

    /// Connect via Unix socket
    ///
    /// # Errors
    ///
    /// Returns [`WireError`] if the Unix domain socket connection to `path` fails.
    pub async fn connect_unix(path: &Path) -> Result<Self> {
        let stream = UnixStream::connect(path).await?;
        Ok(Transport::Unix(stream))
    }

    /// Write bytes to the transport
    ///
    /// # Errors
    ///
    /// Returns [`WireError`] if the underlying I/O write fails.
    pub async fn write_all(&mut self, buf: &[u8]) -> Result<()> {
        match self {
            Transport::Tcp(variant) => variant.write_all(buf).await?,
            Transport::Unix(stream) => stream.write_all(buf).await?,
        }
        Ok(())
    }

    /// Flush the transport
    ///
    /// # Errors
    ///
    /// Returns [`WireError`] if the underlying I/O flush fails.
    pub async fn flush(&mut self) -> Result<()> {
        match self {
            Transport::Tcp(variant) => variant.flush().await?,
            Transport::Unix(stream) => stream.flush().await?,
        }
        Ok(())
    }

    /// Read bytes into buffer
    ///
    /// # Errors
    ///
    /// Returns [`WireError`] if the underlying I/O read fails.
    pub async fn read_buf(&mut self, buf: &mut BytesMut) -> Result<usize> {
        let n = match self {
            Transport::Tcp(variant) => variant.read_buf(buf).await?,
            Transport::Unix(stream) => stream.read_buf(buf).await?,
        };
        Ok(n)
    }

    /// Shutdown the transport
    ///
    /// # Errors
    ///
    /// Returns [`WireError`] if the underlying I/O shutdown fails.
    pub async fn shutdown(&mut self) -> Result<()> {
        match self {
            Transport::Tcp(variant) => variant.shutdown().await?,
            Transport::Unix(stream) => stream.shutdown().await?,
        }
        Ok(())
    }

    /// Apply TCP keepalive to this transport, if it is a TCP connection.
    ///
    /// A no-op for Unix socket transports (keepalive is a TCP-layer feature).
    /// Logs a warning and returns `Ok(())` rather than failing if the platform
    /// does not support the requested keepalive interval.
    ///
    /// # Errors
    ///
    /// Returns [`WireError`] if setting the socket keepalive options fails on a TCP transport.
    pub fn apply_keepalive(&self, idle: Duration) -> Result<()> {
        match self {
            Transport::Tcp(variant) => variant.apply_keepalive(idle),
            Transport::Unix(_) => Ok(()), // keepalive not applicable on Unix sockets
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[tokio::test]
    async fn test_tcp_connect_failure() {
        let result = Transport::connect_tcp("localhost", 9999).await;
        assert!(
            result.is_err(),
            "expected Err for connection to closed port 9999, got: {result:?}"
        );
    }
}