fraiseql_wire/connection/
transport.rs1use crate::Result;
4use bytes::BytesMut;
5use std::path::Path;
6use tokio::io::{AsyncReadExt, AsyncWriteExt};
7use tokio::net::{TcpStream, UnixStream};
8
9#[allow(clippy::large_enum_variant)]
11pub enum TcpVariant {
12 Plain(TcpStream),
14 Tls(tokio_rustls::client::TlsStream<TcpStream>),
16}
17
18impl std::fmt::Debug for TcpVariant {
19 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
20 match self {
21 TcpVariant::Plain(_) => f.write_str("TcpVariant::Plain(TcpStream)"),
22 TcpVariant::Tls(_) => f.write_str("TcpVariant::Tls(TlsStream)"),
23 }
24 }
25}
26
27impl TcpVariant {
28 pub async fn write_all(&mut self, buf: &[u8]) -> Result<()> {
30 match self {
31 TcpVariant::Plain(stream) => stream.write_all(buf).await?,
32 TcpVariant::Tls(stream) => stream.write_all(buf).await?,
33 }
34 Ok(())
35 }
36
37 pub async fn flush(&mut self) -> Result<()> {
39 match self {
40 TcpVariant::Plain(stream) => stream.flush().await?,
41 TcpVariant::Tls(stream) => stream.flush().await?,
42 }
43 Ok(())
44 }
45
46 pub async fn read_buf(&mut self, buf: &mut BytesMut) -> Result<usize> {
48 let n = match self {
49 TcpVariant::Plain(stream) => stream.read_buf(buf).await?,
50 TcpVariant::Tls(stream) => stream.read_buf(buf).await?,
51 };
52 Ok(n)
53 }
54
55 pub async fn shutdown(&mut self) -> Result<()> {
57 match self {
58 TcpVariant::Plain(stream) => stream.shutdown().await?,
59 TcpVariant::Tls(stream) => stream.shutdown().await?,
60 }
61 Ok(())
62 }
63}
64
65#[derive(Debug)]
67#[allow(clippy::large_enum_variant)]
68pub enum Transport {
69 Tcp(TcpVariant),
71 Unix(UnixStream),
73}
74
75impl Transport {
76 pub async fn connect_tcp(host: &str, port: u16) -> Result<Self> {
78 let stream = TcpStream::connect((host, port)).await?;
79 Ok(Transport::Tcp(TcpVariant::Plain(stream)))
80 }
81
82 pub async fn connect_tcp_tls(
89 host: &str,
90 port: u16,
91 tls_config: &crate::connection::TlsConfig,
92 ) -> Result<Self> {
93 use tokio::io::{AsyncReadExt, AsyncWriteExt};
94
95 let mut tcp_stream = TcpStream::connect((host, port)).await?;
96
97 let ssl_request: [u8; 8] = [
101 0x00, 0x00, 0x00, 0x08, 0x04, 0xd2, 0x16, 0x2f, ];
104
105 tcp_stream.write_all(&ssl_request).await?;
106 tcp_stream.flush().await?;
107
108 let mut response = [0u8; 1];
110 tcp_stream.read_exact(&mut response).await?;
111
112 match response[0] {
113 b'S' => {
114 }
116 b'N' => {
117 return Err(crate::Error::Config(
118 "Server does not support SSL connections".to_string(),
119 ));
120 }
121 other => {
122 return Err(crate::Error::Config(format!(
123 "Unexpected SSL response from server: {:02x}",
124 other
125 )));
126 }
127 }
128
129 let server_name = crate::connection::parse_server_name(host)?;
131 let server_name = rustls_pki_types::ServerName::try_from(server_name)
132 .map_err(|_| crate::Error::Config(format!("Invalid hostname for TLS: {}", host)))?;
133
134 let client_config = tls_config.client_config();
136 let tls_connector = tokio_rustls::TlsConnector::from(client_config);
137 let tls_stream = tls_connector
138 .connect(server_name, tcp_stream)
139 .await
140 .map_err(|e| crate::Error::Config(format!("TLS handshake failed: {}", e)))?;
141
142 Ok(Transport::Tcp(TcpVariant::Tls(tls_stream)))
143 }
144
145 pub async fn connect_unix(path: &Path) -> Result<Self> {
147 let stream = UnixStream::connect(path).await?;
148 Ok(Transport::Unix(stream))
149 }
150
151 pub async fn write_all(&mut self, buf: &[u8]) -> Result<()> {
153 match self {
154 Transport::Tcp(variant) => variant.write_all(buf).await?,
155 Transport::Unix(stream) => stream.write_all(buf).await?,
156 }
157 Ok(())
158 }
159
160 pub async fn flush(&mut self) -> Result<()> {
162 match self {
163 Transport::Tcp(variant) => variant.flush().await?,
164 Transport::Unix(stream) => stream.flush().await?,
165 }
166 Ok(())
167 }
168
169 pub async fn read_buf(&mut self, buf: &mut BytesMut) -> Result<usize> {
171 let n = match self {
172 Transport::Tcp(variant) => variant.read_buf(buf).await?,
173 Transport::Unix(stream) => stream.read_buf(buf).await?,
174 };
175 Ok(n)
176 }
177
178 pub async fn shutdown(&mut self) -> Result<()> {
180 match self {
181 Transport::Tcp(variant) => variant.shutdown().await?,
182 Transport::Unix(stream) => stream.shutdown().await?,
183 }
184 Ok(())
185 }
186}
187
188#[cfg(test)]
189mod tests {
190 use super::*;
191
192 #[tokio::test]
193 async fn test_tcp_connect_failure() {
194 let result = Transport::connect_tcp("localhost", 9999).await;
195 assert!(result.is_err());
196 }
197}