socks_http_kit/
lib.rs

1//! A lightweight library for SOCKS5 and HTTP proxy protocol encoding and parsing,
2//! designed to facilitate complex proxy applications.
3//!
4//! This library serves as a foundation layer for higher-level proxy protocols.
5//! It provides a set of Tokio-based asynchronous functions specifically for
6//! parsing and processing SOCKS5 and HTTP proxy protocol requests and responses.
7//! The library employs I/O-agnostic design, meaning it doesn't spawn internal
8//! threads, establish network connections, or perform DNS resolution.
9//! Instead, it delegates these controls entirely to the user code,
10//! enabling flexible integration with various proxy applications.
11//!
12//! Socks-Http-Kit supports:
13//!
14//! - SOCKS5 client and server implementations
15//!     - Full support for `CONNECT`, `BIND`, and `UDP_ASSOCIATE` commands.
16//!     - Username/password authentication mechanism.
17//!
18//! - HTTP proxy client and server implementations.
19//!     - HTTP BASIC authentication support.
20//!
21//! ### SOCKS5
22//!
23//! - Use [`socks5_connect`] to send handshake to SOCKS5 servers,
24//!   with optional authentication information.
25//! - Use [`socks5_accept`] to receive and parse handshake from SOCKS5 clients,
26//!   returning the SOCKS5 command type and target address.
27//! - Use [`socks5_finalize_accept`] to send request processing results back
28//!   to SOCKS5 clients, completing the handshake process.
29//! - [`socks5_read_udp_header`] parses SOCKS5 UDP protocol headers from
30//!   UDP packet buffers.
31//! - [`socks5_write_udp_header`] writes SOCKS5 UDP protocol headers to
32//!   specified buffers.
33//!
34//! ### HTTP
35//!
36//! - Use [`http_connect`] to send handshake to HTTP proxy servers,
37//!   with optional authentication information.
38//! - Use [`http_accept`] to receive and parse handshake from HTTP clients,
39//!   extracting target address information.
40//! - Use [`http_finalize_accept`] to send processing results back to HTTP clients,
41//!   completing the proxy handshake process.
42//!
43//! ### Address
44//!
45//! - [`decode_from_reader`] and [`encode_to_writer`] provide functionality
46//!   to decode/encode SOCKS5-style addresses from asynchronous streams.
47//! - [`decode_from_buf`] and [`encode_to_buf`] support decoding/encoding
48//!   SOCKS5-style addresses in memory buffers, suitable for UDP transport and similar scenarios.
49//!
50//! # Cargo Features
51//! The library provides two optional features, both disabled by default:
52//!
53//! - `socks5`: Enables SOCKS5 proxy protocol functionality, including client
54//!   and server communication, authentication, and UDP parsing and encoding functions.
55//! - `http`: Enables HTTP proxy protocol functionality.
56//!
57//! [`decode_from_reader`]: Address::decode_from_reader
58//! [`encode_to_writer`]: Address::encode_to_writer
59//! [`decode_from_buf`]: Address::decode_from_buf
60//! [`encode_to_buf`]: Address::encode_to_buf
61#![warn(missing_debug_implementations, missing_docs, unreachable_pub)]
62#![cfg_attr(docsrs, feature(doc_cfg))]
63use std::{
64    fmt::{Display, Formatter},
65    io::{Error, ErrorKind, Result},
66    net::{Ipv4Addr, Ipv6Addr},
67    result,
68};
69
70use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
71
72#[cfg(feature = "http")]
73mod http;
74
75#[cfg(feature = "socks5")]
76mod socks5;
77
78#[cfg(test)]
79#[doc(hidden)]
80pub mod test_utils;
81
82#[cfg(feature = "http")]
83#[cfg_attr(docsrs, doc(cfg(feature = "http")))]
84pub use http::{HttpError, HttpReply, http_accept, http_connect, http_finalize_accept};
85#[cfg(feature = "socks5")]
86#[cfg_attr(docsrs, doc(cfg(feature = "socks5")))]
87pub use socks5::{
88    Socks5Command,
89    Socks5Error,
90    Socks5Reply,
91    socks5_accept,
92    socks5_connect,
93    socks5_finalize_accept,
94    socks5_read_udp_header,
95    socks5_write_udp_header,
96};
97
98/// Represents a network address in various supported formats.
99///
100/// This enum is used to specify target addresses for SOCKS5 and HTTP proxy connections,
101/// supporting IPv4, IPv6, and domain name address types as defined in RFC 1928.
102#[derive(Clone, Debug, Eq, PartialEq, Hash)]
103pub enum Address {
104    /// An IPv4 address with a port number.
105    IPv4((Ipv4Addr, u16)),
106
107    /// A domain name with a port number.
108    DomainName((String, u16)),
109
110    /// An IPv6 address with a port number.
111    IPv6((Ipv6Addr, u16)),
112}
113
114impl Address {
115    /// Decodes a SOCKS5-like address from an asynchronous reader.
116    ///
117    /// This method reads a network address from the provided asynchronous reader using the
118    /// SOCKS5 address format (RFC 1928). It reads the address type byte, followed by the
119    /// appropriate address data and port number.
120    ///
121    /// According to RFC 1928, SOCKS5 address format is:
122    /// ```text
123    /// +------+----------+----------+
124    /// | ATYP | DST.ADDR | DST.PORT |
125    /// +------+----------+----------+
126    /// |  1   | Variable |    2     |
127    /// +------+----------+----------+
128    /// ```
129    /// `ATYP`: Address type - 0x01 (IPv4), 0x03 (domain name), 0x04 (IPv6)
130    ///
131    /// `ADDR`: Destination address, format depends on ATYP
132    ///
133    /// `PORT`: Destination port, network byte order (big-endian)
134    pub async fn decode_from_reader<T>(reader: &mut T) -> Result<(Self, usize)>
135    where
136        T: AsyncRead + Unpin,
137    {
138        let addr_type = AddressType::try_from(reader.read_u8().await?)?;
139        match addr_type {
140            AddressType::IPv4 => {
141                let mut ip = [0u8; 4];
142                reader.read_exact(&mut ip).await?;
143                let port = reader.read_u16().await?;
144
145                // len(addr_type) + len(ip) + len(port)
146                Ok((Address::IPv4((Ipv4Addr::from(ip), port)), 1 + 4 + 2))
147            }
148            AddressType::DomainName => {
149                let len = reader.read_u8().await? as usize;
150                let mut domain = vec![0u8; len];
151                reader.read_exact(&mut domain).await?;
152                let domain_str =
153                    String::from_utf8(domain).map_err(|_| AddrError::InvalidDomainNameEncoding)?;
154                let port = reader.read_u16().await?;
155
156                // len(addr_type) + len(domain_len) + len(domain) + len(port)
157                Ok((Address::DomainName((domain_str, port)), 1 + 1 + len + 2))
158            }
159            AddressType::IPv6 => {
160                let mut ip = [0u8; 16];
161                reader.read_exact(&mut ip).await?;
162                let port = reader.read_u16().await?;
163
164                // len(addr_type) + len(ip) + len(port)
165                Ok((Address::IPv6((Ipv6Addr::from(ip), port)), 1 + 16 + 2))
166            }
167        }
168    }
169
170    /// Encodes the address to a SOCKS5-like format and writes it to an asynchronous writer.
171    pub async fn encode_to_writer<T>(&self, writer: &mut T) -> Result<usize>
172    where
173        T: AsyncWrite + Unpin,
174    {
175        match self {
176            Address::IPv4((ip, port)) => {
177                writer.write_u8(AddressType::IPv4 as u8).await?;
178                writer.write_all(&ip.octets()).await?;
179                writer.write_u16(*port).await?;
180                Ok(1 + 4 + 2)
181            }
182            Address::DomainName((domain, port)) => {
183                let domain_bytes = domain.as_bytes();
184                if domain_bytes.len() > 255 {
185                    return Err(AddrError::DomainNameTooLong.into());
186                }
187                writer.write_u8(AddressType::DomainName as u8).await?;
188                writer.write_u8(domain_bytes.len() as u8).await?;
189                writer.write_all(domain_bytes).await?;
190                writer.write_u16(*port).await?;
191                Ok(1 + 1 + domain_bytes.len() + 2)
192            }
193            Address::IPv6((ip, port)) => {
194                writer.write_u8(AddressType::IPv6 as u8).await?;
195                writer.write_all(&ip.octets()).await?;
196                writer.write_u16(*port).await?;
197                Ok(1 + 16 + 2)
198            }
199        }
200    }
201
202    /// Decodes a SOCKS5-like address from a byte buffer.
203    pub fn decode_from_buf(buf: &[u8]) -> Result<(Self, usize)> {
204        let mut cursor = Cursor::new(buf);
205
206        let addr_type = AddressType::try_from(cursor.read_u8()?)?;
207        match addr_type {
208            AddressType::IPv4 => {
209                let mut ip = [0u8; 4];
210                cursor.read_slice(&mut ip)?;
211                let port = cursor.read_u16()?;
212
213                // len(addr_type) + len(ip) + len(port)
214                Ok((Address::IPv4((Ipv4Addr::from(ip), port)), 1 + 4 + 2))
215            }
216            AddressType::DomainName => {
217                let len = cursor.read_u8()? as usize;
218                let mut domain = vec![0u8; len];
219                cursor.read_slice(&mut domain)?;
220                let domain_str =
221                    String::from_utf8(domain).map_err(|_| AddrError::InvalidDomainNameEncoding)?;
222                let port = cursor.read_u16()?;
223
224                // len(addr_type) + len(domain_len) + len(domain) + len(port)
225                Ok((Address::DomainName((domain_str, port)), 1 + 1 + len + 2))
226            }
227            AddressType::IPv6 => {
228                let mut ip = [0u8; 16];
229                cursor.read_slice(&mut ip)?;
230                let port = cursor.read_u16()?;
231
232                // len(addr_type) + len(ip) + len(port)
233                Ok((Address::IPv6((Ipv6Addr::from(ip), port)), 1 + 16 + 2))
234            }
235        }
236    }
237
238    /// Encodes the address to a SOCKS5-like format and writes it to a byte buffer.
239    pub fn encode_to_buf(&self, buf: &mut [u8]) -> Result<usize> {
240        let mut cursor = CursorMut::new(buf);
241        match self {
242            Address::IPv4((ip, port)) => {
243                cursor.write_u8(AddressType::IPv4 as u8)?;
244                cursor.write_slice(&ip.octets())?;
245                cursor.write_u16(*port)?;
246                Ok(1 + 4 + 2)
247            }
248            Address::DomainName((domain, port)) => {
249                let domain_bytes = domain.as_bytes();
250                if domain_bytes.len() > 255 {
251                    return Err(AddrError::DomainNameTooLong.into());
252                }
253                cursor.write_u8(AddressType::DomainName as u8)?;
254                cursor.write_u8(domain_bytes.len() as u8)?;
255                cursor.write_slice(domain_bytes)?;
256                cursor.write_u16(*port)?;
257                Ok(1 + 1 + domain_bytes.len() + 2)
258            }
259            Address::IPv6((ip, port)) => {
260                cursor.write_u8(AddressType::IPv6 as u8)?;
261                cursor.write_slice(&ip.octets())?;
262                cursor.write_u16(*port)?;
263                Ok(1 + 16 + 2)
264            }
265        }
266    }
267}
268
269impl From<Address> for String {
270    /// Converts an `Address` into an HTTP-style text representation.
271    ///
272    /// This implementation formats the address in HTTP-style notation:
273    /// - IPv4: "`192.168.1.1:8080`"
274    /// - IPv6: "`[2001:db8::1]:8080`"
275    /// - Domain: "`example.com:443`"
276    ///
277    /// This format is suitable for use in HTTP headers and other textual representations.
278    fn from(value: Address) -> Self {
279        (&value).into()
280    }
281}
282
283impl From<&Address> for String {
284    /// Converts an `&Address` into an HTTP-style text representation.
285    fn from(address: &Address) -> Self {
286        match address {
287            Address::IPv4((ip, port)) => format!("{}:{}", ip, port),
288            // IPv6 addresses need to be enclosed in square brackets
289            Address::IPv6((ip, port)) => format!("[{}]:{}", ip, port),
290            Address::DomainName((domain, port)) => format!("{}:{}", domain, port),
291        }
292    }
293}
294
295impl TryFrom<String> for Address {
296    type Error = AddrError;
297
298    /// Attempts to parse an HTTP-style text address into an `Address`.
299    fn try_from(value: String) -> result::Result<Self, Self::Error> {
300        Address::try_from(value.as_str())
301    }
302}
303
304impl TryFrom<&str> for Address {
305    type Error = AddrError;
306
307    /// Attempts to parse an HTTP-style text address into an `Address`.
308    fn try_from(string: &str) -> result::Result<Self, Self::Error> {
309        if string.starts_with('[') {
310            // IPv6 format: [IPv6]:port
311            let end_bracket_pos = string
312                .rfind(']')
313                .ok_or(AddrError::InvalidIPv6MissingClosingBracket)?;
314
315            if end_bracket_pos + 1 >= string.len()
316                || &string[end_bracket_pos + 1..end_bracket_pos + 2] != ":"
317            {
318                return Err(AddrError::InvalidIPv6MissingPortSeparator);
319            }
320
321            let host = &string[1..end_bracket_pos]; // Remove brackets
322            let port_str = &string[end_bracket_pos + 2..];
323
324            // Parse port and IPv6 address
325            let port = port_str
326                .parse::<u16>()
327                .map_err(|_| AddrError::InvalidPortNumber)?;
328            let ipv6 = host
329                .parse::<Ipv6Addr>()
330                .map_err(|_| AddrError::InvalidIPv6Address)?;
331
332            Ok(Address::IPv6((ipv6, port)))
333        } else {
334            // IPv4 or domain name format: host:port
335            let last_colon_pos = string
336                .rfind(':')
337                .ok_or(AddrError::InvalidTargetAddressMissingPortSeparator)?;
338
339            let host = &string[0..last_colon_pos];
340            let port_str = &string[last_colon_pos + 1..];
341
342            // Parse port
343            let port = port_str
344                .parse::<u16>()
345                .map_err(|_| AddrError::InvalidPortNumber)?;
346
347            // Try to parse as IPv4 address, otherwise treat as domain name
348            if let Ok(ipv4) = host.parse::<Ipv4Addr>() {
349                Ok(Address::IPv4((ipv4, port)))
350            } else {
351                Ok(Address::DomainName((host.to_string(), port)))
352            }
353        }
354    }
355}
356
357/// Authentication methods supported by the proxy protocol.
358///
359/// This enum represents the authentication methods that can be used
360/// for SOCKS5 (as defined in RFC 1928 and RFC 1929) and HTTP proxy protocols.
361#[derive(Clone, Debug, Default, Eq, PartialEq, Hash)]
362pub enum AuthMethod {
363    /// No authentication required. This is the default method.
364    #[default]
365    NoAuth,
366    /// Username and password authentication.
367    UserPass {
368        /// Username. Must be a valid UTF-8 string with length not exceeding 255 bytes.
369        username: String,
370
371        /// Password. Must be a valid UTF-8 string with length not exceeding 255 bytes.
372        password: String,
373    },
374}
375
376#[derive(Debug, Clone, Copy, PartialEq)]
377enum AddressType {
378    IPv4 = 0x01,
379    DomainName = 0x03,
380    IPv6 = 0x04,
381}
382
383impl TryFrom<u8> for AddressType {
384    type Error = Error;
385
386    fn try_from(value: u8) -> Result<Self> {
387        match value {
388            0x01 => Ok(AddressType::IPv4),
389            0x03 => Ok(AddressType::DomainName),
390            0x04 => Ok(AddressType::IPv6),
391            _ => Err(AddrError::UnsupportedAddressType.into()),
392        }
393    }
394}
395
396/// Errors that can occur address decoding operations.
397#[derive(Clone, Debug, Eq, PartialEq)]
398#[non_exhaustive]
399pub enum AddrError {
400    /// The address type byte is not a supported address type.
401    UnsupportedAddressType,
402    /// The domain name exceeds maximum allowed length (255 bytes).
403    DomainNameTooLong,
404    /// The domain name contains invalid UTF-8 encoding.
405    InvalidDomainNameEncoding,
406
407    /// IPv6 address format is missing the closing bracket.
408    InvalidIPv6MissingClosingBracket,
409    /// IPv6 address format is missing the port separator after the closing bracket.
410    InvalidIPv6MissingPortSeparator,
411    /// Target address is missing the port separator.
412    InvalidTargetAddressMissingPortSeparator,
413    /// Port number is not a valid integer between 0-65535.
414    InvalidPortNumber,
415    /// IPv6 address contains invalid format or characters.
416    InvalidIPv6Address,
417}
418
419impl Display for AddrError {
420    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
421        match self {
422            Self::UnsupportedAddressType => write!(f, "Unsupported address type"),
423            Self::DomainNameTooLong => write!(f, "Domain name too long"),
424            Self::InvalidDomainNameEncoding => write!(f, "Invalid domain name encoding"),
425            Self::InvalidIPv6MissingClosingBracket => {
426                write!(f, "Invalid IPv6 address format: missing closing bracket")
427            }
428            Self::InvalidIPv6MissingPortSeparator => {
429                write!(f, "Invalid IPv6 address format: missing port separator")
430            }
431            Self::InvalidTargetAddressMissingPortSeparator => {
432                write!(f, "Invalid target address format: missing port separator")
433            }
434            Self::InvalidPortNumber => write!(f, "Invalid port number"),
435            Self::InvalidIPv6Address => write!(f, "Invalid IPv6 address"),
436        }
437    }
438}
439
440impl std::error::Error for AddrError {}
441
442impl From<AddrError> for Error {
443    fn from(e: AddrError) -> Self {
444        match e {
445            AddrError::UnsupportedAddressType => Error::new(ErrorKind::InvalidData, e),
446            AddrError::DomainNameTooLong => Error::new(ErrorKind::InvalidInput, e),
447            AddrError::InvalidDomainNameEncoding => Error::new(ErrorKind::InvalidData, e),
448            AddrError::InvalidIPv6MissingClosingBracket => Error::new(ErrorKind::InvalidData, e),
449            AddrError::InvalidIPv6MissingPortSeparator => Error::new(ErrorKind::InvalidData, e),
450            AddrError::InvalidTargetAddressMissingPortSeparator => {
451                Error::new(ErrorKind::InvalidData, e)
452            }
453            AddrError::InvalidPortNumber => Error::new(ErrorKind::InvalidData, e),
454            AddrError::InvalidIPv6Address => Error::new(ErrorKind::InvalidData, e),
455        }
456    }
457}
458
459struct Cursor<'a> {
460    buf: &'a [u8],
461    pos: usize,
462}
463
464impl<'a> Cursor<'a> {
465    fn new(buf: &'a [u8]) -> Self {
466        Self { buf, pos: 0 }
467    }
468
469    fn read_u8(&mut self) -> Result<u8> {
470        let p = self
471            .buf
472            .get(self.pos)
473            .ok_or(Error::new(ErrorKind::UnexpectedEof, "buffer underflow"))?;
474        self.pos += 1;
475        Ok(*p)
476    }
477
478    fn read_u16(&mut self) -> Result<u16> {
479        let p = self
480            .buf
481            .get(self.pos..self.pos + 2)
482            .ok_or(Error::new(ErrorKind::UnexpectedEof, "buffer underflow"))?;
483        self.pos += 2;
484        Ok(u16::from_be_bytes(p.try_into().unwrap()))
485    }
486
487    fn read_slice(&mut self, buf: &mut [u8]) -> Result<()> {
488        let p = self
489            .buf
490            .get(self.pos..self.pos + buf.len())
491            .ok_or(Error::new(ErrorKind::UnexpectedEof, "buffer underflow"))?;
492        self.pos += buf.len();
493        buf.copy_from_slice(p);
494        Ok(())
495    }
496}
497
498struct CursorMut<'a> {
499    buf: &'a mut [u8],
500    pos: usize,
501}
502
503impl<'a> CursorMut<'a> {
504    fn new(buf: &'a mut [u8]) -> Self {
505        Self { buf, pos: 0 }
506    }
507
508    fn write_u8(&mut self, value: u8) -> Result<()> {
509        let p = self
510            .buf
511            .get_mut(self.pos)
512            .ok_or(Error::new(ErrorKind::WriteZero, "buffer overflow"))?;
513        *p = value;
514        self.pos += 1;
515        Ok(())
516    }
517
518    fn write_u16(&mut self, value: u16) -> Result<()> {
519        let p = self
520            .buf
521            .get_mut(self.pos..self.pos + 2)
522            .ok_or(Error::new(ErrorKind::WriteZero, "buffer overflow"))?;
523        p.copy_from_slice(&value.to_be_bytes());
524        self.pos += 2;
525        Ok(())
526    }
527
528    fn write_slice(&mut self, value: &[u8]) -> Result<()> {
529        let p = self
530            .buf
531            .get_mut(self.pos..self.pos + value.len())
532            .ok_or(Error::new(ErrorKind::WriteZero, "buffer overflow"))?;
533        p.copy_from_slice(value);
534        self.pos += value.len();
535        Ok(())
536    }
537}
538
539#[cfg(test)]
540mod test {
541    use std::net::Ipv4Addr;
542
543    use tokio::task;
544
545    use super::*;
546    use crate::test_utils::*;
547
548    #[tokio::test]
549    async fn test_http_connect_accept_finalize_no_auth() {
550        let target_addresses = [
551            Address::IPv4((Ipv4Addr::new(192, 168, 1, 1), 8080)),
552            Address::DomainName(("example.com".to_string(), 443)),
553            Address::IPv6((
554                Ipv6Addr::new(0x20, 0x01, 0x0d, 0xb8, 0x00, 0x00, 0x00, 0x01),
555                8080,
556            )),
557        ];
558        let auth_method = AuthMethod::NoAuth;
559
560        for target in target_addresses {
561            let (mut client_stream, mut server_stream) = create_mock_stream();
562
563            let target_s = target.clone();
564            let target_c = target.clone();
565            let auth_s = auth_method.clone();
566            let auth_c = auth_method.clone();
567
568            let server_task = task::spawn(async move {
569                let received_addr = http_accept(&mut server_stream, &auth_s).await?;
570                assert_eq!(received_addr, target_s);
571
572                http_finalize_accept(&mut server_stream, &HttpReply::Ok).await?;
573                Ok::<_, Error>(())
574            });
575
576            let client_task = task::spawn(async move {
577                http_connect(&mut client_stream, &target_c, &auth_c).await?;
578                Ok::<_, Error>(())
579            });
580
581            let (server_result, client_result) = tokio::join!(server_task, client_task);
582            server_result.unwrap().unwrap();
583            client_result.unwrap().unwrap();
584        }
585    }
586
587    #[tokio::test]
588    async fn test_http_connect_accept_finalize_userpass() {
589        let target_addresses = [
590            Address::IPv4((Ipv4Addr::new(192, 168, 1, 1), 8080)),
591            Address::DomainName(("example.com".to_string(), 443)),
592            Address::IPv6((
593                Ipv6Addr::new(0x20, 0x01, 0x0d, 0xb8, 0x00, 0x00, 0x00, 0x01),
594                8080,
595            )),
596        ];
597        let auth_method = AuthMethod::UserPass {
598            username: "user".to_string(),
599            password: "pass".to_string(),
600        };
601
602        for target in target_addresses {
603            let (mut client_stream, mut server_stream) = create_mock_stream();
604
605            let target_s = target.clone();
606            let target_c = target.clone();
607            let auth_s = auth_method.clone();
608            let auth_c = auth_method.clone();
609
610            let server_task = task::spawn(async move {
611                let received_addr = http_accept(&mut server_stream, &auth_s).await?;
612                assert_eq!(received_addr, target_s);
613
614                http_finalize_accept(&mut server_stream, &HttpReply::Ok).await?;
615                Ok::<_, Error>(())
616            });
617
618            let client_task = task::spawn(async move {
619                http_connect(&mut client_stream, &target_c, &auth_c).await?;
620                Ok::<_, Error>(())
621            });
622
623            let (server_result, client_result) = tokio::join!(server_task, client_task);
624            server_result.unwrap().unwrap();
625            client_result.unwrap().unwrap();
626        }
627    }
628
629    #[tokio::test]
630    async fn test_socks5_connect_accept_finalize_no_auth() {
631        let target_addresses = [
632            Address::IPv4((Ipv4Addr::new(192, 168, 1, 1), 8080)),
633            Address::DomainName(("example.com".to_string(), 443)),
634            Address::IPv6((
635                Ipv6Addr::new(0x20, 0x01, 0x0d, 0xb8, 0x00, 0x00, 0x00, 0x01),
636                8080,
637            )),
638        ];
639        let auth_method = AuthMethod::NoAuth;
640        let commands = [
641            Socks5Command::Connect,
642            Socks5Command::Bind,
643            Socks5Command::UdpAssociate,
644        ];
645
646        for target in target_addresses {
647            for commmand in commands {
648                let (mut client_stream, mut server_stream) = create_mock_stream();
649
650                let target_s = target.clone();
651                let target_c = target.clone();
652                let auth_s = auth_method.clone();
653                let auth_c = auth_method.clone();
654
655                let server_task = task::spawn(async move {
656                    let (cmd, received_addr) = socks5_accept(&mut server_stream, &auth_s).await?;
657                    assert_eq!(cmd, commmand);
658                    assert_eq!(received_addr, target_s);
659
660                    socks5_finalize_accept(
661                        &mut server_stream,
662                        &Socks5Reply::Succeeded,
663                        &received_addr,
664                    )
665                    .await?;
666                    Ok::<_, Error>(())
667                });
668
669                let client_task = task::spawn(async move {
670                    let received_addr =
671                        socks5_connect(&mut client_stream, &commmand, &target_c, &[auth_c]).await?;
672                    assert_eq!(received_addr, target_c);
673                    Ok::<_, Error>(())
674                });
675
676                let (server_result, client_result) = tokio::join!(server_task, client_task);
677                server_result.unwrap().unwrap();
678                client_result.unwrap().unwrap();
679            }
680        }
681    }
682
683    #[tokio::test]
684    async fn test_socks5_connect_accept_finalize_userpass() {
685        let target_addresses = [
686            Address::IPv4((Ipv4Addr::new(192, 168, 1, 1), 8080)),
687            Address::DomainName(("example.com".to_string(), 443)),
688            Address::IPv6((
689                Ipv6Addr::new(0x20, 0x01, 0x0d, 0xb8, 0x00, 0x00, 0x00, 0x01),
690                8080,
691            )),
692        ];
693        let auth_method = AuthMethod::UserPass {
694            username: "user".to_string(),
695            password: "pass".to_string(),
696        };
697        let commands = [
698            Socks5Command::Connect,
699            Socks5Command::Bind,
700            Socks5Command::UdpAssociate,
701        ];
702
703        for target in target_addresses {
704            for commmand in commands {
705                let (mut client_stream, mut server_stream) = create_mock_stream();
706
707                let target_s = target.clone();
708                let target_c = target.clone();
709                let auth_s = auth_method.clone();
710                let auth_c = auth_method.clone();
711
712                let server_task = task::spawn(async move {
713                    let (cmd, received_addr) = socks5_accept(&mut server_stream, &auth_s).await?;
714                    assert_eq!(cmd, commmand);
715                    assert_eq!(received_addr, target_s);
716
717                    socks5_finalize_accept(
718                        &mut server_stream,
719                        &Socks5Reply::Succeeded,
720                        &received_addr,
721                    )
722                    .await?;
723                    Ok::<_, Error>(())
724                });
725
726                let client_task = task::spawn(async move {
727                    let received_addr =
728                        socks5_connect(&mut client_stream, &commmand, &target_c, &[auth_c]).await?;
729                    assert_eq!(received_addr, target_c);
730                    Ok::<_, Error>(())
731                });
732
733                let (server_result, client_result) = tokio::join!(server_task, client_task);
734                server_result.unwrap().unwrap();
735                client_result.unwrap().unwrap();
736            }
737        }
738    }
739
740    #[test]
741    fn test_socks5_udp_encode_decode() {
742        let addresses = [
743            Address::IPv4((Ipv4Addr::new(192, 168, 1, 1), 8080)),
744            Address::DomainName(("example.com".to_string(), 443)),
745            Address::IPv6((
746                Ipv6Addr::new(0x20, 0x01, 0x0d, 0xb8, 0x00, 0x00, 0x00, 0x01),
747                8080,
748            )),
749        ];
750
751        for original_addr in addresses {
752            let mut buffer = vec![0u8; 300];
753
754            let write_len = socks5_write_udp_header(&original_addr, &mut buffer).unwrap();
755
756            // Verify that the first 3 bytes of the header are [0, 0, 0]
757            // (two reserved bytes and the fragment byte)
758            assert_eq!(&buffer[0..3], &[0, 0, 0]);
759
760            let (decoded_addr, read_len) = socks5_read_udp_header(&buffer).unwrap();
761
762            assert_eq!(write_len, read_len);
763            assert_eq!(original_addr, decoded_addr);
764        }
765    }
766
767    #[tokio::test]
768    async fn test_encode_decode_with_stream() {
769        let addresses = [
770            Address::IPv4((Ipv4Addr::new(192, 168, 1, 1), 8080)),
771            Address::DomainName(("example.com".to_string(), 443)),
772            Address::IPv6((
773                Ipv6Addr::new(0x20, 0x01, 0x0d, 0xb8, 0x00, 0x00, 0x00, 0x01),
774                8080,
775            )),
776        ];
777
778        for original_addr in addresses {
779            let (mut stream1, mut stream2) = create_mock_stream();
780
781            let write_len = original_addr.encode_to_writer(&mut stream1).await.unwrap();
782            let (decoded_addr, read_len) = Address::decode_from_reader(&mut stream2).await.unwrap();
783
784            assert_eq!(write_len, read_len);
785            assert_eq!(original_addr, decoded_addr);
786        }
787    }
788
789    #[test]
790    fn test_encode_decode_with_buffer() {
791        let addresses = [
792            Address::IPv4((Ipv4Addr::new(192, 168, 1, 1), 8080)),
793            Address::DomainName(("example.com".to_string(), 443)),
794            Address::IPv6((
795                Ipv6Addr::new(0x20, 0x01, 0x0d, 0xb8, 0x00, 0x00, 0x00, 0x01),
796                8080,
797            )),
798        ];
799
800        for original_addr in addresses {
801            let mut buffer = vec![0u8; 300];
802
803            let write_len = original_addr.encode_to_buf(&mut buffer).unwrap();
804            let (decoded_addr, read_len) = Address::decode_from_buf(&buffer).unwrap();
805
806            assert_eq!(write_len, read_len);
807            assert_eq!(original_addr, decoded_addr);
808        }
809    }
810
811    #[test]
812    fn test_encode_decode_text() {
813        let address_pairs = [
814            (
815                Address::IPv4((Ipv4Addr::new(192, 168, 1, 1), 8080)),
816                "192.168.1.1:8080",
817            ),
818            (
819                Address::DomainName(("example.com".to_string(), 443)),
820                "example.com:443",
821            ),
822            (
823                Address::IPv6((
824                    Ipv6Addr::new(0x20, 0x01, 0x0d, 0xb8, 0x00, 0x00, 0x00, 0x01),
825                    8080,
826                )),
827                "[20:1:d:b8::1]:8080",
828            ),
829        ];
830
831        for (addr, expected_str) in address_pairs {
832            let addr_to_string = String::from(&addr);
833            assert_eq!(addr_to_string, expected_str);
834
835            let string_to_addr = Address::try_from(expected_str).unwrap();
836            assert_eq!(string_to_addr, addr);
837
838            let round_trip = Address::try_from(String::from(&addr)).unwrap();
839            assert_eq!(round_trip, addr);
840        }
841    }
842}