1use crate::util::stream::{tcp_connect_with_timeout, ConnectError};
2use crate::util::target_addr::{read_address, AddrError, TargetAddr};
3use crate::{
4 consts, new_udp_header, parse_udp_request, read_exact, ready, AuthenticationMethod, ReplyError,
5 Socks5Command, SocksError, UdpHeaderError,
6};
7use anyhow::Context;
8use socket2::{Domain, Socket, Type};
9use std::future::Future;
10use std::io;
11use std::marker::PhantomData;
12use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, ToSocketAddrs as StdToSocketAddrs};
13use std::ops::Deref;
14use std::pin::Pin;
15use std::string::FromUtf8Error;
16use std::sync::Arc;
17use std::task::{Context as AsyncContext, Poll};
18use std::time::Duration;
19use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
20use tokio::net::{TcpListener, TcpStream, ToSocketAddrs as AsyncToSocketAddrs, UdpSocket};
21use tokio::try_join;
22use tokio_stream::Stream;
23
24#[derive(thiserror::Error, Debug)]
25pub enum SocksServerError {
26 #[error("i/o error when {context}: {source}")]
27 Io {
28 source: io::Error,
29 context: &'static str,
30 },
31 #[error("string error when {context}: {source}")]
32 FromUtf8 {
33 source: FromUtf8Error,
34 context: &'static str,
35 },
36 #[error(transparent)]
37 ConnectError(#[from] ConnectError),
38 #[error(transparent)]
39 UdpHeaderError(#[from] UdpHeaderError),
40 #[error(transparent)]
41 AddrError(#[from] AddrError),
42 #[error("BUG: {0}")] Bug(&'static str),
44 #[error("Auth method unacceptable `{0:?}`.")]
45 AuthMethodUnacceptable(Vec<u8>),
46 #[error("Unsupported SOCKS version `{0}`.")]
47 UnsupportedSocksVersion(u8),
48 #[error("Unsupported SOCKS command `{0}`.")]
49 UnknownCommand(u8),
50 #[error("Unexpected garbage received on TCP stream used for UDP proxy keep-alive: `{0}`")]
51 UnexpectedUdpControlGarbage(u8),
52 #[error("Empty username received")]
53 EmptyUsername,
54 #[error("Empty password received")]
55 EmptyPassword,
56 #[error("Authentication rejected")]
57 AuthenticationRejected,
58 #[error("End of stream")]
59 EOF,
60}
61
62impl SocksServerError {
63 pub fn to_reply_error(&self) -> ReplyError {
64 match self {
65 SocksServerError::UnknownCommand(_) => ReplyError::CommandNotSupported,
66 SocksServerError::AddrError(err) => err.to_reply_error(),
67 _ => ReplyError::GeneralFailure,
68 }
69 }
70}
71
72pub trait ErrorContext<T> {
73 fn err_when(self, context: &'static str) -> Result<T, SocksServerError>;
74}
75
76impl<T> ErrorContext<T> for Result<T, io::Error> {
77 fn err_when(self, context: &'static str) -> Result<T, SocksServerError> {
78 self.map_err(|source| SocksServerError::Io { source, context })
79 }
80}
81
82impl<T> ErrorContext<T> for Result<T, FromUtf8Error> {
83 fn err_when(self, context: &'static str) -> Result<T, SocksServerError> {
84 self.map_err(|source| SocksServerError::FromUtf8 { source, context })
85 }
86}
87
88#[derive(Clone)]
89pub struct Config<A: Authentication = DenyAuthentication> {
90 request_timeout: Duration,
92 skip_auth: bool,
94 dns_resolve: bool,
96 execute_command: bool,
98 allow_udp: bool,
100 allow_no_auth: bool,
103 auth: Option<Arc<A>>,
105 nodelay: bool,
107}
108
109impl<A: Authentication> Default for Config<A> {
110 fn default() -> Self {
111 Config {
112 request_timeout: Duration::from_secs(10),
113 skip_auth: false,
114 dns_resolve: true,
115 execute_command: true,
116 allow_udp: false,
117 allow_no_auth: false,
118 auth: None,
119 nodelay: false,
120 }
121 }
122}
123
124#[async_trait::async_trait]
126pub trait Authentication: Send + Sync {
127 type Item;
128
129 async fn authenticate(&self, credentials: Option<(String, String)>) -> Option<Self::Item>;
130}
131
132async fn authenticate_callback<T: AsyncRead + AsyncWrite + Unpin, A: Authentication>(
133 auth_callback: &A,
134 auth: StandardAuthenticationStarted<T>,
135) -> Result<(Socks5ServerProtocol<T, states::Authenticated>, A::Item), SocksServerError> {
136 match auth {
137 StandardAuthenticationStarted::NoAuthentication(auth) => {
138 if let Some(credentials) = auth_callback.authenticate(None).await {
139 Ok((auth.finish_auth(), credentials))
140 } else {
141 Err(SocksServerError::AuthenticationRejected)
142 }
143 }
144 StandardAuthenticationStarted::PasswordAuthentication(auth) => {
145 let (username, password, auth) = auth.read_username_password().await?;
146 if let Some(credentials) = auth_callback.authenticate(Some((username, password))).await
147 {
148 Ok((auth.accept().await?.finish_auth(), credentials))
149 } else {
150 auth.reject().await?;
151 Err(SocksServerError::AuthenticationRejected)
152 }
153 }
154 }
155}
156
157pub struct SimpleUserPassword {
159 pub username: String,
160 pub password: String,
161}
162
163pub struct AuthSucceeded {
165 pub username: String,
166}
167
168#[async_trait::async_trait]
171impl Authentication for SimpleUserPassword {
172 type Item = AuthSucceeded;
173
174 async fn authenticate(&self, credentials: Option<(String, String)>) -> Option<Self::Item> {
175 if let Some((username, password)) = credentials {
176 if username == self.username && password == self.password {
178 Some(AuthSucceeded { username })
181 } else {
182 None
184 }
185 } else {
186 None
189 }
190 }
191}
192
193#[derive(Copy, Clone, Default)]
195pub struct DenyAuthentication {}
196
197#[async_trait::async_trait]
198impl Authentication for DenyAuthentication {
199 type Item = ();
200
201 async fn authenticate(&self, _credentials: Option<(String, String)>) -> Option<Self::Item> {
202 None
203 }
204}
205
206#[derive(Copy, Clone, Default)]
208pub struct AcceptAuthentication {}
209
210#[async_trait::async_trait]
211impl Authentication for AcceptAuthentication {
212 type Item = ();
213
214 async fn authenticate(&self, _credentials: Option<(String, String)>) -> Option<Self::Item> {
215 Some(())
216 }
217}
218
219impl<A: Authentication> Config<A> {
220 pub fn set_request_timeout(&mut self, d: Duration) -> &mut Self {
222 self.request_timeout = d;
223 self
224 }
225
226 pub fn set_skip_auth(&mut self, value: bool) -> &mut Self {
229 self.skip_auth = value;
230 self.auth = None;
231 self
232 }
233
234 pub fn with_authentication<T: Authentication + 'static>(self, authentication: T) -> Config<T> {
238 Config {
239 request_timeout: self.request_timeout,
240 skip_auth: self.skip_auth,
241 dns_resolve: self.dns_resolve,
242 execute_command: self.execute_command,
243 allow_udp: self.allow_udp,
244 allow_no_auth: self.allow_no_auth,
245 auth: Some(Arc::new(authentication)),
246 nodelay: self.nodelay,
247 }
248 }
249
250 pub fn set_allow_no_auth(&mut self, value: bool) -> &mut Self {
253 self.allow_no_auth = value;
254 self
255 }
256
257 pub fn set_execute_command(&mut self, value: bool) -> &mut Self {
259 self.execute_command = value;
260 self
261 }
262
263 pub fn set_dns_resolve(&mut self, value: bool) -> &mut Self {
265 self.dns_resolve = value;
266 self
267 }
268
269 pub fn set_udp_support(&mut self, value: bool) -> &mut Self {
271 self.allow_udp = value;
272 self
273 }
274}
275
276#[deprecated(
279 since = "0.11.0",
280 note = "Use the new explicit API instead, see examples/server.rs"
281)]
282pub struct Socks5Server<A: Authentication = DenyAuthentication> {
283 listener: TcpListener,
284 config: Arc<Config<A>>,
285}
286
287#[allow(deprecated)]
288impl<A: Authentication + Default> Socks5Server<A> {
289 pub async fn bind<S: AsyncToSocketAddrs>(addr: S) -> io::Result<Self> {
290 let listener = TcpListener::bind(&addr).await?;
291 let config = Arc::new(Config::default());
292
293 Ok(Socks5Server { listener, config })
294 }
295}
296
297#[allow(deprecated)]
298impl<A: Authentication> Socks5Server<A> {
299 pub fn with_config<T: Authentication>(self, config: Config<T>) -> Socks5Server<T> {
301 Socks5Server {
302 listener: self.listener,
303 config: Arc::new(config),
304 }
305 }
306
307 pub fn incoming(&self) -> Incoming<'_, A> {
309 Incoming(self, None)
310 }
311}
312
313#[allow(deprecated)]
317pub struct Incoming<'a, A: Authentication>(
318 &'a Socks5Server<A>,
319 Option<Pin<Box<dyn Future<Output = io::Result<(TcpStream, SocketAddr)>> + Send + Sync + 'a>>>,
320);
321
322#[allow(deprecated)]
325impl<'a, A: Authentication> Stream for Incoming<'a, A> {
326 type Item = Result<Socks5Socket<TcpStream, A>, SocksError>;
327
328 fn poll_next(mut self: Pin<&mut Self>, cx: &mut AsyncContext<'_>) -> Poll<Option<Self::Item>> {
332 loop {
333 if self.1.is_none() {
334 self.1 = Some(Box::pin(self.0.listener.accept()));
335 }
336
337 if let Some(f) = &mut self.1 {
338 let (socket, peer_addr) = ready!(f.as_mut().poll(cx))?;
340 self.1 = None;
341
342 let local_addr = socket.local_addr()?;
343 debug!(
344 "incoming connection from peer {} @ {}",
345 &peer_addr, &local_addr
346 );
347
348 let socket = Socks5Socket::new(socket, self.0.config.clone());
350
351 return Poll::Ready(Some(Ok(socket)));
352 }
353 }
354 }
355}
356
357#[deprecated(
359 since = "0.11.0",
360 note = "Use the new explicit API instead, see examples/server.rs"
361)]
362pub struct Socks5Socket<T: AsyncRead + AsyncWrite + Unpin, A: Authentication> {
363 inner: T,
364 config: Arc<Config<A>>,
365 auth: AuthenticationMethod,
366 target_addr: Option<TargetAddr>,
367 cmd: Option<Socks5Command>,
368 reply_ip: Option<IpAddr>,
370 credentials: Option<A::Item>,
373}
374
375pub mod states {
376 pub struct Opened;
377 pub struct Authenticated;
378 pub struct CommandRead;
379}
380
381pub struct Socks5ServerProtocol<T, S> {
382 inner: T,
383 _state: PhantomData<S>,
384}
385
386impl<T, S> Socks5ServerProtocol<T, S> {
387 fn new(inner: T) -> Self {
388 Socks5ServerProtocol {
389 inner,
390 _state: PhantomData,
391 }
392 }
393}
394
395impl<T> Socks5ServerProtocol<T, states::Opened> {
396 pub fn start(inner: T) -> Self {
398 Self::new(inner)
399 }
400}
401
402pub trait CheckResult {
403 fn is_good(&self) -> bool;
404}
405
406impl CheckResult for bool {
407 fn is_good(&self) -> bool {
408 *self
409 }
410}
411
412impl<T> CheckResult for Option<T> {
413 fn is_good(&self) -> bool {
414 self.is_some()
415 }
416}
417
418impl<T, E> CheckResult for Result<T, E> {
419 fn is_good(&self) -> bool {
420 self.is_ok()
421 }
422}
423
424impl<T> Socks5ServerProtocol<T, states::Authenticated> {
425 pub fn finish_auth<A: AuthMethodSuccessState<T>>(auth: A) -> Self {
428 Self::new(auth.into_inner())
429 }
430
431 pub fn skip_auth_this_is_not_rfc_compliant(inner: T) -> Self {
436 Self::new(inner)
437 }
438
439 pub async fn accept_no_auth(inner: T) -> Result<Self, SocksServerError>
441 where
442 T: AsyncWrite + AsyncRead + Unpin,
443 {
444 Ok(Socks5ServerProtocol::start(inner)
445 .negotiate_auth(&[NoAuthentication])
446 .await?
447 .finish_auth())
448 }
449
450 pub async fn accept_password_auth<F, R>(
455 inner: T,
456 mut check: F,
457 ) -> Result<(Self, R), SocksServerError>
458 where
459 T: AsyncWrite + AsyncRead + Unpin,
460 F: FnMut(String, String) -> R,
461 R: CheckResult,
462 {
463 let (user, pass, auth) = Socks5ServerProtocol::start(inner)
464 .negotiate_auth(&[PasswordAuthentication])
465 .await?
466 .read_username_password()
467 .await?;
468 let check_result = check(user, pass);
469 if check_result.is_good() {
470 Ok((auth.accept().await?.finish_auth(), check_result))
471 } else {
472 auth.reject().await?;
473 Err(SocksServerError::AuthenticationRejected)
474 }
475 }
476}
477
478pub trait AuthMethodSuccessState<T> {
484 fn into_inner(self) -> T;
485
486 fn finish_auth(self) -> Socks5ServerProtocol<T, states::Authenticated>
487 where
488 Self: Sized,
489 {
490 Socks5ServerProtocol::finish_auth(self)
491 }
492}
493
494pub trait AuthMethod<T>: Copy {
500 type StartingState;
501 fn method_id(self) -> u8;
502 fn new(self, inner: T) -> Self::StartingState;
503}
504
505pub struct NoAuthenticationImpl<T>(T);
506
507impl<T> AuthMethodSuccessState<T> for NoAuthenticationImpl<T> {
508 fn into_inner(self) -> T {
509 self.0
510 }
511}
512
513#[derive(Debug, Clone, Copy)]
522pub struct NoAuthentication;
523
524impl<T> AuthMethod<T> for NoAuthentication {
525 type StartingState = NoAuthenticationImpl<T>;
526
527 fn method_id(self) -> u8 {
528 0x00
529 }
530
531 fn new(self, inner: T) -> Self::StartingState {
532 NoAuthenticationImpl(inner)
533 }
534}
535
536mod password_states {
537 pub struct Started;
538 pub struct Received;
539 pub struct Finished;
540}
541
542pub struct PasswordAuthenticationImpl<T, S> {
543 inner: T,
544 _state: PhantomData<S>,
545}
546
547pub type PasswordAuthenticationStarted<T> = PasswordAuthenticationImpl<T, password_states::Started>;
548
549impl<T, S> PasswordAuthenticationImpl<T, S> {
550 fn new(inner: T) -> Self {
551 PasswordAuthenticationImpl {
552 inner,
553 _state: PhantomData,
554 }
555 }
556}
557
558impl<T: AsyncRead + Unpin> PasswordAuthenticationImpl<T, password_states::Started> {
559 pub async fn read_username_password(
561 self,
562 ) -> Result<
563 (
564 String,
565 String,
566 PasswordAuthenticationImpl<T, password_states::Received>,
567 ),
568 SocksServerError,
569 > {
570 let mut socket = self.inner;
571 trace!("PasswordAuthenticationStarted: read_username_password()");
572 let [version, user_len] = read_exact!(socket, [0u8; 2]).err_when("reading user len")?;
573 debug!(
574 "Auth: [version: {version}, user len: {len}]",
575 version = version,
576 len = user_len,
577 );
578
579 if user_len < 1 {
580 return Err(SocksServerError::EmptyUsername);
581 }
582
583 let username =
584 read_exact!(socket, vec![0u8; user_len as usize]).err_when("reading username")?;
585 debug!("username bytes: {:?}", &username);
586
587 let [pass_len] = read_exact!(socket, [0u8; 1]).err_when("reading password len")?;
588 debug!("Auth: [pass len: {len}]", len = pass_len,);
589
590 if pass_len < 1 {
591 return Err(SocksServerError::EmptyPassword);
592 }
593
594 let password =
595 read_exact!(socket, vec![0u8; pass_len as usize]).err_when("reading password")?;
596 debug!("password bytes: {:?}", &password);
597
598 let username = String::from_utf8(username).err_when("converting username")?;
599 let password = String::from_utf8(password).err_when("converting password")?;
600
601 Ok((username, password, PasswordAuthenticationImpl::new(socket)))
602 }
603}
604
605impl<T: AsyncWrite + Unpin> PasswordAuthenticationImpl<T, password_states::Received> {
606 pub async fn accept(
608 mut self,
609 ) -> Result<PasswordAuthenticationImpl<T, password_states::Finished>, SocksServerError> {
610 self.inner
611 .write_all(&[1, consts::SOCKS5_REPLY_SUCCEEDED])
612 .await
613 .err_when("replying auth success")?;
614
615 debug!("Password authentication accepted.");
616 Ok(PasswordAuthenticationImpl::new(self.inner))
617 }
618
619 pub async fn reject(mut self) -> Result<(), SocksServerError> {
621 self.inner
622 .write_all(&[1, consts::SOCKS5_AUTH_METHOD_NOT_ACCEPTABLE])
623 .await
624 .err_when("replying with auth method not acceptable")?;
625
626 debug!("Password authentication rejected.");
627 Ok(())
628 }
629}
630
631impl<T> AuthMethodSuccessState<T> for PasswordAuthenticationImpl<T, password_states::Finished> {
632 fn into_inner(self) -> T {
633 self.inner
634 }
635}
636
637#[derive(Debug, Clone, Copy)]
639pub struct PasswordAuthentication;
640
641impl<T> AuthMethod<T> for PasswordAuthentication {
642 type StartingState = PasswordAuthenticationImpl<T, password_states::Started>;
643
644 fn method_id(self) -> u8 {
645 0x02
646 }
647
648 fn new(self, inner: T) -> Self::StartingState {
649 PasswordAuthenticationImpl::new(inner)
650 }
651}
652
653#[macro_export]
654macro_rules! auth_method_enums {
655 (
656 $(#[$enum_meta:meta])*
657 $vis:vis enum $enum:ident / $(#[$state_enum_meta:meta])* $state_enum:ident<$state_enum_par:ident> {
658 $($method:ident($state:ty)),+ $(,)?
659 }
660 ) => {
661 $(#[$state_enum_meta])*
662 $vis enum $state_enum<$state_enum_par> {
663 $($method($state)),+
664 }
665
666 #[derive(Clone, Copy)]
667 $(#[$enum_meta])*
668 $vis enum $enum {
669 $($method($method)),+
670 }
671
672 impl<T> AuthMethod<T> for $enum {
673 type StartingState = $state_enum<T>;
674
675 fn method_id(self) -> u8 {
676 match self {
677 $($enum::$method(auth) => AuthMethod::<T>::method_id(auth)),+
678 }
679 }
680
681 fn new(self, inner: T) -> Self::StartingState {
682 match self {
683 $($enum::$method(auth) => $state_enum::$method(auth.new(inner))),+
684 }
685 }
686 }
687 };
688}
689
690auth_method_enums! {
691 pub enum StandardAuthentication / StandardAuthenticationStarted<T> {
696 NoAuthentication(NoAuthenticationImpl<T>),
697 PasswordAuthentication(PasswordAuthenticationImpl<T, password_states::Started>),
698 }
699}
700
701impl StandardAuthentication {
702 pub fn allow_no_auth(allow: bool) -> &'static [StandardAuthentication] {
704 if allow {
705 &[
706 StandardAuthentication::PasswordAuthentication(PasswordAuthentication),
709 StandardAuthentication::NoAuthentication(NoAuthentication),
710 ]
711 } else {
712 &[StandardAuthentication::PasswordAuthentication(
713 PasswordAuthentication,
714 )]
715 }
716 }
717}
718
719#[allow(deprecated)]
720impl<T: AsyncRead + AsyncWrite + Unpin, A: Authentication> Socks5Socket<T, A> {
721 pub fn new(socket: T, config: Arc<Config<A>>) -> Self {
722 Socks5Socket {
723 inner: socket,
724 config,
725 auth: AuthenticationMethod::None,
726 target_addr: None,
727 cmd: None,
728 reply_ip: None,
729 credentials: None,
730 }
731 }
732
733 pub fn set_reply_ip(&mut self, addr: IpAddr) {
745 self.reply_ip = Some(addr);
746 }
747
748 pub async fn upgrade_to_socks5(mut self) -> Result<Socks5Socket<T, A>, SocksError> {
751 trace!("upgrading to socks5...");
752
753 let proto = match self.config.auth.as_ref() {
757 _ if self.config.skip_auth => {
758 debug!("skipping auth");
759 Socks5ServerProtocol::skip_auth_this_is_not_rfc_compliant(self.inner)
760 }
761 None => Socks5ServerProtocol::start(self.inner)
762 .negotiate_auth(&[NoAuthentication])
763 .await?
764 .finish_auth(),
765 Some(auth_callback) => {
766 let methods = StandardAuthentication::allow_no_auth(self.config.allow_no_auth);
767 let auth = Socks5ServerProtocol::start(self.inner)
768 .negotiate_auth(methods)
769 .await?;
770 let (proto, creds) = authenticate_callback(auth_callback.as_ref(), auth).await?;
771 self.credentials = Some(creds);
772 proto
773 }
774 };
775
776 let (proto, cmd, target_addr) = {
777 let triple = proto.read_command().await?;
778
779 if self.config.dns_resolve {
780 triple.resolve_dns().await?
781 } else {
782 debug!(
783 "Domain won't be resolved because `dns_resolve`'s config has been turned off."
784 );
785 triple
786 }
787 };
788
789 match cmd {
790 cmd if !self.config.execute_command => {
791 self.cmd = Some(cmd);
792 self.inner = proto.inner;
793 }
794 Socks5Command::TCPConnect => {
795 self.inner = run_tcp_proxy(
796 proto,
797 &target_addr,
798 self.config.request_timeout,
799 self.config.nodelay,
800 )
801 .await?;
802 }
803 Socks5Command::UDPAssociate if self.config.allow_udp => {
804 self.inner = run_udp_proxy(
805 proto,
806 &target_addr,
807 None,
808 self.reply_ip.context("invalid reply ip")?,
809 None,
810 )
811 .await?;
812 }
813 _ => {
814 proto.reply_error(&ReplyError::CommandNotSupported).await?;
815 return Err(ReplyError::CommandNotSupported.into());
816 }
817 };
818
819 self.target_addr = Some(target_addr); Ok(self)
821 }
822
823 pub fn into_inner(self) -> T {
825 self.inner
826 }
827
828 pub async fn resolve_dns(&mut self) -> Result<(), SocksError> {
831 trace!("resolving dns");
832 if let Some(target_addr) = self.target_addr.take() {
833 self.target_addr = match target_addr {
835 TargetAddr::Domain(_, _) => Some(target_addr.resolve_dns().await?),
836 TargetAddr::Ip(_) => Some(target_addr),
837 };
838 }
839
840 Ok(())
841 }
842
843 pub fn target_addr(&self) -> Option<&TargetAddr> {
844 self.target_addr.as_ref()
845 }
846
847 pub fn auth(&self) -> &AuthenticationMethod {
848 &self.auth
849 }
850
851 pub fn cmd(&self) -> &Option<Socks5Command> {
852 &self.cmd
853 }
854
855 pub fn get_credentials(&self) -> Option<&<<A as Authentication>::Item as Deref>::Target>
857 where
858 <A as Authentication>::Item: Deref,
859 {
860 self.credentials.as_deref()
861 }
862
863 pub fn take_credentials(&mut self) -> Option<A::Item> {
865 self.credentials.take()
866 }
867}
868
869impl<T: AsyncRead + AsyncWrite + Unpin> Socks5ServerProtocol<T, states::Opened> {
870 pub async fn negotiate_auth<M: AuthMethod<T>>(
878 mut self,
879 server_methods: &[M],
880 ) -> Result<M::StartingState, SocksServerError> {
881 trace!("Socks5ServerProtocol: negotiate_auth()");
882 let [version, methods_len] =
883 read_exact!(self.inner, [0u8; 2]).err_when("reading methods")?;
884 debug!(
885 "Handshake headers: [version: {version}, methods len: {len}]",
886 version = version,
887 len = methods_len,
888 );
889
890 if version != consts::SOCKS5_VERSION {
891 return Err(SocksServerError::UnsupportedSocksVersion(version));
892 }
893
894 let methods =
898 read_exact!(self.inner, vec![0u8; methods_len as usize]).err_when("reading methods")?;
899 debug!("methods supported sent by the client: {:?}", &methods);
900
901 for server_method in server_methods {
904 for client_method_id in methods.iter() {
905 if server_method.method_id() == *client_method_id {
906 debug!("Reply with method {}", *client_method_id);
907 self.inner
908 .write_all(&[consts::SOCKS5_VERSION, *client_method_id])
909 .await
910 .err_when("replying with auth method")?;
911 return Ok(server_method.new(self.inner));
912 }
913 }
914 }
915
916 debug!("No auth method supported by both client and server, reply with (0xff)");
917 self.inner
918 .write_all(&[
919 consts::SOCKS5_VERSION,
920 consts::SOCKS5_AUTH_METHOD_NOT_ACCEPTABLE,
921 ])
922 .await
923 .err_when("replying with method not acceptable")?;
924 Err(SocksServerError::AuthMethodUnacceptable(methods))
925 }
926}
927
928impl<T: AsyncRead + AsyncWrite + Unpin> Socks5ServerProtocol<T, states::CommandRead> {
929 pub async fn reply_success(mut self, sock_addr: SocketAddr) -> Result<T, SocksServerError> {
932 self.inner
933 .write(&new_reply(&ReplyError::Succeeded, sock_addr))
934 .await
935 .err_when("writing successful reply")?;
936
937 self.inner.flush().await.err_when("flushing auth reply")?;
938
939 debug!("Wrote success");
940 Ok(self.inner)
941 }
942
943 pub async fn reply_error(mut self, error: &ReplyError) -> Result<(), SocksServerError> {
945 let reply = new_reply(error, "0.0.0.0:0".parse().unwrap());
946 debug!("reply error to be written: {:?}", &reply);
947
948 self.inner
949 .write(&reply)
950 .await
951 .err_when("writing unsuccessful reply")?;
952
953 self.inner.flush().await.err_when("flushing auth reply")?;
954
955 Ok(())
956 }
957}
958
959macro_rules! try_notify {
960 ($proto:expr, $e:expr) => {
961 match $e {
962 Ok(res) => res,
963 Err(err) => {
964 if let Err(rep_err) = $proto.reply_error(&err.to_reply_error()).await {
965 error!(
966 "extra error while reporting an error to the client: {}",
967 rep_err
968 );
969 }
970 return Err(err.into());
971 }
972 }
973 };
974}
975
976impl<T: AsyncRead + AsyncWrite + Unpin> Socks5ServerProtocol<T, states::Authenticated> {
977 pub async fn read_command(
992 mut self,
993 ) -> Result<
994 (
995 Socks5ServerProtocol<T, states::CommandRead>,
996 Socks5Command,
997 TargetAddr,
998 ),
999 SocksServerError,
1000 > {
1001 let [version, cmd, rsv, address_type] =
1002 read_exact!(self.inner, [0u8; 4]).err_when("reading command")?;
1003 debug!(
1004 "Request: [version: {version}, command: {cmd}, rev: {rsv}, address_type: {address_type}]",
1005 version = version,
1006 cmd = cmd,
1007 rsv = rsv,
1008 address_type = address_type,
1009 );
1010
1011 if version != consts::SOCKS5_VERSION {
1012 return Err(SocksServerError::UnsupportedSocksVersion(version));
1013 }
1014
1015 let mut proto = Socks5ServerProtocol::new(self.inner);
1016
1017 let target_addr = try_notify!(proto, read_address(&mut proto.inner, address_type).await);
1019
1020 debug!("Request target is {}", target_addr);
1021
1022 let cmd = try_notify!(
1023 proto,
1024 Socks5Command::from_u8(cmd).ok_or(SocksServerError::UnknownCommand(cmd))
1025 );
1026
1027 Ok((proto, cmd, target_addr))
1028 }
1029}
1030
1031#[allow(async_fn_in_trait)]
1032pub trait DnsResolveHelper
1033where
1034 Self: Sized,
1035{
1036 async fn resolve_dns(self) -> Result<Self, SocksServerError>;
1037}
1038
1039impl<T> DnsResolveHelper
1040 for (
1041 Socks5ServerProtocol<T, states::CommandRead>,
1042 Socks5Command,
1043 TargetAddr,
1044 )
1045where
1046 T: AsyncRead + AsyncWrite + Unpin,
1047{
1048 async fn resolve_dns(self) -> Result<Self, SocksServerError> {
1049 let (proto, cmd, target_addr) = self;
1050 let resolved_addr = try_notify!(proto, target_addr.resolve_dns().await);
1051 Ok((proto, cmd, resolved_addr))
1052 }
1053}
1054
1055pub async fn run_tcp_proxy<T: AsyncRead + AsyncWrite + Unpin>(
1057 proto: Socks5ServerProtocol<T, states::CommandRead>,
1058 addr: &TargetAddr,
1059 request_timeout: Duration,
1060 nodelay: bool,
1061) -> Result<T, SocksServerError> {
1062 let addr = try_notify!(
1063 proto,
1064 addr.to_socket_addrs()
1065 .err_when("converting to socket addr")
1066 .and_then(|mut addrs| addrs.next().ok_or(SocksServerError::Bug("no socket addrs")))
1067 );
1068
1069 let outbound = match tcp_connect_with_timeout(addr, request_timeout).await {
1071 Ok(stream) => stream,
1072 Err(err) => {
1073 proto.reply_error(&err.to_reply_error()).await?;
1074 return Err(err.into());
1075 }
1076 };
1077
1078 try_notify!(
1080 proto,
1081 outbound.set_nodelay(nodelay).err_when("setting nodelay")
1082 );
1083
1084 debug!("Connected to remote destination");
1085
1086 let mut inner = proto
1087 .reply_success(SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 0))
1088 .await?;
1089
1090 transfer(&mut inner, outbound).await;
1091 Ok(inner)
1092}
1093
1094fn udp_bind_random_port(addr: Option<IpAddr>) -> io::Result<Socket> {
1095 if let Some(addr) = addr {
1096 let sock_addr = SocketAddr::new(addr, 0);
1097 let socket = Socket::new(Domain::for_address(sock_addr), Type::DGRAM, None)?;
1098 socket.bind(&sock_addr.into())?;
1099 Ok(socket)
1100 } else {
1101 const V4_UNSPEC: SocketAddr = SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0);
1102 const V6_UNSPEC: SocketAddr = SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0);
1103 Socket::new(Domain::IPV6, Type::DGRAM, None)
1104 .and_then(|socket| socket.set_only_v6(false).map(|_| socket))
1105 .and_then(|socket| socket.bind(&V6_UNSPEC.into()).map(|_| socket))
1106 .or_else(|_| {
1107 Socket::new(Domain::IPV4, Type::DGRAM, None)
1108 .and_then(|socket| socket.bind(&V4_UNSPEC.into()).map(|_| socket))
1109 })
1110 }
1111 .and_then(|socket| socket.set_nonblocking(true).map(|_| socket))
1112}
1113
1114pub async fn run_udp_proxy<T: AsyncRead + AsyncWrite + Unpin>(
1116 proto: Socks5ServerProtocol<T, states::CommandRead>,
1117 addr: &TargetAddr,
1118 peer_bind_ip: Option<IpAddr>,
1119 reply_ip: IpAddr,
1120 outbound_bind_ip: Option<IpAddr>,
1121) -> Result<T, SocksServerError> {
1122 run_udp_proxy_custom(
1123 proto,
1124 addr,
1125 peer_bind_ip,
1126 reply_ip,
1127 move |inbound| async move {
1128 let outbound =
1129 udp_bind_random_port(outbound_bind_ip).err_when("binding outbound udp socket")?;
1130
1131 transfer_udp(inbound, outbound).await
1132 },
1133 )
1134 .await
1135}
1136
1137pub async fn run_udp_proxy_custom<T, F, R>(
1141 proto: Socks5ServerProtocol<T, states::CommandRead>,
1142 _addr: &TargetAddr,
1143 peer_bind_ip: Option<IpAddr>,
1144 reply_ip: IpAddr,
1145 transfer: F,
1146) -> Result<T, SocksServerError>
1147where
1148 T: AsyncRead + AsyncWrite + Unpin,
1149 F: FnOnce(Socket) -> R,
1150 R: Future<Output = Result<(), SocksServerError>>,
1151{
1152 let peer_sock = try_notify!(
1163 proto,
1164 udp_bind_random_port(peer_bind_ip).err_when("binding client udp socket")
1165 );
1166
1167 let peer_addr = try_notify!(
1168 proto,
1169 peer_sock.local_addr().err_when("getting peer's local addr")
1170 );
1171
1172 let reply_port = peer_addr
1173 .as_socket()
1174 .ok_or(SocksServerError::Bug("addr not IP"))?
1175 .port();
1176
1177 let mut inner = proto
1179 .reply_success(SocketAddr::new(reply_ip, reply_port))
1180 .await?;
1181
1182 let udp_fut = transfer(peer_sock);
1183 let tcp_fut = wait_on_tcp(&mut inner);
1184 match try_join!(udp_fut, tcp_fut) {
1185 Ok(_) => warn!("unreachable"),
1186 Err(SocksServerError::EOF) => debug!("EOF on controlling TCP stream, closed UDP proxy"),
1187 Err(err) => warn!("while UDP proxying: {err}"),
1188 }
1189 Ok(inner)
1190}
1191
1192pub async fn wait_on_tcp<I>(stream: &mut I) -> Result<(), SocksServerError>
1196where
1197 I: AsyncRead + Unpin,
1198{
1199 let mut buf = [0; 1];
1200 match stream.read(&mut buf).await {
1201 Ok(0) => Err(SocksServerError::EOF),
1202 Ok(_) => Err(SocksServerError::UnexpectedUdpControlGarbage(buf[0])),
1203 Err(err) => Err(err).err_when("waiting on UDP control stream"),
1204 }
1205}
1206
1207pub async fn transfer<I, O>(mut inbound: I, mut outbound: O)
1210where
1211 I: AsyncRead + AsyncWrite + Unpin,
1212 O: AsyncRead + AsyncWrite + Unpin,
1213{
1214 match tokio::io::copy_bidirectional(&mut inbound, &mut outbound).await {
1215 Ok(res) => debug!("transfer closed ({}, {})", res.0, res.1),
1216 Err(err) => error!("transfer error: {:?}", err),
1217 };
1218}
1219
1220async fn handle_udp_request(
1221 inbound: &UdpSocket,
1222 outbound: &UdpSocket,
1223 outbound_v6: bool,
1224 buf: &mut [u8],
1225) -> Result<(), SocksServerError> {
1226 let (size, client_addr) = inbound
1227 .recv_from(buf)
1228 .await
1229 .err_when("udp receiving from")?;
1230 debug!("Server recieve udp from {}", client_addr);
1231 inbound
1232 .connect(client_addr)
1233 .await
1234 .err_when("connecting udp inbound")?;
1235
1236 let (frag, target_addr, data) = parse_udp_request(&buf[..size]).await?;
1237
1238 if frag != 0 {
1239 debug!("Discard UDP frag packets sliently.");
1240 return Ok(());
1241 }
1242
1243 debug!("Server forward to packet to {}", target_addr);
1244 let mut target_addr = target_addr
1245 .resolve_dns()
1246 .await?
1247 .to_socket_addrs()
1248 .err_when("udp target to socket addrs")?
1249 .next()
1250 .ok_or(SocksServerError::Bug("no socket addrs"))?;
1251
1252 if outbound_v6 {
1253 target_addr.set_ip(match target_addr.ip() {
1254 std::net::IpAddr::V4(v4) => std::net::IpAddr::V6(v4.to_ipv6_mapped()),
1255 v6 @ std::net::IpAddr::V6(_) => v6,
1256 });
1257 }
1258 outbound
1259 .send_to(data, target_addr)
1260 .await
1261 .err_when("udp sending to")?;
1262 Ok(())
1263}
1264
1265async fn handle_udp_requests(
1266 inbound: &UdpSocket,
1267 outbound: &UdpSocket,
1268) -> Result<(), SocksServerError> {
1269 let mut buf = vec![0u8; 8192];
1270 let outbound_v6 = outbound
1271 .local_addr()
1272 .err_when("udp outbound local addr")?
1273 .is_ipv6();
1274 loop {
1275 match handle_udp_request(inbound, outbound, outbound_v6, &mut buf).await {
1276 Ok(_) => trace!("handled udp response"),
1277 Err(err) => debug!("error in handling udp response: {err}"),
1278 }
1279 }
1280}
1281
1282async fn handle_udp_response(
1283 inbound: &UdpSocket,
1284 outbound: &UdpSocket,
1285 buf: &mut [u8],
1286) -> Result<(), SocksServerError> {
1287 let (size, mut remote_addr) = outbound
1288 .recv_from(buf)
1289 .await
1290 .err_when("udp receiving from")?;
1291 debug!("Recieve packet from {}", remote_addr);
1292
1293 if let std::net::IpAddr::V6(v6) = remote_addr.ip() {
1295 if let Some(v4) = v6.to_ipv4_mapped() {
1296 remote_addr.set_ip(std::net::IpAddr::V4(v4));
1297 }
1298 }
1299
1300 let mut data = new_udp_header(remote_addr)?;
1301 data.extend_from_slice(&buf[..size]);
1302 inbound.send(&data).await.err_when("udp sending")?;
1303
1304 Ok(())
1305}
1306
1307async fn handle_udp_responses(
1308 inbound: &UdpSocket,
1309 outbound: &UdpSocket,
1310) -> Result<(), SocksServerError> {
1311 let mut buf = vec![0u8; 8192];
1312 loop {
1313 match handle_udp_response(inbound, outbound, &mut buf).await {
1314 Ok(_) => trace!("handled udp response"),
1315 Err(err) => debug!("error in handling udp response: {err}"),
1316 }
1317 }
1318}
1319
1320pub async fn transfer_udp(inbound: Socket, outbound: Socket) -> Result<(), SocksServerError> {
1322 let inbound = UdpSocket::from_std(inbound.into()).err_when("wrapping inbound socket")?;
1323 let outbound = UdpSocket::from_std(outbound.into()).err_when("wrapping outbound socket")?;
1324 let req_fut = handle_udp_requests(&inbound, &outbound);
1325 let res_fut = handle_udp_responses(&inbound, &outbound);
1326 try_join!(req_fut, res_fut).map(|_| ())
1327}
1328
1329#[allow(deprecated)]
1333impl<T, A: Authentication> Unpin for Socks5Socket<T, A> where T: AsyncRead + AsyncWrite + Unpin {}
1334
1335#[allow(deprecated)]
1337impl<T, A: Authentication> AsyncRead for Socks5Socket<T, A>
1338where
1339 T: AsyncRead + AsyncWrite + Unpin,
1340{
1341 fn poll_read(
1342 mut self: Pin<&mut Self>,
1343 context: &mut std::task::Context,
1344 buf: &mut tokio::io::ReadBuf<'_>,
1345 ) -> Poll<std::io::Result<()>> {
1346 Pin::new(&mut self.inner).poll_read(context, buf)
1347 }
1348}
1349
1350#[allow(deprecated)]
1352impl<T, A: Authentication> AsyncWrite for Socks5Socket<T, A>
1353where
1354 T: AsyncRead + AsyncWrite + Unpin,
1355{
1356 fn poll_write(
1357 mut self: Pin<&mut Self>,
1358 context: &mut std::task::Context,
1359 buf: &[u8],
1360 ) -> Poll<io::Result<usize>> {
1361 Pin::new(&mut self.inner).poll_write(context, buf)
1362 }
1363
1364 fn poll_flush(
1365 mut self: Pin<&mut Self>,
1366 context: &mut std::task::Context,
1367 ) -> Poll<io::Result<()>> {
1368 Pin::new(&mut self.inner).poll_flush(context)
1369 }
1370
1371 fn poll_shutdown(
1372 mut self: Pin<&mut Self>,
1373 context: &mut std::task::Context,
1374 ) -> Poll<io::Result<()>> {
1375 Pin::new(&mut self.inner).poll_shutdown(context)
1376 }
1377}
1378
1379fn new_reply(error: &ReplyError, sock_addr: SocketAddr) -> Vec<u8> {
1381 let (addr_type, mut ip_oct, mut port) = match sock_addr {
1382 SocketAddr::V4(sock) => (
1383 consts::SOCKS5_ADDR_TYPE_IPV4,
1384 sock.ip().octets().to_vec(),
1385 sock.port().to_be_bytes().to_vec(),
1386 ),
1387 SocketAddr::V6(sock) => (
1388 consts::SOCKS5_ADDR_TYPE_IPV6,
1389 sock.ip().octets().to_vec(),
1390 sock.port().to_be_bytes().to_vec(),
1391 ),
1392 };
1393
1394 let mut reply = vec![
1395 consts::SOCKS5_VERSION,
1396 error.as_u8(), 0x00, addr_type, ];
1400 reply.append(&mut ip_oct);
1401 reply.append(&mut port);
1402
1403 reply
1404}
1405
1406#[cfg(test)]
1407#[allow(deprecated)]
1408mod test {
1409 use crate::server::Socks5Server;
1410 use tokio_test::block_on;
1411
1412 use super::AcceptAuthentication;
1413
1414 #[test]
1415 fn test_bind() {
1416 let f = async {
1417 let _server = Socks5Server::<AcceptAuthentication>::bind("127.0.0.1:1080")
1418 .await
1419 .unwrap();
1420 };
1421
1422 block_on(f);
1423 }
1424}