rd_std/socks5/
protocol.rs

1use std::{convert::TryInto, io};
2use thiserror::Error;
3use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
4
5use super::common::Address;
6
7#[derive(Debug, Error)]
8pub enum Error {
9    #[error("Invalid version: {0}")]
10    InvalidVersion(u8),
11    #[error("Too many methods")]
12    TooManyMethods,
13    #[error("Invalid handshake")]
14    InvalidHandshake,
15    #[error("Invalid command: {0}")]
16    InvalidCommand(u8),
17    #[error("Invalid command reply: {0}")]
18    InvalidCommandReply(u8),
19    #[error("Command reply with error: {0:?}")]
20    CommandReply(CommandReply),
21    #[error("IO error: {0:?}")]
22    Io(#[from] io::Error),
23}
24pub type Result<T, E = Error> = ::std::result::Result<T, E>;
25
26#[derive(Debug)]
27pub enum Version {
28    V5,
29}
30impl Version {
31    pub async fn read(mut reader: impl AsyncRead + Unpin) -> Result<Version> {
32        let version = &mut [0u8];
33        reader.read_exact(version).await?;
34        match version[0] {
35            5 => Ok(Version::V5),
36            other => Err(Error::InvalidVersion(other)),
37        }
38    }
39    pub async fn write(&self, mut writer: impl AsyncWrite + Unpin) -> Result<()> {
40        let v = match self {
41            Version::V5 => 5u8,
42        };
43        writer.write_all(&[v]).await?;
44        Ok(())
45    }
46}
47
48#[derive(Debug, Eq, PartialEq, Clone, Copy)]
49pub enum AuthMethod {
50    Noauth,
51    Gssapi,
52    UsernamePassword,
53    NoAcceptableMethod,
54    Other(u8),
55}
56
57impl From<u8> for AuthMethod {
58    fn from(n: u8) -> Self {
59        match n {
60            0x00 => AuthMethod::Noauth,
61            0x01 => AuthMethod::Gssapi,
62            0x02 => AuthMethod::UsernamePassword,
63            0xff => AuthMethod::NoAcceptableMethod,
64            other => AuthMethod::Other(other),
65        }
66    }
67}
68
69impl Into<u8> for AuthMethod {
70    fn into(self) -> u8 {
71        match self {
72            AuthMethod::Noauth => 0x00,
73            AuthMethod::Gssapi => 0x01,
74            AuthMethod::UsernamePassword => 0x02,
75            AuthMethod::NoAcceptableMethod => 0xff,
76            AuthMethod::Other(other) => other,
77        }
78    }
79}
80
81#[derive(Debug, Eq, PartialEq, Clone)]
82pub struct AuthRequest(Vec<AuthMethod>);
83
84impl AuthRequest {
85    pub fn new(methods: impl Into<Vec<AuthMethod>>) -> AuthRequest {
86        AuthRequest(methods.into())
87    }
88    pub async fn read(mut reader: impl AsyncRead + Unpin) -> Result<AuthRequest> {
89        let count = &mut [0u8];
90        reader.read_exact(count).await?;
91        let mut methods = vec![0u8; count[0] as usize];
92        reader.read_exact(&mut methods).await?;
93
94        Ok(AuthRequest(methods.into_iter().map(Into::into).collect()))
95    }
96    pub async fn write(&self, mut writer: impl AsyncWrite + Unpin) -> Result<()> {
97        let count = self.0.len();
98        if count > 255 {
99            return Err(Error::TooManyMethods);
100        }
101
102        writer.write_all(&[count as u8]).await?;
103        writer
104            .write_all(
105                &self
106                    .0
107                    .iter()
108                    .map(|i| Into::<u8>::into(*i))
109                    .collect::<Vec<_>>(),
110            )
111            .await?;
112
113        Ok(())
114    }
115    pub fn select_from(&self, auth: &[AuthMethod]) -> AuthMethod {
116        self.0
117            .iter()
118            .enumerate()
119            .find(|(_, m)| auth.contains(*m))
120            .map(|(v, _)| AuthMethod::from(v as u8))
121            .unwrap_or(AuthMethod::NoAcceptableMethod)
122    }
123}
124
125#[derive(Debug, Eq, PartialEq, Clone)]
126pub struct AuthResponse(AuthMethod);
127
128impl AuthResponse {
129    pub fn new(method: AuthMethod) -> AuthResponse {
130        AuthResponse(method)
131    }
132    pub async fn read(mut reader: impl AsyncRead + Unpin) -> Result<AuthResponse> {
133        let method = &mut [0u8];
134        reader.read_exact(method).await?;
135        Ok(AuthResponse(method[0].into()))
136    }
137    pub async fn write(&self, mut writer: impl AsyncWrite + Unpin) -> Result<()> {
138        writer.write_all(&[self.0.into()]).await?;
139        Ok(())
140    }
141    pub fn method(&self) -> AuthMethod {
142        self.0
143    }
144}
145
146#[derive(Debug)]
147pub enum Command {
148    Connect,
149    Bind,
150    UdpAssociate,
151}
152#[derive(Debug)]
153pub struct CommandRequest {
154    pub command: Command,
155    pub address: Address,
156}
157
158impl CommandRequest {
159    pub fn connect(address: Address) -> CommandRequest {
160        CommandRequest {
161            command: Command::Connect,
162            address,
163        }
164    }
165    pub fn udp_associate(address: Address) -> CommandRequest {
166        CommandRequest {
167            command: Command::UdpAssociate,
168            address,
169        }
170    }
171    pub async fn read(mut reader: impl AsyncRead + Unpin) -> Result<CommandRequest> {
172        let buf = &mut [0u8; 3];
173        reader.read_exact(buf).await?;
174        if buf[0] != 5 {
175            return Err(Error::InvalidVersion(buf[0]));
176        }
177        if buf[2] != 0 {
178            return Err(Error::InvalidHandshake);
179        }
180        let cmd = match buf[1] {
181            1 => Command::Connect,
182            2 => Command::Bind,
183            3 => Command::UdpAssociate,
184            _ => return Err(Error::InvalidCommand(buf[1])),
185        };
186
187        let address = Address::read(reader).await?;
188
189        Ok(CommandRequest {
190            command: cmd,
191            address,
192        })
193    }
194    pub async fn write(&self, mut writer: impl AsyncWrite + Unpin) -> Result<()> {
195        let cmd = match self.command {
196            Command::Connect => 1u8,
197            Command::Bind => 2,
198            Command::UdpAssociate => 3,
199        };
200        writer.write_all(&[0x05, cmd, 0x00]).await?;
201        self.address.write(writer).await?;
202        Ok(())
203    }
204}
205
206#[derive(Debug, PartialEq, PartialOrd)]
207pub enum CommandReply {
208    Succeeded,
209    GeneralSocksServerFailure,
210    ConnectionNotAllowedByRuleset,
211    NetworkUnreachable,
212    HostUnreachable,
213    ConnectionRefused,
214    TtlExpired,
215    CommandNotSupported,
216    AddressTypeNotSupported,
217}
218
219impl CommandReply {
220    pub fn from_u8(n: u8) -> Result<CommandReply> {
221        Ok(match n {
222            0 => CommandReply::Succeeded,
223            1 => CommandReply::GeneralSocksServerFailure,
224            2 => CommandReply::ConnectionNotAllowedByRuleset,
225            3 => CommandReply::NetworkUnreachable,
226            4 => CommandReply::HostUnreachable,
227            5 => CommandReply::ConnectionRefused,
228            6 => CommandReply::TtlExpired,
229            7 => CommandReply::CommandNotSupported,
230            8 => CommandReply::AddressTypeNotSupported,
231            _ => return Err(Error::InvalidCommandReply(n)),
232        })
233    }
234    pub fn to_u8(&self) -> u8 {
235        match self {
236            CommandReply::Succeeded => 0,
237            CommandReply::GeneralSocksServerFailure => 1,
238            CommandReply::ConnectionNotAllowedByRuleset => 2,
239            CommandReply::NetworkUnreachable => 3,
240            CommandReply::HostUnreachable => 4,
241            CommandReply::ConnectionRefused => 5,
242            CommandReply::TtlExpired => 6,
243            CommandReply::CommandNotSupported => 7,
244            CommandReply::AddressTypeNotSupported => 8,
245        }
246    }
247}
248
249#[derive(Debug)]
250pub struct CommandResponse {
251    pub reply: CommandReply,
252    pub address: Address,
253}
254
255impl CommandResponse {
256    pub fn success(address: Address) -> CommandResponse {
257        CommandResponse {
258            reply: CommandReply::Succeeded,
259            address,
260        }
261    }
262    pub fn reply_error(reply: CommandReply) -> CommandResponse {
263        CommandResponse {
264            reply,
265            address: Default::default(),
266        }
267    }
268    pub fn error(e: impl TryInto<io::Error>) -> CommandResponse {
269        match e.try_into() {
270            Ok(v) => {
271                use io::ErrorKind;
272                let reply = match v.kind() {
273                    ErrorKind::ConnectionRefused => CommandReply::ConnectionRefused,
274                    _ => CommandReply::GeneralSocksServerFailure,
275                };
276                CommandResponse {
277                    reply,
278                    address: Default::default(),
279                }
280            }
281            Err(_) => CommandResponse {
282                reply: CommandReply::GeneralSocksServerFailure,
283                address: Default::default(),
284            },
285        }
286    }
287    pub async fn read(mut reader: impl AsyncRead + Unpin) -> Result<CommandResponse> {
288        let buf = &mut [0u8; 3];
289        reader.read_exact(buf).await?;
290        if buf[0] != 5 {
291            return Err(Error::InvalidVersion(buf[0]));
292        }
293        if buf[2] != 0 {
294            return Err(Error::InvalidHandshake);
295        }
296        let reply = CommandReply::from_u8(buf[1])?;
297
298        let address = Address::read(reader).await?;
299
300        if reply != CommandReply::Succeeded {
301            return Err(Error::CommandReply(reply));
302        }
303
304        Ok(CommandResponse { reply, address })
305    }
306    pub async fn write(&self, mut writer: impl AsyncWrite + Unpin) -> Result<()> {
307        writer.write_all(&[0x05, self.reply.to_u8(), 0x00]).await?;
308        self.address.write(writer).await?;
309        Ok(())
310    }
311}