use super::{NodeAddress, TlsMode};
use crate::connection::factory::FerrisKeyConnectionOptions;
#[cfg(feature = "iam")]
use crate::connection::factory::IAMTokenProvider;
use crate::connection::info::ValkeyConnectionInfo;
use crate::connection::{DisconnectNotifier, MultiplexedConnection};
use crate::pubsub::push_manager::PushInfo;
use crate::retry_strategies::RetryStrategy;
use crate::value::{Error, Result};
use async_trait::async_trait;
use futures_intrusive::sync::ManualResetEvent;
use std::fmt;
use std::sync::Arc;
use std::sync::Mutex;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{RwLock, RwLockReadGuard};
use std::time::Duration;
use tokio::sync::{Notify, mpsc};
use tokio::task;
use tokio::time::timeout;
use tokio_retry2::{Retry, RetryError};
use super::{run_with_timeout, types::DEFAULT_CONNECTION_TIMEOUT};
#[derive(PartialEq, Eq, Debug, Clone)]
pub enum ReconnectReason {
ConnectionDropped,
CreateError,
}
#[cfg(feature = "iam")]
#[derive(Clone)]
pub struct IAMTokenHandle {
pub(crate) cached_token: Arc<tokio::sync::RwLock<String>>,
pub(crate) token_created_at: Arc<tokio::sync::RwLock<tokio::time::Instant>>,
pub(crate) iam_token_state: crate::iam::IamTokenState,
}
#[cfg(feature = "iam")]
impl IAMTokenHandle {
pub(crate) async fn get_valid_token_inner(&self) -> Option<String> {
use crate::iam::TOKEN_TTL_SECONDS;
let is_expired = {
let ts = self.token_created_at.read().await;
ts.elapsed() >= std::time::Duration::from_secs(TOKEN_TTL_SECONDS)
};
if is_expired {
tracing::info!("IAM reconnect - Token expired, generating a fresh token before reconnection");
match crate::iam::IAMTokenManager::generate_token_with_backoff(&self.iam_token_state)
.await
{
Ok(new_token) => {
{
let mut guard = self.cached_token.write().await;
*guard = new_token.clone();
}
{
let mut ts = self.token_created_at.write().await;
*ts = tokio::time::Instant::now();
}
return Some(new_token);
}
Err(err) => {
tracing::error!("IAM reconnect - Failed to generate fresh IAM token, using cached token: {err}");
}
}
}
let guard = self.cached_token.read().await;
let token = guard.clone();
if token.is_empty() { None } else { Some(token) }
}
}
#[cfg(feature = "iam")]
#[async_trait::async_trait]
impl IAMTokenProvider for IAMTokenHandle {
async fn get_valid_token(&self) -> Option<String> {
self.get_valid_token_inner().await
}
}
struct ConnectionBackend {
connection_available_signal: ManualResetEvent,
connection_info: RwLock<crate::connection::factory::Client>,
client_dropped_flagged: AtomicBool,
#[cfg(feature = "iam")]
iam_token_handle: Option<IAMTokenHandle>,
}
enum ConnectionState {
Connected(MultiplexedConnection),
Reconnecting,
InitializedDisconnected,
}
struct InnerReconnectingConnection {
state: Mutex<ConnectionState>,
backend: ConnectionBackend,
}
#[derive(Clone)]
pub(super) struct ReconnectingConnection {
inner: Arc<InnerReconnectingConnection>,
connection_options: FerrisKeyConnectionOptions,
}
impl fmt::Debug for ReconnectingConnection {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.node_address())
}
}
async fn get_multiplexed_connection(
client: &crate::connection::factory::Client,
connection_options: &FerrisKeyConnectionOptions,
) -> Result<MultiplexedConnection> {
run_with_timeout(
Some(
connection_options
.connection_timeout
.unwrap_or(DEFAULT_CONNECTION_TIMEOUT),
),
client.get_multiplexed_async_connection(connection_options.clone()),
)
.await
}
#[derive(Clone)]
struct TokioDisconnectNotifier {
disconnect_notifier: Arc<Notify>,
}
#[async_trait]
impl DisconnectNotifier for TokioDisconnectNotifier {
fn notify_disconnect(&mut self) {
self.disconnect_notifier.notify_waiters();
}
async fn wait_for_disconnect_with_timeout(&self, max_wait: &Duration) {
let _ = timeout(*max_wait, async {
self.disconnect_notifier.notified().await;
})
.await;
}
fn clone_box(&self) -> Box<dyn DisconnectNotifier> {
Box::new(self.clone())
}
}
impl TokioDisconnectNotifier {
fn new() -> TokioDisconnectNotifier {
TokioDisconnectNotifier {
disconnect_notifier: Arc::new(Notify::new()),
}
}
}
async fn create_connection(
connection_backend: ConnectionBackend,
retry_strategy: RetryStrategy,
push_sender: Option<mpsc::UnboundedSender<PushInfo>>,
discover_az: bool,
connection_timeout: Duration,
tcp_nodelay: bool,
pubsub_synchronizer: Option<Arc<dyn crate::pubsub::PubSubSynchronizer>>,
) -> std::result::Result<ReconnectingConnection, (ReconnectingConnection, Error)> {
let client = {
let guard = connection_backend
.connection_info
.read()
.unwrap_or_else(|e| e.into_inner());
guard.clone()
};
let connection_options = FerrisKeyConnectionOptions {
push_sender,
disconnect_notifier: Some::<Box<dyn DisconnectNotifier>>(Box::new(
TokioDisconnectNotifier::new(),
)),
discover_az,
connection_timeout: Some(connection_timeout),
connection_retry_strategy: Some(retry_strategy),
tcp_nodelay,
pubsub_synchronizer,
iam_token_provider: None,
};
let action = || async {
client
.get_multiplexed_async_connection(connection_options.clone())
.await
.map_err(|e| {
let is_permanent = matches!(
e.kind(),
crate::value::ErrorKind::AuthenticationFailed
| crate::value::ErrorKind::InvalidClientConfig
| crate::value::ErrorKind::RESP3NotSupported
) || e.to_string().contains("NOAUTH")
|| e.to_string().contains("WRONGPASS");
if is_permanent {
RetryError::permanent(e)
} else {
RetryError::transient(e)
}
})
};
let retry_future = Retry::spawn(retry_strategy.get_bounded_backoff_dur_iterator(), action);
let result = timeout(connection_timeout, retry_future).await;
match result {
Ok(Ok(connection)) => {
{
let client = connection_backend.get_backend_client();
let addr = &client.get_connection_info().addr;
tracing::debug!("connection creation - Connection to {addr} created");
}
tracing::info!(
target: "ferriskey",
event = "connection_opened",
"ferriskey: connection opened"
);
Ok(ReconnectingConnection {
inner: Arc::new(InnerReconnectingConnection {
state: Mutex::new(ConnectionState::Connected(connection)),
backend: connection_backend,
}),
connection_options,
})
}
err => {
let err: Error = match err {
Ok(Err(e)) => e,
_ => std::io::Error::from(std::io::ErrorKind::TimedOut).into(),
};
{
let client = connection_backend.get_backend_client();
let addr = &client.get_connection_info().addr;
tracing::warn!("connection creation - Failed connecting to {addr}, due to {err}");
}
let connection = ReconnectingConnection {
inner: Arc::new(InnerReconnectingConnection {
state: Mutex::new(ConnectionState::InitializedDisconnected),
backend: connection_backend,
}),
connection_options,
};
connection.reconnect(ReconnectReason::CreateError);
Err((connection, err))
}
}
}
fn get_client(
address: &NodeAddress,
tls_mode: TlsMode,
valkey_connection_info: crate::connection::info::ValkeyConnectionInfo,
tls_params: Option<crate::connection::tls::TlsConnParams>,
) -> crate::connection::factory::Client {
let connection_info =
super::get_connection_info(address, tls_mode, valkey_connection_info, tls_params);
crate::connection::factory::Client::open(connection_info)
.expect("Client::open from ConnectionInfo")
}
impl ConnectionBackend {
fn get_backend_client(&self) -> RwLockReadGuard<'_, crate::connection::factory::Client> {
self.connection_info.read().unwrap_or_else(|e| e.into_inner())
}
}
impl ReconnectingConnection {
#[allow(clippy::too_many_arguments)]
pub(super) async fn new(
address: &NodeAddress,
connection_retry_strategy: RetryStrategy,
valkey_connection_info: ValkeyConnectionInfo,
tls_mode: TlsMode,
push_sender: Option<mpsc::UnboundedSender<PushInfo>>,
discover_az: bool,
connection_timeout: Duration,
tls_params: Option<crate::connection::tls::TlsConnParams>,
tcp_nodelay: bool,
pubsub_synchronizer: Option<Arc<dyn crate::pubsub::PubSubSynchronizer>>,
#[cfg(feature = "iam")] iam_token_handle: Option<IAMTokenHandle>,
) -> std::result::Result<ReconnectingConnection, (ReconnectingConnection, Error)> {
tracing::debug!("connection creation - Attempting connection to {address}");
let connection_info = get_client(address, tls_mode, valkey_connection_info, tls_params);
let backend = ConnectionBackend {
connection_info: RwLock::new(connection_info),
connection_available_signal: ManualResetEvent::new(true),
client_dropped_flagged: AtomicBool::new(false),
#[cfg(feature = "iam")]
iam_token_handle,
};
create_connection(
backend,
connection_retry_strategy,
push_sender,
discover_az,
connection_timeout,
tcp_nodelay,
pubsub_synchronizer,
)
.await
}
pub(crate) fn node_address(&self) -> String {
self.inner
.backend
.get_backend_client()
.get_connection_info()
.addr
.to_string()
}
pub(super) fn is_dropped(&self) -> bool {
self.inner
.backend
.client_dropped_flagged
.load(Ordering::Relaxed)
}
pub(super) fn mark_as_dropped(&self) {
tracing::info!(
target: "ferriskey",
event = "connection_closed",
reason = "mark_as_dropped",
"ferriskey: connection closed"
);
self.inner
.backend
.client_dropped_flagged
.store(true, Ordering::Relaxed)
}
pub(super) async fn try_get_connection(&self) -> Option<MultiplexedConnection> {
let guard = self.inner.state.lock().unwrap();
if let ConnectionState::Connected(connection) = &*guard {
Some(connection.clone())
} else {
None
}
}
pub(super) async fn get_connection(&self) -> std::result::Result<MultiplexedConnection, Error> {
loop {
self.inner.backend.connection_available_signal.wait().await;
if let Some(connection) = self.try_get_connection().await {
return Ok(connection);
}
}
}
pub(super) fn reconnect(&self, reason: ReconnectReason) {
{
let mut guard = self.inner.state.lock().unwrap();
if matches!(*guard, ConnectionState::Reconnecting) {
tracing::trace!("reconnect - already started");
return;
}
self.inner.backend.connection_available_signal.reset();
*guard = ConnectionState::Reconnecting;
};
tracing::debug!("reconnect - starting");
let connection_clone = self.clone();
if reason.eq(&ReconnectReason::ConnectionDropped) {
tracing::warn!(
target: "ferriskey",
event = "connection_closed",
reason = "connection_dropped",
"ferriskey: connection dropped, reconnecting"
);
}
task::spawn(async move {
#[cfg(feature = "iam")]
let has_iam = connection_clone.inner.backend.iam_token_handle.is_some();
#[cfg(not(feature = "iam"))]
let has_iam = false;
let static_client = if !has_iam {
Some({
let guard = connection_clone.inner.backend.get_backend_client();
guard.clone()
})
} else {
None
};
let retry_strategy = connection_clone
.connection_options
.connection_retry_strategy
.expect("retry_strategy set by create_connection");
let infinite_backoff_dur_iterator = retry_strategy.get_infinite_backoff_dur_iterator();
for sleep_duration in infinite_backoff_dur_iterator {
if connection_clone.is_dropped() {
tracing::debug!("ReconnectingConnection - reconnect stopped after client was dropped");
return;
}
#[cfg(feature = "iam")]
if let Some(handle) = &connection_clone.inner.backend.iam_token_handle
&& let Some(valid_token) = handle.get_valid_token_inner().await
{
let mut client = connection_clone
.inner
.backend
.connection_info
.write()
.unwrap_or_else(|e| e.into_inner());
client.update_password(Some(valid_token));
tracing::debug!("reconnect - Updated connection password with valid IAM token before reconnection attempt");
}
let client = if let Some(ref c) = static_client {
c.clone()
} else {
let guard = connection_clone.inner.backend.get_backend_client();
guard.clone()
};
match get_multiplexed_connection(&client, &connection_clone.connection_options)
.await
{
Ok(mut connection) => {
if connection
.send_packed_command(&crate::cmd::cmd("PING"))
.await
.is_err()
{
tokio::time::sleep(sleep_duration).await;
continue;
}
{
let mut guard = connection_clone.inner.state.lock().unwrap();
tracing::debug!("reconnect - completed successfully");
connection_clone
.inner
.backend
.connection_available_signal
.set();
*guard = ConnectionState::Connected(connection);
}
tracing::info!(
target: "ferriskey",
event = "connection_opened",
reason = "reconnect",
"ferriskey: reconnect completed"
);
return;
}
Err(_) => tokio::time::sleep(sleep_duration).await,
}
}
});
}
pub fn is_connected(&self) -> bool {
matches!(
*self.inner.state.lock().unwrap(),
ConnectionState::Connected(_)
)
}
pub async fn wait_for_disconnect_with_timeout(&self, max_wait: &Duration) {
if let Some(disconnect_notifier) = &self.connection_options.disconnect_notifier {
disconnect_notifier
.wait_for_disconnect_with_timeout(max_wait)
.await;
} else {
tracing::error!("disconnect notifier - BUG! Disconnect notifier is not set");
tokio::time::sleep(super::CONNECTION_CHECKS_INTERVAL).await;
}
}
pub(crate) fn update_connection_password(&self, new_password: Option<String>) {
let mut client = self
.inner
.backend
.connection_info
.write()
.unwrap_or_else(|e| e.into_inner());
client.update_password(new_password);
}
pub(crate) fn update_connection_database(&self, new_database_id: i64) {
let mut client = self
.inner
.backend
.connection_info
.write()
.unwrap_or_else(|e| e.into_inner());
client.update_database(new_database_id);
}
pub(crate) fn update_connection_client_name(&self, new_client_name: Option<String>) {
let mut client = self
.inner
.backend
.connection_info
.write()
.unwrap_or_else(|e| e.into_inner());
client.update_client_name(new_client_name);
}
pub(crate) fn update_connection_username(&self, new_username: Option<String>) {
let mut client = self
.inner
.backend
.connection_info
.write()
.unwrap_or_else(|e| e.into_inner());
client.update_username(new_username);
}
pub(crate) fn update_connection_protocol(&self, new_protocol: crate::value::ProtocolVersion) {
let mut client = self
.inner
.backend
.connection_info
.write()
.unwrap_or_else(|e| e.into_inner());
client.update_protocol(new_protocol);
}
pub(crate) fn get_username(&self) -> Option<String> {
let client = self.inner.backend.get_backend_client();
client.get_connection_info().valkey.username.clone()
}
}