1use crate::consts;
2use crate::consts::SOCKS5_ADDR_TYPE_IPV4;
3use crate::read_exact;
4use crate::ReplyError;
5use std::fmt;
6use std::io;
7use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
8use std::vec::IntoIter;
9use tokio::io::{AsyncRead, AsyncReadExt};
10use tokio::net::lookup_host;
11
12#[derive(thiserror::Error, Debug)]
14pub enum AddrError {
15 #[error("DNS Resolution failed: {0}")]
16 DNSResolutionFailed(#[source] io::Error),
17 #[error("DNS returned no appropriate records")]
18 NoDNSRecords,
19 #[error("Domain length {0} exceeded maximum")]
20 DomainLenTooLong(usize),
21 #[error("Can't read IPv4: {0}")]
22 IPv4Unreadable(#[source] io::Error),
23 #[error("Can't read IPv6: {0}")]
24 IPv6Unreadable(#[source] io::Error),
25 #[error("Can't read port number: {0}")]
26 PortNumberUnreadable(#[source] io::Error),
27 #[error("Can't read domain len: {0}")]
28 DomainLenUnreadable(#[source] io::Error),
29 #[error("Can't read domain content: {0}")]
30 DomainContentUnreadable(#[source] io::Error),
31 #[error("Can't convert address: {0}")]
32 AddrConversionFailed(#[source] io::Error),
33 #[error("Malformed UTF-8")]
34 Utf8(#[source] std::string::FromUtf8Error),
35 #[error("Unknown address type")]
36 IncorrectAddressType,
37}
38
39impl AddrError {
40 pub fn to_reply_error(&self) -> ReplyError {
41 match self {
42 AddrError::IncorrectAddressType => ReplyError::AddressTypeNotSupported,
43 _ => ReplyError::ConnectionRefused,
44 }
45 }
46}
47
48#[derive(Debug, Clone, PartialEq, Eq, Hash)]
50pub enum TargetAddr {
51 Ip(SocketAddr),
53 Domain(String, u16),
58}
59
60impl TargetAddr {
61 pub async fn resolve_dns(self) -> Result<TargetAddr, AddrError> {
62 match self {
63 TargetAddr::Ip(ip) => Ok(TargetAddr::Ip(ip)),
64 TargetAddr::Domain(domain, port) => {
65 debug!("Attempt to DNS resolve the domain {}...", &domain);
66
67 let socket_addr = lookup_host((&domain[..], port))
68 .await
69 .map_err(|err| AddrError::DNSResolutionFailed(err))?
70 .next()
71 .ok_or(AddrError::NoDNSRecords)?;
72 debug!("domain name resolved to {}", socket_addr);
73
74 Ok(TargetAddr::Ip(socket_addr))
76 }
77 }
78 }
79
80 pub fn is_ip(&self) -> bool {
81 match self {
82 TargetAddr::Ip(_) => true,
83 _ => false,
84 }
85 }
86
87 pub fn is_domain(&self) -> bool {
88 !self.is_ip()
89 }
90
91 pub fn to_be_bytes(&self) -> Result<Vec<u8>, AddrError> {
92 let mut buf = vec![];
93 match self {
94 TargetAddr::Ip(SocketAddr::V4(addr)) => {
95 debug!("TargetAddr::IpV4");
96
97 buf.extend_from_slice(&[SOCKS5_ADDR_TYPE_IPV4]);
98
99 debug!("addr ip {:?}", (*addr.ip()).octets());
100 buf.extend_from_slice(&(addr.ip()).octets()); buf.extend_from_slice(&addr.port().to_be_bytes()); }
103 TargetAddr::Ip(SocketAddr::V6(addr)) => {
104 debug!("TargetAddr::IpV6");
105 buf.extend_from_slice(&[consts::SOCKS5_ADDR_TYPE_IPV6]);
106
107 debug!("addr ip {:?}", (*addr.ip()).octets());
108 buf.extend_from_slice(&(addr.ip()).octets()); buf.extend_from_slice(&addr.port().to_be_bytes()); }
111 TargetAddr::Domain(ref domain, port) => {
112 debug!("TargetAddr::Domain");
113 if domain.len() > u8::max_value() as usize {
114 return Err(AddrError::DomainLenTooLong(domain.len()));
115 }
116 buf.extend_from_slice(&[consts::SOCKS5_ADDR_TYPE_DOMAIN_NAME, domain.len() as u8]);
117 buf.extend_from_slice(domain.as_bytes()); buf.extend_from_slice(&port.to_be_bytes());
119 }
121 }
122 Ok(buf)
123 }
124
125 pub fn into_string_and_port(self) -> (String, u16) {
126 match self {
127 TargetAddr::Ip(socket_addr) => (socket_addr.ip().to_string(), socket_addr.port()),
128 TargetAddr::Domain(domain, port) => (domain, port),
129 }
130 }
131}
132
133impl std::net::ToSocketAddrs for TargetAddr {
136 type Iter = IntoIter<SocketAddr>;
137
138 fn to_socket_addrs(&self) -> io::Result<IntoIter<SocketAddr>> {
139 match *self {
140 TargetAddr::Ip(addr) => Ok(vec![addr].into_iter()),
141 TargetAddr::Domain(_, _) => Err(io::Error::new(
142 io::ErrorKind::Other,
143 "Domain name has to be explicitly resolved, please use TargetAddr::resolve_dns().",
144 )),
145 }
146 }
147}
148
149impl fmt::Display for TargetAddr {
150 #[inline]
151 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
152 match *self {
153 TargetAddr::Ip(ref addr) => write!(f, "{}", addr),
154 TargetAddr::Domain(ref addr, ref port) => write!(f, "{}:{}", addr, port),
155 }
156 }
157}
158
159pub trait ToTargetAddr {
161 fn to_target_addr(&self) -> io::Result<TargetAddr>;
163}
164
165impl<'a> ToTargetAddr for (&'a str, u16) {
166 fn to_target_addr(&self) -> io::Result<TargetAddr> {
167 if let Ok(addr) = self.0.parse::<Ipv4Addr>() {
169 return (addr, self.1).to_target_addr();
170 }
171
172 if let Ok(addr) = self.0.parse::<Ipv6Addr>() {
173 return (addr, self.1).to_target_addr();
174 }
175
176 Ok(TargetAddr::Domain(self.0.to_owned(), self.1))
177 }
178}
179
180impl ToTargetAddr for SocketAddr {
181 fn to_target_addr(&self) -> io::Result<TargetAddr> {
182 Ok(TargetAddr::Ip(*self))
183 }
184}
185
186impl ToTargetAddr for SocketAddrV4 {
187 fn to_target_addr(&self) -> io::Result<TargetAddr> {
188 SocketAddr::V4(*self).to_target_addr()
189 }
190}
191
192impl ToTargetAddr for SocketAddrV6 {
193 fn to_target_addr(&self) -> io::Result<TargetAddr> {
194 SocketAddr::V6(*self).to_target_addr()
195 }
196}
197
198impl ToTargetAddr for (IpAddr, u16) {
199 fn to_target_addr(&self) -> io::Result<TargetAddr> {
200 match self.0 {
201 IpAddr::V4(ipv4_addr) => (ipv4_addr, self.1).to_target_addr(),
202 IpAddr::V6(ipv6_addr) => (ipv6_addr, self.1).to_target_addr(),
203 }
204 }
205}
206
207impl ToTargetAddr for (Ipv4Addr, u16) {
208 fn to_target_addr(&self) -> io::Result<TargetAddr> {
209 SocketAddrV4::new(self.0, self.1).to_target_addr()
210 }
211}
212
213impl ToTargetAddr for (Ipv6Addr, u16) {
214 fn to_target_addr(&self) -> io::Result<TargetAddr> {
215 SocketAddrV6::new(self.0, self.1, 0, 0).to_target_addr()
216 }
217}
218
219impl ToTargetAddr for TargetAddr {
220 fn to_target_addr(&self) -> io::Result<TargetAddr> {
221 Ok(self.clone())
222 }
223}
224
225#[derive(Debug)]
226pub enum Addr {
227 V4([u8; 4]),
228 V6([u8; 16]),
229 Domain(String), }
231
232pub async fn read_address<T: AsyncRead + Unpin>(
234 stream: &mut T,
235 atyp: u8,
236) -> Result<TargetAddr, AddrError> {
237 let addr = match atyp {
238 consts::SOCKS5_ADDR_TYPE_IPV4 => {
239 debug!("Address type `IPv4`");
240 Addr::V4(read_exact!(stream, [0u8; 4]).map_err(|err| AddrError::IPv4Unreadable(err))?)
241 }
242 consts::SOCKS5_ADDR_TYPE_IPV6 => {
243 debug!("Address type `IPv6`");
244 Addr::V6(read_exact!(stream, [0u8; 16]).map_err(|err| AddrError::IPv6Unreadable(err))?)
245 }
246 consts::SOCKS5_ADDR_TYPE_DOMAIN_NAME => {
247 debug!("Address type `domain`");
248 let len =
249 read_exact!(stream, [0]).map_err(|err| AddrError::DomainLenUnreadable(err))?[0];
250 let domain = read_exact!(stream, vec![0u8; len as usize])
251 .map_err(|err| AddrError::DomainContentUnreadable(err))?;
252 let domain = String::from_utf8(domain).map_err(|err| AddrError::Utf8(err))?;
254
255 Addr::Domain(domain)
256 }
257 _ => return Err(AddrError::IncorrectAddressType),
258 };
259
260 let port = read_exact!(stream, [0u8; 2]).map_err(|err| AddrError::PortNumberUnreadable(err))?;
262 let port = (port[0] as u16) << 8 | port[1] as u16;
264
265 let addr: TargetAddr = match addr {
267 Addr::V4([a, b, c, d]) => (Ipv4Addr::new(a, b, c, d), port)
268 .to_target_addr()
269 .map_err(|err| AddrError::AddrConversionFailed(err))?,
270 Addr::V6(x) => (Ipv6Addr::from(x), port)
271 .to_target_addr()
272 .map_err(|err| AddrError::AddrConversionFailed(err))?,
273 Addr::Domain(domain) => TargetAddr::Domain(domain, port),
274 };
275
276 Ok(addr)
277}