Skip to main content

fraiseql_wire/connection/
transport.rs

1//! Transport abstraction (TCP with optional TLS vs Unix socket)
2
3use crate::Result;
4use bytes::BytesMut;
5use socket2::{SockRef, TcpKeepalive};
6use std::path::Path;
7use std::time::Duration;
8use tokio::io::{AsyncReadExt, AsyncWriteExt};
9use tokio::net::{TcpStream, UnixStream};
10
11/// TCP stream variant: plain or TLS-encrypted
12#[allow(clippy::large_enum_variant)] // Reason: variant size difference is acceptable; boxing would add indirection in hot path
13pub enum TcpVariant {
14    /// Plain TCP connection
15    Plain(TcpStream),
16    /// TLS-encrypted TCP connection
17    Tls(tokio_rustls::client::TlsStream<TcpStream>),
18}
19
20impl std::fmt::Debug for TcpVariant {
21    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
22        match self {
23            TcpVariant::Plain(_) => f.write_str("TcpVariant::Plain(TcpStream)"),
24            TcpVariant::Tls(_) => f.write_str("TcpVariant::Tls(TlsStream)"),
25        }
26    }
27}
28
29impl TcpVariant {
30    /// Write all bytes to the stream
31    pub async fn write_all(&mut self, buf: &[u8]) -> Result<()> {
32        match self {
33            TcpVariant::Plain(stream) => stream.write_all(buf).await?,
34            TcpVariant::Tls(stream) => stream.write_all(buf).await?,
35        }
36        Ok(())
37    }
38
39    /// Flush the stream
40    pub async fn flush(&mut self) -> Result<()> {
41        match self {
42            TcpVariant::Plain(stream) => stream.flush().await?,
43            TcpVariant::Tls(stream) => stream.flush().await?,
44        }
45        Ok(())
46    }
47
48    /// Read into buffer
49    pub async fn read_buf(&mut self, buf: &mut BytesMut) -> Result<usize> {
50        let n = match self {
51            TcpVariant::Plain(stream) => stream.read_buf(buf).await?,
52            TcpVariant::Tls(stream) => stream.read_buf(buf).await?,
53        };
54        Ok(n)
55    }
56
57    /// Shutdown the stream
58    pub async fn shutdown(&mut self) -> Result<()> {
59        match self {
60            TcpVariant::Plain(stream) => stream.shutdown().await?,
61            TcpVariant::Tls(stream) => stream.shutdown().await?,
62        }
63        Ok(())
64    }
65
66    /// Apply TCP keepalive settings to the underlying socket.
67    ///
68    /// Extracts the raw socket reference via `socket2::SockRef` and configures
69    /// `SO_KEEPALIVE` with the given idle interval. This is a no-op for TLS
70    /// streams that wrap a `TcpStream`; the keepalive is applied to the inner
71    /// TCP socket before the TLS handshake anyway.
72    pub fn apply_keepalive(&self, idle: Duration) -> Result<()> {
73        let keepalive = TcpKeepalive::new().with_time(idle);
74        match self {
75            TcpVariant::Plain(stream) => {
76                let sock = SockRef::from(stream);
77                sock.set_keepalive(true)?;
78                sock.set_tcp_keepalive(&keepalive)?;
79            }
80            TcpVariant::Tls(stream) => {
81                // The inner TcpStream is accessible via the get_ref() chain.
82                let tcp = stream.get_ref().0;
83                let sock = SockRef::from(tcp);
84                sock.set_keepalive(true)?;
85                sock.set_tcp_keepalive(&keepalive)?;
86            }
87        }
88        Ok(())
89    }
90}
91
92/// Transport layer abstraction
93#[derive(Debug)]
94#[allow(clippy::large_enum_variant)] // Reason: variant size difference is acceptable; boxing would add indirection in hot path
95pub enum Transport {
96    /// TCP socket (plain or TLS)
97    Tcp(TcpVariant),
98    /// Unix domain socket
99    Unix(UnixStream),
100}
101
102impl Transport {
103    /// Connect via plain TCP
104    ///
105    /// # Errors
106    ///
107    /// Returns `Error::Io` if the TCP connection fails.
108    pub async fn connect_tcp(host: &str, port: u16) -> Result<Self> {
109        let stream = TcpStream::connect((host, port)).await?;
110        Ok(Transport::Tcp(TcpVariant::Plain(stream)))
111    }
112
113    /// Connect via TLS-encrypted TCP using PostgreSQL SSL negotiation protocol.
114    ///
115    /// PostgreSQL requires a specific SSL upgrade sequence:
116    /// 1. Send `SSLRequest` message (8 bytes)
117    /// 2. Server responds with 'S' (accept) or 'N' (reject)
118    /// 3. If accepted, perform TLS handshake
119    ///
120    /// # Errors
121    ///
122    /// Returns `Error::Io` if the TCP connection or SSL negotiation fails.
123    /// Returns `Error::Config` if the server rejects SSL, sends an unexpected response,
124    /// the hostname is invalid for TLS, or the TLS handshake fails.
125    pub async fn connect_tcp_tls(
126        host: &str,
127        port: u16,
128        tls_config: &crate::connection::TlsConfig,
129    ) -> Result<Self> {
130        use tokio::io::{AsyncReadExt, AsyncWriteExt};
131
132        let mut tcp_stream = TcpStream::connect((host, port)).await?;
133
134        // PostgreSQL SSLRequest message:
135        // - Length: 8 (4 bytes, big-endian)
136        // - Request code: 80877103 (4 bytes, big-endian) = (1234 << 16) | 5679
137        let ssl_request: [u8; 8] = [
138            0x00, 0x00, 0x00, 0x08, // Length = 8
139            0x04, 0xd2, 0x16, 0x2f, // Request code = 80877103
140        ];
141
142        tcp_stream.write_all(&ssl_request).await?;
143        tcp_stream.flush().await?;
144
145        // Read server response (single byte: 'S' = accept, 'N' = reject)
146        let mut response = [0u8; 1];
147        tcp_stream.read_exact(&mut response).await?;
148
149        match response[0] {
150            b'S' => {
151                // Server accepted SSL - proceed with TLS handshake
152            }
153            b'N' => {
154                return Err(crate::Error::Config(
155                    "Server does not support SSL connections".to_string(),
156                ));
157            }
158            other => {
159                return Err(crate::Error::Config(format!(
160                    "Unexpected SSL response from server: {:02x}",
161                    other
162                )));
163            }
164        }
165
166        // Parse server name for TLS handshake (SNI)
167        let server_name = crate::connection::parse_server_name(host)?;
168        let server_name = rustls_pki_types::ServerName::try_from(server_name)
169            .map_err(|_| crate::Error::Config(format!("Invalid hostname for TLS: {}", host)))?;
170
171        // Perform TLS handshake
172        let client_config = tls_config.client_config();
173        let tls_connector = tokio_rustls::TlsConnector::from(client_config);
174        let tls_stream = tls_connector
175            .connect(server_name, tcp_stream)
176            .await
177            .map_err(|e| crate::Error::Config(format!("TLS handshake failed: {}", e)))?;
178
179        Ok(Transport::Tcp(TcpVariant::Tls(tls_stream)))
180    }
181
182    /// Connect via Unix socket
183    ///
184    /// # Errors
185    ///
186    /// Returns `Error::Io` if the Unix socket connection fails.
187    pub async fn connect_unix(path: &Path) -> Result<Self> {
188        let stream = UnixStream::connect(path).await?;
189        Ok(Transport::Unix(stream))
190    }
191
192    /// Write bytes to the transport
193    ///
194    /// # Errors
195    ///
196    /// Returns `Error::Io` if the write operation fails.
197    pub async fn write_all(&mut self, buf: &[u8]) -> Result<()> {
198        match self {
199            Transport::Tcp(variant) => variant.write_all(buf).await?,
200            Transport::Unix(stream) => stream.write_all(buf).await?,
201        }
202        Ok(())
203    }
204
205    /// Flush the transport
206    ///
207    /// # Errors
208    ///
209    /// Returns `Error::Io` if the flush operation fails.
210    pub async fn flush(&mut self) -> Result<()> {
211        match self {
212            Transport::Tcp(variant) => variant.flush().await?,
213            Transport::Unix(stream) => stream.flush().await?,
214        }
215        Ok(())
216    }
217
218    /// Read bytes into buffer
219    ///
220    /// # Errors
221    ///
222    /// Returns `Error::Io` if the read operation fails.
223    pub async fn read_buf(&mut self, buf: &mut BytesMut) -> Result<usize> {
224        let n = match self {
225            Transport::Tcp(variant) => variant.read_buf(buf).await?,
226            Transport::Unix(stream) => stream.read_buf(buf).await?,
227        };
228        Ok(n)
229    }
230
231    /// Shutdown the transport
232    ///
233    /// # Errors
234    ///
235    /// Returns `Error::Io` if the shutdown operation fails.
236    pub async fn shutdown(&mut self) -> Result<()> {
237        match self {
238            Transport::Tcp(variant) => variant.shutdown().await?,
239            Transport::Unix(stream) => stream.shutdown().await?,
240        }
241        Ok(())
242    }
243
244    /// Apply TCP keepalive to this transport, if it is a TCP connection.
245    ///
246    /// A no-op for Unix socket transports (keepalive is a TCP-layer feature).
247    /// Logs a warning and returns `Ok(())` rather than failing if the platform
248    /// does not support the requested keepalive interval.
249    ///
250    /// # Errors
251    ///
252    /// Returns `Error::Io` if setting the TCP keepalive option fails.
253    pub fn apply_keepalive(&self, idle: Duration) -> Result<()> {
254        match self {
255            Transport::Tcp(variant) => variant.apply_keepalive(idle),
256            Transport::Unix(_) => Ok(()), // keepalive not applicable on Unix sockets
257        }
258    }
259}
260
261#[cfg(test)]
262mod tests {
263    use super::*;
264
265    #[tokio::test]
266    async fn test_tcp_connect_failure() {
267        let result = Transport::connect_tcp("localhost", 9999).await;
268        assert!(result.is_err());
269    }
270}