async-memcached 0.6.0

An Tokio-based memcached client for Rust.
Documentation
use pin_project::pin_project;
use std::io;
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio::io::{AsyncBufRead, AsyncRead, AsyncWrite, BufReader, BufWriter};
use tokio::net::{lookup_host, TcpSocket, TcpStream, UnixStream};

use crate::Error;

#[pin_project(project = ConnectionProjection)]
#[derive(Debug)]
pub enum Connection {
    Tcp(#[pin] BufReader<BufWriter<TcpStream>>),
    Unix(#[pin] BufReader<BufWriter<UnixStream>>),
}

impl AsyncRead for Connection {
    fn poll_read(
        self: Pin<&mut Self>,
        cx: &mut Context<'_>,
        buf: &mut tokio::io::ReadBuf<'_>,
    ) -> Poll<io::Result<()>> {
        match self.project() {
            ConnectionProjection::Tcp(s) => s.poll_read(cx, buf),
            ConnectionProjection::Unix(s) => s.poll_read(cx, buf),
        }
    }
}

impl AsyncWrite for Connection {
    fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>> {
        match self.project() {
            ConnectionProjection::Tcp(s) => s.poll_write(cx, buf),
            ConnectionProjection::Unix(s) => s.poll_write(cx, buf),
        }
    }

    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
        match self.project() {
            ConnectionProjection::Tcp(s) => s.poll_flush(cx),
            ConnectionProjection::Unix(s) => s.poll_flush(cx),
        }
    }

    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
        match self.project() {
            ConnectionProjection::Tcp(s) => s.poll_shutdown(cx),
            ConnectionProjection::Unix(s) => s.poll_shutdown(cx),
        }
    }
}

impl AsyncBufRead for Connection {
    fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<&[u8]>> {
        match self.project() {
            ConnectionProjection::Tcp(s) => s.poll_fill_buf(cx),
            ConnectionProjection::Unix(s) => s.poll_fill_buf(cx),
        }
    }

    fn consume(self: Pin<&mut Self>, amt: usize) {
        match self.project() {
            ConnectionProjection::Tcp(s) => s.consume(amt),
            ConnectionProjection::Unix(s) => s.consume(amt),
        }
    }
}

#[derive(Debug, PartialEq)]
enum Addr {
    Tcp(String),
    Unix(String),
    Unknown(String),
}

impl Addr {
    const DEFAULT_PORT: u16 = 11211;

    fn parse(dsn: &str) -> Result<Self, Error> {
        let url = url::Url::parse(dsn).map_err(|e| {
            Error::Connect(io::Error::new(
                io::ErrorKind::InvalidInput,
                format!("failed to parse DSN: {}", e),
            ))
        })?;

        match url.scheme() {
            "unix" => Ok(Addr::Unix(url.path().to_string())),
            "tcp" => Ok(Addr::Tcp(format!(
                "{}:{}",
                url.host_str().ok_or_else(|| {
                    Error::Connect(io::Error::new(
                        io::ErrorKind::InvalidInput,
                        "no host found in DSN",
                    ))
                })?,
                url.port().unwrap_or(Self::DEFAULT_PORT)
            ))),
            _ => Ok(Addr::Unknown(dsn.to_string())),
        }
    }
}

impl Connection {
    pub async fn new<S: AsRef<str>>(dsn: S) -> Result<Self, Error> {
        match Addr::parse(dsn.as_ref())? {
            Addr::Unix(path) => UnixStream::connect(path)
                .await
                .map(|c| Connection::Unix(BufReader::new(BufWriter::new(c))))
                .map_err(Error::Connect),
            Addr::Tcp(url) | Addr::Unknown(url) => {
                let addrs = lookup_host(url).await.map_err(Error::Connect)?;

                let mut last_err = None;

                for addr in addrs {
                    let socket = TcpSocket::new_v4().map_err(Error::Connect)?;
                    socket.set_nodelay(true).map_err(Error::Connect)?;
                    match socket.connect(addr).await {
                        Ok(stream) => {
                            return Ok(Connection::Tcp(BufReader::new(BufWriter::new(stream))))
                        }
                        Err(e) => last_err = Some(Error::Connect(e)),
                    }
                }
                Err(last_err.unwrap_or_else(|| {
                    Error::Connect(io::Error::new(
                        io::ErrorKind::InvalidInput,
                        "could not resolve to any address",
                    ))
                }))
            }
        }
    }
}

#[cfg(test)]
mod tests {
    use super::{Addr, Error};

    #[tokio::test]
    async fn test_unknown_scheme() {
        assert_eq!(
            Addr::parse("localhost:11211"),
            Ok(Addr::Unknown("localhost:11211".to_string()))
        )
    }

    #[tokio::test]
    async fn test_tcp_scheme() {
        assert_eq!(
            Addr::parse("tcp://localhost:11211"),
            Ok(Addr::Tcp("localhost:11211".to_string()))
        )
    }

    #[tokio::test]
    async fn test_unix_scheme() {
        assert_eq!(
            Addr::parse("unix:///tmp/memcached.sock"),
            Ok(Addr::Unix("/tmp/memcached.sock".to_string()))
        )
    }

    #[tokio::test]
    async fn test_invalid_url() {
        assert!(matches!(
            Addr::parse("tcp://").unwrap_err(),
            Error::Connect(_)
        ));
    }
}