1#![forbid(unsafe_code)]
2#[macro_use]
3extern crate log;
4mod socks;
5
6use futures::future::try_join;
7pub use socks::AuthMethod;
8use socks::{AddrType, Command, Response, RESERVED, VERSION5};
9use std::{
10 boxed::Box,
11 error::Error,
12 io,
13 net::{Shutdown, SocketAddr, SocketAddrV4, SocketAddrV6},
14};
15use tokio::{
16 io::{AsyncReadExt, AsyncWriteExt},
17 net::{TcpListener, TcpStream},
18 sync::{mpsc, oneshot},
19};
20
21type AuthCheckMsg = (String, String, oneshot::Sender<bool>);
23
24pub struct SocksServer {
26 listener: TcpListener,
27 allow_no_auth: bool,
28 auth_tx: mpsc::Sender<AuthCheckMsg>,
29}
30impl SocksServer {
31 pub async fn new(
33 socket_addr: SocketAddr,
34 allow_no_auth: bool,
35 auth: Box<dyn Fn(String, String) -> bool + Send>,
36 ) -> SocksServer {
37 let (tx, mut rx) = mpsc::channel::<AuthCheckMsg>(100);
38 tokio::spawn(async move {
39 while let Some((username, password, sender)) = rx.recv().await {
40 if let Err(_) = sender.send(auth(username, password)) {
41 error!("Failed to send back authentication result.");
42 }
43 }
44 });
45 println!("SOCKS5 server listening on {}", socket_addr);
46 SocksServer {
47 listener: TcpListener::bind(socket_addr).await.unwrap(),
48 allow_no_auth,
49 auth_tx: tx,
50 }
51 }
52
53 pub async fn serve(&mut self) {
81 loop {
82 let no_auth = self.allow_no_auth.clone();
83 if let Ok((socket, address)) = self.listener.accept().await {
84 let tx2 = self.auth_tx.clone();
85 tokio::spawn(async move {
86 info!("Client connected: {}", address);
87 let mut client = SocksServerConnection::new(socket, no_auth, tx2);
88 match client.serve().await {
89 Ok(_) => info!("Request was served successfully."),
90 Err(err) => error!("{}", err.to_string()),
91 }
92 });
93 }
94 }
95 }
96}
97
98struct SocksServerConnection {
100 socket: TcpStream,
101 no_auth: bool,
102 auth_ch: mpsc::Sender<AuthCheckMsg>,
103}
104impl SocksServerConnection {
105 fn new(
106 socket: TcpStream,
107 no_auth: bool,
108 auth_ch: mpsc::Sender<(String, String, oneshot::Sender<bool>)>,
109 ) -> SocksServerConnection {
110 SocksServerConnection {
111 socket,
112 no_auth,
113 auth_ch,
114 }
115 }
116
117 fn shutdown(&mut self, msg: &str) -> Result<(), Box<dyn Error>> {
118 self.socket.shutdown(Shutdown::Both)?;
119 warn!("{}", msg);
120 Ok(())
121 }
122
123 async fn serve(&mut self) -> Result<(), Box<dyn Error>> {
124 let mut header = [0u8; 2];
125 self.socket.read_exact(&mut header).await?;
126
127 if header[0] != VERSION5 {
129 self.shutdown("Unsupported version")?;
130 Err(Response::Failure)?;
131 }
132
133 let methods = AuthMethod::get_available_methods(header[1], &mut self.socket).await?;
135
136 self.auth(methods).await?;
138
139 self.handle_req().await?;
141
142 Ok(())
143 }
144
145 async fn auth(&mut self, methods: Vec<AuthMethod>) -> Result<(), Box<dyn Error>> {
146 if methods.contains(&AuthMethod::UserPass) {
147 self.socket
149 .write_all(&[VERSION5, AuthMethod::UserPass as u8])
150 .await?;
151
152 let mut ulen = [0u8; 2];
154 self.socket.read_exact(&mut ulen).await?;
155 let ulen = ulen[1];
156 let mut username: Vec<u8> = Vec::with_capacity(ulen as usize);
157 for _ in 0..ulen {
158 username.push(0)
159 }
160 self.socket.read_exact(&mut username).await?;
161 let username = String::from_utf8(username).unwrap();
162
163 let mut plen = [0u8; 1];
165 self.socket.read_exact(&mut plen).await?;
166 let plen = plen[0];
167 let mut password: Vec<u8> = Vec::with_capacity(plen as usize);
168 for _ in 0..plen {
169 password.push(0)
170 }
171 self.socket.read_exact(&mut password).await?;
172 let password = String::from_utf8(password).unwrap();
173
174 let (tx, rx) = oneshot::channel::<bool>();
176 self.auth_ch.send((username.clone(), password, tx)).await?;
177 if rx.await? {
178 info!("User authenticated: {}", username);
179 self.socket.write_all(&[1, Response::Success as u8]).await?;
180 } else {
181 self.socket
182 .write_all(&[VERSION5, Response::Failure as u8])
183 .await?;
184 self.shutdown("Authentication failed.")?;
185 }
186 } else if self.no_auth && methods.contains(&AuthMethod::NoAuth) {
187 warn!("Client connected with no authentication");
188 self.socket
189 .write_all(&[VERSION5, AuthMethod::NoAuth as u8])
190 .await?
191 } else {
192 self.socket
193 .write_all(&[VERSION5, Response::Failure as u8])
194 .await?;
195 self.shutdown("No acceptable method found.")?;
196 }
197 Ok(())
198 }
199
200 async fn handle_req(&mut self) -> Result<(), Box<dyn Error>> {
201 let mut data = [0u8; 3];
203 self.socket.read(&mut data).await?;
204
205 let addresses = AddrType::get_socket_addrs(&mut self.socket).await?;
207
208 match Command::from(data[1] as usize) {
210 Some(Command::Connect) => self.cmd_connect(addresses).await?,
212 _ => {
213 self.shutdown("Command not supported.")?;
214 Err(Response::CommandNotSupported)?;
215 }
216 };
217
218 Ok(())
219 }
220
221 async fn cmd_connect(&mut self, addrs: Vec<SocketAddr>) -> Result<(), Box<dyn Error>> {
222 let mut dest = TcpStream::connect(&addrs[..]).await?;
223
224 self.socket
225 .write_all(&[
226 VERSION5,
227 Response::Success as u8,
228 RESERVED,
229 1,
230 127,
231 0,
232 0,
233 1,
234 0,
235 0,
236 ])
237 .await
238 .unwrap();
239
240 let (mut ro, mut wo) = dest.split();
241 let (mut ri, mut wi) = self.socket.split();
242
243 let client_to_server = async {
244 tokio::io::copy(&mut ri, &mut wo).await?;
245 wo.shutdown().await
246 };
247
248 let server_to_client = async {
249 tokio::io::copy(&mut ro, &mut wi).await?;
250 wi.shutdown().await
251 };
252
253 try_join(client_to_server, server_to_client).await?;
254
255 Ok(())
256 }
257}
258
259pub struct SocksStream {
261 stream: TcpStream,
262}
263impl SocksStream {
264 pub async fn connect(
290 proxy_addr: SocketAddr,
291 target_addr: impl ToTargetAddr,
292 user_pass: Option<(String, String)>,
293 ) -> Result<TcpStream, Box<dyn Error>> {
294 let mut socks_stream = SocksStream {
295 stream: TcpStream::connect(proxy_addr).await?,
296 };
297 connect_with_stream(&mut socks_stream.stream, target_addr, user_pass).await?;
298 Ok(socks_stream.stream)
299 }
300}
301
302pub async fn socks_handshake(
304 stream: &mut TcpStream,
305 user_pass: Option<(String, String)>
306) -> Result<(), Box<dyn Error>> {
307 let with_userpass = user_pass.is_some();
308 let methods_len = if with_userpass { 2 } else { 1 };
309
310 let mut data = vec![0; methods_len + 2];
312 data[0] = VERSION5; data[1] = methods_len as u8; if with_userpass {
315 data[2] = AuthMethod::UserPass as u8;
316 }
317 data[1 + methods_len] = AuthMethod::NoAuth as u8;
318 stream.write_all(&mut data).await?;
319
320 let mut response = [0u8; 2];
322 stream.read_exact(&mut response).await?;
323
324 if response[0] != VERSION5 {
326 Err(io::Error::new(
327 io::ErrorKind::InvalidData,
328 "Invalid SOCKS version",
329 ))?;
330 }
331
332 if response[1] == AuthMethod::UserPass as u8 {
333 if let Some((username, password)) = user_pass {
334 let mut data = vec![0; username.len() + password.len() + 3];
336 data[0] = VERSION5;
337 data[1] = username.len() as u8;
338 data[2..2 + username.len()].copy_from_slice(username.as_bytes());
339 data[2 + username.len()] = password.len() as u8;
340 data[3 + username.len()..].copy_from_slice(password.as_bytes());
341 stream.write_all(&data).await?;
342
343 let mut response = [0; 2];
345 stream.read_exact(&mut response).await?;
346 if response[1] != Response::Success as u8 {
347 Err(io::Error::new(
348 io::ErrorKind::Other,
349 "Wrong username/password",
350 ))?;
351 }
352 } else {
353 Err(io::Error::new(
354 io::ErrorKind::Other,
355 "Username & password requried",
356 ))?;
357 }
358 } else if response[1] != AuthMethod::NoAuth as u8 {
359 Err(io::Error::new(
360 io::ErrorKind::Other,
361 "Invalid authentication method",
362 ))?;
363 }
364
365 Ok(())
366}
367
368pub async fn cmd_connect(
370 stream: &mut TcpStream,
371 target_addr: impl ToTargetAddr,
372) -> Result<(), Box<dyn Error>> {
373 let target_addr = target_addr.target_addr();
374
375 let mut data = vec![0; 6 + target_addr.len()];
377 data[0] = VERSION5;
378 data[1] = Command::Connect as u8;
379 data[2] = RESERVED;
380 data[3] = target_addr.addr_type() as u8;
381 target_addr.write_to(&mut data[4..]);
382 stream.write_all(&data).await?;
383
384 let mut response = [0u8; 3];
386 stream.read(&mut response).await?;
387
388 AddrType::get_socket_addrs(stream).await?;
390
391 Ok(())
392}
393
394pub async fn connect_with_stream(
396 stream: &mut TcpStream,
397 target_addr: impl ToTargetAddr,
398 user_pass: Option<(String, String)>,
399) -> Result<(), Box<dyn Error>> {
400
401 socks_handshake(stream, user_pass).await?;
402 cmd_connect(stream, target_addr).await?;
403
404 Ok(())
405}
406
407#[derive(Debug, Clone)]
409pub enum TargetAddr {
410 V4(SocketAddrV4),
411 V6(SocketAddrV6),
412 Domain((String, u16)),
413}
414impl TargetAddr {
415 fn len(&self) -> usize {
416 match self {
417 TargetAddr::V4(_) => 4,
418 TargetAddr::V6(_) => 16,
419 TargetAddr::Domain((domain, _)) => domain.len() + 1,
420 }
421 }
422 fn addr_type(&self) -> AddrType {
423 match self {
424 TargetAddr::V4(_) => AddrType::V4,
425 TargetAddr::V6(_) => AddrType::V4,
426 TargetAddr::Domain(_) => AddrType::Domain,
427 }
428 }
429 fn write_to(&self, buf: &mut [u8]) {
430 match self {
431 TargetAddr::V4(addr) => {
432 let mut ip = addr.ip().octets().to_vec();
433 ip.extend(&addr.port().to_be_bytes());
434 buf[..].copy_from_slice(&ip[..]);
435 }
436 TargetAddr::V6(addr) => {
437 let mut ip = addr.ip().octets().to_vec();
438 ip.extend(&addr.port().to_be_bytes());
439 buf[..].copy_from_slice(&ip[..]);
440 }
441 TargetAddr::Domain((domain, port)) => {
442 let mut ip = domain.as_bytes().to_vec();
443 ip.extend(&port.to_be_bytes());
444 buf[0] = domain.len() as u8;
445 buf[1..].copy_from_slice(&ip[..]);
446 }
447 }
448 }
449}
450
451pub trait ToTargetAddr {
453 fn target_addr(self) -> TargetAddr;
454}
455
456impl ToTargetAddr for TargetAddr {
457 fn target_addr(self) -> TargetAddr {
458 self
459 }
460}
461
462impl ToTargetAddr for SocketAddrV4 {
463 fn target_addr(self) -> TargetAddr {
464 TargetAddr::V4(self)
465 }
466}
467
468impl ToTargetAddr for SocketAddrV6 {
469 fn target_addr(self) -> TargetAddr {
470 TargetAddr::V6(self)
471 }
472}
473
474impl ToTargetAddr for SocketAddr {
475 fn target_addr(self) -> TargetAddr {
476 match self {
477 SocketAddr::V4(addr) => TargetAddr::V4(addr),
478 SocketAddr::V6(addr) => TargetAddr::V6(addr),
479 }
480 }
481}