use crate::types::{RedisResult, Value};
use crate::{
FromRedisValue, Msg, RedisConnectionInfo, ToRedisArgs, aio::Runtime, cmd, errors::RedisError,
errors::closed_connection_error, from_redis_value, parser::ValueCodec,
};
use ::tokio::{
io::{AsyncRead, AsyncWrite},
sync::oneshot,
};
use futures_util::{
future::{Future, FutureExt},
ready,
sink::Sink,
stream::{self, Stream, StreamExt},
};
use pin_project_lite::pin_project;
use std::collections::VecDeque;
use std::pin::Pin;
use std::task::{self, Poll};
use tokio::sync::mpsc::UnboundedSender;
use tokio::sync::mpsc::unbounded_channel;
use tokio_util::codec::Decoder;
use super::{SharedHandleContainer, setup_connection};
type RequestResultSender = oneshot::Sender<RedisResult<Value>>;
struct PipelineMessage {
input: Vec<u8>,
output: RequestResultSender,
}
#[derive(Clone)]
pub struct PubSubSink {
sender: UnboundedSender<PipelineMessage>,
}
pin_project! {
pub struct PubSubStream {
#[pin]
receiver: tokio::sync::mpsc::UnboundedReceiver<Msg>,
_task_handle: Option<SharedHandleContainer>,
}
}
pin_project! {
struct PipelineSink<T> {
#[pin]
sink_stream: T,
in_flight: VecDeque<RequestResultSender>,
sender: UnboundedSender<Msg>,
}
}
impl<T> PipelineSink<T>
where
T: Stream<Item = RedisResult<Value>> + 'static,
{
fn new(sink_stream: T, sender: UnboundedSender<Msg>) -> Self
where
T: Sink<Vec<u8>, Error = RedisError> + Stream<Item = RedisResult<Value>> + 'static,
{
PipelineSink {
sink_stream,
in_flight: VecDeque::new(),
sender,
}
}
fn poll_read(mut self: Pin<&mut Self>, cx: &mut task::Context) -> Poll<Result<(), ()>> {
loop {
let self_ = self.as_mut().project();
if self_.sender.is_closed() {
return Poll::Ready(Err(()));
}
let item = match ready!(self.as_mut().project().sink_stream.poll_next(cx)) {
Some(result) => result,
None => return Poll::Ready(Err(())),
};
self.as_mut().handle_message(item)?;
}
}
fn handle_message(self: Pin<&mut Self>, result: RedisResult<Value>) -> Result<(), ()> {
let self_ = self.project();
match result {
Ok(Value::Array(value)) => {
if let Some(Value::BulkString(kind)) = value.first() {
if matches!(
kind.as_slice(),
b"subscribe" | b"psubscribe" | b"unsubscribe" | b"punsubscribe" | b"pong"
) {
if let Some(entry) = self_.in_flight.pop_front() {
let _ = entry.send(Ok(Value::Array(value)));
};
return Ok(());
}
}
if let Some(msg) = Msg::from_owned_value(Value::Array(value)) {
let _ = self_.sender.send(msg);
Ok(())
} else {
Err(())
}
}
Ok(Value::Push { kind, data }) => {
if kind.has_reply() {
if let Some(entry) = self_.in_flight.pop_front() {
let _ = entry.send(Ok(Value::Push { kind, data }));
};
return Ok(());
}
if let Some(msg) = Msg::from_push_info(crate::PushInfo { kind, data }) {
let _ = self_.sender.send(msg);
Ok(())
} else {
Err(())
}
}
Err(err) if err.is_unrecoverable_error() => Err(()),
_ => {
if let Some(entry) = self_.in_flight.pop_front() {
let _ = entry.send(result);
Ok(())
} else {
Err(())
}
}
}
}
}
impl<T> Sink<PipelineMessage> for PipelineSink<T>
where
T: Sink<Vec<u8>, Error = RedisError> + Stream<Item = RedisResult<Value>> + 'static,
{
type Error = ();
fn poll_ready(
mut self: Pin<&mut Self>,
cx: &mut task::Context,
) -> Poll<Result<(), Self::Error>> {
self.as_mut()
.project()
.sink_stream
.poll_ready(cx)
.map_err(|_| ())
}
fn start_send(
mut self: Pin<&mut Self>,
PipelineMessage { input, output }: PipelineMessage,
) -> Result<(), Self::Error> {
let self_ = self.as_mut().project();
match self_.sink_stream.start_send(input) {
Ok(()) => {
self_.in_flight.push_back(output);
Ok(())
}
Err(err) => {
let _ = output.send(Err(err));
Err(())
}
}
}
fn poll_flush(
mut self: Pin<&mut Self>,
cx: &mut task::Context,
) -> Poll<Result<(), Self::Error>> {
ready!(
self.as_mut()
.project()
.sink_stream
.poll_flush(cx)
.map_err(|err| {
let _ = self.as_mut().handle_message(Err(err));
})
)?;
self.poll_read(cx)
}
fn poll_close(
mut self: Pin<&mut Self>,
cx: &mut task::Context,
) -> Poll<Result<(), Self::Error>> {
if !self.in_flight.is_empty() {
ready!(self.as_mut().poll_flush(cx))?;
}
let this = self.as_mut().project();
if this.sender.is_closed() {
return Poll::Ready(Ok(()));
}
match ready!(this.sink_stream.poll_next(cx)) {
Some(result) => {
let _ = self.handle_message(result);
Poll::Pending
}
None => Poll::Ready(Ok(())),
}
}
}
impl PubSubSink {
fn new<T>(
sink_stream: T,
messages_sender: UnboundedSender<Msg>,
) -> (Self, impl Future<Output = ()>)
where
T: Sink<Vec<u8>, Error = RedisError>,
T: Stream<Item = RedisResult<Value>>,
T: Unpin + Send + 'static,
{
let (sender, mut receiver) = unbounded_channel();
let sink = PipelineSink::new(sink_stream, messages_sender);
let f = stream::poll_fn(move |cx| {
let res = receiver.poll_recv(cx);
match res {
Poll::Ready(None) => Poll::Pending,
_ => res,
}
})
.map(Ok)
.forward(sink)
.map(|_| ());
(PubSubSink { sender }, f)
}
async fn send_recv(&mut self, input: Vec<u8>) -> Result<Value, RedisError> {
let (sender, receiver) = oneshot::channel();
self.sender
.send(PipelineMessage {
input,
output: sender,
})
.map_err(|_| closed_connection_error())?;
match receiver.await {
Ok(result) => result,
Err(_) => Err(closed_connection_error()),
}
}
pub async fn subscribe(&mut self, channel_name: impl ToRedisArgs) -> RedisResult<()> {
let cmd = cmd("SUBSCRIBE").arg(channel_name).get_packed_command();
self.send_recv(cmd).await.map(|_| ())
}
pub async fn unsubscribe(&mut self, channel_name: impl ToRedisArgs) -> RedisResult<()> {
let cmd = cmd("UNSUBSCRIBE").arg(channel_name).get_packed_command();
self.send_recv(cmd).await.map(|_| ())
}
pub async fn psubscribe(&mut self, channel_pattern: impl ToRedisArgs) -> RedisResult<()> {
let cmd = cmd("PSUBSCRIBE").arg(channel_pattern).get_packed_command();
self.send_recv(cmd).await.map(|_| ())
}
pub async fn punsubscribe(&mut self, channel_pattern: impl ToRedisArgs) -> RedisResult<()> {
let cmd = cmd("PUNSUBSCRIBE")
.arg(channel_pattern)
.get_packed_command();
self.send_recv(cmd).await.map(|_| ())
}
pub async fn ping_message<T: FromRedisValue>(
&mut self,
message: impl ToRedisArgs,
) -> RedisResult<T> {
let cmd = cmd("PING").arg(message).get_packed_command();
let response = self.send_recv(cmd).await?;
Ok(from_redis_value(response)?)
}
pub async fn ping<T: FromRedisValue>(&mut self) -> RedisResult<T> {
let cmd = cmd("PING").get_packed_command();
let response = self.send_recv(cmd).await?;
Ok(from_redis_value(response)?)
}
}
pub struct PubSub {
sink: PubSubSink,
stream: PubSubStream,
}
impl PubSub {
pub async fn new<C>(connection_info: &RedisConnectionInfo, stream: C) -> RedisResult<Self>
where
C: Unpin + AsyncRead + AsyncWrite + Send + 'static,
{
let mut codec = ValueCodec::default().framed(stream);
setup_connection(
&mut codec,
connection_info,
#[cfg(feature = "cache-aio")]
None,
)
.await?;
let (sender, receiver) = unbounded_channel();
let (sink, driver) = PubSubSink::new(codec, sender);
let handle = Runtime::locate().spawn(driver);
let _task_handle = Some(SharedHandleContainer::new(handle));
let stream = PubSubStream {
receiver,
_task_handle,
};
let con = PubSub { sink, stream };
Ok(con)
}
pub async fn subscribe(&mut self, channel_name: impl ToRedisArgs) -> RedisResult<()> {
self.sink.subscribe(channel_name).await
}
pub async fn unsubscribe(&mut self, channel_name: impl ToRedisArgs) -> RedisResult<()> {
self.sink.unsubscribe(channel_name).await
}
pub async fn psubscribe(&mut self, channel_pattern: impl ToRedisArgs) -> RedisResult<()> {
self.sink.psubscribe(channel_pattern).await
}
pub async fn punsubscribe(&mut self, channel_pattern: impl ToRedisArgs) -> RedisResult<()> {
self.sink.punsubscribe(channel_pattern).await
}
pub async fn ping<T: FromRedisValue>(&mut self) -> RedisResult<T> {
self.sink.ping().await
}
pub async fn ping_message<T: FromRedisValue>(
&mut self,
message: impl ToRedisArgs,
) -> RedisResult<T> {
self.sink.ping_message(message).await
}
pub fn on_message(&mut self) -> impl Stream<Item = Msg> + '_ {
&mut self.stream
}
pub fn into_on_message(self) -> PubSubStream {
self.stream
}
pub fn split(self) -> (PubSubSink, PubSubStream) {
(self.sink, self.stream)
}
}
impl Stream for PubSubStream {
type Item = Msg;
fn poll_next(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Option<Self::Item>> {
self.project().receiver.poll_recv(cx)
}
}