fraiseql_wire/connection/
transport.rs1use crate::Result;
4use bytes::BytesMut;
5use std::path::Path;
6use tokio::io::{AsyncReadExt, AsyncWriteExt};
7use tokio::net::{TcpStream, UnixStream};
8
9#[allow(clippy::large_enum_variant)]
11pub enum TcpVariant {
12 Plain(TcpStream),
14 Tls(tokio_rustls::client::TlsStream<TcpStream>),
16}
17
18impl std::fmt::Debug for TcpVariant {
19 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
20 match self {
21 TcpVariant::Plain(_) => f.write_str("TcpVariant::Plain(TcpStream)"),
22 TcpVariant::Tls(_) => f.write_str("TcpVariant::Tls(TlsStream)"),
23 }
24 }
25}
26
27impl TcpVariant {
28 pub async fn write_all(&mut self, buf: &[u8]) -> Result<()> {
30 match self {
31 TcpVariant::Plain(stream) => stream.write_all(buf).await?,
32 TcpVariant::Tls(stream) => stream.write_all(buf).await?,
33 }
34 Ok(())
35 }
36
37 pub async fn flush(&mut self) -> Result<()> {
39 match self {
40 TcpVariant::Plain(stream) => stream.flush().await?,
41 TcpVariant::Tls(stream) => stream.flush().await?,
42 }
43 Ok(())
44 }
45
46 pub async fn read_buf(&mut self, buf: &mut BytesMut) -> Result<usize> {
48 let n = match self {
49 TcpVariant::Plain(stream) => stream.read_buf(buf).await?,
50 TcpVariant::Tls(stream) => stream.read_buf(buf).await?,
51 };
52 Ok(n)
53 }
54
55 pub async fn shutdown(&mut self) -> Result<()> {
57 match self {
58 TcpVariant::Plain(stream) => stream.shutdown().await?,
59 TcpVariant::Tls(stream) => stream.shutdown().await?,
60 }
61 Ok(())
62 }
63}
64
65#[derive(Debug)]
67#[allow(clippy::large_enum_variant)]
68pub enum Transport {
69 Tcp(TcpVariant),
71 Unix(UnixStream),
73}
74
75impl Transport {
76 pub async fn connect_tcp(host: &str, port: u16) -> Result<Self> {
78 let stream = TcpStream::connect((host, port)).await?;
79 Ok(Transport::Tcp(TcpVariant::Plain(stream)))
80 }
81
82 pub async fn connect_tcp_tls(
84 host: &str,
85 port: u16,
86 tls_config: &crate::connection::TlsConfig,
87 ) -> Result<Self> {
88 let tcp_stream = TcpStream::connect((host, port)).await?;
89
90 let server_name = crate::connection::parse_server_name(host)?;
92 let server_name = rustls_pki_types::ServerName::try_from(server_name)
93 .map_err(|_| crate::Error::Config(format!("Invalid hostname for TLS: {}", host)))?;
94
95 let client_config = tls_config.client_config();
97 let tls_connector = tokio_rustls::TlsConnector::from(client_config);
98 let tls_stream = tls_connector
99 .connect(server_name, tcp_stream)
100 .await
101 .map_err(|e| crate::Error::Config(format!("TLS handshake failed: {}", e)))?;
102
103 Ok(Transport::Tcp(TcpVariant::Tls(tls_stream)))
104 }
105
106 pub async fn connect_unix(path: &Path) -> Result<Self> {
108 let stream = UnixStream::connect(path).await?;
109 Ok(Transport::Unix(stream))
110 }
111
112 pub async fn write_all(&mut self, buf: &[u8]) -> Result<()> {
114 match self {
115 Transport::Tcp(variant) => variant.write_all(buf).await?,
116 Transport::Unix(stream) => stream.write_all(buf).await?,
117 }
118 Ok(())
119 }
120
121 pub async fn flush(&mut self) -> Result<()> {
123 match self {
124 Transport::Tcp(variant) => variant.flush().await?,
125 Transport::Unix(stream) => stream.flush().await?,
126 }
127 Ok(())
128 }
129
130 pub async fn read_buf(&mut self, buf: &mut BytesMut) -> Result<usize> {
132 let n = match self {
133 Transport::Tcp(variant) => variant.read_buf(buf).await?,
134 Transport::Unix(stream) => stream.read_buf(buf).await?,
135 };
136 Ok(n)
137 }
138
139 pub async fn shutdown(&mut self) -> Result<()> {
141 match self {
142 Transport::Tcp(variant) => variant.shutdown().await?,
143 Transport::Unix(stream) => stream.shutdown().await?,
144 }
145 Ok(())
146 }
147}
148
149#[cfg(test)]
150mod tests {
151 use super::*;
152
153 #[tokio::test]
154 async fn test_tcp_connect_failure() {
155 let result = Transport::connect_tcp("localhost", 9999).await;
156 assert!(result.is_err());
157 }
158}