zero_mysql/sync/
stream.rs1use std::io::{BufReader, Read, Write};
2use std::mem::MaybeUninit;
3use std::net::TcpStream;
4#[cfg(unix)]
5use std::os::unix::net::UnixStream;
6
7use crate::nightly::read_uninit_exact;
8
9#[cfg(feature = "sync-tls")]
10use native_tls::TlsStream;
11
12pub enum Stream {
13 Tcp(BufReader<TcpStream>),
14 #[cfg(feature = "sync-tls")]
15 Tls(BufReader<TlsStream<TcpStream>>),
16 #[cfg(unix)]
17 Unix(BufReader<UnixStream>),
18}
19
20impl Stream {
21 pub fn tcp(stream: TcpStream) -> Self {
22 Self::Tcp(BufReader::new(stream))
23 }
24
25 #[cfg(unix)]
26 pub fn unix(stream: UnixStream) -> Self {
27 Self::Unix(BufReader::new(stream))
28 }
29
30 #[cfg(feature = "sync-tls")]
31 pub fn upgrade_to_tls(self, host: &str) -> std::io::Result<Self> {
32 let tcp = match self {
33 Self::Tcp(buf_reader) => buf_reader.into_inner(),
34 #[cfg(feature = "sync-tls")]
35 Self::Tls(_) => {
36 return Err(std::io::Error::new(
37 std::io::ErrorKind::InvalidInput,
38 "Already using TLS",
39 ));
40 }
41 #[cfg(unix)]
42 Self::Unix(_) => {
43 return Err(std::io::Error::new(
44 std::io::ErrorKind::InvalidInput,
45 "TLS not supported for Unix sockets",
46 ));
47 }
48 };
49
50 let connector = native_tls::TlsConnector::new()
51 .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?;
52 let tls_stream = connector
53 .connect(host, tcp)
54 .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?;
55
56 Ok(Self::Tls(BufReader::new(tls_stream)))
57 }
58
59 pub fn read_exact(&mut self, buf: &mut [u8]) -> std::io::Result<()> {
60 match self {
61 Self::Tcp(r) => r.read_exact(buf),
62 #[cfg(feature = "sync-tls")]
63 Self::Tls(r) => r.read_exact(buf),
64 #[cfg(unix)]
65 Self::Unix(r) => r.read_exact(buf),
66 }
67 }
68
69 pub fn read_buf_exact(&mut self, buf: &mut [MaybeUninit<u8>]) -> std::io::Result<()> {
70 match self {
71 Self::Tcp(r) => read_uninit_exact(r, buf),
72 #[cfg(feature = "sync-tls")]
73 Self::Tls(r) => read_uninit_exact(r, buf),
74 #[cfg(unix)]
75 Self::Unix(r) => read_uninit_exact(r, buf),
76 }
77 }
78
79 pub fn write_all(&mut self, buf: &[u8]) -> std::io::Result<()> {
80 match self {
81 Self::Tcp(r) => r.get_mut().write_all(buf),
82 #[cfg(feature = "sync-tls")]
83 Self::Tls(r) => r.get_mut().write_all(buf),
84 #[cfg(unix)]
85 Self::Unix(r) => r.get_mut().write_all(buf),
86 }
87 }
88
89 pub fn flush(&mut self) -> std::io::Result<()> {
90 match self {
91 Self::Tcp(r) => r.get_mut().flush(),
92 #[cfg(feature = "sync-tls")]
93 Self::Tls(r) => r.get_mut().flush(),
94 #[cfg(unix)]
95 Self::Unix(r) => r.get_mut().flush(),
96 }
97 }
98
99 pub fn is_tcp_loopback(&self) -> bool {
101 match self {
102 Self::Tcp(r) => r
103 .get_ref()
104 .peer_addr()
105 .map(|addr| addr.ip().is_loopback())
106 .unwrap_or(false),
107 #[cfg(feature = "sync-tls")]
108 Self::Tls(r) => r
109 .get_ref()
110 .get_ref()
111 .peer_addr()
112 .map(|addr| addr.ip().is_loopback())
113 .unwrap_or(false),
114 #[cfg(unix)]
115 Self::Unix(_) => false,
116 }
117 }
118}