zero_mysql/tokio/
stream.rs

1use core::mem::MaybeUninit;
2use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader};
3use tokio::net::{TcpStream, UnixStream};
4
5#[cfg(feature = "tokio-tls")]
6use tokio_native_tls::TlsStream;
7
8pub enum Stream {
9    Tcp(BufReader<TcpStream>),
10    #[cfg(feature = "tokio-tls")]
11    Tls(BufReader<TlsStream<TcpStream>>),
12    Unix(BufReader<UnixStream>),
13}
14
15impl Stream {
16    pub fn tcp(stream: TcpStream) -> Self {
17        Self::Tcp(BufReader::new(stream))
18    }
19
20    pub fn unix(stream: UnixStream) -> Self {
21        Self::Unix(BufReader::new(stream))
22    }
23
24    #[cfg(feature = "tokio-tls")]
25    pub async fn upgrade_to_tls(self, host: &str) -> std::io::Result<Self> {
26        let tcp = match self {
27            Self::Tcp(buf_reader) => buf_reader.into_inner(),
28            #[cfg(feature = "tokio-tls")]
29            Self::Tls(_) => {
30                return Err(std::io::Error::new(
31                    std::io::ErrorKind::InvalidInput,
32                    "Already using TLS",
33                ));
34            }
35            Self::Unix(_) => {
36                return Err(std::io::Error::new(
37                    std::io::ErrorKind::InvalidInput,
38                    "TLS not supported for Unix sockets",
39                ));
40            }
41        };
42
43        let connector = native_tls::TlsConnector::new()
44            .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?;
45        let connector = tokio_native_tls::TlsConnector::from(connector);
46        let tls_stream = connector
47            .connect(host, tcp)
48            .await
49            .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?;
50
51        Ok(Self::Tls(BufReader::new(tls_stream)))
52    }
53
54    pub async fn read_exact(&mut self, buf: &mut [u8]) -> std::io::Result<()> {
55        match self {
56            Self::Tcp(reader) => reader.read_exact(buf).await.map(|_| ()),
57            #[cfg(feature = "tokio-tls")]
58            Self::Tls(reader) => reader.read_exact(buf).await.map(|_| ()),
59            Self::Unix(reader) => reader.read_exact(buf).await.map(|_| ()),
60        }
61    }
62
63    pub async fn read_buf_exact(&mut self, buf: &mut [MaybeUninit<u8>]) -> std::io::Result<()> {
64        match self {
65            Self::Tcp(reader) => read_buf_exact_impl(reader, buf).await,
66            #[cfg(feature = "tokio-tls")]
67            Self::Tls(reader) => read_buf_exact_impl(reader, buf).await,
68            Self::Unix(reader) => read_buf_exact_impl(reader, buf).await,
69        }
70    }
71
72    pub async fn write_all(&mut self, buf: &[u8]) -> std::io::Result<()> {
73        match self {
74            Self::Tcp(reader) => reader.get_mut().write_all(buf).await,
75            #[cfg(feature = "tokio-tls")]
76            Self::Tls(reader) => reader.get_mut().write_all(buf).await,
77            Self::Unix(reader) => reader.get_mut().write_all(buf).await,
78        }
79    }
80
81    pub async fn flush(&mut self) -> std::io::Result<()> {
82        match self {
83            Self::Tcp(reader) => reader.get_mut().flush().await,
84            #[cfg(feature = "tokio-tls")]
85            Self::Tls(reader) => reader.get_mut().flush().await,
86            Self::Unix(reader) => reader.get_mut().flush().await,
87        }
88    }
89
90    /// Returns true if this is a TCP connection to a loopback address
91    pub fn is_tcp_loopback(&self) -> bool {
92        match self {
93            Self::Tcp(r) => r
94                .get_ref()
95                .peer_addr()
96                .map(|addr| addr.ip().is_loopback())
97                .unwrap_or(false),
98            #[cfg(feature = "tokio-tls")]
99            Self::Tls(r) => r
100                .get_ref()
101                .get_ref()
102                .get_ref()
103                .get_ref()
104                .peer_addr()
105                .map(|addr| addr.ip().is_loopback())
106                .unwrap_or(false),
107            Self::Unix(_) => false,
108        }
109    }
110}
111
112async fn read_buf_exact_impl<R: AsyncReadExt + Unpin>(
113    reader: &mut R,
114    mut buf: &mut [MaybeUninit<u8>],
115) -> std::io::Result<()> {
116    while !buf.is_empty() {
117        let n = reader.read_buf(&mut buf).await?;
118        if n == 0 {
119            return Err(std::io::Error::new(
120                std::io::ErrorKind::UnexpectedEof,
121                "failed to fill whole buffer",
122            ));
123        }
124    }
125    Ok(())
126}