1use crate::errors::{BuildSocksKind, ExecuteCmdKind, SocksError, SocksResult};
45use crate::package::{
46 read_package, write_package, AuthMethodsPackage, AuthSelectPackage, PasswordReqPackage,
47 PasswordResPackage, RepliesPackage, RequestsPackage,
48};
49use crate::{
50 is_invalid_password, is_invalid_username, AuthMethod, AuthMethods, PrivateStruct, RepliesRep,
51 RequestCmd, ToSocksAddress, DEFAULT_SERVER_ADDR,
52};
53use bytes::{Bytes, BytesMut};
54use std::net::SocketAddr;
55use tokio::net::TcpStream;
56use tracing::error;
57
58pub struct SocksClientBuilder {
59 server_address: SocketAddr,
60 allow_auth_skip: bool,
61 allow_auth_pass: bool,
62 username: Option<Bytes>,
63 password: Option<Bytes>,
64 _private: PrivateStruct,
65}
66
67impl SocksClientBuilder {
68 pub fn new() -> SocksClientBuilder {
69 return SocksClientBuilder {
70 server_address: DEFAULT_SERVER_ADDR,
71 allow_auth_skip: true,
72 allow_auth_pass: false,
73 username: None,
74 password: None,
75 _private: PrivateStruct,
76 };
77 }
78
79 pub fn server_address(mut self, address: SocketAddr) -> Self {
80 self.server_address = address;
81 self
82 }
83
84 pub fn allow_auth_skip(mut self, allow: bool) -> Self {
85 self.allow_auth_skip = allow;
86 self
87 }
88
89 pub fn credential(mut self, username: &[u8], password: &[u8]) -> Self {
90 self.allow_auth_pass = true;
91 self.username = Some(Bytes::copy_from_slice(username));
92 self.password = Some(Bytes::copy_from_slice(password));
93 self
94 }
95
96 pub fn build(self) -> SocksResult<SocksClient> {
97 let SocksClientBuilder {
98 server_address,
99 allow_auth_skip,
100 allow_auth_pass,
101 username,
102 password,
103 _private,
104 } = self;
105 let mut methods = AuthMethods::new();
106 if allow_auth_skip {
107 methods.insert(AuthMethod::SKIP);
108 }
109 if allow_auth_pass {
110 if username
111 .as_ref()
112 .map(|v| is_invalid_username(v.as_ref()))
113 .unwrap_or(true)
114 {
115 return Err(SocksError::BuildSocksClientErr(
116 BuildSocksKind::InvalidUsername,
117 ));
118 }
119 if password
120 .as_ref()
121 .map(|v| is_invalid_password(v.as_ref()))
122 .unwrap_or(true)
123 {
124 return Err(SocksError::BuildSocksClientErr(
125 BuildSocksKind::InvalidPassword,
126 ));
127 }
128 methods.insert(AuthMethod::PASS);
129 }
130 if methods.len() == 0 {
131 return Err(SocksError::BuildSocksClientErr(
132 BuildSocksKind::InvalidAuthMethod,
133 ));
134 }
135 let client = SocksClient {
136 server_addr: server_address,
137 auth_methods: methods,
138 username,
139 password,
140 _private: PrivateStruct,
141 };
142 return Ok(client);
143 }
144}
145
146pub struct SocksClient {
147 server_addr: SocketAddr,
148 auth_methods: AuthMethods,
149 username: Option<Bytes>,
150 password: Option<Bytes>,
151 _private: PrivateStruct,
152}
153
154impl SocksClient {
155 pub async fn connect(&mut self, addr: impl ToSocksAddress) -> SocksResult<TcpStream> {
156 let connection = self.handshake(addr, RequestCmd::CONNECT).await?;
157 return Ok(connection.proxy_stream);
158 }
159
160 async fn handshake(
161 &mut self,
162 addr: impl ToSocksAddress,
163 cmd: RequestCmd,
164 ) -> SocksResult<ClientConnection> {
165 let mut stream = TcpStream::connect(self.server_addr).await?;
166 let local_addr = stream.local_addr()?;
167 let peer_addr = stream.peer_addr()?;
168
169 let mut buffer = BytesMut::with_capacity(512);
170
171 let methods_pac = AuthMethodsPackage::new(self.auth_methods.clone());
172 write_package(&methods_pac, &mut buffer, &mut stream).await?;
173
174 let select_pac: AuthSelectPackage = read_package(&mut buffer, &mut stream).await?;
175 let method = select_pac.auth_method();
176 if !self.auth_methods.contains(&method) {
177 return Err(SocksError::UnsupportedAuthMethod);
178 }
179
180 if method == AuthMethod::PASS {
181 let password_pac = PasswordReqPackage::new(
182 self.username.as_ref().unwrap(),
183 self.password.as_ref().unwrap(),
184 );
185 write_package(&password_pac, &mut buffer, &mut stream).await?;
186
187 let password_pac: PasswordResPackage = read_package(&mut buffer, &mut stream).await?;
188 if !password_pac.is_success() {
189 return Err(SocksError::PasswordAuthNotPassed);
190 }
191 }
192 let requests_pac = RequestsPackage::new(cmd, addr.to_socks_addr());
193 write_package(&requests_pac, &mut buffer, &mut stream).await?;
194
195 let replies_pac: RepliesPackage = read_package(&mut buffer, &mut stream).await?;
196 if !replies_pac.is_success() {
197 let rep = RepliesRep::from_byte(replies_pac.req_ref().to_byte())?;
198 error!("handshake replies error: {}", rep.message());
199 return Err(SocksError::ExecuteCommandErr(ExecuteCmdKind::Client(
200 rep.to_byte(),
201 )));
202 }
203 let stream = ClientConnection {
204 identifier: 0,
205 local_addr,
206 peer_addr,
207 auth_method: AuthMethod::SKIP,
208 proxy_stream: stream,
209 };
210 return Ok(stream);
211 }
212}
213
214#[derive(Debug)]
215pub(crate) struct ClientConnection {
216 identifier: u64,
217 local_addr: SocketAddr,
218 peer_addr: SocketAddr,
219 auth_method: AuthMethod,
220 proxy_stream: TcpStream,
221}
222
223impl ClientConnection {
224 #[allow(dead_code)]
225 fn identifier(&self) -> u64 {
226 return self.identifier;
227 }
228
229 #[allow(dead_code)]
230 fn local_addr(&self) -> SocketAddr {
231 return self.local_addr;
232 }
233
234 #[allow(dead_code)]
235 fn peer_addr(&self) -> SocketAddr {
236 return self.peer_addr;
237 }
238
239 #[allow(dead_code)]
240 fn auth_method(&self) -> AuthMethod {
241 return self.auth_method;
242 }
243}