Skip to main content

socks5_impl/client/
mod.rs

1use crate::{
2    error::{Error, Result},
3    protocol::{Address, AddressType, AsyncStreamOperation, AuthMethod, Command, Reply, StreamOperation, UserKey, Version},
4};
5use std::{
6    fmt::Debug,
7    io::Cursor,
8    net::{SocketAddr, ToSocketAddrs},
9    time::Duration,
10};
11use tokio::{
12    io::{AsyncReadExt, AsyncWriteExt, BufStream},
13    net::{TcpStream, UdpSocket},
14};
15
16#[async_trait::async_trait]
17pub trait Socks5Reader: AsyncReadExt + Unpin {
18    async fn read_version(&mut self) -> Result<()> {
19        let value = Version::try_from(self.read_u8().await?)?;
20        match value {
21            Version::V4 => Err(Error::WrongVersion),
22            Version::V5 => Ok(()),
23        }
24    }
25
26    async fn read_method(&mut self) -> Result<AuthMethod> {
27        let value = AuthMethod::from(self.read_u8().await?);
28        match value {
29            AuthMethod::NoAuth | AuthMethod::UserPass => Ok(value),
30            _ => Err(Error::InvalidAuthMethod(value)),
31        }
32    }
33
34    async fn read_command(&mut self) -> Result<Command> {
35        let value = self.read_u8().await?;
36        Ok(Command::try_from(value)?)
37    }
38
39    async fn read_atyp(&mut self) -> Result<AddressType> {
40        let value = self.read_u8().await?;
41        Ok(AddressType::try_from(value)?)
42    }
43
44    async fn read_reserved(&mut self) -> Result<()> {
45        let value = self.read_u8().await?;
46        match value {
47            0x00 => Ok(()),
48            _ => Err(Error::InvalidReserved(value)),
49        }
50    }
51
52    async fn read_fragment_id(&mut self) -> Result<()> {
53        let value = self.read_u8().await?;
54        if value == 0x00 {
55            Ok(())
56        } else {
57            Err(Error::InvalidFragmentId(value))
58        }
59    }
60
61    async fn read_reply(&mut self) -> Result<()> {
62        let value = self.read_u8().await?;
63        match Reply::try_from(value)? {
64            Reply::Succeeded => Ok(()),
65            reply => Err(format!("{reply}").into()),
66        }
67    }
68
69    async fn read_address(&mut self) -> Result<Address> {
70        Ok(Address::retrieve_from_async_stream(self).await?)
71    }
72
73    async fn read_string(&mut self) -> Result<String> {
74        let len = self.read_u8().await? as usize;
75        let mut str = vec![0; len];
76        self.read_exact(&mut str).await?;
77        let str = String::from_utf8(str)?;
78        Ok(str)
79    }
80
81    async fn read_auth_version(&mut self) -> Result<()> {
82        let value = self.read_u8().await?;
83        if value != 0x01 {
84            return Err(Error::InvalidAuthSubnegotiation(value));
85        }
86        Ok(())
87    }
88
89    async fn read_auth_status(&mut self) -> Result<()> {
90        let value = self.read_u8().await?;
91        if value != 0x00 {
92            return Err(Error::InvalidAuthStatus(value));
93        }
94        Ok(())
95    }
96
97    async fn read_selection_msg(&mut self) -> Result<AuthMethod> {
98        self.read_version().await?;
99        self.read_method().await
100    }
101
102    async fn read_final(&mut self) -> Result<Address> {
103        self.read_version().await?;
104        self.read_reply().await?;
105        self.read_reserved().await?;
106        let addr = self.read_address().await?;
107        Ok(addr)
108    }
109}
110
111#[async_trait::async_trait]
112impl<T: AsyncReadExt + Unpin> Socks5Reader for T {}
113
114#[async_trait::async_trait]
115pub trait Socks5Writer: AsyncWriteExt + Unpin {
116    async fn write_version(&mut self) -> Result<()> {
117        self.write_u8(0x05).await?;
118        Ok(())
119    }
120
121    async fn write_method(&mut self, method: AuthMethod) -> Result<()> {
122        self.write_u8(u8::from(method)).await?;
123        Ok(())
124    }
125
126    async fn write_command(&mut self, command: Command) -> Result<()> {
127        self.write_u8(u8::from(command)).await?;
128        Ok(())
129    }
130
131    async fn write_atyp(&mut self, atyp: AddressType) -> Result<()> {
132        self.write_u8(u8::from(atyp)).await?;
133        Ok(())
134    }
135
136    async fn write_reserved(&mut self) -> Result<()> {
137        self.write_u8(0x00).await?;
138        Ok(())
139    }
140
141    async fn write_fragment_id(&mut self, id: u8) -> Result<()> {
142        self.write_u8(id).await?;
143        Ok(())
144    }
145
146    async fn write_address(&mut self, address: &Address) -> Result<()> {
147        address.write_to_async_stream(self).await?;
148        Ok(())
149    }
150
151    async fn write_string(&mut self, string: &str) -> Result<()> {
152        let bytes = string.as_bytes();
153        if bytes.len() > 255 {
154            return Err("Too long string".into());
155        }
156        self.write_u8(bytes.len() as u8).await?;
157        self.write_all(bytes).await?;
158        Ok(())
159    }
160
161    async fn write_auth_version(&mut self) -> Result<()> {
162        self.write_u8(0x01).await?;
163        Ok(())
164    }
165
166    async fn write_methods(&mut self, methods: &[AuthMethod]) -> Result<()> {
167        let method_count = u8::try_from(methods.len()).map_err(|_| "Too many authentication methods")?;
168        self.write_u8(method_count).await?;
169        for method in methods {
170            self.write_method(*method).await?;
171        }
172        Ok(())
173    }
174
175    async fn write_selection_msg(&mut self, methods: &[AuthMethod]) -> Result<()> {
176        self.write_version().await?;
177        self.write_methods(methods).await?;
178        self.flush().await?;
179        Ok(())
180    }
181
182    async fn write_final(&mut self, command: Command, addr: &Address) -> Result<()> {
183        self.write_version().await?;
184        self.write_command(command).await?;
185        self.write_reserved().await?;
186        self.write_address(addr).await?;
187        self.flush().await?;
188        Ok(())
189    }
190}
191
192#[async_trait::async_trait]
193impl<T: AsyncWriteExt + Unpin> Socks5Writer for T {}
194
195async fn username_password_auth<S>(stream: &mut S, auth: &UserKey) -> Result<()>
196where
197    S: Socks5Writer + Socks5Reader + Send,
198{
199    stream.write_auth_version().await?;
200    stream.write_string(&auth.username).await?;
201    stream.write_string(&auth.password).await?;
202    stream.flush().await?;
203
204    stream.read_auth_version().await?;
205    stream.read_auth_status().await
206}
207
208async fn init<S, A>(stream: &mut S, command: Command, addr: A, auth: Option<UserKey>) -> Result<Address>
209where
210    S: Socks5Writer + Socks5Reader + Send,
211    A: Into<Address>,
212{
213    let addr: Address = addr.into();
214
215    let mut methods = Vec::with_capacity(2);
216    methods.push(AuthMethod::NoAuth);
217    if auth.is_some() {
218        methods.push(AuthMethod::UserPass);
219    }
220    stream.write_selection_msg(&methods).await?;
221    stream.flush().await?;
222
223    let method: AuthMethod = stream.read_selection_msg().await?;
224    match method {
225        AuthMethod::NoAuth => {}
226        AuthMethod::UserPass if auth.is_some() => {
227            username_password_auth(stream, auth.as_ref().unwrap()).await?;
228        }
229        _ => return Err(Error::InvalidAuthMethod(method)),
230    }
231
232    stream.write_final(command, &addr).await?;
233    stream.read_final().await
234}
235
236/// Proxifies a TCP connection. Performs the [`CONNECT`] command under the hood.
237///
238/// [`CONNECT`]: https://tools.ietf.org/html/rfc1928#page-6
239///
240/// ```no_run
241/// # use socks5_impl::Result;
242/// # #[tokio::main(flavor = "current_thread")]
243/// # async fn main() -> Result<()> {
244/// use socks5_impl::client;
245/// use tokio::{io::BufStream, net::TcpStream};
246///
247/// let stream = TcpStream::connect("my-proxy-server.com:54321").await?;
248/// let mut stream = BufStream::new(stream);
249/// client::connect(&mut stream, ("google.com", 80), None).await?;
250///
251/// # Ok(())
252/// # }
253/// ```
254pub async fn connect<S, A>(socket: &mut S, addr: A, auth: Option<UserKey>) -> Result<Address>
255where
256    S: AsyncWriteExt + AsyncReadExt + Send + Unpin,
257    A: Into<Address>,
258{
259    init(socket, Command::Connect, addr, auth).await
260}
261
262/// A listener that accepts TCP connections through a proxy.
263///
264/// ```no_run
265/// # use socks5_impl::Result;
266/// # #[tokio::main(flavor = "current_thread")]
267/// # async fn main() -> Result<()> {
268/// use socks5_impl::client::SocksListener;
269/// use tokio::{io::BufStream, net::TcpStream};
270///
271/// let stream = TcpStream::connect("my-proxy-server.com:54321").await?;
272/// let mut stream = BufStream::new(stream);
273/// let (stream, addr) = SocksListener::bind(stream, ("ftp-server.org", 21), None)
274///     .await?
275///     .accept()
276///     .await?;
277///
278/// # Ok(())
279/// # }
280/// ```
281#[derive(Debug)]
282pub struct SocksListener<S> {
283    stream: S,
284    proxy_addr: Address,
285}
286
287impl<S> SocksListener<S>
288where
289    S: AsyncWriteExt + AsyncReadExt + Send + Unpin,
290{
291    /// Creates `SocksListener`. Performs the [`BIND`] command under the hood.
292    ///
293    /// [`BIND`]: https://tools.ietf.org/html/rfc1928#page-6
294    pub async fn bind<A>(mut stream: S, addr: A, auth: Option<UserKey>) -> Result<Self>
295    where
296        A: Into<Address>,
297    {
298        let addr = init(&mut stream, Command::Bind, addr, auth).await?;
299        Ok(Self { stream, proxy_addr: addr })
300    }
301
302    pub fn proxy_addr(&self) -> &Address {
303        &self.proxy_addr
304    }
305
306    pub async fn accept(mut self) -> Result<(S, Address)> {
307        let addr = self.stream.read_final().await?;
308        Ok((self.stream, addr))
309    }
310}
311
312/// A UDP socket that sends packets through a proxy.
313#[derive(Debug)]
314pub struct SocksDatagram<S> {
315    socket: UdpSocket,
316    proxy_addr: Address,
317    stream: S,
318}
319
320impl<S> SocksDatagram<S>
321where
322    S: AsyncWriteExt + AsyncReadExt + Send + Unpin,
323{
324    /// Creates `SocksDatagram`. Performs [`UDP ASSOCIATE`] under the hood.
325    ///
326    /// [`UDP ASSOCIATE`]: https://tools.ietf.org/html/rfc1928#page-7
327    pub async fn udp_associate(mut stream: S, socket: UdpSocket, auth: Option<UserKey>) -> Result<Self> {
328        let addr = if socket.local_addr()?.is_ipv4() { "0.0.0.0:0" } else { "[::]:0" };
329        let addr = addr.parse::<SocketAddr>()?;
330        let proxy_addr = init(&mut stream, Command::UdpAssociate, addr, auth).await?;
331        let addr = proxy_addr.to_socket_addrs()?.next().ok_or("InvalidAddress")?;
332        socket.connect(addr).await?;
333        Ok(Self {
334            socket,
335            proxy_addr,
336            stream,
337        })
338    }
339
340    /// Returns the address of the associated udp address.
341    pub fn proxy_addr(&self) -> &Address {
342        &self.proxy_addr
343    }
344
345    /// Returns a reference to the underlying udp socket.
346    pub fn get_ref(&self) -> &UdpSocket {
347        &self.socket
348    }
349
350    /// Returns a mutable reference to the underlying udp socket.
351    pub fn get_mut(&mut self) -> &mut UdpSocket {
352        &mut self.socket
353    }
354
355    /// Returns the associated stream and udp socket.
356    pub fn into_inner(self) -> (S, UdpSocket) {
357        (self.stream, self.socket)
358    }
359
360    //  Builds a udp-based client request packet, the format is as follows:
361    //  +----+------+------+----------+----------+----------+
362    //  |RSV | FRAG | ATYP | DST.ADDR | DST.PORT |   DATA   |
363    //  +----+------+------+----------+----------+----------+
364    //  | 2  |  1   |  1   | Variable |    2     | Variable |
365    //  +----+------+------+----------+----------+----------+
366    //  The reference link is as follows:
367    //  https://tools.ietf.org/html/rfc1928#page-8
368    //
369    pub async fn build_socks5_udp_datagram(buf: &[u8], addr: &Address) -> Result<Vec<u8>> {
370        let bytes_size = Self::get_buf_size(addr.len(), buf.len());
371        let bytes = Vec::with_capacity(bytes_size);
372
373        let mut cursor = Cursor::new(bytes);
374        cursor.write_reserved().await?;
375        cursor.write_reserved().await?;
376        cursor.write_fragment_id(0x00).await?;
377        cursor.write_address(addr).await?;
378        cursor.write_all(buf).await?;
379
380        let bytes = cursor.into_inner();
381        Ok(bytes)
382    }
383
384    /// Sends data via the udp socket to the given address.
385    pub async fn send_to<A>(&self, buf: &[u8], addr: A) -> Result<usize>
386    where
387        A: Into<Address>,
388    {
389        let addr: Address = addr.into();
390        let bytes = Self::build_socks5_udp_datagram(buf, &addr).await?;
391        Ok(self.socket.send(&bytes).await?)
392    }
393
394    /// Parses the udp-based server response packet, the format is same as the client request packet.
395    async fn parse_socks5_udp_response(bytes: &mut [u8], buf: &mut Vec<u8>) -> Result<(usize, Address)> {
396        let len = bytes.len();
397        let mut cursor = Cursor::new(bytes);
398        cursor.read_reserved().await?;
399        cursor.read_reserved().await?;
400        cursor.read_fragment_id().await?;
401        let addr = cursor.read_address().await?;
402        let header_len = cursor.position() as usize;
403        buf.resize(len - header_len, 0);
404        _ = cursor.read_exact(buf).await?;
405        Ok((len - header_len, addr))
406    }
407
408    /// Receives data from the udp socket and returns the number of bytes read and the origin of the data.
409    pub async fn recv_from(&self, timeout: Duration, buf: &mut Vec<u8>) -> Result<(usize, Address)> {
410        const UDP_MTU: usize = 1500;
411        // let bytes_size = Self::get_buf_size(Address::max_serialized_len(), buf.len());
412        let bytes_size = UDP_MTU;
413        let mut bytes = vec![0; bytes_size];
414        let len = tokio::time::timeout(timeout, self.socket.recv(&mut bytes)).await??;
415        bytes.truncate(len);
416        let (read, addr) = Self::parse_socks5_udp_response(&mut bytes, buf).await?;
417        Ok((read, addr))
418    }
419
420    fn get_buf_size(addr_size: usize, buf_len: usize) -> usize {
421        // reserved + fragment id + addr_size + buf_len
422        2 + 1 + addr_size + buf_len
423    }
424}
425
426pub type GuardTcpStream = BufStream<TcpStream>;
427pub type SocksUdpClient = SocksDatagram<GuardTcpStream>;
428
429#[async_trait::async_trait]
430pub trait UdpClientTrait {
431    async fn send_to<A>(&mut self, buf: &[u8], addr: A) -> Result<usize>
432    where
433        A: Into<Address> + Send + Unpin;
434
435    async fn recv_from(&mut self, timeout: Duration, buf: &mut Vec<u8>) -> Result<(usize, Address)>;
436}
437
438#[async_trait::async_trait]
439impl UdpClientTrait for SocksUdpClient {
440    async fn send_to<A>(&mut self, buf: &[u8], addr: A) -> Result<usize, Error>
441    where
442        A: Into<Address> + Send + Unpin,
443    {
444        SocksDatagram::send_to(self, buf, addr).await
445    }
446
447    async fn recv_from(&mut self, timeout: Duration, buf: &mut Vec<u8>) -> Result<(usize, Address), Error> {
448        SocksDatagram::recv_from(self, timeout, buf).await
449    }
450}
451
452pub async fn create_udp_client<A: Into<SocketAddr>>(proxy_addr: A, auth: Option<UserKey>) -> Result<SocksUdpClient> {
453    let proxy_addr = proxy_addr.into();
454    let client_addr = if proxy_addr.is_ipv4() { "0.0.0.0:0" } else { "[::]:0" };
455    let proxy = TcpStream::connect(proxy_addr).await?;
456    let proxy = BufStream::new(proxy);
457    let client = UdpSocket::bind(client_addr).await?;
458    SocksDatagram::udp_associate(proxy, client, auth).await
459}
460
461pub struct UdpClientImpl<C> {
462    client: C,
463    server_addr: Address,
464}
465
466impl UdpClientImpl<SocksUdpClient> {
467    pub async fn transfer_data(&self, data: &[u8], timeout: Duration) -> Result<Vec<u8>> {
468        let len = self.client.send_to(data, &self.server_addr).await?;
469        let buf = SocksDatagram::<GuardTcpStream>::build_socks5_udp_datagram(data, &self.server_addr).await?;
470        assert_eq!(len, buf.len());
471
472        let mut buf = Vec::with_capacity(data.len());
473        let (_len, _) = self.client.recv_from(timeout, &mut buf).await?;
474        Ok(buf)
475    }
476
477    pub async fn datagram<A1, A2>(proxy_addr: A1, udp_server_addr: A2, auth: Option<UserKey>) -> Result<Self>
478    where
479        A1: Into<SocketAddr>,
480        A2: Into<Address>,
481    {
482        let client = create_udp_client(proxy_addr, auth).await?;
483
484        let server_addr = udp_server_addr.into();
485
486        Ok(Self { client, server_addr })
487    }
488}
489
490#[cfg(test)]
491mod tests {
492    use crate::{
493        Error, Result,
494        client::{self, SocksListener, SocksUdpClient, UdpClientTrait},
495        protocol::{Address, UserKey},
496    };
497    use std::{
498        net::{SocketAddr, ToSocketAddrs},
499        sync::Arc,
500        time::Duration,
501    };
502    use tokio::{
503        io::{AsyncReadExt, AsyncWriteExt, BufStream},
504        net::{TcpStream, UdpSocket},
505    };
506
507    const PROXY_ADDR: &str = "127.0.0.1:1080";
508    const PROXY_AUTH_ADDR: &str = "127.0.0.1:1081";
509    const DATA: &[u8] = b"Hello, world!";
510
511    async fn connect(addr: &str, auth: Option<UserKey>) {
512        let socket = TcpStream::connect(addr).await.unwrap();
513        let mut socket = BufStream::new(socket);
514        client::connect(&mut socket, Address::from(("baidu.com", 80)), auth).await.unwrap();
515    }
516
517    #[ignore]
518    #[tokio::test]
519    async fn connect_auth() {
520        connect(PROXY_AUTH_ADDR, Some(UserKey::new("hyper", "proxy"))).await;
521    }
522
523    #[ignore]
524    #[tokio::test]
525    async fn connect_no_auth() {
526        connect(PROXY_ADDR, None).await;
527    }
528
529    #[ignore]
530    #[should_panic = "InvalidAuthMethod(NoAcceptableMethods)"]
531    #[tokio::test]
532    async fn connect_no_auth_panic() {
533        connect(PROXY_AUTH_ADDR, None).await;
534    }
535
536    #[ignore]
537    #[tokio::test]
538    async fn bind() {
539        let run_block = async {
540            let server_addr = Address::from(("127.0.0.1", 8000));
541
542            let client = TcpStream::connect(PROXY_ADDR).await?;
543            let client = BufStream::new(client);
544            let client = SocksListener::bind(client, server_addr, None).await?;
545
546            let server_addr = client.proxy_addr.to_socket_addrs()?.next().ok_or("Invalid address")?;
547            let mut server = TcpStream::connect(&server_addr).await?;
548
549            let (mut client, _) = client.accept().await?;
550
551            server.write_all(DATA).await?;
552
553            let mut buf = [0; DATA.len()];
554            client.read_exact(&mut buf).await?;
555            assert_eq!(buf, DATA);
556            Ok::<_, Error>(())
557        };
558        if let Err(e) = run_block.await {
559            println!("{e:?}");
560        }
561    }
562
563    type TestHalves = (Arc<SocksUdpClient>, Arc<SocksUdpClient>);
564
565    #[async_trait::async_trait]
566    impl UdpClientTrait for TestHalves {
567        async fn send_to<A>(&mut self, buf: &[u8], addr: A) -> Result<usize, Error>
568        where
569            A: Into<Address> + Send,
570        {
571            self.1.send_to(buf, addr).await
572        }
573
574        async fn recv_from(&mut self, timeout: Duration, buf: &mut Vec<u8>) -> Result<(usize, Address), Error> {
575            self.0.recv_from(timeout, buf).await
576        }
577    }
578
579    const SERVER_ADDR: &str = "127.0.0.1:23456";
580
581    struct UdpTest<C> {
582        client: C,
583        server: UdpSocket,
584        server_addr: Address,
585    }
586
587    impl<C: UdpClientTrait> UdpTest<C> {
588        async fn test(mut self) {
589            let mut buf = vec![0; DATA.len()];
590            self.client.send_to(DATA, self.server_addr).await.unwrap();
591            let (len, addr) = self.server.recv_from(&mut buf).await.unwrap();
592            assert_eq!(len, buf.len());
593            assert_eq!(buf.as_slice(), DATA);
594
595            let mut buf = vec![0; DATA.len()];
596            self.server.send_to(DATA, addr).await.unwrap();
597            let timeout = Duration::from_secs(5);
598            let (len, _) = self.client.recv_from(timeout, &mut buf).await.unwrap();
599            assert_eq!(len, buf.len());
600            assert_eq!(buf.as_slice(), DATA);
601        }
602    }
603
604    impl UdpTest<SocksUdpClient> {
605        async fn datagram() -> Self {
606            let addr = PROXY_ADDR.parse::<SocketAddr>().unwrap();
607            let client = client::create_udp_client(addr, None).await.unwrap();
608
609            let server_addr: SocketAddr = SERVER_ADDR.parse().unwrap();
610            let server = UdpSocket::bind(server_addr).await.unwrap();
611            let server_addr = Address::from(server_addr);
612
613            Self {
614                client,
615                server,
616                server_addr,
617            }
618        }
619    }
620
621    impl UdpTest<TestHalves> {
622        async fn halves() -> Self {
623            let this = UdpTest::<SocksUdpClient>::datagram().await;
624            let client = Arc::new(this.client);
625            Self {
626                client: (client.clone(), client),
627                server: this.server,
628                server_addr: this.server_addr,
629            }
630        }
631    }
632
633    #[ignore]
634    #[tokio::test]
635    async fn udp_datagram_halves() {
636        UdpTest::halves().await.test().await
637    }
638}