socks5_server/connection/
connect.rs1use socks5_proto::{Address, Reply, Response};
4use std::{
5 io::Error,
6 marker::PhantomData,
7 net::SocketAddr,
8 pin::Pin,
9 task::{Context, Poll},
10};
11use tokio::{
12 io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf},
13 net::TcpStream,
14};
15
16pub mod state {
18 #[derive(Debug)]
19 pub struct NeedReply;
20
21 #[derive(Debug)]
22 pub struct Ready;
23}
24
25#[derive(Debug)]
29pub struct Connect<S> {
30 stream: TcpStream,
31 _state: PhantomData<S>,
32}
33
34impl Connect<state::NeedReply> {
35 pub async fn reply(
39 mut self,
40 reply: Reply,
41 addr: Address,
42 ) -> Result<Connect<state::Ready>, (Error, TcpStream)> {
43 let resp = Response::new(reply, addr);
44
45 if let Err(err) = resp.write_to(&mut self.stream).await {
46 return Err((err, self.stream));
47 }
48
49 Ok(Connect::new(self.stream))
50 }
51}
52
53impl<S> Connect<S> {
54 #[inline]
55 pub(super) fn new(stream: TcpStream) -> Self {
56 Self {
57 stream,
58 _state: PhantomData,
59 }
60 }
61
62 #[inline]
64 pub async fn close(&mut self) -> Result<(), Error> {
65 self.stream.shutdown().await
66 }
67
68 #[inline]
70 pub fn local_addr(&self) -> Result<SocketAddr, Error> {
71 self.stream.local_addr()
72 }
73
74 #[inline]
76 pub fn peer_addr(&self) -> Result<SocketAddr, Error> {
77 self.stream.peer_addr()
78 }
79
80 #[inline]
84 pub fn get_ref(&self) -> &TcpStream {
85 &self.stream
86 }
87
88 #[inline]
92 pub fn get_mut(&mut self) -> &mut TcpStream {
93 &mut self.stream
94 }
95
96 #[inline]
98 pub fn into_inner(self) -> TcpStream {
99 self.stream
100 }
101}
102
103impl AsyncRead for Connect<state::Ready> {
104 #[inline]
105 fn poll_read(
106 mut self: Pin<&mut Self>,
107 cx: &mut Context<'_>,
108 buf: &mut ReadBuf<'_>,
109 ) -> Poll<Result<(), Error>> {
110 Pin::new(&mut self.stream).poll_read(cx, buf)
111 }
112}
113
114impl AsyncWrite for Connect<state::Ready> {
115 #[inline]
116 fn poll_write(
117 mut self: Pin<&mut Self>,
118 cx: &mut Context<'_>,
119 buf: &[u8],
120 ) -> Poll<Result<usize, Error>> {
121 Pin::new(&mut self.stream).poll_write(cx, buf)
122 }
123
124 #[inline]
125 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
126 Pin::new(&mut self.stream).poll_flush(cx)
127 }
128
129 #[inline]
130 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
131 Pin::new(&mut self.stream).poll_shutdown(cx)
132 }
133}