fraiseql_wire/connection/
transport.rs1use crate::Result;
4use bytes::BytesMut;
5use sha2::Digest;
6use std::path::Path;
7use tokio::io::{AsyncReadExt, AsyncWriteExt};
8use tokio::net::{TcpStream, UnixStream};
9
10#[allow(clippy::large_enum_variant)]
12pub enum TcpVariant {
13 Plain(TcpStream),
15 Tls(tokio_rustls::client::TlsStream<TcpStream>),
17}
18
19impl std::fmt::Debug for TcpVariant {
20 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
21 match self {
22 TcpVariant::Plain(_) => f.write_str("TcpVariant::Plain(TcpStream)"),
23 TcpVariant::Tls(_) => f.write_str("TcpVariant::Tls(TlsStream)"),
24 }
25 }
26}
27
28impl TcpVariant {
29 pub async fn write_all(&mut self, buf: &[u8]) -> Result<()> {
31 match self {
32 TcpVariant::Plain(stream) => stream.write_all(buf).await?,
33 TcpVariant::Tls(stream) => stream.write_all(buf).await?,
34 }
35 Ok(())
36 }
37
38 pub async fn flush(&mut self) -> Result<()> {
40 match self {
41 TcpVariant::Plain(stream) => stream.flush().await?,
42 TcpVariant::Tls(stream) => stream.flush().await?,
43 }
44 Ok(())
45 }
46
47 pub async fn read_buf(&mut self, buf: &mut BytesMut) -> Result<usize> {
49 let n = match self {
50 TcpVariant::Plain(stream) => stream.read_buf(buf).await?,
51 TcpVariant::Tls(stream) => stream.read_buf(buf).await?,
52 };
53 Ok(n)
54 }
55
56 pub async fn shutdown(&mut self) -> Result<()> {
58 match self {
59 TcpVariant::Plain(stream) => stream.shutdown().await?,
60 TcpVariant::Tls(stream) => stream.shutdown().await?,
61 }
62 Ok(())
63 }
64
65 pub fn channel_binding_data(&self) -> Option<Vec<u8>> {
70 match self {
71 TcpVariant::Plain(_) => None,
72 TcpVariant::Tls(stream) => {
73 let (_tcp, conn) = stream.get_ref();
74 let certs = conn.peer_certificates()?;
75 let server_cert = certs.first()?;
76 let hash = sha2::Sha256::digest(server_cert.as_ref());
78 Some(hash.to_vec())
79 }
80 }
81 }
82}
83
84#[derive(Debug)]
86#[allow(clippy::large_enum_variant)]
87pub enum Transport {
88 Tcp(TcpVariant),
90 Unix(UnixStream),
92}
93
94impl Transport {
95 pub async fn connect_tcp(host: &str, port: u16) -> Result<Self> {
97 let stream = TcpStream::connect((host, port)).await?;
98 Ok(Transport::Tcp(TcpVariant::Plain(stream)))
99 }
100
101 pub async fn connect_tcp_tls(
103 host: &str,
104 port: u16,
105 tls_config: &crate::connection::TlsConfig,
106 ) -> Result<Self> {
107 let tcp_stream = TcpStream::connect((host, port)).await?;
108
109 let server_name = crate::connection::parse_server_name(host)?;
111 let server_name = rustls_pki_types::ServerName::try_from(server_name)
112 .map_err(|_| crate::Error::Config(format!("Invalid hostname for TLS: {}", host)))?;
113
114 let client_config = tls_config.client_config();
116 let tls_connector = tokio_rustls::TlsConnector::from(client_config);
117 let tls_stream = tls_connector
118 .connect(server_name, tcp_stream)
119 .await
120 .map_err(|e| crate::Error::Config(format!("TLS handshake failed: {}", e)))?;
121
122 Ok(Transport::Tcp(TcpVariant::Tls(tls_stream)))
123 }
124
125 pub async fn connect_unix(path: &Path) -> Result<Self> {
127 let stream = UnixStream::connect(path).await?;
128 Ok(Transport::Unix(stream))
129 }
130
131 pub async fn write_all(&mut self, buf: &[u8]) -> Result<()> {
133 match self {
134 Transport::Tcp(variant) => variant.write_all(buf).await?,
135 Transport::Unix(stream) => stream.write_all(buf).await?,
136 }
137 Ok(())
138 }
139
140 pub async fn flush(&mut self) -> Result<()> {
142 match self {
143 Transport::Tcp(variant) => variant.flush().await?,
144 Transport::Unix(stream) => stream.flush().await?,
145 }
146 Ok(())
147 }
148
149 pub async fn read_buf(&mut self, buf: &mut BytesMut) -> Result<usize> {
151 let n = match self {
152 Transport::Tcp(variant) => variant.read_buf(buf).await?,
153 Transport::Unix(stream) => stream.read_buf(buf).await?,
154 };
155 Ok(n)
156 }
157
158 pub async fn upgrade_to_tls(
163 self,
164 tls_config: &crate::connection::TlsConfig,
165 hostname: &str,
166 ) -> Result<Self> {
167 match self {
168 Transport::Tcp(TcpVariant::Plain(tcp_stream)) => {
169 let server_name = crate::connection::parse_server_name(hostname)?;
170 let server_name =
171 rustls_pki_types::ServerName::try_from(server_name).map_err(|_| {
172 crate::Error::Config(format!("Invalid hostname for TLS: {}", hostname))
173 })?;
174
175 let client_config = tls_config.client_config();
176 let tls_connector = tokio_rustls::TlsConnector::from(client_config);
177 let tls_stream = tls_connector
178 .connect(server_name, tcp_stream)
179 .await
180 .map_err(|e| crate::Error::Config(format!("TLS handshake failed: {}", e)))?;
181
182 Ok(Transport::Tcp(TcpVariant::Tls(tls_stream)))
183 }
184 Transport::Tcp(TcpVariant::Tls(_)) => Err(crate::Error::Config(
185 "transport is already TLS-encrypted".into(),
186 )),
187 Transport::Unix(_) => Err(crate::Error::Config(
188 "cannot upgrade Unix socket to TLS".into(),
189 )),
190 }
191 }
192
193 pub async fn shutdown(&mut self) -> Result<()> {
195 match self {
196 Transport::Tcp(variant) => variant.shutdown().await?,
197 Transport::Unix(stream) => stream.shutdown().await?,
198 }
199 Ok(())
200 }
201
202 pub fn channel_binding_data(&self) -> Option<Vec<u8>> {
206 match self {
207 Transport::Tcp(variant) => variant.channel_binding_data(),
208 Transport::Unix(_) => None,
209 }
210 }
211}
212
213#[cfg(test)]
214mod tests {
215 use super::*;
216
217 #[tokio::test]
218 async fn test_tcp_connect_failure() {
219 let result = Transport::connect_tcp("localhost", 9999).await;
220 assert!(result.is_err());
221 }
222
223 #[test]
224 fn test_upgrade_to_tls_signature_exists() {
225 fn _assert_method_exists(t: Transport, c: &crate::connection::TlsConfig, h: &str) {
227 let _fut = t.upgrade_to_tls(c, h);
228 }
229 }
230}