socks_lib/v5/
stream.rs

1use std::net::SocketAddr;
2
3use bytes::BytesMut;
4
5use crate::io::{self, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
6use crate::v5::{Request, Response, method::Method};
7
8pub struct Stream<T>(T, SocketAddr);
9
10impl<T> Stream<T> {
11    #[inline]
12    pub fn version(&self) -> u8 {
13        0x05
14    }
15
16    #[inline]
17    pub fn peer_addr(&self) -> SocketAddr {
18        self.1
19    }
20
21    #[inline]
22    pub fn with(inner: T, addr: SocketAddr) -> Self {
23        Self(inner, addr)
24    }
25}
26
27// ===== STREAM Server Side Impl =====
28impl<T> Stream<T>
29where
30    T: AsyncRead + AsyncWrite + Unpin,
31{
32    /// # Methods
33    ///
34    /// ```text
35    ///  +----+----------+----------+
36    ///  |VER | NMETHODS | METHODS  |
37    ///  +----+----------+----------+
38    ///  | 1  |    1     | 1 to 255 |
39    ///  +----+----------+----------+
40    /// ```
41    #[inline]
42    pub async fn read_methods(&mut self) -> io::Result<Vec<Method>> {
43        let mut buffer = [0u8; 2];
44        self.0.read_exact(&mut buffer).await?;
45
46        let method_num = buffer[1];
47        if method_num == 1 {
48            let method = self.0.read_u8().await?;
49            return Ok(vec![Method::from_u8(method)]);
50        }
51
52        let mut methods = vec![0u8; method_num as usize];
53        self.0.read_exact(&mut methods).await?;
54
55        let result = methods.into_iter().map(Method::from_u8).collect();
56
57        Ok(result)
58    }
59
60    ///
61    /// ```text
62    ///  +----+--------+
63    ///  |VER | METHOD |
64    ///  +----+--------+
65    ///  | 1  |   1    |
66    ///  +----+--------+
67    ///  ```
68    #[inline]
69    pub async fn write_auth_method(&mut self, method: Method) -> io::Result<usize> {
70        let bytes = [self.version(), method.as_u8()];
71        self.0.write(&bytes).await
72    }
73
74    ///
75    /// ```text
76    ///  +----+-----+-------+------+----------+----------+
77    ///  |VER | CMD |  RSV  | ATYP | DST.ADDR | DST.PORT |
78    ///  +----+-----+-------+------+----------+----------+
79    ///  | 1  |  1  | X'00' |  1   | Variable |    2     |
80    ///  +----+-----+-------+------+----------+----------+
81    /// ```
82    ///
83    #[inline]
84    pub async fn read_request(&mut self) -> io::Result<Request> {
85        let _version = self.0.read_u8().await?;
86        Request::from_async_read(&mut self.0).await
87    }
88
89    ///
90    /// ```text
91    ///  +----+-----+-------+------+----------+----------+
92    ///  |VER | REP |  RSV  | ATYP | BND.ADDR | BND.PORT |
93    ///  +----+-----+-------+------+----------+----------+
94    ///  | 1  |  1  | X'00' |  1   | Variable |    2     |
95    ///  +----+-----+-------+------+----------+----------+
96    /// ```
97    ///
98    #[inline]
99    pub async fn write_response<'a>(&mut self, resp: &Response<'a>) -> io::Result<usize> {
100        let bytes = prepend_u8(resp.to_bytes(), self.version());
101        self.0.write(&bytes).await
102    }
103
104    #[inline]
105    pub async fn write_response_unspecified(&mut self) -> io::Result<usize> {
106        use crate::v5::Address;
107        self.write_response(&Response::Success(Address::unspecified()))
108            .await
109    }
110
111    #[inline]
112    pub async fn write_response_unsupported(&mut self) -> io::Result<usize> {
113        self.write_response(&Response::CommandNotSupported).await
114    }
115}
116
117#[inline]
118fn prepend_u8(mut bytes: BytesMut, value: u8) -> BytesMut {
119    bytes.reserve(1);
120
121    unsafe {
122        let ptr = bytes.as_mut_ptr();
123        std::ptr::copy(ptr, ptr.add(1), bytes.len());
124        std::ptr::write(ptr, value);
125        let new_len = bytes.len() + 1;
126        bytes.set_len(new_len);
127    }
128
129    bytes
130}
131
132mod async_impl {
133    use std::io;
134    use std::pin::Pin;
135    use std::task::{Context, Poll};
136
137    use tokio::io::{AsyncRead, AsyncWrite};
138
139    use super::Stream;
140
141    impl<T> AsyncRead for Stream<T>
142    where
143        T: AsyncRead + AsyncWrite + Unpin,
144    {
145        fn poll_read(
146            mut self: Pin<&mut Self>,
147            cx: &mut Context<'_>,
148            buf: &mut tokio::io::ReadBuf<'_>,
149        ) -> Poll<io::Result<()>> {
150            AsyncRead::poll_read(Pin::new(&mut self.0), cx, buf)
151        }
152    }
153
154    impl<T> AsyncWrite for Stream<T>
155    where
156        T: AsyncRead + AsyncWrite + Unpin,
157    {
158        fn poll_write(
159            mut self: Pin<&mut Self>,
160            cx: &mut Context<'_>,
161            buf: &[u8],
162        ) -> Poll<Result<usize, io::Error>> {
163            AsyncWrite::poll_write(Pin::new(&mut self.0), cx, buf)
164        }
165
166        fn poll_flush(
167            mut self: Pin<&mut Self>,
168            cx: &mut Context<'_>,
169        ) -> Poll<Result<(), io::Error>> {
170            AsyncWrite::poll_flush(Pin::new(&mut self.0), cx)
171        }
172
173        fn poll_shutdown(
174            mut self: Pin<&mut Self>,
175            cx: &mut Context<'_>,
176        ) -> Poll<Result<(), io::Error>> {
177            AsyncWrite::poll_shutdown(Pin::new(&mut self.0), cx)
178        }
179    }
180}