zero_mysql/tokio/
stream.rs1use core::mem::MaybeUninit;
2use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader};
3use tokio::net::TcpStream;
4#[cfg(unix)]
5use tokio::net::UnixStream;
6
7#[cfg(feature = "tokio-tls")]
8use tokio_native_tls::TlsStream;
9
10pub enum Stream {
11 Tcp(BufReader<TcpStream>),
12 #[cfg(feature = "tokio-tls")]
13 Tls(BufReader<TlsStream<TcpStream>>),
14 #[cfg(unix)]
15 Unix(BufReader<UnixStream>),
16}
17
18impl Stream {
19 pub fn tcp(stream: TcpStream) -> Self {
20 Self::Tcp(BufReader::new(stream))
21 }
22
23 #[cfg(unix)]
24 pub fn unix(stream: UnixStream) -> Self {
25 Self::Unix(BufReader::new(stream))
26 }
27
28 #[cfg(feature = "tokio-tls")]
29 pub async fn upgrade_to_tls(self, host: &str) -> std::io::Result<Self> {
30 let tcp = match self {
31 Self::Tcp(buf_reader) => buf_reader.into_inner(),
32 #[cfg(feature = "tokio-tls")]
33 Self::Tls(_) => {
34 return Err(std::io::Error::new(
35 std::io::ErrorKind::InvalidInput,
36 "Already using TLS",
37 ));
38 }
39 #[cfg(unix)]
40 Self::Unix(_) => {
41 return Err(std::io::Error::new(
42 std::io::ErrorKind::InvalidInput,
43 "TLS not supported for Unix sockets",
44 ));
45 }
46 };
47
48 let connector = native_tls::TlsConnector::new()
49 .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?;
50 let connector = tokio_native_tls::TlsConnector::from(connector);
51 let tls_stream = connector
52 .connect(host, tcp)
53 .await
54 .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?;
55
56 Ok(Self::Tls(BufReader::new(tls_stream)))
57 }
58
59 pub async fn read_exact(&mut self, buf: &mut [u8]) -> std::io::Result<()> {
60 match self {
61 Self::Tcp(reader) => reader.read_exact(buf).await.map(|_| ()),
62 #[cfg(feature = "tokio-tls")]
63 Self::Tls(reader) => reader.read_exact(buf).await.map(|_| ()),
64 #[cfg(unix)]
65 Self::Unix(reader) => reader.read_exact(buf).await.map(|_| ()),
66 }
67 }
68
69 pub async fn read_buf_exact(&mut self, buf: &mut [MaybeUninit<u8>]) -> std::io::Result<()> {
70 match self {
71 Self::Tcp(reader) => read_buf_exact_impl(reader, buf).await,
72 #[cfg(feature = "tokio-tls")]
73 Self::Tls(reader) => read_buf_exact_impl(reader, buf).await,
74 #[cfg(unix)]
75 Self::Unix(reader) => read_buf_exact_impl(reader, buf).await,
76 }
77 }
78
79 pub async fn write_all(&mut self, buf: &[u8]) -> std::io::Result<()> {
80 match self {
81 Self::Tcp(reader) => reader.get_mut().write_all(buf).await,
82 #[cfg(feature = "tokio-tls")]
83 Self::Tls(reader) => reader.get_mut().write_all(buf).await,
84 #[cfg(unix)]
85 Self::Unix(reader) => reader.get_mut().write_all(buf).await,
86 }
87 }
88
89 pub async fn flush(&mut self) -> std::io::Result<()> {
90 match self {
91 Self::Tcp(reader) => reader.get_mut().flush().await,
92 #[cfg(feature = "tokio-tls")]
93 Self::Tls(reader) => reader.get_mut().flush().await,
94 #[cfg(unix)]
95 Self::Unix(reader) => reader.get_mut().flush().await,
96 }
97 }
98
99 pub fn is_tcp_loopback(&self) -> bool {
101 match self {
102 Self::Tcp(r) => r
103 .get_ref()
104 .peer_addr()
105 .map(|addr| addr.ip().is_loopback())
106 .unwrap_or(false),
107 #[cfg(feature = "tokio-tls")]
108 Self::Tls(r) => r
109 .get_ref()
110 .get_ref()
111 .get_ref()
112 .get_ref()
113 .peer_addr()
114 .map(|addr| addr.ip().is_loopback())
115 .unwrap_or(false),
116 #[cfg(unix)]
117 Self::Unix(_) => false,
118 }
119 }
120}
121
122async fn read_buf_exact_impl<R: AsyncReadExt + Unpin>(
123 reader: &mut R,
124 mut buf: &mut [MaybeUninit<u8>],
125) -> std::io::Result<()> {
126 while !buf.is_empty() {
127 let n = reader.read_buf(&mut buf).await?;
128 if n == 0 {
129 return Err(std::io::Error::new(
130 std::io::ErrorKind::UnexpectedEof,
131 "failed to fill whole buffer",
132 ));
133 }
134 }
135 Ok(())
136}