use super::get_valkey_connection_info;
use super::reconnecting_connection::{ReconnectReason, ReconnectingConnection};
use super::{ConnectionRequest, NodeAddress, TlsMode};
use crate::client::types::ReadFrom as ClientReadFrom;
use crate::cluster::routing::{
self as cluster_routing, ResponsePolicy, Routable, RoutingInfo, is_readonly_cmd,
};
use crate::connection::ConnectionLike;
use crate::pubsub::push_manager::PushInfo;
use crate::retry_strategies::RetryStrategy;
use crate::value::{Error, Result, Value};
use futures::{StreamExt, future, stream};
use std::sync::Arc;
use std::sync::atomic::AtomicUsize;
use std::sync::atomic::Ordering;
use std::time::Duration;
use tokio::sync::mpsc;
use tokio::task;
#[derive(Debug)]
enum ReadFrom {
Primary,
PreferReplica {
latest_read_replica_index: Arc<AtomicUsize>,
},
AZAffinity {
client_az: String,
last_read_replica_index: Arc<AtomicUsize>,
},
AZAffinityReplicasAndPrimary {
client_az: String,
last_read_replica_index: Arc<AtomicUsize>,
},
}
#[derive(Debug)]
struct DropWrapper {
primary_index: usize,
nodes: Vec<ReconnectingConnection>,
read_from: ReadFrom,
read_only: bool,
}
impl Drop for DropWrapper {
fn drop(&mut self) {
for node in self.nodes.iter() {
node.mark_as_dropped();
}
}
}
#[derive(Clone, Debug)]
pub struct StandaloneClient {
inner: Arc<DropWrapper>,
}
fn format_connection_errors(errors: Vec<(Option<String>, Error)>) -> Error {
if errors.len() == 1 {
return errors.into_iter().next().unwrap().1;
}
let detail: Vec<String> = errors
.iter()
.map(|(addr, err)| match addr {
Some(a) => format!("{a}: {err}"),
None => format!("{err}"),
})
.collect();
Error::from((
crate::value::ErrorKind::ClientError,
"Connection failed",
detail.join("; "),
))
}
impl StandaloneClient {
pub async fn create_client(
connection_request: ConnectionRequest,
push_sender: Option<mpsc::UnboundedSender<PushInfo>>,
#[cfg(feature = "iam")] iam_token_manager: Option<&Arc<crate::iam::IAMTokenManager>>,
pubsub_synchronizer: Option<Arc<dyn crate::pubsub::PubSubSynchronizer>>,
) -> std::result::Result<Self, Error> {
if connection_request.addresses.is_empty() {
return Err(Error::from((
crate::value::ErrorKind::InvalidClientConfig,
"No addresses provided",
)));
}
if connection_request.read_only
&& matches!(
connection_request.read_from,
Some(ClientReadFrom::AZAffinity(_))
| Some(ClientReadFrom::AZAffinityReplicasAndPrimary(_))
)
{
return Err(Error::from((
crate::value::ErrorKind::InvalidClientConfig,
"read-only mode is not compatible with AZAffinity strategies",
)));
}
#[cfg(feature = "iam")]
let valkey_connection_info =
get_valkey_connection_info(&connection_request, iam_token_manager).await;
#[cfg(not(feature = "iam"))]
let valkey_connection_info = get_valkey_connection_info(&connection_request).await;
let retry_strategy = match connection_request.connection_retry_strategy {
Some(strategy) => RetryStrategy::new(
strategy.exponent_base,
strategy.factor,
strategy.number_of_retries,
strategy.jitter_percent,
),
None => RetryStrategy::default(),
};
let tls_mode = connection_request.tls_mode;
let node_count = connection_request.addresses.len();
let discover_az = matches!(
connection_request.read_from,
Some(ClientReadFrom::AZAffinity(_))
| Some(ClientReadFrom::AZAffinityReplicasAndPrimary(_))
);
let connection_timeout = connection_request.get_connection_timeout();
let tcp_nodelay = connection_request.tcp_nodelay;
let has_root_certs = !connection_request.root_certs.is_empty();
let has_client_cert = !connection_request.client_cert.is_empty();
let has_client_key = !connection_request.client_key.is_empty();
if has_client_cert != has_client_key {
return Err(Error::from((
crate::value::ErrorKind::InvalidClientConfig,
"client_cert and client_key must both be provided or both be empty",
)));
}
let tls_params = if has_root_certs || has_client_cert || has_client_key {
if tls_mode.unwrap_or(TlsMode::NoTls) == TlsMode::NoTls {
return Err(Error::from((
crate::value::ErrorKind::InvalidClientConfig,
"TLS certificates provided but TLS is disabled",
)));
}
let root_cert = if has_root_certs {
let mut combined_certs = Vec::new();
for cert in &connection_request.root_certs {
combined_certs.extend_from_slice(cert);
}
Some(combined_certs)
} else {
None
};
let client_tls = if has_client_cert && has_client_key {
Some(crate::connection::tls::ClientTlsConfig {
client_cert: connection_request.client_cert.clone(),
client_key: connection_request.client_key.clone(),
})
} else {
None
};
let tls_certificates = crate::connection::tls::TlsCertificates {
client_tls,
root_cert,
};
Some(crate::connection::tls::retrieve_tls_certificates(
tls_certificates,
)?)
} else {
None
};
let read_only = connection_request.read_only;
let addresses = connection_request.addresses.clone();
let read_from_option = connection_request.read_from.clone();
#[cfg(feature = "iam")]
let iam_token_handle = iam_token_manager.map(|m| m.get_token_handle());
let mut stream = stream::iter(addresses)
.map(move |address| {
let info = valkey_connection_info.clone();
let retry = retry_strategy;
let sender = push_sender.clone();
let tls = tls_mode.unwrap_or(TlsMode::NoTls);
let discover = discover_az;
let timeout = connection_timeout;
let params = tls_params.clone();
let nodelay = tcp_nodelay;
let sync = pubsub_synchronizer.clone();
let skip_replication = read_only;
#[cfg(feature = "iam")]
let iam_handle = iam_token_handle.clone();
async move {
get_connection_and_replication_info(
&address,
&retry,
&info,
tls,
&sender,
discover,
timeout,
params,
nodelay,
&sync,
skip_replication,
#[cfg(feature = "iam")]
iam_handle,
)
.await
.map_err(|err| (format!("{}:{}", address.host, address.port), err))
}
})
.buffer_unordered(node_count);
let mut nodes = Vec::with_capacity(node_count);
let mut addresses_and_errors = Vec::with_capacity(node_count);
let mut primary_index = if read_only { Some(0) } else { None };
while let Some(result) = stream.next().await {
match result {
Ok((connection, replication_status)) => {
nodes.push(connection);
let is_primary = replication_status
.and_then(|status| {
crate::value::from_owned_value::<String>(status).ok()
})
.is_some_and(|val| val.contains("role:master"));
if is_primary {
if let Some(existing_primary) = primary_index {
let msg = format!(
"Primary nodes: {:?}, {:?}",
nodes.last(),
nodes.get(existing_primary)
);
for node in nodes.iter() {
node.mark_as_dropped();
}
return Err(Error::from((
crate::value::ErrorKind::ClientError,
"Primary conflict in standalone setup",
msg,
)));
}
primary_index = Some(nodes.len().saturating_sub(1));
}
}
Err((address, (connection, err))) => {
nodes.push(connection);
addresses_and_errors.push((Some(address), err));
}
}
}
let primary_index = if read_only {
if nodes.is_empty() && !addresses_and_errors.is_empty() {
for node in nodes.iter() {
node.mark_as_dropped();
}
return Err(format_connection_errors(addresses_and_errors));
}
0 } else {
match primary_index {
Some(idx) => idx,
None => {
for node in nodes.iter() {
node.mark_as_dropped();
}
if addresses_and_errors.is_empty() {
return Err(Error::from((
crate::value::ErrorKind::ClientError,
"No primary node found",
)));
}
return Err(format_connection_errors(addresses_and_errors));
}
}
};
if !addresses_and_errors.is_empty() {
tracing::warn!("client creation - Failed to connect to {addresses_and_errors:?}, will attempt to reconnect.");
}
let read_from = if read_only && read_from_option.is_none() {
ReadFrom::PreferReplica {
latest_read_replica_index: Default::default(),
}
} else {
get_read_from(read_from_option)
};
for node in nodes.iter() {
Self::start_heartbeat(node.clone());
}
for node in nodes.iter() {
Self::start_periodic_connection_check(node.clone());
}
tracing::info!(
target: "ferriskey",
event = "client_created",
nodes = nodes.len(),
"ferriskey: standalone client connected"
);
Ok(Self {
inner: Arc::new(DropWrapper {
primary_index,
nodes,
read_from,
read_only,
}),
})
}
fn get_primary_connection(&self) -> &ReconnectingConnection {
self.inner
.nodes
.get(self.inner.primary_index)
.expect("Primary index out of bounds — client in invalid state")
}
fn round_robin_read_from_replica(
&self,
latest_read_replica_index: &Arc<AtomicUsize>,
) -> &ReconnectingConnection {
let initial_index = latest_read_replica_index.load(Ordering::Relaxed);
let mut check_count = 0;
loop {
check_count += 1;
if check_count > self.inner.nodes.len() {
return self.get_primary_connection();
}
let index = (initial_index + check_count) % self.inner.nodes.len();
if index == self.inner.primary_index {
continue;
}
let Some(connection) = self.inner.nodes.get(index) else {
continue;
};
if connection.is_connected() {
let _ = latest_read_replica_index.compare_exchange_weak(
initial_index,
index,
Ordering::Relaxed,
Ordering::Relaxed,
);
return connection;
}
}
}
async fn round_robin_read_from_replica_az_awareness(
&self,
latest_read_replica_index: &Arc<AtomicUsize>,
client_az: String,
) -> &ReconnectingConnection {
let initial_index = latest_read_replica_index.load(Ordering::Relaxed);
let mut retries = 0usize;
loop {
retries = retries.saturating_add(1);
if retries > self.inner.nodes.len() {
return self.round_robin_read_from_replica(latest_read_replica_index);
}
let index = (initial_index + retries) % self.inner.nodes.len();
let replica = &self.inner.nodes[index];
if let Ok(connection) = replica.get_connection().await
&& let Some(replica_az) = connection.get_az().as_deref()
&& replica_az == client_az
{
let _ = latest_read_replica_index.compare_exchange_weak(
initial_index,
index,
Ordering::Relaxed,
Ordering::Relaxed,
);
return replica;
}
}
}
async fn round_robin_read_from_replica_az_awareness_replicas_and_primary(
&self,
latest_read_replica_index: &Arc<AtomicUsize>,
client_az: String,
) -> &ReconnectingConnection {
let initial_index = latest_read_replica_index.load(Ordering::Relaxed);
let mut retries = 0usize;
loop {
retries = retries.saturating_add(1);
if retries >= self.inner.nodes.len() {
break;
}
let index = (initial_index + retries) % self.inner.nodes.len();
let replica = &self.inner.nodes[index];
if let Ok(connection) = replica.get_connection().await
&& let Some(replica_az) = connection.get_az().as_deref()
&& replica_az == client_az
{
let _ = latest_read_replica_index.compare_exchange_weak(
initial_index,
index,
Ordering::Relaxed,
Ordering::Relaxed,
);
return replica;
}
}
let primary = self.get_primary_connection();
if let Ok(connection) = primary.get_connection().await
&& let Some(primary_az) = connection.get_az().as_deref()
&& primary_az == client_az
{
return primary;
}
self.round_robin_read_from_replica(latest_read_replica_index)
}
async fn get_connection(&self, readonly: bool) -> &ReconnectingConnection {
if self.inner.nodes.len() == 1 || !readonly {
return self.get_primary_connection();
}
match &self.inner.read_from {
ReadFrom::Primary => self.get_primary_connection(),
ReadFrom::PreferReplica {
latest_read_replica_index,
} => self.round_robin_read_from_replica(latest_read_replica_index),
ReadFrom::AZAffinity {
client_az,
last_read_replica_index,
} => {
self.round_robin_read_from_replica_az_awareness(
last_read_replica_index,
client_az.to_string(),
)
.await
}
ReadFrom::AZAffinityReplicasAndPrimary {
client_az,
last_read_replica_index,
} => {
self.round_robin_read_from_replica_az_awareness_replicas_and_primary(
last_read_replica_index,
client_az.to_string(),
)
.await
}
}
}
async fn send_request(
cmd: &crate::cmd::Cmd,
reconnecting_connection: &ReconnectingConnection,
) -> Result<Value> {
let mut connection = reconnecting_connection.get_connection().await?;
let result = connection.send_packed_command(cmd).await;
match result {
Err(err) if err.is_unrecoverable_error() => {
tracing::warn!("send request - received disconnect error `{err}`");
reconnecting_connection.reconnect(ReconnectReason::ConnectionDropped);
Err(err)
}
_ => result,
}
}
pub(crate) async fn send_request_to_all_nodes(
&mut self,
cmd: &crate::cmd::Cmd,
response_policy: Option<ResponsePolicy>,
) -> Result<Value> {
let requests = self
.inner
.nodes
.iter()
.map(|node| Self::send_request(cmd, node));
match response_policy {
Some(ResponsePolicy::AllSucceeded) => {
future::try_join_all(requests)
.await
.map(|mut results| results.pop().unwrap()) }
Some(ResponsePolicy::OneSucceeded) => future::select_ok(requests.map(Box::pin))
.await
.map(|(result, _)| result),
Some(ResponsePolicy::FirstSucceededNonEmptyOrAllEmpty) => {
future::select_ok(requests.map(|request| {
Box::pin(async move {
let result = request.await?;
match result {
Value::Nil => {
Err((crate::value::ErrorKind::ResponseError, "no value found")
.into())
}
_ => Ok(result),
}
})
}))
.await
.map(|(result, _)| result)
}
Some(ResponsePolicy::Aggregate(op)) => future::try_join_all(requests)
.await
.and_then(|results| cluster_routing::aggregate(results, op)),
Some(ResponsePolicy::AggregateArray(op)) => future::try_join_all(requests)
.await
.and_then(|results| cluster_routing::aggregate_array(results, op)),
Some(ResponsePolicy::AggregateLogical(op)) => future::try_join_all(requests)
.await
.and_then(|results| cluster_routing::logical_aggregate(results, op)),
Some(ResponsePolicy::CombineArrays) => future::try_join_all(requests)
.await
.and_then(cluster_routing::combine_array_results),
Some(ResponsePolicy::CombineMaps) => future::try_join_all(requests)
.await
.and_then(cluster_routing::combine_map_results),
Some(ResponsePolicy::Special) => {
let results = future::try_join_all(requests).await?;
let node_result_pairs = self
.inner
.nodes
.iter()
.zip(results)
.map(|(node, result)| (Value::BulkString(node.node_address().into()), result))
.collect();
Ok(Value::Map(node_result_pairs))
}
None => {
future::try_join_all(requests).await.map(|vals| Value::Array(vals.into_iter().map(Ok).collect()))
}
}
}
async fn send_request_to_single_node(
&mut self,
cmd: &crate::cmd::Cmd,
readonly: bool,
) -> Result<Value> {
let reconnecting_connection = self.get_connection(readonly).await;
Self::send_request(cmd, reconnecting_connection).await
}
pub async fn send_command(&mut self, cmd: &crate::cmd::Cmd) -> Result<Value> {
let Some(cmd_bytes) = Routable::command(cmd) else {
return self.send_request_to_single_node(cmd, false).await;
};
if self.inner.read_only && !is_readonly_cmd(cmd_bytes.as_slice()) {
return Err(Error::from((
crate::value::ErrorKind::ReadOnly,
"write commands are not allowed in read-only mode",
)));
}
if RoutingInfo::is_all_nodes(cmd_bytes.as_slice()) {
let response_policy = ResponsePolicy::for_command(cmd_bytes.as_slice());
return self.send_request_to_all_nodes(cmd, response_policy).await;
}
self.send_request_to_single_node(cmd, is_readonly_cmd(cmd_bytes.as_slice()))
.await
}
pub async fn send_pipeline(
&mut self,
pipeline: &crate::pipeline::Pipeline,
offset: usize,
count: usize,
) -> Result<Vec<Result<Value>>> {
let reconnecting_connection = self.get_primary_connection();
let mut connection = reconnecting_connection.get_connection().await?;
let result = connection
.send_packed_commands(pipeline, offset, count)
.await;
match result {
Err(err) if err.is_unrecoverable_error() => {
tracing::warn!("pipeline request - received disconnect error `{err}`");
reconnecting_connection.reconnect(ReconnectReason::ConnectionDropped);
Err(err)
}
_ => result,
}
}
fn start_heartbeat(reconnecting_connection: ReconnectingConnection) {
task::spawn(async move {
loop {
tokio::time::sleep(super::HEARTBEAT_SLEEP_DURATION).await;
if reconnecting_connection.is_dropped() {
tracing::debug!("StandaloneClient - heartbeat stopped after connection was dropped");
return;
}
let Some(mut connection) = reconnecting_connection.try_get_connection().await
else {
tracing::debug!("StandaloneClient - heartbeat stopped while connection is reconnecting");
continue;
};
tracing::debug!("StandaloneClient - performing heartbeat");
if connection
.send_packed_command(&crate::cmd::cmd("PING"))
.await
.is_err_and(|err| err.is_connection_dropped() || err.is_connection_refusal())
{
tracing::debug!("StandaloneClient - heartbeat triggered reconnect");
reconnecting_connection.reconnect(ReconnectReason::ConnectionDropped);
}
}
});
}
fn start_periodic_connection_check(reconnecting_connection: ReconnectingConnection) {
task::spawn(async move {
loop {
reconnecting_connection
.wait_for_disconnect_with_timeout(&super::CONNECTION_CHECKS_INTERVAL)
.await;
if reconnecting_connection.is_dropped() {
tracing::debug!("StandaloneClient - connection checker stopped after connection was dropped");
return;
}
let Some(connection) = reconnecting_connection.try_get_connection().await else {
tracing::debug!("StandaloneClient - connection checker is skipping a connections since its reconnecting");
continue;
};
if connection.is_closed() {
tracing::debug!("StandaloneClient - connection checker has triggered reconnect");
reconnecting_connection.reconnect(ReconnectReason::ConnectionDropped);
}
}
});
}
pub async fn update_connection_password(
&self,
new_password: Option<String>,
) -> Result<Value> {
for node in self.inner.nodes.iter() {
node.update_connection_password(new_password.clone());
}
Ok(Value::Okay)
}
pub async fn update_connection_database(&self, database_id: i64) -> Result<Value> {
for node in self.inner.nodes.iter() {
node.update_connection_database(database_id);
}
Ok(Value::Okay)
}
pub async fn update_connection_client_name(
&self,
new_client_name: Option<String>,
) -> Result<Value> {
for node in self.inner.nodes.iter() {
node.update_connection_client_name(new_client_name.clone());
}
Ok(Value::Okay)
}
pub async fn update_connection_username(
&self,
new_username: Option<String>,
) -> Result<Value> {
for node in self.inner.nodes.iter() {
node.update_connection_username(new_username.clone());
}
Ok(Value::Okay)
}
pub async fn update_connection_protocol(
&self,
new_protocol: crate::value::ProtocolVersion,
) -> Result<Value> {
for node in self.inner.nodes.iter() {
node.update_connection_protocol(new_protocol);
}
Ok(Value::Okay)
}
pub fn get_username(&self) -> Option<String> {
self.get_primary_connection().get_username()
}
}
#[allow(clippy::too_many_arguments)]
async fn get_connection_and_replication_info(
address: &NodeAddress,
retry_strategy: &RetryStrategy,
connection_info: &crate::connection::info::ValkeyConnectionInfo,
tls_mode: TlsMode,
push_sender: &Option<mpsc::UnboundedSender<PushInfo>>,
discover_az: bool,
connection_timeout: Duration,
tls_params: Option<crate::connection::tls::TlsConnParams>,
tcp_nodelay: bool,
pubsub_synchronizer: &Option<Arc<dyn crate::pubsub::PubSubSynchronizer>>,
skip_replication_check: bool,
#[cfg(feature = "iam")] iam_token_handle: Option<super::IAMTokenHandle>,
) -> std::result::Result<(ReconnectingConnection, Option<Value>), (ReconnectingConnection, Error)> {
let reconnecting_connection = ReconnectingConnection::new(
address,
*retry_strategy,
connection_info.clone(),
tls_mode,
push_sender.clone(),
discover_az,
connection_timeout,
tls_params,
tcp_nodelay,
pubsub_synchronizer.clone(),
#[cfg(feature = "iam")]
iam_token_handle,
)
.await?;
let mut multiplexed_connection = match reconnecting_connection.get_connection().await {
Ok(multiplexed_connection) => multiplexed_connection,
Err(err) => {
reconnecting_connection.reconnect(ReconnectReason::ConnectionDropped);
return Err((reconnecting_connection, err));
}
};
if skip_replication_check {
return Ok((reconnecting_connection, None));
}
match multiplexed_connection
.send_packed_command(crate::cmd::cmd("INFO").arg("REPLICATION"))
.await
{
Ok(replication_status) => Ok((reconnecting_connection, Some(replication_status))),
Err(err) => Err((reconnecting_connection, err)),
}
}
fn get_read_from(read_from: Option<super::ReadFrom>) -> ReadFrom {
match read_from {
Some(super::ReadFrom::Primary) => ReadFrom::Primary,
Some(super::ReadFrom::PreferReplica) => ReadFrom::PreferReplica {
latest_read_replica_index: Default::default(),
},
Some(super::ReadFrom::AZAffinity(az)) => ReadFrom::AZAffinity {
client_az: az,
last_read_replica_index: Default::default(),
},
Some(super::ReadFrom::AZAffinityReplicasAndPrimary(az)) => {
ReadFrom::AZAffinityReplicasAndPrimary {
client_az: az,
last_read_replica_index: Default::default(),
}
}
None => ReadFrom::Primary,
}
}