Skip to main content

fraiseql_wire/connection/
transport.rs

1//! Transport abstraction (TCP with optional TLS vs Unix socket)
2
3#[allow(unused_imports)] // Reason: used only in doc links for `# Errors` sections
4use crate::error::WireError;
5use crate::Result;
6use bytes::BytesMut;
7use socket2::{SockRef, TcpKeepalive};
8use std::path::Path;
9use std::time::Duration;
10use tokio::io::{AsyncReadExt, AsyncWriteExt};
11use tokio::net::{TcpStream, UnixStream};
12
13/// TCP stream variant: plain or TLS-encrypted
14#[allow(clippy::large_enum_variant)] // Reason: variant size difference is acceptable; boxing would add indirection in hot path
15#[non_exhaustive]
16pub enum TcpVariant {
17    /// Plain TCP connection
18    Plain(TcpStream),
19    /// TLS-encrypted TCP connection
20    Tls(tokio_rustls::client::TlsStream<TcpStream>),
21}
22
23impl std::fmt::Debug for TcpVariant {
24    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
25        match self {
26            TcpVariant::Plain(_) => f.write_str("TcpVariant::Plain(TcpStream)"),
27            TcpVariant::Tls(_) => f.write_str("TcpVariant::Tls(TlsStream)"),
28        }
29    }
30}
31
32impl TcpVariant {
33    /// Write all bytes to the stream
34    ///
35    /// # Errors
36    ///
37    /// Returns [`WireError`] if the underlying I/O write fails.
38    pub async fn write_all(&mut self, buf: &[u8]) -> Result<()> {
39        match self {
40            TcpVariant::Plain(stream) => stream.write_all(buf).await?,
41            TcpVariant::Tls(stream) => stream.write_all(buf).await?,
42        }
43        Ok(())
44    }
45
46    /// Flush the stream
47    ///
48    /// # Errors
49    ///
50    /// Returns [`WireError`] if the underlying I/O flush fails.
51    pub async fn flush(&mut self) -> Result<()> {
52        match self {
53            TcpVariant::Plain(stream) => stream.flush().await?,
54            TcpVariant::Tls(stream) => stream.flush().await?,
55        }
56        Ok(())
57    }
58
59    /// Read into buffer
60    ///
61    /// # Errors
62    ///
63    /// Returns [`WireError`] if the underlying I/O read fails.
64    pub async fn read_buf(&mut self, buf: &mut BytesMut) -> Result<usize> {
65        let n = match self {
66            TcpVariant::Plain(stream) => stream.read_buf(buf).await?,
67            TcpVariant::Tls(stream) => stream.read_buf(buf).await?,
68        };
69        Ok(n)
70    }
71
72    /// Shutdown the stream
73    ///
74    /// # Errors
75    ///
76    /// Returns [`WireError`] if the underlying I/O shutdown fails.
77    pub async fn shutdown(&mut self) -> Result<()> {
78        match self {
79            TcpVariant::Plain(stream) => stream.shutdown().await?,
80            TcpVariant::Tls(stream) => stream.shutdown().await?,
81        }
82        Ok(())
83    }
84
85    /// Apply TCP keepalive settings to the underlying socket.
86    ///
87    /// Extracts the raw socket reference via `socket2::SockRef` and configures
88    /// `SO_KEEPALIVE` with the given idle interval. This is a no-op for TLS
89    /// streams that wrap a `TcpStream`; the keepalive is applied to the inner
90    /// TCP socket before the TLS handshake anyway.
91    ///
92    /// # Errors
93    ///
94    /// Returns [`WireError`] if setting the socket keepalive options fails.
95    pub fn apply_keepalive(&self, idle: Duration) -> Result<()> {
96        let keepalive = TcpKeepalive::new().with_time(idle);
97        match self {
98            TcpVariant::Plain(stream) => {
99                let sock = SockRef::from(stream);
100                sock.set_keepalive(true)?;
101                sock.set_tcp_keepalive(&keepalive)?;
102            }
103            TcpVariant::Tls(stream) => {
104                // The inner TcpStream is accessible via the get_ref() chain.
105                let tcp = stream.get_ref().0;
106                let sock = SockRef::from(tcp);
107                sock.set_keepalive(true)?;
108                sock.set_tcp_keepalive(&keepalive)?;
109            }
110        }
111        Ok(())
112    }
113}
114
115/// Transport layer abstraction
116#[derive(Debug)]
117#[allow(clippy::large_enum_variant)] // Reason: variant size difference is acceptable; boxing would add indirection in hot path
118#[non_exhaustive]
119pub enum Transport {
120    /// TCP socket (plain or TLS)
121    Tcp(TcpVariant),
122    /// Unix domain socket
123    Unix(UnixStream),
124}
125
126impl Transport {
127    /// Connect via plain TCP
128    ///
129    /// # Errors
130    ///
131    /// Returns [`WireError`] if the TCP connection to `host:port` fails.
132    pub async fn connect_tcp(host: &str, port: u16) -> Result<Self> {
133        let stream = TcpStream::connect((host, port)).await?;
134        Ok(Transport::Tcp(TcpVariant::Plain(stream)))
135    }
136
137    /// Connect via TLS-encrypted TCP using PostgreSQL SSL negotiation protocol.
138    ///
139    /// PostgreSQL requires a specific SSL upgrade sequence:
140    /// 1. Send `SSLRequest` message (8 bytes)
141    /// 2. Server responds with 'S' (accept) or 'N' (reject)
142    /// 3. If accepted, perform TLS handshake
143    ///
144    /// # Errors
145    ///
146    /// Returns `WireError::Io` if the TCP connection or TLS handshake fails.
147    /// Returns `WireError::Config` if the server rejects the SSL request.
148    pub async fn connect_tcp_tls(
149        host: &str,
150        port: u16,
151        tls_config: &crate::connection::TlsConfig,
152    ) -> Result<Self> {
153        use tokio::io::{AsyncReadExt, AsyncWriteExt};
154
155        let mut tcp_stream = TcpStream::connect((host, port)).await?;
156
157        // PostgreSQL SSLRequest message:
158        // - Length: 8 (4 bytes, big-endian)
159        // - Request code: 80877103 (4 bytes, big-endian) = (1234 << 16) | 5679
160        let ssl_request: [u8; 8] = [
161            0x00, 0x00, 0x00, 0x08, // Length = 8
162            0x04, 0xd2, 0x16, 0x2f, // Request code = 80877103
163        ];
164
165        tcp_stream.write_all(&ssl_request).await?;
166        tcp_stream.flush().await?;
167
168        // Read server response (single byte: 'S' = accept, 'N' = reject)
169        let mut response = [0u8; 1];
170        tcp_stream.read_exact(&mut response).await?;
171
172        match response[0] {
173            b'S' => {
174                // Server accepted SSL - proceed with TLS handshake
175            }
176            b'N' => {
177                return Err(crate::WireError::Config(
178                    "Server does not support SSL connections".to_string(),
179                ));
180            }
181            other => {
182                return Err(crate::WireError::Config(format!(
183                    "Unexpected SSL response from server: {:02x}",
184                    other
185                )));
186            }
187        }
188
189        // Parse server name for TLS handshake (SNI)
190        let server_name = crate::connection::parse_server_name(host)?;
191        let server_name = rustls_pki_types::ServerName::try_from(server_name)
192            .map_err(|_| crate::WireError::Config(format!("Invalid hostname for TLS: {}", host)))?;
193
194        // Perform TLS handshake
195        let client_config = tls_config.client_config();
196        let tls_connector = tokio_rustls::TlsConnector::from(client_config);
197        let tls_stream = tls_connector
198            .connect(server_name, tcp_stream)
199            .await
200            .map_err(|e| crate::WireError::Config(format!("TLS handshake failed: {}", e)))?;
201
202        Ok(Transport::Tcp(TcpVariant::Tls(tls_stream)))
203    }
204
205    /// Connect via Unix socket
206    ///
207    /// # Errors
208    ///
209    /// Returns [`WireError`] if the Unix domain socket connection to `path` fails.
210    pub async fn connect_unix(path: &Path) -> Result<Self> {
211        let stream = UnixStream::connect(path).await?;
212        Ok(Transport::Unix(stream))
213    }
214
215    /// Write bytes to the transport
216    ///
217    /// # Errors
218    ///
219    /// Returns [`WireError`] if the underlying I/O write fails.
220    pub async fn write_all(&mut self, buf: &[u8]) -> Result<()> {
221        match self {
222            Transport::Tcp(variant) => variant.write_all(buf).await?,
223            Transport::Unix(stream) => stream.write_all(buf).await?,
224        }
225        Ok(())
226    }
227
228    /// Flush the transport
229    ///
230    /// # Errors
231    ///
232    /// Returns [`WireError`] if the underlying I/O flush fails.
233    pub async fn flush(&mut self) -> Result<()> {
234        match self {
235            Transport::Tcp(variant) => variant.flush().await?,
236            Transport::Unix(stream) => stream.flush().await?,
237        }
238        Ok(())
239    }
240
241    /// Read bytes into buffer
242    ///
243    /// # Errors
244    ///
245    /// Returns [`WireError`] if the underlying I/O read fails.
246    pub async fn read_buf(&mut self, buf: &mut BytesMut) -> Result<usize> {
247        let n = match self {
248            Transport::Tcp(variant) => variant.read_buf(buf).await?,
249            Transport::Unix(stream) => stream.read_buf(buf).await?,
250        };
251        Ok(n)
252    }
253
254    /// Shutdown the transport
255    ///
256    /// # Errors
257    ///
258    /// Returns [`WireError`] if the underlying I/O shutdown fails.
259    pub async fn shutdown(&mut self) -> Result<()> {
260        match self {
261            Transport::Tcp(variant) => variant.shutdown().await?,
262            Transport::Unix(stream) => stream.shutdown().await?,
263        }
264        Ok(())
265    }
266
267    /// Apply TCP keepalive to this transport, if it is a TCP connection.
268    ///
269    /// A no-op for Unix socket transports (keepalive is a TCP-layer feature).
270    /// Logs a warning and returns `Ok(())` rather than failing if the platform
271    /// does not support the requested keepalive interval.
272    ///
273    /// # Errors
274    ///
275    /// Returns [`WireError`] if setting the socket keepalive options fails on a TCP transport.
276    pub fn apply_keepalive(&self, idle: Duration) -> Result<()> {
277        match self {
278            Transport::Tcp(variant) => variant.apply_keepalive(idle),
279            Transport::Unix(_) => Ok(()), // keepalive not applicable on Unix sockets
280        }
281    }
282}
283
284#[cfg(test)]
285mod tests {
286    use super::*;
287
288    #[tokio::test]
289    async fn test_tcp_connect_failure() {
290        let result = Transport::connect_tcp("localhost", 9999).await;
291        assert!(
292            result.is_err(),
293            "expected Err for connection to closed port 9999, got: {result:?}"
294        );
295    }
296}