use super::{datagram::GlommioDatagram, stream::GlommioStream};
use crate::{
net::stream::{Buffered, NonBuffered, Preallocated, RxBuf},
reactor::Reactor,
};
use futures_lite::{
future::poll_fn,
io::{AsyncBufRead, AsyncRead, AsyncWrite},
stream::{self, Stream},
};
use nix::sys::socket::{SockAddr, UnixAddr};
use pin_project_lite::pin_project;
use socket2::{Domain, Socket, Type};
use std::{
io,
net::Shutdown,
os::unix::{
io::{AsRawFd, FromRawFd, RawFd},
net::{self, SocketAddr},
},
path::Path,
pin::Pin,
rc::{Rc, Weak},
task::{Context, Poll},
};
type Result<T> = crate::Result<T, ()>;
#[derive(Debug)]
pub struct UnixListener {
reactor: Weak<Reactor>,
listener: net::UnixListener,
}
impl UnixListener {
pub fn bind<A: AsRef<Path>>(addr: A) -> Result<UnixListener> {
let sk = Socket::new(Domain::unix(), Type::stream(), None)?;
let addr = socket2::SockAddr::unix(addr.as_ref())?;
sk.bind(&addr)?;
sk.listen(128)?;
let listener = sk.into_unix_listener();
Ok(UnixListener {
reactor: Rc::downgrade(&crate::executor().reactor()),
listener,
})
}
pub async fn shared_accept(&self) -> Result<AcceptedUnixStream> {
let reactor = self.reactor.upgrade().unwrap();
let source = reactor.accept(self.listener.as_raw_fd());
let fd = source.collect_rw().await?;
Ok(AcceptedUnixStream { fd: fd as RawFd })
}
pub async fn accept(&self) -> Result<UnixStream> {
let a = self.shared_accept().await?;
Ok(a.bind_to_executor())
}
pub fn incoming(&self) -> impl Stream<Item = Result<UnixStream>> + Unpin + '_ {
Box::pin(stream::unfold(self, |listener| async move {
let res = listener.accept().await.map_err(Into::into);
Some((res, listener))
}))
}
pub fn local_addr(&self) -> Result<SocketAddr> {
self.listener.local_addr().map_err(Into::into)
}
}
#[derive(Copy, Clone, Debug)]
pub struct AcceptedUnixStream {
fd: RawFd,
}
impl AcceptedUnixStream {
pub fn bind_to_executor(self) -> UnixStream {
let stream = unsafe { GlommioStream::from_raw_fd(self.fd as _) };
UnixStream { stream }
}
}
pin_project! {
#[derive(Debug)]
pub struct UnixStream<B: RxBuf = NonBuffered> {
stream: GlommioStream<net::UnixStream, B>
}
}
impl FromRawFd for UnixStream {
unsafe fn from_raw_fd(fd: RawFd) -> Self {
UnixStream {
stream: GlommioStream::from_raw_fd(fd as _),
}
}
}
impl UnixStream {
pub fn pair() -> Result<(UnixStream, UnixStream)> {
let (stream1, stream2) = net::UnixStream::pair()?;
let stream1 = GlommioStream::from(socket2::Socket::from(stream1));
let stream2 = GlommioStream::from(socket2::Socket::from(stream2));
let stream1 = Self { stream: stream1 };
let stream2 = Self { stream: stream2 };
Ok((stream1, stream2))
}
pub async fn connect<A: AsRef<Path>>(addr: A) -> Result<UnixStream> {
let reactor = crate::executor().reactor();
let socket = Socket::new(Domain::unix(), Type::stream(), None)?;
let addr = SockAddr::new_unix(addr.as_ref())
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
let source = reactor.connect(socket.as_raw_fd(), addr);
source.collect_rw().await?;
Ok(Self {
stream: GlommioStream::from(socket),
})
}
pub fn buffered(self) -> UnixStream<Preallocated> {
self.buffered_with(Preallocated::default())
}
pub fn buffered_with<B: Buffered>(self, buf: B) -> UnixStream<B> {
UnixStream {
stream: self.stream.buffered_with(buf),
}
}
}
impl<B: RxBuf> UnixStream<B> {
pub async fn shutdown(&self, how: Shutdown) -> Result<()> {
poll_fn(|cx| self.stream.poll_shutdown(cx, how))
.await
.map_err(Into::into)
}
pub async fn peek(&self, buf: &mut [u8]) -> Result<usize> {
self.stream.peek(buf).await.map_err(Into::into)
}
pub fn peer_addr(&self) -> Result<SocketAddr> {
self.stream.stream().peer_addr().map_err(Into::into)
}
pub fn local_addr(&self) -> Result<SocketAddr> {
self.stream.stream().local_addr().map_err(Into::into)
}
}
impl<B: Buffered + Unpin> AsyncBufRead for UnixStream<B> {
fn poll_fill_buf<'a>(
self: Pin<&'a mut Self>,
cx: &mut Context<'_>,
) -> Poll<io::Result<&'a [u8]>> {
let this = self.project();
this.stream.poll_fill_buf(cx)
}
fn consume(mut self: Pin<&mut Self>, amt: usize) {
self.stream.consume(amt);
}
}
impl<B: RxBuf + Unpin> AsyncRead for UnixStream<B> {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
Pin::new(&mut self.stream).poll_read(cx, buf)
}
}
impl<B: RxBuf + Unpin> AsyncWrite for UnixStream<B> {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
Pin::new(&mut self.stream).poll_write(cx, buf)
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.stream).poll_flush(cx)
}
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.stream).poll_close(cx)
}
}
#[derive(Debug)]
pub struct UnixDatagram {
socket: GlommioDatagram<net::UnixDatagram>,
}
impl UnixDatagram {
pub fn pair() -> Result<(UnixDatagram, UnixDatagram)> {
let (socket1, socket2) = net::UnixDatagram::pair()?;
let socket1 = GlommioDatagram::from(socket2::Socket::from(socket1));
let socket2 = GlommioDatagram::from(socket2::Socket::from(socket2));
let socket1 = Self { socket: socket1 };
let socket2 = Self { socket: socket2 };
Ok((socket1, socket2))
}
pub fn bind<A: AsRef<Path>>(addr: A) -> Result<UnixDatagram> {
let sk = Socket::new(Domain::unix(), Type::dgram(), None)?;
let addr = socket2::SockAddr::unix(addr.as_ref())?;
sk.bind(&addr)?;
Ok(Self {
socket: GlommioDatagram::from(sk),
})
}
pub fn unbound() -> Result<UnixDatagram> {
let sk = Socket::new(Domain::unix(), Type::dgram(), None)?;
Ok(Self {
socket: GlommioDatagram::from(sk),
})
}
pub async fn connect<A: AsRef<Path>>(&self, addr: A) -> Result<()> {
let addr = SockAddr::new_unix(addr.as_ref())
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
let reactor = self.socket.reactor.upgrade().unwrap();
let source = reactor.connect(self.socket.as_raw_fd(), addr);
source.collect_rw().await.map(|_| {}).map_err(Into::into)
}
pub fn set_buffer_size(&mut self, buffer_size: usize) {
self.socket.rx_buf_size = buffer_size;
}
pub fn buffer_size(&mut self) -> usize {
self.socket.rx_buf_size
}
pub async fn peek(&self, buf: &mut [u8]) -> Result<usize> {
let _ = self.peer_addr()?;
self.socket.peek(buf).await.map_err(Into::into)
}
#[track_caller]
pub async fn peek_from(&self, buf: &mut [u8]) -> Result<(usize, UnixAddr)> {
let (sz, addr) = self.socket.peek_from(buf).await?;
let addr = match addr {
nix::sys::socket::SockAddr::Unix(addr) => addr,
x => panic!("invalid socket addr for this family!: {:?}", x),
};
Ok((sz, addr))
}
pub fn peer_addr(&self) -> Result<SocketAddr> {
self.socket.socket.peer_addr().map_err(Into::into)
}
pub fn local_addr(&self) -> Result<SocketAddr> {
self.socket.socket.local_addr().map_err(Into::into)
}
pub async fn recv(&self, buf: &mut [u8]) -> Result<usize> {
self.socket.recv(buf).await.map_err(Into::into)
}
#[track_caller]
pub async fn recv_from(&self, buf: &mut [u8]) -> Result<(usize, UnixAddr)> {
let (sz, addr) = self.socket.recv_from(buf).await?;
let addr = match addr {
nix::sys::socket::SockAddr::Unix(addr) => addr,
x => panic!("invalid socket addr for this family!: {:?}", x),
};
Ok((sz, addr))
}
pub async fn send_to<A: AsRef<Path>>(&self, buf: &[u8], addr: A) -> Result<usize> {
let addr = nix::sys::socket::SockAddr::new_unix(addr.as_ref())
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
self.socket.send_to(buf, addr).await.map_err(Into::into)
}
pub async fn send(&self, buf: &[u8]) -> Result<usize> {
self.socket.send(buf).await.map_err(Into::into)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{enclose, test_utils::*};
use futures_lite::io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt};
use std::cell::Cell;
macro_rules! unix_socket_test {
( $name:ident, $dir:ident, $code:block) => {
#[test]
fn $name() {
let td = make_tmp_test_directory(&format!("uds-{}", stringify!($name)));
let $dir = td.path.clone();
test_executor!(async move { $code });
}
};
}
unix_socket_test!(connect_local_server, dir, {
let mut file = dir.clone();
file.push("name");
let listener = UnixListener::bind(&file).unwrap();
let addr = listener.local_addr().unwrap();
let addr = addr.as_pathname().unwrap();
let coord = Rc::new(Cell::new(0));
let listener_handle = crate::spawn_local(enclose! { (coord) async move {
coord.set(1);
listener.accept().await.unwrap();
}})
.detach();
while coord.get() != 1 {
crate::executor().yield_task_queue_now().await;
}
UnixStream::connect(&addr).await.unwrap();
listener_handle.await.unwrap();
});
unix_socket_test!(pair, _dir, {
let (mut p1, mut p2) = UnixStream::pair().unwrap();
let sz = p1.write(&[65u8; 1]).await.unwrap();
assert_eq!(sz, 1);
let mut buf = [0u8; 1];
let sz = p2.read(&mut buf).await.unwrap();
assert_eq!(sz, 1);
assert_eq!(buf[0], 65);
});
unix_socket_test!(read_until, dir, {
let mut file = dir.clone();
file.push("name");
let listener = UnixListener::bind(&file).unwrap();
let listener_handle = crate::spawn_local(async move {
let mut stream = listener.accept().await?.buffered();
let mut buf = Vec::new();
stream.read_until(10, &mut buf).await?;
io::Result::Ok(buf.len())
})
.detach();
let mut stream = UnixStream::connect(&file).await.unwrap();
let vec = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
let b = stream.write(&vec).await.unwrap();
assert_eq!(b, 10);
assert_eq!(listener_handle.await.unwrap().unwrap(), 10);
});
unix_socket_test!(datagram_pair_ping_pong, _dir, {
let (p1, p2) = UnixDatagram::pair().unwrap();
let sz = p1.send(&[65u8; 1]).await.unwrap();
assert_eq!(sz, 1);
let mut buf = [0u8; 1];
let sz = p2.recv(&mut buf).await.unwrap();
assert_eq!(sz, 1);
assert_eq!(buf[0], 65);
});
unix_socket_test!(datagram_send_recv, dir, {
let mut file = dir.clone();
file.push("name");
let p1 = UnixDatagram::bind(&file).unwrap();
let p2 = UnixDatagram::unbound().unwrap();
p2.connect(&file).await.unwrap();
p2.send(b"msg1").await.unwrap();
let mut buf = [0u8; 10];
let sz = p1.recv(&mut buf).await.unwrap();
assert_eq!(sz, 4);
});
unix_socket_test!(datagram_send_to_recv_from, dir, {
let mut file = dir.clone();
file.push("name");
let p1 = UnixDatagram::bind(&file).unwrap();
let p2 = UnixDatagram::unbound().unwrap();
p2.send_to(b"msg1", &file).await.unwrap();
let mut buf = [0u8; 10];
let (sz, addr) = p1.recv_from(&mut buf).await.unwrap();
assert_eq!(sz, 4);
assert!(addr.path().is_none());
});
unix_socket_test!(datagram_connect_unbounded, dir, {
let mut file = dir.clone();
file.push("name");
let _p1 = UnixDatagram::bind(&file).unwrap();
let p2 = UnixDatagram::unbound().unwrap();
p2.connect(&file).await.unwrap();
});
}