fast_socks5/util/
target_addr.rs

1use crate::consts;
2use crate::consts::SOCKS5_ADDR_TYPE_IPV4;
3use crate::read_exact;
4use crate::SocksError;
5use anyhow::Context;
6use std::fmt;
7use std::io;
8use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
9use std::vec::IntoIter;
10use thiserror::Error;
11use tokio::io::{AsyncRead, AsyncReadExt};
12use tokio::net::lookup_host;
13
14/// SOCKS5 reply code
15#[derive(Error, Debug)]
16pub enum AddrError {
17    #[error("DNS Resolution failed")]
18    DNSResolutionFailed,
19    #[error("Can't read IPv4")]
20    IPv4Unreadable,
21    #[error("Can't read IPv6")]
22    IPv6Unreadable,
23    #[error("Can't read port number")]
24    PortNumberUnreadable,
25    #[error("Can't read domain len")]
26    DomainLenUnreadable,
27    #[error("Can't read Domain content")]
28    DomainContentUnreadable,
29    #[error("Malformed UTF-8")]
30    Utf8,
31    #[error("Unknown address type")]
32    IncorrectAddressType,
33    #[error("{0}")]
34    Custom(String),
35}
36
37/// A description of a connection target.
38#[derive(Debug, Clone, PartialEq, Eq, Hash)]
39pub enum TargetAddr {
40    /// Connect to an IP address.
41    Ip(SocketAddr),
42    /// Connect to a fully qualified domain name.
43    ///
44    /// The domain name will be passed along to the proxy server and DNS lookup
45    /// will happen there.
46    Domain(String, u16),
47}
48
49impl TargetAddr {
50    pub async fn resolve_dns(self) -> anyhow::Result<TargetAddr> {
51        match self {
52            TargetAddr::Ip(ip) => Ok(TargetAddr::Ip(ip)),
53            TargetAddr::Domain(domain, port) => {
54                debug!("Attempt to DNS resolve the domain {}...", &domain);
55
56                let socket_addr = lookup_host((&domain[..], port))
57                    .await
58                    .context(AddrError::DNSResolutionFailed)?
59                    .next()
60                    .ok_or(AddrError::Custom(
61                        "Can't fetch DNS to the domain.".to_string(),
62                    ))?;
63                debug!("domain name resolved to {}", socket_addr);
64
65                // has been converted to an ip
66                Ok(TargetAddr::Ip(socket_addr))
67            }
68        }
69    }
70
71    pub fn is_ip(&self) -> bool {
72        match self {
73            TargetAddr::Ip(_) => true,
74            _ => false,
75        }
76    }
77
78    pub fn is_domain(&self) -> bool {
79        !self.is_ip()
80    }
81
82    pub fn to_be_bytes(&self) -> anyhow::Result<Vec<u8>> {
83        let mut buf = vec![];
84        match self {
85            TargetAddr::Ip(SocketAddr::V4(addr)) => {
86                debug!("TargetAddr::IpV4");
87
88                buf.extend_from_slice(&[SOCKS5_ADDR_TYPE_IPV4]);
89
90                debug!("addr ip {:?}", (*addr.ip()).octets());
91                buf.extend_from_slice(&(addr.ip()).octets()); // ip
92                buf.extend_from_slice(&addr.port().to_be_bytes()); // port
93            }
94            TargetAddr::Ip(SocketAddr::V6(addr)) => {
95                debug!("TargetAddr::IpV6");
96                buf.extend_from_slice(&[consts::SOCKS5_ADDR_TYPE_IPV6]);
97
98                debug!("addr ip {:?}", (*addr.ip()).octets());
99                buf.extend_from_slice(&(addr.ip()).octets()); // ip
100                buf.extend_from_slice(&addr.port().to_be_bytes()); // port
101            }
102            TargetAddr::Domain(ref domain, port) => {
103                debug!("TargetAddr::Domain");
104                if domain.len() > u8::max_value() as usize {
105                    return Err(SocksError::ExceededMaxDomainLen(domain.len()).into());
106                }
107                buf.extend_from_slice(&[consts::SOCKS5_ADDR_TYPE_DOMAIN_NAME, domain.len() as u8]);
108                buf.extend_from_slice(domain.as_bytes()); // domain content
109                buf.extend_from_slice(&port.to_be_bytes());
110                // port content (.to_be_bytes() convert from u16 to u8 type)
111            }
112        }
113        Ok(buf)
114    }
115}
116
117// async-std ToSocketAddrs doesn't supports external trait implementation
118// @see https://github.com/async-rs/async-std/issues/539
119impl std::net::ToSocketAddrs for TargetAddr {
120    type Iter = IntoIter<SocketAddr>;
121
122    fn to_socket_addrs(&self) -> io::Result<IntoIter<SocketAddr>> {
123        match *self {
124            TargetAddr::Ip(addr) => Ok(vec![addr].into_iter()),
125            TargetAddr::Domain(_, _) => Err(io::Error::new(
126                io::ErrorKind::Other,
127                "Domain name has to be explicitly resolved, please use TargetAddr::resolve_dns().",
128            )),
129        }
130    }
131}
132
133impl fmt::Display for TargetAddr {
134    #[inline]
135    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
136        match *self {
137            TargetAddr::Ip(ref addr) => write!(f, "{}", addr),
138            TargetAddr::Domain(ref addr, ref port) => write!(f, "{}:{}", addr, port),
139        }
140    }
141}
142
143/// A trait for objects that can be converted to `TargetAddr`.
144pub trait ToTargetAddr {
145    /// Converts the value of `self` to a `TargetAddr`.
146    fn to_target_addr(&self) -> io::Result<TargetAddr>;
147}
148
149impl<'a> ToTargetAddr for (&'a str, u16) {
150    fn to_target_addr(&self) -> io::Result<TargetAddr> {
151        // try to parse as an IP first
152        if let Ok(addr) = self.0.parse::<Ipv4Addr>() {
153            return (addr, self.1).to_target_addr();
154        }
155
156        if let Ok(addr) = self.0.parse::<Ipv6Addr>() {
157            return (addr, self.1).to_target_addr();
158        }
159
160        Ok(TargetAddr::Domain(self.0.to_owned(), self.1))
161    }
162}
163
164impl ToTargetAddr for SocketAddr {
165    fn to_target_addr(&self) -> io::Result<TargetAddr> {
166        Ok(TargetAddr::Ip(*self))
167    }
168}
169
170impl ToTargetAddr for SocketAddrV4 {
171    fn to_target_addr(&self) -> io::Result<TargetAddr> {
172        SocketAddr::V4(*self).to_target_addr()
173    }
174}
175
176impl ToTargetAddr for SocketAddrV6 {
177    fn to_target_addr(&self) -> io::Result<TargetAddr> {
178        SocketAddr::V6(*self).to_target_addr()
179    }
180}
181
182impl ToTargetAddr for (IpAddr, u16) {
183    fn to_target_addr(&self) -> io::Result<TargetAddr> {
184        match self.0 {
185            IpAddr::V4(ipv4_addr) => (ipv4_addr, self.1).to_target_addr(),
186            IpAddr::V6(ipv6_addr) => (ipv6_addr, self.1).to_target_addr(),
187        }
188    }
189}
190
191impl ToTargetAddr for (Ipv4Addr, u16) {
192    fn to_target_addr(&self) -> io::Result<TargetAddr> {
193        SocketAddrV4::new(self.0, self.1).to_target_addr()
194    }
195}
196
197impl ToTargetAddr for (Ipv6Addr, u16) {
198    fn to_target_addr(&self) -> io::Result<TargetAddr> {
199        SocketAddrV6::new(self.0, self.1, 0, 0).to_target_addr()
200    }
201}
202
203#[derive(Debug)]
204pub enum Addr {
205    V4([u8; 4]),
206    V6([u8; 16]),
207    Domain(String), // Vec<[u8]> or Box<[u8]> or String ?
208}
209
210/// This function is used by the client & the server
211pub async fn read_address<T: AsyncRead + Unpin>(
212    stream: &mut T,
213    atyp: u8,
214) -> anyhow::Result<TargetAddr> {
215    let addr = match atyp {
216        consts::SOCKS5_ADDR_TYPE_IPV4 => {
217            debug!("Address type `IPv4`");
218            Addr::V4(read_exact!(stream, [0u8; 4]).context(AddrError::IPv4Unreadable)?)
219        }
220        consts::SOCKS5_ADDR_TYPE_IPV6 => {
221            debug!("Address type `IPv6`");
222            Addr::V6(read_exact!(stream, [0u8; 16]).context(AddrError::IPv6Unreadable)?)
223        }
224        consts::SOCKS5_ADDR_TYPE_DOMAIN_NAME => {
225            debug!("Address type `domain`");
226            let len = read_exact!(stream, [0]).context(AddrError::DomainLenUnreadable)?[0];
227            let domain = read_exact!(stream, vec![0u8; len as usize])
228                .context(AddrError::DomainContentUnreadable)?;
229            // make sure the bytes are correct utf8 string
230            let domain = String::from_utf8(domain).context(AddrError::Utf8)?;
231
232            Addr::Domain(domain)
233        }
234        _ => return Err(anyhow::anyhow!(AddrError::IncorrectAddressType)),
235    };
236
237    // Find port number
238    let port = read_exact!(stream, [0u8; 2]).context(AddrError::PortNumberUnreadable)?;
239    // Convert (u8 * 2) into u16
240    let port = (port[0] as u16) << 8 | port[1] as u16;
241
242    // Merge ADDRESS + PORT into a TargetAddr
243    let addr: TargetAddr = match addr {
244        Addr::V4([a, b, c, d]) => (Ipv4Addr::new(a, b, c, d), port).to_target_addr()?,
245        Addr::V6(x) => (Ipv6Addr::from(x), port).to_target_addr()?,
246        Addr::Domain(domain) => TargetAddr::Domain(domain, port),
247    };
248
249    Ok(addr)
250}