socks5_async/
lib.rs

1#![forbid(unsafe_code)]
2#[macro_use]
3extern crate log;
4mod socks;
5
6use futures::future::try_join;
7pub use socks::AuthMethod;
8use socks::{AddrType, Command, Response, RESERVED, VERSION5};
9use std::{
10    boxed::Box,
11    error::Error,
12    io,
13    net::{Shutdown, SocketAddr, SocketAddrV4, SocketAddrV6},
14};
15use tokio::{
16    io::{AsyncReadExt, AsyncWriteExt},
17    net::{TcpListener, TcpStream},
18    sync::{mpsc, oneshot},
19};
20
21// Transmited over mpsc channel to check user authentication
22type AuthCheckMsg = (String, String, oneshot::Sender<bool>);
23
24/// A SOCKS5 Server
25pub struct SocksServer {
26    listener: TcpListener,
27    allow_no_auth: bool,
28    auth_tx: mpsc::Sender<AuthCheckMsg>,
29}
30impl SocksServer {
31    /// Creates and returns a new `SocksServer`
32    pub async fn new(
33        socket_addr: SocketAddr,
34        allow_no_auth: bool,
35        auth: Box<dyn Fn(String, String) -> bool + Send>,
36    ) -> SocksServer {
37        let (tx, mut rx) = mpsc::channel::<AuthCheckMsg>(100);
38        tokio::spawn(async move {
39            while let Some((username, password, sender)) = rx.recv().await {
40                if let Err(_) = sender.send(auth(username, password)) {
41                    error!("Failed to send back authentication result.");
42                }
43            }
44        });
45        println!("SOCKS5 server listening on {}", socket_addr);
46        SocksServer {
47            listener: TcpListener::bind(socket_addr).await.unwrap(),
48            allow_no_auth,
49            auth_tx: tx,
50        }
51    }
52
53    /// Starts the server. It **should** be called after initializing server
54    ///
55    /// # Example
56    /// ```
57    /// use socks5_async::SocksServer;
58    /// use std::{
59    ///     boxed::Box,
60    ///     error::Error,
61    ///     net::SocketAddr,
62    /// };
63    ///
64    /// let users = vec![
65    ///     (String::from("user1"), String::from("123456"))
66    /// ];
67    ///
68    /// // Server address
69    /// let address: SocketAddr = "127.0.0.1:1080".parse().unwrap();
70    /// let mut socks5 = SocksServer::new(address, true,
71    ///     Box::new(move |username, password| {
72    ///         // Authenticate user
73    ///         return users.contains(&(username, password));
74    ///     }),
75    /// ).await;
76    /// socks5.serve().await;
77    ///
78    /// ```
79
80    pub async fn serve(&mut self) {
81        loop {
82            let no_auth = self.allow_no_auth.clone();
83            if let Ok((socket, address)) = self.listener.accept().await {
84                let tx2 = self.auth_tx.clone();
85                tokio::spawn(async move {
86                    info!("Client connected: {}", address);
87                    let mut client = SocksServerConnection::new(socket, no_auth, tx2);
88                    match client.serve().await {
89                        Ok(_) => info!("Request was served successfully."),
90                        Err(err) => error!("{}", err.to_string()),
91                    }
92                });
93            }
94        }
95    }
96}
97
98// Represents a SOCKS5 Client (connected to SocksServer)
99struct SocksServerConnection {
100    socket: TcpStream,
101    no_auth: bool,
102    auth_ch: mpsc::Sender<AuthCheckMsg>,
103}
104impl SocksServerConnection {
105    fn new(
106        socket: TcpStream,
107        no_auth: bool,
108        auth_ch: mpsc::Sender<(String, String, oneshot::Sender<bool>)>,
109    ) -> SocksServerConnection {
110        SocksServerConnection {
111            socket,
112            no_auth,
113            auth_ch,
114        }
115    }
116
117    fn shutdown(&mut self, msg: &str) -> Result<(), Box<dyn Error>> {
118        self.socket.shutdown(Shutdown::Both)?;
119        warn!("{}", msg);
120        Ok(())
121    }
122
123    async fn serve(&mut self) -> Result<(), Box<dyn Error>> {
124        let mut header = [0u8; 2];
125        self.socket.read_exact(&mut header).await?;
126
127        // Accept only version 5
128        if header[0] != VERSION5 {
129            self.shutdown("Unsupported version")?;
130            Err(Response::Failure)?;
131        }
132
133        // Get available methods
134        let methods = AuthMethod::get_available_methods(header[1], &mut self.socket).await?;
135
136        // Authenticate the user
137        self.auth(methods).await?;
138
139        // Handle the request
140        self.handle_req().await?;
141
142        Ok(())
143    }
144
145    async fn auth(&mut self, methods: Vec<AuthMethod>) -> Result<(), Box<dyn Error>> {
146        if methods.contains(&AuthMethod::UserPass) {
147            // Authenticate with username/password
148            self.socket
149                .write_all(&[VERSION5, AuthMethod::UserPass as u8])
150                .await?;
151
152            // Read username
153            let mut ulen = [0u8; 2];
154            self.socket.read_exact(&mut ulen).await?;
155            let ulen = ulen[1];
156            let mut username: Vec<u8> = Vec::with_capacity(ulen as usize);
157            for _ in 0..ulen {
158                username.push(0)
159            }
160            self.socket.read_exact(&mut username).await?;
161            let username = String::from_utf8(username).unwrap();
162
163            // Read Password
164            let mut plen = [0u8; 1];
165            self.socket.read_exact(&mut plen).await?;
166            let plen = plen[0];
167            let mut password: Vec<u8> = Vec::with_capacity(plen as usize);
168            for _ in 0..plen {
169                password.push(0)
170            }
171            self.socket.read_exact(&mut password).await?;
172            let password = String::from_utf8(password).unwrap();
173
174            // Authenticate user
175            let (tx, rx) = oneshot::channel::<bool>();
176            self.auth_ch.send((username.clone(), password, tx)).await?;
177            if rx.await? {
178                info!("User authenticated: {}", username);
179                self.socket.write_all(&[1, Response::Success as u8]).await?;
180            } else {
181                self.socket
182                    .write_all(&[VERSION5, Response::Failure as u8])
183                    .await?;
184                self.shutdown("Authentication failed.")?;
185            }
186        } else if self.no_auth && methods.contains(&AuthMethod::NoAuth) {
187            warn!("Client connected with no authentication");
188            self.socket
189                .write_all(&[VERSION5, AuthMethod::NoAuth as u8])
190                .await?
191        } else {
192            self.socket
193                .write_all(&[VERSION5, Response::Failure as u8])
194                .await?;
195            self.shutdown("No acceptable method found.")?;
196        }
197        Ok(())
198    }
199
200    async fn handle_req(&mut self) -> Result<(), Box<dyn Error>> {
201        // Read request header
202        let mut data = [0u8; 3];
203        self.socket.read(&mut data).await?;
204
205        // Read socket address
206        let addresses = AddrType::get_socket_addrs(&mut self.socket).await?;
207
208        // Proccess the command
209        match Command::from(data[1] as usize) {
210            // Note: Currently only connect is accepted
211            Some(Command::Connect) => self.cmd_connect(addresses).await?,
212            _ => {
213                self.shutdown("Command not supported.")?;
214                Err(Response::CommandNotSupported)?;
215            }
216        };
217
218        Ok(())
219    }
220
221    async fn cmd_connect(&mut self, addrs: Vec<SocketAddr>) -> Result<(), Box<dyn Error>> {
222        let mut dest = TcpStream::connect(&addrs[..]).await?;
223
224        self.socket
225            .write_all(&[
226                VERSION5,
227                Response::Success as u8,
228                RESERVED,
229                1,
230                127,
231                0,
232                0,
233                1,
234                0,
235                0,
236            ])
237            .await
238            .unwrap();
239
240        let (mut ro, mut wo) = dest.split();
241        let (mut ri, mut wi) = self.socket.split();
242
243        let client_to_server = async {
244            tokio::io::copy(&mut ri, &mut wo).await?;
245            wo.shutdown().await
246        };
247
248        let server_to_client = async {
249            tokio::io::copy(&mut ro, &mut wi).await?;
250            wi.shutdown().await
251        };
252
253        try_join(client_to_server, server_to_client).await?;
254
255        Ok(())
256    }
257}
258
259/// A SOCKS5 Stream
260pub struct SocksStream {
261    stream: TcpStream,
262}
263impl SocksStream {
264    /// Connects to `proxy_addr` and returns a `TcpStream` which
265    /// is authenticated via provided methods and ready to transfer data.
266    ///
267    /// # Example
268    /// ```
269    /// use socks5_async::SocksStream;
270    ///
271    /// // SOCKS5 proxy server address
272    /// let proxy: SocketAddr = "127.0.0.1:1080".parse().unwrap();
273    ///
274    /// // Target address
275    /// let target: SocketAddrV4 = "127.0.0.1:3033".parse().unwrap();
276    ///
277    /// // Connect to server
278    /// let stream = SocksStream::connect(
279    ///     proxy,
280    ///     target,
281    ///     // Pass None if you want to use NoAuth method
282    ///     Some(("user1".to_string(), "123456".to_string())),
283    /// ).await?;
284    ///
285    /// // Use tcp stream ...
286    /// ```
287    /// # Note
288    /// This methods uses `connect_with_stream()` under the hood.
289    pub async fn connect(
290        proxy_addr: SocketAddr,
291        target_addr: impl ToTargetAddr,
292        user_pass: Option<(String, String)>,
293    ) -> Result<TcpStream, Box<dyn Error>> {
294        let mut socks_stream = SocksStream {
295            stream: TcpStream::connect(proxy_addr).await?,
296        };
297        connect_with_stream(&mut socks_stream.stream, target_addr, user_pass).await?;
298        Ok(socks_stream.stream)
299    }
300}
301
302/// Perform SOCKS5 handshake through a TCP stream
303pub async fn socks_handshake(
304    stream: &mut TcpStream,
305    user_pass: Option<(String, String)>
306) -> Result<(), Box<dyn Error>> {
307    let with_userpass = user_pass.is_some();
308    let methods_len = if with_userpass { 2 } else { 1 };
309    
310    // Start SOCKS5 communication
311    let mut data = vec![0; methods_len + 2];
312    data[0] = VERSION5; // Set SOCKS version
313    data[1] = methods_len as u8; // Set authentiaction methods count
314    if with_userpass {
315        data[2] = AuthMethod::UserPass as u8;
316    }
317    data[1 + methods_len] = AuthMethod::NoAuth as u8;
318    stream.write_all(&mut data).await?;
319
320    // Read method selection response
321    let mut response = [0u8; 2];
322    stream.read_exact(&mut response).await?;
323
324    // Check SOCKS version
325    if response[0] != VERSION5 {
326        Err(io::Error::new(
327            io::ErrorKind::InvalidData,
328            "Invalid SOCKS version",
329        ))?;
330    }
331
332    if response[1] == AuthMethod::UserPass as u8 {
333        if let Some((username, password)) = user_pass {
334            // Send username & password
335            let mut data = vec![0; username.len() + password.len() + 3];
336            data[0] = VERSION5;
337            data[1] = username.len() as u8;
338            data[2..2 + username.len()].copy_from_slice(username.as_bytes());
339            data[2 + username.len()] = password.len() as u8;
340            data[3 + username.len()..].copy_from_slice(password.as_bytes());
341            stream.write_all(&data).await?;
342
343            // Read & check server response
344            let mut response = [0; 2];
345            stream.read_exact(&mut response).await?;
346            if response[1] != Response::Success as u8 {
347                Err(io::Error::new(
348                    io::ErrorKind::Other,
349                    "Wrong username/password",
350                ))?;
351            }
352        } else {
353            Err(io::Error::new(
354                io::ErrorKind::Other,
355                "Username & password requried",
356            ))?;
357        }
358    } else if response[1] != AuthMethod::NoAuth as u8 {
359        Err(io::Error::new(
360            io::ErrorKind::Other,
361            "Invalid authentication method",
362        ))?;
363    }
364
365    Ok(())
366}
367
368/// Send `CONNECT` command to a SOCKS server
369pub async fn cmd_connect(
370    stream: &mut TcpStream,
371    target_addr: impl ToTargetAddr,
372) -> Result<(), Box<dyn Error>> {
373    let target_addr = target_addr.target_addr();
374    
375    // Send connect command
376    let mut data = vec![0; 6 + target_addr.len()];
377    data[0] = VERSION5;
378    data[1] = Command::Connect as u8;
379    data[2] = RESERVED;
380    data[3] = target_addr.addr_type() as u8;
381    target_addr.write_to(&mut data[4..]);
382    stream.write_all(&data).await?;
383
384    // Read server response
385    let mut response = [0u8; 3];
386    stream.read(&mut response).await?;
387
388    // Read socket address
389    AddrType::get_socket_addrs(stream).await?;
390
391    Ok(())
392}
393
394/// Perform SOCKS5 handshake and send `CONNECT` command through a TCP stream 
395pub async fn connect_with_stream(
396    stream: &mut TcpStream,
397    target_addr: impl ToTargetAddr,
398    user_pass: Option<(String, String)>,
399) -> Result<(), Box<dyn Error>> {
400    
401    socks_handshake(stream, user_pass).await?;
402    cmd_connect(stream, target_addr).await?;
403
404    Ok(())
405}
406
407/// Socket Address of the target, required by `SocksStream`
408#[derive(Debug, Clone)]
409pub enum TargetAddr {
410    V4(SocketAddrV4),
411    V6(SocketAddrV6),
412    Domain((String, u16)),
413}
414impl TargetAddr {
415    fn len(&self) -> usize {
416        match self {
417            TargetAddr::V4(_) => 4,
418            TargetAddr::V6(_) => 16,
419            TargetAddr::Domain((domain, _)) => domain.len() + 1,
420        }
421    }
422    fn addr_type(&self) -> AddrType {
423        match self {
424            TargetAddr::V4(_) => AddrType::V4,
425            TargetAddr::V6(_) => AddrType::V4,
426            TargetAddr::Domain(_) => AddrType::Domain,
427        }
428    }
429    fn write_to(&self, buf: &mut [u8]) {
430        match self {
431            TargetAddr::V4(addr) => {
432                let mut ip = addr.ip().octets().to_vec();
433                ip.extend(&addr.port().to_be_bytes());
434                buf[..].copy_from_slice(&ip[..]);
435            }
436            TargetAddr::V6(addr) => {
437                let mut ip = addr.ip().octets().to_vec();
438                ip.extend(&addr.port().to_be_bytes());
439                buf[..].copy_from_slice(&ip[..]);
440            }
441            TargetAddr::Domain((domain, port)) => {
442                let mut ip = domain.as_bytes().to_vec();
443                ip.extend(&port.to_be_bytes());
444                buf[0] = domain.len() as u8;
445                buf[1..].copy_from_slice(&ip[..]);
446            }
447        }
448    }
449}
450
451/// A trait implemented by types that can be converted to `TargetAddr`
452pub trait ToTargetAddr {
453    fn target_addr(self) -> TargetAddr;
454}
455
456impl ToTargetAddr for TargetAddr {
457    fn target_addr(self) -> TargetAddr {
458        self
459    }
460}
461
462impl ToTargetAddr for SocketAddrV4 {
463    fn target_addr(self) -> TargetAddr {
464        TargetAddr::V4(self)
465    }
466}
467
468impl ToTargetAddr for SocketAddrV6 {
469    fn target_addr(self) -> TargetAddr {
470        TargetAddr::V6(self)
471    }
472}
473
474impl ToTargetAddr for SocketAddr {
475    fn target_addr(self) -> TargetAddr {
476        match self {
477            SocketAddr::V4(addr) => TargetAddr::V4(addr),
478            SocketAddr::V6(addr) => TargetAddr::V6(addr),
479        }
480    }
481}