use std::net::SocketAddr;
use bytes::BytesMut;
use crate::io::{self, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use crate::v5::{Request, Response, method::Method};
pub struct Stream<T> {
inner: T,
peer_addr: SocketAddr,
local_addr: SocketAddr,
}
impl<T> Stream<T> {
#[inline]
pub fn version(&self) -> u8 {
0x05
}
#[inline]
pub fn peer_addr(&self) -> SocketAddr {
self.peer_addr
}
#[inline]
pub fn local_addr(&self) -> SocketAddr {
self.local_addr
}
#[inline]
pub fn with(inner: T, peer_addr: SocketAddr, local_addr: SocketAddr) -> Self {
Self {
inner,
peer_addr,
local_addr,
}
}
}
impl<T> Stream<T>
where
T: AsyncRead + AsyncWrite + Unpin,
{
#[inline]
pub async fn read_methods(&mut self) -> io::Result<Vec<Method>> {
let mut buffer = [0u8; 2];
self.read_exact(&mut buffer).await?;
let method_num = buffer[1];
if method_num == 1 {
let method = self.read_u8().await?;
return Ok(vec![Method::from_u8(method)]);
}
let mut methods = vec![0u8; method_num as usize];
self.read_exact(&mut methods).await?;
let result = methods.into_iter().map(Method::from_u8).collect();
Ok(result)
}
#[inline]
pub async fn write_auth_method(&mut self, method: Method) -> io::Result<usize> {
let bytes = [self.version(), method.as_u8()];
self.write(&bytes).await
}
#[inline]
pub async fn read_request(&mut self) -> io::Result<Request> {
let _version = self.read_u8().await?;
Request::from_async_read(self).await
}
#[inline]
pub async fn write_response<'a>(&mut self, resp: &Response<'a>) -> io::Result<usize> {
let bytes = prepend_u8(resp.to_bytes(), self.version());
self.write(&bytes).await
}
#[inline]
pub async fn write_response_unspecified(&mut self) -> io::Result<usize> {
use crate::v5::Address;
self.write_response(&Response::Success(Address::unspecified()))
.await
}
#[inline]
pub async fn write_response_unsupported(&mut self) -> io::Result<usize> {
self.write_response(&Response::CommandNotSupported).await
}
}
#[inline]
fn prepend_u8(mut bytes: BytesMut, value: u8) -> BytesMut {
bytes.reserve(1);
unsafe {
let ptr = bytes.as_mut_ptr();
std::ptr::copy(ptr, ptr.add(1), bytes.len());
std::ptr::write(ptr, value);
let new_len = bytes.len() + 1;
bytes.set_len(new_len);
}
bytes
}
mod async_impl {
use super::Stream;
use std::io;
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
impl<T: AsyncRead + Unpin> AsyncRead for Stream<T> {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
Pin::new(&mut self.inner).poll_read(cx, buf)
}
}
impl<T: AsyncWrite + Unpin> AsyncWrite for Stream<T> {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
Pin::new(&mut self.inner).poll_write(cx, buf)
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.inner).poll_flush(cx)
}
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.inner).poll_shutdown(cx)
}
}
}