1use std::sync::Arc;
2use std::time::Duration;
3
4use bytes::{BufMut, BytesMut};
5use tokio::io::{AsyncReadExt, AsyncWriteExt};
6use tokio::net::TcpStream;
7use tokio::sync::RwLock;
8use tokio::time::timeout;
9
10use crate::ProxyAuth;
11use crate::error::{ErrorName, ProxyError};
12use crate::result::ProxyResult;
13
14#[derive(Debug, Clone)]
33pub struct Proxy {
34 proxy_type: ProxyType,
35 proxy_address: String,
36 target: Arc<RwLock<TargetServer>>,
37 timeout: u64,
38 auth: Option<ProxyAuth>,
39}
40
41#[derive(Debug, Clone, PartialEq, Eq)]
43pub enum ProxyType {
44 Socks5,
45 Socks4,
46}
47
48#[derive(Debug, Clone, PartialEq, Eq)]
50struct TargetServer {
51 host: Option<String>,
52 port: Option<u16>,
53}
54
55impl Default for TargetServer {
56 fn default() -> Self {
57 Self { host: None, port: None }
58 }
59}
60
61async fn write_all_to(stream: &mut TcpStream, buffer: Vec<u8>) -> ProxyResult<()> {
63 match timeout(Duration::from_secs(10), stream.write_all(&buffer)).await {
64 Ok(result) => match result {
65 Ok(_) => ProxyResult::Ok(()),
66 Err(e) => ProxyResult::Err(ProxyError::new(ErrorName::StreamError, e.to_string())),
67 },
68 Err(_) => ProxyResult::Err(ProxyError::new(ErrorName::StreamError, "failed to write buffer to stream")),
69 }
70}
71
72async fn read_exact_from<'a>(stream: &mut TcpStream, buffer: &'a mut [u8]) -> ProxyResult<()> {
74 match timeout(Duration::from_secs(10), stream.read_exact(buffer)).await {
75 Ok(result) => match result {
76 Ok(_) => ProxyResult::Ok(()),
77 Err(e) => ProxyResult::Err(ProxyError::new(ErrorName::StreamError, e.to_string())),
78 },
79 Err(_) => ProxyResult::Err(ProxyError::new(ErrorName::StreamError, "failed to read buffer from stream")),
80 }
81}
82
83impl From<String> for Proxy {
84 fn from(value: String) -> Self {
85 let split = value.split("://").collect::<Vec<&str>>();
86 let (protocol, proxy) = (split.get(0).unwrap_or(&"socks5"), split.get(1).unwrap_or(&"127.0.0.1"));
87
88 Self {
89 proxy_address: (*proxy).to_string(),
90 proxy_type: match *protocol {
91 "socks5" => ProxyType::Socks5,
92 "socks4" => ProxyType::Socks4,
93 _ => ProxyType::Socks5,
94 },
95 target: Arc::new(RwLock::new(TargetServer::default())),
96 timeout: 20000,
97 auth: None,
98 }
99 }
100}
101
102impl From<&str> for Proxy {
103 fn from(value: &str) -> Self {
104 let split = value.split("://").collect::<Vec<&str>>();
105 let (protocol, proxy) = (split.get(0).unwrap_or(&"socks5"), split.get(1).unwrap_or(&"127.0.0.1"));
106
107 Self {
108 proxy_address: (*proxy).to_string(),
109 proxy_type: match *protocol {
110 "socks5" => ProxyType::Socks5,
111 "socks4" => ProxyType::Socks4,
112 _ => ProxyType::Socks5,
113 },
114 target: Arc::new(RwLock::new(TargetServer::default())),
115 timeout: 20000,
116 auth: None,
117 }
118 }
119}
120
121impl Proxy {
122 pub fn new(proxy_address: impl Into<String>, proxy_type: ProxyType) -> Self {
124 Self {
125 proxy_address: proxy_address.into(),
126 proxy_type: proxy_type,
127 target: Arc::new(RwLock::new(TargetServer::default())),
128 timeout: 20000,
129 auth: None,
130 }
131 }
132
133 pub fn new_with_auth(proxy_address: impl Into<String>, proxy_type: ProxyType, auth: ProxyAuth) -> Self {
135 Self {
136 proxy_address: proxy_address.into(),
137 proxy_type: proxy_type,
138 target: Arc::new(RwLock::new(TargetServer::default())),
139 timeout: 20000,
140 auth: Some(auth),
141 }
142 }
143
144 pub fn bind(&self, target_host: String, target_port: u16) {
146 match self.target.try_write() {
147 Ok(mut g) => {
148 g.host = Some(target_host);
149 g.port = Some(target_port);
150 }
151 Err(_) => {}
152 }
153 }
154
155 pub fn set_timeout(mut self, timeout: u64) -> Self {
157 self.timeout = timeout;
158 self
159 }
160
161 pub fn set_proxy_type(mut self, proxy_type: ProxyType) -> Self {
163 self.proxy_type = proxy_type;
164 self
165 }
166
167 pub async fn is_available(&self) -> bool {
169 match timeout(Duration::from_millis(self.timeout), TcpStream::connect(&self.proxy_address)).await {
170 Ok(result) => match result {
171 Ok(_) => return true,
172 Err(_) => return false,
173 },
174 Err(_) => return false,
175 }
176 }
177
178 pub fn get_ip(&self) -> Option<String> {
180 if let Some(ip) = self.proxy_address.split(":").collect::<Vec<&str>>().get(0) {
181 Some(ip.to_string())
182 } else {
183 None
184 }
185 }
186
187 pub async fn connect(&self) -> ProxyResult<TcpStream> {
189 let (target_host, target_port) = {
190 let guard = self.target.read().await;
191
192 let Some(host) = guard.host.clone() else {
193 return ProxyResult::Err(ProxyError::new(ErrorName::InvalidData, "target server host not specified"));
194 };
195
196 let Some(port) = guard.port else {
197 return ProxyResult::Err(ProxyError::new(ErrorName::InvalidData, "target server port not specified"));
198 };
199
200 (host, port)
201 };
202
203 let mut stream = match timeout(Duration::from_millis(self.timeout), TcpStream::connect(&self.proxy_address)).await {
204 Ok(result) => match result {
205 Ok(s) => s,
206 Err(_) => return ProxyResult::Err(ProxyError::new(ErrorName::NotConnected, "could not connect to specified server")),
207 },
208 Err(_) => return ProxyResult::Err(ProxyError::new(ErrorName::Timeout, "failed to connect to server within specified time")),
209 };
210
211 match self.proxy_type {
212 ProxyType::Socks5 => self.connect_socks5(&mut stream, target_host, target_port).await?,
213 ProxyType::Socks4 => self.connect_socks4(&mut stream, target_host, target_port).await?,
214 }
215
216 ProxyResult::Ok(stream)
217 }
218
219 async fn connect_socks5(&self, stream: &mut TcpStream, target_host: String, target_port: u16) -> ProxyResult<()> {
221 let greet = if self.auth.is_some() { vec![0x05, 0x02, 0x00, 0x02] } else { vec![0x05, 0x01, 0x00] };
222
223 write_all_to(stream, greet).await?;
224
225 let mut response = [0u8; 2];
226
227 read_exact_from(stream, &mut response).await?;
228
229 if response[0] != 0x05 {
230 return ProxyResult::Err(ProxyError::new(ErrorName::InvalidVersion, "invalid response version"));
231 }
232
233 match response[1] {
234 0x00 => {}
235 0x02 => {
236 if let Some(auth) = &self.auth {
237 let username = auth.username();
238 let password = auth.password();
239
240 if username.len() > 255 || password.len() > 255 {
241 return Err(ProxyError::new(ErrorName::InvalidData, "username or password is too long"));
242 }
243
244 let mut buffer = BytesMut::with_capacity(2 + username.len() + password.len());
245 buffer.put_u8(0x01);
246 buffer.put_u8(username.len() as u8);
247 buffer.put_slice(username.as_bytes());
248 buffer.put_u8(password.len() as u8);
249 buffer.put_slice(password.as_bytes());
250
251 write_all_to(stream, buffer.into()).await?;
252
253 let mut resp = [0u8; 2];
254
255 read_exact_from(stream, &mut resp).await?;
256
257 if resp[0] != 0x01 {
258 return Err(ProxyError::new(ErrorName::AuthFailed, "invalid authorization version"));
259 }
260
261 if resp[1] != 0x00 {
262 return Err(ProxyError::new(ErrorName::AuthFailed, "authorization failed (possibly incorrect password or username)"));
263 }
264 } else {
265 return ProxyResult::Err(ProxyError::new(ErrorName::AuthFailed, "proxy requires authorization (username, password)"));
266 }
267 }
268 _ => return ProxyResult::Err(ProxyError::new(ErrorName::Unsupported, "unsupported authorization method")),
269 }
270
271 let mut request = BytesMut::with_capacity(512);
272 request.put_u8(0x05);
273 request.put_u8(0x01);
274 request.put_u8(0x00);
275
276 if let Ok(ipv4) = target_host.parse::<std::net::Ipv4Addr>() {
277 request.put_u8(0x01);
278 request.put_slice(&ipv4.octets());
279 } else if let Ok(ipv6) = target_host.parse::<std::net::Ipv6Addr>() {
280 request.put_u8(0x04);
281 request.put_slice(&ipv6.octets());
282 } else {
283 request.put_u8(0x03);
284 let host_bytes = target_host.as_bytes();
285
286 if host_bytes.len() > 255 {
287 return ProxyResult::Err(ProxyError::new(ErrorName::InvalidData, "target host is too long"));
288 }
289
290 request.put_u8(host_bytes.len() as u8);
291 request.put_slice(host_bytes);
292 }
293
294 request.put_u16(target_port);
295
296 write_all_to(stream, request.into()).await?;
297
298 let mut header = [0u8; 4];
299
300 read_exact_from(stream, &mut header).await?;
301
302 if header[0] != 0x05 {
303 return ProxyResult::Err(ProxyError::new(ErrorName::InvalidVersion, "invalid response version"));
304 }
305
306 let rep = header[1];
307
308 if rep != 0x00 {
309 return ProxyResult::Err(ProxyError::new(ErrorName::NotConnected, format!("proxy connection error (rep: 0x{:02x})", rep)));
310 }
311
312 let atyp = header[3];
313
314 match atyp {
315 0x01 => {
316 let mut addr = [0u8; 4 + 2];
317 read_exact_from(stream, &mut addr).await?;
318 }
319 0x04 => {
320 let mut addr = [0u8; 16 + 2];
321 read_exact_from(stream, &mut addr).await?;
322 }
323 0x03 => {
324 let mut len = [0u8; 1];
325 read_exact_from(stream, &mut len).await?;
326 let mut rest = vec![0u8; len[0] as usize + 2];
327 read_exact_from(stream, &mut rest).await?;
328 }
329 _ => return ProxyResult::Err(ProxyError::new(ErrorName::InvalidData, format!("unknown address type in reply: 0x{:02x}", atyp))),
330 }
331
332 ProxyResult::Ok(())
333 }
334
335 async fn connect_socks4(&self, stream: &mut TcpStream, target_host: String, target_port: u16) -> ProxyResult<()> {
337 let mut request = BytesMut::with_capacity(512);
338 request.put_u8(0x04);
339 request.put_u8(0x01);
340 request.put_u16(target_port);
341
342 if let Ok(ipv4) = target_host.parse::<std::net::Ipv4Addr>() {
343 request.put_slice(&ipv4.octets());
344
345 if let Some(auth) = &self.auth {
346 request.put_slice(auth.username().as_bytes());
347 } else {
348 request.put_u8(0x00);
349 }
350 } else {
351 request.put_slice(&[0x00, 0x00, 0x00, 0x01]);
352
353 if let Some(auth) = &self.auth {
354 request.put_slice(auth.username().as_bytes());
355 } else {
356 request.put_u8(0x00);
357 }
358
359 if target_host.len() > 255 {
360 return Err(ProxyError::new(ErrorName::InvalidData, "target host is too long"));
361 }
362
363 request.put_slice(target_host.as_bytes());
364 request.put_u8(0x00);
365 }
366
367 write_all_to(stream, request.into()).await?;
368
369 let mut response = [0u8; 8];
370 read_exact_from(stream, &mut response).await?;
371
372 if response[0] != 0x00 {
373 return Err(ProxyError::new(ErrorName::InvalidVersion, "invalid response version"));
374 }
375
376 match response[1] {
377 0x5a => Ok(()),
378 0x5b => Err(ProxyError::new(ErrorName::NotConnected, "request rejected or failed")),
379 0x5c => Err(ProxyError::new(ErrorName::AuthFailed, "client not identd-authenticated")),
380 0x5d => Err(ProxyError::new(ErrorName::AuthFailed, "client identd-user mismatch")),
381 _ => Err(ProxyError::new(ErrorName::Unsupported, format!("unknown response code 0x{:02x}", response[1]))),
382 }
383 }
384}
385
386#[cfg(test)]
387mod tests {
388 use std::io::{Error, ErrorKind};
389
390 use tokio::io::{AsyncReadExt, AsyncWriteExt};
391
392 use crate::result::ProxyResult;
393 use crate::{Proxy, ProxyType};
394
395 #[tokio::test]
396 async fn test_socks5_proxy() -> std::io::Result<()> {
397 let proxy = Proxy::new("212.58.132.5:1080", ProxyType::Socks5);
398 proxy.bind("ipinfo.io".to_string(), 80);
399
400 let mut conn = match proxy.connect().await {
401 ProxyResult::Ok(s) => s,
402 ProxyResult::Err(e) => return Err(Error::new(ErrorKind::NotConnected, e.text())),
403 };
404
405 conn.write_all(b"GET / HTTP/1.0\r\nHost: ipinfo.io\r\n\r\n").await?;
406
407 let mut buf = Vec::new();
408 conn.read_to_end(&mut buf).await?;
409
410 println!("{}", String::from_utf8_lossy(&buf));
411
412 Ok(())
413 }
414
415 #[tokio::test]
416 async fn test_socks4_proxy() -> std::io::Result<()> {
417 let proxy = Proxy::new("68.71.242.118:4145", ProxyType::Socks4);
418 proxy.bind("ipinfo.io".to_string(), 80);
419
420 let mut conn = match proxy.connect().await {
421 ProxyResult::Ok(s) => s,
422 ProxyResult::Err(e) => return Err(Error::new(ErrorKind::NotConnected, e.text())),
423 };
424
425 conn.write_all(b"GET / HTTP/1.0\r\nHost: ipinfo.io\r\n\r\n").await?;
426
427 let mut buf = Vec::new();
428 conn.read_to_end(&mut buf).await?;
429
430 println!("{}", String::from_utf8_lossy(&buf));
431
432 Ok(())
433 }
434}