dlzht_socks5/
server.rs

1//! server module include [`SocksServerBuilder`] and [`SocksServer`]
2//!
3//! ### Run server without any authorization
4//!
5//! ```
6//! use dlzht_socks5::server::SocksServerBuilder;
7//!
8//! #[tokio::main]
9//! async fn main() {
10//!     let server = SocksServerBuilder::new()
11//!         .allow_auth_skip(true)
12//!         .build().unwrap();
13//!     let _ = server.start().await;
14//! }
15//! ```
16//!
17//! Invoking `allow_auth_skip(true)`, server will support auth method `NO AUTHENTICATION REQUIRED`,
18//! which means auth phase can be skipped.
19//!
20//!
21//! ### Run server with password authorization
22//!
23//! ```
24//! use dlzht_socks5::server::SocksServerBuilder;
25//!
26//! #[tokio::main]
27//! async fn main() {
28//!     let server = SocksServerBuilder::new()
29//!         .credential(b"username", b"password")
30//!         .build().unwrap();
31//!     let _ = server.start().await;
32//! }
33//! ```
34//!
35//! Invoking `allow_auth_skip(true)`, server will support auth method `USERNAME/PASSWORD`,
36//! `allow_auth_pass` will auto be set true(we can set false back to disable password auth).
37//!
38//! If we hava multiple username/password, then we can invoke `credential(...)` repeatedly,
39//! or invoke `credentials(...)` for convenience.
40//!
41//! ### Custom validate username/password
42//!
43//! Will support soon
44
45use crate::errors::{BuildSocksKind, ExecuteCmdKind, InvalidPackageKind, SocksError, SocksResult};
46use crate::package::{
47    read_package, write_package, AuthMethodsPackage, AuthSelectPackage, PasswordReqPackage,
48    PasswordResPackage, RepliesPackage, RequestsPackage,
49};
50use crate::{
51    is_invalid_password, is_invalid_username, AuthMethod, AuthMethods, PrivateStruct, RepliesRep,
52    RequestCmd, SocksAddr, DEFAULT_SERVER_ADDR,
53};
54use async_trait::async_trait;
55use bytes::{Bytes, BytesMut};
56use std::collections::HashMap;
57use std::fmt::Debug;
58use std::net::SocketAddr;
59use std::sync::Arc;
60use tokio::io::AsyncWriteExt;
61use tokio::net::{TcpListener, TcpStream};
62use tracing::{debug, warn};
63
64pub struct SocksServerBuilder {
65    server_address: SocketAddr,
66    allow_auth_skip: bool,
67    allow_auth_pass: bool,
68    memory_auth_pass: HashMap<Bytes, Bytes>,
69    custom_auth_pass: Option<Box<dyn PasswordAuthority>>,
70    _private: PrivateStruct,
71}
72
73impl SocksServerBuilder {
74    pub fn new() -> SocksServerBuilder {
75        return SocksServerBuilder {
76            server_address: DEFAULT_SERVER_ADDR,
77            allow_auth_skip: false,
78            allow_auth_pass: false,
79            memory_auth_pass: Default::default(),
80            custom_auth_pass: None,
81            _private: PrivateStruct,
82        };
83    }
84
85    pub fn server_address(mut self, address: SocketAddr) -> Self {
86        self.server_address = address;
87        self
88    }
89
90    pub fn allow_auth_skip(mut self, allow: bool) -> Self {
91        self.allow_auth_skip = allow;
92        self
93    }
94
95    pub fn allow_auth_pass(mut self, allow: bool) -> Self {
96        self.allow_auth_pass = allow;
97        self
98    }
99
100    pub fn credential(mut self, username: &[u8], password: &[u8]) -> Self {
101        self.allow_auth_pass = true;
102        self.memory_auth_pass.insert(
103            Bytes::copy_from_slice(username.as_ref()),
104            Bytes::copy_from_slice(password.as_ref()),
105        );
106        self
107    }
108
109    pub fn credentials(mut self, credentials: HashMap<Bytes, Bytes>) -> Self {
110        self.allow_auth_pass = true;
111        self.memory_auth_pass.extend(credentials);
112        self
113    }
114
115    pub fn custom_auth_pass<T: PasswordAuthority>(&mut self, authority: T) -> &mut Self {
116        self.custom_auth_pass = Some(Box::new(authority));
117        self
118    }
119
120    pub fn build(self) -> SocksResult<SocksServer> {
121        let SocksServerBuilder {
122            server_address: address,
123            allow_auth_skip,
124            allow_auth_pass,
125            memory_auth_pass,
126            custom_auth_pass,
127            _private,
128        } = self;
129
130        if !allow_auth_skip && !allow_auth_pass {
131            return Err(SocksError::BuildSocksServerErr(
132                BuildSocksKind::InvalidAuthMethod,
133            ));
134        }
135
136        for (username, password) in memory_auth_pass.iter() {
137            if is_invalid_username(username.as_ref()) {
138                return Err(SocksError::BuildSocksServerErr(
139                    BuildSocksKind::InvalidUsername,
140                ));
141            }
142            if is_invalid_password(password.as_ref()) {
143                return Err(SocksError::BuildSocksServerErr(
144                    BuildSocksKind::InvalidPassword,
145                ));
146            }
147        }
148
149        let authority = DefaultAuthority::new(memory_auth_pass);
150
151        let server = SocksServer {
152            address,
153            allow_auth_skip,
154            allow_auth_pass,
155            memory_auth_pass: authority,
156            custom_auth_pass,
157            _private: PrivateStruct,
158        };
159        return Ok(server);
160    }
161}
162
163pub struct SocksServer {
164    address: SocketAddr,
165    allow_auth_skip: bool,
166    allow_auth_pass: bool,
167    memory_auth_pass: DefaultAuthority,
168    custom_auth_pass: Option<Box<dyn PasswordAuthority>>,
169    _private: PrivateStruct,
170}
171
172impl SocksServer {
173    pub async fn start(self) -> SocksResult<()> {
174        let listener = TcpListener::bind(self.address).await?;
175        let server = Arc::new(self);
176        loop {
177            match listener.accept().await {
178                Err(err) => {
179                    warn!("accept error: {}", err);
180                }
181                Ok((stream, addr)) => {
182                    debug!("accept success: {}", addr);
183                    let server = server.clone();
184                    tokio::spawn(async move {
185                        match server.handshake(stream, addr).await {
186                            Ok(mut connection) => {
187                                let _ = connection.transfer().await;
188                            }
189                            Err(err) => {
190                                warn!("socks handshake error: {}", err);
191                            }
192                        }
193                    });
194                }
195            }
196        }
197    }
198
199    async fn handshake(
200        &self,
201        mut stream: TcpStream,
202        peer_addr: SocketAddr,
203    ) -> SocksResult<ServerConnection> {
204        let local_addr = stream.local_addr()?;
205        let mut buffer = BytesMut::with_capacity(512);
206
207        return match self.inner_handshake(&mut buffer, &mut stream).await {
208            Ok((identifier, method, target_stream)) => {
209                let connection = ServerConnection {
210                    identifier,
211                    local_addr,
212                    peer_addr,
213                    auth_method: method,
214                    proxy_stream: stream,
215                    target_stream,
216                };
217                Ok(connection)
218            }
219            Err(err) => {
220                let _ = stream.shutdown().await;
221                Err(err)
222            }
223        };
224    }
225
226    async fn inner_handshake(
227        &self,
228        buffer: &mut BytesMut,
229        stream: &mut TcpStream,
230    ) -> SocksResult<(u64, AuthMethod, TcpStream)> {
231        let auth_methods_pac: AuthMethodsPackage = read_package(buffer, stream).await?;
232
233        let method = self
234            .select_auth_method(auth_methods_pac.methods_ref())
235            .unwrap_or(AuthMethod::FAIL);
236        if method == AuthMethod::FAIL {
237            let auth_select_pac = AuthSelectPackage::new(AuthMethod::FAIL);
238            write_package(&auth_select_pac, buffer, stream).await?;
239            return Err(SocksError::UnsupportedAuthMethod);
240        }
241
242        let auth_select_pac = AuthSelectPackage::new(method);
243        write_package(&auth_select_pac, buffer, stream).await?;
244
245        let mut identifier = 0;
246        if method == AuthMethod::PASS {
247            let password_req_pac: PasswordReqPackage = read_package(buffer, stream).await?;
248            let authed = self
249                .process_pass_auth(
250                    password_req_pac.username_ref(),
251                    password_req_pac.password_ref(),
252                )
253                .await;
254            if authed.is_none() {
255                let password_res_pac = PasswordResPackage::new(false);
256                write_package(&password_res_pac, buffer, stream).await?;
257                return Err(SocksError::PasswordAuthNotPassed);
258            }
259            identifier = authed.unwrap_or(0);
260            let password_res_pac = PasswordResPackage::new(true);
261            write_package(&password_res_pac, buffer, stream).await?;
262        }
263
264        let requests_pac: RequestsPackage = match read_package(buffer, stream).await {
265            Ok(pac) => pac,
266            Err(err) => {
267                if matches!(
268                    err,
269                    SocksError::InvalidPackageErr(InvalidPackageKind::InvalidRequestsCmd(_))
270                ) {
271                    let replies_pac = RepliesPackage::new(
272                        RepliesRep::COMMAND_NOT_SUPPORTED,
273                        SocksAddr::UNSPECIFIED_ADDR,
274                    );
275                    write_package(&replies_pac, buffer, stream).await?;
276                }
277                return Err(err);
278            }
279        };
280        if &RequestCmd::CONNECT != requests_pac.cmd_ref() {
281            let replies_pac = RepliesPackage::new(
282                RepliesRep::COMMAND_NOT_SUPPORTED,
283                SocksAddr::UNSPECIFIED_ADDR,
284            );
285            write_package(&replies_pac, buffer, stream).await?;
286            return Err(SocksError::UnsupportedCommand(
287                requests_pac.cmd_ref().to_byte(),
288            ));
289        }
290        let target_stream = match self.connect_target_peer(requests_pac.addr_ref()).await {
291            Ok(stream) => stream,
292            Err(SocksError::ExecuteCommandErr(ExecuteCmdKind::Server(err))) => {
293                let replies_pac = RepliesPackage::new((&err).into(), SocksAddr::UNSPECIFIED_ADDR);
294                write_package(&replies_pac, buffer, stream).await?;
295                return Err(SocksError::ExecuteCommandErr(ExecuteCmdKind::Server(err)));
296            }
297            Err(err) => {
298                return Err(err);
299            }
300        };
301        let replies_pac = RepliesPackage::new(RepliesRep::SUCCESS, SocksAddr::UNSPECIFIED_ADDR);
302        write_package(&replies_pac, buffer, stream).await?;
303        return Ok((identifier, method, target_stream));
304    }
305
306    fn select_auth_method(&self, methods: &AuthMethods) -> Option<AuthMethod> {
307        if self.allow_auth_skip && methods.contains(&AuthMethod::SKIP) {
308            return Some(AuthMethod::SKIP);
309        }
310        if self.allow_auth_pass && methods.contains(&AuthMethod::PASS) {
311            return Some(AuthMethod::PASS);
312        }
313        return None;
314    }
315
316    async fn process_pass_auth(&self, username: &[u8], password: &[u8]) -> Option<u64> {
317        let res = self.memory_auth_pass.auth(username, password).await;
318        if res.is_some() {
319            return res;
320        }
321        return match &self.custom_auth_pass {
322            None => None,
323            Some(authority) => authority.auth(username, password).await,
324        };
325    }
326
327    async fn connect_target_peer(&self, addr: &SocksAddr) -> SocksResult<TcpStream> {
328        let stream = match addr {
329            SocksAddr::IPV4(ipv4) => TcpStream::connect(ipv4).await,
330            SocksAddr::IPV6(ipv6) => TcpStream::connect(ipv6).await,
331            SocksAddr::Domain(domain, port) => TcpStream::connect((domain.as_str(), *port)).await,
332        };
333        return stream.map_err(|err| SocksError::ExecuteCommandErr(ExecuteCmdKind::Server(err)));
334    }
335}
336
337#[derive(Debug)]
338pub(crate) struct ServerConnection {
339    identifier: u64,
340    local_addr: SocketAddr,
341    peer_addr: SocketAddr,
342    auth_method: AuthMethod,
343    proxy_stream: TcpStream,
344    target_stream: TcpStream,
345}
346
347impl ServerConnection {
348    #[allow(dead_code)]
349    fn identifier(&self) -> u64 {
350        return self.identifier;
351    }
352
353    #[allow(dead_code)]
354    fn local_addr(&self) -> SocketAddr {
355        return self.local_addr;
356    }
357
358    #[allow(dead_code)]
359    fn peer_addr(&self) -> SocketAddr {
360        return self.peer_addr;
361    }
362
363    #[allow(dead_code)]
364    fn auth_method(&self) -> AuthMethod {
365        return self.auth_method;
366    }
367
368    async fn transfer(&mut self) -> SocksResult<()> {
369        tokio::io::copy_bidirectional(&mut self.proxy_stream, &mut self.target_stream).await?;
370        return Ok(());
371    }
372}
373
374#[async_trait]
375pub trait PasswordAuthority: Send + Sync + 'static {
376    async fn auth(&self, username: &[u8], password: &[u8]) -> Option<u64>;
377}
378
379pub(crate) struct DefaultAuthority {
380    passwords: HashMap<Bytes, Bytes>,
381}
382
383impl DefaultAuthority {
384    pub fn new(passwords: HashMap<Bytes, Bytes>) -> DefaultAuthority {
385        return DefaultAuthority { passwords };
386    }
387}
388
389#[async_trait]
390impl PasswordAuthority for DefaultAuthority {
391    async fn auth(&self, username: &[u8], password: &[u8]) -> Option<u64> {
392        let result = self
393            .passwords
394            .get(username)
395            .map(|p| p == password)
396            .unwrap_or(false);
397        return if result { Some(1) } else { None };
398    }
399}