Skip to main content

fraiseql_wire/connection/
transport.rs

1//! Transport abstraction (TCP with optional TLS vs Unix socket)
2
3use crate::Result;
4use bytes::BytesMut;
5use sha2::Digest;
6use std::path::Path;
7use tokio::io::{AsyncReadExt, AsyncWriteExt};
8use tokio::net::{TcpStream, UnixStream};
9
10/// TCP stream variant: plain or TLS-encrypted
11#[allow(clippy::large_enum_variant)]
12pub enum TcpVariant {
13    /// Plain TCP connection
14    Plain(TcpStream),
15    /// TLS-encrypted TCP connection
16    Tls(tokio_rustls::client::TlsStream<TcpStream>),
17}
18
19impl std::fmt::Debug for TcpVariant {
20    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
21        match self {
22            TcpVariant::Plain(_) => f.write_str("TcpVariant::Plain(TcpStream)"),
23            TcpVariant::Tls(_) => f.write_str("TcpVariant::Tls(TlsStream)"),
24        }
25    }
26}
27
28impl TcpVariant {
29    /// Write all bytes to the stream
30    pub async fn write_all(&mut self, buf: &[u8]) -> Result<()> {
31        match self {
32            TcpVariant::Plain(stream) => stream.write_all(buf).await?,
33            TcpVariant::Tls(stream) => stream.write_all(buf).await?,
34        }
35        Ok(())
36    }
37
38    /// Flush the stream
39    pub async fn flush(&mut self) -> Result<()> {
40        match self {
41            TcpVariant::Plain(stream) => stream.flush().await?,
42            TcpVariant::Tls(stream) => stream.flush().await?,
43        }
44        Ok(())
45    }
46
47    /// Read into buffer
48    pub async fn read_buf(&mut self, buf: &mut BytesMut) -> Result<usize> {
49        let n = match self {
50            TcpVariant::Plain(stream) => stream.read_buf(buf).await?,
51            TcpVariant::Tls(stream) => stream.read_buf(buf).await?,
52        };
53        Ok(n)
54    }
55
56    /// Shutdown the stream
57    pub async fn shutdown(&mut self) -> Result<()> {
58        match self {
59            TcpVariant::Plain(stream) => stream.shutdown().await?,
60            TcpVariant::Tls(stream) => stream.shutdown().await?,
61        }
62        Ok(())
63    }
64
65    /// Extract the `tls-server-end-point` channel binding data from a TLS connection.
66    ///
67    /// Returns `None` for plain TCP connections.
68    /// For TLS connections, returns the SHA-256 hash of the server's DER-encoded certificate.
69    pub fn channel_binding_data(&self) -> Option<Vec<u8>> {
70        match self {
71            TcpVariant::Plain(_) => None,
72            TcpVariant::Tls(stream) => {
73                let (_tcp, conn) = stream.get_ref();
74                let certs = conn.peer_certificates()?;
75                let server_cert = certs.first()?;
76                // tls-server-end-point: SHA-256 hash of the DER-encoded server certificate
77                let hash = sha2::Sha256::digest(server_cert.as_ref());
78                Some(hash.to_vec())
79            }
80        }
81    }
82}
83
84/// Transport layer abstraction
85#[derive(Debug)]
86#[allow(clippy::large_enum_variant)]
87pub enum Transport {
88    /// TCP socket (plain or TLS)
89    Tcp(TcpVariant),
90    /// Unix domain socket
91    Unix(UnixStream),
92}
93
94impl Transport {
95    /// Connect via plain TCP
96    pub async fn connect_tcp(host: &str, port: u16) -> Result<Self> {
97        let stream = TcpStream::connect((host, port)).await?;
98        Ok(Transport::Tcp(TcpVariant::Plain(stream)))
99    }
100
101    /// Connect via TLS-encrypted TCP
102    pub async fn connect_tcp_tls(
103        host: &str,
104        port: u16,
105        tls_config: &crate::connection::TlsConfig,
106    ) -> Result<Self> {
107        let tcp_stream = TcpStream::connect((host, port)).await?;
108
109        // Parse server name for TLS handshake (SNI)
110        let server_name = crate::connection::parse_server_name(host)?;
111        let server_name = rustls_pki_types::ServerName::try_from(server_name)
112            .map_err(|_| crate::Error::Config(format!("Invalid hostname for TLS: {}", host)))?;
113
114        // Perform TLS handshake
115        let client_config = tls_config.client_config();
116        let tls_connector = tokio_rustls::TlsConnector::from(client_config);
117        let tls_stream = tls_connector
118            .connect(server_name, tcp_stream)
119            .await
120            .map_err(|e| crate::Error::Config(format!("TLS handshake failed: {}", e)))?;
121
122        Ok(Transport::Tcp(TcpVariant::Tls(tls_stream)))
123    }
124
125    /// Connect via Unix socket
126    pub async fn connect_unix(path: &Path) -> Result<Self> {
127        let stream = UnixStream::connect(path).await?;
128        Ok(Transport::Unix(stream))
129    }
130
131    /// Write bytes to the transport
132    pub async fn write_all(&mut self, buf: &[u8]) -> Result<()> {
133        match self {
134            Transport::Tcp(variant) => variant.write_all(buf).await?,
135            Transport::Unix(stream) => stream.write_all(buf).await?,
136        }
137        Ok(())
138    }
139
140    /// Flush the transport
141    pub async fn flush(&mut self) -> Result<()> {
142        match self {
143            Transport::Tcp(variant) => variant.flush().await?,
144            Transport::Unix(stream) => stream.flush().await?,
145        }
146        Ok(())
147    }
148
149    /// Read bytes into buffer
150    pub async fn read_buf(&mut self, buf: &mut BytesMut) -> Result<usize> {
151        let n = match self {
152            Transport::Tcp(variant) => variant.read_buf(buf).await?,
153            Transport::Unix(stream) => stream.read_buf(buf).await?,
154        };
155        Ok(n)
156    }
157
158    /// Upgrade a plain TCP transport to TLS after SSLRequest negotiation.
159    ///
160    /// Consumes `self` and returns a new `Transport` with a TLS-encrypted stream.
161    /// Returns an error if the transport is not a plain TCP connection.
162    pub async fn upgrade_to_tls(
163        self,
164        tls_config: &crate::connection::TlsConfig,
165        hostname: &str,
166    ) -> Result<Self> {
167        match self {
168            Transport::Tcp(TcpVariant::Plain(tcp_stream)) => {
169                let server_name = crate::connection::parse_server_name(hostname)?;
170                let server_name =
171                    rustls_pki_types::ServerName::try_from(server_name).map_err(|_| {
172                        crate::Error::Config(format!("Invalid hostname for TLS: {}", hostname))
173                    })?;
174
175                let client_config = tls_config.client_config();
176                let tls_connector = tokio_rustls::TlsConnector::from(client_config);
177                let tls_stream = tls_connector
178                    .connect(server_name, tcp_stream)
179                    .await
180                    .map_err(|e| crate::Error::Config(format!("TLS handshake failed: {}", e)))?;
181
182                Ok(Transport::Tcp(TcpVariant::Tls(tls_stream)))
183            }
184            Transport::Tcp(TcpVariant::Tls(_)) => Err(crate::Error::Config(
185                "transport is already TLS-encrypted".into(),
186            )),
187            Transport::Unix(_) => Err(crate::Error::Config(
188                "cannot upgrade Unix socket to TLS".into(),
189            )),
190        }
191    }
192
193    /// Shutdown the transport
194    pub async fn shutdown(&mut self) -> Result<()> {
195        match self {
196            Transport::Tcp(variant) => variant.shutdown().await?,
197            Transport::Unix(stream) => stream.shutdown().await?,
198        }
199        Ok(())
200    }
201
202    /// Extract channel binding data from the transport (if TLS is active).
203    ///
204    /// Returns `None` for plain TCP or Unix socket connections.
205    pub fn channel_binding_data(&self) -> Option<Vec<u8>> {
206        match self {
207            Transport::Tcp(variant) => variant.channel_binding_data(),
208            Transport::Unix(_) => None,
209        }
210    }
211}
212
213#[cfg(test)]
214mod tests {
215    use super::*;
216
217    #[tokio::test]
218    async fn test_tcp_connect_failure() {
219        let result = Transport::connect_tcp("localhost", 9999).await;
220        assert!(result.is_err());
221    }
222
223    #[test]
224    fn test_upgrade_to_tls_signature_exists() {
225        // Compile-time check that upgrade_to_tls exists with the expected signature
226        fn _assert_method_exists(t: Transport, c: &crate::connection::TlsConfig, h: &str) {
227            let _fut = t.upgrade_to_tls(c, h);
228        }
229    }
230}