1use std::{
4 fmt::{self, Display, Formatter},
5 io,
6 net::{Ipv4Addr, Ipv6Addr, SocketAddrV4, SocketAddrV6},
7};
8
9use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
10
11pub mod constants {
13 pub const VERSION: u8 = 0x05;
14
15 pub const ATYP_IPV4: u8 = 0x01;
17 pub const ATYP_DOMAIN_NAME: u8 = 0x03;
18 pub const ATYP_IPV6: u8 = 0x04;
19
20 pub const METHOD_NO_AUTHENTICATION: u8 = 0x00;
22
23 pub const COMMAND_CONNECT: u8 = 0x01;
25}
26
27pub enum Socks5Addr {
29 Ipv4(SocketAddrV4),
30 Ipv6(SocketAddrV6),
31 DomainName((String, u16)),
32}
33
34impl Socks5Addr {
35 pub async fn construct<R>(reader: &mut R) -> io::Result<Self>
37 where
38 R: AsyncRead + Unpin + ?Sized,
39 {
40 let mut buf = [0u8];
41 reader.read_exact(&mut buf).await?;
42 let atyp = buf[0];
43
44 match atyp {
45 constants::ATYP_IPV4 => {
46 let mut buf = [0u8; 6];
47 reader.read_exact(&mut buf).await?;
48
49 let ipv4_addr = Ipv4Addr::new(buf[0], buf[1], buf[2], buf[3]);
50 let port = u16::from_be_bytes([buf[4], buf[5]]);
51
52 Ok(Socks5Addr::Ipv4(SocketAddrV4::new(ipv4_addr, port)))
53 }
54 constants::ATYP_DOMAIN_NAME => {
55 let mut buf = [0u8];
56 reader.read_exact(&mut buf).await?;
57 let len = buf[0] as usize;
58
59 let mut buf = vec![0u8; len + 2];
60 reader.read_exact(&mut buf).await?;
61
62 let domain_name_bytes = buf[..len].to_vec();
63 let domain_name = match String::from_utf8(domain_name_bytes) {
64 Ok(x) => x,
65 Err(_) => return Err(io::Error::new(io::ErrorKind::Other, Error::DomainName)),
66 };
67
68 let port = u16::from_be_bytes([buf[len], buf[len + 1]]);
69
70 Ok(Socks5Addr::DomainName((domain_name, port)))
71 }
72 constants::ATYP_IPV6 => {
73 let mut buf = [0u8; 18];
74 reader.read_exact(&mut buf).await?;
75
76 let a = u16::from_be_bytes([buf[0], buf[1]]);
77 let b = u16::from_be_bytes([buf[2], buf[3]]);
78 let c = u16::from_be_bytes([buf[4], buf[5]]);
79 let d = u16::from_be_bytes([buf[6], buf[7]]);
80 let e = u16::from_be_bytes([buf[8], buf[9]]);
81 let f = u16::from_be_bytes([buf[10], buf[11]]);
82 let g = u16::from_be_bytes([buf[12], buf[13]]);
83 let h = u16::from_be_bytes([buf[14], buf[15]]);
84
85 let ipv6_addr = Ipv6Addr::new(a, b, c, d, e, f, g, h);
86 let port = u16::from_be_bytes([buf[16], buf[17]]);
87
88 Ok(Socks5Addr::Ipv6(SocketAddrV6::new(ipv6_addr, port, 0, 0)))
89 }
90 x => Err(io::Error::new(
91 io::ErrorKind::Other,
92 format!("{} is a invalid address type", x),
93 )),
94 }
95 }
96
97 pub fn get_raw_parts(&self) -> Vec<u8> {
99 let mut addr = Vec::<u8>::new();
100
101 match self {
102 Socks5Addr::Ipv4(v4) => {
103 addr.push(constants::ATYP_IPV4);
104 addr.append(&mut v4.ip().octets().to_vec());
105 addr.append(&mut v4.port().to_be_bytes().to_vec());
106 }
107 Socks5Addr::Ipv6(v6) => {
108 addr.push(constants::ATYP_IPV6);
109 addr.append(&mut v6.ip().octets().to_vec());
110 addr.append(&mut v6.port().to_be_bytes().to_vec());
111 }
112 Socks5Addr::DomainName((domain_name, port)) => {
113 addr.push(constants::ATYP_DOMAIN_NAME);
114 addr.push(domain_name.len() as u8);
115 addr.append(&mut domain_name.clone().into_bytes());
116 addr.append(&mut port.to_be_bytes().to_vec());
117 }
118 };
119
120 addr
121 }
122}
123
124impl Display for Socks5Addr {
125 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
126 match self {
127 Socks5Addr::Ipv4(v4) => write!(f, "{}", v4.to_string()),
128 Socks5Addr::Ipv6(v6) => write!(f, "{}", v6.to_string()),
129 Socks5Addr::DomainName((host, port)) => write!(f, "{}:{}", host, port),
130 }
131 }
132}
133
134#[derive(Debug)]
136pub enum Error {
137 Version(u8),
139
140 VersionInconsistent { now: u8, before: u8 },
142
143 Method,
145
146 Command(u8),
148
149 DomainName,
151}
152
153impl Display for Error {
154 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
155 match self {
156 Error::Version(v) => write!(f, "{} is the unsupported socks version", v),
157 Error::VersionInconsistent { now, before } => {
158 write!(
159 f,
160 "socks version number({}) is inconsistent with before({})",
161 now, before
162 )
163 }
164 Error::Method => write!(f, "only support the NO AUTHENTICATION method"),
165 Error::Command(cmd) => write!(f, "only support the CONNECT method, request {}", cmd),
166 Error::DomainName => write!(f, "the requested domain name is not a string"),
167 }
168 }
169}
170
171impl std::error::Error for Error {}
172
173pub async fn handshake<S>(stream: &mut S) -> io::Result<Socks5Addr>
175where
176 S: AsyncRead + AsyncWrite + Unpin + ?Sized,
177{
178 let mut buf = [0u8; 2];
180 stream.read_exact(&mut buf).await?;
181
182 let ver = buf[0];
183 if ver != constants::VERSION {
184 return Err(io::Error::new(io::ErrorKind::Other, Error::Version(ver)));
185 }
186
187 let mut methods = vec![0u8; buf[1] as usize];
188 stream.read_exact(&mut methods).await?;
189
190 if !methods
191 .iter()
192 .any(|&x| x == constants::METHOD_NO_AUTHENTICATION)
193 {
194 return Err(io::Error::new(io::ErrorKind::Other, Error::Method));
195 }
196
197 let rsp = [constants::VERSION, constants::METHOD_NO_AUTHENTICATION];
198 stream.write_all(&rsp).await?;
199
200 let mut buf = [0u8; 3];
202 stream.read_exact(&mut buf).await?;
203
204 let ver = buf[0];
205 if ver != constants::VERSION {
206 return Err(io::Error::new(
207 io::ErrorKind::Other,
208 Error::VersionInconsistent {
209 now: ver,
210 before: 0x05,
211 },
212 ));
213 }
214
215 let cmd = buf[1];
216 if cmd != constants::COMMAND_CONNECT {
217 return Err(io::Error::new(io::ErrorKind::Other, Error::Command(cmd)));
218 }
219
220 let addr = Socks5Addr::construct(stream).await?;
221
222 let rsp = [
223 constants::VERSION,
224 0x00,
225 0x00,
226 constants::ATYP_IPV4,
227 0x00,
228 0x00,
229 0x00,
230 0x00,
231 0x00,
232 0x00,
233 ];
234 stream.write_all(&rsp).await?;
235
236 Ok(addr)
237}