fast_socks5/util/
stream.rs

1use crate::ReplyError;
2use std::io;
3use std::time::Duration;
4use tokio::io::ErrorKind as IOErrorKind;
5use tokio::net::{TcpStream, ToSocketAddrs};
6use tokio::time::timeout;
7
8/// Easy to destructure bytes buffers by naming each fields:
9///
10/// # Examples (before)
11///
12/// ```ignore
13/// let mut buf = [0u8; 2];
14/// stream.read_exact(&mut buf).await?;
15/// let [version, method_len] = buf;
16///
17/// assert_eq!(version, 0x05);
18/// ```
19///
20/// # Examples (after)
21///
22/// ```ignore
23/// let [version, method_len] = read_exact!(stream, [0u8; 2]);
24///
25/// assert_eq!(version, 0x05);
26/// ```
27#[macro_export]
28macro_rules! read_exact {
29    ($stream: expr, $array: expr) => {{
30        let mut x = $array;
31        //        $stream
32        //            .read_exact(&mut x)
33        //            .await
34        //            .map_err(|_| io_err("lol"))?;
35        $stream.read_exact(&mut x).await.map(|_| x)
36    }};
37}
38
39#[macro_export]
40macro_rules! ready {
41    ($e:expr $(,)?) => {
42        match $e {
43            std::task::Poll::Ready(t) => t,
44            std::task::Poll::Pending => return std::task::Poll::Pending,
45        }
46    };
47}
48
49#[derive(thiserror::Error, Debug)]
50pub enum ConnectError {
51    #[error("Connection timed out")]
52    ConnectionTimeout,
53    #[error("Connection refused: {0}")]
54    ConnectionRefused(#[source] io::Error),
55    #[error("Connection aborted: {0}")]
56    ConnectionAborted(#[source] io::Error),
57    #[error("Connection reset: {0}")]
58    ConnectionReset(#[source] io::Error),
59    #[error("Not connected: {0}")]
60    NotConnected(#[source] io::Error),
61    #[error("Other i/o error: {0}")]
62    Other(#[source] io::Error),
63}
64
65impl ConnectError {
66    pub fn to_reply_error(&self) -> ReplyError {
67        match self {
68            ConnectError::ConnectionTimeout => ReplyError::ConnectionTimeout,
69            ConnectError::ConnectionRefused(_) => ReplyError::ConnectionRefused,
70            ConnectError::ConnectionAborted(_) | ConnectError::ConnectionReset(_) => {
71                ReplyError::ConnectionNotAllowed
72            }
73            ConnectError::NotConnected(_) => ReplyError::NetworkUnreachable,
74            ConnectError::Other(_) => ReplyError::GeneralFailure,
75        }
76    }
77}
78
79pub async fn tcp_connect_with_timeout<T>(
80    addr: T,
81    request_timeout: Duration,
82) -> Result<TcpStream, ConnectError>
83where
84    T: ToSocketAddrs,
85{
86    let fut = tcp_connect(addr);
87    match timeout(request_timeout, fut).await {
88        Ok(result) => result,
89        Err(_) => Err(ConnectError::ConnectionTimeout),
90    }
91}
92
93pub async fn tcp_connect<T>(addr: T) -> Result<TcpStream, ConnectError>
94where
95    T: ToSocketAddrs,
96{
97    match TcpStream::connect(addr).await {
98        Ok(o) => Ok(o),
99        Err(e) => match e.kind() {
100            IOErrorKind::ConnectionRefused => Err(ConnectError::ConnectionRefused(e)),
101            IOErrorKind::ConnectionAborted => Err(ConnectError::ConnectionAborted(e)),
102            IOErrorKind::ConnectionReset => Err(ConnectError::ConnectionReset(e)),
103            IOErrorKind::NotConnected => Err(ConnectError::NotConnected(e)),
104            _ => Err(ConnectError::Other(e)),
105        },
106    }
107}