1use crate::errors::{BuildSocksKind, ExecuteCmdKind, InvalidPackageKind, SocksError, SocksResult};
46use crate::package::{
47 read_package, write_package, AuthMethodsPackage, AuthSelectPackage, PasswordReqPackage,
48 PasswordResPackage, RepliesPackage, RequestsPackage,
49};
50use crate::{
51 is_invalid_password, is_invalid_username, AuthMethod, AuthMethods, PrivateStruct, RepliesRep,
52 RequestCmd, SocksAddr, DEFAULT_SERVER_ADDR,
53};
54use async_trait::async_trait;
55use bytes::{Bytes, BytesMut};
56use std::collections::HashMap;
57use std::fmt::Debug;
58use std::net::SocketAddr;
59use std::sync::Arc;
60use tokio::io::AsyncWriteExt;
61use tokio::net::{TcpListener, TcpStream};
62use tracing::{debug, warn};
63
64pub struct SocksServerBuilder {
65 server_address: SocketAddr,
66 allow_auth_skip: bool,
67 allow_auth_pass: bool,
68 memory_auth_pass: HashMap<Bytes, Bytes>,
69 custom_auth_pass: Option<Box<dyn PasswordAuthority>>,
70 _private: PrivateStruct,
71}
72
73impl SocksServerBuilder {
74 pub fn new() -> SocksServerBuilder {
75 return SocksServerBuilder {
76 server_address: DEFAULT_SERVER_ADDR,
77 allow_auth_skip: false,
78 allow_auth_pass: false,
79 memory_auth_pass: Default::default(),
80 custom_auth_pass: None,
81 _private: PrivateStruct,
82 };
83 }
84
85 pub fn server_address(mut self, address: SocketAddr) -> Self {
86 self.server_address = address;
87 self
88 }
89
90 pub fn allow_auth_skip(mut self, allow: bool) -> Self {
91 self.allow_auth_skip = allow;
92 self
93 }
94
95 pub fn allow_auth_pass(mut self, allow: bool) -> Self {
96 self.allow_auth_pass = allow;
97 self
98 }
99
100 pub fn credential(mut self, username: &[u8], password: &[u8]) -> Self {
101 self.allow_auth_pass = true;
102 self.memory_auth_pass.insert(
103 Bytes::copy_from_slice(username.as_ref()),
104 Bytes::copy_from_slice(password.as_ref()),
105 );
106 self
107 }
108
109 pub fn credentials(mut self, credentials: HashMap<Bytes, Bytes>) -> Self {
110 self.allow_auth_pass = true;
111 self.memory_auth_pass.extend(credentials);
112 self
113 }
114
115 pub fn custom_auth_pass<T: PasswordAuthority>(&mut self, authority: T) -> &mut Self {
116 self.custom_auth_pass = Some(Box::new(authority));
117 self
118 }
119
120 pub fn build(self) -> SocksResult<SocksServer> {
121 let SocksServerBuilder {
122 server_address: address,
123 allow_auth_skip,
124 allow_auth_pass,
125 memory_auth_pass,
126 custom_auth_pass,
127 _private,
128 } = self;
129
130 if !allow_auth_skip && !allow_auth_pass {
131 return Err(SocksError::BuildSocksServerErr(
132 BuildSocksKind::InvalidAuthMethod,
133 ));
134 }
135
136 for (username, password) in memory_auth_pass.iter() {
137 if is_invalid_username(username.as_ref()) {
138 return Err(SocksError::BuildSocksServerErr(
139 BuildSocksKind::InvalidUsername,
140 ));
141 }
142 if is_invalid_password(password.as_ref()) {
143 return Err(SocksError::BuildSocksServerErr(
144 BuildSocksKind::InvalidPassword,
145 ));
146 }
147 }
148
149 let authority = DefaultAuthority::new(memory_auth_pass);
150
151 let server = SocksServer {
152 address,
153 allow_auth_skip,
154 allow_auth_pass,
155 memory_auth_pass: authority,
156 custom_auth_pass,
157 _private: PrivateStruct,
158 };
159 return Ok(server);
160 }
161}
162
163pub struct SocksServer {
164 address: SocketAddr,
165 allow_auth_skip: bool,
166 allow_auth_pass: bool,
167 memory_auth_pass: DefaultAuthority,
168 custom_auth_pass: Option<Box<dyn PasswordAuthority>>,
169 _private: PrivateStruct,
170}
171
172impl SocksServer {
173 pub async fn start(self) -> SocksResult<()> {
174 let listener = TcpListener::bind(self.address).await?;
175 let server = Arc::new(self);
176 loop {
177 match listener.accept().await {
178 Err(err) => {
179 warn!("accept error: {}", err);
180 }
181 Ok((stream, addr)) => {
182 debug!("accept success: {}", addr);
183 let server = server.clone();
184 tokio::spawn(async move {
185 match server.handshake(stream, addr).await {
186 Ok(mut connection) => {
187 let _ = connection.transfer().await;
188 }
189 Err(err) => {
190 warn!("socks handshake error: {}", err);
191 }
192 }
193 });
194 }
195 }
196 }
197 }
198
199 async fn handshake(
200 &self,
201 mut stream: TcpStream,
202 peer_addr: SocketAddr,
203 ) -> SocksResult<ServerConnection> {
204 let local_addr = stream.local_addr()?;
205 let mut buffer = BytesMut::with_capacity(512);
206
207 return match self.inner_handshake(&mut buffer, &mut stream).await {
208 Ok((identifier, method, target_stream)) => {
209 let connection = ServerConnection {
210 identifier,
211 local_addr,
212 peer_addr,
213 auth_method: method,
214 proxy_stream: stream,
215 target_stream,
216 };
217 Ok(connection)
218 }
219 Err(err) => {
220 let _ = stream.shutdown().await;
221 Err(err)
222 }
223 };
224 }
225
226 async fn inner_handshake(
227 &self,
228 buffer: &mut BytesMut,
229 stream: &mut TcpStream,
230 ) -> SocksResult<(u64, AuthMethod, TcpStream)> {
231 let auth_methods_pac: AuthMethodsPackage = read_package(buffer, stream).await?;
232
233 let method = self
234 .select_auth_method(auth_methods_pac.methods_ref())
235 .unwrap_or(AuthMethod::FAIL);
236 if method == AuthMethod::FAIL {
237 let auth_select_pac = AuthSelectPackage::new(AuthMethod::FAIL);
238 write_package(&auth_select_pac, buffer, stream).await?;
239 return Err(SocksError::UnsupportedAuthMethod);
240 }
241
242 let auth_select_pac = AuthSelectPackage::new(method);
243 write_package(&auth_select_pac, buffer, stream).await?;
244
245 let mut identifier = 0;
246 if method == AuthMethod::PASS {
247 let password_req_pac: PasswordReqPackage = read_package(buffer, stream).await?;
248 let authed = self
249 .process_pass_auth(
250 password_req_pac.username_ref(),
251 password_req_pac.password_ref(),
252 )
253 .await;
254 if authed.is_none() {
255 let password_res_pac = PasswordResPackage::new(false);
256 write_package(&password_res_pac, buffer, stream).await?;
257 return Err(SocksError::PasswordAuthNotPassed);
258 }
259 identifier = authed.unwrap_or(0);
260 let password_res_pac = PasswordResPackage::new(true);
261 write_package(&password_res_pac, buffer, stream).await?;
262 }
263
264 let requests_pac: RequestsPackage = match read_package(buffer, stream).await {
265 Ok(pac) => pac,
266 Err(err) => {
267 if matches!(
268 err,
269 SocksError::InvalidPackageErr(InvalidPackageKind::InvalidRequestsCmd(_))
270 ) {
271 let replies_pac = RepliesPackage::new(
272 RepliesRep::COMMAND_NOT_SUPPORTED,
273 SocksAddr::UNSPECIFIED_ADDR,
274 );
275 write_package(&replies_pac, buffer, stream).await?;
276 }
277 return Err(err);
278 }
279 };
280 if &RequestCmd::CONNECT != requests_pac.cmd_ref() {
281 let replies_pac = RepliesPackage::new(
282 RepliesRep::COMMAND_NOT_SUPPORTED,
283 SocksAddr::UNSPECIFIED_ADDR,
284 );
285 write_package(&replies_pac, buffer, stream).await?;
286 return Err(SocksError::UnsupportedCommand(
287 requests_pac.cmd_ref().to_byte(),
288 ));
289 }
290 let target_stream = match self.connect_target_peer(requests_pac.addr_ref()).await {
291 Ok(stream) => stream,
292 Err(SocksError::ExecuteCommandErr(ExecuteCmdKind::Server(err))) => {
293 let replies_pac = RepliesPackage::new((&err).into(), SocksAddr::UNSPECIFIED_ADDR);
294 write_package(&replies_pac, buffer, stream).await?;
295 return Err(SocksError::ExecuteCommandErr(ExecuteCmdKind::Server(err)));
296 }
297 Err(err) => {
298 return Err(err);
299 }
300 };
301 let replies_pac = RepliesPackage::new(RepliesRep::SUCCESS, SocksAddr::UNSPECIFIED_ADDR);
302 write_package(&replies_pac, buffer, stream).await?;
303 return Ok((identifier, method, target_stream));
304 }
305
306 fn select_auth_method(&self, methods: &AuthMethods) -> Option<AuthMethod> {
307 if self.allow_auth_skip && methods.contains(&AuthMethod::SKIP) {
308 return Some(AuthMethod::SKIP);
309 }
310 if self.allow_auth_pass && methods.contains(&AuthMethod::PASS) {
311 return Some(AuthMethod::PASS);
312 }
313 return None;
314 }
315
316 async fn process_pass_auth(&self, username: &[u8], password: &[u8]) -> Option<u64> {
317 let res = self.memory_auth_pass.auth(username, password).await;
318 if res.is_some() {
319 return res;
320 }
321 return match &self.custom_auth_pass {
322 None => None,
323 Some(authority) => authority.auth(username, password).await,
324 };
325 }
326
327 async fn connect_target_peer(&self, addr: &SocksAddr) -> SocksResult<TcpStream> {
328 let stream = match addr {
329 SocksAddr::IPV4(ipv4) => TcpStream::connect(ipv4).await,
330 SocksAddr::IPV6(ipv6) => TcpStream::connect(ipv6).await,
331 SocksAddr::Domain(domain, port) => TcpStream::connect((domain.as_str(), *port)).await,
332 };
333 return stream.map_err(|err| SocksError::ExecuteCommandErr(ExecuteCmdKind::Server(err)));
334 }
335}
336
337#[derive(Debug)]
338pub(crate) struct ServerConnection {
339 identifier: u64,
340 local_addr: SocketAddr,
341 peer_addr: SocketAddr,
342 auth_method: AuthMethod,
343 proxy_stream: TcpStream,
344 target_stream: TcpStream,
345}
346
347impl ServerConnection {
348 #[allow(dead_code)]
349 fn identifier(&self) -> u64 {
350 return self.identifier;
351 }
352
353 #[allow(dead_code)]
354 fn local_addr(&self) -> SocketAddr {
355 return self.local_addr;
356 }
357
358 #[allow(dead_code)]
359 fn peer_addr(&self) -> SocketAddr {
360 return self.peer_addr;
361 }
362
363 #[allow(dead_code)]
364 fn auth_method(&self) -> AuthMethod {
365 return self.auth_method;
366 }
367
368 async fn transfer(&mut self) -> SocksResult<()> {
369 tokio::io::copy_bidirectional(&mut self.proxy_stream, &mut self.target_stream).await?;
370 return Ok(());
371 }
372}
373
374#[async_trait]
375pub trait PasswordAuthority: Send + Sync + 'static {
376 async fn auth(&self, username: &[u8], password: &[u8]) -> Option<u64>;
377}
378
379pub(crate) struct DefaultAuthority {
380 passwords: HashMap<Bytes, Bytes>,
381}
382
383impl DefaultAuthority {
384 pub fn new(passwords: HashMap<Bytes, Bytes>) -> DefaultAuthority {
385 return DefaultAuthority { passwords };
386 }
387}
388
389#[async_trait]
390impl PasswordAuthority for DefaultAuthority {
391 async fn auth(&self, username: &[u8], password: &[u8]) -> Option<u64> {
392 let result = self
393 .passwords
394 .get(username)
395 .map(|p| p == password)
396 .unwrap_or(false);
397 return if result { Some(1) } else { None };
398 }
399}