fraiseql_wire/connection/
transport.rs1use crate::Result;
4use bytes::BytesMut;
5use socket2::{SockRef, TcpKeepalive};
6use std::path::Path;
7use std::time::Duration;
8use tokio::io::{AsyncReadExt, AsyncWriteExt};
9use tokio::net::{TcpStream, UnixStream};
10
11#[allow(clippy::large_enum_variant)] pub enum TcpVariant {
14 Plain(TcpStream),
16 Tls(tokio_rustls::client::TlsStream<TcpStream>),
18}
19
20impl std::fmt::Debug for TcpVariant {
21 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
22 match self {
23 TcpVariant::Plain(_) => f.write_str("TcpVariant::Plain(TcpStream)"),
24 TcpVariant::Tls(_) => f.write_str("TcpVariant::Tls(TlsStream)"),
25 }
26 }
27}
28
29impl TcpVariant {
30 pub async fn write_all(&mut self, buf: &[u8]) -> Result<()> {
32 match self {
33 TcpVariant::Plain(stream) => stream.write_all(buf).await?,
34 TcpVariant::Tls(stream) => stream.write_all(buf).await?,
35 }
36 Ok(())
37 }
38
39 pub async fn flush(&mut self) -> Result<()> {
41 match self {
42 TcpVariant::Plain(stream) => stream.flush().await?,
43 TcpVariant::Tls(stream) => stream.flush().await?,
44 }
45 Ok(())
46 }
47
48 pub async fn read_buf(&mut self, buf: &mut BytesMut) -> Result<usize> {
50 let n = match self {
51 TcpVariant::Plain(stream) => stream.read_buf(buf).await?,
52 TcpVariant::Tls(stream) => stream.read_buf(buf).await?,
53 };
54 Ok(n)
55 }
56
57 pub async fn shutdown(&mut self) -> Result<()> {
59 match self {
60 TcpVariant::Plain(stream) => stream.shutdown().await?,
61 TcpVariant::Tls(stream) => stream.shutdown().await?,
62 }
63 Ok(())
64 }
65
66 pub fn apply_keepalive(&self, idle: Duration) -> Result<()> {
73 let keepalive = TcpKeepalive::new().with_time(idle);
74 match self {
75 TcpVariant::Plain(stream) => {
76 let sock = SockRef::from(stream);
77 sock.set_keepalive(true)?;
78 sock.set_tcp_keepalive(&keepalive)?;
79 }
80 TcpVariant::Tls(stream) => {
81 let tcp = stream.get_ref().0;
83 let sock = SockRef::from(tcp);
84 sock.set_keepalive(true)?;
85 sock.set_tcp_keepalive(&keepalive)?;
86 }
87 }
88 Ok(())
89 }
90}
91
92#[derive(Debug)]
94#[allow(clippy::large_enum_variant)] pub enum Transport {
96 Tcp(TcpVariant),
98 Unix(UnixStream),
100}
101
102impl Transport {
103 pub async fn connect_tcp(host: &str, port: u16) -> Result<Self> {
109 let stream = TcpStream::connect((host, port)).await?;
110 Ok(Transport::Tcp(TcpVariant::Plain(stream)))
111 }
112
113 pub async fn connect_tcp_tls(
126 host: &str,
127 port: u16,
128 tls_config: &crate::connection::TlsConfig,
129 ) -> Result<Self> {
130 use tokio::io::{AsyncReadExt, AsyncWriteExt};
131
132 let mut tcp_stream = TcpStream::connect((host, port)).await?;
133
134 let ssl_request: [u8; 8] = [
138 0x00, 0x00, 0x00, 0x08, 0x04, 0xd2, 0x16, 0x2f, ];
141
142 tcp_stream.write_all(&ssl_request).await?;
143 tcp_stream.flush().await?;
144
145 let mut response = [0u8; 1];
147 tcp_stream.read_exact(&mut response).await?;
148
149 match response[0] {
150 b'S' => {
151 }
153 b'N' => {
154 return Err(crate::Error::Config(
155 "Server does not support SSL connections".to_string(),
156 ));
157 }
158 other => {
159 return Err(crate::Error::Config(format!(
160 "Unexpected SSL response from server: {:02x}",
161 other
162 )));
163 }
164 }
165
166 let server_name = crate::connection::parse_server_name(host)?;
168 let server_name = rustls_pki_types::ServerName::try_from(server_name)
169 .map_err(|_| crate::Error::Config(format!("Invalid hostname for TLS: {}", host)))?;
170
171 let client_config = tls_config.client_config();
173 let tls_connector = tokio_rustls::TlsConnector::from(client_config);
174 let tls_stream = tls_connector
175 .connect(server_name, tcp_stream)
176 .await
177 .map_err(|e| crate::Error::Config(format!("TLS handshake failed: {}", e)))?;
178
179 Ok(Transport::Tcp(TcpVariant::Tls(tls_stream)))
180 }
181
182 pub async fn connect_unix(path: &Path) -> Result<Self> {
188 let stream = UnixStream::connect(path).await?;
189 Ok(Transport::Unix(stream))
190 }
191
192 pub async fn write_all(&mut self, buf: &[u8]) -> Result<()> {
198 match self {
199 Transport::Tcp(variant) => variant.write_all(buf).await?,
200 Transport::Unix(stream) => stream.write_all(buf).await?,
201 }
202 Ok(())
203 }
204
205 pub async fn flush(&mut self) -> Result<()> {
211 match self {
212 Transport::Tcp(variant) => variant.flush().await?,
213 Transport::Unix(stream) => stream.flush().await?,
214 }
215 Ok(())
216 }
217
218 pub async fn read_buf(&mut self, buf: &mut BytesMut) -> Result<usize> {
224 let n = match self {
225 Transport::Tcp(variant) => variant.read_buf(buf).await?,
226 Transport::Unix(stream) => stream.read_buf(buf).await?,
227 };
228 Ok(n)
229 }
230
231 pub async fn shutdown(&mut self) -> Result<()> {
237 match self {
238 Transport::Tcp(variant) => variant.shutdown().await?,
239 Transport::Unix(stream) => stream.shutdown().await?,
240 }
241 Ok(())
242 }
243
244 pub fn apply_keepalive(&self, idle: Duration) -> Result<()> {
254 match self {
255 Transport::Tcp(variant) => variant.apply_keepalive(idle),
256 Transport::Unix(_) => Ok(()), }
258 }
259}
260
261#[cfg(test)]
262mod tests {
263 use super::*;
264
265 #[tokio::test]
266 async fn test_tcp_connect_failure() {
267 let result = Transport::connect_tcp("localhost", 9999).await;
268 assert!(result.is_err());
269 }
270}