async_socks5/
lib.rs

1//! An `async`/`.await` [SOCKS5] implementation.
2//!
3//! [SOCKS5]: https://tools.ietf.org/html/rfc1928
4
5#![deny(missing_debug_implementations)]
6
7use std::{
8    fmt::Debug,
9    io::Cursor,
10    net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6},
11    string::FromUtf8Error,
12};
13use tokio::{
14    io,
15    io::{AsyncReadExt, AsyncWriteExt},
16    net::UdpSocket,
17};
18
19// Error and Result
20// *****************************************************************************
21
22/// The library's error type.
23#[derive(Debug, thiserror::Error)]
24pub enum Error {
25    #[error("{0}")]
26    Io(
27        #[from]
28        #[source]
29        io::Error,
30    ),
31    #[error("{0}")]
32    FromUtf8(
33        #[from]
34        #[source]
35        FromUtf8Error,
36    ),
37    #[error("Invalid SOCKS version: {0:x}")]
38    InvalidVersion(u8),
39    #[error("Invalid command: {0:x}")]
40    InvalidCommand(u8),
41    #[error("Invalid address type: {0:x}")]
42    InvalidAtyp(u8),
43    #[error("Invalid reserved bytes: {0:x}")]
44    InvalidReserved(u8),
45    #[error("Invalid authentication status: {0:x}")]
46    InvalidAuthStatus(u8),
47    #[error("Invalid authentication version of subnegotiation: {0:x}")]
48    InvalidAuthSubnegotiation(u8),
49    #[error("Invalid fragment id: {0:x}")]
50    InvalidFragmentId(u8),
51    #[error("Invalid authentication method: {0:?}")]
52    InvalidAuthMethod(AuthMethod),
53    #[error("SOCKS version is 4 when 5 is expected")]
54    WrongVersion,
55    #[error("No acceptable methods")]
56    NoAcceptableMethods,
57    #[error("Unsuccessful reply: {0:?}")]
58    Response(UnsuccessfulReply),
59    #[error("{0:?} length is more than 255 bytes")]
60    TooLongString(StringKind),
61}
62
63/// Required to mark which string is too long.
64/// See [`Error::TooLongString`].
65///
66/// [`Error::TooLongString`]: enum.Error.html#variant.TooLongString
67#[derive(Debug, Eq, PartialEq, Copy, Clone, Hash)]
68pub enum StringKind {
69    Domain,
70    Username,
71    Password,
72}
73
74/// The library's `Result` type alias.
75pub type Result<T, E = Error> = std::result::Result<T, E>;
76
77// Utilities
78// *****************************************************************************
79
80trait ReadExt: AsyncReadExt + Unpin {
81    async fn read_version(&mut self) -> Result<()> {
82        let value = self.read_u8().await?;
83
84        match value {
85            0x04 => Err(Error::WrongVersion),
86            0x05 => Ok(()),
87            _ => Err(Error::InvalidVersion(value)),
88        }
89    }
90
91    async fn read_method(&mut self) -> Result<AuthMethod> {
92        let value = self.read_u8().await?;
93
94        let method = match value {
95            0x00 => AuthMethod::None,
96            0x01 => AuthMethod::GssApi,
97            0x02 => AuthMethod::UsernamePassword,
98            0x03..=0x7f => AuthMethod::IanaReserved(value),
99            0x80..=0xfe => AuthMethod::Private(value),
100            0xff => return Err(Error::NoAcceptableMethods),
101        };
102
103        Ok(method)
104    }
105
106    async fn read_command(&mut self) -> Result<Command> {
107        let value = self.read_u8().await?;
108
109        let command = match value {
110            0x01 => Command::Connect,
111            0x02 => Command::Bind,
112            0x03 => Command::UdpAssociate,
113            _ => return Err(Error::InvalidCommand(value)),
114        };
115
116        Ok(command)
117    }
118
119    async fn read_atyp(&mut self) -> Result<Atyp> {
120        let value = self.read_u8().await?;
121        let atyp = match value {
122            0x01 => Atyp::V4,
123            0x03 => Atyp::Domain,
124            0x04 => Atyp::V6,
125            _ => return Err(Error::InvalidAtyp(value)),
126        };
127        Ok(atyp)
128    }
129
130    async fn read_reserved(&mut self) -> Result<()> {
131        let value = self.read_u8().await?;
132
133        match value {
134            0x00 => Ok(()),
135            _ => Err(Error::InvalidReserved(value)),
136        }
137    }
138
139    async fn read_fragment_id(&mut self) -> Result<()> {
140        let value = self.read_u8().await?;
141
142        if value == 0x00 {
143            Ok(())
144        } else {
145            Err(Error::InvalidFragmentId(value))
146        }
147    }
148
149    async fn read_reply(&mut self) -> Result<()> {
150        let value = self.read_u8().await?;
151
152        let reply = match value {
153            0x00 => return Ok(()),
154            0x01 => UnsuccessfulReply::GeneralFailure,
155            0x02 => UnsuccessfulReply::ConnectionNotAllowedByRules,
156            0x03 => UnsuccessfulReply::NetworkUnreachable,
157            0x04 => UnsuccessfulReply::HostUnreachable,
158            0x05 => UnsuccessfulReply::ConnectionRefused,
159            0x06 => UnsuccessfulReply::TtlExpired,
160            0x07 => UnsuccessfulReply::CommandNotSupported,
161            0x08 => UnsuccessfulReply::AddressTypeNotSupported,
162            _ => UnsuccessfulReply::Unassigned(value),
163        };
164
165        Err(Error::Response(reply))
166    }
167
168    async fn read_target_addr(&mut self) -> Result<AddrKind> {
169        let atyp: Atyp = self.read_atyp().await?;
170
171        let addr = match atyp {
172            Atyp::V4 => {
173                let mut ip = [0; 4];
174                self.read_exact(&mut ip).await?;
175                let port = self.read_u16().await?;
176                AddrKind::Ip(SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::from(ip), port)))
177            }
178            Atyp::V6 => {
179                let mut ip = [0; 16];
180                self.read_exact(&mut ip).await?;
181                let port = self.read_u16().await?;
182                AddrKind::Ip(SocketAddr::V6(SocketAddrV6::new(
183                    Ipv6Addr::from(ip),
184                    port,
185                    0,
186                    0,
187                )))
188            }
189            Atyp::Domain => {
190                let str = self.read_string().await?;
191                let port = self.read_u16().await?;
192                AddrKind::Domain(str, port)
193            }
194        };
195
196        Ok(addr)
197    }
198
199    async fn read_string(&mut self) -> Result<String> {
200        let len = self.read_u8().await?;
201        let mut str = vec![0; len as usize];
202        self.read_exact(&mut str).await?;
203        let str = String::from_utf8(str)?;
204        Ok(str)
205    }
206
207    async fn read_auth_version(&mut self) -> Result<()> {
208        let value = self.read_u8().await?;
209
210        if value != 0x01 {
211            return Err(Error::InvalidAuthSubnegotiation(value));
212        }
213
214        Ok(())
215    }
216
217    async fn read_auth_status(&mut self) -> Result<()> {
218        let value = self.read_u8().await?;
219
220        if value != 0x00 {
221            return Err(Error::InvalidAuthStatus(value));
222        }
223
224        Ok(())
225    }
226
227    async fn read_selection_msg(&mut self) -> Result<AuthMethod> {
228        self.read_version().await?;
229        self.read_method().await
230    }
231
232    async fn read_final(&mut self) -> Result<AddrKind> {
233        self.read_version().await?;
234        self.read_reply().await?;
235        self.read_reserved().await?;
236        let addr = self.read_target_addr().await?;
237        Ok(addr)
238    }
239}
240
241impl<T: AsyncReadExt + Unpin> ReadExt for T {}
242
243trait WriteExt: AsyncWriteExt + Unpin {
244    async fn write_version(&mut self) -> Result<()> {
245        self.write_u8(0x05).await?;
246        Ok(())
247    }
248
249    async fn write_method(&mut self, method: AuthMethod) -> Result<()> {
250        let value = match method {
251            AuthMethod::None => 0x00,
252            AuthMethod::GssApi => 0x01,
253            AuthMethod::UsernamePassword => 0x02,
254            AuthMethod::IanaReserved(value) => value,
255            AuthMethod::Private(value) => value,
256        };
257        self.write_u8(value).await?;
258        Ok(())
259    }
260
261    async fn write_command(&mut self, command: Command) -> Result<()> {
262        self.write_u8(command as u8).await?;
263        Ok(())
264    }
265
266    async fn write_atyp(&mut self, atyp: Atyp) -> Result<()> {
267        self.write_u8(atyp as u8).await?;
268        Ok(())
269    }
270
271    async fn write_reserved(&mut self) -> Result<()> {
272        self.write_u8(0x00).await?;
273        Ok(())
274    }
275
276    async fn write_fragment_id(&mut self) -> Result<()> {
277        self.write_u8(0x00).await?;
278        Ok(())
279    }
280
281    async fn write_target_addr(&mut self, target_addr: &AddrKind) -> Result<()> {
282        match target_addr {
283            AddrKind::Ip(SocketAddr::V4(addr)) => {
284                self.write_atyp(Atyp::V4).await?;
285                self.write_all(&addr.ip().octets()).await?;
286                self.write_u16(addr.port()).await?;
287            }
288            AddrKind::Ip(SocketAddr::V6(addr)) => {
289                self.write_atyp(Atyp::V6).await?;
290                self.write_all(&addr.ip().octets()).await?;
291                self.write_u16(addr.port()).await?;
292            }
293            AddrKind::Domain(domain, port) => {
294                self.write_atyp(Atyp::Domain).await?;
295                self.write_string(domain, StringKind::Domain).await?;
296                self.write_u16(*port).await?;
297            }
298        }
299        Ok(())
300    }
301
302    async fn write_string(&mut self, string: &str, kind: StringKind) -> Result<()> {
303        let bytes = string.as_bytes();
304        if bytes.len() > 255 {
305            return Err(Error::TooLongString(kind));
306        }
307        self.write_u8(bytes.len() as u8).await?;
308        self.write_all(bytes).await?;
309        Ok(())
310    }
311
312    async fn write_auth_version(&mut self) -> Result<()> {
313        self.write_u8(0x01).await?;
314        Ok(())
315    }
316
317    async fn write_methods(&mut self, methods: &[AuthMethod]) -> Result<()> {
318        self.write_u8(methods.len() as u8).await?;
319        for method in methods {
320            self.write_method(*method).await?;
321        }
322        Ok(())
323    }
324
325    async fn write_selection_msg(&mut self, methods: &[AuthMethod]) -> Result<()> {
326        self.write_version().await?;
327        self.write_methods(methods).await?;
328        self.flush().await?;
329        Ok(())
330    }
331
332    async fn write_final(&mut self, command: Command, addr: &AddrKind) -> Result<()> {
333        self.write_version().await?;
334        self.write_command(command).await?;
335        self.write_reserved().await?;
336        self.write_target_addr(addr).await?;
337        self.flush().await?;
338        Ok(())
339    }
340}
341
342impl<T: AsyncWriteExt + Unpin> WriteExt for T {}
343
344async fn username_password_auth<S>(stream: &mut S, auth: Auth) -> Result<()>
345where
346    S: WriteExt + ReadExt + Send,
347{
348    stream.write_auth_version().await?;
349    stream
350        .write_string(&auth.username, StringKind::Username)
351        .await?;
352    stream
353        .write_string(&auth.password, StringKind::Password)
354        .await?;
355    stream.flush().await?;
356
357    stream.read_auth_version().await?;
358    stream.read_auth_status().await
359}
360
361async fn init<S, A>(
362    stream: &mut S,
363    command: Command,
364    addr: A,
365    auth: Option<Auth>,
366) -> Result<AddrKind>
367where
368    S: WriteExt + ReadExt + Send,
369    A: Into<AddrKind>,
370{
371    let addr: AddrKind = addr.into();
372
373    let mut methods = Vec::with_capacity(2);
374    methods.push(AuthMethod::None);
375    if auth.is_some() {
376        methods.push(AuthMethod::UsernamePassword);
377    }
378    stream.write_selection_msg(&methods).await?;
379
380    let method: AuthMethod = stream.read_selection_msg().await?;
381    match method {
382        AuthMethod::None => {}
383        // FIXME: until if let in match is stabilized
384        AuthMethod::UsernamePassword if auth.is_some() => {
385            username_password_auth(stream, auth.unwrap()).await?;
386        }
387        _ => return Err(Error::InvalidAuthMethod(method)),
388    }
389
390    stream.write_final(command, &addr).await?;
391    stream.read_final().await
392}
393
394// Types
395// *****************************************************************************
396
397/// Required for a username + password authentication.
398#[derive(Debug, Eq, PartialEq, Clone, Hash)]
399pub struct Auth {
400    pub username: String,
401    pub password: String,
402}
403
404impl Auth {
405    /// Constructs `Auth` with the specified username and a password.
406    pub fn new<U, P>(username: U, password: P) -> Self
407    where
408        U: Into<String>,
409        P: Into<String>,
410    {
411        Self {
412            username: username.into(),
413            password: password.into(),
414        }
415    }
416}
417
418/// A proxy authentication method.
419#[derive(Debug, Eq, PartialEq, Copy, Clone, Hash)]
420pub enum AuthMethod {
421    /// No authentication required.
422    None,
423    /// GSS API.
424    GssApi,
425    /// A username + password authentication.
426    UsernamePassword,
427    /// IANA reserved.
428    IanaReserved(u8),
429    /// A private authentication method.
430    Private(u8),
431}
432
433enum Command {
434    Connect = 0x01,
435    Bind = 0x02,
436    UdpAssociate = 0x03,
437}
438
439enum Atyp {
440    V4 = 0x01,
441    Domain = 0x03,
442    V6 = 0x4,
443}
444
445/// An unsuccessful reply from a proxy server.
446#[derive(Debug, Eq, PartialEq, Copy, Clone, Hash)]
447pub enum UnsuccessfulReply {
448    GeneralFailure,
449    ConnectionNotAllowedByRules,
450    NetworkUnreachable,
451    HostUnreachable,
452    ConnectionRefused,
453    TtlExpired,
454    CommandNotSupported,
455    AddressTypeNotSupported,
456    Unassigned(u8),
457}
458
459/// Either [`SocketAddr`] or a domain and a port.
460///
461/// [`SocketAddr`]: https://doc.rust-lang.org/std/net/enum.SocketAddr.html
462#[derive(Debug, Eq, PartialEq, Clone, Hash)]
463pub enum AddrKind {
464    Ip(SocketAddr),
465    Domain(String, u16),
466}
467
468impl AddrKind {
469    const MAX_SIZE: usize = 1 // atyp
470        + 1 // domain len
471        + 255 // domain
472        + 2; // port
473
474    // FIXME: until ToSocketAddrs is allowed to implement
475    fn to_socket_addr(&self) -> String {
476        match self {
477            AddrKind::Ip(addr) => addr.to_string(),
478            AddrKind::Domain(domain, port) => format!("{}:{}", domain, port),
479        }
480    }
481
482    fn size(&self) -> usize {
483        1 + // atyp
484            2 + // port
485            match self {
486                AddrKind::Ip(SocketAddr::V4(_)) => 4,
487                AddrKind::Ip(SocketAddr::V6(_)) => 16,
488                AddrKind::Domain(domain, _) =>
489                    1 // string len
490                        + domain.len(),
491            }
492    }
493}
494
495impl From<(IpAddr, u16)> for AddrKind {
496    fn from(value: (IpAddr, u16)) -> Self {
497        Self::Ip(value.into())
498    }
499}
500
501impl From<(Ipv4Addr, u16)> for AddrKind {
502    fn from(value: (Ipv4Addr, u16)) -> Self {
503        Self::Ip(value.into())
504    }
505}
506
507impl From<(Ipv6Addr, u16)> for AddrKind {
508    fn from(value: (Ipv6Addr, u16)) -> Self {
509        Self::Ip(value.into())
510    }
511}
512
513impl From<(String, u16)> for AddrKind {
514    fn from((domain, port): (String, u16)) -> Self {
515        Self::Domain(domain, port)
516    }
517}
518
519impl From<(&'_ str, u16)> for AddrKind {
520    fn from((domain, port): (&'_ str, u16)) -> Self {
521        Self::Domain(domain.to_owned(), port)
522    }
523}
524
525impl From<SocketAddr> for AddrKind {
526    fn from(value: SocketAddr) -> Self {
527        Self::Ip(value)
528    }
529}
530
531impl From<SocketAddrV4> for AddrKind {
532    fn from(value: SocketAddrV4) -> Self {
533        Self::Ip(value.into())
534    }
535}
536
537impl From<SocketAddrV6> for AddrKind {
538    fn from(value: SocketAddrV6) -> Self {
539        Self::Ip(value.into())
540    }
541}
542
543// Public API
544// *****************************************************************************
545
546/// Proxifies a TCP connection. Performs the [`CONNECT`] command under the hood.
547///
548/// [`CONNECT`]: https://tools.ietf.org/html/rfc1928#page-6
549///
550/// ```no_run
551/// # use async_socks5::Result;
552/// # #[tokio::main(flavor = "current_thread")]
553/// # async fn main() -> Result<()> {
554/// use async_socks5::connect;
555/// use tokio::{io::BufStream, net::TcpStream};
556///
557/// let stream = TcpStream::connect("my-proxy-server.com:54321").await?;
558/// let mut stream = BufStream::new(stream);
559/// connect(&mut stream, ("google.com", 80), None).await?;
560///
561/// # Ok(())
562/// # }
563/// ```
564pub async fn connect<S, A>(socket: &mut S, addr: A, auth: Option<Auth>) -> Result<AddrKind>
565where
566    S: AsyncWriteExt + AsyncReadExt + Send + Unpin,
567    A: Into<AddrKind>,
568{
569    init(socket, Command::Connect, addr, auth).await
570}
571
572/// A listener that accepts TCP connections through a proxy.
573///
574/// ```no_run
575/// # use async_socks5::Result;
576/// # #[tokio::main(flavor = "current_thread")]
577/// # async fn main() -> Result<()> {
578/// use async_socks5::SocksListener;
579/// use tokio::{io::BufStream, net::TcpStream};
580///
581/// let stream = TcpStream::connect("my-proxy-server.com:54321").await?;
582/// let mut stream = BufStream::new(stream);
583/// let (stream, addr) = SocksListener::bind(stream, ("ftp-server.org", 21), None)
584///     .await?
585///     .accept()
586///     .await?;
587///
588/// # Ok(())
589/// # }
590/// ```
591#[derive(Debug)]
592pub struct SocksListener<S> {
593    stream: S,
594    proxy_addr: AddrKind,
595}
596
597impl<S> SocksListener<S>
598where
599    S: AsyncWriteExt + AsyncReadExt + Send + Unpin,
600{
601    /// Creates `SocksListener`. Performs the [`BIND`] command under the hood.
602    ///
603    /// [`BIND`]: https://tools.ietf.org/html/rfc1928#page-6
604    pub async fn bind<A>(mut stream: S, addr: A, auth: Option<Auth>) -> Result<Self>
605    where
606        A: Into<AddrKind>,
607    {
608        let addr = init(&mut stream, Command::Bind, addr, auth).await?;
609        Ok(Self {
610            stream,
611            proxy_addr: addr,
612        })
613    }
614
615    pub fn proxy_addr(&self) -> &AddrKind {
616        &self.proxy_addr
617    }
618
619    pub async fn accept(mut self) -> Result<(S, AddrKind)> {
620        let addr = self.stream.read_final().await?;
621        Ok((self.stream, addr))
622    }
623}
624
625/// A UDP socket that sends packets through a proxy.
626#[derive(Debug)]
627pub struct SocksDatagram<S> {
628    socket: UdpSocket,
629    proxy_addr: AddrKind,
630    stream: S,
631}
632
633impl<S> SocksDatagram<S>
634where
635    S: AsyncWriteExt + AsyncReadExt + Send + Unpin,
636{
637    /// Creates `SocksDatagram`. Performs [`UDP ASSOCIATE`] under the hood.
638    ///
639    /// [`UDP ASSOCIATE`]: https://tools.ietf.org/html/rfc1928#page-7
640    pub async fn associate<A>(
641        mut proxy_stream: S,
642        socket: UdpSocket,
643        auth: Option<Auth>,
644        association_addr: Option<A>,
645    ) -> Result<Self>
646    where
647        A: Into<AddrKind>,
648    {
649        let addr = association_addr
650            .map(Into::into)
651            .unwrap_or_else(|| AddrKind::Ip(SocketAddr::new(IpAddr::from([0, 0, 0, 0]), 0)));
652        let proxy_addr = init(&mut proxy_stream, Command::UdpAssociate, addr, auth).await?;
653        socket.connect(proxy_addr.to_socket_addr()).await?;
654        Ok(Self {
655            socket,
656            proxy_addr,
657            stream: proxy_stream,
658        })
659    }
660
661    pub fn proxy_addr(&self) -> &AddrKind {
662        &self.proxy_addr
663    }
664
665    pub fn get_ref(&self) -> &UdpSocket {
666        &self.socket
667    }
668
669    pub fn get_mut(&mut self) -> &mut UdpSocket {
670        &mut self.socket
671    }
672
673    pub fn into_inner(self) -> (S, UdpSocket) {
674        (self.stream, self.socket)
675    }
676
677    async fn write_request(buf: &[u8], addr: AddrKind) -> Result<Vec<u8>> {
678        let bytes_size = Self::get_buf_size(addr.size(), buf.len());
679        let bytes = Vec::with_capacity(bytes_size);
680
681        let mut cursor = Cursor::new(bytes);
682        cursor.write_reserved().await?;
683        cursor.write_reserved().await?;
684        cursor.write_fragment_id().await?;
685        cursor.write_target_addr(&addr).await?;
686        cursor.write_all(buf).await?;
687
688        let bytes = cursor.into_inner();
689        Ok(bytes)
690    }
691
692    pub async fn send_to<A>(&self, buf: &[u8], addr: A) -> Result<usize>
693    where
694        A: Into<AddrKind>,
695    {
696        let addr: AddrKind = addr.into();
697        let bytes = Self::write_request(buf, addr).await?;
698        Ok(self.socket.send(&bytes).await?)
699    }
700
701    async fn read_response(
702        len: usize,
703        buf: &mut [u8],
704        bytes: &mut [u8],
705    ) -> Result<(usize, AddrKind)> {
706        let mut cursor = Cursor::new(bytes);
707        cursor.read_reserved().await?;
708        cursor.read_reserved().await?;
709        cursor.read_fragment_id().await?;
710        let addr = cursor.read_target_addr().await?;
711        let header_len = cursor.position() as usize;
712        cursor.read_exact(buf).await?;
713        Ok((len - header_len, addr))
714    }
715
716    pub async fn recv_from(&self, buf: &mut [u8]) -> Result<(usize, AddrKind)> {
717        let bytes_size = Self::get_buf_size(AddrKind::MAX_SIZE, buf.len());
718        let mut bytes = vec![0; bytes_size];
719
720        let len = self.socket.recv(&mut bytes).await?;
721        let (read, addr) = Self::read_response(len, buf, &mut bytes).await?;
722        Ok((read, addr))
723    }
724
725    fn get_buf_size(addr_size: usize, buf_len: usize) -> usize {
726        2 // reserved
727                + 1 // fragment id
728                + addr_size
729                + buf_len
730    }
731}
732
733// Tests
734// *****************************************************************************
735
736#[cfg(test)]
737mod tests {
738    use super::*;
739    use std::sync::Arc;
740    use tokio::{io::BufStream, net::TcpStream};
741
742    const PROXY_ADDR: &str = "127.0.0.1:1080";
743    const PROXY_AUTH_ADDR: &str = "127.0.0.1:1081";
744    const DATA: &[u8] = b"Hello, world!";
745
746    async fn connect(addr: &str, auth: Option<Auth>) {
747        let socket = TcpStream::connect(addr).await.unwrap();
748        let mut socket = BufStream::new(socket);
749        super::connect(
750            &mut socket,
751            AddrKind::Domain("google.com".to_string(), 80),
752            auth,
753        )
754        .await
755        .unwrap();
756    }
757
758    #[tokio::test]
759    async fn connect_auth() {
760        connect(PROXY_AUTH_ADDR, Some(Auth::new("hyper", "proxy"))).await;
761    }
762
763    #[tokio::test]
764    async fn connect_no_auth() {
765        connect(PROXY_ADDR, None).await;
766    }
767
768    #[should_panic = "ConnectionNotAllowedByRules"]
769    #[tokio::test]
770    async fn connect_no_auth_panic() {
771        connect(PROXY_AUTH_ADDR, None).await;
772    }
773
774    #[tokio::test]
775    async fn bind() {
776        let server_addr = AddrKind::Domain("127.0.0.1".to_string(), 80);
777
778        let client = TcpStream::connect(PROXY_ADDR).await.unwrap();
779        let client = BufStream::new(client);
780        let client = SocksListener::bind(client, server_addr, None)
781            .await
782            .unwrap();
783
784        let server_addr = client.proxy_addr.to_socket_addr();
785        let mut server = TcpStream::connect(&server_addr).await.unwrap();
786
787        let (mut client, _) = client.accept().await.unwrap();
788
789        server.write_all(DATA).await.unwrap();
790
791        let mut buf = [0; DATA.len()];
792        client.read_exact(&mut buf).await.unwrap();
793        assert_eq!(buf, DATA);
794    }
795
796    type TestStream = BufStream<TcpStream>;
797    type TestDatagram = SocksDatagram<TestStream>;
798    type TestHalves = (Arc<TestDatagram>, Arc<TestDatagram>);
799
800    trait UdpClient {
801        async fn send_to<A>(&mut self, buf: &[u8], addr: A) -> Result<usize>
802        where
803            A: Into<AddrKind> + Send;
804
805        async fn recv_from(&mut self, buf: &mut [u8]) -> Result<(usize, AddrKind)>;
806    }
807
808    impl UdpClient for TestDatagram {
809        async fn send_to<A>(&mut self, buf: &[u8], addr: A) -> Result<usize, Error>
810        where
811            A: Into<AddrKind> + Send,
812        {
813            SocksDatagram::send_to(self, buf, addr).await
814        }
815
816        async fn recv_from(&mut self, buf: &mut [u8]) -> Result<(usize, AddrKind), Error> {
817            SocksDatagram::recv_from(self, buf).await
818        }
819    }
820
821    impl UdpClient for TestHalves {
822        async fn send_to<A>(&mut self, buf: &[u8], addr: A) -> Result<usize, Error>
823        where
824            A: Into<AddrKind> + Send,
825        {
826            self.1.send_to(buf, addr).await
827        }
828
829        async fn recv_from(&mut self, buf: &mut [u8]) -> Result<(usize, AddrKind), Error> {
830            self.0.recv_from(buf).await
831        }
832    }
833
834    const CLIENT_ADDR: &str = "127.0.0.1:2345";
835    const SERVER_ADDR: &str = "127.0.0.1:23456";
836
837    async fn create_client() -> TestDatagram {
838        let proxy = TcpStream::connect(PROXY_ADDR).await.unwrap();
839        let proxy = BufStream::new(proxy);
840        let client = UdpSocket::bind(CLIENT_ADDR).await.unwrap();
841        SocksDatagram::associate(proxy, client, None, None::<SocketAddr>)
842            .await
843            .unwrap()
844    }
845
846    struct UdpTest<C> {
847        client: C,
848        server: UdpSocket,
849        server_addr: AddrKind,
850    }
851
852    impl<C: UdpClient> UdpTest<C> {
853        async fn test(mut self) {
854            let mut buf = vec![0; DATA.len()];
855            self.client.send_to(DATA, self.server_addr).await.unwrap();
856            let (len, addr) = self.server.recv_from(&mut buf).await.unwrap();
857            assert_eq!(len, buf.len());
858            assert_eq!(buf.as_slice(), DATA);
859
860            let mut buf = vec![0; DATA.len()];
861            self.server.send_to(DATA, addr).await.unwrap();
862            let (len, _) = self.client.recv_from(&mut buf).await.unwrap();
863            assert_eq!(len, buf.len());
864            assert_eq!(buf.as_slice(), DATA);
865        }
866    }
867
868    impl UdpTest<TestDatagram> {
869        async fn datagram() -> Self {
870            let client = create_client().await;
871
872            let server_addr: SocketAddr = SERVER_ADDR.parse().unwrap();
873            let server = UdpSocket::bind(server_addr).await.unwrap();
874            let server_addr = AddrKind::Ip(server_addr);
875
876            Self {
877                client,
878                server,
879                server_addr,
880            }
881        }
882    }
883
884    impl UdpTest<TestHalves> {
885        async fn halves() -> Self {
886            let this = UdpTest::<TestDatagram>::datagram().await;
887            let client = Arc::new(this.client);
888            Self {
889                client: (client.clone(), client),
890                server: this.server,
891                server_addr: this.server_addr,
892            }
893        }
894    }
895
896    #[tokio::test]
897    async fn udp_associate() {
898        UdpTest::datagram().await.test().await
899    }
900
901    #[tokio::test]
902    async fn udp_datagram_halves() {
903        UdpTest::halves().await.test().await
904    }
905}