peerlink/
connector.rs

1use std::net::{SocketAddr, SocketAddrV4, SocketAddrV6, ToSocketAddrs};
2use std::{io, net};
3
4/// A connect target that can be either a socket address or a resolvable domain name.
5#[derive(Debug, Clone, Eq, PartialEq)]
6pub enum Target {
7    /// The target is a socket.
8    Socket(SocketAddr),
9    /// The target is a fully qualified domain, along with a port.
10    Domain(String, u16),
11}
12
13impl From<SocketAddr> for Target {
14    fn from(value: SocketAddr) -> Self {
15        Self::Socket(value)
16    }
17}
18
19impl From<SocketAddrV4> for Target {
20    fn from(value: SocketAddrV4) -> Self {
21        Self::Socket(value.into())
22    }
23}
24
25impl From<SocketAddrV6> for Target {
26    fn from(value: SocketAddrV6) -> Self {
27        Self::Socket(value.into())
28    }
29}
30
31impl From<(net::Ipv4Addr, u16)> for Target {
32    fn from(value: (net::Ipv4Addr, u16)) -> Self {
33        Self::Socket(value.into())
34    }
35}
36
37impl From<(net::Ipv6Addr, u16)> for Target {
38    fn from(value: (net::Ipv6Addr, u16)) -> Self {
39        Self::Socket(value.into())
40    }
41}
42
43impl From<(&str, u16)> for Target {
44    fn from((domain, port): (&str, u16)) -> Self {
45        Self::Domain(domain.to_owned(), port)
46    }
47}
48
49impl From<(String, u16)> for Target {
50    fn from((domain, port): (String, u16)) -> Self {
51        Self::Domain(domain, port)
52    }
53}
54
55impl std::str::FromStr for Target {
56    type Err = io::Error;
57
58    fn from_str(value: &str) -> Result<Self, Self::Err> {
59        if let Ok(addr) = value.parse::<SocketAddr>() {
60            Ok(Self::Socket(addr))
61        } else {
62            let (domain, port) = value.split_once(':').ok_or(io::Error::new(
63                io::ErrorKind::InvalidInput,
64                "not a target address",
65            ))?;
66
67            let port: u16 = port
68                .parse()
69                .map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "not a valid port"))?;
70            Ok(Self::Domain(domain.to_owned(), port))
71        }
72    }
73}
74
75impl std::fmt::Display for Target {
76    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
77        match self {
78            Target::Socket(socket_addr) => socket_addr.fmt(f),
79            Target::Domain(domain, port) => write!(f, "{domain}:{port}"),
80        }
81    }
82}
83
84/// Types implementing this trait can connect to a target address in a custom manner before
85/// returning a [`mio::net::TcpStream`]. This can be used for proxying and other custom scenarios.
86/// It is the responsibility of the caller to put the stream into nonblocking mode. Failing
87/// to do so will block the reactor indefinitely and render it inoperable.
88pub trait Connector: Clone + Send + Sync + 'static {
89    /// Sometimes it is not possible to connect in a non-blocking manner, such as when using 3rd
90    /// party libraries. If anything in the connect logic blocks, this must be set to `true`. In
91    /// that case connects are performed on a dedicated thread in order to not block the reactor.
92    /// Setting this to `false` when it should be `true` will interfere with the operation of the
93    /// reactor.
94    const CONNECT_IN_BACKGROUND: bool;
95
96    /// Connect to a target address and return a [`mio`] TCP stream.
97    fn connect(&self, target: &Target) -> io::Result<mio::net::TcpStream>;
98}
99
100/// Default [`Connector`] implementation for [`mio`] that just connects to a target address.
101#[derive(Clone)]
102pub struct DefaultConnector;
103
104impl Connector for DefaultConnector {
105    const CONNECT_IN_BACKGROUND: bool = false;
106
107    fn connect(&self, target: &Target) -> io::Result<mio::net::TcpStream> {
108        let socket_addr = match target {
109            Target::Socket(socket) => *socket,
110            Target::Domain(domain, port) => (domain.as_str(), *port)
111                .to_socket_addrs()?
112                .next()
113                .ok_or(io::Error::new(
114                    io::ErrorKind::AddrNotAvailable,
115                    "target -> socket address: DNS resolution failure",
116                ))?,
117        };
118
119        mio::net::TcpStream::connect(socket_addr)
120    }
121}
122
123/// Connector that connects through a socks5 proxy.
124#[cfg(feature = "socks")]
125#[derive(Clone)]
126pub struct Socks5Connector {
127    /// The socket address of the proxy.
128    pub proxy: std::net::SocketAddr,
129    /// Optional socks username and password.
130    pub credentials: Option<(String, String)>,
131}
132
133#[cfg(feature = "socks")]
134impl Connector for Socks5Connector {
135    const CONNECT_IN_BACKGROUND: bool = true;
136
137    fn connect(&self, target: &Target) -> io::Result<mio::net::TcpStream> {
138        use socks::ToTargetAddr;
139        let target = target
140            .to_target_addr()
141            .map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "not a target address"))?;
142
143        let stream = match self.credentials.as_ref() {
144            Some((username, password)) => {
145                socks::Socks5Stream::connect_with_password(self.proxy, target, username, password)?
146            }
147            None => socks::Socks5Stream::connect(self.proxy, target)?,
148        }
149        .into_inner();
150
151        use std::time::Duration;
152
153        let try_deadline = Duration::from_millis(5000);
154        let mut elapsed = Duration::ZERO;
155        let mut try_cycle_duration = Duration::from_millis(1);
156
157        loop {
158            match stream.set_nonblocking(true) {
159                Ok(()) => break Ok(mio::net::TcpStream::from_std(stream)),
160
161                Err(err)
162                    if err.kind() == std::io::ErrorKind::WouldBlock && elapsed < try_deadline =>
163                {
164                    try_cycle_duration =
165                        (try_cycle_duration * 2).clamp(Duration::ZERO, Duration::from_millis(1000));
166                    std::thread::sleep(try_cycle_duration);
167                    elapsed += try_cycle_duration;
168                }
169
170                Err(err) => break Err(err),
171            }
172        }
173    }
174}
175
176#[cfg(feature = "socks")]
177impl socks::ToTargetAddr for Target {
178    fn to_target_addr(&self) -> io::Result<socks::TargetAddr> {
179        match self {
180            Target::Socket(socket) => Ok(socks::TargetAddr::Ip(*socket)),
181            Target::Domain(domain, port) => (domain.as_str(), *port).to_target_addr(),
182        }
183    }
184}