use std::collections::VecDeque;
use std::fmt::Arguments;
use std::io::{self, BufReader, Read, Write};
use std::mem;
use std::net::ToSocketAddrs;
#[cfg(unix)]
use tokio_uds::UnixStream;
use tokio_codec::{Decoder, Framed};
use tokio_io::{self, AsyncWrite};
use tokio_tcp::TcpStream;
use futures::future::{Either, Executor};
use futures::{future, try_ready, Async, AsyncSink, Future, Poll, Sink, StartSend, Stream};
use tokio_sync::{mpsc, oneshot};
use crate::cmd::cmd;
use crate::types::{ErrorKind, RedisError, RedisFuture, Value};
use crate::connection::{ConnectionAddr, ConnectionInfo};
use crate::parser::ValueCodec;
enum ActualConnection {
Tcp(BufReader<TcpStream>),
#[cfg(unix)]
Unix(BufReader<UnixStream>),
}
struct WriteWrapper<T>(BufReader<T>);
impl<T> Write for WriteWrapper<T>
where
T: Read + Write,
{
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.0.get_mut().write(buf)
}
fn flush(&mut self) -> io::Result<()> {
self.0.get_mut().flush()
}
fn write_all(&mut self, buf: &[u8]) -> io::Result<()> {
self.0.get_mut().write_all(buf)
}
fn write_fmt(&mut self, fmt: Arguments<'_>) -> io::Result<()> {
self.0.get_mut().write_fmt(fmt)
}
}
impl<T> AsyncWrite for WriteWrapper<T>
where
T: Read + AsyncWrite,
{
fn shutdown(&mut self) -> Poll<(), io::Error> {
self.0.get_mut().shutdown()
}
}
pub struct Connection {
con: ActualConnection,
db: i64,
}
macro_rules! with_connection {
($con:expr, $f:expr) => {
match $con {
#[cfg(not(unix))]
ActualConnection::Tcp(con) => {
$f(con).map(|(con, value)| (ActualConnection::Tcp(con), value))
}
#[cfg(unix)]
ActualConnection::Tcp(con) => {
Either::A($f(con).map(|(con, value)| (ActualConnection::Tcp(con), value)))
}
#[cfg(unix)]
ActualConnection::Unix(con) => {
Either::B($f(con).map(|(con, value)| (ActualConnection::Unix(con), value)))
}
}
};
}
macro_rules! with_write_connection {
($con:expr, $f:expr) => {
match $con {
#[cfg(not(unix))]
ActualConnection::Tcp(con) => {
$f(WriteWrapper(con)).map(|(con, value)| (ActualConnection::Tcp(con.0), value))
}
#[cfg(unix)]
ActualConnection::Tcp(con) => Either::A(
$f(WriteWrapper(con)).map(|(con, value)| (ActualConnection::Tcp(con.0), value)),
),
#[cfg(unix)]
ActualConnection::Unix(con) => Either::B(
$f(WriteWrapper(con)).map(|(con, value)| (ActualConnection::Unix(con.0), value)),
),
}
};
}
impl Connection {
pub fn read_response(self) -> impl Future<Item = (Self, Value), Error = RedisError> {
let db = self.db;
with_connection!(self.con, crate::parser::parse_redis_value_async).then(move |result| {
match result {
Ok((con, value)) => Ok((Connection { con, db }, value)),
Err(err) => {
Err(err)
}
}
})
}
}
pub fn connect(
connection_info: ConnectionInfo,
) -> impl Future<Item = Connection, Error = RedisError> {
let connection = match *connection_info.addr {
ConnectionAddr::Tcp(ref host, port) => {
let socket_addr = match (&host[..], port).to_socket_addrs() {
Ok(mut socket_addrs) => match socket_addrs.next() {
Some(socket_addr) => socket_addr,
None => {
return Either::A(future::err(RedisError::from((
ErrorKind::InvalidClientConfig,
"No address found for host",
))));
}
},
Err(err) => return Either::A(future::err(err.into())),
};
Either::A(
TcpStream::connect(&socket_addr)
.from_err()
.map(|con| ActualConnection::Tcp(BufReader::new(con))),
)
}
#[cfg(unix)]
ConnectionAddr::Unix(ref path) => Either::B(
UnixStream::connect(path).map(|stream| ActualConnection::Unix(BufReader::new(stream))),
),
#[cfg(not(unix))]
ConnectionAddr::Unix(_) => Either::B(future::err(RedisError::from((
ErrorKind::InvalidClientConfig,
"Cannot connect to unix sockets \
on this platform",
)))),
};
Either::B(connection.from_err().and_then(move |con| {
let rv = Connection {
con,
db: connection_info.db,
};
let login = match connection_info.passwd {
Some(ref passwd) => {
Either::A(cmd("AUTH").arg(&**passwd).query_async::<_, Value>(rv).then(
|x| match x {
Ok((rv, Value::Okay)) => Ok(rv),
_ => {
fail!((
ErrorKind::AuthenticationFailed,
"Password authentication failed"
));
}
},
))
}
None => Either::B(future::ok(rv)),
};
login.and_then(move |rv| {
if connection_info.db != 0 {
Either::A(
cmd("SELECT")
.arg(connection_info.db)
.query_async::<_, Value>(rv)
.then(|result| match result {
Ok((rv, Value::Okay)) => Ok(rv),
_ => fail!((
ErrorKind::ResponseError,
"Redis server refused to switch database"
)),
}),
)
} else {
Either::B(future::ok(rv))
}
})
}))
}
pub trait ConnectionLike: Sized {
fn req_packed_command(self, cmd: Vec<u8>) -> RedisFuture<(Self, Value)>;
fn req_packed_commands(
self,
cmd: Vec<u8>,
offset: usize,
count: usize,
) -> RedisFuture<(Self, Vec<Value>)>;
fn get_db(&self) -> i64;
}
impl ConnectionLike for Connection {
fn req_packed_command(self, cmd: Vec<u8>) -> RedisFuture<(Self, Value)> {
let db = self.db;
Box::new(
with_write_connection!(self.con, |con| tokio_io::io::write_all(con, cmd))
.from_err()
.and_then(move |(con, _)| Connection { con, db }.read_response()),
)
}
fn req_packed_commands(
self,
cmd: Vec<u8>,
offset: usize,
count: usize,
) -> RedisFuture<(Self, Vec<Value>)> {
let db = self.db;
Box::new(
with_write_connection!(self.con, |con| tokio_io::io::write_all(con, cmd))
.from_err()
.and_then(move |(con, _)| {
let mut con = Some(Connection { con, db });
let mut rv = vec![];
let mut future = None;
let mut idx = 0;
future::poll_fn(move || {
while idx < offset + count {
if future.is_none() {
future = Some(con.take().unwrap().read_response());
}
let (con2, item) = try_ready!(future.as_mut().unwrap().poll());
con = Some(con2);
future = None;
if idx >= offset {
rv.push(item);
}
idx += 1;
}
Ok(Async::Ready((
con.take().unwrap(),
mem::replace(&mut rv, Vec::new()),
)))
})
}),
)
}
fn get_db(&self) -> i64 {
self.db
}
}
type PipelineOutput<O, E> = oneshot::Sender<Result<Vec<O>, E>>;
struct InFlight<O, E> {
output: PipelineOutput<O, E>,
response_count: usize,
buffer: Vec<O>,
}
struct PipelineMessage<S, I, E> {
input: S,
output: PipelineOutput<I, E>,
response_count: usize,
}
struct Pipeline<T>(mpsc::Sender<PipelineMessage<T::SinkItem, T::Item, T::Error>>)
where
T: Stream + Sink;
impl<T> Clone for Pipeline<T>
where
T: Stream + Sink,
{
fn clone(&self) -> Self {
Pipeline(self.0.clone())
}
}
struct PipelineSink<T>
where
T: Sink<SinkError = <T as Stream>::Error> + Stream + 'static,
{
sink_stream: T,
in_flight: VecDeque<InFlight<T::Item, T::Error>>,
}
impl<T> PipelineSink<T>
where
T: Sink<SinkError = <T as Stream>::Error> + Stream + 'static,
{
fn poll_read(&mut self) -> Poll<(), ()> {
loop {
let item = match self.sink_stream.poll() {
Ok(Async::Ready(Some(item))) => Ok(item),
Ok(Async::Ready(None)) => return Err(()),
Err(err) => Err(err),
Ok(Async::NotReady) => return Ok(Async::NotReady),
};
self.send_result(item);
}
}
fn send_result(&mut self, result: Result<T::Item, T::Error>) {
let response = {
let entry = match self.in_flight.front_mut() {
Some(entry) => entry,
None => return,
};
match result {
Ok(item) => {
entry.buffer.push(item);
if entry.response_count > entry.buffer.len() {
return;
}
Ok(mem::replace(&mut entry.buffer, Vec::new()))
}
Err(err) => Err(err),
}
};
let entry = self.in_flight.pop_front().unwrap();
entry.output.send(response).ok();
}
}
impl<T> Sink for PipelineSink<T>
where
T: Sink<SinkError = <T as Stream>::Error> + Stream + 'static,
{
type SinkItem = PipelineMessage<T::SinkItem, T::Item, T::Error>;
type SinkError = ();
fn start_send(
&mut self,
PipelineMessage {
input,
output,
response_count,
}: Self::SinkItem,
) -> StartSend<Self::SinkItem, Self::SinkError> {
match self.sink_stream.start_send(input) {
Ok(AsyncSink::NotReady(input)) => Ok(AsyncSink::NotReady(PipelineMessage {
input,
output,
response_count,
})),
Ok(AsyncSink::Ready) => {
self.in_flight.push_back(InFlight {
output,
response_count,
buffer: Vec::new(),
});
Ok(AsyncSink::Ready)
}
Err(err) => {
let _ = output.send(Err(err));
Err(())
}
}
}
fn poll_complete(&mut self) -> Poll<(), Self::SinkError> {
try_ready!(self.sink_stream.poll_complete().map_err(|err| {
self.send_result(Err(err));
}));
self.poll_read()
}
fn close(&mut self) -> Poll<(), Self::SinkError> {
if !self.in_flight.is_empty() {
try_ready!(self.poll_complete());
}
self.sink_stream.close().map_err(|err| {
self.send_result(Err(err));
})
}
}
impl<T> Pipeline<T>
where
T: Sink<SinkError = <T as Stream>::Error> + Stream + Send + 'static,
T::SinkItem: Send,
T::Item: Send,
T::Error: Send,
T::Error: ::std::fmt::Debug,
{
fn new<E>(sink_stream: T, executor: E) -> Self
where
E: Executor<Box<dyn Future<Item = (), Error = ()> + Send>>,
{
const BUFFER_SIZE: usize = 50;
let (sender, receiver) = mpsc::channel(BUFFER_SIZE);
let f = receiver
.map_err(|_| ())
.forward(PipelineSink {
sink_stream,
in_flight: VecDeque::new(),
})
.map(|_| ());
executor.execute(Box::new(f)).unwrap();
Pipeline(sender)
}
fn send(
&self,
item: T::SinkItem,
) -> impl Future<Item = T::Item, Error = Option<T::Error>> + Send {
self.send_recv_multiple(item, 1)
.map(|mut item| item.pop().unwrap())
}
fn send_recv_multiple(
&self,
input: T::SinkItem,
count: usize,
) -> impl Future<Item = Vec<T::Item>, Error = Option<T::Error>> + Send {
let self_ = self.0.clone();
let (sender, receiver) = oneshot::channel();
self_
.send(PipelineMessage {
input,
response_count: count,
output: sender,
})
.map_err(|_| None)
.and_then(|_| {
receiver.then(|result| match result {
Ok(result) => result.map_err(Some),
Err(_) => {
Err(None)
}
})
})
}
}
#[derive(Clone)]
enum ActualPipeline {
Tcp(Pipeline<Framed<TcpStream, ValueCodec>>),
#[cfg(unix)]
Unix(Pipeline<Framed<UnixStream, ValueCodec>>),
}
#[derive(Clone)]
pub struct SharedConnection {
pipeline: ActualPipeline,
db: i64,
}
impl SharedConnection {
pub fn new<E>(con: Connection, executor: E) -> impl Future<Item = Self, Error = RedisError>
where
E: Executor<Box<dyn Future<Item = (), Error = ()> + Send>>,
{
future::lazy(|| {
let pipeline = match con.con {
ActualConnection::Tcp(tcp) => {
let codec = ValueCodec::default().framed(tcp.into_inner());
ActualPipeline::Tcp(Pipeline::new(codec, executor))
}
#[cfg(unix)]
ActualConnection::Unix(unix) => {
let codec = ValueCodec::default().framed(unix.into_inner());
ActualPipeline::Unix(Pipeline::new(codec, executor))
}
};
Ok(SharedConnection {
pipeline,
db: con.db,
})
})
}
}
impl ConnectionLike for SharedConnection {
fn req_packed_command(self, cmd: Vec<u8>) -> RedisFuture<(Self, Value)> {
#[cfg(not(unix))]
let future = match self.pipeline {
ActualPipeline::Tcp(ref pipeline) => pipeline.send(cmd),
};
#[cfg(unix)]
let future = match self.pipeline {
ActualPipeline::Tcp(ref pipeline) => Either::A(pipeline.send(cmd)),
ActualPipeline::Unix(ref pipeline) => Either::B(pipeline.send(cmd)),
};
Box::new(future.map(|value| (self, value)).map_err(|err| {
err.unwrap_or_else(|| RedisError::from(io::Error::from(io::ErrorKind::BrokenPipe)))
}))
}
fn req_packed_commands(
self,
cmd: Vec<u8>,
offset: usize,
count: usize,
) -> RedisFuture<(Self, Vec<Value>)> {
#[cfg(not(unix))]
let future = match self.pipeline {
ActualPipeline::Tcp(ref pipeline) => pipeline.send_recv_multiple(cmd, offset + count),
};
#[cfg(unix)]
let future = match self.pipeline {
ActualPipeline::Tcp(ref pipeline) => {
Either::A(pipeline.send_recv_multiple(cmd, offset + count))
}
ActualPipeline::Unix(ref pipeline) => {
Either::B(pipeline.send_recv_multiple(cmd, offset + count))
}
};
Box::new(
future
.map(move |mut value| {
value.drain(..offset);
(self, value)
})
.map_err(|err| {
err.unwrap_or_else(|| {
RedisError::from(io::Error::from(io::ErrorKind::BrokenPipe))
})
}),
)
}
fn get_db(&self) -> i64 {
self.db
}
}