dlzht_socks5/
client.rs

1//! client module include [`SocksClientBuilder`] and [`SocksClient`]
2//!
3//! ### Run client without any authorization
4//! ```
5//! use dlzht_socks5::client::SocksClientBuilder;
6//! use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4};
7//!
8//! #[tokio::main]
9//! async fn main() {
10//!     let address = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), 8080));
11//!     let mut client = SocksClientBuilder::new()
12//!         .server_address(address)
13//!         .allow_auth_skip(true)
14//!         .build()
15//!         .unwrap();
16//!     let mut stream = client
17//!         .connect(("127.0.0.1".to_string(), 9000))
18//!         .await
19//!         .unwrap();
20//! }
21//! ```
22//!
23//! ### Run client with password authorization
24//!
25//! ```
26//! use dlzht_socks5::client::SocksClientBuilder;
27//! use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4};
28//!
29//! #[tokio::main]
30//! async fn main() {
31//!     let address = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), 8080));
32//!     let mut client = SocksClientBuilder::new()
33//!         .server_address(address)
34//!         .credential(b"username", b"password")
35//!         .build()
36//!         .unwrap();
37//!     let mut stream = client
38//!         .connect(("127.0.0.1".to_string(), 9000))
39//!         .await
40//!         .unwrap();
41//! }
42//! ```
43
44use crate::errors::{BuildSocksKind, ExecuteCmdKind, SocksError, SocksResult};
45use crate::package::{
46    read_package, write_package, AuthMethodsPackage, AuthSelectPackage, PasswordReqPackage,
47    PasswordResPackage, RepliesPackage, RequestsPackage,
48};
49use crate::{
50    is_invalid_password, is_invalid_username, AuthMethod, AuthMethods, PrivateStruct, RepliesRep,
51    RequestCmd, ToSocksAddress, DEFAULT_SERVER_ADDR,
52};
53use bytes::{Bytes, BytesMut};
54use std::net::SocketAddr;
55use tokio::net::TcpStream;
56use tracing::error;
57
58pub struct SocksClientBuilder {
59    server_address: SocketAddr,
60    allow_auth_skip: bool,
61    allow_auth_pass: bool,
62    username: Option<Bytes>,
63    password: Option<Bytes>,
64    _private: PrivateStruct,
65}
66
67impl SocksClientBuilder {
68    pub fn new() -> SocksClientBuilder {
69        return SocksClientBuilder {
70            server_address: DEFAULT_SERVER_ADDR,
71            allow_auth_skip: true,
72            allow_auth_pass: false,
73            username: None,
74            password: None,
75            _private: PrivateStruct,
76        };
77    }
78
79    pub fn server_address(mut self, address: SocketAddr) -> Self {
80        self.server_address = address;
81        self
82    }
83
84    pub fn allow_auth_skip(mut self, allow: bool) -> Self {
85        self.allow_auth_skip = allow;
86        self
87    }
88
89    pub fn credential(mut self, username: &[u8], password: &[u8]) -> Self {
90        self.allow_auth_pass = true;
91        self.username = Some(Bytes::copy_from_slice(username));
92        self.password = Some(Bytes::copy_from_slice(password));
93        self
94    }
95
96    pub fn build(self) -> SocksResult<SocksClient> {
97        let SocksClientBuilder {
98            server_address,
99            allow_auth_skip,
100            allow_auth_pass,
101            username,
102            password,
103            _private,
104        } = self;
105        let mut methods = AuthMethods::new();
106        if allow_auth_skip {
107            methods.insert(AuthMethod::SKIP);
108        }
109        if allow_auth_pass {
110            if username
111                .as_ref()
112                .map(|v| is_invalid_username(v.as_ref()))
113                .unwrap_or(true)
114            {
115                return Err(SocksError::BuildSocksClientErr(
116                    BuildSocksKind::InvalidUsername,
117                ));
118            }
119            if password
120                .as_ref()
121                .map(|v| is_invalid_password(v.as_ref()))
122                .unwrap_or(true)
123            {
124                return Err(SocksError::BuildSocksClientErr(
125                    BuildSocksKind::InvalidPassword,
126                ));
127            }
128            methods.insert(AuthMethod::PASS);
129        }
130        if methods.len() == 0 {
131            return Err(SocksError::BuildSocksClientErr(
132                BuildSocksKind::InvalidAuthMethod,
133            ));
134        }
135        let client = SocksClient {
136            server_addr: server_address,
137            auth_methods: methods,
138            username,
139            password,
140            _private: PrivateStruct,
141        };
142        return Ok(client);
143    }
144}
145
146pub struct SocksClient {
147    server_addr: SocketAddr,
148    auth_methods: AuthMethods,
149    username: Option<Bytes>,
150    password: Option<Bytes>,
151    _private: PrivateStruct,
152}
153
154impl SocksClient {
155    pub async fn connect(&mut self, addr: impl ToSocksAddress) -> SocksResult<TcpStream> {
156        let connection = self.handshake(addr, RequestCmd::CONNECT).await?;
157        return Ok(connection.proxy_stream);
158    }
159
160    async fn handshake(
161        &mut self,
162        addr: impl ToSocksAddress,
163        cmd: RequestCmd,
164    ) -> SocksResult<ClientConnection> {
165        let mut stream = TcpStream::connect(self.server_addr).await?;
166        let local_addr = stream.local_addr()?;
167        let peer_addr = stream.peer_addr()?;
168
169        let mut buffer = BytesMut::with_capacity(512);
170
171        let methods_pac = AuthMethodsPackage::new(self.auth_methods.clone());
172        write_package(&methods_pac, &mut buffer, &mut stream).await?;
173
174        let select_pac: AuthSelectPackage = read_package(&mut buffer, &mut stream).await?;
175        let method = select_pac.auth_method();
176        if !self.auth_methods.contains(&method) {
177            return Err(SocksError::UnsupportedAuthMethod);
178        }
179
180        if method == AuthMethod::PASS {
181            let password_pac = PasswordReqPackage::new(
182                self.username.as_ref().unwrap(),
183                self.password.as_ref().unwrap(),
184            );
185            write_package(&password_pac, &mut buffer, &mut stream).await?;
186
187            let password_pac: PasswordResPackage = read_package(&mut buffer, &mut stream).await?;
188            if !password_pac.is_success() {
189                return Err(SocksError::PasswordAuthNotPassed);
190            }
191        }
192        let requests_pac = RequestsPackage::new(cmd, addr.to_socks_addr());
193        write_package(&requests_pac, &mut buffer, &mut stream).await?;
194
195        let replies_pac: RepliesPackage = read_package(&mut buffer, &mut stream).await?;
196        if !replies_pac.is_success() {
197            let rep = RepliesRep::from_byte(replies_pac.req_ref().to_byte())?;
198            error!("handshake replies error: {}", rep.message());
199            return Err(SocksError::ExecuteCommandErr(ExecuteCmdKind::Client(
200                rep.to_byte(),
201            )));
202        }
203        let stream = ClientConnection {
204            identifier: 0,
205            local_addr,
206            peer_addr,
207            auth_method: AuthMethod::SKIP,
208            proxy_stream: stream,
209        };
210        return Ok(stream);
211    }
212}
213
214#[derive(Debug)]
215pub(crate) struct ClientConnection {
216    identifier: u64,
217    local_addr: SocketAddr,
218    peer_addr: SocketAddr,
219    auth_method: AuthMethod,
220    proxy_stream: TcpStream,
221}
222
223impl ClientConnection {
224    #[allow(dead_code)]
225    fn identifier(&self) -> u64 {
226        return self.identifier;
227    }
228
229    #[allow(dead_code)]
230    fn local_addr(&self) -> SocketAddr {
231        return self.local_addr;
232    }
233
234    #[allow(dead_code)]
235    fn peer_addr(&self) -> SocketAddr {
236        return self.peer_addr;
237    }
238
239    #[allow(dead_code)]
240    fn auth_method(&self) -> AuthMethod {
241        return self.auth_method;
242    }
243}