Skip to main content

fraiseql_wire/connection/
transport.rs

1//! Transport abstraction (TCP with optional TLS vs Unix socket)
2
3use crate::Result;
4use bytes::BytesMut;
5use std::path::Path;
6use tokio::io::{AsyncReadExt, AsyncWriteExt};
7use tokio::net::{TcpStream, UnixStream};
8
9/// TCP stream variant: plain or TLS-encrypted
10#[allow(clippy::large_enum_variant)]
11pub enum TcpVariant {
12    /// Plain TCP connection
13    Plain(TcpStream),
14    /// TLS-encrypted TCP connection
15    Tls(tokio_rustls::client::TlsStream<TcpStream>),
16}
17
18impl std::fmt::Debug for TcpVariant {
19    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
20        match self {
21            TcpVariant::Plain(_) => f.write_str("TcpVariant::Plain(TcpStream)"),
22            TcpVariant::Tls(_) => f.write_str("TcpVariant::Tls(TlsStream)"),
23        }
24    }
25}
26
27impl TcpVariant {
28    /// Write all bytes to the stream
29    pub async fn write_all(&mut self, buf: &[u8]) -> Result<()> {
30        match self {
31            TcpVariant::Plain(stream) => stream.write_all(buf).await?,
32            TcpVariant::Tls(stream) => stream.write_all(buf).await?,
33        }
34        Ok(())
35    }
36
37    /// Flush the stream
38    pub async fn flush(&mut self) -> Result<()> {
39        match self {
40            TcpVariant::Plain(stream) => stream.flush().await?,
41            TcpVariant::Tls(stream) => stream.flush().await?,
42        }
43        Ok(())
44    }
45
46    /// Read into buffer
47    pub async fn read_buf(&mut self, buf: &mut BytesMut) -> Result<usize> {
48        let n = match self {
49            TcpVariant::Plain(stream) => stream.read_buf(buf).await?,
50            TcpVariant::Tls(stream) => stream.read_buf(buf).await?,
51        };
52        Ok(n)
53    }
54
55    /// Shutdown the stream
56    pub async fn shutdown(&mut self) -> Result<()> {
57        match self {
58            TcpVariant::Plain(stream) => stream.shutdown().await?,
59            TcpVariant::Tls(stream) => stream.shutdown().await?,
60        }
61        Ok(())
62    }
63}
64
65/// Transport layer abstraction
66#[derive(Debug)]
67#[allow(clippy::large_enum_variant)]
68pub enum Transport {
69    /// TCP socket (plain or TLS)
70    Tcp(TcpVariant),
71    /// Unix domain socket
72    Unix(UnixStream),
73}
74
75impl Transport {
76    /// Connect via plain TCP
77    pub async fn connect_tcp(host: &str, port: u16) -> Result<Self> {
78        let stream = TcpStream::connect((host, port)).await?;
79        Ok(Transport::Tcp(TcpVariant::Plain(stream)))
80    }
81
82    /// Connect via TLS-encrypted TCP
83    pub async fn connect_tcp_tls(
84        host: &str,
85        port: u16,
86        tls_config: &crate::connection::TlsConfig,
87    ) -> Result<Self> {
88        let tcp_stream = TcpStream::connect((host, port)).await?;
89
90        // Parse server name for TLS handshake (SNI)
91        let server_name = crate::connection::parse_server_name(host)?;
92        let server_name = rustls_pki_types::ServerName::try_from(server_name)
93            .map_err(|_| crate::Error::Config(format!("Invalid hostname for TLS: {}", host)))?;
94
95        // Perform TLS handshake
96        let client_config = tls_config.client_config();
97        let tls_connector = tokio_rustls::TlsConnector::from(client_config);
98        let tls_stream = tls_connector
99            .connect(server_name, tcp_stream)
100            .await
101            .map_err(|e| crate::Error::Config(format!("TLS handshake failed: {}", e)))?;
102
103        Ok(Transport::Tcp(TcpVariant::Tls(tls_stream)))
104    }
105
106    /// Connect via Unix socket
107    pub async fn connect_unix(path: &Path) -> Result<Self> {
108        let stream = UnixStream::connect(path).await?;
109        Ok(Transport::Unix(stream))
110    }
111
112    /// Write bytes to the transport
113    pub async fn write_all(&mut self, buf: &[u8]) -> Result<()> {
114        match self {
115            Transport::Tcp(variant) => variant.write_all(buf).await?,
116            Transport::Unix(stream) => stream.write_all(buf).await?,
117        }
118        Ok(())
119    }
120
121    /// Flush the transport
122    pub async fn flush(&mut self) -> Result<()> {
123        match self {
124            Transport::Tcp(variant) => variant.flush().await?,
125            Transport::Unix(stream) => stream.flush().await?,
126        }
127        Ok(())
128    }
129
130    /// Read bytes into buffer
131    pub async fn read_buf(&mut self, buf: &mut BytesMut) -> Result<usize> {
132        let n = match self {
133            Transport::Tcp(variant) => variant.read_buf(buf).await?,
134            Transport::Unix(stream) => stream.read_buf(buf).await?,
135        };
136        Ok(n)
137    }
138
139    /// Shutdown the transport
140    pub async fn shutdown(&mut self) -> Result<()> {
141        match self {
142            Transport::Tcp(variant) => variant.shutdown().await?,
143            Transport::Unix(stream) => stream.shutdown().await?,
144        }
145        Ok(())
146    }
147}
148
149#[cfg(test)]
150mod tests {
151    use super::*;
152
153    #[tokio::test]
154    async fn test_tcp_connect_failure() {
155        let result = Transport::connect_tcp("localhost", 9999).await;
156        assert!(result.is_err());
157    }
158}