fast_socks5/util/
stream.rs1use crate::ReplyError;
2use std::io;
3use std::time::Duration;
4use tokio::io::ErrorKind as IOErrorKind;
5use tokio::net::{TcpStream, ToSocketAddrs};
6use tokio::time::timeout;
7
8#[macro_export]
28macro_rules! read_exact {
29 ($stream: expr, $array: expr) => {{
30 let mut x = $array;
31 $stream.read_exact(&mut x).await.map(|_| x)
36 }};
37}
38
39#[macro_export]
40macro_rules! ready {
41 ($e:expr $(,)?) => {
42 match $e {
43 std::task::Poll::Ready(t) => t,
44 std::task::Poll::Pending => return std::task::Poll::Pending,
45 }
46 };
47}
48
49#[derive(thiserror::Error, Debug)]
50pub enum ConnectError {
51 #[error("Connection timed out")]
52 ConnectionTimeout,
53 #[error("Connection refused: {0}")]
54 ConnectionRefused(#[source] io::Error),
55 #[error("Connection aborted: {0}")]
56 ConnectionAborted(#[source] io::Error),
57 #[error("Connection reset: {0}")]
58 ConnectionReset(#[source] io::Error),
59 #[error("Not connected: {0}")]
60 NotConnected(#[source] io::Error),
61 #[error("Other i/o error: {0}")]
62 Other(#[source] io::Error),
63}
64
65impl ConnectError {
66 pub fn to_reply_error(&self) -> ReplyError {
67 match self {
68 ConnectError::ConnectionTimeout => ReplyError::ConnectionTimeout,
69 ConnectError::ConnectionRefused(_) => ReplyError::ConnectionRefused,
70 ConnectError::ConnectionAborted(_) | ConnectError::ConnectionReset(_) => {
71 ReplyError::ConnectionNotAllowed
72 }
73 ConnectError::NotConnected(_) => ReplyError::NetworkUnreachable,
74 ConnectError::Other(_) => ReplyError::GeneralFailure,
75 }
76 }
77}
78
79pub async fn tcp_connect_with_timeout<T>(
80 addr: T,
81 request_timeout: Duration,
82) -> Result<TcpStream, ConnectError>
83where
84 T: ToSocketAddrs,
85{
86 let fut = tcp_connect(addr);
87 match timeout(request_timeout, fut).await {
88 Ok(result) => result,
89 Err(_) => Err(ConnectError::ConnectionTimeout),
90 }
91}
92
93pub async fn tcp_connect<T>(addr: T) -> Result<TcpStream, ConnectError>
94where
95 T: ToSocketAddrs,
96{
97 match TcpStream::connect(addr).await {
98 Ok(o) => Ok(o),
99 Err(e) => match e.kind() {
100 IOErrorKind::ConnectionRefused => Err(ConnectError::ConnectionRefused(e)),
101 IOErrorKind::ConnectionAborted => Err(ConnectError::ConnectionAborted(e)),
102 IOErrorKind::ConnectionReset => Err(ConnectError::ConnectionReset(e)),
103 IOErrorKind::NotConnected => Err(ConnectError::NotConnected(e)),
104 _ => Err(ConnectError::Other(e)),
105 },
106 }
107}