use async_trait::async_trait;
use std::collections::VecDeque;
use std::io;
use std::mem;
use std::net::SocketAddr;
use std::net::ToSocketAddrs;
#[cfg(unix)]
use std::path::Path;
use std::pin::Pin;
use std::task::{self, Poll};
use combine::{parser::combinator::AnySendSyncPartialState, stream::PointerOffset};
use ::tokio::{
io::{AsyncRead, AsyncWrite, AsyncWriteExt},
sync::{mpsc, oneshot},
};
#[cfg(feature = "tls")]
use native_tls::TlsConnector;
#[cfg(any(feature = "tokio-comp", feature = "async-std-comp"))]
use tokio_util::codec::Decoder;
use futures_util::{
future::{Future, FutureExt, TryFutureExt},
ready,
sink::Sink,
stream::{self, Stream, StreamExt, TryStreamExt as _},
};
use pin_project_lite::pin_project;
use crate::cmd::{cmd, Cmd};
use crate::connection::Msg;
use crate::connection::{ConnectionAddr, ConnectionInfo};
#[cfg(any(feature = "tokio-comp", feature = "async-std-comp"))]
use crate::parser::ValueCodec;
use crate::types::{ErrorKind, RedisError, RedisFuture, RedisResult, Value};
use crate::{from_redis_value, ToRedisArgs};
#[cfg(feature = "async-std-comp")]
#[cfg_attr(docsrs, doc(cfg(feature = "async-std-comp")))]
pub mod async_std;
#[cfg(feature = "tokio-comp")]
#[cfg_attr(docsrs, doc(cfg(feature = "tokio-comp")))]
pub mod tokio;
#[async_trait]
pub(crate) trait RedisRuntime: AsyncStream + Send + Sync + Sized + 'static {
async fn connect_tcp(socket_addr: SocketAddr) -> RedisResult<Self>;
#[cfg(feature = "tls")]
async fn connect_tcp_tls(
hostname: &str,
socket_addr: SocketAddr,
insecure: bool,
) -> RedisResult<Self>;
#[cfg(unix)]
async fn connect_unix(path: &Path) -> RedisResult<Self>;
fn spawn(f: impl Future<Output = ()> + Send + 'static);
fn boxed(self) -> Pin<Box<dyn AsyncStream + Send + Sync>> {
Box::pin(self)
}
}
#[derive(Clone, Debug)]
pub(crate) enum Runtime {
#[cfg(feature = "tokio-comp")]
Tokio,
#[cfg(feature = "async-std-comp")]
AsyncStd,
}
impl Runtime {
pub(crate) fn locate() -> Self {
#[cfg(all(feature = "tokio-comp", not(feature = "async-std-comp")))]
{
Runtime::Tokio
}
#[cfg(all(not(feature = "tokio-comp"), feature = "async-std-comp"))]
{
Runtime::AsyncStd
}
#[cfg(all(feature = "tokio-comp", feature = "async-std-comp"))]
{
if ::tokio::runtime::Handle::try_current().is_ok() {
Runtime::Tokio
} else {
Runtime::AsyncStd
}
}
#[cfg(all(not(feature = "tokio-comp"), not(feature = "async-std-comp")))]
{
compile_error!("tokio-comp or async-std-comp features required for aio feature")
}
}
#[allow(dead_code)]
fn spawn(&self, f: impl Future<Output = ()> + Send + 'static) {
match self {
#[cfg(feature = "tokio-comp")]
Runtime::Tokio => tokio::Tokio::spawn(f),
#[cfg(feature = "async-std-comp")]
Runtime::AsyncStd => async_std::AsyncStd::spawn(f),
}
}
}
pub trait AsyncStream: AsyncRead + AsyncWrite {}
impl<S> AsyncStream for S where S: AsyncRead + AsyncWrite {}
pub struct PubSub<C = Pin<Box<dyn AsyncStream + Send + Sync>>>(Connection<C>);
impl<C> PubSub<C>
where
C: Unpin + AsyncRead + AsyncWrite + Send,
{
fn new(con: Connection<C>) -> Self {
Self(con)
}
pub async fn subscribe<T: ToRedisArgs>(&mut self, channel: T) -> RedisResult<()> {
Ok(cmd("SUBSCRIBE")
.arg(channel)
.query_async(&mut self.0)
.await?)
}
pub async fn psubscribe<T: ToRedisArgs>(&mut self, pchannel: T) -> RedisResult<()> {
Ok(cmd("PSUBSCRIBE")
.arg(pchannel)
.query_async(&mut self.0)
.await?)
}
pub async fn unsubscribe<T: ToRedisArgs>(&mut self, channel: T) -> RedisResult<()> {
Ok(cmd("UNSUBSCRIBE")
.arg(channel)
.query_async(&mut self.0)
.await?)
}
pub async fn punsubscribe<T: ToRedisArgs>(&mut self, pchannel: T) -> RedisResult<()> {
Ok(cmd("PUNSUBSCRIBE")
.arg(pchannel)
.query_async(&mut self.0)
.await?)
}
pub fn on_message<'a>(&'a mut self) -> impl Stream<Item = Msg> + 'a {
ValueCodec::default()
.framed(&mut self.0.con)
.into_stream()
.filter_map(|msg| Box::pin(async move { Msg::from_value(&msg.ok()?) }))
}
pub fn into_on_message(self) -> impl Stream<Item = Msg> {
ValueCodec::default()
.framed(self.0.con)
.into_stream()
.filter_map(|msg| Box::pin(async move { Msg::from_value(&msg.ok()?) }))
}
pub async fn into_connection(mut self) -> Connection<C> {
self.0.exit_pubsub().await.ok();
self.0
}
}
pub struct Connection<C = Pin<Box<dyn AsyncStream + Send + Sync>>> {
con: C,
buf: Vec<u8>,
decoder: combine::stream::Decoder<AnySendSyncPartialState, PointerOffset<[u8]>>,
db: i64,
pubsub: bool,
}
fn assert_sync<T: Sync>() {}
#[allow(unused)]
fn test() {
assert_sync::<Connection>();
}
impl<C> Connection<C> {
pub(crate) fn map<D>(self, f: impl FnOnce(C) -> D) -> Connection<D> {
let Self {
con,
buf,
decoder,
db,
pubsub,
} = self;
Connection {
con: f(con),
buf,
decoder,
db,
pubsub,
}
}
}
impl<C> Connection<C>
where
C: Unpin + AsyncRead + AsyncWrite + Send,
{
pub(crate) async fn new(connection_info: &ConnectionInfo, con: C) -> RedisResult<Self> {
let mut rv = Connection {
con,
buf: Vec::new(),
decoder: combine::stream::Decoder::new(),
db: connection_info.db,
pubsub: false,
};
authenticate(connection_info, &mut rv).await?;
Ok(rv)
}
pub fn into_pubsub(self) -> PubSub<C> {
PubSub::new(self)
}
async fn read_response(&mut self) -> RedisResult<Value> {
crate::parser::parse_redis_value_async(&mut self.decoder, &mut self.con).await
}
async fn exit_pubsub(&mut self) -> RedisResult<()> {
let res = self.clear_active_subscriptions().await;
if res.is_ok() {
self.pubsub = false;
} else {
self.pubsub = true;
}
res
}
async fn clear_active_subscriptions(&mut self) -> RedisResult<()> {
{
let unsubscribe = crate::Pipeline::new()
.add_command(cmd("UNSUBSCRIBE"))
.add_command(cmd("PUNSUBSCRIBE"))
.get_packed_pipeline();
self.con.write_all(&unsubscribe).await?;
}
let mut received_unsub = false;
let mut received_punsub = false;
loop {
let res: (Vec<u8>, (), isize) = from_redis_value(&self.read_response().await?)?;
match res.0.first() {
Some(&b'u') => received_unsub = true,
Some(&b'p') => received_punsub = true,
_ => (),
}
if received_unsub && received_punsub && res.2 == 0 {
break;
}
}
Ok(())
}
}
pub(crate) async fn connect<C>(connection_info: &ConnectionInfo) -> RedisResult<Connection<C>>
where
C: Unpin + RedisRuntime + AsyncRead + AsyncWrite + Send,
{
let con = connect_simple::<C>(connection_info).await?;
Connection::new(connection_info, con).await
}
async fn authenticate<C>(connection_info: &ConnectionInfo, con: &mut C) -> RedisResult<()>
where
C: ConnectionLike,
{
if let Some(passwd) = &connection_info.passwd {
let mut command = cmd("AUTH");
if let Some(username) = &connection_info.username {
command.arg(username);
}
match command.arg(passwd).query_async(con).await {
Ok(Value::Okay) => (),
Err(e) => {
let err_msg = e.detail().ok_or((
ErrorKind::AuthenticationFailed,
"Password authentication failed",
))?;
if !err_msg.contains("wrong number of arguments for 'auth' command") {
fail!((
ErrorKind::AuthenticationFailed,
"Password authentication failed",
));
}
let mut command = cmd("AUTH");
match command.arg(passwd).query_async(con).await {
Ok(Value::Okay) => (),
_ => {
fail!((
ErrorKind::AuthenticationFailed,
"Password authentication failed"
));
}
}
}
_ => {
fail!((
ErrorKind::AuthenticationFailed,
"Password authentication failed"
));
}
}
}
if connection_info.db != 0 {
match cmd("SELECT").arg(connection_info.db).query_async(con).await {
Ok(Value::Okay) => (),
_ => fail!((
ErrorKind::ResponseError,
"Redis server refused to switch database"
)),
}
}
Ok(())
}
pub(crate) async fn connect_simple<T: RedisRuntime>(
connection_info: &ConnectionInfo,
) -> RedisResult<T> {
Ok(match *connection_info.addr {
ConnectionAddr::Tcp(ref host, port) => {
let socket_addr = get_socket_addrs(host, port)?;
<T>::connect_tcp(socket_addr).await?
}
#[cfg(feature = "tls")]
ConnectionAddr::TcpTls {
ref host,
port,
insecure,
} => {
let socket_addr = get_socket_addrs(host, port)?;
<T>::connect_tcp_tls(host, socket_addr, insecure).await?
}
#[cfg(not(feature = "tls"))]
ConnectionAddr::TcpTls { .. } => {
fail!((
ErrorKind::InvalidClientConfig,
"Cannot connect to TCP with TLS without the tls feature"
));
}
#[cfg(unix)]
ConnectionAddr::Unix(ref path) => <T>::connect_unix(path).await?,
#[cfg(not(unix))]
ConnectionAddr::Unix(_) => {
return Err(RedisError::from((
ErrorKind::InvalidClientConfig,
"Cannot connect to unix sockets \
on this platform",
)))
}
})
}
fn get_socket_addrs(host: &str, port: u16) -> RedisResult<SocketAddr> {
let mut socket_addrs = (&host[..], port).to_socket_addrs()?;
match socket_addrs.next() {
Some(socket_addr) => Ok(socket_addr),
None => Err(RedisError::from((
ErrorKind::InvalidClientConfig,
"No address found for host",
))),
}
}
pub trait ConnectionLike {
fn req_packed_command<'a>(&'a mut self, cmd: &'a Cmd) -> RedisFuture<'a, Value>;
fn req_packed_commands<'a>(
&'a mut self,
cmd: &'a crate::Pipeline,
offset: usize,
count: usize,
) -> RedisFuture<'a, Vec<Value>>;
fn get_db(&self) -> i64;
}
impl<C> ConnectionLike for Connection<C>
where
C: Unpin + AsyncRead + AsyncWrite + Send,
{
fn req_packed_command<'a>(&'a mut self, cmd: &'a Cmd) -> RedisFuture<'a, Value> {
(async move {
if self.pubsub {
self.exit_pubsub().await?;
}
self.buf.clear();
cmd.write_packed_command(&mut self.buf);
self.con.write_all(&self.buf).await?;
self.read_response().await
})
.boxed()
}
fn req_packed_commands<'a>(
&'a mut self,
cmd: &'a crate::Pipeline,
offset: usize,
count: usize,
) -> RedisFuture<'a, Vec<Value>> {
(async move {
if self.pubsub {
self.exit_pubsub().await?;
}
self.buf.clear();
cmd.write_packed_pipeline(&mut self.buf);
self.con.write_all(&self.buf).await?;
for _ in 0..offset {
self.read_response().await?;
}
let mut rv = Vec::with_capacity(count);
for _ in 0..count {
rv.push(self.read_response().await?);
}
Ok(rv)
})
.boxed()
}
fn get_db(&self) -> i64 {
self.db
}
}
type PipelineOutput<O, E> = oneshot::Sender<Result<Vec<O>, E>>;
struct InFlight<O, E> {
output: PipelineOutput<O, E>,
response_count: usize,
buffer: Vec<O>,
}
struct PipelineMessage<S, I, E> {
input: S,
output: PipelineOutput<I, E>,
response_count: usize,
}
struct Pipeline<SinkItem, I, E>(mpsc::Sender<PipelineMessage<SinkItem, I, E>>);
impl<SinkItem, I, E> Clone for Pipeline<SinkItem, I, E> {
fn clone(&self) -> Self {
Pipeline(self.0.clone())
}
}
pin_project! {
struct PipelineSink<T, I, E> {
#[pin]
sink_stream: T,
in_flight: VecDeque<InFlight<I, E>>,
error: Option<E>,
}
}
impl<T, I, E> PipelineSink<T, I, E>
where
T: Stream<Item = Result<I, E>> + 'static,
{
fn new<SinkItem>(sink_stream: T) -> Self
where
T: Sink<SinkItem, Error = E> + Stream<Item = Result<I, E>> + 'static,
{
PipelineSink {
sink_stream,
in_flight: VecDeque::new(),
error: None,
}
}
fn poll_read(mut self: Pin<&mut Self>, cx: &mut task::Context) -> Poll<Result<(), ()>> {
loop {
let item = match ready!(self.as_mut().project().sink_stream.poll_next(cx)) {
Some(Ok(item)) => Ok(item),
Some(Err(err)) => Err(err),
None => return Poll::Ready(Err(())),
};
self.as_mut().send_result(item);
}
}
fn send_result(self: Pin<&mut Self>, result: Result<I, E>) {
let self_ = self.project();
let response = {
let entry = match self_.in_flight.front_mut() {
Some(entry) => entry,
None => return,
};
match result {
Ok(item) => {
entry.buffer.push(item);
if entry.response_count > entry.buffer.len() {
return;
}
Ok(mem::replace(&mut entry.buffer, Vec::new()))
}
Err(err) => Err(err),
}
};
let entry = self_.in_flight.pop_front().unwrap();
entry.output.send(response).ok();
}
}
impl<SinkItem, T, I, E> Sink<PipelineMessage<SinkItem, I, E>> for PipelineSink<T, I, E>
where
T: Sink<SinkItem, Error = E> + Stream<Item = Result<I, E>> + 'static,
{
type Error = ();
fn poll_ready(
mut self: Pin<&mut Self>,
cx: &mut task::Context,
) -> Poll<Result<(), Self::Error>> {
match ready!(self.as_mut().project().sink_stream.poll_ready(cx)) {
Ok(()) => Ok(()).into(),
Err(err) => {
*self.project().error = Some(err);
Ok(()).into()
}
}
}
fn start_send(
mut self: Pin<&mut Self>,
PipelineMessage {
input,
output,
response_count,
}: PipelineMessage<SinkItem, I, E>,
) -> Result<(), Self::Error> {
let self_ = self.as_mut().project();
if let Some(err) = self_.error.take() {
let _ = output.send(Err(err));
return Err(());
}
match self_.sink_stream.start_send(input) {
Ok(()) => {
self_.in_flight.push_back(InFlight {
output,
response_count,
buffer: Vec::new(),
});
Ok(())
}
Err(err) => {
let _ = output.send(Err(err));
Err(())
}
}
}
fn poll_flush(
mut self: Pin<&mut Self>,
cx: &mut task::Context,
) -> Poll<Result<(), Self::Error>> {
ready!(self
.as_mut()
.project()
.sink_stream
.poll_flush(cx)
.map_err(|err| {
self.as_mut().send_result(Err(err));
}))?;
self.poll_read(cx)
}
fn poll_close(
mut self: Pin<&mut Self>,
cx: &mut task::Context,
) -> Poll<Result<(), Self::Error>> {
if !self.in_flight.is_empty() {
ready!(self.as_mut().poll_flush(cx))?;
}
let this = self.as_mut().project();
this.sink_stream.poll_close(cx).map_err(|err| {
self.send_result(Err(err));
})
}
}
impl<SinkItem, I, E> Pipeline<SinkItem, I, E>
where
SinkItem: Send + 'static,
I: Send + 'static,
E: Send + 'static,
{
fn new<T>(sink_stream: T) -> (Self, impl Future<Output = ()>)
where
T: Sink<SinkItem, Error = E> + Stream<Item = Result<I, E>> + 'static,
T: Send + 'static,
T::Item: Send,
T::Error: Send,
T::Error: ::std::fmt::Debug,
{
const BUFFER_SIZE: usize = 50;
let (sender, mut receiver) = mpsc::channel(BUFFER_SIZE);
let f = stream::poll_fn(move |cx| receiver.poll_recv(cx))
.map(Ok)
.forward(PipelineSink::new::<SinkItem>(sink_stream))
.map(|_| ());
(Pipeline(sender), f)
}
async fn send(&mut self, item: SinkItem) -> Result<I, Option<E>> {
self.send_recv_multiple(item, 1)
.map_ok(|mut item| item.pop().unwrap())
.await
}
async fn send_recv_multiple(
&mut self,
input: SinkItem,
count: usize,
) -> Result<Vec<I>, Option<E>> {
let (sender, receiver) = oneshot::channel();
self.0
.send(PipelineMessage {
input,
response_count: count,
output: sender,
})
.map_err(|_| None)
.and_then(|_| {
receiver.map(|result| {
match result {
Ok(result) => result.map_err(Some),
Err(_) => {
Err(None)
}
}
})
})
.await
}
}
#[derive(Clone)]
pub struct MultiplexedConnection {
pipeline: Pipeline<Vec<u8>, Value, RedisError>,
db: i64,
}
impl MultiplexedConnection {
pub(crate) async fn new<C>(
connection_info: &ConnectionInfo,
con: C,
) -> RedisResult<(Self, impl Future<Output = ()>)>
where
C: Unpin + AsyncRead + AsyncWrite + Send + 'static,
{
fn boxed(
f: impl Future<Output = ()> + Send + 'static,
) -> Pin<Box<dyn Future<Output = ()> + Send>> {
Box::pin(f)
}
#[cfg(all(not(feature = "tokio-comp"), not(feature = "async-std-comp")))]
compile_error!("tokio-comp or async-std-comp features required for aio feature");
let codec = ValueCodec::default().framed(con);
let (pipeline, driver) = Pipeline::new(codec);
let driver = boxed(driver);
let mut con = MultiplexedConnection {
pipeline,
db: connection_info.db,
};
let driver = {
let auth = authenticate(connection_info, &mut con);
futures_util::pin_mut!(auth);
match futures_util::future::select(auth, driver).await {
futures_util::future::Either::Left((result, driver)) => {
result?;
driver
}
futures_util::future::Either::Right(((), _)) => {
unreachable!("Multiplexed connection driver unexpectedly terminated")
}
}
};
Ok((con, driver))
}
}
impl ConnectionLike for MultiplexedConnection {
fn req_packed_command<'a>(&'a mut self, cmd: &'a Cmd) -> RedisFuture<'a, Value> {
(async move {
let value = self
.pipeline
.send(cmd.get_packed_command())
.await
.map_err(|err| {
err.unwrap_or_else(|| {
RedisError::from(io::Error::from(io::ErrorKind::BrokenPipe))
})
})?;
Ok(value)
})
.boxed()
}
fn req_packed_commands<'a>(
&'a mut self,
cmd: &'a crate::Pipeline,
offset: usize,
count: usize,
) -> RedisFuture<'a, Vec<Value>> {
(async move {
let mut value = self
.pipeline
.send_recv_multiple(cmd.get_packed_pipeline(), offset + count)
.await
.map_err(|err| {
err.unwrap_or_else(|| {
RedisError::from(io::Error::from(io::ErrorKind::BrokenPipe))
})
})?;
value.drain(..offset);
Ok(value)
})
.boxed()
}
fn get_db(&self) -> i64 {
self.db
}
}
#[cfg(feature = "connection-manager")]
mod connection_manager {
use super::*;
use std::sync::Arc;
use arc_swap::{self, ArcSwap};
use futures::future::{self, Shared};
use futures_util::future::BoxFuture;
use crate::Client;
#[derive(Clone)]
pub struct ConnectionManager {
client: Client,
connection: Arc<ArcSwap<SharedRedisFuture<MultiplexedConnection>>>,
runtime: Runtime,
}
type CloneableRedisResult<T> = Result<T, Arc<RedisError>>;
type SharedRedisFuture<T> = Shared<BoxFuture<'static, CloneableRedisResult<T>>>;
impl ConnectionManager {
pub async fn new(client: Client) -> RedisResult<Self> {
let runtime = Runtime::locate();
let connection = client.get_multiplexed_async_connection().await?;
Ok(Self {
client,
connection: Arc::new(ArcSwap::from_pointee(
future::ok(connection).boxed().shared(),
)),
runtime,
})
}
fn reconnect(
&self,
current: arc_swap::Guard<Arc<SharedRedisFuture<MultiplexedConnection>>>,
) {
let client = self.client.clone();
let new_connection: SharedRedisFuture<MultiplexedConnection> =
async move { Ok(client.get_multiplexed_async_connection().await?) }
.boxed()
.shared();
let new_connection_arc = Arc::new(new_connection.clone());
let prev = self
.connection
.compare_and_swap(¤t, new_connection_arc);
if Arc::ptr_eq(&prev, ¤t) {
self.runtime.spawn(new_connection.map(|_| ()));
}
}
}
macro_rules! reconnect_if_dropped {
($self:expr, $result:expr, $current:expr) => {
if let Err(ref e) = $result {
if e.is_connection_dropped() {
$self.reconnect($current);
}
}
};
}
macro_rules! reconnect_if_io_error {
($self:expr, $result:expr, $current:expr) => {
if let Err(e) = $result {
if e.is_io_error() {
$self.reconnect($current);
}
return Err(e);
}
};
}
impl ConnectionLike for ConnectionManager {
fn req_packed_command<'a>(&'a mut self, cmd: &'a Cmd) -> RedisFuture<'a, Value> {
(async move {
let guard = self.connection.load();
let connection_result = (**guard)
.clone()
.await
.map_err(|e| e.clone_mostly("Reconnecting failed"));
reconnect_if_io_error!(self, connection_result, guard);
let result = connection_result?.req_packed_command(cmd).await;
reconnect_if_dropped!(self, &result, guard);
result
})
.boxed()
}
fn req_packed_commands<'a>(
&'a mut self,
cmd: &'a crate::Pipeline,
offset: usize,
count: usize,
) -> RedisFuture<'a, Vec<Value>> {
(async move {
let guard = self.connection.load();
let connection_result = (**guard)
.clone()
.await
.map_err(|e| e.clone_mostly("Reconnecting failed"));
reconnect_if_io_error!(self, connection_result, guard);
let result = connection_result?
.req_packed_commands(cmd, offset, count)
.await;
reconnect_if_dropped!(self, &result, guard);
result
})
.boxed()
}
fn get_db(&self) -> i64 {
self.client.connection_info().db
}
}
}
#[cfg(feature = "connection-manager")]
pub use connection_manager::ConnectionManager;