1#![deny(unsafe_code)]
50#![warn(missing_docs)]
51
52use std::{
53 fmt,
54 net::{SocketAddrV4, SocketAddrV6},
55};
56
57use futures::{
58 io::{self, AsyncRead, AsyncWrite},
59 prelude::*,
60};
61
62use koibumi_net::{
63 domain::{Domain, SocketDomain},
64 socks::SocketAddr,
65};
66
67const SOCKS_VERSION_5: u8 = 0x05;
68const SOCKS_NO_AUTHENTICATION_REQUIRED: u8 = 0x00;
69const SOCKS_USERNAME_AND_PASSWORD: u8 = 0x02;
70const SOCKS_COMMAND_CONNECT: u8 = 0x01;
71const SOCKS_RESERVED: u8 = 0x00;
72const SOCKS_ADDRESS_IPV4: u8 = 0x01;
73const SOCKS_ADDRESS_DOMAIN_NAME: u8 = 0x03;
74const SOCKS_ADDRESS_IPV6: u8 = 0x04;
75const SOCKS_REPLY_SUCCEEDED: u8 = 0x00;
76const SOCKS_REPLY_GENERAL_SOCKS_SERVER_FAILURE: u8 = 0x01;
77const SOCKS_REPLY_HOST_UNREACHABLE: u8 = 0x04;
78const SOCKS_REPLY_CONNECTION_REFUSED: u8 = 0x05;
79const SOCKS_REPLY_TTL_EXPIRED: u8 = 0x06;
80const SOCKS_REPLY_COMMAND_NOT_SUPPORTED: u8 = 0x07;
81const SOCKS_REPLY_ADDRESS_TYPE_NOT_SUPPORTED: u8 = 0x08;
82
83const SOCKS_SUBNEGOTIATION_VERSION: u8 = 0x01;
84const SOCKS_SUBNEGOTIATION_REPLY_SUCCEEDED: u8 = 0x00;
85
86#[derive(Clone, PartialEq, Eq, Hash, Debug)]
88enum Auth {
89 None,
91 Password {
93 username: Vec<u8>,
95 password: Vec<u8>,
97 },
98}
99
100#[derive(Debug)]
107pub enum ConnectError {
108 UnsupportedVersion(u8),
111
112 UnsupportedMethod(u8),
117
118 GeneralServerFailure,
120
121 HostUnreachable,
123
124 ConnectionRefused,
126
127 TtlExpired,
129
130 CommandNotSupported,
132
133 AddressTypeNotSupported,
135
136 UnknownFailure(u8),
139
140 UnsupportedAddressType(u8),
143
144 IoError(io::Error),
147
148 InvalidUsernameLength(usize),
151 InvalidPasswordLength(usize),
154 UnsupportedSubnegotiationVersion(u8),
157 AuthenticationFailure,
159}
160
161impl fmt::Display for ConnectError {
162 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
163 match self {
164 Self::UnsupportedVersion(ver) => write!(f, "Unsupported SOCKS version: {:#02x}", ver),
165 Self::UnsupportedMethod(method) => {
166 write!(f, "Unsupported SOCKS method: {:#02x}", method)
167 }
168 Self::GeneralServerFailure => "General SOCKS server failure".fmt(f),
169 Self::HostUnreachable => "Host unreachable".fmt(f),
170 Self::ConnectionRefused => "Connection refused".fmt(f),
171 Self::TtlExpired => "TTL expired".fmt(f),
172 Self::CommandNotSupported => "Command not supported".fmt(f),
173 Self::AddressTypeNotSupported => "Address type not supported".fmt(f),
174 Self::UnknownFailure(rep) => write!(f, "Unknown SOCKS failure: {:#02x}", rep),
175 Self::UnsupportedAddressType(atyp) => {
176 write!(f, "Unsupported address type: {:#02x}", atyp)
177 }
178
179 Self::IoError(err) => err.fmt(f),
180
181 Self::InvalidUsernameLength(len) => {
182 write!(f, "username length must be 1..255, but {}", len)
183 }
184 Self::InvalidPasswordLength(len) => {
185 write!(f, "password length must be 1..255, but {}", len)
186 }
187 Self::UnsupportedSubnegotiationVersion(ver) => {
188 write!(f, "Unsupported SOCKS subnegotiation version: {:#02x}", ver)
189 }
190 Self::AuthenticationFailure => "authentication failure".fmt(f),
191 }
192 }
193}
194
195impl std::error::Error for ConnectError {}
196
197impl From<io::Error> for ConnectError {
198 fn from(err: io::Error) -> Self {
199 ConnectError::IoError(err)
200 }
201}
202
203pub type Result<T> = std::result::Result<T, ConnectError>;
205
206async fn read_u8<R>(s: &mut R) -> Result<u8>
207where
208 R: AsyncRead + Unpin,
209{
210 let mut bytes = [0; 1];
211 s.read_exact(&mut bytes).await?;
212 Ok(bytes[0])
213}
214
215async fn read_u16<R>(s: &mut R) -> Result<u16>
216where
217 R: AsyncRead + Unpin,
218{
219 let mut bytes = [0; 2];
220 s.read_exact(&mut bytes).await?;
221 Ok(u16::from_be_bytes(bytes))
222}
223
224pub async fn connect<S>(server: &mut S, destination: SocketAddr) -> Result<SocketAddr>
270where
271 S: AsyncRead + AsyncWrite + Unpin,
272{
273 connect_with_auth(server, Auth::None, destination).await
274}
275
276pub async fn connect_with_password<S>(
284 server: &mut S,
285 username: impl AsRef<[u8]>,
286 password: impl AsRef<[u8]>,
287 destination: SocketAddr,
288) -> Result<SocketAddr>
289where
290 S: AsyncRead + AsyncWrite + Unpin,
291{
292 connect_with_auth(
293 server,
294 Auth::Password {
295 username: username.as_ref().to_vec(),
296 password: password.as_ref().to_vec(),
297 },
298 destination,
299 )
300 .await
301}
302
303#[allow(clippy::len_zero)]
310async fn connect_with_auth<S>(
311 server: &mut S,
312 auth: Auth,
313 destination: SocketAddr,
314) -> Result<SocketAddr>
315where
316 S: AsyncRead + AsyncWrite + Unpin,
317{
318 if let Auth::Password { username, password } = &auth {
321 if username.len() < 1 || username.len() > 255 {
322 return Err(ConnectError::InvalidUsernameLength(username.len()));
323 }
324 if password.len() < 1 || password.len() > 255 {
325 return Err(ConnectError::InvalidPasswordLength(password.len()));
326 }
327 }
328
329 let mut packet: Vec<u8> = Vec::with_capacity(3);
332 packet.push(SOCKS_VERSION_5);
334 packet.push(1);
336 let requested_method = match auth {
338 Auth::None => SOCKS_NO_AUTHENTICATION_REQUIRED,
339 Auth::Password { .. } => SOCKS_USERNAME_AND_PASSWORD,
340 };
341 packet.push(requested_method);
342
343 server.write_all(&packet).await?;
344 server.flush().await?;
345
346 let ver = read_u8(server).await?;
349 let method = read_u8(server).await?;
350 if ver != SOCKS_VERSION_5 {
351 return Err(ConnectError::UnsupportedVersion(ver));
352 }
353 if method != requested_method {
354 return Err(ConnectError::UnsupportedMethod(method));
355 }
356
357 if let Auth::Password { username, password } = auth {
360 let mut packet = Vec::new();
361 packet.push(SOCKS_SUBNEGOTIATION_VERSION);
362 packet.push(username.len() as u8);
363 packet.extend_from_slice(&username);
364 packet.push(password.len() as u8);
365 packet.extend_from_slice(&password);
366
367 server.write_all(&packet).await?;
368 server.flush().await?;
369
370 let ver = read_u8(server).await?;
371 let status = read_u8(server).await?;
372 if ver != SOCKS_SUBNEGOTIATION_VERSION {
373 return Err(ConnectError::UnsupportedSubnegotiationVersion(ver));
374 }
375 if status != SOCKS_SUBNEGOTIATION_REPLY_SUCCEEDED {
376 return Err(ConnectError::AuthenticationFailure);
377 }
378 }
379
380 let mut packet: Vec<u8> = Vec::new();
383 packet.push(SOCKS_VERSION_5);
384 packet.push(SOCKS_COMMAND_CONNECT);
385 packet.push(SOCKS_RESERVED);
386 match destination {
387 SocketAddr::Ipv4(addr) => {
388 packet.push(SOCKS_ADDRESS_IPV4);
389 packet.extend_from_slice(&addr.ip().octets());
390 packet.extend_from_slice(&addr.port().to_be_bytes());
391 }
392 SocketAddr::Domain(domain) => {
393 packet.push(SOCKS_ADDRESS_DOMAIN_NAME);
394 packet.push(domain.domain().as_ref().len() as u8);
395 packet.extend_from_slice(domain.domain().as_bytes());
396 packet.extend_from_slice(&domain.port().as_u16().to_be_bytes());
397 }
398 SocketAddr::Ipv6(addr) => {
399 packet.push(SOCKS_ADDRESS_IPV6);
400 packet.extend_from_slice(&addr.ip().octets());
401 packet.extend_from_slice(&addr.port().to_be_bytes());
402 }
403 }
404
405 server.write_all(&packet).await?;
406 server.flush().await?;
407
408 let ver = read_u8(server).await?;
411 let rep = read_u8(server).await?;
412 if ver != SOCKS_VERSION_5 {
413 return Err(ConnectError::UnsupportedVersion(ver));
414 }
415 match rep {
416 SOCKS_REPLY_SUCCEEDED => {}
417 SOCKS_REPLY_GENERAL_SOCKS_SERVER_FAILURE => return Err(ConnectError::GeneralServerFailure),
418 SOCKS_REPLY_HOST_UNREACHABLE => return Err(ConnectError::HostUnreachable),
419 SOCKS_REPLY_CONNECTION_REFUSED => return Err(ConnectError::ConnectionRefused),
420 SOCKS_REPLY_TTL_EXPIRED => return Err(ConnectError::TtlExpired),
421 SOCKS_REPLY_COMMAND_NOT_SUPPORTED => return Err(ConnectError::CommandNotSupported),
422 SOCKS_REPLY_ADDRESS_TYPE_NOT_SUPPORTED => {
423 return Err(ConnectError::AddressTypeNotSupported)
424 }
425 _ => return Err(ConnectError::UnknownFailure(rep)),
426 }
427 let _rsv = read_u8(server).await?;
428 let atyp = read_u8(server).await?;
429 match atyp {
430 SOCKS_ADDRESS_IPV4 => {
431 let mut bytes = [0; 4];
432 server.read_exact(&mut bytes).await?;
433 let port = read_u16(server).await?;
434 Ok(SocketAddr::Ipv4(SocketAddrV4::new(bytes.into(), port)))
435 }
436 SOCKS_ADDRESS_DOMAIN_NAME => {
437 let len = read_u8(server).await?;
438 let mut r = server.take(len as u64);
439 let mut bytes = Vec::with_capacity(len as usize);
440 r.read_to_end(&mut bytes).await?;
441 let domain = Domain::from_bytes(&bytes).unwrap();
442 let port = read_u16(server).await?;
443 Ok(SocketAddr::Domain(SocketDomain::new(domain, port.into())))
444 }
445 SOCKS_ADDRESS_IPV6 => {
446 let mut bytes = [0; 16];
447 server.read_exact(&mut bytes).await?;
448 let port = read_u16(server).await?;
449 Ok(SocketAddr::Ipv6(SocketAddrV6::new(
450 bytes.into(),
451 port,
452 0,
453 0,
454 )))
455 }
456 _ => Err(ConnectError::UnsupportedAddressType(atyp)),
457 }
458}