1use crate::{
2 error::{Error, Result},
3 protocol::{Address, AddressType, AsyncStreamOperation, AuthMethod, Command, Reply, StreamOperation, UserKey, Version},
4};
5use std::{
6 fmt::Debug,
7 io::Cursor,
8 net::{SocketAddr, ToSocketAddrs},
9 time::Duration,
10};
11use tokio::{
12 io::{AsyncReadExt, AsyncWriteExt, BufStream},
13 net::{TcpStream, UdpSocket},
14};
15
16#[async_trait::async_trait]
17pub trait Socks5Reader: AsyncReadExt + Unpin {
18 async fn read_version(&mut self) -> Result<()> {
19 let value = Version::try_from(self.read_u8().await?)?;
20 match value {
21 Version::V4 => Err(Error::WrongVersion),
22 Version::V5 => Ok(()),
23 }
24 }
25
26 async fn read_method(&mut self) -> Result<AuthMethod> {
27 let value = AuthMethod::from(self.read_u8().await?);
28 match value {
29 AuthMethod::NoAuth | AuthMethod::UserPass => Ok(value),
30 _ => Err(Error::InvalidAuthMethod(value)),
31 }
32 }
33
34 async fn read_command(&mut self) -> Result<Command> {
35 let value = self.read_u8().await?;
36 Ok(Command::try_from(value)?)
37 }
38
39 async fn read_atyp(&mut self) -> Result<AddressType> {
40 let value = self.read_u8().await?;
41 Ok(AddressType::try_from(value)?)
42 }
43
44 async fn read_reserved(&mut self) -> Result<()> {
45 let value = self.read_u8().await?;
46 match value {
47 0x00 => Ok(()),
48 _ => Err(Error::InvalidReserved(value)),
49 }
50 }
51
52 async fn read_fragment_id(&mut self) -> Result<()> {
53 let value = self.read_u8().await?;
54 if value == 0x00 {
55 Ok(())
56 } else {
57 Err(Error::InvalidFragmentId(value))
58 }
59 }
60
61 async fn read_reply(&mut self) -> Result<()> {
62 let value = self.read_u8().await?;
63 match Reply::try_from(value)? {
64 Reply::Succeeded => Ok(()),
65 reply => Err(format!("{reply}").into()),
66 }
67 }
68
69 async fn read_address(&mut self) -> Result<Address> {
70 Ok(Address::retrieve_from_async_stream(self).await?)
71 }
72
73 async fn read_string(&mut self) -> Result<String> {
74 let len = self.read_u8().await? as usize;
75 let mut str = vec![0; len];
76 self.read_exact(&mut str).await?;
77 let str = String::from_utf8(str)?;
78 Ok(str)
79 }
80
81 async fn read_auth_version(&mut self) -> Result<()> {
82 let value = self.read_u8().await?;
83 if value != 0x01 {
84 return Err(Error::InvalidAuthSubnegotiation(value));
85 }
86 Ok(())
87 }
88
89 async fn read_auth_status(&mut self) -> Result<()> {
90 let value = self.read_u8().await?;
91 if value != 0x00 {
92 return Err(Error::InvalidAuthStatus(value));
93 }
94 Ok(())
95 }
96
97 async fn read_selection_msg(&mut self) -> Result<AuthMethod> {
98 self.read_version().await?;
99 self.read_method().await
100 }
101
102 async fn read_final(&mut self) -> Result<Address> {
103 self.read_version().await?;
104 self.read_reply().await?;
105 self.read_reserved().await?;
106 let addr = self.read_address().await?;
107 Ok(addr)
108 }
109}
110
111#[async_trait::async_trait]
112impl<T: AsyncReadExt + Unpin> Socks5Reader for T {}
113
114#[async_trait::async_trait]
115pub trait Socks5Writer: AsyncWriteExt + Unpin {
116 async fn write_version(&mut self) -> Result<()> {
117 self.write_u8(0x05).await?;
118 Ok(())
119 }
120
121 async fn write_method(&mut self, method: AuthMethod) -> Result<()> {
122 self.write_u8(u8::from(method)).await?;
123 Ok(())
124 }
125
126 async fn write_command(&mut self, command: Command) -> Result<()> {
127 self.write_u8(u8::from(command)).await?;
128 Ok(())
129 }
130
131 async fn write_atyp(&mut self, atyp: AddressType) -> Result<()> {
132 self.write_u8(u8::from(atyp)).await?;
133 Ok(())
134 }
135
136 async fn write_reserved(&mut self) -> Result<()> {
137 self.write_u8(0x00).await?;
138 Ok(())
139 }
140
141 async fn write_fragment_id(&mut self, id: u8) -> Result<()> {
142 self.write_u8(id).await?;
143 Ok(())
144 }
145
146 async fn write_address(&mut self, address: &Address) -> Result<()> {
147 address.write_to_async_stream(self).await?;
148 Ok(())
149 }
150
151 async fn write_string(&mut self, string: &str) -> Result<()> {
152 let bytes = string.as_bytes();
153 if bytes.len() > 255 {
154 return Err("Too long string".into());
155 }
156 self.write_u8(bytes.len() as u8).await?;
157 self.write_all(bytes).await?;
158 Ok(())
159 }
160
161 async fn write_auth_version(&mut self) -> Result<()> {
162 self.write_u8(0x01).await?;
163 Ok(())
164 }
165
166 async fn write_methods(&mut self, methods: &[AuthMethod]) -> Result<()> {
167 let method_count = u8::try_from(methods.len()).map_err(|_| "Too many authentication methods")?;
168 self.write_u8(method_count).await?;
169 for method in methods {
170 self.write_method(*method).await?;
171 }
172 Ok(())
173 }
174
175 async fn write_selection_msg(&mut self, methods: &[AuthMethod]) -> Result<()> {
176 self.write_version().await?;
177 self.write_methods(methods).await?;
178 self.flush().await?;
179 Ok(())
180 }
181
182 async fn write_final(&mut self, command: Command, addr: &Address) -> Result<()> {
183 self.write_version().await?;
184 self.write_command(command).await?;
185 self.write_reserved().await?;
186 self.write_address(addr).await?;
187 self.flush().await?;
188 Ok(())
189 }
190}
191
192#[async_trait::async_trait]
193impl<T: AsyncWriteExt + Unpin> Socks5Writer for T {}
194
195async fn username_password_auth<S>(stream: &mut S, auth: &UserKey) -> Result<()>
196where
197 S: Socks5Writer + Socks5Reader + Send,
198{
199 stream.write_auth_version().await?;
200 stream.write_string(&auth.username).await?;
201 stream.write_string(&auth.password).await?;
202 stream.flush().await?;
203
204 stream.read_auth_version().await?;
205 stream.read_auth_status().await
206}
207
208async fn init<S, A>(stream: &mut S, command: Command, addr: A, auth: Option<UserKey>) -> Result<Address>
209where
210 S: Socks5Writer + Socks5Reader + Send,
211 A: Into<Address>,
212{
213 let addr: Address = addr.into();
214
215 let mut methods = Vec::with_capacity(2);
216 methods.push(AuthMethod::NoAuth);
217 if auth.is_some() {
218 methods.push(AuthMethod::UserPass);
219 }
220 stream.write_selection_msg(&methods).await?;
221 stream.flush().await?;
222
223 let method: AuthMethod = stream.read_selection_msg().await?;
224 match method {
225 AuthMethod::NoAuth => {}
226 AuthMethod::UserPass if auth.is_some() => {
227 username_password_auth(stream, auth.as_ref().unwrap()).await?;
228 }
229 _ => return Err(Error::InvalidAuthMethod(method)),
230 }
231
232 stream.write_final(command, &addr).await?;
233 stream.read_final().await
234}
235
236pub async fn connect<S, A>(socket: &mut S, addr: A, auth: Option<UserKey>) -> Result<Address>
255where
256 S: AsyncWriteExt + AsyncReadExt + Send + Unpin,
257 A: Into<Address>,
258{
259 init(socket, Command::Connect, addr, auth).await
260}
261
262#[derive(Debug)]
282pub struct SocksListener<S> {
283 stream: S,
284 proxy_addr: Address,
285}
286
287impl<S> SocksListener<S>
288where
289 S: AsyncWriteExt + AsyncReadExt + Send + Unpin,
290{
291 pub async fn bind<A>(mut stream: S, addr: A, auth: Option<UserKey>) -> Result<Self>
295 where
296 A: Into<Address>,
297 {
298 let addr = init(&mut stream, Command::Bind, addr, auth).await?;
299 Ok(Self { stream, proxy_addr: addr })
300 }
301
302 pub fn proxy_addr(&self) -> &Address {
303 &self.proxy_addr
304 }
305
306 pub async fn accept(mut self) -> Result<(S, Address)> {
307 let addr = self.stream.read_final().await?;
308 Ok((self.stream, addr))
309 }
310}
311
312#[derive(Debug)]
314pub struct SocksDatagram<S> {
315 socket: UdpSocket,
316 proxy_addr: Address,
317 stream: S,
318}
319
320impl<S> SocksDatagram<S>
321where
322 S: AsyncWriteExt + AsyncReadExt + Send + Unpin,
323{
324 pub async fn udp_associate(mut stream: S, socket: UdpSocket, auth: Option<UserKey>) -> Result<Self> {
328 let addr = if socket.local_addr()?.is_ipv4() { "0.0.0.0:0" } else { "[::]:0" };
329 let addr = addr.parse::<SocketAddr>()?;
330 let proxy_addr = init(&mut stream, Command::UdpAssociate, addr, auth).await?;
331 let addr = proxy_addr.to_socket_addrs()?.next().ok_or("InvalidAddress")?;
332 socket.connect(addr).await?;
333 Ok(Self {
334 socket,
335 proxy_addr,
336 stream,
337 })
338 }
339
340 pub fn proxy_addr(&self) -> &Address {
342 &self.proxy_addr
343 }
344
345 pub fn get_ref(&self) -> &UdpSocket {
347 &self.socket
348 }
349
350 pub fn get_mut(&mut self) -> &mut UdpSocket {
352 &mut self.socket
353 }
354
355 pub fn into_inner(self) -> (S, UdpSocket) {
357 (self.stream, self.socket)
358 }
359
360 pub async fn build_socks5_udp_datagram(buf: &[u8], addr: &Address) -> Result<Vec<u8>> {
370 let bytes_size = Self::get_buf_size(addr.len(), buf.len());
371 let bytes = Vec::with_capacity(bytes_size);
372
373 let mut cursor = Cursor::new(bytes);
374 cursor.write_reserved().await?;
375 cursor.write_reserved().await?;
376 cursor.write_fragment_id(0x00).await?;
377 cursor.write_address(addr).await?;
378 cursor.write_all(buf).await?;
379
380 let bytes = cursor.into_inner();
381 Ok(bytes)
382 }
383
384 pub async fn send_to<A>(&self, buf: &[u8], addr: A) -> Result<usize>
386 where
387 A: Into<Address>,
388 {
389 let addr: Address = addr.into();
390 let bytes = Self::build_socks5_udp_datagram(buf, &addr).await?;
391 Ok(self.socket.send(&bytes).await?)
392 }
393
394 async fn parse_socks5_udp_response(bytes: &mut [u8], buf: &mut Vec<u8>) -> Result<(usize, Address)> {
396 let len = bytes.len();
397 let mut cursor = Cursor::new(bytes);
398 cursor.read_reserved().await?;
399 cursor.read_reserved().await?;
400 cursor.read_fragment_id().await?;
401 let addr = cursor.read_address().await?;
402 let header_len = cursor.position() as usize;
403 buf.resize(len - header_len, 0);
404 _ = cursor.read_exact(buf).await?;
405 Ok((len - header_len, addr))
406 }
407
408 pub async fn recv_from(&self, timeout: Duration, buf: &mut Vec<u8>) -> Result<(usize, Address)> {
410 const UDP_MTU: usize = 1500;
411 let bytes_size = UDP_MTU;
413 let mut bytes = vec![0; bytes_size];
414 let len = tokio::time::timeout(timeout, self.socket.recv(&mut bytes)).await??;
415 bytes.truncate(len);
416 let (read, addr) = Self::parse_socks5_udp_response(&mut bytes, buf).await?;
417 Ok((read, addr))
418 }
419
420 fn get_buf_size(addr_size: usize, buf_len: usize) -> usize {
421 2 + 1 + addr_size + buf_len
423 }
424}
425
426pub type GuardTcpStream = BufStream<TcpStream>;
427pub type SocksUdpClient = SocksDatagram<GuardTcpStream>;
428
429#[async_trait::async_trait]
430pub trait UdpClientTrait {
431 async fn send_to<A>(&mut self, buf: &[u8], addr: A) -> Result<usize>
432 where
433 A: Into<Address> + Send + Unpin;
434
435 async fn recv_from(&mut self, timeout: Duration, buf: &mut Vec<u8>) -> Result<(usize, Address)>;
436}
437
438#[async_trait::async_trait]
439impl UdpClientTrait for SocksUdpClient {
440 async fn send_to<A>(&mut self, buf: &[u8], addr: A) -> Result<usize, Error>
441 where
442 A: Into<Address> + Send + Unpin,
443 {
444 SocksDatagram::send_to(self, buf, addr).await
445 }
446
447 async fn recv_from(&mut self, timeout: Duration, buf: &mut Vec<u8>) -> Result<(usize, Address), Error> {
448 SocksDatagram::recv_from(self, timeout, buf).await
449 }
450}
451
452pub async fn create_udp_client<A: Into<SocketAddr>>(proxy_addr: A, auth: Option<UserKey>) -> Result<SocksUdpClient> {
453 let proxy_addr = proxy_addr.into();
454 let client_addr = if proxy_addr.is_ipv4() { "0.0.0.0:0" } else { "[::]:0" };
455 let proxy = TcpStream::connect(proxy_addr).await?;
456 let proxy = BufStream::new(proxy);
457 let client = UdpSocket::bind(client_addr).await?;
458 SocksDatagram::udp_associate(proxy, client, auth).await
459}
460
461pub struct UdpClientImpl<C> {
462 client: C,
463 server_addr: Address,
464}
465
466impl UdpClientImpl<SocksUdpClient> {
467 pub async fn transfer_data(&self, data: &[u8], timeout: Duration) -> Result<Vec<u8>> {
468 let len = self.client.send_to(data, &self.server_addr).await?;
469 let buf = SocksDatagram::<GuardTcpStream>::build_socks5_udp_datagram(data, &self.server_addr).await?;
470 assert_eq!(len, buf.len());
471
472 let mut buf = Vec::with_capacity(data.len());
473 let (_len, _) = self.client.recv_from(timeout, &mut buf).await?;
474 Ok(buf)
475 }
476
477 pub async fn datagram<A1, A2>(proxy_addr: A1, udp_server_addr: A2, auth: Option<UserKey>) -> Result<Self>
478 where
479 A1: Into<SocketAddr>,
480 A2: Into<Address>,
481 {
482 let client = create_udp_client(proxy_addr, auth).await?;
483
484 let server_addr = udp_server_addr.into();
485
486 Ok(Self { client, server_addr })
487 }
488}
489
490#[cfg(test)]
491mod tests {
492 use crate::{
493 Error, Result,
494 client::{self, SocksListener, SocksUdpClient, UdpClientTrait},
495 protocol::{Address, UserKey},
496 };
497 use std::{
498 net::{SocketAddr, ToSocketAddrs},
499 sync::Arc,
500 time::Duration,
501 };
502 use tokio::{
503 io::{AsyncReadExt, AsyncWriteExt, BufStream},
504 net::{TcpStream, UdpSocket},
505 };
506
507 const PROXY_ADDR: &str = "127.0.0.1:1080";
508 const PROXY_AUTH_ADDR: &str = "127.0.0.1:1081";
509 const DATA: &[u8] = b"Hello, world!";
510
511 async fn connect(addr: &str, auth: Option<UserKey>) {
512 let socket = TcpStream::connect(addr).await.unwrap();
513 let mut socket = BufStream::new(socket);
514 client::connect(&mut socket, Address::from(("baidu.com", 80)), auth).await.unwrap();
515 }
516
517 #[ignore]
518 #[tokio::test]
519 async fn connect_auth() {
520 connect(PROXY_AUTH_ADDR, Some(UserKey::new("hyper", "proxy"))).await;
521 }
522
523 #[ignore]
524 #[tokio::test]
525 async fn connect_no_auth() {
526 connect(PROXY_ADDR, None).await;
527 }
528
529 #[ignore]
530 #[should_panic = "InvalidAuthMethod(NoAcceptableMethods)"]
531 #[tokio::test]
532 async fn connect_no_auth_panic() {
533 connect(PROXY_AUTH_ADDR, None).await;
534 }
535
536 #[ignore]
537 #[tokio::test]
538 async fn bind() {
539 let run_block = async {
540 let server_addr = Address::from(("127.0.0.1", 8000));
541
542 let client = TcpStream::connect(PROXY_ADDR).await?;
543 let client = BufStream::new(client);
544 let client = SocksListener::bind(client, server_addr, None).await?;
545
546 let server_addr = client.proxy_addr.to_socket_addrs()?.next().ok_or("Invalid address")?;
547 let mut server = TcpStream::connect(&server_addr).await?;
548
549 let (mut client, _) = client.accept().await?;
550
551 server.write_all(DATA).await?;
552
553 let mut buf = [0; DATA.len()];
554 client.read_exact(&mut buf).await?;
555 assert_eq!(buf, DATA);
556 Ok::<_, Error>(())
557 };
558 if let Err(e) = run_block.await {
559 println!("{e:?}");
560 }
561 }
562
563 type TestHalves = (Arc<SocksUdpClient>, Arc<SocksUdpClient>);
564
565 #[async_trait::async_trait]
566 impl UdpClientTrait for TestHalves {
567 async fn send_to<A>(&mut self, buf: &[u8], addr: A) -> Result<usize, Error>
568 where
569 A: Into<Address> + Send,
570 {
571 self.1.send_to(buf, addr).await
572 }
573
574 async fn recv_from(&mut self, timeout: Duration, buf: &mut Vec<u8>) -> Result<(usize, Address), Error> {
575 self.0.recv_from(timeout, buf).await
576 }
577 }
578
579 const SERVER_ADDR: &str = "127.0.0.1:23456";
580
581 struct UdpTest<C> {
582 client: C,
583 server: UdpSocket,
584 server_addr: Address,
585 }
586
587 impl<C: UdpClientTrait> UdpTest<C> {
588 async fn test(mut self) {
589 let mut buf = vec![0; DATA.len()];
590 self.client.send_to(DATA, self.server_addr).await.unwrap();
591 let (len, addr) = self.server.recv_from(&mut buf).await.unwrap();
592 assert_eq!(len, buf.len());
593 assert_eq!(buf.as_slice(), DATA);
594
595 let mut buf = vec![0; DATA.len()];
596 self.server.send_to(DATA, addr).await.unwrap();
597 let timeout = Duration::from_secs(5);
598 let (len, _) = self.client.recv_from(timeout, &mut buf).await.unwrap();
599 assert_eq!(len, buf.len());
600 assert_eq!(buf.as_slice(), DATA);
601 }
602 }
603
604 impl UdpTest<SocksUdpClient> {
605 async fn datagram() -> Self {
606 let addr = PROXY_ADDR.parse::<SocketAddr>().unwrap();
607 let client = client::create_udp_client(addr, None).await.unwrap();
608
609 let server_addr: SocketAddr = SERVER_ADDR.parse().unwrap();
610 let server = UdpSocket::bind(server_addr).await.unwrap();
611 let server_addr = Address::from(server_addr);
612
613 Self {
614 client,
615 server,
616 server_addr,
617 }
618 }
619 }
620
621 impl UdpTest<TestHalves> {
622 async fn halves() -> Self {
623 let this = UdpTest::<SocksUdpClient>::datagram().await;
624 let client = Arc::new(this.client);
625 Self {
626 client: (client.clone(), client),
627 server: this.server,
628 server_addr: this.server_addr,
629 }
630 }
631 }
632
633 #[ignore]
634 #[tokio::test]
635 async fn udp_datagram_halves() {
636 UdpTest::halves().await.test().await
637 }
638}