use std::io::Write as _;
use std::mem::MaybeUninit;
use std::net::Shutdown;
use std::os::fd::{AsFd, AsRawFd, BorrowedFd, RawFd};
use std::pin::Pin;
use std::sync::Arc;
use std::task::{ready, Context, Poll};
use std::{fmt, io};
use socket2::{SockRef, Socket};
use tokio::io::unix::AsyncFd;
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use uni_addr::{UniAddr, UniAddrInner};
wrapper_lite::wrapper!(
#[wrapper_impl(AsRef)]
#[wrapper_impl(AsMut)]
#[wrapper_impl(BorrowMut)]
#[wrapper_impl(DerefMut)]
pub struct UniStream {
inner: AsyncFd<Socket>,
local_addr: UniAddr,
peer_addr: UniAddr,
}
);
#[allow(clippy::missing_fields_in_debug)]
impl fmt::Debug for UniStream {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("UniStream")
.field("local_addr", &self.local_addr)
.field("peer_addr", &self.peer_addr)
.finish()
}
}
impl AsFd for UniStream {
#[inline]
fn as_fd(&self) -> BorrowedFd<'_> {
self.inner.as_fd()
}
}
impl AsRawFd for UniStream {
#[inline]
fn as_raw_fd(&self) -> RawFd {
self.inner.as_raw_fd()
}
}
impl TryFrom<tokio::net::TcpStream> for UniStream {
type Error = io::Error;
#[inline]
fn try_from(stream: tokio::net::TcpStream) -> Result<Self, Self::Error> {
let peer_addr = UniAddr::from(stream.peer_addr()?);
let local_addr = UniAddr::from(stream.local_addr()?);
stream
.into_std()
.map(Into::into)
.and_then(AsyncFd::new)
.map(|inner| Self {
inner,
local_addr,
peer_addr,
})
}
}
impl TryFrom<tokio::net::UnixStream> for UniStream {
type Error = io::Error;
#[inline]
fn try_from(stream: tokio::net::UnixStream) -> Result<Self, Self::Error> {
let peer_addr = UniAddr::from(stream.peer_addr()?);
let local_addr = UniAddr::from(stream.local_addr()?);
stream
.into_std()
.map(Into::into)
.and_then(AsyncFd::new)
.map(|inner| Self {
inner,
local_addr,
peer_addr,
})
}
}
impl TryFrom<std::net::TcpStream> for UniStream {
type Error = io::Error;
#[inline]
fn try_from(stream: std::net::TcpStream) -> Result<Self, Self::Error> {
stream.set_nonblocking(true)?;
let peer_addr = UniAddr::from(stream.peer_addr()?);
let local_addr = UniAddr::from(stream.local_addr()?);
AsyncFd::new(stream.into()).map(|inner| Self {
inner,
local_addr,
peer_addr,
})
}
}
impl TryFrom<std::os::unix::net::UnixStream> for UniStream {
type Error = io::Error;
#[inline]
fn try_from(stream: std::os::unix::net::UnixStream) -> Result<Self, Self::Error> {
stream.set_nonblocking(true)?;
let peer_addr = UniAddr::from(stream.peer_addr()?);
let local_addr = UniAddr::from(stream.local_addr()?);
AsyncFd::new(stream.into()).map(|inner| Self {
inner,
local_addr,
peer_addr,
})
}
}
impl UniStream {
pub async fn connect(addr: &UniAddr) -> io::Result<Self> {
match addr.as_inner() {
UniAddrInner::Inet(addr) => tokio::net::TcpStream::connect(addr)
.await
.and_then(Self::try_from),
UniAddrInner::Unix(addr) => tokio::net::UnixStream::connect(addr.to_os_string())
.await
.and_then(Self::try_from),
UniAddrInner::Host(addr) => tokio::net::TcpStream::connect(&**addr)
.await
.and_then(Self::try_from),
_ => Err(io::Error::new(
io::ErrorKind::Other,
"unsupported address type",
)),
}
}
pub fn as_socket_ref(&self) -> SockRef<'_> {
self.inner.get_ref().into()
}
#[inline]
pub const fn local_addr(&self) -> &UniAddr {
&self.local_addr
}
#[inline]
pub const fn peer_addr(&self) -> &UniAddr {
&self.peer_addr
}
pub async fn peek(&self, buf: &mut [u8]) -> io::Result<usize> {
loop {
let mut guard = self.inner.readable().await?;
#[allow(unsafe_code)]
let buf = unsafe { &mut *(buf as *mut [u8] as *mut [MaybeUninit<u8>]) };
match guard.try_io(|inner| inner.get_ref().peek(buf)) {
Ok(result) => return result,
Err(_would_block) => {}
}
}
}
pub fn poll_peek(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<usize>> {
loop {
let mut guard = ready!(self.inner.poll_read_ready(cx))?;
#[allow(unsafe_code)]
let unfilled = unsafe { buf.unfilled_mut() };
match guard.try_io(|inner| inner.get_ref().peek(unfilled)) {
Ok(Ok(len)) => {
#[allow(unsafe_code)]
unsafe {
buf.assume_init(len);
};
buf.advance(len);
return Poll::Ready(Ok(len));
}
Ok(Err(err)) => return Poll::Ready(Err(err)),
Err(_would_block) => {}
}
}
}
#[inline]
pub fn into_split(self) -> (OwnedReadHalf, OwnedWriteHalf) {
let this = Arc::new(self);
(
OwnedReadHalf::const_from(this.clone()),
OwnedWriteHalf::const_from(this),
)
}
fn poll_read_priv(&self, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<io::Result<()>> {
loop {
let mut guard = ready!(self.inner.poll_read_ready(cx))?;
#[allow(unsafe_code)]
let unfilled = unsafe { buf.unfilled_mut() };
match guard.try_io(|inner| inner.get_ref().recv(unfilled)) {
Ok(Ok(len)) => {
#[allow(unsafe_code)]
unsafe {
buf.assume_init(len);
};
buf.advance(len);
return Poll::Ready(Ok(()));
}
Ok(Err(err)) => return Poll::Ready(Err(err)),
Err(_would_block) => {}
}
}
}
fn poll_write_priv(&self, cx: &mut Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>> {
loop {
let mut guard = ready!(self.inner.poll_write_ready(cx))?;
match guard.try_io(|inner| inner.get_ref().send(buf)) {
Ok(result) => return Poll::Ready(result),
Err(_would_block) => {}
}
}
}
#[inline]
fn flush_priv(&self) -> io::Result<()> {
self.inner.get_ref().flush()
}
#[inline]
fn shutdown_priv(&self, shutdown: Shutdown) -> io::Result<()> {
self.inner.get_ref().shutdown(shutdown)
}
}
impl AsyncRead for UniStream {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
self.poll_read_priv(cx, buf)
}
}
impl AsyncWrite for UniStream {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
self.poll_write_priv(cx, buf)
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(self.flush_priv())
}
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(self.shutdown_priv(Shutdown::Write))
}
}
wrapper_lite::wrapper!(
#[wrapper_impl(AsRef<UniStream>)]
#[derive(Debug)]
pub struct OwnedReadHalf(Arc<UniStream>);
);
impl AsyncRead for OwnedReadHalf {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
self.inner.poll_read_priv(cx, buf)
}
}
wrapper_lite::wrapper!(
#[wrapper_impl(AsRef<UniStream>)]
#[derive(Debug)]
pub struct OwnedWriteHalf(Arc<UniStream>);
);
impl AsyncWrite for OwnedWriteHalf {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
self.inner.poll_write_priv(cx, buf)
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(self.inner.flush_priv())
}
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(self.inner.shutdown_priv(Shutdown::Write))
}
}
impl Drop for OwnedWriteHalf {
fn drop(&mut self) {
let _ = self.inner.get_ref().shutdown(Shutdown::Write);
}
}
#[cfg(feature = "splice")]
impl tokio_splice2::AsyncReadFd for UniStream {
fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
self.inner.poll_read_ready(cx).map_ok(|_| ())
}
fn try_io_read<R>(&self, f: impl FnOnce() -> io::Result<R>) -> io::Result<R> {
use tokio::io::Interest;
self.inner.try_io(Interest::READABLE, |_| f())
}
}
#[cfg(feature = "splice")]
impl tokio_splice2::AsyncWriteFd for UniStream {
fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
self.inner.poll_write_ready(cx).map_ok(|_| ())
}
fn try_io_write<R>(&self, f: impl FnOnce() -> io::Result<R>) -> io::Result<R> {
use tokio::io::Interest;
self.inner.try_io(Interest::WRITABLE, |_| f())
}
}
#[cfg(feature = "splice")]
impl tokio_splice2::IsNotFile for UniStream {}