zlink-smol 0.3.0

zlink library for the smol runtime
Documentation
use crate::{
    connection::socket::{self, Socket},
    Result,
};
use async_io::Async;
use std::{
    os::{
        fd::{AsFd, BorrowedFd, OwnedFd},
        unix::net::UnixStream as StdUnixStream,
    },
    sync::Arc,
};

/// The connection type that uses Unix Domain Sockets for transport.
pub type Connection = crate::Connection<Stream>;

/// Connect to Unix Domain Socket at the given path.
pub async fn connect<P>(path: P) -> Result<Connection>
where
    P: AsRef<std::path::Path>,
{
    Async::<StdUnixStream>::connect(path)
        .await
        .map(Stream)
        .map(Connection::new)
        .map_err(Into::into)
}

/// The [`Socket`] implementation using Unix Domain Sockets.
#[derive(Debug)]
pub struct Stream(Async<StdUnixStream>);

impl Socket for Stream {
    type ReadHalf = ReadHalf;
    type WriteHalf = WriteHalf;

    const CAN_TRANSFER_FDS: bool = true;

    fn split(self) -> (Self::ReadHalf, Self::WriteHalf) {
        let stream = Arc::new(self.0);

        (ReadHalf(Arc::clone(&stream)), WriteHalf(stream))
    }
}

impl From<Async<StdUnixStream>> for Stream {
    fn from(stream: Async<StdUnixStream>) -> Self {
        Self(stream)
    }
}

impl AsFd for Stream {
    fn as_fd(&self) -> BorrowedFd<'_> {
        self.0.as_fd()
    }
}

impl socket::UnixSocket for Stream {}

/// The [`ReadHalf`] implementation using Unix Domain Sockets.
#[derive(Debug)]
pub struct ReadHalf(Arc<Async<StdUnixStream>>);

impl socket::ReadHalf for ReadHalf {
    async fn read(&mut self, buf: &mut [u8]) -> Result<(usize, Vec<OwnedFd>)> {
        use std::{future::poll_fn, task::Poll};

        poll_fn(|cx| loop {
            match crate::unix_utils::recvmsg(self.0.as_ref(), buf) {
                Ok(result) => return Poll::Ready(Ok(result)),
                Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {
                    match self.0.poll_readable(cx) {
                        Poll::Pending => return Poll::Pending,
                        Poll::Ready(res) => res?,
                    }
                }
                Err(e) => return Poll::Ready(Err(e.into())),
            }
        })
        .await
    }
}

impl AsFd for ReadHalf {
    fn as_fd(&self) -> BorrowedFd<'_> {
        self.0.as_ref().as_fd()
    }
}

impl socket::UnixSocket for ReadHalf {}

/// The [`WriteHalf`] implementation using Unix Domain Sockets.
#[derive(Debug)]
pub struct WriteHalf(Arc<Async<StdUnixStream>>);

impl socket::WriteHalf for WriteHalf {
    async fn write(&mut self, buf: &[u8], fds: &[impl AsFd]) -> Result<()> {
        use std::{future::poll_fn, task::Poll};

        // Convert to BorrowedFd for rustix.
        let borrowed_fds: Vec<BorrowedFd<'_>> = fds.iter().map(|f| f.as_fd()).collect();

        let mut pos = 0;
        while pos < buf.len() {
            // Use FDs on first write, empty slice on subsequent writes.
            let fds_to_send = if pos == 0 { &borrowed_fds[..] } else { &[] };

            let n: usize = poll_fn(|cx| loop {
                match crate::unix_utils::sendmsg(self.0.as_ref(), &buf[pos..], fds_to_send) {
                    Ok(bytes_sent) => return Poll::Ready(Ok::<_, crate::Error>(bytes_sent)),
                    Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {
                        match self.0.poll_writable(cx) {
                            Poll::Pending => return Poll::Pending,
                            Poll::Ready(res) => res?,
                        }
                    }
                    Err(e) => return Poll::Ready(Err(e.into())),
                }
            })
            .await?;

            pos += n;
        }

        Ok(())
    }
}

impl AsFd for WriteHalf {
    fn as_fd(&self) -> BorrowedFd<'_> {
        self.0.as_ref().as_fd()
    }
}

impl socket::UnixSocket for WriteHalf {}