#[allow(unused_imports)] use crate::error::WireError;
use crate::Result;
use bytes::BytesMut;
use socket2::{SockRef, TcpKeepalive};
use std::path::Path;
use std::time::Duration;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpStream, UnixStream};
#[allow(clippy::large_enum_variant)] #[non_exhaustive]
pub enum TcpVariant {
Plain(TcpStream),
Tls(tokio_rustls::client::TlsStream<TcpStream>),
}
impl std::fmt::Debug for TcpVariant {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
TcpVariant::Plain(_) => f.write_str("TcpVariant::Plain(TcpStream)"),
TcpVariant::Tls(_) => f.write_str("TcpVariant::Tls(TlsStream)"),
}
}
}
impl TcpVariant {
pub async fn write_all(&mut self, buf: &[u8]) -> Result<()> {
match self {
TcpVariant::Plain(stream) => stream.write_all(buf).await?,
TcpVariant::Tls(stream) => stream.write_all(buf).await?,
}
Ok(())
}
pub async fn flush(&mut self) -> Result<()> {
match self {
TcpVariant::Plain(stream) => stream.flush().await?,
TcpVariant::Tls(stream) => stream.flush().await?,
}
Ok(())
}
pub async fn read_buf(&mut self, buf: &mut BytesMut) -> Result<usize> {
let n = match self {
TcpVariant::Plain(stream) => stream.read_buf(buf).await?,
TcpVariant::Tls(stream) => stream.read_buf(buf).await?,
};
Ok(n)
}
pub async fn shutdown(&mut self) -> Result<()> {
match self {
TcpVariant::Plain(stream) => stream.shutdown().await?,
TcpVariant::Tls(stream) => stream.shutdown().await?,
}
Ok(())
}
pub fn apply_keepalive(&self, idle: Duration) -> Result<()> {
let keepalive = TcpKeepalive::new().with_time(idle);
match self {
TcpVariant::Plain(stream) => {
let sock = SockRef::from(stream);
sock.set_keepalive(true)?;
sock.set_tcp_keepalive(&keepalive)?;
}
TcpVariant::Tls(stream) => {
let tcp = stream.get_ref().0;
let sock = SockRef::from(tcp);
sock.set_keepalive(true)?;
sock.set_tcp_keepalive(&keepalive)?;
}
}
Ok(())
}
}
#[derive(Debug)]
#[allow(clippy::large_enum_variant)] #[non_exhaustive]
pub enum Transport {
Tcp(TcpVariant),
Unix(UnixStream),
}
impl Transport {
pub async fn connect_tcp(host: &str, port: u16) -> Result<Self> {
let stream = TcpStream::connect((host, port)).await?;
Ok(Transport::Tcp(TcpVariant::Plain(stream)))
}
pub async fn connect_tcp_tls(
host: &str,
port: u16,
tls_config: &crate::connection::TlsConfig,
) -> Result<Self> {
use tokio::io::{AsyncReadExt, AsyncWriteExt};
let mut tcp_stream = TcpStream::connect((host, port)).await?;
let ssl_request: [u8; 8] = [
0x00, 0x00, 0x00, 0x08, 0x04, 0xd2, 0x16, 0x2f, ];
tcp_stream.write_all(&ssl_request).await?;
tcp_stream.flush().await?;
let mut response = [0u8; 1];
tcp_stream.read_exact(&mut response).await?;
match response[0] {
b'S' => {
}
b'N' => {
return Err(crate::WireError::Config(
"Server does not support SSL connections".to_string(),
));
}
other => {
return Err(crate::WireError::Config(format!(
"Unexpected SSL response from server: {:02x}",
other
)));
}
}
let server_name = crate::connection::parse_server_name(host)?;
let server_name = rustls_pki_types::ServerName::try_from(server_name)
.map_err(|_| crate::WireError::Config(format!("Invalid hostname for TLS: {}", host)))?;
let client_config = tls_config.client_config();
let tls_connector = tokio_rustls::TlsConnector::from(client_config);
let tls_stream = tls_connector
.connect(server_name, tcp_stream)
.await
.map_err(|e| crate::WireError::Config(format!("TLS handshake failed: {}", e)))?;
Ok(Transport::Tcp(TcpVariant::Tls(tls_stream)))
}
pub async fn connect_unix(path: &Path) -> Result<Self> {
let stream = UnixStream::connect(path).await?;
Ok(Transport::Unix(stream))
}
pub async fn write_all(&mut self, buf: &[u8]) -> Result<()> {
match self {
Transport::Tcp(variant) => variant.write_all(buf).await?,
Transport::Unix(stream) => stream.write_all(buf).await?,
}
Ok(())
}
pub async fn flush(&mut self) -> Result<()> {
match self {
Transport::Tcp(variant) => variant.flush().await?,
Transport::Unix(stream) => stream.flush().await?,
}
Ok(())
}
pub async fn read_buf(&mut self, buf: &mut BytesMut) -> Result<usize> {
let n = match self {
Transport::Tcp(variant) => variant.read_buf(buf).await?,
Transport::Unix(stream) => stream.read_buf(buf).await?,
};
Ok(n)
}
pub async fn shutdown(&mut self) -> Result<()> {
match self {
Transport::Tcp(variant) => variant.shutdown().await?,
Transport::Unix(stream) => stream.shutdown().await?,
}
Ok(())
}
pub fn apply_keepalive(&self, idle: Duration) -> Result<()> {
match self {
Transport::Tcp(variant) => variant.apply_keepalive(idle),
Transport::Unix(_) => Ok(()), }
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_tcp_connect_failure() {
let result = Transport::connect_tcp("localhost", 9999).await;
assert!(
result.is_err(),
"expected Err for connection to closed port 9999, got: {result:?}"
);
}
}