1use std::io::{self, Read, Write};
2use std::net::TcpStream;
3use std::os::unix::io::{AsRawFd, RawFd};
4use std::sync::mpsc;
5use std::thread;
6use std::time::Duration;
7
8pub enum Connection {
9 Plain(TcpStream),
10 Tls {
11 tls: rustls::StreamOwned<rustls::ClientConnection, TcpStream>,
12 },
13}
14
15impl Connection {
16 pub fn connect_async(
17 host: String,
18 port: u16,
19 use_tls: bool,
20 ) -> Result<mpsc::Receiver<Result<Self, String>>, String> {
21 Self::connect_async_with_roots(host, port, use_tls, Vec::new())
22 }
23
24 pub fn connect_async_with_roots(
25 host: String,
26 port: u16,
27 use_tls: bool,
28 extra_roots: Vec<Vec<u8>>,
29 ) -> Result<mpsc::Receiver<Result<Self, String>>, String> {
30 let (tx, rx) = mpsc::channel();
31 thread::spawn(move || {
32 let result = Self::connect_blocking(&host, port, use_tls, &extra_roots);
33 let _ = tx.send(result);
34 });
35 Ok(rx)
36 }
37
38 fn connect_blocking(
39 host: &str,
40 port: u16,
41 use_tls: bool,
42 extra_roots: &[Vec<u8>],
43 ) -> Result<Self, String> {
44 let addr = format!("{host}:{port}");
45 let stream =
46 TcpStream::connect(&addr).map_err(|e| format!("connect to {addr} failed: {e}"))?;
47 stream
48 .set_read_timeout(None)
49 .map_err(|e| format!("set_read_timeout failed: {e}"))?;
50 stream
51 .set_write_timeout(None)
52 .map_err(|e| format!("set_write_timeout failed: {e}"))?;
53 stream
54 .set_nonblocking(true)
55 .map_err(|e| format!("set_nonblocking failed: {e}"))?;
56
57 if use_tls {
58 Self::wrap_tls(host, stream, extra_roots)
59 } else {
60 Ok(Connection::Plain(stream))
61 }
62 }
63
64 fn wrap_tls(host: &str, stream: TcpStream, extra_roots: &[Vec<u8>]) -> Result<Self, String> {
65 let mut root_certs =
66 rustls::RootCertStore::from_iter(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
67 for cert_der in extra_roots {
68 root_certs
69 .add(cert_der.clone().into())
70 .map_err(|e| format!("add root cert failed: {e}"))?;
71 }
72 let config = rustls::ClientConfig::builder()
73 .with_root_certificates(root_certs)
74 .with_no_client_auth();
75 let server_name = rustls::pki_types::ServerName::try_from(host)
76 .map_err(|e| format!("invalid server name: {e}"))?
77 .to_owned();
78 let tls_conn = rustls::ClientConnection::new(std::sync::Arc::new(config), server_name)
79 .map_err(|e| format!("tls handshake failed: {e}"))?;
80 let tls = rustls::StreamOwned::new(tls_conn, stream);
81 Ok(Connection::Tls { tls })
82 }
83
84 pub fn set_nonblocking(&self, nonblocking: bool) -> Result<(), String> {
85 match self {
86 Connection::Plain(stream) => stream
87 .set_nonblocking(nonblocking)
88 .map_err(|e| format!("set_nonblocking failed: {e}")),
89 Connection::Tls { tls } => tls
90 .sock
91 .set_nonblocking(nonblocking)
92 .map_err(|e| format!("set_nonblocking failed: {e}")),
93 }
94 }
95
96 pub fn raw_fd(&self) -> RawFd {
97 match self {
98 Connection::Plain(stream) => stream.as_raw_fd(),
99 Connection::Tls { tls } => tls.sock.as_raw_fd(),
100 }
101 }
102
103 pub fn is_tls(&self) -> bool {
104 matches!(self, Connection::Tls { .. })
105 }
106
107 pub fn set_read_timeout(&self, dur: Option<Duration>) -> Result<(), String> {
108 match self {
109 Connection::Plain(stream) => stream
110 .set_read_timeout(dur)
111 .map_err(|e| format!("set_read_timeout failed: {e}")),
112 Connection::Tls { tls } => tls
113 .sock
114 .set_read_timeout(dur)
115 .map_err(|e| format!("set_read_timeout failed: {e}")),
116 }
117 }
118
119 pub fn try_clone(&self) -> Result<Self, String> {
120 match self {
121 Connection::Plain(stream) => stream
122 .try_clone()
123 .map(Connection::Plain)
124 .map_err(|e| format!("try_clone failed: {e}")),
125 Connection::Tls { .. } => Err("cannot clone TLS connection".into()),
126 }
127 }
128}
129
130impl Read for Connection {
131 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
132 match self {
133 Connection::Plain(stream) => stream.read(buf),
134 Connection::Tls { tls } => tls.read(buf),
135 }
136 }
137}
138
139impl Write for Connection {
140 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
141 match self {
142 Connection::Plain(stream) => stream.write(buf),
143 Connection::Tls { tls } => tls.write(buf),
144 }
145 }
146
147 fn flush(&mut self) -> io::Result<()> {
148 match self {
149 Connection::Plain(stream) => stream.flush(),
150 Connection::Tls { tls } => tls.sock.flush(),
151 }
152 }
153}
154
155impl AsRawFd for Connection {
156 fn as_raw_fd(&self) -> RawFd {
157 self.raw_fd()
158 }
159}