melib/utils/
connections.rs

1/*
2 * meli - melib library
3 *
4 * Copyright 2020  Manos Pitsidianakis
5 *
6 * This file is part of meli.
7 *
8 * meli is free software: you can redistribute it and/or modify
9 * it under the terms of the GNU General Public License as published by
10 * the Free Software Foundation, either version 3 of the License, or
11 * (at your option) any later version.
12 *
13 * meli is distributed in the hope that it will be useful,
14 * but WITHOUT ANY WARRANTY; without even the implied warranty of
15 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16 * GNU General Public License for more details.
17 *
18 * You should have received a copy of the GNU General Public License
19 * along with meli. If not, see <http://www.gnu.org/licenses/>.
20 */
21
22//! Connections layers (TCP/fd/TLS/Deflate) to use with remote backends.
23use std::{
24    borrow::Cow,
25    os::{
26        fd::{AsFd, BorrowedFd, OwnedFd},
27        unix::io::AsRawFd,
28    },
29    time::Duration,
30};
31
32use flate2::{read::DeflateDecoder, write::DeflateEncoder, Compression};
33#[cfg(any(target_os = "openbsd", target_os = "netbsd", target_os = "haiku"))]
34use libc::SO_KEEPALIVE as KEEPALIVE_OPTION;
35#[cfg(any(target_os = "macos", target_os = "ios"))]
36use libc::TCP_KEEPALIVE as KEEPALIVE_OPTION;
37#[cfg(not(any(
38    target_os = "openbsd",
39    target_os = "netbsd",
40    target_os = "haiku",
41    target_os = "macos",
42    target_os = "ios"
43)))]
44use libc::TCP_KEEPIDLE as KEEPALIVE_OPTION;
45use libc::{self, c_int, c_void};
46
47// pub mod smol;
48pub mod std_net;
49
50pub const CONNECTION_ATTEMPT_DELAY: std::time::Duration = std::time::Duration::from_millis(250);
51
52pub enum Connection {
53    Tcp {
54        inner: std::net::TcpStream,
55        id: Option<&'static str>,
56        trace: bool,
57    },
58    Fd {
59        inner: OwnedFd,
60        id: Option<&'static str>,
61        trace: bool,
62    },
63    #[cfg(feature = "tls")]
64    Tls {
65        inner: native_tls::TlsStream<Self>,
66        id: Option<&'static str>,
67        trace: bool,
68    },
69    Deflate {
70        inner: DeflateEncoder<DeflateDecoder<Box<Self>>>,
71        id: Option<&'static str>,
72        trace: bool,
73    },
74}
75
76impl std::fmt::Debug for Connection {
77    fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
78        match self {
79            Tcp {
80                ref trace,
81                ref inner,
82                ref id,
83            } => fmt
84                .debug_struct(crate::identify!(Connection))
85                .field("variant", &stringify!(Tcp))
86                .field(stringify!(trace), trace)
87                .field(stringify!(id), id)
88                .field(stringify!(inner), inner)
89                .finish(),
90            #[cfg(feature = "tls")]
91            Tls {
92                ref trace,
93                ref inner,
94                ref id,
95            } => fmt
96                .debug_struct(crate::identify!(Connection))
97                .field("variant", &stringify!(Tls))
98                .field(stringify!(trace), trace)
99                .field(stringify!(id), id)
100                .field(stringify!(inner), inner.get_ref())
101                .finish(),
102            Fd {
103                ref trace,
104                ref inner,
105                ref id,
106            } => fmt
107                .debug_struct(crate::identify!(Connection))
108                .field("variant", &stringify!(Fd))
109                .field(stringify!(trace), trace)
110                .field(stringify!(id), id)
111                .field(stringify!(inner), inner)
112                .finish(),
113            Deflate {
114                ref trace,
115                ref inner,
116                ref id,
117            } => fmt
118                .debug_struct(crate::identify!(Connection))
119                .field("variant", &stringify!(Deflate))
120                .field(stringify!(trace), trace)
121                .field(stringify!(id), id)
122                .field(stringify!(inner), inner)
123                .finish(),
124        }
125    }
126}
127
128use Connection::*;
129
130macro_rules! syscall {
131    ($fn: ident ( $($arg: expr),* $(,)* ) ) => {{
132        #[allow(unused_unsafe)]
133        let res = unsafe { libc::$fn($($arg, )*) };
134        if res == -1 {
135            Err(std::io::Error::last_os_error())
136        } else {
137            Ok(res)
138        }
139    }};
140}
141
142/// Hardcoded `setsockopt` arguments for type safety when calling
143/// [`Connection::setsockopt`] in an `unsafe` block.
144///
145/// Add new variants when you need to call `setsockopt` with new arguments.
146pub enum SockOpts {
147    /// Set TCP Keep Alive.
148    ///
149    /// Following text is sourced from <https://tldp.org/HOWTO/html_single/TCP-Keepalive-HOWTO/>.
150    ///
151    /// ```text
152    /// 4.2. The setsockopt function call
153    ///
154    /// All you need to enable keepalive for a specific socket is to set the specific socket option
155    /// on the socket itself. The prototype of the function is as follows:
156    ///
157    ///
158    ///   int setsockopt(int s, int level, int optname,
159    ///                  const void *optval, socklen_t optlen)
160    ///
161    ///
162    /// The first parameter is the socket, previously created with the socket(2); the second one
163    /// must be SOL_SOCKET, and the third must be SO_KEEPALIVE . The fourth parameter must be a
164    /// boolean integer value, indicating that we want to enable the option, while the last is the
165    /// size of the value passed before.
166    ///
167    /// According to the manpage, 0 is returned upon success, and -1 is returned on error (and
168    /// errno is properly set).
169    ///
170    /// There are also three other socket options you can set for keepalive when you write your
171    /// application. They all use the SOL_TCP level instead of SOL_SOCKET, and they override
172    /// system-wide variables only for the current socket. If you read without writing first, the
173    /// current system-wide parameters will be returned.
174    ///
175    ///     TCP_KEEPCNT: overrides tcp_keepalive_probes
176    ///
177    ///     TCP_KEEPIDLE: overrides tcp_keepalive_time
178    ///
179    ///     TCP_KEEPINTVL: overrides tcp_keepalive_intvl
180    /// ```
181    ///
182    /// Field `duration` overrides `tcp_keepalive_time`:
183    ///
184    /// ```text
185    /// tcp_keepalive_time
186    ///
187    ///    the interval between the last data packet sent (simple ACKs are not considered data) and the
188    ///    first keepalive probe; after the connection is marked to need keepalive, this counter is not
189    ///    used any further
190    /// ```
191    ///
192    /// The default value in the Linux kernel is 7200 seconds (2 hours).
193    KeepAlive {
194        enable: bool,
195        duration: Option<Duration>,
196    },
197    TcpNoDelay {
198        enable: bool,
199    },
200}
201
202impl Connection {
203    pub const IO_BUF_SIZE: usize = 64 * 1024;
204
205    pub fn deflate(mut self) -> Self {
206        let trace = self.is_trace_enabled();
207        let id = self.id();
208        self.set_trace(false);
209        Self::Deflate {
210            inner: DeflateEncoder::new(
211                DeflateDecoder::new_with_buf(Box::new(self), vec![0; Self::IO_BUF_SIZE]),
212                Compression::default(),
213            ),
214            id,
215            trace,
216        }
217    }
218
219    #[cfg(feature = "tls")]
220    pub fn new_tls(mut inner: native_tls::TlsStream<Self>) -> Self {
221        let trace = inner.get_ref().is_trace_enabled();
222        let id = inner.get_ref().id();
223        if trace {
224            inner.get_mut().set_trace(false);
225        }
226        Self::Tls { inner, id, trace }
227    }
228
229    pub fn new_tcp(inner: std::net::TcpStream) -> Self {
230        let ret = Self::Tcp {
231            inner,
232            id: None,
233            trace: false,
234        };
235        _ = ret.setsockopt(SockOpts::TcpNoDelay { enable: true });
236
237        ret
238    }
239
240    pub fn trace(mut self, val: bool) -> Self {
241        match self {
242            Tcp { ref mut trace, .. } => *trace = val,
243            #[cfg(feature = "tls")]
244            Tls { ref mut trace, .. } => *trace = val,
245            Fd { ref mut trace, .. } => {
246                *trace = val;
247            }
248            Deflate { ref mut trace, .. } => *trace = val,
249        }
250        self
251    }
252
253    pub fn with_id(mut self, val: &'static str) -> Self {
254        match self {
255            Tcp { ref mut id, .. } => *id = Some(val),
256            #[cfg(feature = "tls")]
257            Tls { ref mut id, .. } => *id = Some(val),
258            Fd { ref mut id, .. } => {
259                *id = Some(val);
260            }
261            Deflate { ref mut id, .. } => *id = Some(val),
262        }
263        self
264    }
265
266    pub fn set_trace(&mut self, val: bool) {
267        match self {
268            Tcp { ref mut trace, .. } => *trace = val,
269            #[cfg(feature = "tls")]
270            Tls { ref mut trace, .. } => *trace = val,
271            Fd { ref mut trace, .. } => {
272                *trace = val;
273            }
274            Deflate { ref mut trace, .. } => *trace = val,
275        }
276    }
277
278    pub fn set_nonblocking(&self, nonblocking: bool) -> std::io::Result<()> {
279        if self.is_trace_enabled() {
280            let id = self.id();
281            log::trace!(
282                "{}{}{}{:?} set_nonblocking({:?})",
283                if id.is_some() { "[" } else { "" },
284                if let Some(id) = id.as_ref() { id } else { "" },
285                if id.is_some() { "]: " } else { "" },
286                self,
287                nonblocking
288            );
289        }
290        match self {
291            Tcp { ref inner, .. } => inner.set_nonblocking(nonblocking),
292            #[cfg(feature = "tls")]
293            Tls { ref inner, .. } => inner.get_ref().set_nonblocking(nonblocking),
294            Fd { inner, .. } => {
295                // [ref:VERIFY]
296                nix::fcntl::fcntl(
297                    inner.as_raw_fd(),
298                    nix::fcntl::FcntlArg::F_SETFL(if nonblocking {
299                        nix::fcntl::OFlag::O_NONBLOCK
300                    } else {
301                        !nix::fcntl::OFlag::O_NONBLOCK
302                    }),
303                )
304                .map_err(|err| std::io::Error::from_raw_os_error(err as i32))?;
305                Ok(())
306            }
307            Deflate { ref inner, .. } => inner.get_ref().get_ref().set_nonblocking(nonblocking),
308        }
309    }
310
311    pub fn set_read_timeout(&self, dur: Option<Duration>) -> std::io::Result<()> {
312        if self.is_trace_enabled() {
313            let id = self.id();
314            log::trace!(
315                "{}{}{}{:?} set_read_timeout({:?})",
316                if id.is_some() { "[" } else { "" },
317                if let Some(id) = id.as_ref() { id } else { "" },
318                if id.is_some() { "]: " } else { "" },
319                self,
320                dur
321            );
322        }
323        match self {
324            Tcp { ref inner, .. } => inner.set_read_timeout(dur),
325            #[cfg(feature = "tls")]
326            Tls { ref inner, .. } => inner.get_ref().set_read_timeout(dur),
327            Fd { .. } => Ok(()),
328            Deflate { ref inner, .. } => inner.get_ref().get_ref().set_read_timeout(dur),
329        }
330    }
331
332    pub fn set_write_timeout(&self, dur: Option<Duration>) -> std::io::Result<()> {
333        if self.is_trace_enabled() {
334            let id = self.id();
335            log::trace!(
336                "{}{}{}{:?} set_write_timeout({:?})",
337                if id.is_some() { "[" } else { "" },
338                if let Some(id) = id.as_ref() { id } else { "" },
339                if id.is_some() { "]: " } else { "" },
340                self,
341                dur
342            );
343        }
344        match self {
345            Tcp { ref inner, .. } => inner.set_write_timeout(dur),
346            #[cfg(feature = "tls")]
347            Tls { ref inner, .. } => inner.get_ref().set_write_timeout(dur),
348            Fd { .. } => Ok(()),
349            Deflate { ref inner, .. } => inner.get_ref().get_ref().set_write_timeout(dur),
350        }
351    }
352
353    pub fn keepalive(&self) -> std::io::Result<Option<Duration>> {
354        if self.is_trace_enabled() {
355            log::trace!("{:?} keepalive()", self);
356        }
357        if matches!(self, Fd { .. }) {
358            return Ok(None);
359        }
360        unsafe {
361            let raw: c_int = self.__getsockopt(libc::SOL_SOCKET, libc::SO_KEEPALIVE)?;
362            if raw == 0 {
363                return Ok(None);
364            }
365            let secs: c_int = self.__getsockopt(libc::IPPROTO_TCP, KEEPALIVE_OPTION)?;
366            Ok(Some(Duration::new(secs as u64, 0)))
367        }
368    }
369
370    pub fn set_keepalive(&self, keepalive: Option<Duration>) -> std::io::Result<()> {
371        if self.is_trace_enabled() {
372            let id = self.id();
373            log::trace!(
374                "{}{}{}{:?} set_keepalive({:?})",
375                if id.is_some() { "[" } else { "" },
376                if let Some(id) = id.as_ref() { id } else { "" },
377                if id.is_some() { "]: " } else { "" },
378                self,
379                keepalive
380            );
381        }
382        if matches!(self, Fd { .. }) {
383            return Ok(());
384        }
385        self.setsockopt(SockOpts::KeepAlive {
386            enable: keepalive.is_some(),
387            duration: keepalive,
388        })
389    }
390
391    unsafe fn inner_setsockopt<T>(&self, opt: c_int, val: c_int, payload: T) -> std::io::Result<()>
392    where
393        T: Copy,
394    {
395        let payload = std::ptr::addr_of!(payload) as *const c_void;
396        syscall!(setsockopt(
397            self.as_raw_fd(),
398            opt,
399            val,
400            payload,
401            std::mem::size_of::<T>() as libc::socklen_t,
402        ))?;
403        Ok(())
404    }
405
406    pub fn setsockopt(&self, option: SockOpts) -> std::io::Result<()> {
407        match option {
408            SockOpts::KeepAlive {
409                enable: true,
410                duration,
411            } => {
412                unsafe {
413                    self.inner_setsockopt(libc::SOL_SOCKET, libc::SO_KEEPALIVE, <c_int>::from(true))
414                }?;
415                if let Some(dur) = duration {
416                    unsafe {
417                        self.inner_setsockopt(
418                            libc::IPPROTO_TCP,
419                            KEEPALIVE_OPTION,
420                            dur.as_secs() as c_int,
421                        )
422                    }?;
423                }
424                Ok(())
425            }
426            SockOpts::KeepAlive {
427                enable: false,
428                duration: _,
429            } => unsafe {
430                self.inner_setsockopt(libc::SOL_SOCKET, libc::SO_KEEPALIVE, <c_int>::from(false))
431            },
432            SockOpts::TcpNoDelay { enable } => unsafe {
433                #[cfg(any(
434                    target_os = "openbsd",
435                    target_os = "netbsd",
436                    target_os = "haiku",
437                    target_os = "macos",
438                    target_os = "ios"
439                ))]
440                {
441                    self.inner_setsockopt(
442                        libc::IPPROTO_TCP,
443                        libc::TCP_NODELAY,
444                        if enable { c_int::from(1_u8) } else { 0 },
445                    )
446                }
447                #[cfg(not(any(
448                    target_os = "openbsd",
449                    target_os = "netbsd",
450                    target_os = "haiku",
451                    target_os = "macos",
452                    target_os = "ios"
453                )))]
454                {
455                    self.inner_setsockopt(
456                        libc::SOL_TCP,
457                        libc::TCP_NODELAY,
458                        if enable { c_int::from(1_u8) } else { 0 },
459                    )
460                }
461            },
462        }
463    }
464
465    #[inline]
466    unsafe fn __getsockopt<T: Copy>(&self, opt: c_int, val: c_int) -> std::io::Result<T> {
467        let mut slot: T = unsafe { std::mem::zeroed() };
468        let mut len = std::mem::size_of::<T>() as libc::socklen_t;
469        syscall!(getsockopt(
470            self.as_raw_fd(),
471            opt,
472            val,
473            std::ptr::addr_of_mut!(slot) as *mut _,
474            &mut len,
475        ))?;
476        assert_eq!(len as usize, std::mem::size_of::<T>());
477        Ok(slot)
478    }
479
480    fn is_trace_enabled(&self) -> bool {
481        match self {
482            Fd { trace, .. } | Tcp { trace, .. } => *trace,
483            #[cfg(feature = "tls")]
484            Tls { trace, .. } => *trace,
485            Deflate { trace, .. } => *trace,
486        }
487    }
488
489    fn id(&self) -> Option<&'static str> {
490        match self {
491            Fd { id, .. } | Tcp { id, .. } => *id,
492            #[cfg(feature = "tls")]
493            Tls { id, .. } => *id,
494            Deflate { id, .. } => *id,
495        }
496    }
497}
498
499impl std::io::Read for Connection {
500    fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
501        let res = match self {
502            Tcp { ref mut inner, .. } => inner.read(buf),
503            #[cfg(feature = "tls")]
504            Tls { ref mut inner, .. } => inner.read(buf),
505            Fd { ref inner, .. } => {
506                use std::os::unix::io::{FromRawFd, IntoRawFd};
507                let mut f = unsafe { std::fs::File::from_raw_fd(inner.as_raw_fd()) };
508                let ret = f.read(buf);
509                let _ = f.into_raw_fd();
510                ret
511            }
512            Deflate { ref mut inner, .. } => inner.read(buf),
513        };
514        if self.is_trace_enabled() {
515            let id = self.id();
516            match &res {
517                Ok(len) => {
518                    let slice = &buf[..*len];
519                    log::trace!(
520                        "{}{}{}{:?} read {:?} bytes:{}",
521                        if id.is_some() { "[" } else { "" },
522                        if let Some(id) = id.as_ref() { id } else { "" },
523                        if id.is_some() { "]: " } else { "" },
524                        self,
525                        len,
526                        std::str::from_utf8(slice)
527                            .map(Cow::Borrowed)
528                            .or_else(|_| crate::text::hex::bytes_to_hex(slice).map(Cow::Owned))
529                            .unwrap_or(Cow::Borrowed("Could not convert to hex."))
530                    );
531                }
532                Err(err) if matches!(err.kind(), std::io::ErrorKind::WouldBlock) => {}
533                Err(err) => {
534                    log::trace!(
535                        "{}{}{}{:?} could not read {:?}",
536                        if id.is_some() { "[" } else { "" },
537                        if let Some(id) = id.as_ref() { id } else { "" },
538                        if id.is_some() { "]: " } else { "" },
539                        self,
540                        err,
541                    );
542                }
543            }
544        }
545        res
546    }
547}
548
549impl std::io::Write for Connection {
550    fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
551        if self.is_trace_enabled() {
552            let id = self.id();
553            log::trace!(
554                "{}{}{}{:?} writing {} bytes:{}",
555                if id.is_some() { "[" } else { "" },
556                if let Some(id) = id.as_ref() { id } else { "" },
557                if id.is_some() { "]: " } else { "" },
558                self,
559                buf.len(),
560                std::str::from_utf8(buf)
561                    .map(Cow::Borrowed)
562                    .or_else(|_| crate::text::hex::bytes_to_hex(buf).map(Cow::Owned))
563                    .unwrap_or(Cow::Borrowed("Could not convert to hex."))
564            );
565        }
566        match self {
567            Tcp { ref mut inner, .. } => inner.write(buf),
568            #[cfg(feature = "tls")]
569            Tls { ref mut inner, .. } => inner.write(buf),
570            Fd { ref inner, .. } => {
571                use std::os::unix::io::{FromRawFd, IntoRawFd};
572                let mut f = unsafe { std::fs::File::from_raw_fd(inner.as_raw_fd()) };
573                let ret = f.write(buf);
574                let _ = f.into_raw_fd();
575                ret
576            }
577            Deflate { ref mut inner, .. } => inner.write(buf),
578        }
579    }
580
581    fn flush(&mut self) -> std::io::Result<()> {
582        match self {
583            Tcp { ref mut inner, .. } => inner.flush(),
584            #[cfg(feature = "tls")]
585            Tls { ref mut inner, .. } => inner.flush(),
586            Fd { ref inner, .. } => {
587                use std::os::unix::io::{FromRawFd, IntoRawFd};
588                let mut f = unsafe { std::fs::File::from_raw_fd(inner.as_raw_fd()) };
589                let ret = f.flush();
590                let _ = f.into_raw_fd();
591                ret
592            }
593            Deflate { ref mut inner, .. } => inner.flush(),
594        }
595    }
596}
597
598impl std::os::unix::io::AsRawFd for Connection {
599    fn as_raw_fd(&self) -> std::os::unix::io::RawFd {
600        match self {
601            Tcp { ref inner, .. } => inner.as_raw_fd(),
602            #[cfg(feature = "tls")]
603            Tls { ref inner, .. } => inner.get_ref().as_raw_fd(),
604            Fd { ref inner, .. } => inner.as_raw_fd(),
605            Deflate { ref inner, .. } => inner.get_ref().get_ref().as_raw_fd(),
606        }
607    }
608}
609
610impl AsFd for Connection {
611    fn as_fd(&'_ self) -> BorrowedFd<'_> {
612        match self {
613            Tcp { ref inner, .. } => inner.as_fd(),
614            #[cfg(feature = "tls")]
615            Tls { ref inner, .. } => inner.get_ref().as_fd(),
616            Fd { ref inner, .. } => inner.as_fd(),
617            Deflate { ref inner, .. } => inner.get_ref().get_ref().as_fd(),
618        }
619    }
620}
621
622unsafe impl async_io::IoSafe for Connection {}
623
624#[deprecated = "While it supports IPv6, it does not implement the happy eyeballs algorithm. Use \
625                {std_net,smol}::tcp_stream_connect instead."]
626pub fn lookup_ip(host: &str, port: u16) -> crate::Result<std::net::SocketAddr> {
627    use std::net::ToSocketAddrs;
628
629    use crate::error::{Error, ErrorKind, NetworkErrorKind};
630
631    let addrs = (host, port).to_socket_addrs()?;
632    for addr in addrs {
633        if matches!(
634            addr,
635            std::net::SocketAddr::V4(_) | std::net::SocketAddr::V6(_)
636        ) {
637            return Ok(addr);
638        }
639    }
640
641    Err(
642        Error::new(format!("Could not lookup address {}:{}", host, port))
643            .set_kind(ErrorKind::Network(NetworkErrorKind::HostLookupFailed)),
644    )
645}