use super::{AsyncPushSender, HandleContainer, RedisFuture};
#[cfg(feature = "cache-aio")]
use crate::caching::CacheManager;
use crate::{
AsyncConnectionConfig, Client, Cmd, Pipeline, PushInfo, PushKind, ToRedisArgs,
aio::{ConnectionLike, MultiplexedConnection, Runtime},
check_resp3,
client::{DEFAULT_CONNECTION_TIMEOUT, DEFAULT_RESPONSE_TIMEOUT},
cmd,
errors::RedisError,
subscription_tracker::{SubscriptionAction, SubscriptionTracker},
types::{RedisResult, Value},
};
use arc_swap::ArcSwap;
use backon::{ExponentialBuilder, Retryable};
use futures_channel::oneshot;
use futures_util::future::{BoxFuture, FutureExt, Shared};
use std::sync::{Arc, Weak};
use std::time::Duration;
use tokio::sync::Mutex;
use tokio::sync::mpsc::{UnboundedReceiver, unbounded_channel};
type OptionalPushSender = Option<Arc<dyn AsyncPushSender>>;
#[derive(Clone)]
pub struct ConnectionManagerConfig {
exponent_base: f32,
min_delay: Duration,
max_delay: Option<Duration>,
number_of_retries: usize,
response_timeout: Option<Duration>,
connection_timeout: Option<Duration>,
push_sender: Option<Arc<dyn AsyncPushSender>>,
resubscribe_automatically: bool,
#[cfg(feature = "cache-aio")]
pub(crate) cache_config: Option<crate::caching::CacheConfig>,
pipeline_buffer_size: Option<usize>,
#[cfg(feature = "token-based-authentication")]
credentials_provider: Option<std::sync::Arc<dyn crate::auth::StreamingCredentialsProvider>>,
}
impl std::fmt::Debug for ConnectionManagerConfig {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
let &Self {
exponent_base,
min_delay,
number_of_retries,
max_delay,
response_timeout,
connection_timeout,
push_sender,
resubscribe_automatically,
#[cfg(feature = "cache-aio")]
cache_config,
pipeline_buffer_size,
#[cfg(feature = "token-based-authentication")]
credentials_provider,
} = &self;
let mut str = f.debug_struct("ConnectionManagerConfig");
str.field("exponent_base", &exponent_base)
.field("min_delay", &min_delay)
.field("max_delay", &max_delay)
.field("number_of_retries", &number_of_retries)
.field("response_timeout", &response_timeout)
.field("connection_timeout", &connection_timeout)
.field("resubscribe_automatically", &resubscribe_automatically)
.field("pipeline_buffer_size", &pipeline_buffer_size)
.field(
"push_sender",
if push_sender.is_some() {
&"set"
} else {
&"not set"
},
);
#[cfg(feature = "cache-aio")]
str.field("cache_config", &cache_config);
#[cfg(feature = "token-based-authentication")]
str.field(
"credentials_provider",
if credentials_provider.is_some() {
&"set"
} else {
&"not set"
},
);
str.finish()
}
}
impl ConnectionManagerConfig {
const DEFAULT_CONNECTION_RETRY_EXPONENT_BASE: f32 = 2.0;
const DEFAULT_CONNECTION_RETRY_MIN_DELAY: Duration = Duration::from_millis(100);
const DEFAULT_NUMBER_OF_CONNECTION_RETRIES: usize = 6;
pub fn new() -> Self {
Self::default()
}
pub fn min_delay(&self) -> Duration {
self.min_delay
}
pub fn max_delay(&self) -> Option<Duration> {
self.max_delay
}
pub fn exponent_base(&self) -> f32 {
self.exponent_base
}
pub fn number_of_retries(&self) -> usize {
self.number_of_retries
}
pub fn response_timeout(&self) -> Option<Duration> {
self.response_timeout
}
pub fn connection_timeout(&self) -> Option<Duration> {
self.connection_timeout
}
pub fn automatic_resubscription(&self) -> bool {
self.resubscribe_automatically
}
#[cfg(feature = "cache-aio")]
pub fn cache_config(&self) -> Option<&crate::caching::CacheConfig> {
self.cache_config.as_ref()
}
pub fn set_min_delay(mut self, min_delay: Duration) -> ConnectionManagerConfig {
self.min_delay = min_delay;
self
}
pub fn set_max_delay(mut self, time: Duration) -> ConnectionManagerConfig {
self.max_delay = Some(time);
self
}
pub fn set_exponent_base(mut self, base: f32) -> ConnectionManagerConfig {
self.exponent_base = base;
self
}
pub fn set_number_of_retries(mut self, amount: usize) -> ConnectionManagerConfig {
self.number_of_retries = amount;
self
}
pub fn set_response_timeout(mut self, duration: Option<Duration>) -> ConnectionManagerConfig {
self.response_timeout = duration;
self
}
pub fn set_connection_timeout(mut self, duration: Option<Duration>) -> ConnectionManagerConfig {
self.connection_timeout = duration;
self
}
pub fn set_push_sender(mut self, sender: impl AsyncPushSender) -> Self {
self.push_sender = Some(Arc::new(sender));
self
}
pub fn set_automatic_resubscription(mut self) -> Self {
self.resubscribe_automatically = true;
self
}
#[cfg(feature = "cache-aio")]
pub fn set_cache_config(self, cache_config: crate::caching::CacheConfig) -> Self {
Self {
cache_config: Some(cache_config),
..self
}
}
pub fn set_pipeline_buffer_size(mut self, size: usize) -> Self {
self.pipeline_buffer_size = Some(size);
self
}
#[cfg(feature = "token-based-authentication")]
pub fn set_credentials_provider<P>(mut self, provider: P) -> Self
where
P: crate::auth::StreamingCredentialsProvider + 'static,
{
self.credentials_provider = Some(std::sync::Arc::new(provider));
self
}
}
impl Default for ConnectionManagerConfig {
fn default() -> Self {
Self {
exponent_base: Self::DEFAULT_CONNECTION_RETRY_EXPONENT_BASE,
min_delay: Self::DEFAULT_CONNECTION_RETRY_MIN_DELAY,
max_delay: None,
number_of_retries: Self::DEFAULT_NUMBER_OF_CONNECTION_RETRIES,
response_timeout: DEFAULT_RESPONSE_TIMEOUT,
connection_timeout: DEFAULT_CONNECTION_TIMEOUT,
push_sender: None,
resubscribe_automatically: false,
#[cfg(feature = "cache-aio")]
cache_config: None,
pipeline_buffer_size: None,
#[cfg(feature = "token-based-authentication")]
credentials_provider: None,
}
}
}
struct Internals {
client: Client,
connection: ArcSwap<SharedRedisFuture<MultiplexedConnection>>,
runtime: Runtime,
retry_strategy: ExponentialBuilder,
connection_config: AsyncConnectionConfig,
subscription_tracker: Option<Mutex<SubscriptionTracker>>,
#[cfg(feature = "cache-aio")]
cache_manager: Option<CacheManager>,
_task_handle: HandleContainer,
}
#[derive(Clone)]
pub struct ConnectionManager(Arc<Internals>);
impl std::fmt::Debug for ConnectionManager {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ConnectionManager")
.field("client", &self.0.client)
.field("retry_strategy", &self.0.retry_strategy)
.finish()
}
}
type SharedRedisFuture<T> = Shared<BoxFuture<'static, RedisResult<T>>>;
macro_rules! reconnect_if_dropped {
($self:expr, $result:expr, $current:expr) => {
if let Err(e) = $result {
if e.is_unrecoverable_error() {
Self::reconnect(Arc::downgrade(&$self.0), $current);
}
}
};
}
macro_rules! reconnect_if_io_error {
($self:expr, $result:expr, $current:expr) => {
if let Err(e) = $result {
if e.is_io_error() {
Self::reconnect(Arc::downgrade(&$self.0), $current);
}
return Err(e);
}
};
}
impl ConnectionManager {
pub async fn new(client: Client) -> RedisResult<Self> {
let config = ConnectionManagerConfig::new();
Self::new_with_config(client, config).await
}
pub async fn new_with_config(
client: Client,
config: ConnectionManagerConfig,
) -> RedisResult<Self> {
let manager = Self::new_lazy_with_config(client, config)?;
let guard = manager.0.connection.load();
(**guard).clone().await.map_err(|e| e.clone())?;
Ok(manager)
}
pub fn new_lazy_with_config(
client: Client,
config: ConnectionManagerConfig,
) -> RedisResult<Self> {
let runtime = Runtime::locate();
if config.resubscribe_automatically && config.push_sender.is_none() {
return Err((crate::ErrorKind::Client, "Cannot set resubscribe_automatically without setting a push sender to receive messages.").into());
}
let mut retry_strategy = ExponentialBuilder::default()
.with_factor(config.exponent_base)
.with_min_delay(config.min_delay)
.with_max_times(config.number_of_retries)
.with_jitter();
if let Some(max_delay) = config.max_delay {
retry_strategy = retry_strategy.with_max_delay(max_delay);
}
let mut connection_config = AsyncConnectionConfig::new()
.set_connection_timeout(config.connection_timeout)
.set_response_timeout(config.response_timeout);
connection_config.pipeline_buffer_size = config.pipeline_buffer_size;
#[cfg(feature = "cache-aio")]
let cache_manager = config
.cache_config
.as_ref()
.map(|cache_config| CacheManager::new(*cache_config));
#[cfg(feature = "cache-aio")]
if let Some(cache_manager) = cache_manager.as_ref() {
connection_config = connection_config.set_cache_manager(cache_manager.clone());
}
#[cfg(feature = "token-based-authentication")]
if let Some(credentials_provider) = config.credentials_provider {
connection_config.credentials_provider = Some(credentials_provider);
}
let (oneshot_sender, oneshot_receiver) = oneshot::channel();
let _task_handle = HandleContainer::new(
runtime.spawn(Self::check_for_disconnect_pushes(oneshot_receiver)),
);
let mut components_for_reconnection_on_push = None;
if let Some(push_sender) = config.push_sender.clone() {
check_resp3!(
client.connection_info.redis.protocol,
"Can only pass push sender to a connection using RESP3"
);
let (internal_sender, internal_receiver) = unbounded_channel();
components_for_reconnection_on_push = Some((internal_receiver, Some(push_sender)));
connection_config =
connection_config.set_push_sender_internal(Arc::new(internal_sender));
} else if client.connection_info.redis.protocol.supports_resp3() {
let (internal_sender, internal_receiver) = unbounded_channel();
components_for_reconnection_on_push = Some((internal_receiver, None));
connection_config =
connection_config.set_push_sender_internal(Arc::new(internal_sender));
}
let subscription_tracker = if config.resubscribe_automatically {
Some(Mutex::new(SubscriptionTracker::default()))
} else {
None
};
let client_clone = client.clone();
let retry_strategy_clone = retry_strategy;
let connection_config_clone = connection_config.clone();
let lazy_connection: SharedRedisFuture<MultiplexedConnection> = async move {
Self::new_connection(
&client_clone,
retry_strategy_clone,
&connection_config_clone,
None,
)
.await
}
.boxed()
.shared();
let new_self = Self(Arc::new(Internals {
client,
connection: ArcSwap::from_pointee(lazy_connection),
runtime,
retry_strategy,
connection_config,
subscription_tracker,
#[cfg(feature = "cache-aio")]
cache_manager,
_task_handle,
}));
if let Some((internal_receiver, external_sender)) = components_for_reconnection_on_push {
oneshot_sender
.send((
Arc::downgrade(&new_self.0),
internal_receiver,
external_sender,
))
.map_err(|_| {
crate::RedisError::from((
crate::ErrorKind::Client,
"Failed to set automatic resubscription",
))
})?;
};
Ok(new_self)
}
async fn new_connection(
client: &Client,
exponential_backoff: ExponentialBuilder,
connection_config: &AsyncConnectionConfig,
additional_commands: Option<Pipeline>,
) -> RedisResult<MultiplexedConnection> {
let connection_config = connection_config.clone();
let get_conn = || async {
client
.get_multiplexed_async_connection_with_config(&connection_config)
.await
};
let mut conn = get_conn
.retry(exponential_backoff)
.sleep(|duration| async move { Runtime::locate().sleep(duration).await })
.await?;
if let Some(pipeline) = additional_commands {
let _ = pipeline.exec_async(&mut conn).await;
}
Ok(conn)
}
fn reconnect(
internals: Weak<Internals>,
current: arc_swap::Guard<Arc<SharedRedisFuture<MultiplexedConnection>>>,
) {
let Some(internals) = internals.upgrade() else {
return;
};
let internals_clone = internals.clone();
#[cfg(not(feature = "cache-aio"))]
let connection_config = internals.connection_config.clone();
#[cfg(feature = "cache-aio")]
let mut connection_config = internals.connection_config.clone();
#[cfg(feature = "cache-aio")]
if let Some(manager) = internals.cache_manager.as_ref() {
let new_cache_manager = manager.clone_and_increase_epoch();
connection_config = connection_config.set_cache_manager(new_cache_manager);
}
let new_connection: SharedRedisFuture<MultiplexedConnection> = async move {
let additional_commands = match &internals_clone.subscription_tracker {
Some(subscription_tracker) => Some(
subscription_tracker
.lock()
.await
.get_subscription_pipeline(),
),
None => None,
};
let con = Self::new_connection(
&internals_clone.client,
internals_clone.retry_strategy,
&connection_config,
additional_commands,
)
.await?;
Ok(con)
}
.boxed()
.shared();
let new_connection_arc = Arc::new(new_connection.clone());
let prev = internals
.connection
.compare_and_swap(¤t, new_connection_arc);
if Arc::ptr_eq(&prev, ¤t) {
internals.runtime.spawn(new_connection.map(|_| ())).detach();
}
}
async fn check_for_disconnect_pushes(
receiver: oneshot::Receiver<(
Weak<Internals>,
UnboundedReceiver<PushInfo>,
OptionalPushSender,
)>,
) {
let Ok((this, mut internal_receiver, external_sender)) = receiver.await else {
return;
};
while let Some(push_info) = internal_receiver.recv().await {
if push_info.kind == PushKind::Disconnection {
let Some(internals) = this.upgrade() else {
return;
};
Self::reconnect(Arc::downgrade(&internals), internals.connection.load());
}
if let Some(sender) = external_sender.as_ref() {
let _ = sender.send(push_info);
}
}
}
pub async fn send_packed_command(&mut self, cmd: &Cmd) -> RedisResult<Value> {
let guard = self.0.connection.load();
let connection_result = (**guard).clone().await.map_err(|e| e.clone());
reconnect_if_io_error!(self, connection_result, guard);
let result = connection_result?.send_packed_command(cmd).await;
reconnect_if_dropped!(self, &result, guard);
result
}
pub async fn send_packed_commands(
&mut self,
cmd: &crate::Pipeline,
offset: usize,
count: usize,
) -> RedisResult<Vec<Value>> {
let guard = self.0.connection.load();
let connection_result = (**guard).clone().await.map_err(|e| e.clone());
reconnect_if_io_error!(self, connection_result, guard);
let result = connection_result?
.send_packed_commands(cmd, offset, count)
.await;
reconnect_if_dropped!(self, &result, guard);
result
}
async fn update_subscription_tracker(
&self,
action: SubscriptionAction,
args: impl ToRedisArgs,
) {
let Some(subscription_tracker) = &self.0.subscription_tracker else {
return;
};
let args = args.to_redis_args().into_iter();
subscription_tracker
.lock()
.await
.update_with_request(action, args);
}
pub async fn subscribe(&mut self, channel_name: impl ToRedisArgs) -> RedisResult<()> {
check_resp3!(self.0.client.connection_info.redis.protocol);
let mut cmd = cmd("SUBSCRIBE");
cmd.arg(&channel_name);
cmd.exec_async(self).await?;
self.update_subscription_tracker(SubscriptionAction::Subscribe, channel_name)
.await;
Ok(())
}
pub async fn unsubscribe(&mut self, channel_name: impl ToRedisArgs) -> RedisResult<()> {
check_resp3!(self.0.client.connection_info.redis.protocol);
let mut cmd = cmd("UNSUBSCRIBE");
cmd.arg(&channel_name);
cmd.exec_async(self).await?;
self.update_subscription_tracker(SubscriptionAction::Unsubscribe, channel_name)
.await;
Ok(())
}
pub async fn psubscribe(&mut self, channel_pattern: impl ToRedisArgs) -> RedisResult<()> {
check_resp3!(self.0.client.connection_info.redis.protocol);
let mut cmd = cmd("PSUBSCRIBE");
cmd.arg(&channel_pattern);
cmd.exec_async(self).await?;
self.update_subscription_tracker(SubscriptionAction::PSubscribe, channel_pattern)
.await;
Ok(())
}
pub async fn punsubscribe(&mut self, channel_pattern: impl ToRedisArgs) -> RedisResult<()> {
check_resp3!(self.0.client.connection_info.redis.protocol);
let mut cmd = cmd("PUNSUBSCRIBE");
cmd.arg(&channel_pattern);
cmd.exec_async(self).await?;
self.update_subscription_tracker(SubscriptionAction::PUnsubscribe, channel_pattern)
.await;
Ok(())
}
#[cfg(feature = "cache-aio")]
#[cfg_attr(docsrs, doc(cfg(feature = "cache-aio")))]
pub fn get_cache_statistics(&self) -> Option<crate::caching::CacheStatistics> {
self.0.cache_manager.as_ref().map(|cm| cm.statistics())
}
}
impl ConnectionLike for ConnectionManager {
fn req_packed_command<'a>(&'a mut self, cmd: &'a Cmd) -> RedisFuture<'a, Value> {
(async move { self.send_packed_command(cmd).await }).boxed()
}
fn req_packed_commands<'a>(
&'a mut self,
cmd: &'a crate::Pipeline,
offset: usize,
count: usize,
) -> RedisFuture<'a, Vec<Value>> {
(async move { self.send_packed_commands(cmd, offset, count).await }).boxed()
}
fn get_db(&self) -> i64 {
self.0.client.connection_info().redis.db
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_connection_manager_config_pipeline_buffer_size_default() {
let config = ConnectionManagerConfig::new();
assert_eq!(config.pipeline_buffer_size, None);
}
#[test]
fn test_connection_manager_config_pipeline_buffer_size_custom() {
let config = ConnectionManagerConfig::new().set_pipeline_buffer_size(100);
assert_eq!(config.pipeline_buffer_size, Some(100));
}
#[test]
fn test_lazy_connection_manager_with_config() {
let client = Client::open("redis://127.0.0.1/").unwrap();
let config = ConnectionManagerConfig::new()
.set_pipeline_buffer_size(100)
.set_number_of_retries(3);
let result = ConnectionManager::new_lazy_with_config(client, config);
assert!(result.is_ok());
}
#[test]
fn test_lazy_connection_manager_rejects_invalid_config() {
let client = Client::open("redis://127.0.0.1/?protocol=resp3").unwrap();
let config = ConnectionManagerConfig::new().set_automatic_resubscription(); let result = ConnectionManager::new_lazy_with_config(client, config);
assert_matches::assert_matches!(result, Err(err) if err.kind() == crate::ErrorKind::Client);
}
}