1use std::{
24 borrow::Cow,
25 os::{
26 fd::{AsFd, BorrowedFd, OwnedFd},
27 unix::io::AsRawFd,
28 },
29 time::Duration,
30};
31
32use flate2::{read::DeflateDecoder, write::DeflateEncoder, Compression};
33#[cfg(any(target_os = "openbsd", target_os = "netbsd", target_os = "haiku"))]
34use libc::SO_KEEPALIVE as KEEPALIVE_OPTION;
35#[cfg(any(target_os = "macos", target_os = "ios"))]
36use libc::TCP_KEEPALIVE as KEEPALIVE_OPTION;
37#[cfg(not(any(
38 target_os = "openbsd",
39 target_os = "netbsd",
40 target_os = "haiku",
41 target_os = "macos",
42 target_os = "ios"
43)))]
44use libc::TCP_KEEPIDLE as KEEPALIVE_OPTION;
45use libc::{self, c_int, c_void};
46
47pub mod std_net;
49
50pub const CONNECTION_ATTEMPT_DELAY: std::time::Duration = std::time::Duration::from_millis(250);
51
52pub enum Connection {
53 Tcp {
54 inner: std::net::TcpStream,
55 id: Option<&'static str>,
56 trace: bool,
57 },
58 Fd {
59 inner: OwnedFd,
60 id: Option<&'static str>,
61 trace: bool,
62 },
63 #[cfg(feature = "tls")]
64 Tls {
65 inner: native_tls::TlsStream<Self>,
66 id: Option<&'static str>,
67 trace: bool,
68 },
69 Deflate {
70 inner: DeflateEncoder<DeflateDecoder<Box<Self>>>,
71 id: Option<&'static str>,
72 trace: bool,
73 },
74}
75
76impl std::fmt::Debug for Connection {
77 fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
78 match self {
79 Tcp {
80 ref trace,
81 ref inner,
82 ref id,
83 } => fmt
84 .debug_struct(crate::identify!(Connection))
85 .field("variant", &stringify!(Tcp))
86 .field(stringify!(trace), trace)
87 .field(stringify!(id), id)
88 .field(stringify!(inner), inner)
89 .finish(),
90 #[cfg(feature = "tls")]
91 Tls {
92 ref trace,
93 ref inner,
94 ref id,
95 } => fmt
96 .debug_struct(crate::identify!(Connection))
97 .field("variant", &stringify!(Tls))
98 .field(stringify!(trace), trace)
99 .field(stringify!(id), id)
100 .field(stringify!(inner), inner.get_ref())
101 .finish(),
102 Fd {
103 ref trace,
104 ref inner,
105 ref id,
106 } => fmt
107 .debug_struct(crate::identify!(Connection))
108 .field("variant", &stringify!(Fd))
109 .field(stringify!(trace), trace)
110 .field(stringify!(id), id)
111 .field(stringify!(inner), inner)
112 .finish(),
113 Deflate {
114 ref trace,
115 ref inner,
116 ref id,
117 } => fmt
118 .debug_struct(crate::identify!(Connection))
119 .field("variant", &stringify!(Deflate))
120 .field(stringify!(trace), trace)
121 .field(stringify!(id), id)
122 .field(stringify!(inner), inner)
123 .finish(),
124 }
125 }
126}
127
128use Connection::*;
129
130macro_rules! syscall {
131 ($fn: ident ( $($arg: expr),* $(,)* ) ) => {{
132 #[allow(unused_unsafe)]
133 let res = unsafe { libc::$fn($($arg, )*) };
134 if res == -1 {
135 Err(std::io::Error::last_os_error())
136 } else {
137 Ok(res)
138 }
139 }};
140}
141
142pub enum SockOpts {
147 KeepAlive {
194 enable: bool,
195 duration: Option<Duration>,
196 },
197 TcpNoDelay {
198 enable: bool,
199 },
200}
201
202impl Connection {
203 pub const IO_BUF_SIZE: usize = 64 * 1024;
204
205 pub fn deflate(mut self) -> Self {
206 let trace = self.is_trace_enabled();
207 let id = self.id();
208 self.set_trace(false);
209 Self::Deflate {
210 inner: DeflateEncoder::new(
211 DeflateDecoder::new_with_buf(Box::new(self), vec![0; Self::IO_BUF_SIZE]),
212 Compression::default(),
213 ),
214 id,
215 trace,
216 }
217 }
218
219 #[cfg(feature = "tls")]
220 pub fn new_tls(mut inner: native_tls::TlsStream<Self>) -> Self {
221 let trace = inner.get_ref().is_trace_enabled();
222 let id = inner.get_ref().id();
223 if trace {
224 inner.get_mut().set_trace(false);
225 }
226 Self::Tls { inner, id, trace }
227 }
228
229 pub fn new_tcp(inner: std::net::TcpStream) -> Self {
230 let ret = Self::Tcp {
231 inner,
232 id: None,
233 trace: false,
234 };
235 _ = ret.setsockopt(SockOpts::TcpNoDelay { enable: true });
236
237 ret
238 }
239
240 pub fn trace(mut self, val: bool) -> Self {
241 match self {
242 Tcp { ref mut trace, .. } => *trace = val,
243 #[cfg(feature = "tls")]
244 Tls { ref mut trace, .. } => *trace = val,
245 Fd { ref mut trace, .. } => {
246 *trace = val;
247 }
248 Deflate { ref mut trace, .. } => *trace = val,
249 }
250 self
251 }
252
253 pub fn with_id(mut self, val: &'static str) -> Self {
254 match self {
255 Tcp { ref mut id, .. } => *id = Some(val),
256 #[cfg(feature = "tls")]
257 Tls { ref mut id, .. } => *id = Some(val),
258 Fd { ref mut id, .. } => {
259 *id = Some(val);
260 }
261 Deflate { ref mut id, .. } => *id = Some(val),
262 }
263 self
264 }
265
266 pub fn set_trace(&mut self, val: bool) {
267 match self {
268 Tcp { ref mut trace, .. } => *trace = val,
269 #[cfg(feature = "tls")]
270 Tls { ref mut trace, .. } => *trace = val,
271 Fd { ref mut trace, .. } => {
272 *trace = val;
273 }
274 Deflate { ref mut trace, .. } => *trace = val,
275 }
276 }
277
278 pub fn set_nonblocking(&self, nonblocking: bool) -> std::io::Result<()> {
279 if self.is_trace_enabled() {
280 let id = self.id();
281 log::trace!(
282 "{}{}{}{:?} set_nonblocking({:?})",
283 if id.is_some() { "[" } else { "" },
284 if let Some(id) = id.as_ref() { id } else { "" },
285 if id.is_some() { "]: " } else { "" },
286 self,
287 nonblocking
288 );
289 }
290 match self {
291 Tcp { ref inner, .. } => inner.set_nonblocking(nonblocking),
292 #[cfg(feature = "tls")]
293 Tls { ref inner, .. } => inner.get_ref().set_nonblocking(nonblocking),
294 Fd { inner, .. } => {
295 nix::fcntl::fcntl(
297 inner.as_raw_fd(),
298 nix::fcntl::FcntlArg::F_SETFL(if nonblocking {
299 nix::fcntl::OFlag::O_NONBLOCK
300 } else {
301 !nix::fcntl::OFlag::O_NONBLOCK
302 }),
303 )
304 .map_err(|err| std::io::Error::from_raw_os_error(err as i32))?;
305 Ok(())
306 }
307 Deflate { ref inner, .. } => inner.get_ref().get_ref().set_nonblocking(nonblocking),
308 }
309 }
310
311 pub fn set_read_timeout(&self, dur: Option<Duration>) -> std::io::Result<()> {
312 if self.is_trace_enabled() {
313 let id = self.id();
314 log::trace!(
315 "{}{}{}{:?} set_read_timeout({:?})",
316 if id.is_some() { "[" } else { "" },
317 if let Some(id) = id.as_ref() { id } else { "" },
318 if id.is_some() { "]: " } else { "" },
319 self,
320 dur
321 );
322 }
323 match self {
324 Tcp { ref inner, .. } => inner.set_read_timeout(dur),
325 #[cfg(feature = "tls")]
326 Tls { ref inner, .. } => inner.get_ref().set_read_timeout(dur),
327 Fd { .. } => Ok(()),
328 Deflate { ref inner, .. } => inner.get_ref().get_ref().set_read_timeout(dur),
329 }
330 }
331
332 pub fn set_write_timeout(&self, dur: Option<Duration>) -> std::io::Result<()> {
333 if self.is_trace_enabled() {
334 let id = self.id();
335 log::trace!(
336 "{}{}{}{:?} set_write_timeout({:?})",
337 if id.is_some() { "[" } else { "" },
338 if let Some(id) = id.as_ref() { id } else { "" },
339 if id.is_some() { "]: " } else { "" },
340 self,
341 dur
342 );
343 }
344 match self {
345 Tcp { ref inner, .. } => inner.set_write_timeout(dur),
346 #[cfg(feature = "tls")]
347 Tls { ref inner, .. } => inner.get_ref().set_write_timeout(dur),
348 Fd { .. } => Ok(()),
349 Deflate { ref inner, .. } => inner.get_ref().get_ref().set_write_timeout(dur),
350 }
351 }
352
353 pub fn keepalive(&self) -> std::io::Result<Option<Duration>> {
354 if self.is_trace_enabled() {
355 log::trace!("{:?} keepalive()", self);
356 }
357 if matches!(self, Fd { .. }) {
358 return Ok(None);
359 }
360 unsafe {
361 let raw: c_int = self.__getsockopt(libc::SOL_SOCKET, libc::SO_KEEPALIVE)?;
362 if raw == 0 {
363 return Ok(None);
364 }
365 let secs: c_int = self.__getsockopt(libc::IPPROTO_TCP, KEEPALIVE_OPTION)?;
366 Ok(Some(Duration::new(secs as u64, 0)))
367 }
368 }
369
370 pub fn set_keepalive(&self, keepalive: Option<Duration>) -> std::io::Result<()> {
371 if self.is_trace_enabled() {
372 let id = self.id();
373 log::trace!(
374 "{}{}{}{:?} set_keepalive({:?})",
375 if id.is_some() { "[" } else { "" },
376 if let Some(id) = id.as_ref() { id } else { "" },
377 if id.is_some() { "]: " } else { "" },
378 self,
379 keepalive
380 );
381 }
382 if matches!(self, Fd { .. }) {
383 return Ok(());
384 }
385 self.setsockopt(SockOpts::KeepAlive {
386 enable: keepalive.is_some(),
387 duration: keepalive,
388 })
389 }
390
391 unsafe fn inner_setsockopt<T>(&self, opt: c_int, val: c_int, payload: T) -> std::io::Result<()>
392 where
393 T: Copy,
394 {
395 let payload = std::ptr::addr_of!(payload) as *const c_void;
396 syscall!(setsockopt(
397 self.as_raw_fd(),
398 opt,
399 val,
400 payload,
401 std::mem::size_of::<T>() as libc::socklen_t,
402 ))?;
403 Ok(())
404 }
405
406 pub fn setsockopt(&self, option: SockOpts) -> std::io::Result<()> {
407 match option {
408 SockOpts::KeepAlive {
409 enable: true,
410 duration,
411 } => {
412 unsafe {
413 self.inner_setsockopt(libc::SOL_SOCKET, libc::SO_KEEPALIVE, <c_int>::from(true))
414 }?;
415 if let Some(dur) = duration {
416 unsafe {
417 self.inner_setsockopt(
418 libc::IPPROTO_TCP,
419 KEEPALIVE_OPTION,
420 dur.as_secs() as c_int,
421 )
422 }?;
423 }
424 Ok(())
425 }
426 SockOpts::KeepAlive {
427 enable: false,
428 duration: _,
429 } => unsafe {
430 self.inner_setsockopt(libc::SOL_SOCKET, libc::SO_KEEPALIVE, <c_int>::from(false))
431 },
432 SockOpts::TcpNoDelay { enable } => unsafe {
433 #[cfg(any(
434 target_os = "openbsd",
435 target_os = "netbsd",
436 target_os = "haiku",
437 target_os = "macos",
438 target_os = "ios"
439 ))]
440 {
441 self.inner_setsockopt(
442 libc::IPPROTO_TCP,
443 libc::TCP_NODELAY,
444 if enable { c_int::from(1_u8) } else { 0 },
445 )
446 }
447 #[cfg(not(any(
448 target_os = "openbsd",
449 target_os = "netbsd",
450 target_os = "haiku",
451 target_os = "macos",
452 target_os = "ios"
453 )))]
454 {
455 self.inner_setsockopt(
456 libc::SOL_TCP,
457 libc::TCP_NODELAY,
458 if enable { c_int::from(1_u8) } else { 0 },
459 )
460 }
461 },
462 }
463 }
464
465 #[inline]
466 unsafe fn __getsockopt<T: Copy>(&self, opt: c_int, val: c_int) -> std::io::Result<T> {
467 let mut slot: T = unsafe { std::mem::zeroed() };
468 let mut len = std::mem::size_of::<T>() as libc::socklen_t;
469 syscall!(getsockopt(
470 self.as_raw_fd(),
471 opt,
472 val,
473 std::ptr::addr_of_mut!(slot) as *mut _,
474 &mut len,
475 ))?;
476 assert_eq!(len as usize, std::mem::size_of::<T>());
477 Ok(slot)
478 }
479
480 fn is_trace_enabled(&self) -> bool {
481 match self {
482 Fd { trace, .. } | Tcp { trace, .. } => *trace,
483 #[cfg(feature = "tls")]
484 Tls { trace, .. } => *trace,
485 Deflate { trace, .. } => *trace,
486 }
487 }
488
489 fn id(&self) -> Option<&'static str> {
490 match self {
491 Fd { id, .. } | Tcp { id, .. } => *id,
492 #[cfg(feature = "tls")]
493 Tls { id, .. } => *id,
494 Deflate { id, .. } => *id,
495 }
496 }
497}
498
499impl std::io::Read for Connection {
500 fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
501 let res = match self {
502 Tcp { ref mut inner, .. } => inner.read(buf),
503 #[cfg(feature = "tls")]
504 Tls { ref mut inner, .. } => inner.read(buf),
505 Fd { ref inner, .. } => {
506 use std::os::unix::io::{FromRawFd, IntoRawFd};
507 let mut f = unsafe { std::fs::File::from_raw_fd(inner.as_raw_fd()) };
508 let ret = f.read(buf);
509 let _ = f.into_raw_fd();
510 ret
511 }
512 Deflate { ref mut inner, .. } => inner.read(buf),
513 };
514 if self.is_trace_enabled() {
515 let id = self.id();
516 match &res {
517 Ok(len) => {
518 let slice = &buf[..*len];
519 log::trace!(
520 "{}{}{}{:?} read {:?} bytes:{}",
521 if id.is_some() { "[" } else { "" },
522 if let Some(id) = id.as_ref() { id } else { "" },
523 if id.is_some() { "]: " } else { "" },
524 self,
525 len,
526 std::str::from_utf8(slice)
527 .map(Cow::Borrowed)
528 .or_else(|_| crate::text::hex::bytes_to_hex(slice).map(Cow::Owned))
529 .unwrap_or(Cow::Borrowed("Could not convert to hex."))
530 );
531 }
532 Err(err) if matches!(err.kind(), std::io::ErrorKind::WouldBlock) => {}
533 Err(err) => {
534 log::trace!(
535 "{}{}{}{:?} could not read {:?}",
536 if id.is_some() { "[" } else { "" },
537 if let Some(id) = id.as_ref() { id } else { "" },
538 if id.is_some() { "]: " } else { "" },
539 self,
540 err,
541 );
542 }
543 }
544 }
545 res
546 }
547}
548
549impl std::io::Write for Connection {
550 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
551 if self.is_trace_enabled() {
552 let id = self.id();
553 log::trace!(
554 "{}{}{}{:?} writing {} bytes:{}",
555 if id.is_some() { "[" } else { "" },
556 if let Some(id) = id.as_ref() { id } else { "" },
557 if id.is_some() { "]: " } else { "" },
558 self,
559 buf.len(),
560 std::str::from_utf8(buf)
561 .map(Cow::Borrowed)
562 .or_else(|_| crate::text::hex::bytes_to_hex(buf).map(Cow::Owned))
563 .unwrap_or(Cow::Borrowed("Could not convert to hex."))
564 );
565 }
566 match self {
567 Tcp { ref mut inner, .. } => inner.write(buf),
568 #[cfg(feature = "tls")]
569 Tls { ref mut inner, .. } => inner.write(buf),
570 Fd { ref inner, .. } => {
571 use std::os::unix::io::{FromRawFd, IntoRawFd};
572 let mut f = unsafe { std::fs::File::from_raw_fd(inner.as_raw_fd()) };
573 let ret = f.write(buf);
574 let _ = f.into_raw_fd();
575 ret
576 }
577 Deflate { ref mut inner, .. } => inner.write(buf),
578 }
579 }
580
581 fn flush(&mut self) -> std::io::Result<()> {
582 match self {
583 Tcp { ref mut inner, .. } => inner.flush(),
584 #[cfg(feature = "tls")]
585 Tls { ref mut inner, .. } => inner.flush(),
586 Fd { ref inner, .. } => {
587 use std::os::unix::io::{FromRawFd, IntoRawFd};
588 let mut f = unsafe { std::fs::File::from_raw_fd(inner.as_raw_fd()) };
589 let ret = f.flush();
590 let _ = f.into_raw_fd();
591 ret
592 }
593 Deflate { ref mut inner, .. } => inner.flush(),
594 }
595 }
596}
597
598impl std::os::unix::io::AsRawFd for Connection {
599 fn as_raw_fd(&self) -> std::os::unix::io::RawFd {
600 match self {
601 Tcp { ref inner, .. } => inner.as_raw_fd(),
602 #[cfg(feature = "tls")]
603 Tls { ref inner, .. } => inner.get_ref().as_raw_fd(),
604 Fd { ref inner, .. } => inner.as_raw_fd(),
605 Deflate { ref inner, .. } => inner.get_ref().get_ref().as_raw_fd(),
606 }
607 }
608}
609
610impl AsFd for Connection {
611 fn as_fd(&'_ self) -> BorrowedFd<'_> {
612 match self {
613 Tcp { ref inner, .. } => inner.as_fd(),
614 #[cfg(feature = "tls")]
615 Tls { ref inner, .. } => inner.get_ref().as_fd(),
616 Fd { ref inner, .. } => inner.as_fd(),
617 Deflate { ref inner, .. } => inner.get_ref().get_ref().as_fd(),
618 }
619 }
620}
621
622unsafe impl async_io::IoSafe for Connection {}
623
624#[deprecated = "While it supports IPv6, it does not implement the happy eyeballs algorithm. Use \
625 {std_net,smol}::tcp_stream_connect instead."]
626pub fn lookup_ip(host: &str, port: u16) -> crate::Result<std::net::SocketAddr> {
627 use std::net::ToSocketAddrs;
628
629 use crate::error::{Error, ErrorKind, NetworkErrorKind};
630
631 let addrs = (host, port).to_socket_addrs()?;
632 for addr in addrs {
633 if matches!(
634 addr,
635 std::net::SocketAddr::V4(_) | std::net::SocketAddr::V6(_)
636 ) {
637 return Ok(addr);
638 }
639 }
640
641 Err(
642 Error::new(format!("Could not lookup address {}:{}", host, port))
643 .set_kind(ErrorKind::Network(NetworkErrorKind::HostLookupFailed)),
644 )
645}