1use std::io;
2use std::net::SocketAddr;
3
4use bytes::BytesMut;
5use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, BufReader};
6use tokio::net::{TcpListener, ToSocketAddrs};
7
8use super::{Method, Request, Response, Stream};
9
10pub struct Server {
11 listener: TcpListener,
12}
13
14impl Server {
15 const VERSION_5: u8 = 0x05;
16
17 pub async fn bind<A: ToSocketAddrs>(addr: A) -> io::Result<Self> {
18 Ok(Self {
19 listener: TcpListener::bind(addr).await?,
20 })
21 }
22
23 pub fn local_addr(&self) -> io::Result<SocketAddr> {
24 self.listener.local_addr()
25 }
26
27 #[inline]
28 pub async fn accept(
29 &self,
30 ) -> io::Result<(
31 Request,
32 Stream<impl AsyncRead + AsyncWrite + Unpin + 'static>,
33 )> {
34 let (inner, from) = self.listener.accept().await?;
35 let inner = BufReader::new(inner);
36 let mut stream = Stream::with(Self::VERSION_5, from, inner);
37
38 let _methods = stream.read_methods().await?;
39
40 stream.write_auth_method(Method::NoAuthentication).await?;
42
43 let request = stream.read_request().await?;
44
45 Ok((request, stream))
46 }
47}
48
49impl<T> Stream<T>
50where
51 T: AsyncRead + AsyncWrite + Unpin,
52{
53 fn with<A: Into<SocketAddr>>(version: u8, from: A, inner: BufReader<T>) -> Self {
54 Self {
55 version,
56 from: from.into(),
57 inner,
58 }
59 }
60
61 #[inline]
71 async fn read_methods(&mut self) -> io::Result<Vec<Method>> {
72 let mut buffer = [0u8; 2];
73 self.inner.read_exact(&mut buffer).await?;
74
75 let method_num = buffer[1];
76 if method_num == 1 {
77 let method = self.inner.read_u8().await?;
78 return Ok(vec![Method::from_u8(method)]);
79 }
80
81 let mut methods = vec![0u8; method_num as usize];
82 self.inner.read_exact(&mut methods).await?;
83
84 let result = methods.into_iter().map(|e| Method::from_u8(e)).collect();
85
86 Ok(result)
87 }
88
89 #[inline]
98 async fn write_auth_method(&mut self, method: Method) -> io::Result<usize> {
99 let bytes = [self.version, method.as_u8()];
100 self.inner.write(&bytes).await
101 }
102
103 #[inline]
113 async fn read_request(&mut self) -> io::Result<Request> {
114 let _version = self.inner.read_u8().await?;
115 Request::from_async_read(&mut self.inner).await
116 }
117
118 #[inline]
128 pub async fn write_response<'a>(&mut self, resp: &Response<'a>) -> io::Result<usize> {
129 let bytes = prepend_u8(resp.to_bytes(), self.version);
130 self.inner.write(&bytes).await
131 }
132}
133
134fn prepend_u8(mut bytes: BytesMut, value: u8) -> BytesMut {
135 bytes.reserve(1);
136
137 unsafe {
138 let ptr = bytes.as_mut_ptr();
139 std::ptr::copy(ptr, ptr.add(1), bytes.len());
140 std::ptr::write(ptr, value);
141 let new_len = bytes.len() + 1;
142 bytes.set_len(new_len);
143 }
144
145 bytes
146}