use super::RedisFuture;
use crate::cmd::Cmd;
use crate::types::{RedisError, RedisResult, Value};
use crate::{
aio::{ConnectionLike, MultiplexedConnection, Runtime},
Client,
};
#[cfg(all(not(feature = "tokio-comp"), feature = "async-std-comp"))]
use ::async_std::net::ToSocketAddrs;
use arc_swap::ArcSwap;
use futures::{
future::{self, Shared},
FutureExt,
};
use futures_util::future::BoxFuture;
use std::sync::Arc;
use tokio_retry::strategy::{jitter, ExponentialBackoff};
use tokio_retry::Retry;
#[derive(Clone)]
pub struct ConnectionManager {
client: Client,
connection: Arc<ArcSwap<SharedRedisFuture<MultiplexedConnection>>>,
runtime: Runtime,
retry_strategy: ExponentialBackoff,
number_of_retries: usize,
response_timeout: std::time::Duration,
connection_timeout: std::time::Duration,
}
type CloneableRedisResult<T> = Result<T, Arc<RedisError>>;
type SharedRedisFuture<T> = Shared<BoxFuture<'static, CloneableRedisResult<T>>>;
macro_rules! reconnect_if_dropped {
($self:expr, $result:expr, $current:expr) => {
if let Err(ref e) = $result {
if e.is_unrecoverable_error() {
$self.reconnect($current);
}
}
};
}
macro_rules! reconnect_if_io_error {
($self:expr, $result:expr, $current:expr) => {
if let Err(e) = $result {
if e.is_io_error() {
$self.reconnect($current);
}
return Err(e);
}
};
}
impl ConnectionManager {
const DEFAULT_CONNECTION_RETRY_EXPONENT_BASE: u64 = 2;
const DEFAULT_CONNECTION_RETRY_FACTOR: u64 = 100;
const DEFAULT_NUMBER_OF_CONNECTION_RETRIESE: usize = 6;
pub async fn new(client: Client) -> RedisResult<Self> {
Self::new_with_backoff(
client,
Self::DEFAULT_CONNECTION_RETRY_EXPONENT_BASE,
Self::DEFAULT_CONNECTION_RETRY_FACTOR,
Self::DEFAULT_NUMBER_OF_CONNECTION_RETRIESE,
)
.await
}
pub async fn new_with_backoff(
client: Client,
exponent_base: u64,
factor: u64,
number_of_retries: usize,
) -> RedisResult<Self> {
Self::new_with_backoff_and_timeouts(
client,
exponent_base,
factor,
number_of_retries,
std::time::Duration::MAX,
std::time::Duration::MAX,
)
.await
}
pub async fn new_with_backoff_and_timeouts(
client: Client,
exponent_base: u64,
factor: u64,
number_of_retries: usize,
response_timeout: std::time::Duration,
connection_timeout: std::time::Duration,
) -> RedisResult<Self> {
let runtime = Runtime::locate();
let retry_strategy = ExponentialBackoff::from_millis(exponent_base).factor(factor);
let connection = Self::new_connection(
client.clone(),
retry_strategy.clone(),
number_of_retries,
response_timeout,
connection_timeout,
)
.await?;
Ok(Self {
client,
connection: Arc::new(ArcSwap::from_pointee(
future::ok(connection).boxed().shared(),
)),
runtime,
number_of_retries,
retry_strategy,
response_timeout,
connection_timeout,
})
}
async fn new_connection(
client: Client,
exponential_backoff: ExponentialBackoff,
number_of_retries: usize,
response_timeout: std::time::Duration,
connection_timeout: std::time::Duration,
) -> RedisResult<MultiplexedConnection> {
let retry_strategy = exponential_backoff.map(jitter).take(number_of_retries);
Retry::spawn(retry_strategy, || {
client.get_multiplexed_async_connection_with_timeouts(
response_timeout,
connection_timeout,
)
})
.await
}
fn reconnect(&self, current: arc_swap::Guard<Arc<SharedRedisFuture<MultiplexedConnection>>>) {
let client = self.client.clone();
let retry_strategy = self.retry_strategy.clone();
let number_of_retries = self.number_of_retries;
let response_timeout = self.response_timeout;
let connection_timeout = self.connection_timeout;
let new_connection: SharedRedisFuture<MultiplexedConnection> = async move {
Ok(Self::new_connection(
client,
retry_strategy,
number_of_retries,
response_timeout,
connection_timeout,
)
.await?)
}
.boxed()
.shared();
let new_connection_arc = Arc::new(new_connection.clone());
let prev = self
.connection
.compare_and_swap(¤t, new_connection_arc);
if Arc::ptr_eq(&prev, ¤t) {
self.runtime.spawn(new_connection.map(|_| ()));
}
}
pub async fn send_packed_command(&mut self, cmd: &Cmd) -> RedisResult<Value> {
let guard = self.connection.load();
let connection_result = (**guard)
.clone()
.await
.map_err(|e| e.clone_mostly("Reconnecting failed"));
reconnect_if_io_error!(self, connection_result, guard);
let result = connection_result?.send_packed_command(cmd).await;
reconnect_if_dropped!(self, &result, guard);
result
}
pub async fn send_packed_commands(
&mut self,
cmd: &crate::Pipeline,
offset: usize,
count: usize,
) -> RedisResult<Vec<Value>> {
let guard = self.connection.load();
let connection_result = (**guard)
.clone()
.await
.map_err(|e| e.clone_mostly("Reconnecting failed"));
reconnect_if_io_error!(self, connection_result, guard);
let result = connection_result?
.send_packed_commands(cmd, offset, count)
.await;
reconnect_if_dropped!(self, &result, guard);
result
}
}
impl ConnectionLike for ConnectionManager {
fn req_packed_command<'a>(&'a mut self, cmd: &'a Cmd) -> RedisFuture<'a, Value> {
(async move { self.send_packed_command(cmd).await }).boxed()
}
fn req_packed_commands<'a>(
&'a mut self,
cmd: &'a crate::Pipeline,
offset: usize,
count: usize,
) -> RedisFuture<'a, Vec<Value>> {
(async move { self.send_packed_commands(cmd, offset, count).await }).boxed()
}
fn get_db(&self) -> i64 {
self.client.connection_info().redis.db
}
}