zero_mysql/sync/
stream.rs1use core::io::BorrowedCursor;
2use std::io::{BufReader, Read, Write};
3use std::net::TcpStream;
4use std::os::unix::net::UnixStream;
5
6#[cfg(feature = "sync-tls")]
7use native_tls::TlsStream;
8
9pub enum Stream {
10 Tcp(BufReader<TcpStream>),
11 #[cfg(feature = "sync-tls")]
12 Tls(BufReader<TlsStream<TcpStream>>),
13 Unix(BufReader<UnixStream>),
14}
15
16impl Stream {
17 pub fn tcp(stream: TcpStream) -> Self {
18 Self::Tcp(BufReader::new(stream))
19 }
20
21 pub fn unix(stream: UnixStream) -> Self {
22 Self::Unix(BufReader::new(stream))
23 }
24
25 #[cfg(feature = "sync-tls")]
26 pub fn upgrade_to_tls(self, host: &str) -> std::io::Result<Self> {
27 let tcp = match self {
28 Self::Tcp(buf_reader) => buf_reader.into_inner(),
29 #[cfg(feature = "sync-tls")]
30 Self::Tls(_) => {
31 return Err(std::io::Error::new(
32 std::io::ErrorKind::InvalidInput,
33 "Already using TLS",
34 ));
35 }
36 Self::Unix(_) => {
37 return Err(std::io::Error::new(
38 std::io::ErrorKind::InvalidInput,
39 "TLS not supported for Unix sockets",
40 ));
41 }
42 };
43
44 let connector = native_tls::TlsConnector::new()
45 .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?;
46 let tls_stream = connector
47 .connect(host, tcp)
48 .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?;
49
50 Ok(Self::Tls(BufReader::new(tls_stream)))
51 }
52
53 pub fn read_exact(&mut self, buf: &mut [u8]) -> std::io::Result<()> {
54 match self {
55 Self::Tcp(r) => r.read_exact(buf),
56 #[cfg(feature = "sync-tls")]
57 Self::Tls(r) => r.read_exact(buf),
58 Self::Unix(r) => r.read_exact(buf),
59 }
60 }
61
62 pub fn read_buf_exact(&mut self, cursor: BorrowedCursor<'_>) -> std::io::Result<()> {
63 match self {
64 Self::Tcp(r) => r.read_buf_exact(cursor),
65 #[cfg(feature = "sync-tls")]
66 Self::Tls(r) => r.read_buf_exact(cursor),
67 Self::Unix(r) => r.read_buf_exact(cursor),
68 }
69 }
70
71 pub fn write_all(&mut self, buf: &[u8]) -> std::io::Result<()> {
72 match self {
73 Self::Tcp(r) => r.get_mut().write_all(buf),
74 #[cfg(feature = "sync-tls")]
75 Self::Tls(r) => r.get_mut().write_all(buf),
76 Self::Unix(r) => r.get_mut().write_all(buf),
77 }
78 }
79
80 pub fn flush(&mut self) -> std::io::Result<()> {
81 match self {
82 Self::Tcp(r) => r.get_mut().flush(),
83 #[cfg(feature = "sync-tls")]
84 Self::Tls(r) => r.get_mut().flush(),
85 Self::Unix(r) => r.get_mut().flush(),
86 }
87 }
88
89 pub fn is_tcp_loopback(&self) -> bool {
91 match self {
92 Self::Tcp(r) => r
93 .get_ref()
94 .peer_addr()
95 .map(|addr| addr.ip().is_loopback())
96 .unwrap_or(false),
97 #[cfg(feature = "sync-tls")]
98 Self::Tls(r) => r
99 .get_ref()
100 .get_ref()
101 .peer_addr()
102 .map(|addr| addr.ip().is_loopback())
103 .unwrap_or(false),
104 Self::Unix(_) => false,
105 }
106 }
107}