postro 0.1.1

Asynchronous Postgres Driver and Utility
Documentation
use std::io;

/// An either `TcpStream` or `Socket`, which implement
/// `AsyncRead` and `AsyncWrite` transparently.
///
/// Require `tokio` feature, otherwise panic at runtime.
pub struct Socket {
    kind: Kind,
}

enum Kind {
    #[cfg(feature = "tokio")]
    TokioTcp(tokio::net::TcpStream),
    #[cfg(all(feature = "tokio", unix))]
    TokioUnixSocket(tokio::net::UnixStream),
}

impl Socket {
    pub async fn connect_tcp(host: &str, port: u16) -> io::Result<Socket> {
        #[cfg(feature = "tokio")]
        {
            let socket = tokio::net::TcpStream::connect((host,port)).await?;
            socket.set_nodelay(true)?;
            #[cfg(feature = "log")]
            log::debug!("Connected via TCP Stream: {:?}", socket.local_addr());
            Ok(Socket { kind: Kind::TokioTcp(socket) })
        }

        #[cfg(not(feature = "tokio"))]
        {
            let _ = (host,port);
            panic!("runtime disabled")
        }
    }

    pub async fn connect_socket(path: &str) -> io::Result<Socket> {
        #[cfg(all(feature = "tokio", unix))]
        {
            let socket = tokio::net::UnixStream::connect(path).await?;
            #[cfg(feature = "log")]
            log::debug!("Connected via Unix socket: {:?}", socket.peer_addr()?.as_pathname());
            Ok(Socket { kind: Kind::TokioUnixSocket(socket) })
        }

        #[cfg(not(all(feature = "tokio", unix)))]
        {
            let _ = path;
            panic!("runtime disabled")
        }
    }

    pub fn poll_shutdown(&mut self, _cx: &mut std::task::Context) -> std::task::Poll<io::Result<()>> {
        #[cfg(all(feature = "tokio", unix))]
        {
            tokio::io::AsyncWrite::poll_shutdown(std::pin::Pin::new(self), _cx)
        }

        #[cfg(not(all(feature = "tokio", unix)))]
        {
            panic!("runtime disabled")
        }
    }

    pub fn shutdown(&mut self) -> impl Future<Output = io::Result<()>> {
        std::future::poll_fn(|cx|self.poll_shutdown(cx))
    }
}

#[cfg(feature = "tokio")]
impl tokio::io::AsyncRead for Socket {
    fn poll_read(
        mut self: std::pin::Pin<&mut Self>,
        cx: &mut std::task::Context<'_>,
        buf: &mut tokio::io::ReadBuf<'_>,
    ) -> std::task::Poll<std::io::Result<()>> {
        use std::pin::Pin;
        match &mut self.kind {
            Kind::TokioTcp(t) => Pin::new(t).poll_read(cx, buf),
            #[cfg(unix)]
            Kind::TokioUnixSocket(u) => Pin::new(u).poll_read(cx, buf),
        }
    }
}

#[cfg(feature = "tokio")]
impl tokio::io::AsyncWrite for Socket {
    fn poll_write(
        mut self: std::pin::Pin<&mut Self>,
        cx: &mut std::task::Context<'_>,
        buf: &[u8],
    ) -> std::task::Poll<io::Result<usize>> {
        use std::pin::Pin;
        match &mut self.kind {
            Kind::TokioTcp(t) => Pin::new(t).poll_write(cx, buf),
            #[cfg(unix)]
            Kind::TokioUnixSocket(u) => Pin::new(u).poll_write(cx, buf),
        }
    }

    fn poll_write_vectored(
        mut self: std::pin::Pin<&mut Self>,
        cx: &mut std::task::Context<'_>,
        bufs: &[io::IoSlice<'_>],
    ) -> std::task::Poll<io::Result<usize>> {
        use std::pin::Pin;
        match &mut self.kind {
            Kind::TokioTcp(t) => Pin::new(t).poll_write_vectored(cx, bufs),
            #[cfg(unix)]
            Kind::TokioUnixSocket(u) => Pin::new(u).poll_write_vectored(cx, bufs),
        }
    }

    #[inline]
    fn is_write_vectored(&self) -> bool {
        true
    }

    #[inline]
    fn poll_flush(
        self: std::pin::Pin<&mut Self>,
        _: &mut std::task::Context<'_>,
    ) -> std::task::Poll<std::result::Result<(), std::io::Error>> {
        std::task::Poll::Ready(Ok(()))
    }

    fn poll_shutdown(
        mut self: std::pin::Pin<&mut Self>,
        cx: &mut std::task::Context<'_>,
    ) -> std::task::Poll<std::result::Result<(), std::io::Error>> {
        use std::pin::Pin;
        match &mut self.kind {
            Kind::TokioTcp(t) => Pin::new(t).poll_shutdown(cx),
            #[cfg(unix)]
            Kind::TokioUnixSocket(u) => Pin::new(u).poll_shutdown(cx),
        }
    }
}

impl std::fmt::Debug for Socket {
    fn fmt(&self, _f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        match &self.kind {
            #[cfg(feature = "tokio")]
            Kind::TokioTcp(tcp) => std::fmt::Debug::fmt(&tcp, _f),
            #[cfg(all(feature = "tokio", unix))]
            Kind::TokioUnixSocket(unix) => std::fmt::Debug::fmt(&unix, _f),
            #[cfg(not(feature = "tokio"))]
            _ => Ok(())
        }
    }
}