zero_mysql/tokio/
stream.rs

1use core::mem::MaybeUninit;
2use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader};
3use tokio::net::TcpStream;
4#[cfg(unix)]
5use tokio::net::UnixStream;
6
7#[cfg(feature = "tokio-tls")]
8use tokio_native_tls::TlsStream;
9
10pub enum Stream {
11    Tcp(BufReader<TcpStream>),
12    #[cfg(feature = "tokio-tls")]
13    Tls(BufReader<TlsStream<TcpStream>>),
14    #[cfg(unix)]
15    Unix(BufReader<UnixStream>),
16}
17
18impl Stream {
19    pub fn tcp(stream: TcpStream) -> Self {
20        Self::Tcp(BufReader::new(stream))
21    }
22
23    #[cfg(unix)]
24    pub fn unix(stream: UnixStream) -> Self {
25        Self::Unix(BufReader::new(stream))
26    }
27
28    #[cfg(feature = "tokio-tls")]
29    pub async fn upgrade_to_tls(self, host: &str) -> std::io::Result<Self> {
30        let tcp = match self {
31            Self::Tcp(buf_reader) => buf_reader.into_inner(),
32            #[cfg(feature = "tokio-tls")]
33            Self::Tls(_) => {
34                return Err(std::io::Error::new(
35                    std::io::ErrorKind::InvalidInput,
36                    "Already using TLS",
37                ));
38            }
39            #[cfg(unix)]
40            Self::Unix(_) => {
41                return Err(std::io::Error::new(
42                    std::io::ErrorKind::InvalidInput,
43                    "TLS not supported for Unix sockets",
44                ));
45            }
46        };
47
48        let connector = native_tls::TlsConnector::new()
49            .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?;
50        let connector = tokio_native_tls::TlsConnector::from(connector);
51        let tls_stream = connector
52            .connect(host, tcp)
53            .await
54            .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?;
55
56        Ok(Self::Tls(BufReader::new(tls_stream)))
57    }
58
59    pub async fn read_exact(&mut self, buf: &mut [u8]) -> std::io::Result<()> {
60        match self {
61            Self::Tcp(reader) => reader.read_exact(buf).await.map(|_| ()),
62            #[cfg(feature = "tokio-tls")]
63            Self::Tls(reader) => reader.read_exact(buf).await.map(|_| ()),
64            #[cfg(unix)]
65            Self::Unix(reader) => reader.read_exact(buf).await.map(|_| ()),
66        }
67    }
68
69    pub async fn read_buf_exact(&mut self, buf: &mut [MaybeUninit<u8>]) -> std::io::Result<()> {
70        match self {
71            Self::Tcp(reader) => read_buf_exact_impl(reader, buf).await,
72            #[cfg(feature = "tokio-tls")]
73            Self::Tls(reader) => read_buf_exact_impl(reader, buf).await,
74            #[cfg(unix)]
75            Self::Unix(reader) => read_buf_exact_impl(reader, buf).await,
76        }
77    }
78
79    pub async fn write_all(&mut self, buf: &[u8]) -> std::io::Result<()> {
80        match self {
81            Self::Tcp(reader) => reader.get_mut().write_all(buf).await,
82            #[cfg(feature = "tokio-tls")]
83            Self::Tls(reader) => reader.get_mut().write_all(buf).await,
84            #[cfg(unix)]
85            Self::Unix(reader) => reader.get_mut().write_all(buf).await,
86        }
87    }
88
89    pub async fn flush(&mut self) -> std::io::Result<()> {
90        match self {
91            Self::Tcp(reader) => reader.get_mut().flush().await,
92            #[cfg(feature = "tokio-tls")]
93            Self::Tls(reader) => reader.get_mut().flush().await,
94            #[cfg(unix)]
95            Self::Unix(reader) => reader.get_mut().flush().await,
96        }
97    }
98
99    /// Returns true if this is a TCP connection to a loopback address
100    pub fn is_tcp_loopback(&self) -> bool {
101        match self {
102            Self::Tcp(r) => r
103                .get_ref()
104                .peer_addr()
105                .map(|addr| addr.ip().is_loopback())
106                .unwrap_or(false),
107            #[cfg(feature = "tokio-tls")]
108            Self::Tls(r) => r
109                .get_ref()
110                .get_ref()
111                .get_ref()
112                .get_ref()
113                .peer_addr()
114                .map(|addr| addr.ip().is_loopback())
115                .unwrap_or(false),
116            #[cfg(unix)]
117            Self::Unix(_) => false,
118        }
119    }
120}
121
122async fn read_buf_exact_impl<R: AsyncReadExt + Unpin>(
123    reader: &mut R,
124    mut buf: &mut [MaybeUninit<u8>],
125) -> std::io::Result<()> {
126    while !buf.is_empty() {
127        let n = reader.read_buf(&mut buf).await?;
128        if n == 0 {
129            return Err(std::io::Error::new(
130                std::io::ErrorKind::UnexpectedEof,
131                "failed to fill whole buffer",
132            ));
133        }
134    }
135    Ok(())
136}