1use std::io::{self, ErrorKind, Read, Write};
2use std::net::{SocketAddr, TcpStream};
3use std::os::unix::io::{AsRawFd, FromRawFd, RawFd};
4use std::sync::Arc;
5
6#[derive(Debug)]
7pub enum Connection {
8 Plain(TcpStream),
9 Tls {
10 tls: rustls::ClientConnection,
11 stream: TcpStream,
12 },
13}
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq)]
16pub enum IoHint {
17 Read,
18 Write,
19 ReadWrite,
20 Ready,
21}
22
23impl Connection {
24 pub fn connect_nonblocking(host: &str, port: u16) -> Result<TcpStream, String> {
25 let addr = format!("{host}:{port}");
26 let addrs: Vec<SocketAddr> = std::net::ToSocketAddrs::to_socket_addrs(&addr)
27 .map_err(|e| format!("dns resolve failed for {host}:{port}: {e}"))?
28 .collect();
29
30 for addr in addrs {
31 let domain = if addr.is_ipv6() {
32 libc::AF_INET6
33 } else {
34 libc::AF_INET
35 };
36 let sock = unsafe {
37 libc::socket(
38 domain,
39 libc::SOCK_STREAM | libc::SOCK_NONBLOCK | libc::SOCK_CLOEXEC,
40 0,
41 )
42 };
43 if sock < 0 {
44 continue;
45 }
46
47 let (addr_ptr, addr_len) = socket_addr_to_raw(&addr);
48 let ret = unsafe { libc::connect(sock, addr_ptr, addr_len) };
49
50 if ret == 0 {
51 return Ok(unsafe { TcpStream::from_raw_fd(sock) });
52 }
53
54 let errno = unsafe { *libc::__errno_location() };
55 if errno == libc::EINPROGRESS {
56 return Ok(unsafe { TcpStream::from_raw_fd(sock) });
57 }
58
59 unsafe { libc::close(sock) };
60 }
61
62 Err(format!("connect to {host}:{port} failed"))
63 }
64
65 pub fn check_connect(stream: &TcpStream) -> Result<(), String> {
66 let mut err: i32 = 0;
67 let mut err_len: u32 = std::mem::size_of::<i32>() as u32;
68 let ret = unsafe {
69 libc::getsockopt(
70 stream.as_raw_fd(),
71 libc::SOL_SOCKET,
72 libc::SO_ERROR,
73 &mut err as *mut _ as *mut _,
74 &mut err_len,
75 )
76 };
77 if ret < 0 {
78 return Err("getsockopt failed".into());
79 }
80 if err != 0 {
81 return Err(format!("connect failed: errno {err}"));
82 }
83 Ok(())
84 }
85
86 pub fn start_tls(
87 host: &str,
88 stream: TcpStream,
89 extra_roots: &[Vec<u8>],
90 ) -> Result<Self, String> {
91 let mut root_certs =
92 rustls::RootCertStore::from_iter(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
93 for cert_der in extra_roots {
94 root_certs
95 .add(cert_der.clone().into())
96 .map_err(|e| format!("add root cert failed: {e}"))?;
97 }
98 let config = rustls::ClientConfig::builder()
99 .with_root_certificates(root_certs)
100 .with_no_client_auth();
101 let server_name = rustls::pki_types::ServerName::try_from(host)
102 .map_err(|e| format!("invalid server name: {e}"))?
103 .to_owned();
104 let tls_conn = rustls::ClientConnection::new(Arc::new(config), server_name)
105 .map_err(|e| format!("tls init failed: {e}"))?;
106 Ok(Connection::Tls {
107 tls: tls_conn,
108 stream,
109 })
110 }
111
112 pub fn tls_handshake_step(&mut self) -> Result<IoHint, String> {
113 match self {
114 Connection::Plain(_) => Ok(IoHint::Ready),
115 Connection::Tls { tls, stream } => {
116 if !tls.is_handshaking() {
117 return Ok(IoHint::Ready);
118 }
119
120 let mut need_read = false;
121 let mut need_write = false;
122
123 if tls.wants_read() {
124 match tls.read_tls(stream) {
125 Ok(_) => {
126 tls.process_new_packets()
127 .map_err(|e| format!("tls process error: {e}"))?;
128 }
129 Err(e) if e.kind() == ErrorKind::WouldBlock => {
130 need_read = true;
131 }
132 Err(e) => return Err(format!("tls read error: {e}")),
133 }
134 }
135
136 if tls.wants_write() {
137 match tls.write_tls(stream) {
138 Ok(_) => {}
139 Err(e) if e.kind() == ErrorKind::WouldBlock => {
140 need_write = true;
141 }
142 Err(e) => return Err(format!("tls write error: {e}")),
143 }
144 }
145
146 if !tls.is_handshaking() {
147 Ok(IoHint::Ready)
148 } else {
149 match (need_read, need_write) {
150 (true, true) => Ok(IoHint::ReadWrite),
151 (true, false) => Ok(IoHint::Read),
152 (false, true) => Ok(IoHint::Write),
153 (false, false) => {
154 if tls.wants_read() && tls.wants_write() {
155 Ok(IoHint::ReadWrite)
156 } else if tls.wants_read() {
157 Ok(IoHint::Read)
158 } else {
159 Ok(IoHint::Write)
160 }
161 }
162 }
163 }
164 }
165 }
166 }
167
168 pub fn tls_wants_read(&self) -> bool {
169 match self {
170 Connection::Plain(_) => false,
171 Connection::Tls { tls, .. } => tls.wants_read() || tls.is_handshaking(),
172 }
173 }
174
175 pub fn tls_wants_write(&self) -> bool {
176 match self {
177 Connection::Plain(_) => false,
178 Connection::Tls { tls, .. } => tls.wants_write() || tls.is_handshaking(),
179 }
180 }
181
182 pub fn set_nonblocking(&self, nonblocking: bool) -> Result<(), String> {
183 match self {
184 Connection::Plain(stream) => stream
185 .set_nonblocking(nonblocking)
186 .map_err(|e| format!("set_nonblocking failed: {e}")),
187 Connection::Tls { stream, .. } => stream
188 .set_nonblocking(nonblocking)
189 .map_err(|e| format!("set_nonblocking failed: {e}")),
190 }
191 }
192
193 pub fn raw_fd(&self) -> RawFd {
194 match self {
195 Connection::Plain(stream) => stream.as_raw_fd(),
196 Connection::Tls { stream, .. } => stream.as_raw_fd(),
197 }
198 }
199
200 pub fn is_tls(&self) -> bool {
201 matches!(self, Connection::Tls { .. })
202 }
203
204 pub fn set_read_timeout(&self, dur: Option<std::time::Duration>) -> Result<(), String> {
205 let stream = match self {
206 Connection::Plain(s) => s,
207 Connection::Tls { stream, .. } => stream,
208 };
209 stream
210 .set_read_timeout(dur)
211 .map_err(|e| format!("set_read_timeout failed: {e}"))
212 }
213}
214
215impl Read for Connection {
216 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
217 match self {
218 Connection::Plain(stream) => stream.read(buf),
219 Connection::Tls { tls, stream } => loop {
220 match tls.read_tls(stream) {
221 Ok(0) => {
222 tls.process_new_packets()
223 .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
224 return tls.reader().read(buf);
225 }
226 Ok(_) => {
227 tls.process_new_packets()
228 .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
229 match tls.reader().read(buf) {
230 Ok(n) => return Ok(n),
231 Err(e) if e.kind() == ErrorKind::WouldBlock => continue,
232 Err(e) => return Err(e),
233 }
234 }
235 Err(e) if e.kind() == ErrorKind::WouldBlock => match tls.reader().read(buf) {
236 Ok(n) => return Ok(n),
237 Err(e2) if e2.kind() == ErrorKind::WouldBlock => return Err(e),
238 Err(e2) => return Err(e2),
239 },
240 Err(e) => return Err(e),
241 }
242 },
243 }
244 }
245}
246
247impl Write for Connection {
248 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
249 match self {
250 Connection::Plain(stream) => stream.write(buf),
251 Connection::Tls { tls, stream } => {
252 let n = tls.writer().write(buf)?;
253 let _ = tls.write_tls(stream);
254 Ok(n)
255 }
256 }
257 }
258
259 fn flush(&mut self) -> io::Result<()> {
260 match self {
261 Connection::Plain(stream) => stream.flush(),
262 Connection::Tls { stream, .. } => stream.flush(),
263 }
264 }
265}
266
267impl AsRawFd for Connection {
268 fn as_raw_fd(&self) -> RawFd {
269 self.raw_fd()
270 }
271}
272
273fn socket_addr_to_raw(addr: &SocketAddr) -> (*const libc::sockaddr, u32) {
274 match addr {
275 SocketAddr::V4(v4) => {
276 let raw: libc::sockaddr_in = libc::sockaddr_in {
277 sin_family: libc::AF_INET as u16,
278 sin_port: v4.port().to_be(),
279 sin_addr: libc::in_addr {
280 s_addr: u32::from_ne_bytes(v4.ip().octets()),
281 },
282 sin_zero: [0; 8],
283 };
284 (
285 &raw as *const _ as *const libc::sockaddr,
286 std::mem::size_of::<libc::sockaddr_in>() as u32,
287 )
288 }
289 SocketAddr::V6(v6) => {
290 let raw = libc::sockaddr_in6 {
291 sin6_family: libc::AF_INET6 as u16,
292 sin6_port: v6.port().to_be(),
293 sin6_flowinfo: v6.flowinfo(),
294 sin6_addr: libc::in6_addr {
295 s6_addr: v6.ip().octets(),
296 },
297 sin6_scope_id: v6.scope_id(),
298 };
299 (
300 &raw as *const _ as *const libc::sockaddr,
301 std::mem::size_of::<libc::sockaddr_in6>() as u32,
302 )
303 }
304 }
305}