Skip to main content

fraiseql_wire/connection/
transport.rs

1//! Transport abstraction (TCP with optional TLS vs Unix socket)
2
3use crate::Result;
4use bytes::BytesMut;
5use std::path::Path;
6use tokio::io::{AsyncReadExt, AsyncWriteExt};
7use tokio::net::{TcpStream, UnixStream};
8
9/// TCP stream variant: plain or TLS-encrypted
10#[allow(clippy::large_enum_variant)]
11pub enum TcpVariant {
12    /// Plain TCP connection
13    Plain(TcpStream),
14    /// TLS-encrypted TCP connection
15    Tls(tokio_rustls::client::TlsStream<TcpStream>),
16}
17
18impl std::fmt::Debug for TcpVariant {
19    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
20        match self {
21            TcpVariant::Plain(_) => f.write_str("TcpVariant::Plain(TcpStream)"),
22            TcpVariant::Tls(_) => f.write_str("TcpVariant::Tls(TlsStream)"),
23        }
24    }
25}
26
27impl TcpVariant {
28    /// Write all bytes to the stream
29    pub async fn write_all(&mut self, buf: &[u8]) -> Result<()> {
30        match self {
31            TcpVariant::Plain(stream) => stream.write_all(buf).await?,
32            TcpVariant::Tls(stream) => stream.write_all(buf).await?,
33        }
34        Ok(())
35    }
36
37    /// Flush the stream
38    pub async fn flush(&mut self) -> Result<()> {
39        match self {
40            TcpVariant::Plain(stream) => stream.flush().await?,
41            TcpVariant::Tls(stream) => stream.flush().await?,
42        }
43        Ok(())
44    }
45
46    /// Read into buffer
47    pub async fn read_buf(&mut self, buf: &mut BytesMut) -> Result<usize> {
48        let n = match self {
49            TcpVariant::Plain(stream) => stream.read_buf(buf).await?,
50            TcpVariant::Tls(stream) => stream.read_buf(buf).await?,
51        };
52        Ok(n)
53    }
54
55    /// Shutdown the stream
56    pub async fn shutdown(&mut self) -> Result<()> {
57        match self {
58            TcpVariant::Plain(stream) => stream.shutdown().await?,
59            TcpVariant::Tls(stream) => stream.shutdown().await?,
60        }
61        Ok(())
62    }
63}
64
65/// Transport layer abstraction
66#[derive(Debug)]
67#[allow(clippy::large_enum_variant)]
68pub enum Transport {
69    /// TCP socket (plain or TLS)
70    Tcp(TcpVariant),
71    /// Unix domain socket
72    Unix(UnixStream),
73}
74
75impl Transport {
76    /// Connect via plain TCP
77    pub async fn connect_tcp(host: &str, port: u16) -> Result<Self> {
78        let stream = TcpStream::connect((host, port)).await?;
79        Ok(Transport::Tcp(TcpVariant::Plain(stream)))
80    }
81
82    /// Connect via TLS-encrypted TCP using PostgreSQL SSL negotiation protocol.
83    ///
84    /// PostgreSQL requires a specific SSL upgrade sequence:
85    /// 1. Send SSLRequest message (8 bytes)
86    /// 2. Server responds with 'S' (accept) or 'N' (reject)
87    /// 3. If accepted, perform TLS handshake
88    pub async fn connect_tcp_tls(
89        host: &str,
90        port: u16,
91        tls_config: &crate::connection::TlsConfig,
92    ) -> Result<Self> {
93        use tokio::io::{AsyncReadExt, AsyncWriteExt};
94
95        let mut tcp_stream = TcpStream::connect((host, port)).await?;
96
97        // PostgreSQL SSLRequest message:
98        // - Length: 8 (4 bytes, big-endian)
99        // - Request code: 80877103 (4 bytes, big-endian) = (1234 << 16) | 5679
100        let ssl_request: [u8; 8] = [
101            0x00, 0x00, 0x00, 0x08, // Length = 8
102            0x04, 0xd2, 0x16, 0x2f, // Request code = 80877103
103        ];
104
105        tcp_stream.write_all(&ssl_request).await?;
106        tcp_stream.flush().await?;
107
108        // Read server response (single byte: 'S' = accept, 'N' = reject)
109        let mut response = [0u8; 1];
110        tcp_stream.read_exact(&mut response).await?;
111
112        match response[0] {
113            b'S' => {
114                // Server accepted SSL - proceed with TLS handshake
115            }
116            b'N' => {
117                return Err(crate::Error::Config(
118                    "Server does not support SSL connections".to_string(),
119                ));
120            }
121            other => {
122                return Err(crate::Error::Config(format!(
123                    "Unexpected SSL response from server: {:02x}",
124                    other
125                )));
126            }
127        }
128
129        // Parse server name for TLS handshake (SNI)
130        let server_name = crate::connection::parse_server_name(host)?;
131        let server_name = rustls_pki_types::ServerName::try_from(server_name)
132            .map_err(|_| crate::Error::Config(format!("Invalid hostname for TLS: {}", host)))?;
133
134        // Perform TLS handshake
135        let client_config = tls_config.client_config();
136        let tls_connector = tokio_rustls::TlsConnector::from(client_config);
137        let tls_stream = tls_connector
138            .connect(server_name, tcp_stream)
139            .await
140            .map_err(|e| crate::Error::Config(format!("TLS handshake failed: {}", e)))?;
141
142        Ok(Transport::Tcp(TcpVariant::Tls(tls_stream)))
143    }
144
145    /// Connect via Unix socket
146    pub async fn connect_unix(path: &Path) -> Result<Self> {
147        let stream = UnixStream::connect(path).await?;
148        Ok(Transport::Unix(stream))
149    }
150
151    /// Write bytes to the transport
152    pub async fn write_all(&mut self, buf: &[u8]) -> Result<()> {
153        match self {
154            Transport::Tcp(variant) => variant.write_all(buf).await?,
155            Transport::Unix(stream) => stream.write_all(buf).await?,
156        }
157        Ok(())
158    }
159
160    /// Flush the transport
161    pub async fn flush(&mut self) -> Result<()> {
162        match self {
163            Transport::Tcp(variant) => variant.flush().await?,
164            Transport::Unix(stream) => stream.flush().await?,
165        }
166        Ok(())
167    }
168
169    /// Read bytes into buffer
170    pub async fn read_buf(&mut self, buf: &mut BytesMut) -> Result<usize> {
171        let n = match self {
172            Transport::Tcp(variant) => variant.read_buf(buf).await?,
173            Transport::Unix(stream) => stream.read_buf(buf).await?,
174        };
175        Ok(n)
176    }
177
178    /// Shutdown the transport
179    pub async fn shutdown(&mut self) -> Result<()> {
180        match self {
181            Transport::Tcp(variant) => variant.shutdown().await?,
182            Transport::Unix(stream) => stream.shutdown().await?,
183        }
184        Ok(())
185    }
186}
187
188#[cfg(test)]
189mod tests {
190    use super::*;
191
192    #[tokio::test]
193    async fn test_tcp_connect_failure() {
194        let result = Transport::connect_tcp("localhost", 9999).await;
195        assert!(result.is_err());
196    }
197}