use std::{
collections::VecDeque,
future::Future,
net::SocketAddr,
pin::Pin,
sync::Arc,
task::{Context, Poll},
};
use futures_channel::{mpsc, oneshot};
use futures_sink::Sink;
use futures_util::{future, stream::StreamExt};
use super::connect::{connect, RespConnection};
use crate::{
error,
reconnect::{reconnect, Reconnect},
resp,
};
enum SendStatus {
Ok,
End,
Full(resp::RespValue),
}
#[derive(Debug)]
enum ReceiveStatus {
ReadyFinished,
ReadyMore,
NotReady,
}
type Responder = oneshot::Sender<resp::RespValue>;
type SendPayload = (resp::RespValue, Responder);
struct PairedConnectionInner {
connection: RespConnection,
out_rx: mpsc::UnboundedReceiver<SendPayload>,
waiting: VecDeque<Responder>,
send_status: SendStatus,
}
impl PairedConnectionInner {
fn new(
con: RespConnection,
out_rx: mpsc::UnboundedReceiver<(resp::RespValue, oneshot::Sender<resp::RespValue>)>,
) -> Self {
PairedConnectionInner {
connection: con,
out_rx,
waiting: VecDeque::new(),
send_status: SendStatus::Ok,
}
}
fn impl_start_send(
&mut self,
cx: &mut Context,
msg: resp::RespValue,
) -> Result<bool, error::Error> {
match Pin::new(&mut self.connection).poll_ready(cx) {
Poll::Ready(Ok(())) => (),
Poll::Ready(Err(e)) => return Err(e.into()),
Poll::Pending => {
self.send_status = SendStatus::Full(msg);
return Ok(false);
}
}
self.send_status = SendStatus::Ok;
Pin::new(&mut self.connection).start_send(msg)?;
Ok(true)
}
fn poll_start_send(&mut self, cx: &mut Context) -> Result<bool, error::Error> {
let mut status = SendStatus::Ok;
::std::mem::swap(&mut status, &mut self.send_status);
let message = match status {
SendStatus::End => {
self.send_status = SendStatus::End;
return Ok(false);
}
SendStatus::Full(msg) => msg,
SendStatus::Ok => match self.out_rx.poll_next_unpin(cx) {
Poll::Ready(Some((msg, tx))) => {
self.waiting.push_back(tx);
msg
}
Poll::Ready(None) => {
self.send_status = SendStatus::End;
return Ok(false);
}
Poll::Pending => return Ok(false),
},
};
self.impl_start_send(cx, message)
}
fn poll_complete(&mut self, cx: &mut Context) -> Result<(), error::Error> {
let _ = Pin::new(&mut self.connection).poll_flush(cx)?;
Ok(())
}
fn receive(&mut self, cx: &mut Context) -> Result<ReceiveStatus, error::Error> {
if let SendStatus::End = self.send_status {
if self.waiting.is_empty() {
return Ok(ReceiveStatus::ReadyFinished);
}
}
match self.connection.poll_next_unpin(cx) {
Poll::Ready(None) => Err(error::unexpected("Connection to Redis closed unexpectedly")),
Poll::Ready(Some(msg)) => {
let tx = match self.waiting.pop_front() {
Some(tx) => tx,
None => panic!("Received unexpected message: {:?}", msg),
};
let _ = tx.send(msg?);
Ok(ReceiveStatus::ReadyMore)
}
Poll::Pending => Ok(ReceiveStatus::NotReady),
}
}
fn handle_error(&self, e: &error::Error) {
log::error!("Internal error in PairedConnectionInner: {}", e);
}
}
impl Future for PairedConnectionInner {
type Output = ();
#[allow(clippy::unit_arg)]
fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
let mut_self = self.get_mut();
let mut sending = true;
while sending {
sending = match mut_self.poll_start_send(cx) {
Ok(sending) => sending,
Err(ref e) => return Poll::Ready(mut_self.handle_error(e)),
};
}
if let Err(ref e) = mut_self.poll_complete(cx) {
return Poll::Ready(mut_self.handle_error(e));
};
loop {
match mut_self.receive(cx) {
Ok(ReceiveStatus::NotReady) => return Poll::Pending,
Ok(ReceiveStatus::ReadyMore) => (),
Ok(ReceiveStatus::ReadyFinished) => return Poll::Ready(()),
Err(ref e) => return Poll::Ready(mut_self.handle_error(e)),
}
}
}
}
#[derive(Debug, Clone)]
pub struct PairedConnection {
out_tx_c: Arc<Reconnect<SendPayload, mpsc::UnboundedSender<SendPayload>>>,
}
async fn inner_conn_fn(
addr: SocketAddr,
) -> Result<mpsc::UnboundedSender<SendPayload>, error::Error> {
let connection = connect(&addr).await?;
let (out_tx, out_rx) = mpsc::unbounded();
let paired_connection_inner = PairedConnectionInner::new(connection, out_rx);
tokio::spawn(paired_connection_inner);
Ok(out_tx)
}
pub async fn paired_connect(addr: &SocketAddr) -> Result<PairedConnection, error::Error> {
let addr = *addr;
let reconnecting_con = reconnect(
|con: &mpsc::UnboundedSender<SendPayload>, act| {
con.unbounded_send(act).map_err(|e| e.into())
},
move || {
let con_f = inner_conn_fn(addr);
Box::new(Box::pin(con_f))
},
);
Ok(PairedConnection {
out_tx_c: Arc::new(reconnecting_con.await?),
})
}
impl PairedConnection {
pub fn send<T>(&self, msg: resp::RespValue) -> impl Future<Output = Result<T, error::Error>>
where
T: resp::FromResp,
{
match &msg {
resp::RespValue::Array(_) => (),
_ => {
return future::Either::Right(future::ready(Err(error::internal(
"Command must be a RespValue::Array",
))));
}
}
let (tx, rx) = oneshot::channel();
match self.out_tx_c.do_work((msg, tx)) {
Ok(()) => future::Either::Left(async move {
match rx.await {
Ok(v) => Ok(T::from_resp(v)?),
Err(_) => Err(error::internal(
"Connection closed before response received",
)),
}
}),
Err(e) => future::Either::Right(future::ready(Err(e))),
}
}
pub fn send_and_forget(&self, msg: resp::RespValue) {
let send_f = self.send::<resp::RespValue>(msg);
let forget_f = async {
if let Err(e) = send_f.await {
log::error!("Error in send_and_forget: {}", e);
}
};
tokio::spawn(forget_f);
}
}
#[cfg(test)]
mod test {
#[tokio::test]
async fn can_paired_connect() {
let addr = "127.0.0.1:6379".parse().unwrap();
let connection = super::paired_connect(&addr)
.await
.expect("Cannot establish connection");
let res_f = connection.send(resp_array!["PING", "TEST"]);
connection.send_and_forget(resp_array!["SET", "X", "123"]);
let wait_f = connection.send(resp_array!["GET", "X"]);
let result_1: String = res_f.await.expect("Cannot read result of first thing");
let result_2: String = wait_f.await.expect("Cannot read result of second thing");
assert_eq!(result_1, "TEST");
assert_eq!(result_2, "123");
}
#[tokio::test]
async fn complex_paired_connect() {
let addr = "127.0.0.1:6379".parse().unwrap();
let connection = super::paired_connect(&addr)
.await
.expect("Cannot establish connection");
let value: String = connection
.send(resp_array!["INCR", "CTR"])
.await
.expect("Cannot increment counter");
let result: String = connection
.send(resp_array!["SET", "LASTCTR", value])
.await
.expect("Cannot set value");
assert_eq!(result, "OK");
}
#[tokio::test]
async fn sending_a_lot_of_data_test() {
let addr = "127.0.0.1:6379".parse().unwrap();
let connection = super::paired_connect(&addr)
.await
.expect("Cannot connect to Redis");
let mut futures = Vec::with_capacity(1000);
for i in 0..1000 {
let key = format!("X_{}", i);
connection.send_and_forget(resp_array!["SET", &key, i.to_string()]);
futures.push(connection.send(resp_array!["GET", key]));
}
let last_future = futures.remove(999);
let result: String = last_future.await.expect("Cannot wait for result");
assert_eq!(result, "999");
}
}