redis-driver 0.3.0

Redis async driver for Rust
Documentation
use crate::{Result};
#[cfg(feature = "tokio-runtime")]
use crate::Error;
#[cfg(feature = "tls")]
use crate::TlsConfig;
use futures::{Future, FutureExt};
use log::{debug, info};
use std::{
    pin::Pin,
    task::{Context, Poll},
    time::Duration,
};

#[cfg(feature = "tokio-runtime")]
pub(crate) type TcpStreamReader = tokio::io::ReadHalf<tokio::net::TcpStream>;
#[cfg(feature = "tokio-runtime")]
pub(crate) type TcpStreamWriter = tokio::io::WriteHalf<tokio::net::TcpStream>;
#[cfg(feature = "tokio-tls")]
pub(crate) type TcpTlsStreamReader =
    tokio::io::ReadHalf<tokio_native_tls::TlsStream<tokio::net::TcpStream>>;
#[cfg(feature = "tokio-tls")]
pub(crate) type TcpTlsStreamWriter =
    tokio::io::WriteHalf<tokio_native_tls::TlsStream<tokio::net::TcpStream>>;

#[cfg(feature = "async-std-runtime")]
pub(crate) type TcpStreamReader =
    tokio_util::compat::Compat<futures::io::ReadHalf<async_std::net::TcpStream>>;
#[cfg(feature = "async-std-runtime")]
pub(crate) type TcpStreamWriter =
    tokio_util::compat::Compat<futures::io::WriteHalf<async_std::net::TcpStream>>;
#[cfg(feature = "async-std-tls")]
pub(crate) type TcpTlsStreamReader = tokio_util::compat::Compat<
    futures::io::ReadHalf<async_native_tls::TlsStream<async_std::net::TcpStream>>,
>;
#[cfg(feature = "async-std-tls")]
pub(crate) type TcpTlsStreamWriter = tokio_util::compat::Compat<
    futures::io::WriteHalf<async_native_tls::TlsStream<async_std::net::TcpStream>>,
>;

#[cfg(feature = "tokio-runtime")]
pub(crate) async fn tcp_connect(
    host: &str,
    port: u16,
) -> Result<(TcpStreamReader, TcpStreamWriter)> {
    debug!("Connecting to {host}:{port}...");

    let stream = tokio::net::TcpStream::connect((host, port)).await?;
    let (reader, writer) = tokio::io::split(stream);

    info!("Connected to {host}:{port}");

    Ok((reader, writer))
}

#[cfg(feature = "tokio-runtime")]
#[cfg(feature = "tokio-tls")]
#[cfg(feature = "tls")]
pub(crate) async fn tcp_tls_connect(
    host: &str,
    port: u16,
    tls_config: &TlsConfig,
) -> Result<(TcpTlsStreamReader, TcpTlsStreamWriter)> {
    debug!("Connecting to {host}:{port}...");

    let stream = tokio::net::TcpStream::connect((host, port)).await?;
    let builder = tls_config.into_tls_connector_builder();
    let tls_connector: native_tls::TlsConnector = builder.build()?;
    let tls_connector = tokio_native_tls::TlsConnector::from(tls_connector);
    let tls_stream = tls_connector.connect(host, stream).await?;
    let (reader, writer) = tokio::io::split(tls_stream);

    info!("Connected to {host}:{port}");

    Ok((reader, writer))
}

#[cfg(feature = "async-std-runtime")]
pub(crate) async fn tcp_connect(
    host: &str,
    port: u16,
) -> Result<(TcpStreamReader, TcpStreamWriter)> {
    use futures::AsyncReadExt;
    use tokio_util::compat::{FuturesAsyncReadCompatExt, FuturesAsyncWriteCompatExt};

    debug!("Connecting to {host}:{port}...");

    let stream = async_std::net::TcpStream::connect((host, port)).await?;
    let (reader, writer) = stream.split();
    let reader = reader.compat();
    let writer = writer.compat_write();

    info!("Connected to {host}:{port}");

    Ok((reader, writer))
}

#[cfg(feature = "async-std-runtime")]
#[cfg(feature = "async-std-tls")]
#[cfg(feature = "tls")]
pub(crate) async fn tcp_tls_connect(
    host: &str,
    port: u16,
    tls_config: &TlsConfig,
) -> Result<(TcpTlsStreamReader, TcpTlsStreamWriter)> {
    use futures::AsyncReadExt;
    use tokio_util::compat::{FuturesAsyncReadCompatExt, FuturesAsyncWriteCompatExt};

    debug!("Connecting to {host}:{port}...");

    let stream = async_std::net::TcpStream::connect((host, port)).await?;
    let builder = tls_config.into_tls_connector_builder();
    let tls_connector: async_native_tls::TlsConnector = builder.into();
    let tls_stream = tls_connector.connect(host, stream).await?;
    let (reader, writer) = tls_stream.split();
    let reader = reader.compat();
    let writer = writer.compat_write();

    info!("Connected to {host}:{port}");

    Ok((reader, writer))
}

pub enum JoinHandle<T> {
    #[cfg(feature = "tokio-runtime")]
    Tokio(tokio::task::JoinHandle<T>),
    #[cfg(feature = "async-std-runtime")]
    AsyncStd(async_std::task::JoinHandle<T>),
}

impl<T> Future for JoinHandle<T> {
    type Output = Result<T>;

    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
        match self.get_mut() {
            #[cfg(feature = "tokio-runtime")]
            JoinHandle::Tokio(join_handle) => match join_handle.poll_unpin(cx) {
                Poll::Ready(Ok(result)) => Poll::Ready(Ok(result)),
                Poll::Ready(Err(e)) => Poll::Ready(Err(Error::Client(format!("JoinError: {e}")))),
                Poll::Pending => Poll::Pending,
            },
            #[cfg(feature = "async-std-runtime")]
            JoinHandle::AsyncStd(join_handle) => match join_handle.poll_unpin(cx) {
                Poll::Ready(result) => Poll::Ready(Ok(result)),
                Poll::Pending => Poll::Pending,
            },
        }
    }
}

#[cfg(feature = "tokio-runtime")]
pub(crate) fn spawn<F, T>(future: F) -> JoinHandle<T>
where
    F: Future<Output = T> + Send + 'static,
    T: Send + 'static,
{
    JoinHandle::Tokio(tokio::spawn(future))
}

#[cfg(feature = "async-std-runtime")]
pub(crate) fn spawn<F, T>(future: F) -> JoinHandle<T>
where
    F: Future<Output = T> + Send + 'static,
    T: Send + 'static,
{
    JoinHandle::AsyncStd(async_std::task::spawn(future))
}

#[allow(dead_code)]
#[cfg(feature = "tokio-runtime")]
pub(crate) async fn sleep(duration: Duration) {
    tokio::time::sleep(duration).await;
}

#[allow(dead_code)]
#[cfg(feature = "async-std-runtime")]
pub(crate) async fn sleep(duration: Duration) {
    async_std::task::sleep(duration).await;
}