fraiseql_wire/connection/
transport.rs1#[allow(unused_imports)] use crate::error::WireError;
5use crate::Result;
6use bytes::BytesMut;
7use socket2::{SockRef, TcpKeepalive};
8use std::path::Path;
9use std::time::Duration;
10use tokio::io::{AsyncReadExt, AsyncWriteExt};
11use tokio::net::{TcpStream, UnixStream};
12
13#[allow(clippy::large_enum_variant)] #[non_exhaustive]
16pub enum TcpVariant {
17 Plain(TcpStream),
19 Tls(tokio_rustls::client::TlsStream<TcpStream>),
21}
22
23impl std::fmt::Debug for TcpVariant {
24 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
25 match self {
26 TcpVariant::Plain(_) => f.write_str("TcpVariant::Plain(TcpStream)"),
27 TcpVariant::Tls(_) => f.write_str("TcpVariant::Tls(TlsStream)"),
28 }
29 }
30}
31
32impl TcpVariant {
33 pub async fn write_all(&mut self, buf: &[u8]) -> Result<()> {
39 match self {
40 TcpVariant::Plain(stream) => stream.write_all(buf).await?,
41 TcpVariant::Tls(stream) => stream.write_all(buf).await?,
42 }
43 Ok(())
44 }
45
46 pub async fn flush(&mut self) -> Result<()> {
52 match self {
53 TcpVariant::Plain(stream) => stream.flush().await?,
54 TcpVariant::Tls(stream) => stream.flush().await?,
55 }
56 Ok(())
57 }
58
59 pub async fn read_buf(&mut self, buf: &mut BytesMut) -> Result<usize> {
65 let n = match self {
66 TcpVariant::Plain(stream) => stream.read_buf(buf).await?,
67 TcpVariant::Tls(stream) => stream.read_buf(buf).await?,
68 };
69 Ok(n)
70 }
71
72 pub async fn shutdown(&mut self) -> Result<()> {
78 match self {
79 TcpVariant::Plain(stream) => stream.shutdown().await?,
80 TcpVariant::Tls(stream) => stream.shutdown().await?,
81 }
82 Ok(())
83 }
84
85 pub fn apply_keepalive(&self, idle: Duration) -> Result<()> {
96 let keepalive = TcpKeepalive::new().with_time(idle);
97 match self {
98 TcpVariant::Plain(stream) => {
99 let sock = SockRef::from(stream);
100 sock.set_keepalive(true)?;
101 sock.set_tcp_keepalive(&keepalive)?;
102 }
103 TcpVariant::Tls(stream) => {
104 let tcp = stream.get_ref().0;
106 let sock = SockRef::from(tcp);
107 sock.set_keepalive(true)?;
108 sock.set_tcp_keepalive(&keepalive)?;
109 }
110 }
111 Ok(())
112 }
113}
114
115#[derive(Debug)]
117#[allow(clippy::large_enum_variant)] #[non_exhaustive]
119pub enum Transport {
120 Tcp(TcpVariant),
122 Unix(UnixStream),
124}
125
126impl Transport {
127 pub async fn connect_tcp(host: &str, port: u16) -> Result<Self> {
133 let stream = TcpStream::connect((host, port)).await?;
134 Ok(Transport::Tcp(TcpVariant::Plain(stream)))
135 }
136
137 pub async fn connect_tcp_tls(
149 host: &str,
150 port: u16,
151 tls_config: &crate::connection::TlsConfig,
152 ) -> Result<Self> {
153 use tokio::io::{AsyncReadExt, AsyncWriteExt};
154
155 let mut tcp_stream = TcpStream::connect((host, port)).await?;
156
157 let ssl_request: [u8; 8] = [
161 0x00, 0x00, 0x00, 0x08, 0x04, 0xd2, 0x16, 0x2f, ];
164
165 tcp_stream.write_all(&ssl_request).await?;
166 tcp_stream.flush().await?;
167
168 let mut response = [0u8; 1];
170 tcp_stream.read_exact(&mut response).await?;
171
172 match response[0] {
173 b'S' => {
174 }
176 b'N' => {
177 return Err(crate::WireError::Config(
178 "Server does not support SSL connections".to_string(),
179 ));
180 }
181 other => {
182 return Err(crate::WireError::Config(format!(
183 "Unexpected SSL response from server: {:02x}",
184 other
185 )));
186 }
187 }
188
189 let server_name = crate::connection::parse_server_name(host)?;
191 let server_name = rustls_pki_types::ServerName::try_from(server_name)
192 .map_err(|_| crate::WireError::Config(format!("Invalid hostname for TLS: {}", host)))?;
193
194 let client_config = tls_config.client_config();
196 let tls_connector = tokio_rustls::TlsConnector::from(client_config);
197 let tls_stream = tls_connector
198 .connect(server_name, tcp_stream)
199 .await
200 .map_err(|e| crate::WireError::Config(format!("TLS handshake failed: {}", e)))?;
201
202 Ok(Transport::Tcp(TcpVariant::Tls(tls_stream)))
203 }
204
205 pub async fn connect_unix(path: &Path) -> Result<Self> {
211 let stream = UnixStream::connect(path).await?;
212 Ok(Transport::Unix(stream))
213 }
214
215 pub async fn write_all(&mut self, buf: &[u8]) -> Result<()> {
221 match self {
222 Transport::Tcp(variant) => variant.write_all(buf).await?,
223 Transport::Unix(stream) => stream.write_all(buf).await?,
224 }
225 Ok(())
226 }
227
228 pub async fn flush(&mut self) -> Result<()> {
234 match self {
235 Transport::Tcp(variant) => variant.flush().await?,
236 Transport::Unix(stream) => stream.flush().await?,
237 }
238 Ok(())
239 }
240
241 pub async fn read_buf(&mut self, buf: &mut BytesMut) -> Result<usize> {
247 let n = match self {
248 Transport::Tcp(variant) => variant.read_buf(buf).await?,
249 Transport::Unix(stream) => stream.read_buf(buf).await?,
250 };
251 Ok(n)
252 }
253
254 pub async fn shutdown(&mut self) -> Result<()> {
260 match self {
261 Transport::Tcp(variant) => variant.shutdown().await?,
262 Transport::Unix(stream) => stream.shutdown().await?,
263 }
264 Ok(())
265 }
266
267 pub fn apply_keepalive(&self, idle: Duration) -> Result<()> {
277 match self {
278 Transport::Tcp(variant) => variant.apply_keepalive(idle),
279 Transport::Unix(_) => Ok(()), }
281 }
282}
283
284#[cfg(test)]
285mod tests {
286 use super::*;
287
288 #[tokio::test]
289 async fn test_tcp_connect_failure() {
290 let result = Transport::connect_tcp("localhost", 9999).await;
291 assert!(
292 result.is_err(),
293 "expected Err for connection to closed port 9999, got: {result:?}"
294 );
295 }
296}