1use std::net::SocketAddr;
2
3use bytes::BytesMut;
4
5use crate::io::{self, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
6use crate::v5::{Request, Response, method::Method};
7
8pub struct Stream<T>(T, SocketAddr);
9
10impl<T> Stream<T> {
11 #[inline]
12 pub fn version(&self) -> u8 {
13 0x05
14 }
15
16 #[inline]
17 pub fn peer_addr(&self) -> SocketAddr {
18 self.1
19 }
20
21 #[inline]
22 pub fn with(inner: T, addr: SocketAddr) -> Self {
23 Self(inner, addr)
24 }
25}
26
27impl<T> Stream<T>
29where
30 T: AsyncRead + AsyncWrite + Unpin,
31{
32 #[inline]
42 pub async fn read_methods(&mut self) -> io::Result<Vec<Method>> {
43 let mut buffer = [0u8; 2];
44 self.0.read_exact(&mut buffer).await?;
45
46 let method_num = buffer[1];
47 if method_num == 1 {
48 let method = self.0.read_u8().await?;
49 return Ok(vec![Method::from_u8(method)]);
50 }
51
52 let mut methods = vec![0u8; method_num as usize];
53 self.0.read_exact(&mut methods).await?;
54
55 let result = methods.into_iter().map(Method::from_u8).collect();
56
57 Ok(result)
58 }
59
60 #[inline]
69 pub async fn write_auth_method(&mut self, method: Method) -> io::Result<usize> {
70 let bytes = [self.version(), method.as_u8()];
71 self.0.write(&bytes).await
72 }
73
74 #[inline]
84 pub async fn read_request(&mut self) -> io::Result<Request> {
85 let _version = self.0.read_u8().await?;
86 Request::from_async_read(&mut self.0).await
87 }
88
89 #[inline]
99 pub async fn write_response<'a>(&mut self, resp: &Response<'a>) -> io::Result<usize> {
100 let bytes = prepend_u8(resp.to_bytes(), self.version());
101 self.0.write(&bytes).await
102 }
103
104 #[inline]
105 pub async fn write_response_unspecified(&mut self) -> io::Result<usize> {
106 use crate::v5::Address;
107 self.write_response(&Response::Success(Address::unspecified()))
108 .await
109 }
110
111 #[inline]
112 pub async fn write_response_unsupported(&mut self) -> io::Result<usize> {
113 self.write_response(&Response::CommandNotSupported).await
114 }
115}
116
117#[inline]
118fn prepend_u8(mut bytes: BytesMut, value: u8) -> BytesMut {
119 bytes.reserve(1);
120
121 unsafe {
122 let ptr = bytes.as_mut_ptr();
123 std::ptr::copy(ptr, ptr.add(1), bytes.len());
124 std::ptr::write(ptr, value);
125 let new_len = bytes.len() + 1;
126 bytes.set_len(new_len);
127 }
128
129 bytes
130}
131
132mod async_impl {
133 use std::io;
134 use std::pin::Pin;
135 use std::task::{Context, Poll};
136
137 use tokio::io::{AsyncRead, AsyncWrite};
138
139 use super::Stream;
140
141 impl<T> AsyncRead for Stream<T>
142 where
143 T: AsyncRead + AsyncWrite + Unpin,
144 {
145 fn poll_read(
146 mut self: Pin<&mut Self>,
147 cx: &mut Context<'_>,
148 buf: &mut tokio::io::ReadBuf<'_>,
149 ) -> Poll<io::Result<()>> {
150 AsyncRead::poll_read(Pin::new(&mut self.0), cx, buf)
151 }
152 }
153
154 impl<T> AsyncWrite for Stream<T>
155 where
156 T: AsyncRead + AsyncWrite + Unpin,
157 {
158 fn poll_write(
159 mut self: Pin<&mut Self>,
160 cx: &mut Context<'_>,
161 buf: &[u8],
162 ) -> Poll<Result<usize, io::Error>> {
163 AsyncWrite::poll_write(Pin::new(&mut self.0), cx, buf)
164 }
165
166 fn poll_flush(
167 mut self: Pin<&mut Self>,
168 cx: &mut Context<'_>,
169 ) -> Poll<Result<(), io::Error>> {
170 AsyncWrite::poll_flush(Pin::new(&mut self.0), cx)
171 }
172
173 fn poll_shutdown(
174 mut self: Pin<&mut Self>,
175 cx: &mut Context<'_>,
176 ) -> Poll<Result<(), io::Error>> {
177 AsyncWrite::poll_shutdown(Pin::new(&mut self.0), cx)
178 }
179 }
180}