use super::{AsyncPushSender, ConnectionLike, Runtime, SharedHandleContainer, TaskHandle};
#[cfg(feature = "cache-aio")]
use crate::caching::{CacheManager, CacheStatistics, PrepareCacheResult};
use crate::{
AsyncConnectionConfig, ProtocolVersion, PushInfo, RedisConnectionInfo, ServerError,
ToRedisArgs,
aio::setup_connection,
check_resp3, cmd,
cmd::Cmd,
errors::{RedisError, closed_connection_error},
parser::ValueCodec,
types::{RedisFuture, RedisResult, Value},
};
use ::tokio::{
io::{AsyncRead, AsyncWrite},
sync::{mpsc, oneshot},
};
#[cfg(feature = "token-based-authentication")]
use {
crate::errors::ErrorKind,
arcstr::ArcStr,
log::{debug, error},
std::sync::atomic::{AtomicBool, Ordering},
};
use futures_util::{
future::{Future, FutureExt},
ready,
sink::Sink,
stream::{self, Stream, StreamExt},
};
use pin_project_lite::pin_project;
use std::collections::VecDeque;
use std::fmt;
use std::fmt::Debug;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{self, Poll};
use std::time::Duration;
use tokio_util::codec::Decoder;
type PipelineOutput = oneshot::Sender<RedisResult<Value>>;
enum ErrorOrErrors {
Errors(Vec<(usize, ServerError)>),
FirstError(RedisError),
}
enum ResponseAggregate {
SingleCommand,
Pipeline {
buffer: Vec<Value>,
error_or_errors: ErrorOrErrors,
expectation: PipelineResponseExpectation,
},
}
struct PipelineResponseExpectation {
skipped_response_count: usize,
expected_response_count: usize,
is_transaction: bool,
seen_responses: usize,
}
impl ResponseAggregate {
fn new(expectation: Option<PipelineResponseExpectation>) -> Self {
match expectation {
Some(expectation) => ResponseAggregate::Pipeline {
buffer: Vec::new(),
error_or_errors: ErrorOrErrors::Errors(Vec::new()),
expectation,
},
None => ResponseAggregate::SingleCommand,
}
}
}
struct InFlight {
output: Option<PipelineOutput>,
response_aggregate: ResponseAggregate,
}
struct PipelineMessage {
input: Vec<u8>,
output: Option<PipelineOutput>,
expectation: Option<PipelineResponseExpectation>,
}
#[derive(Clone)]
struct Pipeline {
sender: mpsc::Sender<PipelineMessage>,
}
impl Debug for Pipeline {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_tuple("Pipeline").field(&self.sender).finish()
}
}
#[cfg(feature = "cache-aio")]
pin_project! {
struct PipelineSink<T> {
#[pin]
sink_stream: T,
in_flight: VecDeque<InFlight>,
error: Option<RedisError>,
push_sender: Option<Arc<dyn AsyncPushSender>>,
cache_manager: Option<CacheManager>,
}
}
#[cfg(not(feature = "cache-aio"))]
pin_project! {
struct PipelineSink<T> {
#[pin]
sink_stream: T,
in_flight: VecDeque<InFlight>,
error: Option<RedisError>,
push_sender: Option<Arc<dyn AsyncPushSender>>,
}
}
fn send_push(push_sender: &Option<Arc<dyn AsyncPushSender>>, info: PushInfo) {
if let Some(sender) = push_sender {
let _ = sender.send(info);
};
}
pub(crate) fn send_disconnect(push_sender: &Option<Arc<dyn AsyncPushSender>>) {
send_push(push_sender, PushInfo::disconnect());
}
impl<T> PipelineSink<T>
where
T: Stream<Item = RedisResult<Value>> + 'static,
{
fn new(
sink_stream: T,
push_sender: Option<Arc<dyn AsyncPushSender>>,
#[cfg(feature = "cache-aio")] cache_manager: Option<CacheManager>,
) -> Self
where
T: Sink<Vec<u8>, Error = RedisError> + Stream<Item = RedisResult<Value>> + 'static,
{
PipelineSink {
sink_stream,
in_flight: VecDeque::new(),
error: None,
push_sender,
#[cfg(feature = "cache-aio")]
cache_manager,
}
}
fn poll_read(mut self: Pin<&mut Self>, cx: &mut task::Context) -> Poll<Result<(), ()>> {
loop {
let item = ready!(self.as_mut().project().sink_stream.poll_next(cx));
let item = match item {
Some(result) => result,
None => Err(closed_connection_error()),
};
let is_unrecoverable = item.as_ref().is_err_and(|err| err.is_unrecoverable_error());
self.as_mut().send_result(item);
if is_unrecoverable {
let self_ = self.project();
send_disconnect(self_.push_sender);
return Poll::Ready(Err(()));
}
}
}
fn send_result(self: Pin<&mut Self>, result: RedisResult<Value>) {
let self_ = self.project();
let result = match result {
Ok(Value::Push { kind, data }) if !kind.has_reply() => {
#[cfg(feature = "cache-aio")]
if let Some(cache_manager) = &self_.cache_manager {
cache_manager.handle_push_value(&kind, &data);
}
send_push(self_.push_sender, PushInfo { kind, data });
return;
}
Ok(Value::Push { kind, data }) if kind.has_reply() => {
send_push(
self_.push_sender,
PushInfo {
kind: kind.clone(),
data: data.clone(),
},
);
Ok(Value::Push { kind, data })
}
_ => result,
};
let mut entry = match self_.in_flight.pop_front() {
Some(entry) => entry,
None => return,
};
match &mut entry.response_aggregate {
ResponseAggregate::SingleCommand => {
if let Some(output) = entry.output.take() {
_ = output.send(result);
}
}
ResponseAggregate::Pipeline {
buffer,
error_or_errors,
expectation:
PipelineResponseExpectation {
expected_response_count,
skipped_response_count,
is_transaction,
seen_responses,
},
} => {
*seen_responses += 1;
if *skipped_response_count > 0 {
if *is_transaction {
if let ErrorOrErrors::Errors(errs) = error_or_errors {
match result {
Ok(Value::ServerError(err)) => {
errs.push((*seen_responses - 2, err)); }
Err(err) => *error_or_errors = ErrorOrErrors::FirstError(err),
_ => {}
}
}
}
*skipped_response_count -= 1;
self_.in_flight.push_front(entry);
return;
}
match result {
Ok(item) => {
buffer.push(item);
}
Err(err) => {
if matches!(error_or_errors, ErrorOrErrors::Errors(_)) {
*error_or_errors = ErrorOrErrors::FirstError(err)
}
}
}
if buffer.len() < *expected_response_count {
self_.in_flight.push_front(entry);
return;
}
let response =
match std::mem::replace(error_or_errors, ErrorOrErrors::Errors(Vec::new())) {
ErrorOrErrors::Errors(errors) => {
if errors.is_empty() {
Ok(Value::Array(std::mem::take(buffer)))
} else {
Err(RedisError::make_aborted_transaction(errors))
}
}
ErrorOrErrors::FirstError(redis_error) => Err(redis_error),
};
if let Some(output) = entry.output.take() {
_ = output.send(response);
}
}
}
}
}
impl<T> Sink<PipelineMessage> for PipelineSink<T>
where
T: Sink<Vec<u8>, Error = RedisError> + Stream<Item = RedisResult<Value>> + 'static,
{
type Error = ();
fn poll_ready(
mut self: Pin<&mut Self>,
cx: &mut task::Context,
) -> Poll<Result<(), Self::Error>> {
match ready!(self.as_mut().project().sink_stream.poll_ready(cx)) {
Ok(()) => Ok(()).into(),
Err(err) => {
*self.project().error = Some(err);
Ok(()).into()
}
}
}
fn start_send(
mut self: Pin<&mut Self>,
PipelineMessage {
input,
mut output,
expectation,
}: PipelineMessage,
) -> Result<(), Self::Error> {
if output.as_ref().is_some_and(|output| output.is_closed()) {
return Ok(());
}
let self_ = self.as_mut().project();
if let Some(err) = self_.error.take() {
if let Some(output) = output.take() {
_ = output.send(Err(err));
}
return Err(());
}
match self_.sink_stream.start_send(input) {
Ok(()) => {
let response_aggregate = ResponseAggregate::new(expectation);
let entry = InFlight {
output,
response_aggregate,
};
self_.in_flight.push_back(entry);
Ok(())
}
Err(err) => {
if let Some(output) = output.take() {
_ = output.send(Err(err));
}
Err(())
}
}
}
fn poll_flush(
mut self: Pin<&mut Self>,
cx: &mut task::Context,
) -> Poll<Result<(), Self::Error>> {
ready!(
self.as_mut()
.project()
.sink_stream
.poll_flush(cx)
.map_err(|err| {
self.as_mut().send_result(Err(err));
})
)?;
self.poll_read(cx)
}
fn poll_close(
mut self: Pin<&mut Self>,
cx: &mut task::Context,
) -> Poll<Result<(), Self::Error>> {
if !self.in_flight.is_empty() {
ready!(self.as_mut().poll_flush(cx))?;
}
let this = self.as_mut().project();
this.sink_stream.poll_close(cx).map_err(|err| {
self.send_result(Err(err));
})
}
}
impl Pipeline {
const DEFAULT_BUFFER_SIZE: usize = 50;
fn resolve_buffer_size(size: Option<usize>) -> usize {
size.unwrap_or(Self::DEFAULT_BUFFER_SIZE)
}
fn new<T>(
sink_stream: T,
push_sender: Option<Arc<dyn AsyncPushSender>>,
#[cfg(feature = "cache-aio")] cache_manager: Option<CacheManager>,
buffer_size: usize,
) -> (Self, impl Future<Output = ()>)
where
T: Sink<Vec<u8>, Error = RedisError>,
T: Stream<Item = RedisResult<Value>>,
T: Unpin + Send + 'static,
{
let (sender, mut receiver) = mpsc::channel(buffer_size);
let sink = PipelineSink::new(
sink_stream,
push_sender,
#[cfg(feature = "cache-aio")]
cache_manager,
);
let f = stream::poll_fn(move |cx| receiver.poll_recv(cx))
.map(Ok)
.forward(sink)
.map(|_| ());
(Pipeline { sender }, f)
}
async fn send_recv(
&mut self,
input: Vec<u8>,
expectation: Option<PipelineResponseExpectation>,
timeout: Option<Duration>,
skip_response: bool,
) -> Result<Value, RedisError> {
if input.is_empty() {
return Err(RedisError::make_empty_command());
}
let request = async {
if skip_response {
self.sender
.send(PipelineMessage {
input,
expectation,
output: None,
})
.await
.map_err(|_| None)?;
return Ok(Value::Nil);
}
let (sender, receiver) = oneshot::channel();
self.sender
.send(PipelineMessage {
input,
expectation,
output: Some(sender),
})
.await
.map_err(|_| None)?;
receiver.await
.map_err(|_| None)
.and_then(|res| res.map_err(Some))
};
match timeout {
Some(timeout) => match Runtime::locate().timeout(timeout, request).await {
Ok(res) => res,
Err(elapsed) => Err(Some(elapsed.into())),
},
None => request.await,
}
.map_err(|err| err.unwrap_or_else(closed_connection_error))
}
}
#[derive(Clone)]
pub struct MultiplexedConnection {
pipeline: Pipeline,
db: i64,
response_timeout: Option<Duration>,
protocol: ProtocolVersion,
concurrency_limiter: Option<Arc<async_lock::Semaphore>>,
_task_handle: Option<SharedHandleContainer>,
#[cfg(feature = "cache-aio")]
pub(crate) cache_manager: Option<CacheManager>,
#[cfg(feature = "token-based-authentication")]
_credentials_subscription_task_handle: Option<SharedHandleContainer>,
#[cfg(feature = "token-based-authentication")]
re_authentication_failed: Arc<AtomicBool>,
}
impl Debug for MultiplexedConnection {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let MultiplexedConnection {
pipeline,
db,
response_timeout,
protocol,
concurrency_limiter: _,
_task_handle,
#[cfg(feature = "cache-aio")]
cache_manager: _,
#[cfg(feature = "token-based-authentication")]
_credentials_subscription_task_handle: _,
#[cfg(feature = "token-based-authentication")]
re_authentication_failed: _,
} = self;
f.debug_struct("MultiplexedConnection")
.field("pipeline", &pipeline)
.field("db", &db)
.field("response_timeout", &response_timeout)
.field("protocol", &protocol)
.finish()
}
}
impl MultiplexedConnection {
pub async fn new<C>(
connection_info: &RedisConnectionInfo,
stream: C,
) -> RedisResult<(Self, impl Future<Output = ()>)>
where
C: Unpin + AsyncRead + AsyncWrite + Send + 'static,
{
Self::new_with_config(connection_info, stream, AsyncConnectionConfig::default()).await
}
pub async fn new_with_config<C>(
connection_info: &RedisConnectionInfo,
stream: C,
config: AsyncConnectionConfig,
) -> RedisResult<(Self, impl Future<Output = ()> + 'static)>
where
C: Unpin + AsyncRead + AsyncWrite + Send + 'static,
{
let mut codec = ValueCodec::default().framed(stream);
if config.push_sender.is_some() {
check_resp3!(
connection_info.protocol,
"Can only pass push sender to a connection using RESP3"
);
}
#[cfg(feature = "cache-aio")]
let cache_config = config.cache.as_ref().map(|cache| match cache {
crate::client::Cache::Config(cache_config) => *cache_config,
#[cfg(any(feature = "connection-manager", feature = "cluster-async"))]
crate::client::Cache::Manager(cache_manager) => cache_manager.cache_config,
});
#[cfg(feature = "cache-aio")]
let cache_manager_opt = config
.cache
.map(|cache| {
check_resp3!(
connection_info.protocol,
"Can only enable client side caching in a connection using RESP3"
);
match cache {
crate::client::Cache::Config(cache_config) => {
Ok(CacheManager::new(cache_config))
}
#[cfg(any(feature = "connection-manager", feature = "cluster-async"))]
crate::client::Cache::Manager(cache_manager) => Ok(cache_manager),
}
})
.transpose()?;
#[cfg(feature = "token-based-authentication")]
let mut connection_info = connection_info.clone();
#[cfg(not(feature = "token-based-authentication"))]
let connection_info = connection_info.clone();
#[cfg(feature = "token-based-authentication")]
if let Some(ref credentials_provider) = config.credentials_provider {
match credentials_provider.subscribe().next().await {
Some(Ok(credentials)) => {
connection_info.username = Some(ArcStr::from(credentials.username));
connection_info.password = Some(ArcStr::from(credentials.password));
}
Some(Err(err)) => {
error!("Error while receiving credentials from stream: {err}");
return Err(err);
}
None => {
let err = RedisError::from((
ErrorKind::AuthenticationFailed,
"Credentials stream closed unexpectedly before yielding credentials!",
));
error!("{err}");
return Err(err);
}
}
}
setup_connection(
&mut codec,
&connection_info,
#[cfg(feature = "cache-aio")]
cache_config,
)
.await?;
if config.push_sender.is_some() {
check_resp3!(
connection_info.protocol,
"Can only pass push sender to a connection using RESP3"
);
}
let (pipeline, driver) = Pipeline::new(
codec,
config.push_sender,
#[cfg(feature = "cache-aio")]
cache_manager_opt.clone(),
Pipeline::resolve_buffer_size(config.pipeline_buffer_size),
);
let concurrency_limiter = config
.concurrency_limit
.map(|n| Arc::new(async_lock::Semaphore::new(n)));
let con = MultiplexedConnection {
pipeline,
db: connection_info.db,
response_timeout: config.response_timeout,
protocol: connection_info.protocol,
concurrency_limiter,
_task_handle: None,
#[cfg(feature = "cache-aio")]
cache_manager: cache_manager_opt,
#[cfg(feature = "token-based-authentication")]
_credentials_subscription_task_handle: None,
#[cfg(feature = "token-based-authentication")]
re_authentication_failed: Arc::new(AtomicBool::new(false)),
};
#[cfg(feature = "token-based-authentication")]
if let Some(streaming_provider) = config.credentials_provider {
let mut inner_connection = con.clone();
let re_authentication_failed_arc = Arc::clone(&con.re_authentication_failed);
let mut stream = streaming_provider.subscribe();
let subscription_task_handle = Runtime::locate().spawn(async move {
let mut error_cause_logged = false;
while let Some(result) = stream.next().await {
match result {
Ok(credentials) => {
if let Err(err) = inner_connection
.re_authenticate_with_credentials(&credentials)
.await
{
error!("Failed to re-authenticate async connection: {err}.");
error_cause_logged = true;
re_authentication_failed_arc.store(true, Ordering::Relaxed);
break;
} else {
debug!("Re-authenticated async connection");
}
}
Err(err) => {
error!("Credentials stream error for async connection: {err}.");
error_cause_logged = true;
}
}
}
if !re_authentication_failed_arc.load(Ordering::Relaxed) {
if !error_cause_logged {
error!("Re-authentication stream ended unexpectedly.");
}
re_authentication_failed_arc.store(true, Ordering::Relaxed);
}
});
return Ok((
Self {
_credentials_subscription_task_handle: Some(SharedHandleContainer::new(
subscription_task_handle,
)),
..con
},
driver,
));
}
Ok((con, driver))
}
pub(crate) fn set_task_handle(&mut self, handle: TaskHandle) {
self._task_handle = Some(SharedHandleContainer::new(handle));
}
pub fn set_response_timeout(&mut self, timeout: std::time::Duration) {
self.response_timeout = Some(timeout);
}
pub async fn send_packed_command(&mut self, cmd: &Cmd) -> RedisResult<Value> {
let _permit = if cmd.skip_concurrency_limit {
None
} else if let Some(limiter) = &self.concurrency_limiter {
Some(limiter.acquire().await)
} else {
None
};
#[cfg(feature = "token-based-authentication")]
if self.re_authentication_failed.load(Ordering::Relaxed) {
return Err(RedisError::from((
ErrorKind::AuthenticationFailed,
"Connection is no longer usable due to re-authentication failure",
)));
}
#[cfg(feature = "cache-aio")]
if let Some(cache_manager) = &self.cache_manager {
match cache_manager.get_cached_cmd(cmd) {
PrepareCacheResult::Cached(value) => return Ok(value),
PrepareCacheResult::NotCached(cacheable_command) => {
let mut pipeline = crate::Pipeline::new();
cacheable_command.pack_command(cache_manager, &mut pipeline);
let result = self
.pipeline
.send_recv(
pipeline.get_packed_pipeline(),
Some(PipelineResponseExpectation {
skipped_response_count: 0,
expected_response_count: pipeline.commands.len(),
is_transaction: false,
seen_responses: 0,
}),
self.response_timeout,
cmd.is_no_response(),
)
.await?;
let replies: Vec<Value> = crate::types::from_redis_value(result)?;
return cacheable_command.resolve(cache_manager, replies.into_iter());
}
_ => (),
}
}
self.pipeline
.send_recv(
cmd.get_packed_command(),
None,
self.response_timeout,
cmd.is_no_response(),
)
.await
}
pub async fn send_packed_commands(
&mut self,
cmd: &crate::Pipeline,
offset: usize,
count: usize,
) -> RedisResult<Vec<Value>> {
let _permits = if let Some(limiter) = &self.concurrency_limiter {
let mut permits = Vec::with_capacity(count.max(1));
permits.push(limiter.acquire().await);
for _ in 1..count {
match limiter.try_acquire() {
Some(permit) => permits.push(permit),
None => break,
}
}
permits
} else {
Vec::new()
};
#[cfg(feature = "token-based-authentication")]
if self.re_authentication_failed.load(Ordering::Relaxed) {
return Err(RedisError::from((
ErrorKind::AuthenticationFailed,
"Connection is no longer usable due to re-authentication failure",
)));
}
#[cfg(feature = "cache-aio")]
if let Some(cache_manager) = &self.cache_manager {
let (cacheable_pipeline, pipeline, (skipped_response_count, expected_response_count)) =
cache_manager.get_cached_pipeline(cmd);
if pipeline.is_empty() {
return cacheable_pipeline.resolve(cache_manager, Value::Array(Vec::new()));
}
let result = self
.pipeline
.send_recv(
pipeline.get_packed_pipeline(),
Some(PipelineResponseExpectation {
skipped_response_count,
expected_response_count,
is_transaction: cacheable_pipeline.transaction_mode,
seen_responses: 0,
}),
self.response_timeout,
false,
)
.await?;
return cacheable_pipeline.resolve(cache_manager, result);
}
let value = self
.pipeline
.send_recv(
cmd.get_packed_pipeline(),
Some(PipelineResponseExpectation {
skipped_response_count: offset,
expected_response_count: count,
is_transaction: cmd.is_transaction(),
seen_responses: 0,
}),
self.response_timeout,
false,
)
.await?;
match value {
Value::Array(values) => Ok(values),
_ => Ok(vec![value]),
}
}
#[cfg(feature = "cache-aio")]
#[cfg_attr(docsrs, doc(cfg(feature = "cache-aio")))]
pub fn get_cache_statistics(&self) -> Option<CacheStatistics> {
self.cache_manager.as_ref().map(|cm| cm.statistics())
}
}
impl ConnectionLike for MultiplexedConnection {
fn req_packed_command<'a>(&'a mut self, cmd: &'a Cmd) -> RedisFuture<'a, Value> {
(async move { self.send_packed_command(cmd).await }).boxed()
}
fn req_packed_commands<'a>(
&'a mut self,
cmd: &'a crate::Pipeline,
offset: usize,
count: usize,
) -> RedisFuture<'a, Vec<Value>> {
(async move { self.send_packed_commands(cmd, offset, count).await }).boxed()
}
fn get_db(&self) -> i64 {
self.db
}
}
impl MultiplexedConnection {
pub async fn subscribe(&mut self, channel_name: impl ToRedisArgs) -> RedisResult<()> {
check_resp3!(self.protocol);
let mut cmd = cmd("SUBSCRIBE");
cmd.arg(channel_name);
cmd.exec_async(self).await?;
Ok(())
}
pub async fn unsubscribe(&mut self, channel_name: impl ToRedisArgs) -> RedisResult<()> {
check_resp3!(self.protocol);
let mut cmd = cmd("UNSUBSCRIBE");
cmd.arg(channel_name);
cmd.exec_async(self).await?;
Ok(())
}
pub async fn psubscribe(&mut self, channel_pattern: impl ToRedisArgs) -> RedisResult<()> {
check_resp3!(self.protocol);
let mut cmd = cmd("PSUBSCRIBE");
cmd.arg(channel_pattern);
cmd.exec_async(self).await?;
Ok(())
}
pub async fn punsubscribe(&mut self, channel_pattern: impl ToRedisArgs) -> RedisResult<()> {
check_resp3!(self.protocol);
let mut cmd = cmd("PUNSUBSCRIBE");
cmd.arg(channel_pattern);
cmd.exec_async(self).await?;
Ok(())
}
}
#[cfg(feature = "token-based-authentication")]
impl MultiplexedConnection {
async fn re_authenticate_with_credentials(
&mut self,
credentials: &crate::auth::BasicAuth,
) -> RedisResult<()> {
let mut auth_cmd =
crate::connection::authenticate_cmd(Some(&credentials.username), &credentials.password);
auth_cmd.skip_concurrency_limit = true;
self.send_packed_command(&auth_cmd)
.await?
.extract_error()
.map(|_| ())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_pipeline_resolve_buffer_size_default() {
assert_eq!(Pipeline::resolve_buffer_size(None), 50);
}
#[test]
fn test_pipeline_resolve_buffer_size_custom() {
assert_eq!(Pipeline::resolve_buffer_size(Some(100)), 100);
}
fn mock_conn_info() -> RedisConnectionInfo {
RedisConnectionInfo {
skip_set_lib_name: true,
..Default::default()
}
}
async fn create_mock_connection(
concurrency_limit: usize,
) -> (
MultiplexedConnection,
tokio::sync::mpsc::Receiver<()>,
tokio::sync::mpsc::Sender<()>,
) {
use futures_util::StreamExt;
use tokio::io::AsyncWriteExt;
use tokio_util::codec::FramedRead;
let (client_half, server_half) = tokio::io::duplex(4096);
let (cmd_received_tx, cmd_received_rx) = tokio::sync::mpsc::channel::<()>(10);
let (send_response_tx, mut send_response_rx) = tokio::sync::mpsc::channel::<()>(10);
let (server_read, mut server_write) = tokio::io::split(server_half);
tokio::spawn(async move {
let mut reader = FramedRead::new(server_read, ValueCodec::default());
while let Some(Ok(_)) = reader.next().await {
let _ = cmd_received_tx.send(()).await;
}
});
tokio::spawn(async move {
while send_response_rx.recv().await.is_some() {
let _ = server_write.write_all(b"+OK\r\n").await;
let _ = server_write.flush().await;
}
});
let config = AsyncConnectionConfig::new()
.set_concurrency_limit(concurrency_limit)
.set_response_timeout(None)
.set_connection_timeout(None);
let (conn, driver) =
MultiplexedConnection::new_with_config(&mock_conn_info(), client_half, config)
.await
.unwrap();
tokio::spawn(driver);
(conn, cmd_received_rx, send_response_tx)
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn test_concurrency_limit_enforced() {
let (conn, mut cmd_received_rx, send_response_tx) = create_mock_connection(2).await;
let h1 = tokio::spawn({
let mut c = conn.clone();
async move { c.send_packed_command(&cmd("PING")).await }
});
let h2 = tokio::spawn({
let mut c = conn.clone();
async move { c.send_packed_command(&cmd("PING")).await }
});
let h3 = tokio::spawn({
let mut c = conn.clone();
async move { c.send_packed_command(&cmd("PING")).await }
});
cmd_received_rx.recv().await.unwrap();
cmd_received_rx.recv().await.unwrap();
let third = tokio::time::timeout(Duration::from_millis(100), cmd_received_rx.recv()).await;
assert!(
third.is_err(),
"3rd request should be blocked by concurrency limit"
);
send_response_tx.send(()).await.unwrap();
cmd_received_rx.recv().await.unwrap();
send_response_tx.send(()).await.unwrap();
send_response_tx.send(()).await.unwrap();
h1.await.unwrap().unwrap();
h2.await.unwrap().unwrap();
h3.await.unwrap().unwrap();
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn test_no_limit_bypasses_concurrency_limit() {
let (conn, mut cmd_received_rx, send_response_tx) = create_mock_connection(1).await;
let h1 = tokio::spawn({
let mut c = conn.clone();
async move { c.send_packed_command(&cmd("PING")).await }
});
cmd_received_rx.recv().await.unwrap();
let h2 = tokio::spawn({
let mut c = conn.clone();
async move {
let mut ping = cmd("PING");
ping.skip_concurrency_limit = true;
c.send_packed_command(&ping).await
}
});
let received =
tokio::time::timeout(Duration::from_millis(100), cmd_received_rx.recv()).await;
assert!(
received.is_ok(),
"no_limit request should bypass concurrency limit"
);
send_response_tx.send(()).await.unwrap();
send_response_tx.send(()).await.unwrap();
h1.await.unwrap().unwrap();
h2.await.unwrap().unwrap();
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn test_pipeline_acquires_multiple_permits() {
let (conn, mut cmd_received_rx, send_response_tx) = create_mock_connection(3).await;
let pipeline_handle = tokio::spawn({
let mut c = conn.clone();
async move {
let mut pipe = crate::Pipeline::new();
pipe.cmd("SET").arg("a").arg("1");
pipe.cmd("SET").arg("b").arg("2");
pipe.cmd("SET").arg("c").arg("3");
c.send_packed_commands(&pipe, 0, 3).await
}
});
for _ in 0..3 {
cmd_received_rx.recv().await.unwrap();
}
let single_handle = tokio::spawn({
let mut c = conn.clone();
async move { c.send_packed_command(&cmd("PING")).await }
});
let blocked =
tokio::time::timeout(Duration::from_millis(100), cmd_received_rx.recv()).await;
assert!(
blocked.is_err(),
"single command should be blocked while pipeline holds all permits"
);
for _ in 0..3 {
send_response_tx.send(()).await.unwrap();
}
cmd_received_rx.recv().await.unwrap();
send_response_tx.send(()).await.unwrap();
pipeline_handle.await.unwrap().unwrap();
single_handle.await.unwrap().unwrap();
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn test_pipeline_proceeds_with_partial_permits() {
let (conn, mut cmd_received_rx, send_response_tx) = create_mock_connection(2).await;
let single_handle = tokio::spawn({
let mut c = conn.clone();
async move { c.send_packed_command(&cmd("PING")).await }
});
cmd_received_rx.recv().await.unwrap();
let pipeline_handle = tokio::spawn({
let mut c = conn.clone();
async move {
let mut pipe = crate::Pipeline::new();
for i in 0..5 {
pipe.cmd("SET").arg(format!("k{i}")).arg(i);
}
c.send_packed_commands(&pipe, 0, 5).await
}
});
let received =
tokio::time::timeout(Duration::from_millis(100), cmd_received_rx.recv()).await;
assert!(
received.is_ok(),
"pipeline should proceed even with only partial permits"
);
for _ in 1..5 {
cmd_received_rx.recv().await.unwrap();
}
for _ in 0..6 {
send_response_tx.send(()).await.unwrap();
}
single_handle.await.unwrap().unwrap();
pipeline_handle.await.unwrap().unwrap();
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn test_permit_released_on_cancellation() {
let (conn, mut cmd_received_rx, send_response_tx) = create_mock_connection(1).await;
let h1 = tokio::spawn({
let mut c = conn.clone();
async move { c.send_packed_command(&cmd("PING")).await }
});
cmd_received_rx.recv().await.unwrap();
let h2 = tokio::spawn({
let mut c = conn.clone();
async move { c.send_packed_command(&cmd("PING")).await }
});
tokio::time::sleep(Duration::from_millis(50)).await;
h2.abort();
let _ = h2.await;
send_response_tx.send(()).await.unwrap();
h1.await.unwrap().unwrap();
let h3 = tokio::spawn({
let mut c = conn.clone();
async move { c.send_packed_command(&cmd("PING")).await }
});
let received =
tokio::time::timeout(Duration::from_millis(100), cmd_received_rx.recv()).await;
assert!(
received.is_ok(),
"request after cancellation should acquire the permit"
);
send_response_tx.send(()).await.unwrap();
h3.await.unwrap().unwrap();
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn test_permit_released_on_response_timeout() {
use futures_util::StreamExt;
use tokio::io::AsyncWriteExt;
use tokio_util::codec::FramedRead;
let (client_half, server_half) = tokio::io::duplex(4096);
let (cmd_received_tx, mut cmd_received_rx) = tokio::sync::mpsc::channel::<()>(10);
let (server_read, mut server_write) = tokio::io::split(server_half);
tokio::spawn(async move {
let mut reader = FramedRead::new(server_read, ValueCodec::default());
while let Some(Ok(_)) = reader.next().await {
let _ = cmd_received_tx.send(()).await;
}
});
tokio::spawn(async move {
futures_util::future::pending::<()>().await;
let _ = server_write.write_all(b"").await;
});
let config = AsyncConnectionConfig::new()
.set_concurrency_limit(1)
.set_response_timeout(Some(Duration::from_millis(100)))
.set_connection_timeout(None);
let (conn, driver) =
MultiplexedConnection::new_with_config(&mock_conn_info(), client_half, config)
.await
.unwrap();
tokio::spawn(driver);
let mut c1 = conn.clone();
let err = c1.send_packed_command(&cmd("PING")).await.unwrap_err();
assert!(err.is_io_error(), "expected IO error from timeout");
cmd_received_rx.recv().await.unwrap();
let mut c2 = conn.clone();
let err = c2.send_packed_command(&cmd("PING")).await.unwrap_err();
assert!(err.is_io_error(), "expected IO error from timeout");
cmd_received_rx.recv().await.unwrap();
}
}