1use crate::utils::*;
2
3use std::{
4 io,
5 net::SocketAddr,
6 ops::{Deref, DerefMut},
7};
8use tokio::{
9 io::{AsyncReadExt, AsyncWriteExt, Result},
10 net::{TcpStream, ToSocketAddrs},
11};
12
13pub async fn new(
14 server: impl ToSocketAddrs,
15 dest: &Addr,
16 auth: Option<AuthMethod>,
17) -> Result<TcpStream> {
18 let conn = TcpStream::connect(server).await?;
19 let auth = auth.unwrap_or(AuthMethod::NoAuth);
20
21 let client = PendingHandshake(conn);
22 let client = client.handshake(&auth).await?;
23 let client = client.authenticate(&auth).await?;
24 let client = client.connect(dest).await?;
25
26 Ok(client)
27}
28
29impl_deref!(PendingHandshake, TcpStream);
30impl PendingHandshake {
31 #[inline]
32 async fn handshake(mut self, method: &AuthMethod) -> Result<PendingAuthenticate> {
33 let msg: &[u8] = &[SOCKS_VER, 0x01, method.to_code()];
34 self.write_all(msg).await?;
35 self.flush().await?;
36
37 let mut buffer = [0; 2];
38 self.read_exact(&mut buffer).await?;
39
40 if buffer[0] != SOCKS_VER {
41 return Err(io::Error::new(
42 io::ErrorKind::ConnectionAborted,
43 "unsupported protocol",
44 ));
45 }
46
47 let auth = AuthMethod::from_code(buffer[1])?;
48
49 if let AuthMethod::NoAvailable = auth {
50 Err(io::Error::new(
51 io::ErrorKind::ConnectionRefused,
52 "no supported authenticate method available",
53 ))
54 } else if auth.to_code() != method.to_code() {
55 Err(io::Error::new(
56 io::ErrorKind::ConnectionAborted,
57 "unsupported protocol",
58 ))
59 } else {
60 Ok(PendingAuthenticate(self.0))
61 }
62 }
63}
64
65impl_deref!(PendingAuthenticate, TcpStream);
66impl PendingAuthenticate {
67 #[inline]
68 async fn authenticate(self, auth: &AuthMethod) -> Result<PendingConnect> {
69 match auth {
70 AuthMethod::NoAuth => Ok(PendingConnect(self.0)),
71 _ => Err(io::Error::new(
72 io::ErrorKind::Other,
73 format!("authenticate method {:?} not implemented", &auth),
74 )),
75 }
76 }
77}
78
79impl_deref!(PendingConnect, TcpStream);
80impl PendingConnect {
81 #[inline]
82 async fn connect(mut self, dest: &Addr) -> Result<TcpStream> {
83 let mut buffer = [0u8; 4 + 255 + 2];
84 let mut request = Buffer::from(&mut buffer);
85 request.extend(&[SOCKS_RSV, SOCKS_COMMAND_CONNECT, SOCKS_RSV]);
86
87 parse_dest(&mut request, dest)?;
88
89 self.write_all(request.content()).await?;
90 self.flush().await?;
91
92 let header: &mut [u8] = &mut buffer[..4];
93
94 self.read_exact(header).await?;
95
96 if header[0] != SOCKS_VER || header[02] != SOCKS_RSV {
97 return Err(io::Error::new(
98 io::ErrorKind::ConnectionAborted,
99 "unsupported protocol",
100 ));
101 }
102 if header[1] != SocksError::SUCCESS as u8 {
103 return Err(SocksError::from(header[1]).into());
104 }
105
106 self.extract_address(header[3], &mut buffer).await?;
107
108 Ok(self.0)
109 }
110
111 async fn extract_address(&mut self, addr_type: u8, buffer: &mut [u8]) -> Result<()> {
112 match addr_type {
113 SOCKS_ADDR_IPV4 => self.read_exact(&mut buffer[..4 + 2]).await?,
114 SOCKS_ADDR_IPV6 => self.read_exact(&mut buffer[..16 + 2]).await?,
115 SOCKS_ADDR_DOMAINNAME => {
116 self.read_exact(&mut buffer[..1]).await?;
117 let len = buffer[0] as usize;
118 self.read_exact(&mut buffer[..(len + 2)]).await?
119 }
120 _ => {
121 return Err(io::Error::new(
122 io::ErrorKind::ConnectionAborted,
123 "unsupported address type",
124 ))
125 }
126 };
127 Ok(())
128 }
129}
130
131macro_rules! write_addr_binary {
132 ($buffer:ident,$addr_type:ident,$addr:ident) => {{
133 $buffer.push($addr_type);
134 $buffer.extend(&$addr.ip().octets());
135 $buffer.extend(&$addr.port().to_be_bytes());
136 }};
137}
138
139#[inline]
140fn parse_dest(request: &mut Buffer, dest: &Addr) -> Result<()> {
141 match dest {
142 Addr::SocketAddr(addr) => {
143 match addr {
144 SocketAddr::V4(v4) => write_addr_binary!(request, SOCKS_ADDR_IPV4, v4),
145 SocketAddr::V6(v6) => write_addr_binary!(request, SOCKS_ADDR_IPV6, v6),
146 };
147 }
148 Addr::HostnamePort(hostname_port) => {
149 request.push(SOCKS_ADDR_DOMAINNAME);
150 let mut hostname_port = hostname_port.split(":");
151 let parse_err =
152 io::Error::new(io::ErrorKind::InvalidInput, "bad pattern in hostname:port");
153 let hostname = hostname_port.next();
154 let port = hostname_port.next();
155 let none = hostname_port.next();
156
157 if let (Some(hostname), Some(port), None) = (hostname, port, none) {
158 let hostname = hostname.as_bytes();
159 if hostname.len() > u8::MAX as usize {
160 return Err(io::Error::new(
161 io::ErrorKind::InvalidInput,
162 "hostname too long",
163 ));
164 }
165 request.push(hostname.len() as u8);
166 request.extend(hostname);
167 let port = port.parse::<u16>().map_err(|_| parse_err)?;
168 request.extend(&port.to_be_bytes());
169 } else {
170 return Err(parse_err);
171 }
172 }
173 }
174 Ok(())
175}