use std::{
future::Future,
io,
path::Path,
pin::Pin,
task::{Context, Poll},
};
use compio_buf::{BufResult, IoBuf, IoBufMut, IoVectoredBuf, IoVectoredBufMut};
use compio_driver::{
BufferRef, impl_raw_fd,
op::{RecvFlags, RecvMsgMultiResult},
};
use compio_io::{
AsyncRead, AsyncReadManaged, AsyncReadMulti, AsyncWrite,
ancillary::{
AsyncReadAncillary, AsyncReadAncillaryManaged, AsyncReadAncillaryMulti, AsyncWriteAncillary,
},
util::Splittable,
};
use compio_runtime::fd::PollFd;
use futures_util::{Stream, StreamExt, stream::FusedStream};
use socket2::{Domain, SockAddr, Socket as Socket2, Type};
use crate::{Incoming, MSG_NOSIGNAL, OwnedReadHalf, OwnedWriteHalf, ReadHalf, Socket, WriteHalf};
#[derive(Debug, Clone)]
pub struct UnixListener {
inner: Socket,
}
impl UnixListener {
pub async fn bind(path: impl AsRef<Path>) -> io::Result<Self> {
Self::bind_addr(&SockAddr::unix(path)?).await
}
pub async fn bind_addr(addr: &SockAddr) -> io::Result<Self> {
if !addr.is_unix() {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"addr is not unix socket address",
));
}
let socket = Socket::new(addr.domain(), Type::STREAM, None).await?;
socket.bind(addr).await?;
socket.listen(1024).await?;
Ok(UnixListener { inner: socket })
}
#[cfg(unix)]
pub fn from_std(stream: std::os::unix::net::UnixListener) -> io::Result<Self> {
Ok(Self {
inner: Socket::from_socket2(Socket2::from(stream))?,
})
}
pub fn close(self) -> impl Future<Output = io::Result<()>> {
self.inner.close()
}
pub async fn accept(&self) -> io::Result<(UnixStream, SockAddr)> {
let (socket, addr) = self.inner.accept().await?;
let stream = UnixStream { inner: socket };
Ok((stream, addr))
}
pub fn incoming(&self) -> UnixIncoming<'_> {
UnixIncoming {
inner: self.inner.incoming(),
}
}
pub fn local_addr(&self) -> io::Result<SockAddr> {
self.inner.local_addr()
}
pub fn take_error(&self) -> io::Result<Option<io::Error>> {
self.inner.socket.take_error()
}
}
impl_raw_fd!(UnixListener, socket2::Socket, inner, socket);
pub struct UnixIncoming<'a> {
inner: Incoming<'a>,
}
impl Stream for UnixIncoming<'_> {
type Item = io::Result<UnixStream>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.get_mut();
this.inner.poll_next_unpin(cx).map(|res| {
res.map(|res| {
let socket = res?;
Ok(UnixStream { inner: socket })
})
})
}
}
impl FusedStream for UnixIncoming<'_> {
fn is_terminated(&self) -> bool {
self.inner.is_terminated()
}
}
#[derive(Debug, Clone)]
pub struct UnixStream {
inner: Socket,
}
impl UnixStream {
pub async fn connect(path: impl AsRef<Path>) -> io::Result<Self> {
Self::connect_addr(&SockAddr::unix(path)?).await
}
pub async fn connect_addr(addr: &SockAddr) -> io::Result<Self> {
if !addr.is_unix() {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"addr is not unix socket address",
));
}
let socket = Socket::new(Domain::UNIX, Type::STREAM, None).await?;
#[cfg(windows)]
{
let new_addr = empty_unix_socket();
socket.bind(&new_addr).await?
}
socket.connect_async(addr).await?;
let unix_stream = UnixStream { inner: socket };
Ok(unix_stream)
}
#[cfg(unix)]
pub fn from_std(stream: std::os::unix::net::UnixStream) -> io::Result<Self> {
Ok(Self {
inner: Socket::from_socket2(Socket2::from(stream))?,
})
}
pub fn close(self) -> impl Future<Output = io::Result<()>> {
self.inner.close()
}
pub fn peer_addr(&self) -> io::Result<SockAddr> {
#[allow(unused_mut)]
let mut addr = self.inner.peer_addr()?;
#[cfg(windows)]
{
fix_unix_socket_length(&mut addr);
}
Ok(addr)
}
pub fn local_addr(&self) -> io::Result<SockAddr> {
self.inner.local_addr()
}
pub fn take_error(&self) -> io::Result<Option<io::Error>> {
self.inner.socket.take_error()
}
pub fn split(&self) -> (ReadHalf<'_, Self>, WriteHalf<'_, Self>) {
crate::split(self)
}
pub fn into_split(self) -> (OwnedReadHalf<Self>, OwnedWriteHalf<Self>) {
crate::into_split(self)
}
pub fn to_poll_fd(&self) -> io::Result<PollFd<Socket2>> {
self.inner.to_poll_fd()
}
pub fn into_poll_fd(self) -> io::Result<PollFd<Socket2>> {
self.inner.into_poll_fd()
}
pub fn sock_nonempty(&self) -> Option<bool> {
self.inner.sock_nonempty()
}
pub async fn send_zerocopy<T: IoBuf>(
&self,
buf: T,
) -> BufResult<usize, impl Future<Output = T> + use<T>> {
self.inner.send_zerocopy(buf, MSG_NOSIGNAL).await
}
pub async fn send_zerocopy_vectored<T: IoVectoredBuf>(
&self,
buf: T,
) -> BufResult<usize, impl Future<Output = T> + use<T>> {
self.inner.send_zerocopy_vectored(buf, MSG_NOSIGNAL).await
}
}
impl AsyncRead for UnixStream {
#[inline]
async fn read<B: IoBufMut>(&mut self, buf: B) -> BufResult<usize, B> {
(&*self).read(buf).await
}
#[inline]
async fn read_vectored<V: IoVectoredBufMut>(&mut self, buf: V) -> BufResult<usize, V> {
(&*self).read_vectored(buf).await
}
}
impl AsyncRead for &UnixStream {
#[inline]
async fn read<B: IoBufMut>(&mut self, buf: B) -> BufResult<usize, B> {
self.inner.recv(buf, RecvFlags::empty()).await
}
#[inline]
async fn read_vectored<V: IoVectoredBufMut>(&mut self, buf: V) -> BufResult<usize, V> {
self.inner.recv_vectored(buf, RecvFlags::empty()).await
}
}
impl AsyncReadManaged for UnixStream {
type Buffer = BufferRef;
async fn read_managed(&mut self, len: usize) -> io::Result<Option<Self::Buffer>> {
(&*self).read_managed(len).await
}
}
impl AsyncReadManaged for &UnixStream {
type Buffer = BufferRef;
async fn read_managed(&mut self, len: usize) -> io::Result<Option<Self::Buffer>> {
self.inner.recv_managed(len, RecvFlags::empty()).await
}
}
impl AsyncReadMulti for UnixStream {
fn read_multi(&mut self, len: usize) -> impl Stream<Item = io::Result<Self::Buffer>> {
self.inner.recv_multi(len, RecvFlags::empty())
}
}
impl AsyncReadMulti for &UnixStream {
fn read_multi(&mut self, len: usize) -> impl Stream<Item = io::Result<Self::Buffer>> {
self.inner.recv_multi(len, RecvFlags::empty())
}
}
impl AsyncReadAncillary for UnixStream {
#[inline]
async fn read_with_ancillary<T: IoBufMut, C: IoBufMut>(
&mut self,
buffer: T,
control: C,
) -> BufResult<(usize, usize), (T, C)> {
(&*self).read_with_ancillary(buffer, control).await
}
#[inline]
async fn read_vectored_with_ancillary<T: IoVectoredBufMut, C: IoBufMut>(
&mut self,
buffer: T,
control: C,
) -> BufResult<(usize, usize), (T, C)> {
(&*self).read_vectored_with_ancillary(buffer, control).await
}
}
impl AsyncReadAncillary for &UnixStream {
#[inline]
async fn read_with_ancillary<T: IoBufMut, C: IoBufMut>(
&mut self,
buffer: T,
control: C,
) -> BufResult<(usize, usize), (T, C)> {
self.inner
.recv_msg(buffer, control, RecvFlags::empty())
.await
.map_res(|(res, len, _addr)| (res, len))
}
#[inline]
async fn read_vectored_with_ancillary<T: IoVectoredBufMut, C: IoBufMut>(
&mut self,
buffer: T,
control: C,
) -> BufResult<(usize, usize), (T, C)> {
self.inner
.recv_msg_vectored(buffer, control, RecvFlags::empty())
.await
.map_res(|(res, len, _addr)| (res, len))
}
}
impl AsyncReadAncillaryManaged for UnixStream {
#[inline]
async fn read_managed_with_ancillary<C: IoBufMut>(
&mut self,
len: usize,
control: C,
) -> io::Result<Option<(Self::Buffer, C)>> {
(&*self).read_managed_with_ancillary(len, control).await
}
}
impl AsyncReadAncillaryManaged for &UnixStream {
#[inline]
async fn read_managed_with_ancillary<C: IoBufMut>(
&mut self,
len: usize,
control: C,
) -> io::Result<Option<(Self::Buffer, C)>> {
self.inner
.recv_msg_managed(len, control, RecvFlags::empty())
.await
.map(|res| res.map(|(res, len, _addr)| (res, len)))
}
}
impl AsyncReadAncillaryMulti for UnixStream {
type Return = RecvMsgMultiResult;
#[inline]
fn read_multi_with_ancillary(
&mut self,
control_len: usize,
) -> impl Stream<Item = io::Result<Self::Return>> {
self.inner.recv_msg_multi(control_len, RecvFlags::empty())
}
}
impl AsyncReadAncillaryMulti for &UnixStream {
type Return = RecvMsgMultiResult;
#[inline]
fn read_multi_with_ancillary(
&mut self,
control_len: usize,
) -> impl Stream<Item = io::Result<Self::Return>> {
self.inner.recv_msg_multi(control_len, RecvFlags::empty())
}
}
impl AsyncWrite for UnixStream {
#[inline]
async fn write<T: IoBuf>(&mut self, buf: T) -> BufResult<usize, T> {
(&*self).write(buf).await
}
#[inline]
async fn write_vectored<T: IoVectoredBuf>(&mut self, buf: T) -> BufResult<usize, T> {
(&*self).write_vectored(buf).await
}
#[inline]
async fn flush(&mut self) -> io::Result<()> {
(&*self).flush().await
}
#[inline]
async fn shutdown(&mut self) -> io::Result<()> {
(&*self).shutdown().await
}
}
impl AsyncWrite for &UnixStream {
#[inline]
async fn write<T: IoBuf>(&mut self, buf: T) -> BufResult<usize, T> {
self.inner.send(buf, MSG_NOSIGNAL).await
}
#[inline]
async fn write_vectored<T: IoVectoredBuf>(&mut self, buf: T) -> BufResult<usize, T> {
self.inner.send_vectored(buf, MSG_NOSIGNAL).await
}
#[inline]
async fn flush(&mut self) -> io::Result<()> {
Ok(())
}
#[inline]
async fn shutdown(&mut self) -> io::Result<()> {
self.inner.shutdown().await
}
}
impl AsyncWriteAncillary for UnixStream {
#[inline]
async fn write_with_ancillary<T: IoBuf, C: IoBuf>(
&mut self,
buffer: T,
control: C,
) -> BufResult<usize, (T, C)> {
(&*self).write_with_ancillary(buffer, control).await
}
#[inline]
async fn write_vectored_with_ancillary<T: IoVectoredBuf, C: IoBuf>(
&mut self,
buffer: T,
control: C,
) -> BufResult<usize, (T, C)> {
(&*self)
.write_vectored_with_ancillary(buffer, control)
.await
}
}
impl AsyncWriteAncillary for &UnixStream {
#[inline]
async fn write_with_ancillary<T: IoBuf, C: IoBuf>(
&mut self,
buffer: T,
control: C,
) -> BufResult<usize, (T, C)> {
self.inner
.send_msg(buffer, control, None, MSG_NOSIGNAL)
.await
}
#[inline]
async fn write_vectored_with_ancillary<T: IoVectoredBuf, C: IoBuf>(
&mut self,
buffer: T,
control: C,
) -> BufResult<usize, (T, C)> {
self.inner
.send_msg_vectored(buffer, control, None, MSG_NOSIGNAL)
.await
}
}
impl Splittable for UnixStream {
type ReadHalf = OwnedReadHalf<Self>;
type WriteHalf = OwnedWriteHalf<Self>;
fn split(self) -> (Self::ReadHalf, Self::WriteHalf) {
crate::into_split(self)
}
}
impl<'a> Splittable for &'a UnixStream {
type ReadHalf = ReadHalf<'a, UnixStream>;
type WriteHalf = WriteHalf<'a, UnixStream>;
fn split(self) -> (Self::ReadHalf, Self::WriteHalf) {
crate::split(self)
}
}
impl<'a> Splittable for &'a mut UnixStream {
type ReadHalf = ReadHalf<'a, UnixStream>;
type WriteHalf = WriteHalf<'a, UnixStream>;
fn split(self) -> (Self::ReadHalf, Self::WriteHalf) {
crate::split(self)
}
}
impl_raw_fd!(UnixStream, socket2::Socket, inner, socket);
#[derive(Debug)]
pub struct UnixSocket {
inner: Socket,
}
impl UnixSocket {
pub async fn new_stream() -> io::Result<UnixSocket> {
UnixSocket::new(socket2::Type::STREAM).await
}
async fn new(ty: socket2::Type) -> io::Result<UnixSocket> {
let inner = Socket::new(socket2::Domain::UNIX, ty, None).await?;
Ok(UnixSocket { inner })
}
pub fn local_addr(&self) -> io::Result<SockAddr> {
self.inner.local_addr()
}
pub fn take_error(&self) -> io::Result<Option<io::Error>> {
self.inner.socket.take_error()
}
pub async fn bind(&self, path: impl AsRef<Path>) -> io::Result<()> {
self.bind_addr(&SockAddr::unix(path)?).await
}
pub async fn bind_addr(&self, addr: &SockAddr) -> io::Result<()> {
if !addr.is_unix() {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"addr is not unix socket address",
));
}
self.inner.bind(addr).await
}
pub async fn listen(self, backlog: i32) -> io::Result<UnixListener> {
self.inner.listen(backlog).await?;
Ok(UnixListener { inner: self.inner })
}
pub async fn connect(self, path: impl AsRef<Path>) -> io::Result<UnixStream> {
self.connect_addr(&SockAddr::unix(path)?).await
}
pub async fn connect_addr(self, addr: &SockAddr) -> io::Result<UnixStream> {
if !addr.is_unix() {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"addr is not unix socket address",
));
}
self.inner.connect_async(addr).await?;
Ok(UnixStream { inner: self.inner })
}
}
impl_raw_fd!(UnixSocket, socket2::Socket, inner, socket);
#[cfg(windows)]
#[inline]
fn empty_unix_socket() -> SockAddr {
use windows_sys::Win32::Networking::WinSock::{AF_UNIX, SOCKADDR_UN};
unsafe {
SockAddr::try_init(|addr, len| {
let addr: *mut SOCKADDR_UN = addr.cast();
std::ptr::write(
addr,
SOCKADDR_UN {
sun_family: AF_UNIX,
sun_path: [0; 108],
},
);
std::ptr::write(len, 3);
Ok(())
})
}
.unwrap()
.1
}
#[cfg(windows)]
#[inline]
fn fix_unix_socket_length(addr: &mut SockAddr) {
use windows_sys::Win32::Networking::WinSock::SOCKADDR_UN;
let unix_addr: &SOCKADDR_UN = unsafe { &*addr.as_ptr().cast() };
let sun_path = unsafe {
std::slice::from_raw_parts(
unix_addr.sun_path.as_ptr() as *const u8,
unix_addr.sun_path.len(),
)
};
let addr_len = match std::ffi::CStr::from_bytes_until_nul(sun_path) {
Ok(str) => str.to_bytes_with_nul().len() + 2,
Err(_) => std::mem::size_of::<SOCKADDR_UN>(),
};
unsafe {
addr.set_length(addr_len as _);
}
}