1use crate::consts;
2use crate::consts::SOCKS5_ADDR_TYPE_IPV4;
3use crate::read_exact;
4use crate::SocksError;
5use anyhow::Context;
6use std::fmt;
7use std::io;
8use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
9use std::vec::IntoIter;
10use thiserror::Error;
11use tokio::io::{AsyncRead, AsyncReadExt};
12use tokio::net::lookup_host;
13
14#[derive(Error, Debug)]
16pub enum AddrError {
17 #[error("DNS Resolution failed")]
18 DNSResolutionFailed,
19 #[error("Can't read IPv4")]
20 IPv4Unreadable,
21 #[error("Can't read IPv6")]
22 IPv6Unreadable,
23 #[error("Can't read port number")]
24 PortNumberUnreadable,
25 #[error("Can't read domain len")]
26 DomainLenUnreadable,
27 #[error("Can't read Domain content")]
28 DomainContentUnreadable,
29 #[error("Malformed UTF-8")]
30 Utf8,
31 #[error("Unknown address type")]
32 IncorrectAddressType,
33 #[error("{0}")]
34 Custom(String),
35}
36
37#[derive(Debug, Clone, PartialEq, Eq, Hash)]
39pub enum TargetAddr {
40 Ip(SocketAddr),
42 Domain(String, u16),
47}
48
49impl TargetAddr {
50 pub async fn resolve_dns(self) -> anyhow::Result<TargetAddr> {
51 match self {
52 TargetAddr::Ip(ip) => Ok(TargetAddr::Ip(ip)),
53 TargetAddr::Domain(domain, port) => {
54 debug!("Attempt to DNS resolve the domain {}...", &domain);
55
56 let socket_addr = lookup_host((&domain[..], port))
57 .await
58 .context(AddrError::DNSResolutionFailed)?
59 .next()
60 .ok_or(AddrError::Custom(
61 "Can't fetch DNS to the domain.".to_string(),
62 ))?;
63 debug!("domain name resolved to {}", socket_addr);
64
65 Ok(TargetAddr::Ip(socket_addr))
67 }
68 }
69 }
70
71 pub fn is_ip(&self) -> bool {
72 match self {
73 TargetAddr::Ip(_) => true,
74 _ => false,
75 }
76 }
77
78 pub fn is_domain(&self) -> bool {
79 !self.is_ip()
80 }
81
82 pub fn to_be_bytes(&self) -> anyhow::Result<Vec<u8>> {
83 let mut buf = vec![];
84 match self {
85 TargetAddr::Ip(SocketAddr::V4(addr)) => {
86 debug!("TargetAddr::IpV4");
87
88 buf.extend_from_slice(&[SOCKS5_ADDR_TYPE_IPV4]);
89
90 debug!("addr ip {:?}", (*addr.ip()).octets());
91 buf.extend_from_slice(&(addr.ip()).octets()); buf.extend_from_slice(&addr.port().to_be_bytes()); }
94 TargetAddr::Ip(SocketAddr::V6(addr)) => {
95 debug!("TargetAddr::IpV6");
96 buf.extend_from_slice(&[consts::SOCKS5_ADDR_TYPE_IPV6]);
97
98 debug!("addr ip {:?}", (*addr.ip()).octets());
99 buf.extend_from_slice(&(addr.ip()).octets()); buf.extend_from_slice(&addr.port().to_be_bytes()); }
102 TargetAddr::Domain(ref domain, port) => {
103 debug!("TargetAddr::Domain");
104 if domain.len() > u8::max_value() as usize {
105 return Err(SocksError::ExceededMaxDomainLen(domain.len()).into());
106 }
107 buf.extend_from_slice(&[consts::SOCKS5_ADDR_TYPE_DOMAIN_NAME, domain.len() as u8]);
108 buf.extend_from_slice(domain.as_bytes()); buf.extend_from_slice(&port.to_be_bytes());
110 }
112 }
113 Ok(buf)
114 }
115}
116
117impl std::net::ToSocketAddrs for TargetAddr {
120 type Iter = IntoIter<SocketAddr>;
121
122 fn to_socket_addrs(&self) -> io::Result<IntoIter<SocketAddr>> {
123 match *self {
124 TargetAddr::Ip(addr) => Ok(vec![addr].into_iter()),
125 TargetAddr::Domain(_, _) => Err(io::Error::new(
126 io::ErrorKind::Other,
127 "Domain name has to be explicitly resolved, please use TargetAddr::resolve_dns().",
128 )),
129 }
130 }
131}
132
133impl fmt::Display for TargetAddr {
134 #[inline]
135 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
136 match *self {
137 TargetAddr::Ip(ref addr) => write!(f, "{}", addr),
138 TargetAddr::Domain(ref addr, ref port) => write!(f, "{}:{}", addr, port),
139 }
140 }
141}
142
143pub trait ToTargetAddr {
145 fn to_target_addr(&self) -> io::Result<TargetAddr>;
147}
148
149impl<'a> ToTargetAddr for (&'a str, u16) {
150 fn to_target_addr(&self) -> io::Result<TargetAddr> {
151 if let Ok(addr) = self.0.parse::<Ipv4Addr>() {
153 return (addr, self.1).to_target_addr();
154 }
155
156 if let Ok(addr) = self.0.parse::<Ipv6Addr>() {
157 return (addr, self.1).to_target_addr();
158 }
159
160 Ok(TargetAddr::Domain(self.0.to_owned(), self.1))
161 }
162}
163
164impl ToTargetAddr for SocketAddr {
165 fn to_target_addr(&self) -> io::Result<TargetAddr> {
166 Ok(TargetAddr::Ip(*self))
167 }
168}
169
170impl ToTargetAddr for SocketAddrV4 {
171 fn to_target_addr(&self) -> io::Result<TargetAddr> {
172 SocketAddr::V4(*self).to_target_addr()
173 }
174}
175
176impl ToTargetAddr for SocketAddrV6 {
177 fn to_target_addr(&self) -> io::Result<TargetAddr> {
178 SocketAddr::V6(*self).to_target_addr()
179 }
180}
181
182impl ToTargetAddr for (IpAddr, u16) {
183 fn to_target_addr(&self) -> io::Result<TargetAddr> {
184 match self.0 {
185 IpAddr::V4(ipv4_addr) => (ipv4_addr, self.1).to_target_addr(),
186 IpAddr::V6(ipv6_addr) => (ipv6_addr, self.1).to_target_addr(),
187 }
188 }
189}
190
191impl ToTargetAddr for (Ipv4Addr, u16) {
192 fn to_target_addr(&self) -> io::Result<TargetAddr> {
193 SocketAddrV4::new(self.0, self.1).to_target_addr()
194 }
195}
196
197impl ToTargetAddr for (Ipv6Addr, u16) {
198 fn to_target_addr(&self) -> io::Result<TargetAddr> {
199 SocketAddrV6::new(self.0, self.1, 0, 0).to_target_addr()
200 }
201}
202
203#[derive(Debug)]
204pub enum Addr {
205 V4([u8; 4]),
206 V6([u8; 16]),
207 Domain(String), }
209
210pub async fn read_address<T: AsyncRead + Unpin>(
212 stream: &mut T,
213 atyp: u8,
214) -> anyhow::Result<TargetAddr> {
215 let addr = match atyp {
216 consts::SOCKS5_ADDR_TYPE_IPV4 => {
217 debug!("Address type `IPv4`");
218 Addr::V4(read_exact!(stream, [0u8; 4]).context(AddrError::IPv4Unreadable)?)
219 }
220 consts::SOCKS5_ADDR_TYPE_IPV6 => {
221 debug!("Address type `IPv6`");
222 Addr::V6(read_exact!(stream, [0u8; 16]).context(AddrError::IPv6Unreadable)?)
223 }
224 consts::SOCKS5_ADDR_TYPE_DOMAIN_NAME => {
225 debug!("Address type `domain`");
226 let len = read_exact!(stream, [0]).context(AddrError::DomainLenUnreadable)?[0];
227 let domain = read_exact!(stream, vec![0u8; len as usize])
228 .context(AddrError::DomainContentUnreadable)?;
229 let domain = String::from_utf8(domain).context(AddrError::Utf8)?;
231
232 Addr::Domain(domain)
233 }
234 _ => return Err(anyhow::anyhow!(AddrError::IncorrectAddressType)),
235 };
236
237 let port = read_exact!(stream, [0u8; 2]).context(AddrError::PortNumberUnreadable)?;
239 let port = (port[0] as u16) << 8 | port[1] as u16;
241
242 let addr: TargetAddr = match addr {
244 Addr::V4([a, b, c, d]) => (Ipv4Addr::new(a, b, c, d), port).to_target_addr()?,
245 Addr::V6(x) => (Ipv6Addr::from(x), port).to_target_addr()?,
246 Addr::Domain(domain) => TargetAddr::Domain(domain, port),
247 };
248
249 Ok(addr)
250}