1use std::{convert::TryInto, io};
2use thiserror::Error;
3use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
4
5use super::common::Address;
6
7#[derive(Debug, Error)]
8pub enum Error {
9 #[error("Invalid version: {0}")]
10 InvalidVersion(u8),
11 #[error("Too many methods")]
12 TooManyMethods,
13 #[error("Invalid handshake")]
14 InvalidHandshake,
15 #[error("Invalid command: {0}")]
16 InvalidCommand(u8),
17 #[error("Invalid command reply: {0}")]
18 InvalidCommandReply(u8),
19 #[error("Command reply with error: {0:?}")]
20 CommandReply(CommandReply),
21 #[error("IO error: {0:?}")]
22 Io(#[from] io::Error),
23}
24pub type Result<T, E = Error> = ::std::result::Result<T, E>;
25
26#[derive(Debug)]
27pub enum Version {
28 V5,
29}
30impl Version {
31 pub async fn read(mut reader: impl AsyncRead + Unpin) -> Result<Version> {
32 let version = &mut [0u8];
33 reader.read_exact(version).await?;
34 match version[0] {
35 5 => Ok(Version::V5),
36 other => Err(Error::InvalidVersion(other)),
37 }
38 }
39 pub async fn write(&self, mut writer: impl AsyncWrite + Unpin) -> Result<()> {
40 let v = match self {
41 Version::V5 => 5u8,
42 };
43 writer.write_all(&[v]).await?;
44 Ok(())
45 }
46}
47
48#[derive(Debug, Eq, PartialEq, Clone, Copy)]
49pub enum AuthMethod {
50 Noauth,
51 Gssapi,
52 UsernamePassword,
53 NoAcceptableMethod,
54 Other(u8),
55}
56
57impl From<u8> for AuthMethod {
58 fn from(n: u8) -> Self {
59 match n {
60 0x00 => AuthMethod::Noauth,
61 0x01 => AuthMethod::Gssapi,
62 0x02 => AuthMethod::UsernamePassword,
63 0xff => AuthMethod::NoAcceptableMethod,
64 other => AuthMethod::Other(other),
65 }
66 }
67}
68
69impl Into<u8> for AuthMethod {
70 fn into(self) -> u8 {
71 match self {
72 AuthMethod::Noauth => 0x00,
73 AuthMethod::Gssapi => 0x01,
74 AuthMethod::UsernamePassword => 0x02,
75 AuthMethod::NoAcceptableMethod => 0xff,
76 AuthMethod::Other(other) => other,
77 }
78 }
79}
80
81#[derive(Debug, Eq, PartialEq, Clone)]
82pub struct AuthRequest(Vec<AuthMethod>);
83
84impl AuthRequest {
85 pub fn new(methods: impl Into<Vec<AuthMethod>>) -> AuthRequest {
86 AuthRequest(methods.into())
87 }
88 pub async fn read(mut reader: impl AsyncRead + Unpin) -> Result<AuthRequest> {
89 let count = &mut [0u8];
90 reader.read_exact(count).await?;
91 let mut methods = vec![0u8; count[0] as usize];
92 reader.read_exact(&mut methods).await?;
93
94 Ok(AuthRequest(methods.into_iter().map(Into::into).collect()))
95 }
96 pub async fn write(&self, mut writer: impl AsyncWrite + Unpin) -> Result<()> {
97 let count = self.0.len();
98 if count > 255 {
99 return Err(Error::TooManyMethods);
100 }
101
102 writer.write_all(&[count as u8]).await?;
103 writer
104 .write_all(
105 &self
106 .0
107 .iter()
108 .map(|i| Into::<u8>::into(*i))
109 .collect::<Vec<_>>(),
110 )
111 .await?;
112
113 Ok(())
114 }
115 pub fn select_from(&self, auth: &[AuthMethod]) -> AuthMethod {
116 self.0
117 .iter()
118 .enumerate()
119 .find(|(_, m)| auth.contains(*m))
120 .map(|(v, _)| AuthMethod::from(v as u8))
121 .unwrap_or(AuthMethod::NoAcceptableMethod)
122 }
123}
124
125#[derive(Debug, Eq, PartialEq, Clone)]
126pub struct AuthResponse(AuthMethod);
127
128impl AuthResponse {
129 pub fn new(method: AuthMethod) -> AuthResponse {
130 AuthResponse(method)
131 }
132 pub async fn read(mut reader: impl AsyncRead + Unpin) -> Result<AuthResponse> {
133 let method = &mut [0u8];
134 reader.read_exact(method).await?;
135 Ok(AuthResponse(method[0].into()))
136 }
137 pub async fn write(&self, mut writer: impl AsyncWrite + Unpin) -> Result<()> {
138 writer.write_all(&[self.0.into()]).await?;
139 Ok(())
140 }
141 pub fn method(&self) -> AuthMethod {
142 self.0
143 }
144}
145
146#[derive(Debug)]
147pub enum Command {
148 Connect,
149 Bind,
150 UdpAssociate,
151}
152#[derive(Debug)]
153pub struct CommandRequest {
154 pub command: Command,
155 pub address: Address,
156}
157
158impl CommandRequest {
159 pub fn connect(address: Address) -> CommandRequest {
160 CommandRequest {
161 command: Command::Connect,
162 address,
163 }
164 }
165 pub fn udp_associate(address: Address) -> CommandRequest {
166 CommandRequest {
167 command: Command::UdpAssociate,
168 address,
169 }
170 }
171 pub async fn read(mut reader: impl AsyncRead + Unpin) -> Result<CommandRequest> {
172 let buf = &mut [0u8; 3];
173 reader.read_exact(buf).await?;
174 if buf[0] != 5 {
175 return Err(Error::InvalidVersion(buf[0]));
176 }
177 if buf[2] != 0 {
178 return Err(Error::InvalidHandshake);
179 }
180 let cmd = match buf[1] {
181 1 => Command::Connect,
182 2 => Command::Bind,
183 3 => Command::UdpAssociate,
184 _ => return Err(Error::InvalidCommand(buf[1])),
185 };
186
187 let address = Address::read(reader).await?;
188
189 Ok(CommandRequest {
190 command: cmd,
191 address,
192 })
193 }
194 pub async fn write(&self, mut writer: impl AsyncWrite + Unpin) -> Result<()> {
195 let cmd = match self.command {
196 Command::Connect => 1u8,
197 Command::Bind => 2,
198 Command::UdpAssociate => 3,
199 };
200 writer.write_all(&[0x05, cmd, 0x00]).await?;
201 self.address.write(writer).await?;
202 Ok(())
203 }
204}
205
206#[derive(Debug, PartialEq, PartialOrd)]
207pub enum CommandReply {
208 Succeeded,
209 GeneralSocksServerFailure,
210 ConnectionNotAllowedByRuleset,
211 NetworkUnreachable,
212 HostUnreachable,
213 ConnectionRefused,
214 TtlExpired,
215 CommandNotSupported,
216 AddressTypeNotSupported,
217}
218
219impl CommandReply {
220 pub fn from_u8(n: u8) -> Result<CommandReply> {
221 Ok(match n {
222 0 => CommandReply::Succeeded,
223 1 => CommandReply::GeneralSocksServerFailure,
224 2 => CommandReply::ConnectionNotAllowedByRuleset,
225 3 => CommandReply::NetworkUnreachable,
226 4 => CommandReply::HostUnreachable,
227 5 => CommandReply::ConnectionRefused,
228 6 => CommandReply::TtlExpired,
229 7 => CommandReply::CommandNotSupported,
230 8 => CommandReply::AddressTypeNotSupported,
231 _ => return Err(Error::InvalidCommandReply(n)),
232 })
233 }
234 pub fn to_u8(&self) -> u8 {
235 match self {
236 CommandReply::Succeeded => 0,
237 CommandReply::GeneralSocksServerFailure => 1,
238 CommandReply::ConnectionNotAllowedByRuleset => 2,
239 CommandReply::NetworkUnreachable => 3,
240 CommandReply::HostUnreachable => 4,
241 CommandReply::ConnectionRefused => 5,
242 CommandReply::TtlExpired => 6,
243 CommandReply::CommandNotSupported => 7,
244 CommandReply::AddressTypeNotSupported => 8,
245 }
246 }
247}
248
249#[derive(Debug)]
250pub struct CommandResponse {
251 pub reply: CommandReply,
252 pub address: Address,
253}
254
255impl CommandResponse {
256 pub fn success(address: Address) -> CommandResponse {
257 CommandResponse {
258 reply: CommandReply::Succeeded,
259 address,
260 }
261 }
262 pub fn reply_error(reply: CommandReply) -> CommandResponse {
263 CommandResponse {
264 reply,
265 address: Default::default(),
266 }
267 }
268 pub fn error(e: impl TryInto<io::Error>) -> CommandResponse {
269 match e.try_into() {
270 Ok(v) => {
271 use io::ErrorKind;
272 let reply = match v.kind() {
273 ErrorKind::ConnectionRefused => CommandReply::ConnectionRefused,
274 _ => CommandReply::GeneralSocksServerFailure,
275 };
276 CommandResponse {
277 reply,
278 address: Default::default(),
279 }
280 }
281 Err(_) => CommandResponse {
282 reply: CommandReply::GeneralSocksServerFailure,
283 address: Default::default(),
284 },
285 }
286 }
287 pub async fn read(mut reader: impl AsyncRead + Unpin) -> Result<CommandResponse> {
288 let buf = &mut [0u8; 3];
289 reader.read_exact(buf).await?;
290 if buf[0] != 5 {
291 return Err(Error::InvalidVersion(buf[0]));
292 }
293 if buf[2] != 0 {
294 return Err(Error::InvalidHandshake);
295 }
296 let reply = CommandReply::from_u8(buf[1])?;
297
298 let address = Address::read(reader).await?;
299
300 if reply != CommandReply::Succeeded {
301 return Err(Error::CommandReply(reply));
302 }
303
304 Ok(CommandResponse { reply, address })
305 }
306 pub async fn write(&self, mut writer: impl AsyncWrite + Unpin) -> Result<()> {
307 writer.write_all(&[0x05, self.reply.to_u8(), 0x00]).await?;
308 self.address.write(writer).await?;
309 Ok(())
310 }
311}