use crate::cmd::Cmd;
use crate::cmd::cmd;
use crate::connection::ConnectionLike;
use crate::connection::DisconnectNotifier;
use crate::connection::factory::FerrisKeyConnectionOptions;
use crate::connection::info::ConnectionInfo;
use crate::connection::runtime;
use crate::connection::setup_connection;
use crate::pipeline::PipelineRetryStrategy;
use crate::protocol::parser::ValueCodec;
use crate::pubsub::push_manager::PushManager;
use crate::value::{ProtocolVersion, PushKind};
use crate::value::{Error, Result, Value};
use ::tokio::{
io::{AsyncRead, AsyncWrite},
sync::{mpsc, oneshot},
};
use arc_swap::ArcSwap;
use futures_util::{
future::{Future, FutureExt},
ready,
sink::Sink,
stream::{self, Stream, StreamExt, TryStreamExt as _},
};
use pin_project_lite::pin_project;
use std::collections::VecDeque;
use std::fmt;
use std::fmt::Debug;
use std::pin::Pin;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use std::task::{self, Poll};
use std::time::Duration;
use tokio_util::codec::Decoder;
const DEFAULT_CONNECTION_ATTEMPT_TIMEOUT: Duration = Duration::from_millis(2000);
type PipelineOutput = oneshot::Sender<Result<Value>>;
enum ResponseAggregate {
SingleCommand,
Pipeline {
expected_response_count: usize, current_response_count: usize,
buffer: Vec<Result<Value>>,
first_err: Option<Error>,
is_transaction: bool,
},
}
impl ResponseAggregate {
fn new(pipeline_response_count: Option<usize>, is_transaction: bool) -> Self {
match pipeline_response_count {
Some(response_count) => ResponseAggregate::Pipeline {
expected_response_count: response_count,
current_response_count: 0,
buffer: Vec::new(),
first_err: None,
is_transaction,
},
None => ResponseAggregate::SingleCommand,
}
}
}
struct InFlight {
output: PipelineOutput,
response_aggregate: ResponseAggregate,
is_fenced: bool,
fenced_result: Option<Result<Value>>,
}
struct PipelineMessage<S> {
input: S,
output: PipelineOutput,
pipeline_response_count: Option<usize>,
is_transaction: bool,
is_fenced: bool,
}
#[derive(Clone)]
pub(crate) struct Pipeline<SinkItem> {
sender: mpsc::Sender<PipelineMessage<SinkItem>>,
push_manager: Arc<ArcSwap<PushManager>>,
is_stream_closed: Arc<AtomicBool>,
}
impl<SinkItem> Debug for Pipeline<SinkItem>
where
SinkItem: Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_tuple("Pipeline").field(&self.sender).finish()
}
}
pin_project! {
struct PipelineSink<T> {
#[pin]
sink_stream: T,
in_flight: VecDeque<InFlight>,
error: Option<Error>,
push_manager: Arc<ArcSwap<PushManager>>,
disconnect_notifier: Option<Box<dyn DisconnectNotifier>>,
is_stream_closed: Arc<AtomicBool>,
response_sync_lost: bool,
}
impl<T> PinnedDrop for PipelineSink<T> {
fn drop(this: Pin<&mut Self>) {
let this = this.project();
let push_manager = this.push_manager.load();
let address = push_manager.get_address();
if let Some(address) = address
&& let Some(sync) = push_manager.get_synchronizer() {
let addresses = std::collections::HashSet::from([address.clone()]);
sync.remove_current_subscriptions_for_addresses(&addresses);
}
}
}
}
impl<T> PipelineSink<T>
where
T: Stream<Item = Result<Value>> + 'static,
{
fn new<SinkItem>(
sink_stream: T,
push_manager: Arc<ArcSwap<PushManager>>,
disconnect_notifier: Option<Box<dyn DisconnectNotifier>>,
is_stream_closed: Arc<AtomicBool>,
) -> Self
where
T: Sink<SinkItem, Error = Error> + Stream<Item = Result<Value>> + 'static,
{
PipelineSink {
sink_stream,
in_flight: VecDeque::with_capacity(128),
error: None,
push_manager,
disconnect_notifier,
is_stream_closed,
response_sync_lost: false,
}
}
fn poll_read(mut self: Pin<&mut Self>, cx: &mut task::Context) -> Poll<std::result::Result<(), ()>> {
loop {
let item = match ready!(self.as_mut().project().sink_stream.poll_next(cx)) {
Some(result) => result,
None => {
if let Some(disconnect_notifier) = self.as_mut().project().disconnect_notifier {
disconnect_notifier.notify_disconnect();
}
self.is_stream_closed.store(true, Ordering::Relaxed);
return Poll::Ready(Err(()));
}
};
self.as_mut().send_result(item);
}
}
fn send_result(self: Pin<&mut Self>, result: Result<Value>) {
let self_ = self.project();
if *self_.response_sync_lost {
while let Some(entry) = self_.in_flight.pop_front() {
let err = Error::from((
crate::value::ErrorKind::ProtocolDesync,
"Response synchronization lost - connection must be reestablished",
));
entry.output.send(Err(err)).ok();
}
return;
}
if let Ok(res) = &result
&& let Value::Push { kind, data: _data } = res
{
self_.push_manager.load().try_send_raw(res);
if !kind.has_reply() {
return;
}
}
let mut entry = match self_.in_flight.pop_front() {
Some(entry) => entry,
None => return,
};
if entry.is_fenced {
Self::handle_fenced_command(entry, result, self_.in_flight, self_.response_sync_lost, self_.is_stream_closed);
return;
}
match &mut entry.response_aggregate {
ResponseAggregate::SingleCommand => {
entry
.output
.send(result.and_then(|v| v.extract_error()))
.ok();
}
ResponseAggregate::Pipeline {
expected_response_count,
current_response_count,
buffer,
first_err,
is_transaction,
} => {
match result {
Ok(item) => {
buffer.push(Ok(item));
}
Err(err) if *is_transaction => {
if first_err.is_none() {
*first_err = Some(err);
}
}
Err(err) => {
buffer.push(Err(err));
}
}
*current_response_count += 1;
if current_response_count < expected_response_count {
self_.in_flight.push_front(entry);
return;
}
let response = match first_err.take() {
Some(err) => Err(err),
None => Ok(Value::Array(std::mem::take(buffer))),
};
entry.output.send(response).ok();
}
}
}
fn handle_fenced_command(
mut entry: InFlight,
result: Result<Value>,
in_flight: &mut VecDeque<InFlight>,
response_sync_lost: &mut bool,
is_stream_closed: &Arc<AtomicBool>,
) {
if let Some(stored_result) = entry.fenced_result.take() {
Self::handle_fenced_second_response(entry, result, stored_result, in_flight, response_sync_lost, is_stream_closed);
} else {
Self::handle_fenced_first_response(entry, result, in_flight);
}
}
fn handle_fenced_first_response(
mut entry: InFlight,
result: Result<Value>,
in_flight: &mut VecDeque<InFlight>,
) {
match result {
Ok(Value::SimpleString(ref s)) if s == "PONG" || s == "pong" => {
entry.output.send(Ok(Value::Nil)).ok();
}
Err(err) => {
entry.fenced_result = Some(Err(err));
in_flight.push_front(entry);
}
Ok(value) => {
entry.fenced_result = Some(Ok(value));
in_flight.push_front(entry);
}
}
}
fn handle_fenced_second_response(
entry: InFlight,
pong_result: Result<Value>,
stored_result: Result<Value>,
in_flight: &mut VecDeque<InFlight>,
response_sync_lost: &mut bool,
is_stream_closed: &Arc<AtomicBool>,
) {
let is_pong = matches!(
&pong_result,
Ok(Value::SimpleString(s)) if s == "PONG"
);
if !is_pong {
*response_sync_lost = true;
is_stream_closed.store(true, Ordering::Relaxed);
tracing::error!("Fenced command - CRITICAL: Expected PONG for fenced command but got unexpected response. Response synchronization lost. All commands will fail until reconnection.");
let err = Error::from((
crate::value::ErrorKind::ProtocolDesync,
"Expected PONG for fenced command but received different response",
format!("Response synchronization lost. Got: {:?}", pong_result),
));
entry.output.send(Err(err)).ok();
while let Some(remaining) = in_flight.pop_front() {
let err = Error::from((
crate::value::ErrorKind::ProtocolDesync,
"Response synchronization lost - connection must be reestablished",
));
remaining.output.send(Err(err)).ok();
}
return;
}
let final_result = stored_result.and_then(|v| v.extract_error());
entry.output.send(final_result).ok();
}
}
impl<SinkItem, T> Sink<PipelineMessage<SinkItem>> for PipelineSink<T>
where
T: Sink<SinkItem, Error = Error> + Stream<Item = Result<Value>> + 'static,
{
type Error = ();
fn poll_ready(
mut self: Pin<&mut Self>,
cx: &mut task::Context,
) -> Poll<std::result::Result<(), Self::Error>> {
match ready!(self.as_mut().project().sink_stream.poll_ready(cx)) {
Ok(()) => Ok(()).into(),
Err(err) => {
*self.project().error = Some(err);
Ok(()).into()
}
}
}
fn start_send(
mut self: Pin<&mut Self>,
PipelineMessage {
input,
output,
pipeline_response_count,
is_transaction,
is_fenced,
}: PipelineMessage<SinkItem>,
) -> std::result::Result<(), Self::Error> {
if output.is_closed() {
return Ok(());
}
let self_ = self.as_mut().project();
if let Some(err) = self_.error.take() {
let _ = output.send(Err(err));
return Err(());
}
if *self_.response_sync_lost {
let err = Error::from((
crate::value::ErrorKind::ProtocolDesync,
"Response synchronization lost - connection must be reestablished",
));
let _ = output.send(Err(err));
return Err(());
}
match self_.sink_stream.start_send(input) {
Ok(()) => {
let response_aggregate =
ResponseAggregate::new(pipeline_response_count, is_transaction);
let entry = InFlight {
output,
response_aggregate,
is_fenced,
fenced_result: None,
};
self_.in_flight.push_back(entry);
Ok(())
}
Err(err) => {
let _ = output.send(Err(err));
Err(())
}
}
}
fn poll_flush(
mut self: Pin<&mut Self>,
cx: &mut task::Context,
) -> Poll<std::result::Result<(), Self::Error>> {
ready!(
self.as_mut()
.project()
.sink_stream
.poll_flush(cx)
.map_err(|err| {
self.as_mut().send_result(Err(err));
})
)?;
self.poll_read(cx)
}
fn poll_close(
mut self: Pin<&mut Self>,
cx: &mut task::Context,
) -> Poll<std::result::Result<(), Self::Error>> {
if !self.in_flight.is_empty() {
ready!(self.as_mut().poll_flush(cx))?;
}
let this = self.as_mut().project();
this.sink_stream.poll_close(cx).map_err(|err| {
self.send_result(Err(err));
})
}
}
impl<SinkItem> Pipeline<SinkItem>
where
SinkItem: Send + 'static,
{
fn new<T>(
sink_stream: T,
disconnect_notifier: Option<Box<dyn DisconnectNotifier>>,
) -> (Self, impl Future<Output = ()>)
where
T: Sink<SinkItem, Error = Error> + Stream<Item = Result<Value>> + 'static,
T: Send + 'static,
T::Item: Send,
T::Error: Send,
T::Error: ::std::fmt::Debug,
{
const BUFFER_SIZE: usize = 50;
let (sender, mut receiver) = mpsc::channel(BUFFER_SIZE);
let push_manager: Arc<ArcSwap<PushManager>> =
Arc::new(ArcSwap::new(Arc::new(PushManager::default())));
let is_stream_closed = Arc::new(AtomicBool::new(false));
let sink = PipelineSink::new::<SinkItem>(
sink_stream,
push_manager.clone(),
disconnect_notifier,
is_stream_closed.clone(),
);
let f = stream::poll_fn(move |cx| receiver.poll_recv(cx))
.map(Ok)
.forward(sink)
.map(|_| ());
(
Pipeline {
sender,
push_manager,
is_stream_closed,
},
f,
)
}
async fn send_single(
&mut self,
item: SinkItem,
timeout: Duration,
is_fenced: bool,
) -> Result<Value> {
self.send_recv(item, None, timeout, true, is_fenced).await
}
async fn send_recv(
&mut self,
input: SinkItem,
pipeline_response_count: Option<usize>,
timeout: Duration,
is_atomic: bool,
is_fenced: bool,
) -> Result<Value> {
let (sender, receiver) = oneshot::channel();
self.sender
.send(PipelineMessage {
input,
pipeline_response_count,
output: sender,
is_transaction: is_atomic,
is_fenced,
})
.await
.map_err(|err| {
Error::from((
crate::value::ErrorKind::FatalSendError,
"Failed to send the request to the server",
err.to_string(),
))
})?;
match runtime::timeout(timeout, receiver).await {
Ok(Ok(result)) => result,
Ok(Err(err)) => {
Err(Error::from((
crate::value::ErrorKind::FatalReceiveError,
"Failed to receive a response due to a fatal error",
err.to_string(),
)))
}
Err(elapsed) => Err(elapsed.into()),
}
}
fn set_push_manager(&mut self, push_manager: PushManager) {
self.push_manager.store(Arc::new(push_manager));
}
pub fn is_closed(&self) -> bool {
self.is_stream_closed.load(Ordering::Relaxed)
}
}
#[derive(Clone)]
pub struct MultiplexedConnection {
pipeline: Pipeline<bytes::Bytes>,
db: i64,
response_timeout: Duration,
protocol: ProtocolVersion,
push_manager: PushManager,
availability_zone: Option<String>,
password: Option<String>,
}
impl Debug for MultiplexedConnection {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("MultiplexedConnection")
.field("pipeline", &self.pipeline)
.field("db", &self.db)
.finish()
}
}
impl MultiplexedConnection {
pub async fn new<C>(
connection_info: ConnectionInfo,
stream: C,
ferriskey_connection_options: FerrisKeyConnectionOptions,
) -> Result<(Self, impl Future<Output = ()>)>
where
C: Unpin + AsyncRead + AsyncWrite + Send + 'static,
{
Self::new_with_response_timeout(
connection_info,
stream,
super::factory::NO_TIMEOUT,
ferriskey_connection_options,
)
.await
}
pub async fn new_with_response_timeout<C>(
connection_info: ConnectionInfo,
stream: C,
response_timeout: std::time::Duration,
ferriskey_connection_options: FerrisKeyConnectionOptions,
) -> Result<(Self, impl Future<Output = ()>)>
where
C: Unpin + AsyncRead + AsyncWrite + Send + 'static,
{
let codec = ValueCodec.framed(stream).and_then(|msg| async move { msg });
let (mut pipeline, driver) =
Pipeline::new(codec, ferriskey_connection_options.disconnect_notifier);
let driver = Box::pin(driver);
let pm = PushManager::new(
ferriskey_connection_options.push_sender,
ferriskey_connection_options.pubsub_synchronizer,
Some(connection_info.addr.to_string()),
);
pipeline.set_push_manager(pm.clone());
let mut con = MultiplexedConnection::builder(pipeline)
.with_db(connection_info.valkey.db)
.with_response_timeout(response_timeout)
.with_push_manager(pm)
.with_protocol(connection_info.valkey.protocol)
.with_password(connection_info.valkey.password.clone())
.with_availability_zone(None)
.build()
.await?;
let driver = {
let auth = setup_connection(
&connection_info.valkey,
&mut con,
ferriskey_connection_options.discover_az,
);
futures_util::pin_mut!(auth);
match futures_util::future::select(auth, driver).await {
futures_util::future::Either::Left((result, driver)) => {
result?;
driver
}
futures_util::future::Either::Right(((), _)) => {
return Err(Error::from((
crate::value::ErrorKind::IoError,
"Multiplexed connection driver unexpectedly terminated",
)));
}
}
};
Ok((con, driver))
}
pub fn set_response_timeout(&mut self, timeout: std::time::Duration) {
self.response_timeout = timeout;
}
pub async fn send_packed_command(&mut self, cmd: &Cmd) -> Result<Value> {
let result = self
.pipeline
.send_single(
cmd.get_packed_command(),
self.response_timeout,
cmd.is_fenced(),
)
.await;
if self.protocol != ProtocolVersion::RESP2
&& let Err(e) = &result
&& e.is_connection_dropped()
{
self.push_manager.try_send_raw(&Value::Push {
kind: PushKind::Disconnection,
data: vec![],
});
}
result
}
pub async fn send_packed_commands(
&mut self,
cmd: &crate::pipeline::Pipeline,
offset: usize,
count: usize,
) -> Result<Vec<Result<Value>>> {
let result = self
.pipeline
.send_recv(
cmd.get_packed_pipeline(),
Some(offset + count),
self.response_timeout,
cmd.is_atomic(),
false,
)
.await;
if self.protocol != ProtocolVersion::RESP2
&& let Err(e) = &result
&& e.is_connection_dropped()
{
self.push_manager.try_send_raw(&Value::Push {
kind: PushKind::Disconnection,
data: vec![],
});
}
let value = result?;
match value {
Value::Array(mut values) => {
values.drain(..offset);
Ok(values)
}
_ => Ok(vec![Ok(value)]),
}
}
pub async fn set_push_manager(&mut self, push_manager: PushManager) {
self.push_manager = push_manager.clone();
self.pipeline.set_push_manager(push_manager);
}
pub fn get_availability_zone(&self) -> Option<String> {
self.availability_zone.clone()
}
pub async fn update_connection_password(
&mut self,
password: Option<String>,
) -> Result<Value> {
self.password = password;
Ok(Value::Okay)
}
pub(crate) fn builder(pipeline: Pipeline<bytes::Bytes>) -> MultiplexedConnectionBuilder {
MultiplexedConnectionBuilder::new(pipeline)
}
pub fn update_push_manager_node_address(&mut self, address: String) {
let updated_pm = self.push_manager.with_address(address);
self.pipeline.set_push_manager(updated_pm.clone());
self.push_manager = updated_pm;
}
}
pub struct MultiplexedConnectionBuilder {
pipeline: Pipeline<bytes::Bytes>,
db: Option<i64>,
response_timeout: Option<Duration>,
push_manager: Option<PushManager>,
protocol: Option<ProtocolVersion>,
password: Option<String>,
availability_zone: Option<String>,
}
impl MultiplexedConnectionBuilder {
pub(crate) fn new(pipeline: Pipeline<bytes::Bytes>) -> Self {
Self {
pipeline,
db: None,
response_timeout: None,
push_manager: None,
protocol: None,
password: None,
availability_zone: None,
}
}
pub fn with_db(mut self, db: i64) -> Self {
self.db = Some(db);
self
}
pub fn with_response_timeout(mut self, timeout: Duration) -> Self {
self.response_timeout = Some(timeout);
self
}
pub fn with_push_manager(mut self, push_manager: PushManager) -> Self {
self.push_manager = Some(push_manager);
self
}
pub fn with_protocol(mut self, protocol: ProtocolVersion) -> Self {
self.protocol = Some(protocol);
self
}
pub fn with_password(mut self, password: Option<String>) -> Self {
self.password = password;
self
}
pub fn with_availability_zone(mut self, az: Option<String>) -> Self {
self.availability_zone = az;
self
}
pub async fn build(self) -> Result<MultiplexedConnection> {
let db = self.db.unwrap_or_default();
let response_timeout = self
.response_timeout
.unwrap_or(DEFAULT_CONNECTION_ATTEMPT_TIMEOUT);
let push_manager = self.push_manager.unwrap_or_default();
let protocol = self.protocol.unwrap_or_default();
let password = self.password;
let con = MultiplexedConnection {
pipeline: self.pipeline,
db,
response_timeout,
push_manager,
protocol,
password,
availability_zone: self.availability_zone,
};
Ok(con)
}
}
impl ConnectionLike for MultiplexedConnection {
async fn req_packed_command(&mut self, cmd: &Cmd) -> Result<Value> {
self.send_packed_command(cmd).await
}
async fn send_packed_bytes(
&mut self,
packed: bytes::Bytes,
is_fenced: bool,
) -> Result<Value> {
let result = self
.pipeline
.send_single(packed, self.response_timeout, is_fenced)
.await;
if self.protocol != ProtocolVersion::RESP2
&& let Err(e) = &result
&& e.is_connection_dropped()
{
self.push_manager.try_send_raw(&Value::Push {
kind: PushKind::Disconnection,
data: vec![],
});
}
result
}
async fn req_packed_commands(
&mut self,
cmd: &crate::pipeline::Pipeline,
offset: usize,
count: usize,
_pipeline_retry_strategy: Option<PipelineRetryStrategy>,
) -> Result<Vec<Result<Value>>> {
self.send_packed_commands(cmd, offset, count).await
}
fn get_db(&self) -> i64 {
self.db
}
fn is_closed(&self) -> bool {
self.pipeline.is_closed()
}
fn get_az(&self) -> Option<String> {
self.availability_zone.clone()
}
fn set_az(&mut self, az: Option<String>) {
self.availability_zone = az;
}
fn update_push_manager_node_address(&mut self, address: String) {
MultiplexedConnection::update_push_manager_node_address(self, address);
}
}
impl MultiplexedConnection {
pub async fn subscribe(&mut self, channel_name: String) -> Result<()> {
if self.protocol == ProtocolVersion::RESP2 {
return Err(Error::from((
crate::value::ErrorKind::InvalidClientConfig,
"RESP3 is required for this command",
)));
}
let mut cmd = cmd("SUBSCRIBE");
cmd.arg(channel_name.clone());
cmd.query_async::<_, ()>(self).await?;
Ok(())
}
pub async fn unsubscribe(&mut self, channel_name: String) -> Result<()> {
if self.protocol == ProtocolVersion::RESP2 {
return Err(Error::from((
crate::value::ErrorKind::InvalidClientConfig,
"RESP3 is required for this command",
)));
}
let mut cmd = cmd("UNSUBSCRIBE");
cmd.arg(channel_name);
cmd.query_async::<_, ()>(self).await?;
Ok(())
}
pub async fn psubscribe(&mut self, channel_pattern: String) -> Result<()> {
if self.protocol == ProtocolVersion::RESP2 {
return Err(Error::from((
crate::value::ErrorKind::InvalidClientConfig,
"RESP3 is required for this command",
)));
}
let mut cmd = cmd("PSUBSCRIBE");
cmd.arg(channel_pattern.clone());
cmd.query_async::<_, ()>(self).await?;
Ok(())
}
pub async fn punsubscribe(&mut self, channel_pattern: String) -> Result<()> {
if self.protocol == ProtocolVersion::RESP2 {
return Err(Error::from((
crate::value::ErrorKind::InvalidClientConfig,
"RESP3 is required for this command",
)));
}
let mut cmd = cmd("PUNSUBSCRIBE");
cmd.arg(channel_pattern);
cmd.query_async::<_, ()>(self).await?;
Ok(())
}
pub fn get_push_manager(&self) -> PushManager {
self.push_manager.clone()
}
}