socks_lib/v5/
mod.rs

1pub mod server;
2
3use std::io;
4use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
5use std::sync::LazyLock;
6
7use bytes::{Buf, BufMut, Bytes, BytesMut};
8use tokio::io::{AsyncRead, AsyncReadExt, BufReader};
9
10/// # Method
11///
12/// ```text
13///  +--------+
14///  | METHOD |
15///  +--------+
16///  |   1    |
17///  +--------+
18/// ```
19///
20#[derive(Debug, Copy, Clone, PartialEq)]
21pub enum Method {
22    NoAuthentication,
23    GSSAPI,
24    UsernamePassword,
25    IanaAssigned(u8),
26    ReservedPrivate(u8),
27    NoAcceptableMethod,
28}
29
30impl Method {
31    #[rustfmt::skip]
32    #[inline]
33    fn as_u8(&self) -> u8 {
34        match self {
35            Self::NoAuthentication            => 0x00,
36            Self::GSSAPI                      => 0x01,
37            Self::UsernamePassword            => 0x03,
38            Self::IanaAssigned(value)         => *value,
39            Self::ReservedPrivate(value)      => *value,
40            Self::NoAcceptableMethod          => 0xFF,
41        }
42    }
43
44    #[rustfmt::skip]
45    #[inline]
46    fn from_u8(value: u8) -> Self {
47        match value {
48            0x00        => Self::NoAuthentication,
49            0x01        => Self::GSSAPI,
50            0x02        => Self::UsernamePassword,
51            0x03..=0x7F => Self::IanaAssigned(value),
52            0x80..=0xFE => Self::ReservedPrivate(value),
53            0xFF        => Self::NoAcceptableMethod,
54        }
55    }
56}
57
58/// # Request
59///
60/// ```text
61///  +-----+-------+------+----------+----------+
62///  | CMD |  RSV  | ATYP | DST.ADDR | DST.PORT |
63///  +-----+-------+------+----------+----------+
64///  |  1  | X'00' |  1   | Variable |    2     |
65///  +-----+-------+------+----------+----------+
66/// ```
67///
68#[derive(Debug, Clone, PartialEq)]
69pub enum Request {
70    Bind(Address),
71    Connect(Address),
72    Associate(Address),
73}
74
75#[rustfmt::skip]
76impl Request {
77    const SOCKS5_CMD_CONNECT:   u8 = 0x01;
78    const SOCKS5_CMD_BIND:      u8 = 0x02;
79    const SOCKS5_CMD_ASSOCIATE: u8 = 0x03;
80}
81
82impl Request {
83    pub async fn from_async_read<R: AsyncRead + Unpin>(
84        reader: &mut BufReader<R>,
85    ) -> io::Result<Self> {
86        let mut buf = [0u8; 2];
87        reader.read_exact(&mut buf).await?;
88
89        let command = buf[0];
90
91        let request = match command {
92            Self::SOCKS5_CMD_BIND => Self::Bind(Address::from_async_read(reader).await?),
93            Self::SOCKS5_CMD_CONNECT => Self::Connect(Address::from_async_read(reader).await?),
94            Self::SOCKS5_CMD_ASSOCIATE => Self::Associate(Address::from_async_read(reader).await?),
95            command => {
96                return Err(io::Error::new(
97                    io::ErrorKind::InvalidData,
98                    format!("Invalid request command: {}", command),
99                ));
100            }
101        };
102
103        Ok(request)
104    }
105}
106
107/// # Address
108///
109/// ```text
110///  +------+----------+----------+
111///  | ATYP | DST.ADDR | DST.PORT |
112///  +------+----------+----------+
113///  |  1   | Variable |    2     |
114///  +------+----------+----------+
115/// ```
116///
117/// ## DST.ADDR BND.ADDR
118///   In an address field (DST.ADDR, BND.ADDR), the ATYP field specifies
119///   the type of address contained within the field:
120///   
121/// o ATYP: X'01'
122///   the address is a version-4 IP address, with a length of 4 octets
123///   
124/// o ATYP: X'03'
125///   the address field contains a fully-qualified domain name.  The first
126///   octet of the address field contains the number of octets of name that
127///   follow, there is no terminating NUL octet.
128///   
129/// o ATYP: X'04'  
130///   the address is a version-6 IP address, with a length of 16 octets.
131///
132#[derive(Debug, Clone, PartialEq)]
133pub enum Address {
134    IPv4(SocketAddrV4),
135    IPv6(SocketAddrV6),
136    Domain(Domain, u16),
137}
138
139static UNSPECIFIED_ADDRESS: LazyLock<Address> =
140    LazyLock::new(|| Address::IPv4(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0)));
141
142#[rustfmt::skip]
143impl Address {
144    const PORT_LENGTH:         usize = 2;
145    const IPV4_ADDRESS_LENGTH: usize = 4;
146    const IPV6_ADDRESS_LENGTH: usize = 16;
147
148    const SOCKS5_ADDRESS_TYPE_IPV4:        u8 = 0x01;
149    const SOCKS5_ADDRESS_TYPE_DOMAIN_NAME: u8 = 0x03;
150    const SOCKS5_ADDRESS_TYPE_IPV6:        u8 = 0x04;
151}
152
153impl Address {
154    #[inline]
155    pub fn unspecified() -> &'static Self {
156        &UNSPECIFIED_ADDRESS
157    }
158
159    pub fn from_socket_addr(addr: SocketAddr) -> Self {
160        match addr {
161            SocketAddr::V4(addr) => Self::IPv4(addr),
162            SocketAddr::V6(addr) => Self::IPv6(addr),
163        }
164    }
165
166    pub async fn from_async_read<R: AsyncRead + Unpin>(
167        reader: &mut BufReader<R>,
168    ) -> io::Result<Self> {
169        let address_type = reader.read_u8().await?;
170
171        match address_type {
172            Self::SOCKS5_ADDRESS_TYPE_IPV4 => {
173                let mut buf = [0u8; Self::IPV4_ADDRESS_LENGTH + Self::PORT_LENGTH];
174                reader.read_exact(&mut buf).await?;
175
176                let ip = Ipv4Addr::new(buf[0], buf[1], buf[2], buf[3]);
177                let port = u16::from_be_bytes([buf[4], buf[5]]);
178
179                Ok(Address::IPv4(SocketAddrV4::new(ip, port)))
180            }
181
182            Self::SOCKS5_ADDRESS_TYPE_IPV6 => {
183                let mut buf = [0u8; Self::IPV6_ADDRESS_LENGTH + Self::PORT_LENGTH];
184                reader.read_exact(&mut buf).await?;
185
186                let ip = Ipv6Addr::from([
187                    buf[0], buf[1], buf[2], buf[3], buf[4], buf[5], buf[6], buf[7], buf[8], buf[9],
188                    buf[10], buf[11], buf[12], buf[13], buf[14], buf[15],
189                ]);
190                let port = u16::from_be_bytes([buf[16], buf[17]]);
191
192                Ok(Address::IPv6(SocketAddrV6::new(ip, port, 0, 0)))
193            }
194
195            Self::SOCKS5_ADDRESS_TYPE_DOMAIN_NAME => {
196                let domain_len = reader.read_u8().await? as usize;
197
198                let mut buf = vec![0u8; domain_len + Self::PORT_LENGTH];
199                reader.read_exact(&mut buf).await?;
200
201                let domain = Bytes::copy_from_slice(&buf[..domain_len]);
202                let port = u16::from_be_bytes([buf[domain_len], buf[domain_len + 1]]);
203
204                Ok(Address::Domain(Domain(domain), port))
205            }
206
207            n => Err(io::Error::new(
208                io::ErrorKind::InvalidData,
209                format!("Invalid address type: {}", n),
210            )),
211        }
212    }
213
214    pub fn from_bytes<B: Buf>(buf: &mut B) -> io::Result<Self> {
215        if buf.remaining() < 1 {
216            return Err(io::Error::new(
217                io::ErrorKind::InvalidData,
218                "Insufficient data for address",
219            ));
220        }
221
222        let address_type = buf.get_u8();
223
224        match address_type {
225            Self::SOCKS5_ADDRESS_TYPE_IPV4 => {
226                if buf.remaining() < Self::IPV4_ADDRESS_LENGTH + Self::PORT_LENGTH {
227                    return Err(io::Error::new(
228                        io::ErrorKind::InvalidData,
229                        "Insufficient data for IPv4 address",
230                    ));
231                }
232
233                let mut ip = [0u8; Self::IPV4_ADDRESS_LENGTH];
234                buf.copy_to_slice(&mut ip);
235
236                let port = buf.get_u16();
237
238                Ok(Address::IPv4(SocketAddrV4::new(Ipv4Addr::from(ip), port)))
239            }
240
241            Self::SOCKS5_ADDRESS_TYPE_IPV6 => {
242                if buf.remaining() < Self::IPV6_ADDRESS_LENGTH + Self::PORT_LENGTH {
243                    return Err(io::Error::new(
244                        io::ErrorKind::InvalidData,
245                        "Insufficient data for IPv6 address",
246                    ));
247                }
248
249                let mut ip = [0u8; Self::IPV6_ADDRESS_LENGTH];
250                buf.copy_to_slice(&mut ip);
251
252                let port = buf.get_u16();
253
254                Ok(Address::IPv6(SocketAddrV6::new(
255                    Ipv6Addr::from(ip),
256                    port,
257                    0,
258                    0,
259                )))
260            }
261
262            Self::SOCKS5_ADDRESS_TYPE_DOMAIN_NAME => {
263                if buf.remaining() < 1 {
264                    return Err(io::Error::new(
265                        io::ErrorKind::InvalidData,
266                        "Insufficient data for domain length",
267                    ));
268                }
269
270                let domain_len = buf.get_u8() as usize;
271
272                if buf.remaining() < domain_len + Self::PORT_LENGTH {
273                    return Err(io::Error::new(
274                        io::ErrorKind::InvalidData,
275                        "Insufficient data for domain name",
276                    ));
277                }
278
279                let mut domain = vec![0u8; domain_len];
280                buf.copy_to_slice(&mut domain);
281
282                let port = buf.get_u16();
283
284                Ok(Address::Domain(Domain(Bytes::from(domain)), port))
285            }
286
287            n => Err(io::Error::new(
288                io::ErrorKind::InvalidData,
289                format!("Invalid address type: {}", n),
290            )),
291        }
292    }
293
294    #[inline]
295    pub fn to_bytes(&self) -> Bytes {
296        let mut bytes = BytesMut::new();
297
298        match self {
299            Self::Domain(domain, port) => {
300                let domain_bytes = domain.as_bytes();
301                bytes.put_u8(Self::SOCKS5_ADDRESS_TYPE_DOMAIN_NAME);
302                bytes.put_u8(domain_bytes.len() as u8);
303                bytes.extend_from_slice(domain_bytes);
304                bytes.extend_from_slice(&port.to_be_bytes());
305            }
306            Self::IPv4(addr) => {
307                bytes.put_u8(Self::SOCKS5_ADDRESS_TYPE_IPV4);
308                bytes.extend_from_slice(&addr.ip().octets());
309                bytes.extend_from_slice(&addr.port().to_be_bytes());
310            }
311            Self::IPv6(addr) => {
312                bytes.put_u8(Self::SOCKS5_ADDRESS_TYPE_IPV6);
313                bytes.extend_from_slice(&addr.ip().octets());
314                bytes.extend_from_slice(&addr.port().to_be_bytes());
315            }
316        }
317
318        bytes.freeze()
319    }
320
321    #[inline]
322    pub fn port(&self) -> u16 {
323        match self {
324            Self::IPv4(addr) => addr.port(),
325            Self::IPv6(addr) => addr.port(),
326            Self::Domain(_, port) => *port,
327        }
328    }
329
330    #[inline]
331    pub fn format_as_string(&self) -> io::Result<String> {
332        match self {
333            Self::IPv4(addr) => Ok(addr.to_string()),
334            Self::IPv6(addr) => Ok(addr.to_string()),
335            Self::Domain(domain, port) => Ok(format!("{}:{}", domain.format_as_str()?, port)),
336        }
337    }
338}
339
340#[derive(Debug, Clone, PartialEq)]
341pub struct Domain(Bytes);
342
343impl Into<Domain> for String {
344    #[inline]
345    fn into(self) -> Domain {
346        Domain(Bytes::from(self))
347    }
348}
349
350impl Into<Domain> for &[u8] {
351    #[inline]
352    fn into(self) -> Domain {
353        Domain(Bytes::copy_from_slice(self))
354    }
355}
356
357impl Into<Domain> for &str {
358    #[inline]
359    fn into(self) -> Domain {
360        Domain(Bytes::copy_from_slice(self.as_bytes()))
361    }
362}
363
364impl Domain {
365    #[inline]
366    pub fn format_as_str(&self) -> io::Result<&str> {
367        use std::str::from_utf8;
368
369        from_utf8(&self.0).map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "Invalid UTF-8"))
370    }
371
372    #[inline]
373    pub fn as_bytes(&self) -> &[u8] {
374        &self.0
375    }
376
377    #[inline]
378    pub fn to_bytes(self) -> Bytes {
379        self.0
380    }
381
382    #[inline]
383    pub fn from_bytes(bytes: Bytes) -> Self {
384        Self(bytes)
385    }
386
387    #[inline]
388    pub fn from_string(string: String) -> Self {
389        string.into()
390    }
391}
392
393impl AsRef<[u8]> for Domain {
394    #[inline]
395    fn as_ref(&self) -> &[u8] {
396        self.as_bytes()
397    }
398}
399
400/// # Response
401///
402/// ```text
403///  +-----+-------+------+----------+----------+
404///  | REP |  RSV  | ATYP | BND.ADDR | BND.PORT |
405///  +-----+-------+------+----------+----------+
406///  |  1  | X'00' |  1   | Variable |    2     |
407///  +-----+-------+------+----------+----------+
408/// ```
409///
410#[derive(Debug, Clone)]
411pub enum Response<'a> {
412    Success(&'a Address),
413    GeneralFailure,
414    ConnectionNotAllowed,
415    NetworkUnreachable,
416    HostUnreachable,
417    ConnectionRefused,
418    TTLExpired,
419    CommandNotSupported,
420    AddressTypeNotSupported,
421    Unassigned(u8),
422}
423
424#[rustfmt::skip]
425impl Response<'_> {
426    const SOCKS5_REPLY_SUCCEEDED:                  u8 = 0x00;
427    const SOCKS5_REPLY_GENERAL_FAILURE:            u8 = 0x01;
428    const SOCKS5_REPLY_CONNECTION_NOT_ALLOWED:     u8 = 0x02;
429    const SOCKS5_REPLY_NETWORK_UNREACHABLE:        u8 = 0x03;
430    const SOCKS5_REPLY_HOST_UNREACHABLE:           u8 = 0x04;
431    const SOCKS5_REPLY_CONNECTION_REFUSED:         u8 = 0x05;
432    const SOCKS5_REPLY_TTL_EXPIRED:                u8 = 0x06;
433    const SOCKS5_REPLY_COMMAND_NOT_SUPPORTED:      u8 = 0x07;
434    const SOCKS5_REPLY_ADDRESS_TYPE_NOT_SUPPORTED: u8 = 0x08;
435}
436
437impl Response<'_> {
438    #[inline]
439    pub fn to_bytes(&self) -> BytesMut {
440        let mut bytes = BytesMut::new();
441
442        let (reply, address) = match &self {
443            Self::GeneralFailure
444            | Self::ConnectionNotAllowed
445            | Self::NetworkUnreachable
446            | Self::HostUnreachable
447            | Self::ConnectionRefused
448            | Self::TTLExpired
449            | Self::CommandNotSupported
450            | Self::AddressTypeNotSupported => (self.as_u8(), Address::unspecified()),
451            Self::Unassigned(code) => (*code, Address::unspecified()),
452            Self::Success(address) => (self.as_u8(), *address),
453        };
454
455        bytes.put_u8(reply);
456        bytes.put_u8(0x00);
457        bytes.extend(address.to_bytes());
458
459        bytes
460    }
461
462    #[rustfmt::skip]
463    #[inline]
464    fn as_u8(&self) -> u8 {
465        match self {
466            Self::Success(_)                 => Self::SOCKS5_REPLY_SUCCEEDED,
467            Self::GeneralFailure             => Self::SOCKS5_REPLY_GENERAL_FAILURE,
468            Self::ConnectionNotAllowed       => Self::SOCKS5_REPLY_CONNECTION_NOT_ALLOWED,
469            Self::NetworkUnreachable         => Self::SOCKS5_REPLY_NETWORK_UNREACHABLE,
470            Self::HostUnreachable            => Self::SOCKS5_REPLY_HOST_UNREACHABLE,
471            Self::ConnectionRefused          => Self::SOCKS5_REPLY_CONNECTION_REFUSED,
472            Self::TTLExpired                 => Self::SOCKS5_REPLY_TTL_EXPIRED,
473            Self::CommandNotSupported        => Self::SOCKS5_REPLY_COMMAND_NOT_SUPPORTED,
474            Self::AddressTypeNotSupported    => Self::SOCKS5_REPLY_ADDRESS_TYPE_NOT_SUPPORTED,
475            Self::Unassigned(code)           => *code
476        }
477    }
478}
479
480/// # UDP Packet
481///
482///
483/// ```text
484///  +-----+------+------+----------+----------+----------+
485///  | RSV | FRAG | ATYP | DST.ADDR | DST.PORT |   DATA   |
486///  +-----+------+------+----------+----------+----------+
487///  |  2  |  1   |  1   | Variable |    2     | Variable |
488///  +-----+------+------+----------+----------+----------+
489/// ```
490///
491#[derive(Debug)]
492pub struct UdpPacket {
493    pub frag: u8,
494    pub address: Address,
495    pub data: Bytes,
496}
497
498impl UdpPacket {
499    pub fn from_bytes<B: Buf>(buf: &mut B) -> io::Result<Self> {
500        if buf.remaining() < 2 {
501            return Err(io::Error::new(
502                io::ErrorKind::InvalidData,
503                "Insufficient data for RSV",
504            ));
505        }
506        buf.advance(2);
507
508        if buf.remaining() < 1 {
509            return Err(io::Error::new(
510                io::ErrorKind::InvalidData,
511                "Insufficient data for FRAG",
512            ));
513        }
514        let frag = buf.get_u8();
515
516        let address = Address::from_bytes(buf)?;
517
518        let data = buf.copy_to_bytes(buf.remaining());
519
520        Ok(Self {
521            frag,
522            address,
523            data,
524        })
525    }
526
527    pub fn to_bytes(&self) -> Bytes {
528        let mut bytes = BytesMut::new();
529
530        bytes.put_u8(0x00);
531        bytes.put_u8(0x00);
532
533        bytes.put_u8(self.frag);
534        bytes.extend(self.address.to_bytes());
535        bytes.extend_from_slice(&self.data);
536
537        bytes.freeze()
538    }
539
540    pub fn un_frag(address: Address, data: Bytes) -> Self {
541        Self {
542            frag: 0,
543            address,
544            data,
545        }
546    }
547}
548
549pub struct Stream<T> {
550    version: u8,
551    from: SocketAddr,
552    inner: BufReader<T>,
553}
554
555impl<T> Stream<T> {
556    pub fn version(&self) -> u8 {
557        self.version
558    }
559
560    pub fn from_addr(&self) -> SocketAddr {
561        self.from
562    }
563}
564
565mod async_impl {
566    use std::io;
567    use std::pin::Pin;
568    use std::task::{Context, Poll};
569
570    use tokio::io::{AsyncRead, AsyncWrite};
571
572    use super::Stream;
573
574    impl<T> AsyncRead for Stream<T>
575    where
576        T: AsyncRead + AsyncWrite + Unpin,
577    {
578        fn poll_read(
579            mut self: Pin<&mut Self>,
580            cx: &mut Context<'_>,
581            buf: &mut tokio::io::ReadBuf<'_>,
582        ) -> Poll<io::Result<()>> {
583            AsyncRead::poll_read(Pin::new(&mut self.inner.get_mut()), cx, buf)
584        }
585    }
586
587    impl<T> AsyncWrite for Stream<T>
588    where
589        T: AsyncRead + AsyncWrite + Unpin,
590    {
591        fn poll_write(
592            mut self: Pin<&mut Self>,
593            cx: &mut Context<'_>,
594            buf: &[u8],
595        ) -> Poll<Result<usize, io::Error>> {
596            AsyncWrite::poll_write(Pin::new(&mut self.inner.get_mut()), cx, buf)
597        }
598
599        fn poll_flush(
600            mut self: Pin<&mut Self>,
601            cx: &mut Context<'_>,
602        ) -> Poll<Result<(), io::Error>> {
603            AsyncWrite::poll_flush(Pin::new(&mut self.inner.get_mut()), cx)
604        }
605
606        fn poll_shutdown(
607            mut self: Pin<&mut Self>,
608            cx: &mut Context<'_>,
609        ) -> Poll<Result<(), io::Error>> {
610            AsyncWrite::poll_shutdown(Pin::new(&mut self.inner.get_mut()), cx)
611        }
612    }
613}
614
615#[cfg(feature = "ombrac")]
616mod ombrac {
617    use super::Address;
618
619    use ombrac::address::Address as OmbracAddress;
620
621    impl Into<OmbracAddress> for Address {
622        #[inline]
623        fn into(self) -> OmbracAddress {
624            match self {
625                Self::Domain(domain, port) => {
626                    OmbracAddress::Domain(domain.format_as_str().unwrap().to_string(), port)
627                }
628                Self::IPv4(addr) => OmbracAddress::IPv4(addr),
629                Self::IPv6(addr) => OmbracAddress::IPv6(addr),
630            }
631        }
632    }
633
634    impl Into<Address> for OmbracAddress {
635        #[inline]
636        fn into(self) -> Address {
637            match self {
638                Self::Domain(domain, port) => Address::Domain(domain.into(), port),
639                Self::IPv4(addr) => Address::IPv4(addr),
640                Self::IPv6(addr) => Address::IPv6(addr),
641            }
642        }
643    }
644}
645
646#[cfg(test)]
647mod tests {
648    mod test_request {
649        use crate::v5::{Address, Request};
650
651        use bytes::{BufMut, BytesMut};
652        use std::io::Cursor;
653        use tokio::io::BufReader;
654
655        #[tokio::test]
656        async fn test_request_from_async_read_connect_ipv4() {
657            let mut buffer = BytesMut::new();
658
659            // Command + Reserved
660            buffer.put_u8(Request::SOCKS5_CMD_CONNECT);
661            buffer.put_u8(0x00); // Reserved
662
663            // Address type + Address + Port
664            buffer.put_u8(Address::SOCKS5_ADDRESS_TYPE_IPV4);
665            buffer.put_slice(&[192, 168, 1, 1]); // IP
666            buffer.put_u16(80); // Port
667
668            let bytes = buffer.freeze();
669            let mut cursor = Cursor::new(bytes);
670            let mut reader = BufReader::new(&mut cursor);
671
672            let request = Request::from_async_read(&mut reader).await.unwrap();
673
674            match request {
675                Request::Connect(addr) => match addr {
676                    Address::IPv4(socket_addr) => {
677                        assert_eq!(socket_addr.ip().octets(), [192, 168, 1, 1]);
678                        assert_eq!(socket_addr.port(), 80);
679                    }
680                    _ => panic!("Should be IPv4 address"),
681                },
682                _ => panic!("Should be Connect request"),
683            }
684        }
685
686        #[tokio::test]
687        async fn test_request_from_async_read_bind_ipv6() {
688            let mut buffer = BytesMut::new();
689
690            // Command + Reserved
691            buffer.put_u8(Request::SOCKS5_CMD_BIND);
692            buffer.put_u8(0x00); // Reserved
693
694            // Address type + Address + Port
695            buffer.put_u8(Address::SOCKS5_ADDRESS_TYPE_IPV6);
696            buffer.put_slice(&[0x20, 0x01, 0x0d, 0xb8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]); // IPv6
697            buffer.put_u16(443); // Port
698
699            let bytes = buffer.freeze();
700            let mut cursor = Cursor::new(bytes);
701            let mut reader = BufReader::new(&mut cursor);
702
703            let request = Request::from_async_read(&mut reader).await.unwrap();
704
705            match request {
706                Request::Bind(addr) => match addr {
707                    Address::IPv6(socket_addr) => {
708                        assert_eq!(
709                            socket_addr.ip().octets(),
710                            [0x20, 0x01, 0x0d, 0xb8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]
711                        );
712                        assert_eq!(socket_addr.port(), 443);
713                    }
714                    _ => panic!("Should be IPv6 address"),
715                },
716                _ => panic!("Should be Bind request"),
717            }
718        }
719
720        #[tokio::test]
721        async fn test_request_from_async_read_associate_domain() {
722            let mut buffer = BytesMut::new();
723
724            // Command + Reserved
725            buffer.put_u8(Request::SOCKS5_CMD_ASSOCIATE);
726            buffer.put_u8(0x00); // Reserved
727
728            // Address type + Address + Port
729            buffer.put_u8(Address::SOCKS5_ADDRESS_TYPE_DOMAIN_NAME);
730            buffer.put_u8(11); // Length of domain name
731            buffer.put_slice(b"example.com"); // Domain name
732            buffer.put_u16(8080); // Port
733
734            let bytes = buffer.freeze();
735            let mut cursor = Cursor::new(bytes);
736            let mut reader = BufReader::new(&mut cursor);
737
738            let request = Request::from_async_read(&mut reader).await.unwrap();
739
740            match request {
741                Request::Associate(addr) => match addr {
742                    Address::Domain(domain, port) => {
743                        assert_eq!(domain.as_bytes(), b"example.com");
744                        assert_eq!(port, 8080);
745                    }
746                    _ => panic!("Should be domain address"),
747                },
748                _ => panic!("Should be Associate request"),
749            }
750        }
751
752        #[tokio::test]
753        async fn test_request_from_async_read_invalid_command() {
754            let mut buffer = BytesMut::new();
755
756            // Invalid Command + Reserved
757            buffer.put_u8(0xFF); // Invalid command
758            buffer.put_u8(0x00); // Reserved
759
760            let bytes = buffer.freeze();
761            let mut cursor = Cursor::new(bytes);
762            let mut reader = BufReader::new(&mut cursor);
763
764            let result = Request::from_async_read(&mut reader).await;
765
766            assert!(result.is_err());
767            if let Err(e) = result {
768                assert_eq!(e.kind(), std::io::ErrorKind::InvalidData);
769            }
770        }
771
772        #[tokio::test]
773        async fn test_request_from_async_read_incomplete_data() {
774            let mut buffer = BytesMut::new();
775
776            // Command only, missing reserved byte
777            buffer.put_u8(Request::SOCKS5_CMD_CONNECT);
778
779            let bytes = buffer.freeze();
780            let mut cursor = Cursor::new(bytes);
781            let mut reader = BufReader::new(&mut cursor);
782
783            let result = Request::from_async_read(&mut reader).await;
784
785            assert!(result.is_err());
786        }
787    }
788
789    mod test_address {
790        use crate::v5::{Address, Domain};
791
792        use bytes::{BufMut, Bytes, BytesMut};
793        use std::io::Cursor;
794        use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
795        use tokio::io::BufReader;
796
797        #[tokio::test]
798        async fn test_address_unspecified() {
799            let unspecified = Address::unspecified();
800            match unspecified {
801                Address::IPv4(addr) => {
802                    assert_eq!(addr.ip(), &Ipv4Addr::UNSPECIFIED);
803                    assert_eq!(addr.port(), 0);
804                }
805                _ => panic!("Unspecified address should be IPv4"),
806            }
807        }
808
809        #[tokio::test]
810        async fn test_address_from_socket_addr_ipv4() {
811            let socket = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), 8080));
812            let address = Address::from_socket_addr(socket);
813
814            match address {
815                Address::IPv4(addr) => {
816                    assert_eq!(addr.ip().octets(), [127, 0, 0, 1]);
817                    assert_eq!(addr.port(), 8080);
818                }
819                _ => panic!("Should be IPv4 address"),
820            }
821        }
822
823        #[tokio::test]
824        async fn test_address_from_socket_addr_ipv6() {
825            let socket = SocketAddr::V6(SocketAddrV6::new(
826                Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1),
827                8080,
828                0,
829                0,
830            ));
831            let address = Address::from_socket_addr(socket);
832
833            match address {
834                Address::IPv6(addr) => {
835                    assert_eq!(
836                        addr.ip().octets(),
837                        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]
838                    );
839                    assert_eq!(addr.port(), 8080);
840                }
841                _ => panic!("Should be IPv6 address"),
842            }
843        }
844
845        #[tokio::test]
846        async fn test_address_to_bytes_ipv4() {
847            let addr = Address::IPv4(SocketAddrV4::new(Ipv4Addr::new(192, 168, 1, 1), 80));
848            let bytes = addr.to_bytes();
849
850            assert_eq!(bytes[0], Address::SOCKS5_ADDRESS_TYPE_IPV4);
851            assert_eq!(bytes[1..5], [192, 168, 1, 1]);
852            assert_eq!(bytes[5..7], [0, 80]); // Port 80 in big-endian
853        }
854
855        #[tokio::test]
856        async fn test_address_to_bytes_ipv6() {
857            let addr = Address::IPv6(SocketAddrV6::new(
858                Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 1),
859                443,
860                0,
861                0,
862            ));
863            let bytes = addr.to_bytes();
864
865            assert_eq!(bytes[0], Address::SOCKS5_ADDRESS_TYPE_IPV6);
866            assert_eq!(
867                bytes[1..17],
868                [0x20, 0x01, 0x0d, 0xb8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]
869            );
870            assert_eq!(bytes[17..19], [1, 187]); // Port 443 in big-endian
871        }
872
873        #[tokio::test]
874        async fn test_address_to_bytes_domain() {
875            let domain = Domain(Bytes::from("example.com"));
876            let addr = Address::Domain(domain, 8080);
877            let bytes = addr.to_bytes();
878
879            assert_eq!(bytes[0], Address::SOCKS5_ADDRESS_TYPE_DOMAIN_NAME);
880            assert_eq!(bytes[1], 11); // Length of "example.com"
881            assert_eq!(&bytes[2..13], b"example.com");
882            assert_eq!(bytes[13..15], [31, 144]); // Port 8080 in big-endian
883        }
884
885        #[tokio::test]
886        async fn test_address_from_bytes_ipv4() {
887            let mut buffer = BytesMut::new();
888            buffer.put_u8(Address::SOCKS5_ADDRESS_TYPE_IPV4);
889            buffer.put_slice(&[192, 168, 1, 1]); // IP
890            buffer.put_u16(80); // Port
891
892            let mut bytes = buffer.freeze();
893            let addr = Address::from_bytes(&mut bytes).unwrap();
894
895            match addr {
896                Address::IPv4(socket_addr) => {
897                    assert_eq!(socket_addr.ip().octets(), [192, 168, 1, 1]);
898                    assert_eq!(socket_addr.port(), 80);
899                }
900                _ => panic!("Should be IPv4 address"),
901            }
902        }
903
904        #[tokio::test]
905        async fn test_address_from_bytes_ipv6() {
906            let mut buffer = BytesMut::new();
907            buffer.put_u8(Address::SOCKS5_ADDRESS_TYPE_IPV6);
908            buffer.put_slice(&[0x20, 0x01, 0x0d, 0xb8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]); // IPv6
909            buffer.put_u16(443); // Port
910
911            let mut bytes = buffer.freeze();
912            let addr = Address::from_bytes(&mut bytes).unwrap();
913
914            match addr {
915                Address::IPv6(socket_addr) => {
916                    assert_eq!(
917                        socket_addr.ip().octets(),
918                        [0x20, 0x01, 0x0d, 0xb8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]
919                    );
920                    assert_eq!(socket_addr.port(), 443);
921                }
922                _ => panic!("Should be IPv6 address"),
923            }
924        }
925
926        #[tokio::test]
927        async fn test_address_from_bytes_domain() {
928            let mut buffer = BytesMut::new();
929            buffer.put_u8(Address::SOCKS5_ADDRESS_TYPE_DOMAIN_NAME);
930            buffer.put_u8(11); // Length of domain name
931            buffer.put_slice(b"example.com"); // Domain name
932            buffer.put_u16(8080); // Port
933
934            let mut bytes = buffer.freeze();
935            let addr = Address::from_bytes(&mut bytes).unwrap();
936
937            match addr {
938                Address::Domain(domain, port) => {
939                    assert_eq!(domain.as_bytes(), b"example.com");
940                    assert_eq!(port, 8080);
941                }
942                _ => panic!("Should be domain address"),
943            }
944        }
945
946        #[tokio::test]
947        async fn test_address_from_async_read_ipv4() {
948            let mut buffer = BytesMut::new();
949            buffer.put_u8(Address::SOCKS5_ADDRESS_TYPE_IPV4);
950            buffer.put_slice(&[192, 168, 1, 1]); // IP
951            buffer.put_u16(80); // Port
952
953            let bytes = buffer.freeze();
954            let mut cursor = Cursor::new(bytes);
955            let mut reader = BufReader::new(&mut cursor);
956
957            let addr = Address::from_async_read(&mut reader).await.unwrap();
958
959            match addr {
960                Address::IPv4(socket_addr) => {
961                    assert_eq!(socket_addr.ip().octets(), [192, 168, 1, 1]);
962                    assert_eq!(socket_addr.port(), 80);
963                }
964                _ => panic!("Should be IPv4 address"),
965            }
966        }
967
968        #[tokio::test]
969        async fn test_address_from_async_read_ipv6() {
970            let mut buffer = BytesMut::new();
971            buffer.put_u8(Address::SOCKS5_ADDRESS_TYPE_IPV6);
972            buffer.put_slice(&[0x20, 0x01, 0x0d, 0xb8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]); // IPv6
973            buffer.put_u16(443); // Port
974
975            let bytes = buffer.freeze();
976            let mut cursor = Cursor::new(bytes);
977            let mut reader = BufReader::new(&mut cursor);
978
979            let addr = Address::from_async_read(&mut reader).await.unwrap();
980
981            match addr {
982                Address::IPv6(socket_addr) => {
983                    assert_eq!(
984                        socket_addr.ip().octets(),
985                        [0x20, 0x01, 0x0d, 0xb8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]
986                    );
987                    assert_eq!(socket_addr.port(), 443);
988                }
989                _ => panic!("Should be IPv6 address"),
990            }
991        }
992
993        #[tokio::test]
994        async fn test_address_from_async_read_domain() {
995            let mut buffer = BytesMut::new();
996            buffer.put_u8(Address::SOCKS5_ADDRESS_TYPE_DOMAIN_NAME);
997            buffer.put_u8(11); // Length of domain name
998            buffer.put_slice(b"example.com"); // Domain name
999            buffer.put_u16(8080); // Port
1000
1001            let bytes = buffer.freeze();
1002            let mut cursor = Cursor::new(bytes);
1003            let mut reader = BufReader::new(&mut cursor);
1004
1005            let addr = Address::from_async_read(&mut reader).await.unwrap();
1006
1007            match addr {
1008                Address::Domain(domain, port) => {
1009                    assert_eq!(domain.as_bytes(), b"example.com");
1010                    assert_eq!(port, 8080);
1011                }
1012                _ => panic!("Should be domain address"),
1013            }
1014        }
1015
1016        #[tokio::test]
1017        async fn test_address_from_bytes_invalid_type() {
1018            let mut buffer = BytesMut::new();
1019            buffer.put_u8(0xFF); // Invalid address type
1020
1021            let mut bytes = buffer.freeze();
1022            let result = Address::from_bytes(&mut bytes);
1023
1024            assert!(result.is_err());
1025        }
1026
1027        #[tokio::test]
1028        async fn test_address_from_bytes_insufficient_data() {
1029            // IPv4 with incomplete data
1030            let mut buffer = BytesMut::new();
1031            buffer.put_u8(Address::SOCKS5_ADDRESS_TYPE_IPV4);
1032            buffer.put_slice(&[192, 168]); // Incomplete IP
1033
1034            let mut bytes = buffer.freeze();
1035            let result = Address::from_bytes(&mut bytes);
1036
1037            assert!(result.is_err());
1038        }
1039
1040        #[tokio::test]
1041        async fn test_address_port() {
1042            let addr1 = Address::IPv4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), 8080));
1043            assert_eq!(addr1.port(), 8080);
1044
1045            let addr2 = Address::IPv6(SocketAddrV6::new(
1046                Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1),
1047                443,
1048                0,
1049                0,
1050            ));
1051            assert_eq!(addr2.port(), 443);
1052
1053            let addr3 = Address::Domain(Domain(Bytes::from("example.com")), 80);
1054            assert_eq!(addr3.port(), 80);
1055        }
1056
1057        #[tokio::test]
1058        async fn test_address_format_as_string() {
1059            let addr1 = Address::IPv4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), 8080));
1060            assert_eq!(addr1.format_as_string().unwrap(), "127.0.0.1:8080");
1061
1062            let addr2 = Address::IPv6(SocketAddrV6::new(
1063                Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1),
1064                443,
1065                0,
1066                0,
1067            ));
1068            assert_eq!(addr2.format_as_string().unwrap(), "[::1]:443");
1069
1070            // This test assumes Domain::domain_str() returns Ok with the domain string
1071            let addr3 = Address::Domain(Domain(Bytes::from("example.com")), 80);
1072            assert_eq!(addr3.format_as_string().unwrap(), "example.com:80");
1073        }
1074    }
1075}