use std::{future::Future, io, path::Path};
use compio_buf::{BufResult, IoBuf, IoBufMut, IoVectoredBuf, IoVectoredBufMut};
use compio_driver::impl_raw_fd;
use compio_io::{AsyncRead, AsyncReadManaged, AsyncWrite, util::Splittable};
use compio_runtime::{BorrowedBuffer, BufferPool};
use socket2::{SockAddr, Socket as Socket2, Type};
use crate::{OwnedReadHalf, OwnedWriteHalf, PollFd, ReadHalf, Socket, SocketOpts, WriteHalf};
#[derive(Debug, Clone)]
pub struct UnixListener {
inner: Socket,
}
impl UnixListener {
pub async fn bind(path: impl AsRef<Path>) -> io::Result<Self> {
Self::bind_addr(&SockAddr::unix(path)?).await
}
pub async fn bind_addr(addr: &SockAddr) -> io::Result<Self> {
Self::bind_with_options(addr, &SocketOpts::default()).await
}
pub async fn bind_with_options(addr: &SockAddr, opts: &SocketOpts) -> io::Result<Self> {
if !addr.is_unix() {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"addr is not unix socket address",
));
}
let socket = Socket::bind(addr, Type::STREAM, None).await?;
opts.setup_socket(&socket)?;
socket.listen(1024)?;
Ok(UnixListener { inner: socket })
}
#[cfg(unix)]
pub fn from_std(stream: std::os::unix::net::UnixListener) -> io::Result<Self> {
Ok(Self {
inner: Socket::from_socket2(Socket2::from(stream))?,
})
}
pub fn close(self) -> impl Future<Output = io::Result<()>> {
self.inner.close()
}
pub async fn accept(&self) -> io::Result<(UnixStream, SockAddr)> {
let (socket, addr) = self.inner.accept().await?;
let stream = UnixStream { inner: socket };
Ok((stream, addr))
}
pub async fn accept_with_options(
&self,
options: &SocketOpts,
) -> io::Result<(UnixStream, SockAddr)> {
let (socket, addr) = self.inner.accept().await?;
options.setup_socket(&socket)?;
let stream = UnixStream { inner: socket };
Ok((stream, addr))
}
pub fn local_addr(&self) -> io::Result<SockAddr> {
self.inner.local_addr()
}
}
impl_raw_fd!(UnixListener, socket2::Socket, inner, socket);
#[derive(Debug, Clone)]
pub struct UnixStream {
inner: Socket,
}
impl UnixStream {
pub async fn connect(path: impl AsRef<Path>) -> io::Result<Self> {
Self::connect_addr(&SockAddr::unix(path)?).await
}
pub async fn connect_addr(addr: &SockAddr) -> io::Result<Self> {
Self::connect_with_options(addr, &SocketOpts::default()).await
}
pub async fn connect_with_options(addr: &SockAddr, options: &SocketOpts) -> io::Result<Self> {
if !addr.is_unix() {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"addr is not unix socket address",
));
}
#[cfg(windows)]
let socket = {
let new_addr = empty_unix_socket();
Socket::bind(&new_addr, Type::STREAM, None).await?
};
#[cfg(unix)]
let socket = {
use socket2::Domain;
Socket::new(Domain::UNIX, Type::STREAM, None).await?
};
options.setup_socket(&socket)?;
socket.connect_async(addr).await?;
let unix_stream = UnixStream { inner: socket };
Ok(unix_stream)
}
#[cfg(unix)]
pub fn from_std(stream: std::os::unix::net::UnixStream) -> io::Result<Self> {
Ok(Self {
inner: Socket::from_socket2(Socket2::from(stream))?,
})
}
pub fn close(self) -> impl Future<Output = io::Result<()>> {
self.inner.close()
}
pub fn peer_addr(&self) -> io::Result<SockAddr> {
#[allow(unused_mut)]
let mut addr = self.inner.peer_addr()?;
#[cfg(windows)]
{
fix_unix_socket_length(&mut addr);
}
Ok(addr)
}
pub fn local_addr(&self) -> io::Result<SockAddr> {
self.inner.local_addr()
}
pub fn split(&self) -> (ReadHalf<'_, Self>, WriteHalf<'_, Self>) {
crate::split(self)
}
pub fn into_split(self) -> (OwnedReadHalf<Self>, OwnedWriteHalf<Self>) {
crate::into_split(self)
}
pub fn to_poll_fd(&self) -> io::Result<PollFd<Socket2>> {
self.inner.to_poll_fd()
}
pub fn into_poll_fd(self) -> io::Result<PollFd<Socket2>> {
self.inner.into_poll_fd()
}
}
impl AsyncRead for UnixStream {
#[inline]
async fn read<B: IoBufMut>(&mut self, buf: B) -> BufResult<usize, B> {
(&*self).read(buf).await
}
#[inline]
async fn read_vectored<V: IoVectoredBufMut>(&mut self, buf: V) -> BufResult<usize, V> {
(&*self).read_vectored(buf).await
}
}
impl AsyncRead for &UnixStream {
#[inline]
async fn read<B: IoBufMut>(&mut self, buf: B) -> BufResult<usize, B> {
self.inner.recv(buf, 0).await
}
#[inline]
async fn read_vectored<V: IoVectoredBufMut>(&mut self, buf: V) -> BufResult<usize, V> {
self.inner.recv_vectored(buf, 0).await
}
}
impl AsyncReadManaged for UnixStream {
type Buffer<'a> = BorrowedBuffer<'a>;
type BufferPool = BufferPool;
async fn read_managed<'a>(
&mut self,
buffer_pool: &'a Self::BufferPool,
len: usize,
) -> io::Result<Self::Buffer<'a>> {
(&*self).read_managed(buffer_pool, len).await
}
}
impl AsyncReadManaged for &UnixStream {
type Buffer<'a> = BorrowedBuffer<'a>;
type BufferPool = BufferPool;
async fn read_managed<'a>(
&mut self,
buffer_pool: &'a Self::BufferPool,
len: usize,
) -> io::Result<Self::Buffer<'a>> {
self.inner.recv_managed(buffer_pool, len as _, 0).await
}
}
impl AsyncWrite for UnixStream {
#[inline]
async fn write<T: IoBuf>(&mut self, buf: T) -> BufResult<usize, T> {
(&*self).write(buf).await
}
#[inline]
async fn write_vectored<T: IoVectoredBuf>(&mut self, buf: T) -> BufResult<usize, T> {
(&*self).write_vectored(buf).await
}
#[inline]
async fn flush(&mut self) -> io::Result<()> {
(&*self).flush().await
}
#[inline]
async fn shutdown(&mut self) -> io::Result<()> {
(&*self).shutdown().await
}
}
impl AsyncWrite for &UnixStream {
#[inline]
async fn write<T: IoBuf>(&mut self, buf: T) -> BufResult<usize, T> {
self.inner.send(buf, 0).await
}
#[inline]
async fn write_vectored<T: IoVectoredBuf>(&mut self, buf: T) -> BufResult<usize, T> {
self.inner.send_vectored(buf, 0).await
}
#[inline]
async fn flush(&mut self) -> io::Result<()> {
Ok(())
}
#[inline]
async fn shutdown(&mut self) -> io::Result<()> {
self.inner.shutdown().await
}
}
impl Splittable for UnixStream {
type ReadHalf = OwnedReadHalf<Self>;
type WriteHalf = OwnedWriteHalf<Self>;
fn split(self) -> (Self::ReadHalf, Self::WriteHalf) {
crate::into_split(self)
}
}
impl<'a> Splittable for &'a UnixStream {
type ReadHalf = ReadHalf<'a, UnixStream>;
type WriteHalf = WriteHalf<'a, UnixStream>;
fn split(self) -> (Self::ReadHalf, Self::WriteHalf) {
crate::split(self)
}
}
impl_raw_fd!(UnixStream, socket2::Socket, inner, socket);
#[cfg(windows)]
#[inline]
fn empty_unix_socket() -> SockAddr {
use windows_sys::Win32::Networking::WinSock::{AF_UNIX, SOCKADDR_UN};
unsafe {
SockAddr::try_init(|addr, len| {
let addr: *mut SOCKADDR_UN = addr.cast();
std::ptr::write(
addr,
SOCKADDR_UN {
sun_family: AF_UNIX,
sun_path: [0; 108],
},
);
std::ptr::write(len, 3);
Ok(())
})
}
.unwrap()
.1
}
#[cfg(windows)]
#[inline]
fn fix_unix_socket_length(addr: &mut SockAddr) {
use windows_sys::Win32::Networking::WinSock::SOCKADDR_UN;
let unix_addr: &SOCKADDR_UN = unsafe { &*addr.as_ptr().cast() };
let sun_path = unsafe {
std::slice::from_raw_parts(
unix_addr.sun_path.as_ptr() as *const u8,
unix_addr.sun_path.len(),
)
};
let addr_len = match std::ffi::CStr::from_bytes_until_nul(sun_path) {
Ok(str) => str.to_bytes_with_nul().len() + 2,
Err(_) => std::mem::size_of::<SOCKADDR_UN>(),
};
unsafe {
addr.set_length(addr_len as _);
}
}