1#![deny(unsafe_code)]
49#![warn(missing_docs)]
50
51use std::{
52 fmt,
53 io::{self, Read, Write},
54 net::{SocketAddrV4, SocketAddrV6},
55};
56
57use koibumi_net::{
58 domain::{Domain, SocketDomain},
59 socks::SocketAddr,
60};
61
62const SOCKS_VERSION_5: u8 = 0x05;
63const SOCKS_NO_AUTHENTICATION_REQUIRED: u8 = 0x00;
64const SOCKS_USERNAME_AND_PASSWORD: u8 = 0x02;
65const SOCKS_COMMAND_CONNECT: u8 = 0x01;
66const SOCKS_RESERVED: u8 = 0x00;
67const SOCKS_ADDRESS_IPV4: u8 = 0x01;
68const SOCKS_ADDRESS_DOMAIN_NAME: u8 = 0x03;
69const SOCKS_ADDRESS_IPV6: u8 = 0x04;
70const SOCKS_REPLY_SUCCEEDED: u8 = 0x00;
71const SOCKS_REPLY_GENERAL_SOCKS_SERVER_FAILURE: u8 = 0x01;
72const SOCKS_REPLY_HOST_UNREACHABLE: u8 = 0x04;
73const SOCKS_REPLY_CONNECTION_REFUSED: u8 = 0x05;
74const SOCKS_REPLY_TTL_EXPIRED: u8 = 0x06;
75const SOCKS_REPLY_COMMAND_NOT_SUPPORTED: u8 = 0x07;
76const SOCKS_REPLY_ADDRESS_TYPE_NOT_SUPPORTED: u8 = 0x08;
77
78const SOCKS_SUBNEGOTIATION_VERSION: u8 = 0x01;
79const SOCKS_SUBNEGOTIATION_REPLY_SUCCEEDED: u8 = 0x00;
80
81#[derive(Clone, PartialEq, Eq, Hash, Debug)]
83enum Auth {
84 None,
86 Password {
88 username: Vec<u8>,
90 password: Vec<u8>,
92 },
93}
94
95#[derive(Debug)]
102pub enum ConnectError {
103 UnsupportedVersion(u8),
106
107 UnsupportedMethod(u8),
112
113 GeneralServerFailure,
115
116 HostUnreachable,
118
119 ConnectionRefused,
121
122 TtlExpired,
124
125 CommandNotSupported,
127
128 AddressTypeNotSupported,
130
131 UnknownFailure(u8),
134
135 UnsupportedAddressType(u8),
138
139 IoError(io::Error),
142
143 InvalidUsernameLength(usize),
146 InvalidPasswordLength(usize),
149 UnsupportedSubnegotiationVersion(u8),
152 AuthenticationFailure,
154}
155
156impl fmt::Display for ConnectError {
157 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
158 match self {
159 Self::UnsupportedVersion(ver) => write!(f, "Unsupported SOCKS version: {:#02x}", ver),
160 Self::UnsupportedMethod(method) => {
161 write!(f, "Unsupported SOCKS method: {:#02x}", method)
162 }
163 Self::GeneralServerFailure => "General SOCKS server failure".fmt(f),
164 Self::HostUnreachable => "Host unreachable".fmt(f),
165 Self::ConnectionRefused => "Connection refused".fmt(f),
166 Self::TtlExpired => "TTL expired".fmt(f),
167 Self::CommandNotSupported => "Command not supported".fmt(f),
168 Self::AddressTypeNotSupported => "Address type not supported".fmt(f),
169 Self::UnknownFailure(rep) => write!(f, "Unknown SOCKS failure: {:#02x}", rep),
170 Self::UnsupportedAddressType(atyp) => {
171 write!(f, "Unsupported address type: {:#02x}", atyp)
172 }
173
174 Self::IoError(err) => err.fmt(f),
175
176 Self::InvalidUsernameLength(len) => {
177 write!(f, "username length must be 1..255, but {}", len)
178 }
179 Self::InvalidPasswordLength(len) => {
180 write!(f, "password length must be 1..255, but {}", len)
181 }
182 Self::UnsupportedSubnegotiationVersion(ver) => {
183 write!(f, "Unsupported SOCKS subnegotiation version: {:#02x}", ver)
184 }
185 Self::AuthenticationFailure => "authentication failure".fmt(f),
186 }
187 }
188}
189
190impl std::error::Error for ConnectError {}
191
192impl From<io::Error> for ConnectError {
193 fn from(err: io::Error) -> Self {
194 ConnectError::IoError(err)
195 }
196}
197
198pub type Result<T> = std::result::Result<T, ConnectError>;
200
201fn read_u8<R>(s: &mut R) -> Result<u8>
202where
203 R: Read,
204{
205 let mut bytes = [0; 1];
206 s.read_exact(&mut bytes)?;
207 Ok(bytes[0])
208}
209
210fn read_u16<R>(s: &mut R) -> Result<u16>
211where
212 R: Read,
213{
214 let mut bytes = [0; 2];
215 s.read_exact(&mut bytes)?;
216 Ok(u16::from_be_bytes(bytes))
217}
218
219pub fn connect<S>(server: &mut S, destination: SocketAddr) -> Result<SocketAddr>
265where
266 S: Read + Write,
267{
268 connect_with_auth(server, Auth::None, destination)
269}
270
271pub fn connect_with_password<S>(
279 server: &mut S,
280 username: impl AsRef<[u8]>,
281 password: impl AsRef<[u8]>,
282 destination: SocketAddr,
283) -> Result<SocketAddr>
284where
285 S: Read + Write,
286{
287 connect_with_auth(
288 server,
289 Auth::Password {
290 username: username.as_ref().to_vec(),
291 password: password.as_ref().to_vec(),
292 },
293 destination,
294 )
295}
296
297#[allow(clippy::len_zero)]
304fn connect_with_auth<S>(server: &mut S, auth: Auth, destination: SocketAddr) -> Result<SocketAddr>
305where
306 S: Read + Write,
307{
308 if let Auth::Password { username, password } = &auth {
311 if username.len() < 1 || username.len() > 255 {
312 return Err(ConnectError::InvalidUsernameLength(username.len()));
313 }
314 if password.len() < 1 || password.len() > 255 {
315 return Err(ConnectError::InvalidPasswordLength(password.len()));
316 }
317 }
318
319 let mut packet: Vec<u8> = Vec::with_capacity(3);
322 packet.push(SOCKS_VERSION_5);
324 packet.push(1);
326 let requested_method = match auth {
328 Auth::None => SOCKS_NO_AUTHENTICATION_REQUIRED,
329 Auth::Password { .. } => SOCKS_USERNAME_AND_PASSWORD,
330 };
331 packet.push(requested_method);
332
333 server.write_all(&packet)?;
334 server.flush()?;
335
336 let ver = read_u8(server)?;
339 let method = read_u8(server)?;
340 if ver != SOCKS_VERSION_5 {
341 return Err(ConnectError::UnsupportedVersion(ver));
342 }
343 if method != requested_method {
344 return Err(ConnectError::UnsupportedMethod(method));
345 }
346
347 if let Auth::Password { username, password } = auth {
350 let mut packet = Vec::new();
351 packet.push(SOCKS_SUBNEGOTIATION_VERSION);
352 packet.push(username.len() as u8);
353 packet.extend_from_slice(&username);
354 packet.push(password.len() as u8);
355 packet.extend_from_slice(&password);
356
357 server.write_all(&packet)?;
358 server.flush()?;
359
360 let ver = read_u8(server)?;
361 let status = read_u8(server)?;
362 if ver != SOCKS_SUBNEGOTIATION_VERSION {
363 return Err(ConnectError::UnsupportedSubnegotiationVersion(ver));
364 }
365 if status != SOCKS_SUBNEGOTIATION_REPLY_SUCCEEDED {
366 return Err(ConnectError::AuthenticationFailure);
367 }
368 }
369
370 let mut packet: Vec<u8> = Vec::new();
373 packet.push(SOCKS_VERSION_5);
374 packet.push(SOCKS_COMMAND_CONNECT);
375 packet.push(SOCKS_RESERVED);
376 match destination {
377 SocketAddr::Ipv4(addr) => {
378 packet.push(SOCKS_ADDRESS_IPV4);
379 packet.extend_from_slice(&addr.ip().octets());
380 packet.extend_from_slice(&addr.port().to_be_bytes());
381 }
382 SocketAddr::Domain(domain) => {
383 packet.push(SOCKS_ADDRESS_DOMAIN_NAME);
384 packet.push(domain.domain().as_ref().len() as u8);
385 packet.extend_from_slice(domain.domain().as_bytes());
386 packet.extend_from_slice(&domain.port().as_u16().to_be_bytes());
387 }
388 SocketAddr::Ipv6(addr) => {
389 packet.push(SOCKS_ADDRESS_IPV6);
390 packet.extend_from_slice(&addr.ip().octets());
391 packet.extend_from_slice(&addr.port().to_be_bytes());
392 }
393 }
394
395 server.write_all(&packet)?;
396 server.flush()?;
397
398 let ver = read_u8(server)?;
401 let rep = read_u8(server)?;
402 if ver != SOCKS_VERSION_5 {
403 return Err(ConnectError::UnsupportedVersion(ver));
404 }
405 match rep {
406 SOCKS_REPLY_SUCCEEDED => {}
407 SOCKS_REPLY_GENERAL_SOCKS_SERVER_FAILURE => return Err(ConnectError::GeneralServerFailure),
408 SOCKS_REPLY_HOST_UNREACHABLE => return Err(ConnectError::HostUnreachable),
409 SOCKS_REPLY_CONNECTION_REFUSED => return Err(ConnectError::ConnectionRefused),
410 SOCKS_REPLY_TTL_EXPIRED => return Err(ConnectError::TtlExpired),
411 SOCKS_REPLY_COMMAND_NOT_SUPPORTED => return Err(ConnectError::CommandNotSupported),
412 SOCKS_REPLY_ADDRESS_TYPE_NOT_SUPPORTED => {
413 return Err(ConnectError::AddressTypeNotSupported)
414 }
415 _ => return Err(ConnectError::UnknownFailure(rep)),
416 }
417 let _rsv = read_u8(server)?;
418 let atyp = read_u8(server)?;
419 match atyp {
420 SOCKS_ADDRESS_IPV4 => {
421 let mut bytes = [0; 4];
422 server.read_exact(&mut bytes)?;
423 let port = read_u16(server)?;
424 Ok(SocketAddr::Ipv4(SocketAddrV4::new(bytes.into(), port)))
425 }
426 SOCKS_ADDRESS_DOMAIN_NAME => {
427 let len = read_u8(server)?;
428 let mut r = server.take(len as u64);
429 let mut bytes = Vec::with_capacity(len as usize);
430 r.read_to_end(&mut bytes)?;
431 let domain = Domain::from_bytes(&bytes).unwrap();
432 let port = read_u16(server)?;
433 Ok(SocketAddr::Domain(SocketDomain::new(domain, port.into())))
434 }
435 SOCKS_ADDRESS_IPV6 => {
436 let mut bytes = [0; 16];
437 server.read_exact(&mut bytes)?;
438 let port = read_u16(server)?;
439 Ok(SocketAddr::Ipv6(SocketAddrV6::new(
440 bytes.into(),
441 port,
442 0,
443 0,
444 )))
445 }
446 _ => Err(ConnectError::UnsupportedAddressType(atyp)),
447 }
448}